summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc55
-rw-r--r--.github/ISSUE_TEMPLATE/bug_report.md31
-rw-r--r--.github/ISSUE_TEMPLATE/config.yml11
-rw-r--r--.github/ISSUE_TEMPLATE/feature_request.md21
-rw-r--r--.github/issue_template.md20
-rw-r--r--.github/labeler.yml42
-rw-r--r--.github/pull_request_template.md5
-rw-r--r--.github/workflows/build.yml21
-rw-r--r--.github/workflows/go.yml75
-rw-r--r--.github/workflows/issue_reviver.yml16
-rw-r--r--.github/workflows/labeler.yml12
-rw-r--r--.github/workflows/stale.yml20
-rw-r--r--.travis.yml47
-rw-r--r--BUILD91
-rw-r--r--CODE_OF_CONDUCT.md5
-rw-r--r--CONTRIBUTING.md44
-rw-r--r--Dockerfile8
-rw-r--r--GOVERNANCE.md113
-rw-r--r--LICENSE22
-rw-r--r--Makefile396
-rw-r--r--README.md102
-rw-r--r--SECURITY.md3
-rw-r--r--WORKSPACE1004
-rw-r--r--g3doc/BUILD44
-rw-r--r--g3doc/Layers.pngbin0 -> 11044 bytes
-rw-r--r--g3doc/Layers.svg1
-rw-r--r--g3doc/Machine-Virtualization.pngbin0 -> 13205 bytes
-rw-r--r--g3doc/Machine-Virtualization.svg1
-rw-r--r--g3doc/README.md166
-rw-r--r--g3doc/Rule-Based-Execution.pngbin0 -> 6780 bytes
-rw-r--r--g3doc/Rule-Based-Execution.svg1
-rw-r--r--g3doc/Sentry-Gofer.pngbin0 -> 9064 bytes
-rw-r--r--g3doc/Sentry-Gofer.svg1
-rw-r--r--g3doc/architecture_guide/BUILD50
-rw-r--r--g3doc/architecture_guide/performance.md277
-rw-r--r--g3doc/architecture_guide/platforms.md61
-rw-r--r--g3doc/architecture_guide/platforms.pngbin0 -> 21384 bytes
-rw-r--r--g3doc/architecture_guide/platforms.svg334
-rw-r--r--g3doc/architecture_guide/resources.md144
-rw-r--r--g3doc/architecture_guide/resources.pngbin0 -> 16621 bytes
-rw-r--r--g3doc/architecture_guide/resources.svg208
-rw-r--r--g3doc/architecture_guide/security.md255
-rw-r--r--g3doc/architecture_guide/security.pngbin0 -> 16932 bytes
-rw-r--r--g3doc/architecture_guide/security.svg153
-rw-r--r--g3doc/community.md31
-rw-r--r--g3doc/logo.txt1
-rw-r--r--g3doc/roadmap.md49
-rw-r--r--g3doc/style.md88
-rw-r--r--g3doc/user_guide/BUILD70
-rw-r--r--g3doc/user_guide/FAQ.md122
-rw-r--r--g3doc/user_guide/checkpoint_restore.md101
-rw-r--r--g3doc/user_guide/compatibility.md93
-rw-r--r--g3doc/user_guide/containerd/BUILD33
-rw-r--r--g3doc/user_guide/containerd/configuration.md70
-rw-r--r--g3doc/user_guide/containerd/containerd_11.md163
-rw-r--r--g3doc/user_guide/containerd/quick_start.md176
-rw-r--r--g3doc/user_guide/debugging.md141
-rw-r--r--g3doc/user_guide/filesystem.md60
-rw-r--r--g3doc/user_guide/install.md157
-rw-r--r--g3doc/user_guide/networking.md85
-rw-r--r--g3doc/user_guide/platforms.md95
-rw-r--r--g3doc/user_guide/quick_start/BUILD33
-rw-r--r--g3doc/user_guide/quick_start/docker.md96
-rw-r--r--g3doc/user_guide/quick_start/kubernetes.md34
-rw-r--r--g3doc/user_guide/quick_start/oci.md43
-rw-r--r--g3doc/user_guide/tutorials/BUILD37
-rw-r--r--g3doc/user_guide/tutorials/add-node-pool.pngbin0 -> 70208 bytes
-rw-r--r--g3doc/user_guide/tutorials/cni.md174
-rw-r--r--g3doc/user_guide/tutorials/docker.md68
-rw-r--r--g3doc/user_guide/tutorials/kubernetes.md134
-rw-r--r--g3doc/user_guide/tutorials/node-pool-button.pngbin0 -> 13757 bytes
-rw-r--r--go.mod63
-rw-r--r--go.sum380
-rw-r--r--images/BUILD11
-rw-r--r--images/Makefile100
-rw-r--r--images/README.md70
-rw-r--r--images/basic/alpine/Dockerfile1
-rw-r--r--images/basic/busybox/Dockerfile1
-rw-r--r--images/basic/hostoverlaytest/Dockerfile8
-rw-r--r--images/basic/hostoverlaytest/copy_up_testfile.txt1
-rw-r--r--images/basic/hostoverlaytest/test_copy_up.c88
-rw-r--r--images/basic/hostoverlaytest/test_rewinddir.c78
-rw-r--r--images/basic/httpd/Dockerfile1
-rw-r--r--images/basic/linktest/Dockerfile7
-rw-r--r--images/basic/linktest/link_test.c93
-rw-r--r--images/basic/mysql/Dockerfile1
-rw-r--r--images/basic/nginx/Dockerfile1
-rw-r--r--images/basic/python/Dockerfile2
-rw-r--r--images/basic/resolv/Dockerfile1
-rw-r--r--images/basic/ruby/Dockerfile1
-rw-r--r--images/basic/tmpfile/Dockerfile4
-rw-r--r--images/basic/tomcat/Dockerfile1
-rw-r--r--images/basic/ubuntu/Dockerfile1
-rw-r--r--images/benchmarks/ab/Dockerfile7
-rw-r--r--images/benchmarks/absl/Dockerfile21
-rw-r--r--images/benchmarks/alpine/Dockerfile1
-rw-r--r--images/benchmarks/ffmpeg/Dockerfile9
-rw-r--r--images/benchmarks/fio/Dockerfile7
-rw-r--r--images/benchmarks/hey/Dockerfile12
-rw-r--r--images/benchmarks/httpd/Dockerfile17
-rw-r--r--images/benchmarks/httpd/apache2-tmpdir.conf5
-rw-r--r--images/benchmarks/iperf/Dockerfile8
-rw-r--r--images/benchmarks/nginx/Dockerfile1
-rw-r--r--images/benchmarks/node/Dockerfile1
-rw-r--r--images/benchmarks/node/index.hbs8
-rw-r--r--images/benchmarks/node/index.js42
-rw-r--r--images/benchmarks/node/package-lock.json486
-rw-r--r--images/benchmarks/node/package.json19
-rw-r--r--images/benchmarks/redis/Dockerfile1
-rwxr-xr-ximages/benchmarks/ruby/Dockerfile27
-rwxr-xr-ximages/benchmarks/ruby/Gemfile5
-rw-r--r--images/benchmarks/ruby/Gemfile.lock26
-rwxr-xr-ximages/benchmarks/ruby/config.ru2
-rwxr-xr-ximages/benchmarks/ruby/index.erb8
-rwxr-xr-ximages/benchmarks/ruby/main.rb27
-rw-r--r--images/benchmarks/runsc/Dockerfile24
-rw-r--r--images/benchmarks/sysbench/Dockerfile7
-rw-r--r--images/benchmarks/tensorflow/Dockerfile7
-rw-r--r--images/benchmarks/util/Dockerfile3
-rw-r--r--images/default/Dockerfile16
-rw-r--r--images/iptables/Dockerfile2
-rw-r--r--images/jekyll/Dockerfile14
-rw-r--r--images/jekyll/checks.rb36
-rw-r--r--images/packetdrill/Dockerfile8
-rw-r--r--images/packetimpact/Dockerfile16
-rw-r--r--images/runtimes/go1.12/Dockerfile4
-rw-r--r--images/runtimes/java11/Dockerfile (renamed from test/runtimes/images/Dockerfile_java11)8
-rw-r--r--images/runtimes/nodejs12.4.0/Dockerfile (renamed from test/runtimes/images/Dockerfile_nodejs12.4.0)9
-rw-r--r--images/runtimes/php7.3.6/Dockerfile (renamed from test/runtimes/images/Dockerfile_php7.3.6)8
-rw-r--r--images/runtimes/python3.7.3/Dockerfile (renamed from test/runtimes/images/Dockerfile_python3.7.3)9
-rw-r--r--kokoro/build.cfg23
-rw-r--r--kokoro/build_tests.cfg1
-rw-r--r--kokoro/common.cfg29
-rw-r--r--kokoro/do_tests.cfg9
-rw-r--r--kokoro/docker_tests.cfg10
-rw-r--r--kokoro/go.cfg20
-rw-r--r--kokoro/go_tests.cfg1
-rw-r--r--kokoro/hostnet_tests.cfg10
-rw-r--r--kokoro/kvm_tests.cfg10
-rw-r--r--kokoro/make_tests.cfg9
-rw-r--r--kokoro/overlay_tests.cfg10
-rw-r--r--kokoro/release.cfg1
-rw-r--r--kokoro/root_tests.cfg10
-rw-r--r--kokoro/simple_tests.cfg9
-rw-r--r--kokoro/swgso_tests.cfg9
-rw-r--r--kokoro/syscall_tests.cfg9
-rwxr-xr-xkokoro/ubuntu1604/30_containerd.sh76
l---------kokoro/ubuntu1804/10_core.sh1
l---------kokoro/ubuntu1804/20_bazel.sh1
l---------kokoro/ubuntu1804/25_docker.sh1
l---------kokoro/ubuntu1804/30_containerd.sh1
l---------kokoro/ubuntu1804/40_kokoro.sh1
-rw-r--r--pkg/abi/BUILD3
-rw-r--r--pkg/abi/abi.go4
-rw-r--r--pkg/abi/linux/BUILD21
-rw-r--r--pkg/abi/linux/aio.go60
-rw-r--r--pkg/abi/linux/arch_amd64.go23
-rw-r--r--pkg/abi/linux/dev.go11
-rw-r--r--pkg/abi/linux/elf.go3
-rw-r--r--pkg/abi/linux/epoll.go22
-rw-r--r--pkg/abi/linux/epoll_amd64.go (renamed from pkg/sentry/fsimpl/proc/mounts.go)26
-rw-r--r--pkg/abi/linux/epoll_arm64.go28
-rw-r--r--pkg/abi/linux/fadvise.go24
-rw-r--r--pkg/abi/linux/fcntl.go41
-rw-r--r--pkg/abi/linux/file.go99
-rw-r--r--pkg/abi/linux/file_amd64.go12
-rw-r--r--pkg/abi/linux/file_arm64.go12
-rw-r--r--pkg/abi/linux/fs.go9
-rw-r--r--pkg/abi/linux/fuse.go303
-rw-r--r--pkg/abi/linux/futex.go18
-rw-r--r--pkg/abi/linux/ioctl.go53
-rw-r--r--pkg/abi/linux/ioctl_tun.go29
-rw-r--r--pkg/abi/linux/ip.go10
-rw-r--r--pkg/abi/linux/mm.go17
-rw-r--r--pkg/abi/linux/netdevice.go4
-rw-r--r--pkg/abi/linux/netfilter.go443
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go310
-rw-r--r--pkg/abi/linux/netfilter_test.go4
-rw-r--r--pkg/abi/linux/netlink_route.go17
-rw-r--r--pkg/abi/linux/ptrace_amd64.go52
-rw-r--r--pkg/abi/linux/ptrace_arm64.go29
-rw-r--r--pkg/abi/linux/rseq.go130
-rw-r--r--pkg/abi/linux/seccomp.go7
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/socket.go49
-rw-r--r--pkg/abi/linux/tcp.go1
-rw-r--r--pkg/abi/linux/time.go21
-rw-r--r--pkg/abi/linux/xattr.go28
-rw-r--r--pkg/amutex/BUILD8
-rw-r--r--pkg/amutex/amutex.go17
-rw-r--r--pkg/amutex/amutex_test.go3
-rw-r--r--pkg/atomicbitops/BUILD17
-rw-r--r--pkg/atomicbitops/atomicbitops.go (renamed from pkg/atomicbitops/atomic_bitops.go)31
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s (renamed from pkg/atomicbitops/atomic_bitops_amd64.s)38
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s (renamed from pkg/atomicbitops/atomic_bitops_arm64.s)34
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go (renamed from pkg/atomicbitops/atomic_bitops_common.go)42
-rw-r--r--pkg/atomicbitops/atomicbitops_test.go (renamed from pkg/atomicbitops/atomic_bitops_test.go)67
-rw-r--r--pkg/binary/BUILD6
-rw-r--r--pkg/binary/binary.go10
-rw-r--r--pkg/bits/BUILD6
-rw-r--r--pkg/bits/bits_template.go8
-rw-r--r--pkg/bits/uint64_test.go18
-rw-r--r--pkg/bpf/BUILD6
-rw-r--r--pkg/bpf/interpreter_test.go2
-rw-r--r--pkg/buffer/BUILD43
-rw-r--r--pkg/buffer/buffer.go94
-rw-r--r--pkg/buffer/safemem.go133
-rw-r--r--pkg/buffer/safemem_test.go170
-rw-r--r--pkg/buffer/view.go390
-rw-r--r--pkg/buffer/view_test.go467
-rw-r--r--pkg/buffer/view_unsafe.go25
-rw-r--r--pkg/cleanup/BUILD17
-rw-r--r--pkg/cleanup/cleanup.go60
-rw-r--r--pkg/cleanup/cleanup_test.go66
-rw-r--r--pkg/compressio/BUILD11
-rw-r--r--pkg/compressio/compressio.go56
-rw-r--r--pkg/context/BUILD (renamed from pkg/sentry/context/BUILD)5
-rw-r--r--pkg/context/context.go (renamed from pkg/sentry/context/context.go)67
-rw-r--r--pkg/control/client/BUILD3
-rw-r--r--pkg/control/server/BUILD4
-rw-r--r--pkg/control/server/server.go2
-rw-r--r--pkg/cpuid/BUILD17
-rw-r--r--pkg/cpuid/cpuid.go1058
-rw-r--r--pkg/cpuid/cpuid_arm64.go483
-rw-r--r--pkg/cpuid/cpuid_arm64_test.go55
-rw-r--r--pkg/cpuid/cpuid_parse_x86_test.go (renamed from pkg/cpuid/cpuid_parse_test.go)2
-rw-r--r--pkg/cpuid/cpuid_x86.go1111
-rw-r--r--pkg/cpuid/cpuid_x86_test.go (renamed from pkg/cpuid/cpuid_test.go)2
-rw-r--r--pkg/eventchannel/BUILD17
-rw-r--r--pkg/eventchannel/event.go2
-rw-r--r--pkg/eventchannel/event_test.go6
-rw-r--r--pkg/fd/BUILD6
-rw-r--r--pkg/fdchannel/BUILD9
-rw-r--r--pkg/fdchannel/fdchannel_test.go3
-rw-r--r--pkg/fdnotifier/BUILD4
-rw-r--r--pkg/fdnotifier/fdnotifier.go2
-rw-r--r--pkg/flipcall/BUILD12
-rw-r--r--pkg/flipcall/ctrl_futex.go2
-rw-r--r--pkg/flipcall/flipcall.go12
-rw-r--r--pkg/flipcall/flipcall_example_test.go3
-rw-r--r--pkg/flipcall/flipcall_test.go3
-rw-r--r--pkg/flipcall/flipcall_unsafe.go10
-rw-r--r--pkg/flipcall/futex_linux.go6
-rw-r--r--pkg/flipcall/packet_window_allocator.go4
-rw-r--r--pkg/flipcall/packet_window_mmap.go25
-rw-r--r--pkg/fspath/BUILD19
-rw-r--r--pkg/fspath/builder.go8
-rw-r--r--pkg/fspath/fspath.go27
-rw-r--r--pkg/fspath/fspath_test.go25
-rw-r--r--pkg/gate/BUILD5
-rw-r--r--pkg/gate/gate_test.go5
-rw-r--r--pkg/gohacks/BUILD12
-rw-r--r--pkg/gohacks/gohacks_unsafe.go57
-rw-r--r--pkg/goid/BUILD25
-rw-r--r--pkg/goid/empty_test.go22
-rw-r--r--pkg/goid/goid.go24
-rw-r--r--pkg/goid/goid_amd64.s21
-rw-r--r--pkg/goid/goid_arm64.s21
-rw-r--r--pkg/goid/goid_race.go25
-rw-r--r--pkg/goid/goid_test.go74
-rw-r--r--pkg/goid/goid_unsafe.go64
-rw-r--r--pkg/ilist/BUILD6
-rw-r--r--pkg/ilist/list.go58
-rw-r--r--pkg/iovec/BUILD18
-rw-r--r--pkg/iovec/iovec.go75
-rw-r--r--pkg/iovec/iovec_test.go121
-rw-r--r--pkg/linewriter/BUILD9
-rw-r--r--pkg/linewriter/linewriter.go3
-rw-r--r--pkg/log/BUILD14
-rw-r--r--pkg/log/glog.go162
-rw-r--r--pkg/log/json.go4
-rw-r--r--pkg/log/json_k8s.go4
-rw-r--r--pkg/log/log.go76
-rw-r--r--pkg/log/log_test.go37
-rw-r--r--pkg/memutil/BUILD3
-rw-r--r--pkg/merkletree/BUILD17
-rw-r--r--pkg/merkletree/merkletree.go314
-rw-r--r--pkg/merkletree/merkletree_test.go353
-rw-r--r--pkg/metric/BUILD24
-rw-r--r--pkg/metric/metric.go46
-rw-r--r--pkg/metric/metric.proto10
-rw-r--r--pkg/metric/metric_test.go22
-rw-r--r--pkg/p9/BUILD10
-rw-r--r--pkg/p9/buffer.go10
-rw-r--r--pkg/p9/client.go56
-rw-r--r--pkg/p9/client_file.go72
-rw-r--r--pkg/p9/client_test.go9
-rw-r--r--pkg/p9/file.go34
-rw-r--r--pkg/p9/handlers.go145
-rw-r--r--pkg/p9/messages.go898
-rw-r--r--pkg/p9/messages_test.go21
-rw-r--r--pkg/p9/p9.go130
-rw-r--r--pkg/p9/p9test/BUILD8
-rw-r--r--pkg/p9/p9test/client_test.go80
-rw-r--r--pkg/p9/p9test/p9test.go2
-rw-r--r--pkg/p9/path_tree.go3
-rw-r--r--pkg/p9/server.go12
-rw-r--r--pkg/p9/transport.go70
-rw-r--r--pkg/p9/transport_flipcall.go6
-rw-r--r--pkg/p9/transport_test.go10
-rw-r--r--pkg/p9/version.go20
-rw-r--r--pkg/pool/BUILD25
-rw-r--r--pkg/pool/pool.go (renamed from pkg/p9/pool.go)28
-rw-r--r--pkg/pool/pool_test.go (renamed from pkg/p9/pool_test.go)8
-rw-r--r--pkg/procid/BUILD10
-rw-r--r--pkg/procid/procid_amd64.s2
-rw-r--r--pkg/procid/procid_arm64.s2
-rw-r--r--pkg/procid/procid_test.go3
-rw-r--r--pkg/rand/BUILD8
-rw-r--r--pkg/rand/rand_linux.go19
-rw-r--r--pkg/refs/BUILD12
-rw-r--r--pkg/refs/refcounter.go38
-rw-r--r--pkg/refs/refcounter_test.go41
-rw-r--r--pkg/refs_vfs2/BUILD26
-rw-r--r--pkg/refs_vfs2/refs.go36
-rw-r--r--pkg/refs_vfs2/refs_template.go133
-rw-r--r--pkg/safecopy/BUILD (renamed from pkg/sentry/platform/safecopy/BUILD)8
-rw-r--r--pkg/safecopy/LICENSE (renamed from pkg/sentry/platform/safecopy/LICENSE)0
-rw-r--r--pkg/safecopy/atomic_amd64.s (renamed from pkg/sentry/platform/safecopy/atomic_amd64.s)0
-rw-r--r--pkg/safecopy/atomic_arm64.s (renamed from pkg/sentry/platform/safecopy/atomic_arm64.s)0
-rw-r--r--pkg/safecopy/memclr_amd64.s (renamed from pkg/sentry/platform/safecopy/memclr_amd64.s)0
-rw-r--r--pkg/safecopy/memclr_arm64.s (renamed from pkg/sentry/platform/safecopy/memclr_arm64.s)0
-rw-r--r--pkg/safecopy/memcpy_amd64.s (renamed from pkg/sentry/platform/safecopy/memcpy_amd64.s)111
-rw-r--r--pkg/safecopy/memcpy_arm64.s (renamed from pkg/sentry/platform/safecopy/memcpy_arm64.s)0
-rw-r--r--pkg/safecopy/safecopy.go (renamed from pkg/sentry/platform/safecopy/safecopy.go)0
-rw-r--r--pkg/safecopy/safecopy_test.go (renamed from pkg/sentry/platform/safecopy/safecopy_test.go)88
-rw-r--r--pkg/safecopy/safecopy_unsafe.go (renamed from pkg/sentry/platform/safecopy/safecopy_unsafe.go)98
-rw-r--r--pkg/safecopy/sighandler_amd64.s (renamed from pkg/sentry/platform/safecopy/sighandler_amd64.s)0
-rw-r--r--pkg/safecopy/sighandler_arm64.s (renamed from pkg/sentry/platform/safecopy/sighandler_arm64.s)0
-rw-r--r--pkg/safemem/BUILD (renamed from pkg/sentry/safemem/BUILD)10
-rw-r--r--pkg/safemem/block_unsafe.go (renamed from pkg/sentry/safemem/block_unsafe.go)2
-rw-r--r--pkg/safemem/io.go (renamed from pkg/sentry/safemem/io.go)0
-rw-r--r--pkg/safemem/io_test.go (renamed from pkg/sentry/safemem/io_test.go)0
-rw-r--r--pkg/safemem/safemem.go (renamed from pkg/sentry/safemem/safemem.go)0
-rw-r--r--pkg/safemem/seq_test.go (renamed from pkg/sentry/safemem/seq_test.go)21
-rw-r--r--pkg/safemem/seq_unsafe.go (renamed from pkg/sentry/safemem/seq_unsafe.go)20
-rw-r--r--pkg/seccomp/BUILD14
-rw-r--r--pkg/seccomp/seccomp.go33
-rw-r--r--pkg/seccomp/seccomp_rules.go13
-rw-r--r--pkg/seccomp/seccomp_test.go123
-rw-r--r--pkg/seccomp/seccomp_test_victim.go9
-rw-r--r--pkg/seccomp/seccomp_test_victim_amd64.go32
-rw-r--r--pkg/seccomp/seccomp_test_victim_arm64.go29
-rw-r--r--pkg/seccomp/seccomp_unsafe.go9
-rw-r--r--pkg/secio/BUILD6
-rw-r--r--pkg/segment/BUILD2
-rw-r--r--pkg/segment/set.go400
-rw-r--r--pkg/segment/test/BUILD24
-rw-r--r--pkg/segment/test/segment_test.go397
-rw-r--r--pkg/segment/test/set_functions.go32
-rw-r--r--pkg/sentry/BUILD6
-rw-r--r--pkg/sentry/arch/BUILD34
-rw-r--r--pkg/sentry/arch/arch.go11
-rw-r--r--pkg/sentry/arch/arch_aarch64.go326
-rw-r--r--pkg/sentry/arch/arch_amd64.go23
-rw-r--r--pkg/sentry/arch/arch_amd64.s7
-rw-r--r--pkg/sentry/arch/arch_arm64.go286
-rw-r--r--pkg/sentry/arch/arch_state_x86.go50
-rw-r--r--pkg/sentry/arch/arch_x86.go54
-rw-r--r--pkg/sentry/arch/arch_x86_impl.go41
-rw-r--r--pkg/sentry/arch/auxv.go2
-rw-r--r--pkg/sentry/arch/registers.proto37
-rw-r--r--pkg/sentry/arch/signal.go253
-rw-r--r--pkg/sentry/arch/signal_act.go4
-rw-r--r--pkg/sentry/arch/signal_amd64.go234
-rw-r--r--pkg/sentry/arch/signal_arm64.go181
-rw-r--r--pkg/sentry/arch/signal_stack.go7
-rw-r--r--pkg/sentry/arch/stack.go7
-rw-r--r--pkg/sentry/arch/syscalls_amd64.go7
-rw-r--r--pkg/sentry/arch/syscalls_arm64.go81
-rw-r--r--pkg/sentry/contexttest/BUILD (renamed from pkg/sentry/context/contexttest/BUILD)5
-rw-r--r--pkg/sentry/contexttest/contexttest.go (renamed from pkg/sentry/context/contexttest/contexttest.go)6
-rw-r--r--pkg/sentry/control/BUILD14
-rw-r--r--pkg/sentry/control/logging.go4
-rw-r--r--pkg/sentry/control/pprof.go51
-rw-r--r--pkg/sentry/control/proc.go189
-rw-r--r--pkg/sentry/control/proc_test.go10
-rw-r--r--pkg/sentry/device/BUILD11
-rw-r--r--pkg/sentry/device/device.go5
-rw-r--r--pkg/sentry/devices/memdev/BUILD28
-rw-r--r--pkg/sentry/devices/memdev/full.go76
-rw-r--r--pkg/sentry/devices/memdev/memdev.go59
-rw-r--r--pkg/sentry/devices/memdev/null.go77
-rw-r--r--pkg/sentry/devices/memdev/random.go93
-rw-r--r--pkg/sentry/devices/memdev/zero.go89
-rw-r--r--pkg/sentry/devices/ttydev/BUILD16
-rw-r--r--pkg/sentry/devices/ttydev/ttydev.go51
-rw-r--r--pkg/sentry/devices/tundev/BUILD23
-rw-r--r--pkg/sentry/devices/tundev/tundev.go178
-rw-r--r--pkg/sentry/fdimport/BUILD19
-rw-r--r--pkg/sentry/fdimport/fdimport.go134
-rw-r--r--pkg/sentry/fs/BUILD21
-rw-r--r--pkg/sentry/fs/anon/BUILD7
-rw-r--r--pkg/sentry/fs/anon/anon.go4
-rw-r--r--pkg/sentry/fs/attr.go7
-rw-r--r--pkg/sentry/fs/context.go2
-rw-r--r--pkg/sentry/fs/copy_up.go55
-rw-r--r--pkg/sentry/fs/copy_up_test.go8
-rw-r--r--pkg/sentry/fs/dev/BUILD15
-rw-r--r--pkg/sentry/fs/dev/dev.go17
-rw-r--r--pkg/sentry/fs/dev/fs.go2
-rw-r--r--pkg/sentry/fs/dev/full.go4
-rw-r--r--pkg/sentry/fs/dev/net_tun.go177
-rw-r--r--pkg/sentry/fs/dev/null.go2
-rw-r--r--pkg/sentry/fs/dev/random.go6
-rw-r--r--pkg/sentry/fs/dev/tty.go2
-rw-r--r--pkg/sentry/fs/dirent.go278
-rw-r--r--pkg/sentry/fs/dirent_cache.go8
-rw-r--r--pkg/sentry/fs/dirent_cache_limiter.go3
-rw-r--r--pkg/sentry/fs/dirent_refs_test.go20
-rw-r--r--pkg/sentry/fs/dirent_state.go3
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD19
-rw-r--r--pkg/sentry/fs/fdpipe/pipe.go10
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener.go2
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_opener_test.go23
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_state.go4
-rw-r--r--pkg/sentry/fs/fdpipe/pipe_test.go26
-rw-r--r--pkg/sentry/fs/file.go29
-rw-r--r--pkg/sentry/fs/file_operations.go7
-rw-r--r--pkg/sentry/fs/file_overlay.go30
-rw-r--r--pkg/sentry/fs/file_overlay_test.go85
-rw-r--r--pkg/sentry/fs/filesystems.go18
-rw-r--r--pkg/sentry/fs/filetest/BUILD9
-rw-r--r--pkg/sentry/fs/filetest/filetest.go6
-rw-r--r--pkg/sentry/fs/flags.go7
-rw-r--r--pkg/sentry/fs/fs.go8
-rw-r--r--pkg/sentry/fs/fsutil/BUILD34
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set.go13
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set_test.go2
-rw-r--r--pkg/sentry/fs/fsutil/file.go8
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go35
-rw-r--r--pkg/sentry/fs/fsutil/frame_ref_set.go55
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go28
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go2
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go27
-rw-r--r--pkg/sentry/fs/fsutil/inode.go62
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go35
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go8
-rw-r--r--pkg/sentry/fs/g3doc/.gitignore1
-rw-r--r--pkg/sentry/fs/g3doc/fuse.md263
-rw-r--r--pkg/sentry/fs/g3doc/inotify.md16
-rw-r--r--pkg/sentry/fs/gofer/BUILD19
-rw-r--r--pkg/sentry/fs/gofer/attr.go22
-rw-r--r--pkg/sentry/fs/gofer/cache_policy.go5
-rw-r--r--pkg/sentry/fs/gofer/context_file.go30
-rw-r--r--pkg/sentry/fs/gofer/fifo.go40
-rw-r--r--pkg/sentry/fs/gofer/file.go12
-rw-r--r--pkg/sentry/fs/gofer/file_state.go11
-rw-r--r--pkg/sentry/fs/gofer/fs.go5
-rw-r--r--pkg/sentry/fs/gofer/gofer_test.go14
-rw-r--r--pkg/sentry/fs/gofer/handles.go15
-rw-r--r--pkg/sentry/fs/gofer/inode.go59
-rw-r--r--pkg/sentry/fs/gofer/inode_state.go3
-rw-r--r--pkg/sentry/fs/gofer/path.go213
-rw-r--r--pkg/sentry/fs/gofer/session.go205
-rw-r--r--pkg/sentry/fs/gofer/session_state.go18
-rw-r--r--pkg/sentry/fs/gofer/socket.go20
-rw-r--r--pkg/sentry/fs/gofer/util.go18
-rw-r--r--pkg/sentry/fs/host/BUILD25
-rw-r--r--pkg/sentry/fs/host/control.go10
-rw-r--r--pkg/sentry/fs/host/descriptor.go37
-rw-r--r--pkg/sentry/fs/host/descriptor_state.go2
-rw-r--r--pkg/sentry/fs/host/descriptor_test.go4
-rw-r--r--pkg/sentry/fs/host/file.go24
-rw-r--r--pkg/sentry/fs/host/fs.go339
-rw-r--r--pkg/sentry/fs/host/fs_test.go380
-rw-r--r--pkg/sentry/fs/host/host.go59
-rw-r--r--pkg/sentry/fs/host/inode.go153
-rw-r--r--pkg/sentry/fs/host/inode_state.go32
-rw-r--r--pkg/sentry/fs/host/inode_test.go73
-rw-r--r--pkg/sentry/fs/host/ioctl_unsafe.go4
-rw-r--r--pkg/sentry/fs/host/socket.go26
-rw-r--r--pkg/sentry/fs/host/socket_iovec.go9
-rw-r--r--pkg/sentry/fs/host/socket_test.go48
-rw-r--r--pkg/sentry/fs/host/socket_unsafe.go4
-rw-r--r--pkg/sentry/fs/host/tty.go22
-rw-r--r--pkg/sentry/fs/host/util.go90
-rw-r--r--pkg/sentry/fs/host/util_amd64_unsafe.go41
-rw-r--r--pkg/sentry/fs/host/util_arm64_unsafe.go41
-rw-r--r--pkg/sentry/fs/host/util_unsafe.go60
-rw-r--r--pkg/sentry/fs/host/wait_test.go7
-rw-r--r--pkg/sentry/fs/inode.go55
-rw-r--r--pkg/sentry/fs/inode_inotify.go8
-rw-r--r--pkg/sentry/fs/inode_operations.go38
-rw-r--r--pkg/sentry/fs/inode_overlay.go93
-rw-r--r--pkg/sentry/fs/inode_overlay_test.go14
-rw-r--r--pkg/sentry/fs/inotify.go19
-rw-r--r--pkg/sentry/fs/inotify_event.go4
-rw-r--r--pkg/sentry/fs/inotify_watch.go11
-rw-r--r--pkg/sentry/fs/lock/BUILD7
-rw-r--r--pkg/sentry/fs/lock/lock.go46
-rw-r--r--pkg/sentry/fs/lock/lock_set_functions.go8
-rw-r--r--pkg/sentry/fs/lock/lock_test.go111
-rw-r--r--pkg/sentry/fs/mock.go2
-rw-r--r--pkg/sentry/fs/mount.go14
-rw-r--r--pkg/sentry/fs/mount_overlay.go8
-rw-r--r--pkg/sentry/fs/mount_test.go40
-rw-r--r--pkg/sentry/fs/mounts.go134
-rw-r--r--pkg/sentry/fs/mounts_test.go4
-rw-r--r--pkg/sentry/fs/offset.go2
-rw-r--r--pkg/sentry/fs/overlay.go19
-rw-r--r--pkg/sentry/fs/proc/BUILD18
-rw-r--r--pkg/sentry/fs/proc/README.md4
-rw-r--r--pkg/sentry/fs/proc/cgroup.go6
-rw-r--r--pkg/sentry/fs/proc/cpuinfo.go14
-rw-r--r--pkg/sentry/fs/proc/device/BUILD3
-rw-r--r--pkg/sentry/fs/proc/exec_args.go8
-rw-r--r--pkg/sentry/fs/proc/fds.go24
-rw-r--r--pkg/sentry/fs/proc/filesystems.go6
-rw-r--r--pkg/sentry/fs/proc/fs.go6
-rw-r--r--pkg/sentry/fs/proc/inode.go8
-rw-r--r--pkg/sentry/fs/proc/loadavg.go6
-rw-r--r--pkg/sentry/fs/proc/meminfo.go18
-rw-r--r--pkg/sentry/fs/proc/mounts.go57
-rw-r--r--pkg/sentry/fs/proc/net.go68
-rw-r--r--pkg/sentry/fs/proc/proc.go17
-rw-r--r--pkg/sentry/fs/proc/rpcinet_proc.go217
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD17
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile.go6
-rw-r--r--pkg/sentry/fs/proc/seqfile/seqfile_test.go6
-rw-r--r--pkg/sentry/fs/proc/stat.go6
-rw-r--r--pkg/sentry/fs/proc/sys.go17
-rw-r--r--pkg/sentry/fs/proc/sys_net.go135
-rw-r--r--pkg/sentry/fs/proc/sys_net_state.go1
-rw-r--r--pkg/sentry/fs/proc/sys_net_test.go8
-rw-r--r--pkg/sentry/fs/proc/task.go210
-rw-r--r--pkg/sentry/fs/proc/uid_gid_map.go8
-rw-r--r--pkg/sentry/fs/proc/uptime.go8
-rw-r--r--pkg/sentry/fs/proc/version.go6
-rw-r--r--pkg/sentry/fs/ramfs/BUILD13
-rw-r--r--pkg/sentry/fs/ramfs/dir.go22
-rw-r--r--pkg/sentry/fs/ramfs/socket.go2
-rw-r--r--pkg/sentry/fs/ramfs/symlink.go2
-rw-r--r--pkg/sentry/fs/ramfs/tree.go4
-rw-r--r--pkg/sentry/fs/ramfs/tree_test.go4
-rw-r--r--pkg/sentry/fs/restore.go2
-rw-r--r--pkg/sentry/fs/splice.go2
-rw-r--r--pkg/sentry/fs/sys/BUILD7
-rw-r--r--pkg/sentry/fs/sys/devices.go2
-rw-r--r--pkg/sentry/fs/sys/fs.go2
-rw-r--r--pkg/sentry/fs/sys/sys.go4
-rw-r--r--pkg/sentry/fs/timerfd/BUILD7
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go8
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD17
-rw-r--r--pkg/sentry/fs/tmpfs/file_regular.go4
-rw-r--r--pkg/sentry/fs/tmpfs/file_test.go6
-rw-r--r--pkg/sentry/fs/tmpfs/fs.go5
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go20
-rw-r--r--pkg/sentry/fs/tmpfs/tmpfs.go38
-rw-r--r--pkg/sentry/fs/tty/BUILD17
-rw-r--r--pkg/sentry/fs/tty/dir.go14
-rw-r--r--pkg/sentry/fs/tty/fs.go4
-rw-r--r--pkg/sentry/fs/tty/line_discipline.go14
-rw-r--r--pkg/sentry/fs/tty/master.go21
-rw-r--r--pkg/sentry/fs/tty/queue.go24
-rw-r--r--pkg/sentry/fs/tty/slave.go17
-rw-r--r--pkg/sentry/fs/tty/terminal.go12
-rw-r--r--pkg/sentry/fs/tty/tty_test.go4
-rw-r--r--pkg/sentry/fs/user/BUILD40
-rw-r--r--pkg/sentry/fs/user/path.go170
-rw-r--r--pkg/sentry/fs/user/user.go (renamed from runsc/boot/user.go)83
-rw-r--r--pkg/sentry/fs/user/user_test.go (renamed from runsc/boot/user_test.go)188
-rw-r--r--pkg/sentry/fsbridge/BUILD24
-rw-r--r--pkg/sentry/fsbridge/bridge.go54
-rw-r--r--pkg/sentry/fsbridge/fs.go181
-rw-r--r--pkg/sentry/fsbridge/vfs.go142
-rw-r--r--pkg/sentry/fsimpl/devpts/BUILD44
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go233
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts_test.go56
-rw-r--r--pkg/sentry/fsimpl/devpts/line_discipline.go445
-rw-r--r--pkg/sentry/fsimpl/devpts/master.go237
-rw-r--r--pkg/sentry/fsimpl/devpts/queue.go236
-rw-r--r--pkg/sentry/fsimpl/devpts/slave.go197
-rw-r--r--pkg/sentry/fsimpl/devpts/terminal.go120
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/BUILD33
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs.go219
-rw-r--r--pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go122
-rw-r--r--pkg/sentry/fsimpl/eventfd/BUILD33
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd.go285
-rw-r--r--pkg/sentry/fsimpl/eventfd/eventfd_test.go97
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD41
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD7
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go39
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_file.go13
-rw-r--r--pkg/sentry/fsimpl/ext/block_map_test.go27
-rw-r--r--pkg/sentry/fsimpl/ext/dentry.go32
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go48
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD6
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent.go10
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/extent_test.go6
-rw-r--r--pkg/sentry/fsimpl/ext/ext.go36
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go76
-rw-r--r--pkg/sentry/fsimpl/ext/extent_file.go19
-rw-r--r--pkg/sentry/fsimpl/ext/extent_test.go20
-rw-r--r--pkg/sentry/fsimpl/ext/file_description.go26
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go193
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go67
-rw-r--r--pkg/sentry/fsimpl/ext/regular_file.go39
-rw-r--r--pkg/sentry/fsimpl/ext/symlink.go20
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD63
-rw-r--r--pkg/sentry/fsimpl/fuse/connection.go437
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go397
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go428
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go324
-rw-r--r--pkg/sentry/fsimpl/fuse/init.go166
-rw-r--r--pkg/sentry/fsimpl/fuse/register.go42
-rw-r--r--pkg/sentry/fsimpl/gofer/BUILD90
-rw-r--r--pkg/sentry/fsimpl/gofer/directory.go306
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go1550
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go1708
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer_test.go67
-rw-r--r--pkg/sentry/fsimpl/gofer/handle.go130
-rw-r--r--pkg/sentry/fsimpl/gofer/host_named_pipe.go97
-rw-r--r--pkg/sentry/fsimpl/gofer/p9file.go233
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go944
-rw-r--r--pkg/sentry/fsimpl/gofer/socket.go146
-rw-r--r--pkg/sentry/fsimpl/gofer/special_file.go292
-rw-r--r--pkg/sentry/fsimpl/gofer/symlink.go47
-rw-r--r--pkg/sentry/fsimpl/gofer/time.go82
-rw-r--r--pkg/sentry/fsimpl/host/BUILD52
-rw-r--r--pkg/sentry/fsimpl/host/control.go96
-rw-r--r--pkg/sentry/fsimpl/host/host.go769
-rw-r--r--pkg/sentry/fsimpl/host/ioctl_unsafe.go56
-rw-r--r--pkg/sentry/fsimpl/host/mmap.go131
-rw-r--r--pkg/sentry/fsimpl/host/socket.go385
-rw-r--r--pkg/sentry/fsimpl/host/socket_iovec.go110
-rw-r--r--pkg/sentry/fsimpl/host/socket_unsafe.go101
-rw-r--r--pkg/sentry/fsimpl/host/tty.go390
-rw-r--r--pkg/sentry/fsimpl/host/util.go56
-rw-r--r--pkg/sentry/fsimpl/host/util_unsafe.go34
-rw-r--r--pkg/sentry/fsimpl/kernfs/BUILD75
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go147
-rw-r--r--pkg/sentry/fsimpl/kernfs/fd_impl_util.go252
-rw-r--r--pkg/sentry/fsimpl/kernfs/filesystem.go840
-rw-r--r--pkg/sentry/fsimpl/kernfs/inode_impl_util.go613
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go456
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go330
-rw-r--r--pkg/sentry/fsimpl/kernfs/symlink.go66
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD76
-rw-r--r--pkg/sentry/fsimpl/memfs/filesystem.go579
-rw-r--r--pkg/sentry/fsimpl/memfs/memfs.go302
-rw-r--r--pkg/sentry/fsimpl/memfs/regular_file.go154
-rw-r--r--pkg/sentry/fsimpl/overlay/BUILD41
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go262
-rw-r--r--pkg/sentry/fsimpl/overlay/directory.go289
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go1364
-rw-r--r--pkg/sentry/fsimpl/overlay/non_directory.go266
-rw-r--r--pkg/sentry/fsimpl/overlay/overlay.go627
-rw-r--r--pkg/sentry/fsimpl/pipefs/BUILD21
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go165
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD58
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go117
-rw-r--r--pkg/sentry/fsimpl/proc/loadavg.go40
-rw-r--r--pkg/sentry/fsimpl/proc/meminfo.go77
-rw-r--r--pkg/sentry/fsimpl/proc/net.go338
-rw-r--r--pkg/sentry/fsimpl/proc/stat.go127
-rw-r--r--pkg/sentry/fsimpl/proc/subtasks.go182
-rw-r--r--pkg/sentry/fsimpl/proc/sys.go51
-rw-r--r--pkg/sentry/fsimpl/proc/task.go374
-rw-r--r--pkg/sentry/fsimpl/proc/task_fds.go307
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go902
-rw-r--r--pkg/sentry/fsimpl/proc/task_net.go810
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go256
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go384
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go317
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys_test.go (renamed from pkg/sentry/fsimpl/proc/net_test.go)77
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_test.go505
-rw-r--r--pkg/sentry/fsimpl/proc/version.go68
-rw-r--r--pkg/sentry/fsimpl/signalfd/BUILD20
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go136
-rw-r--r--pkg/sentry/fsimpl/sockfs/BUILD18
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go109
-rw-r--r--pkg/sentry/fsimpl/sys/BUILD34
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go159
-rw-r--r--pkg/sentry/fsimpl/sys/sys_test.go89
-rw-r--r--pkg/sentry/fsimpl/testutil/BUILD37
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go180
-rw-r--r--pkg/sentry/fsimpl/testutil/testutil.go284
-rw-r--r--pkg/sentry/fsimpl/timerfd/BUILD17
-rw-r--r--pkg/sentry/fsimpl/timerfd/timerfd.go144
-rw-r--r--pkg/sentry/fsimpl/tmpfs/BUILD125
-rw-r--r--pkg/sentry/fsimpl/tmpfs/benchmark_test.go (renamed from pkg/sentry/fsimpl/memfs/benchmark_test.go)128
-rw-r--r--pkg/sentry/fsimpl/tmpfs/device_file.go49
-rw-r--r--pkg/sentry/fsimpl/tmpfs/directory.go (renamed from pkg/sentry/fsimpl/memfs/directory.go)119
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go860
-rw-r--r--pkg/sentry/fsimpl/tmpfs/named_pipe.go (renamed from pkg/sentry/fsimpl/memfs/named_pipe.go)31
-rw-r--r--pkg/sentry/fsimpl/tmpfs/pipe_test.go (renamed from pkg/sentry/fsimpl/memfs/pipe_test.go)61
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go637
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file_test.go349
-rw-r--r--pkg/sentry/fsimpl/tmpfs/socket_file.go34
-rw-r--r--pkg/sentry/fsimpl/tmpfs/stat_test.go236
-rw-r--r--pkg/sentry/fsimpl/tmpfs/symlink.go (renamed from pkg/sentry/fsimpl/memfs/symlink.go)7
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go775
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs_test.go156
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD23
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go333
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go355
-rw-r--r--pkg/sentry/hostcpu/BUILD6
-rw-r--r--pkg/sentry/hostfd/BUILD17
-rw-r--r--pkg/sentry/hostfd/hostfd.go84
-rw-r--r--pkg/sentry/hostfd/hostfd_unsafe.go85
-rw-r--r--pkg/sentry/hostmm/BUILD5
-rw-r--r--pkg/sentry/hostmm/hostmm.go2
-rw-r--r--pkg/sentry/inet/BUILD7
-rw-r--r--pkg/sentry/inet/context.go2
-rw-r--r--pkg/sentry/inet/inet.go36
-rw-r--r--pkg/sentry/inet/namespace.go102
-rw-r--r--pkg/sentry/inet/test_stack.go33
-rw-r--r--pkg/sentry/kernel/BUILD55
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go84
-rw-r--r--pkg/sentry/kernel/aio.go81
-rw-r--r--pkg/sentry/kernel/auth/BUILD8
-rw-r--r--pkg/sentry/kernel/auth/context.go2
-rw-r--r--pkg/sentry/kernel/auth/credentials.go28
-rw-r--r--pkg/sentry/kernel/auth/id_map.go2
-rw-r--r--pkg/sentry/kernel/auth/user_namespace.go2
-rw-r--r--pkg/sentry/kernel/context.go45
-rw-r--r--pkg/sentry/kernel/contexttest/BUILD7
-rw-r--r--pkg/sentry/kernel/contexttest/contexttest.go4
-rw-r--r--pkg/sentry/kernel/epoll/BUILD14
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go74
-rw-r--r--pkg/sentry/kernel/epoll/epoll_state.go16
-rw-r--r--pkg/sentry/kernel/epoll/epoll_test.go7
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD15
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd.go10
-rw-r--r--pkg/sentry/kernel/eventfd/eventfd_test.go4
-rw-r--r--pkg/sentry/kernel/fasync/BUILD5
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go21
-rw-r--r--pkg/sentry/kernel/fd_table.go367
-rw-r--r--pkg/sentry/kernel/fd_table_test.go18
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go100
-rw-r--r--pkg/sentry/kernel/fs_context.go137
-rw-r--r--pkg/sentry/kernel/futex/BUILD19
-rw-r--r--pkg/sentry/kernel/futex/futex.go48
-rw-r--r--pkg/sentry/kernel/futex/futex_test.go70
-rw-r--r--pkg/sentry/kernel/kernel.go513
-rw-r--r--pkg/sentry/kernel/kernel_opts.go20
-rw-r--r--pkg/sentry/kernel/memevent/BUILD21
-rw-r--r--pkg/sentry/kernel/memevent/memory_events.go2
-rw-r--r--pkg/sentry/kernel/pipe/BUILD38
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go115
-rw-r--r--pkg/sentry/kernel/pipe/node.go11
-rw-r--r--pkg/sentry/kernel/pipe/node_test.go8
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go139
-rw-r--r--pkg/sentry/kernel/pipe/pipe_test.go20
-rw-r--r--pkg/sentry/kernel/pipe/pipe_unsafe.go35
-rw-r--r--pkg/sentry/kernel/pipe/pipe_util.go35
-rw-r--r--pkg/sentry/kernel/pipe/reader.go3
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go4
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go410
-rw-r--r--pkg/sentry/kernel/pipe/writer.go3
-rw-r--r--pkg/sentry/kernel/ptrace.go13
-rw-r--r--pkg/sentry/kernel/ptrace_amd64.go2
-rw-r--r--pkg/sentry/kernel/ptrace_arm64.go3
-rw-r--r--pkg/sentry/kernel/rseq.go373
-rw-r--r--pkg/sentry/kernel/sched/BUILD6
-rw-r--r--pkg/sentry/kernel/seccomp.go2
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD13
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go11
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore_test.go4
-rw-r--r--pkg/sentry/kernel/sessions.go7
-rw-r--r--pkg/sentry/kernel/shm/BUILD9
-rw-r--r--pkg/sentry/kernel/shm/shm.go106
-rw-r--r--pkg/sentry/kernel/signal.go3
-rw-r--r--pkg/sentry/kernel/signal_handlers.go3
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD10
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go9
-rw-r--r--pkg/sentry/kernel/syscalls.go99
-rw-r--r--pkg/sentry/kernel/syscalls_state.go36
-rw-r--r--pkg/sentry/kernel/syslog.go12
-rw-r--r--pkg/sentry/kernel/task.go201
-rw-r--r--pkg/sentry/kernel/task_block.go8
-rw-r--r--pkg/sentry/kernel/task_clone.go56
-rw-r--r--pkg/sentry/kernel/task_context.go11
-rw-r--r--pkg/sentry/kernel/task_exec.go22
-rw-r--r--pkg/sentry/kernel/task_exit.go17
-rw-r--r--pkg/sentry/kernel/task_futex.go127
-rw-r--r--pkg/sentry/kernel/task_identity.go2
-rw-r--r--pkg/sentry/kernel/task_log.go130
-rw-r--r--pkg/sentry/kernel/task_net.go19
-rw-r--r--pkg/sentry/kernel/task_run.go89
-rw-r--r--pkg/sentry/kernel/task_sched.go4
-rw-r--r--pkg/sentry/kernel/task_signals.go35
-rw-r--r--pkg/sentry/kernel/task_start.go76
-rw-r--r--pkg/sentry/kernel/task_stop.go16
-rw-r--r--pkg/sentry/kernel/task_syscall.go36
-rw-r--r--pkg/sentry/kernel/task_usermem.go4
-rw-r--r--pkg/sentry/kernel/task_work.go38
-rw-r--r--pkg/sentry/kernel/thread_group.go44
-rw-r--r--pkg/sentry/kernel/threads.go9
-rw-r--r--pkg/sentry/kernel/time/BUILD7
-rw-r--r--pkg/sentry/kernel/time/context.go2
-rw-r--r--pkg/sentry/kernel/time/tcpip.go131
-rw-r--r--pkg/sentry/kernel/time/time.go12
-rw-r--r--pkg/sentry/kernel/timekeeper.go30
-rw-r--r--pkg/sentry/kernel/timekeeper_test.go4
-rw-r--r--pkg/sentry/kernel/tty.go15
-rw-r--r--pkg/sentry/kernel/uts_namespace.go3
-rw-r--r--pkg/sentry/kernel/vdso.go10
-rw-r--r--pkg/sentry/limits/BUILD9
-rw-r--r--pkg/sentry/limits/context.go2
-rw-r--r--pkg/sentry/limits/limits.go3
-rw-r--r--pkg/sentry/loader/BUILD16
-rw-r--r--pkg/sentry/loader/elf.go51
-rw-r--r--pkg/sentry/loader/interpreter.go10
-rw-r--r--pkg/sentry/loader/loader.go236
-rw-r--r--pkg/sentry/loader/vdso.go91
-rw-r--r--pkg/sentry/memmap/BUILD27
-rw-r--r--pkg/sentry/memmap/mapping_set.go2
-rw-r--r--pkg/sentry/memmap/mapping_set_test.go2
-rw-r--r--pkg/sentry/memmap/memmap.go79
-rw-r--r--pkg/sentry/mm/BUILD33
-rw-r--r--pkg/sentry/mm/README.md8
-rw-r--r--pkg/sentry/mm/address_space.go54
-rw-r--r--pkg/sentry/mm/aio_context.go117
-rw-r--r--pkg/sentry/mm/aio_context_state.go2
-rw-r--r--pkg/sentry/mm/debug.go2
-rw-r--r--pkg/sentry/mm/io.go6
-rw-r--r--pkg/sentry/mm/lifecycle.go46
-rw-r--r--pkg/sentry/mm/metadata.go29
-rw-r--r--pkg/sentry/mm/mm.go37
-rw-r--r--pkg/sentry/mm/mm_test.go8
-rw-r--r--pkg/sentry/mm/pma.go33
-rw-r--r--pkg/sentry/mm/procfs.go20
-rw-r--r--pkg/sentry/mm/save_restore.go2
-rw-r--r--pkg/sentry/mm/shm.go4
-rw-r--r--pkg/sentry/mm/special_mappable.go15
-rw-r--r--pkg/sentry/mm/syscalls.go12
-rw-r--r--pkg/sentry/mm/vma.go19
-rw-r--r--pkg/sentry/pgalloc/BUILD44
-rw-r--r--pkg/sentry/pgalloc/context.go2
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go298
-rw-r--r--pkg/sentry/pgalloc/pgalloc_test.go208
-rw-r--r--pkg/sentry/pgalloc/save_restore.go16
-rw-r--r--pkg/sentry/platform/BUILD27
-rw-r--r--pkg/sentry/platform/context.go2
-rw-r--r--pkg/sentry/platform/interrupt/BUILD7
-rw-r--r--pkg/sentry/platform/interrupt/interrupt.go3
-rw-r--r--pkg/sentry/platform/kvm/BUILD31
-rw-r--r--pkg/sentry/platform/kvm/address_space.go88
-rw-r--r--pkg/sentry/platform/kvm/bluepill.go32
-rw-r--r--pkg/sentry/platform/kvm/bluepill_allocator.go (renamed from pkg/sentry/platform/kvm/allocator.go)54
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64.go32
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go45
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go124
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.s89
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go97
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go12
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go91
-rw-r--r--pkg/sentry/platform/kvm/context.go15
-rw-r--r--pkg/sentry/platform/kvm/filters_amd64.go (renamed from pkg/sentry/platform/kvm/filters.go)0
-rw-r--r--pkg/sentry/platform/kvm/filters_arm64.go32
-rw-r--r--pkg/sentry/platform/kvm/kvm.go62
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64.go41
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go51
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go67
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64_unsafe.go (renamed from pkg/sentry/usermem/usermem_unsafe.go)30
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go25
-rw-r--r--pkg/sentry/platform/kvm/kvm_const_arm64.go152
-rw-r--r--pkg/sentry/platform/kvm/kvm_test.go60
-rw-r--r--pkg/sentry/platform/kvm/machine.go98
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go44
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go64
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go183
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go286
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go68
-rw-r--r--pkg/sentry/platform/kvm/physical_map.go14
-rw-r--r--pkg/sentry/platform/kvm/physical_map_amd64.go22
-rw-r--r--pkg/sentry/platform/kvm/physical_map_arm64.go (renamed from pkg/sentry/fsimpl/proc/proc.go)7
-rw-r--r--pkg/sentry/platform/kvm/testutil/BUILD4
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_amd64.go17
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.go17
-rw-r--r--pkg/sentry/platform/kvm/testutil/testutil_arm64.s27
-rw-r--r--pkg/sentry/platform/kvm/virtual_map.go2
-rw-r--r--pkg/sentry/platform/kvm/virtual_map_test.go2
-rw-r--r--pkg/sentry/platform/mmap_min_addr.go2
-rw-r--r--pkg/sentry/platform/platform.go119
-rw-r--r--pkg/sentry/platform/ptrace/BUILD12
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go27
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_amd64.go19
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64.go5
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go62
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go6
-rw-r--r--pkg/sentry/platform/ptrace/stub_amd64.s29
-rw-r--r--pkg/sentry/platform/ptrace/stub_arm64.s30
-rw-r--r--pkg/sentry/platform/ptrace/stub_unsafe.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go60
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_amd64.go130
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_arm64.go62
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go93
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go23
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/BUILD50
-rw-r--r--pkg/sentry/platform/ring0/aarch64.go111
-rw-r--r--pkg/sentry/platform/ring0/defs.go22
-rw-r--r--pkg/sentry/platform/ring0/defs_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/defs_arm64.go143
-rw-r--r--pkg/sentry/platform/ring0/entry_amd64.go6
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.go60
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s786
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD21
-rw-r--r--pkg/sentry/platform/ring0/kernel.go24
-rw-r--r--pkg/sentry/platform/ring0/kernel_amd64.go12
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go72
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go58
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s217
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/platform/ring0/offsets_amd64.go5
-rw-r--r--pkg/sentry/platform/ring0/offsets_arm64.go127
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD28
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator.go11
-rw-r--r--pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables.go19
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go215
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go9
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go57
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go80
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_test.go2
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pagetables_x86.go4
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids.go104
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go32
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s45
-rw-r--r--pkg/sentry/platform/ring0/pagetables/pcids_x86.go95
-rw-r--r--pkg/sentry/platform/ring0/pagetables/walker_arm64.go314
-rw-r--r--pkg/sentry/platform/ring0/x86.go2
-rw-r--r--pkg/sentry/sighandling/BUILD3
-rw-r--r--pkg/sentry/sighandling/sighandling.go80
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go26
-rw-r--r--pkg/sentry/socket/BUILD9
-rw-r--r--pkg/sentry/socket/control/BUILD15
-rw-r--r--pkg/sentry/socket/control/control.go300
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go131
-rw-r--r--pkg/sentry/socket/hostinet/BUILD20
-rw-r--r--pkg/sentry/socket/hostinet/socket.go273
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go14
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go203
-rw-r--r--pkg/sentry/socket/hostinet/sockopt_impl.go27
-rw-r--r--pkg/sentry/socket/hostinet/stack.go37
-rw-r--r--pkg/sentry/socket/netfilter/BUILD13
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go95
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go600
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go149
-rw-r--r--pkg/sentry/socket/netfilter/targets.go282
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go130
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go129
-rw-r--r--pkg/sentry/socket/netlink/BUILD29
-rw-r--r--pkg/sentry/socket/netlink/message.go136
-rw-r--r--pkg/sentry/socket/netlink/message_test.go312
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD7
-rw-r--r--pkg/sentry/socket/netlink/port/port.go3
-rw-r--r--pkg/sentry/socket/netlink/provider.go18
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go69
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD9
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go339
-rw-r--r--pkg/sentry/socket/netlink/socket.go291
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go152
-rw-r--r--pkg/sentry/socket/netlink/uevent/BUILD16
-rw-r--r--pkg/sentry/socket/netlink/uevent/protocol.go60
-rw-r--r--pkg/sentry/socket/netstack/BUILD20
-rw-r--r--pkg/sentry/socket/netstack/netstack.go1064
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go332
-rw-r--r--pkg/sentry/socket/netstack/provider.go22
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go141
-rw-r--r--pkg/sentry/socket/netstack/stack.go198
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD68
-rw-r--r--pkg/sentry/socket/rpcinet/conn/BUILD17
-rw-r--r--pkg/sentry/socket/rpcinet/conn/conn.go187
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/BUILD16
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go231
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go909
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go178
-rw-r--r--pkg/sentry/socket/rpcinet/stack_unsafe.go193
-rw-r--r--pkg/sentry/socket/rpcinet/syscall_rpc.proto353
-rw-r--r--pkg/sentry/socket/socket.go102
-rw-r--r--pkg/sentry/socket/unix/BUILD29
-rw-r--r--pkg/sentry/socket/unix/io.go17
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD7
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go20
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go14
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go52
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go141
-rw-r--r--pkg/sentry/socket/unix/unix.go215
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go376
-rw-r--r--pkg/sentry/state/BUILD3
-rw-r--r--pkg/sentry/state/state.go1
-rw-r--r--pkg/sentry/strace/BUILD28
-rw-r--r--pkg/sentry/strace/epoll.go89
-rw-r--r--pkg/sentry/strace/linux64_amd64.go (renamed from pkg/sentry/strace/linux64.go)63
-rw-r--r--pkg/sentry/strace/linux64_arm64.go323
-rw-r--r--pkg/sentry/strace/poll.go2
-rw-r--r--pkg/sentry/strace/select.go56
-rw-r--r--pkg/sentry/strace/signal.go2
-rw-r--r--pkg/sentry/strace/socket.go245
-rw-r--r--pkg/sentry/strace/strace.go73
-rw-r--r--pkg/sentry/strace/strace.proto3
-rw-r--r--pkg/sentry/strace/syscalls.go41
-rw-r--r--pkg/sentry/syscalls/BUILD3
-rw-r--r--pkg/sentry/syscalls/epoll.go21
-rw-r--r--pkg/sentry/syscalls/linux/BUILD25
-rw-r--r--pkg/sentry/syscalls/linux/error.go76
-rw-r--r--pkg/sentry/syscalls/linux/flags.go1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go709
-rw-r--r--pkg/sentry/syscalls/linux/linux64_amd64.go386
-rw-r--r--pkg/sentry/syscalls/linux/linux64_arm64.go313
-rw-r--r--pkg/sentry/syscalls/linux/sigset.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go201
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_amd64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_clone_arm64.go35
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go36
-rw-r--r--pkg/sentry/syscalls/linux/sys_eventfd.go19
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go308
-rw-r--r--pkg/sentry/syscalls/linux/sys_futex.go66
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_inotify.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_mempolicy.go20
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_mount.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_pipe.go20
-rw-r--r--pkg/sentry/syscalls/linux/sys_poll.go87
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go9
-rw-r--r--pkg/sentry/syscalls/linux/sys_random.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go28
-rw-r--r--pkg/sentry/syscalls/linux/sys_rlimit.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_rseq.go48
-rw-r--r--pkg/sentry/syscalls/linux/sys_seccomp.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_shm.go9
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go27
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go110
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go57
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go76
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_amd64.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat_arm64.go45
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_sysinfo.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go43
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_timer.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_timerfd.go6
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_amd64.go (renamed from pkg/sentry/syscalls/linux/sys_tls.go)0
-rw-r--r--pkg/sentry/syscalls/linux/sys_tls_arm64.go28
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go12
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go20
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go432
-rw-r--r--pkg/sentry/syscalls/linux/timespec.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD78
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go219
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go228
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/eventfd.go61
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go355
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go334
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fscontext.go131
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go161
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/inotify.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go107
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/lock.go64
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/memfd.go64
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go92
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mount.go150
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go94
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go63
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go586
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go641
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go484
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/signal.go100
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go1144
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/splice.go490
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go388
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_amd64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat_arm64.go46
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go115
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/timerfd.go127
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go268
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go356
-rw-r--r--pkg/sentry/time/BUILD10
-rw-r--r--pkg/sentry/time/calibrated_clock.go2
-rw-r--r--pkg/sentry/time/muldiv_arm64.s3
-rw-r--r--pkg/sentry/time/parameters.go12
-rw-r--r--pkg/sentry/time/parameters_test.go15
-rw-r--r--pkg/sentry/unimpl/BUILD23
-rw-r--r--pkg/sentry/unimpl/events.go2
-rw-r--r--pkg/sentry/uniqueid/BUILD5
-rw-r--r--pkg/sentry/uniqueid/context.go2
-rw-r--r--pkg/sentry/usage/BUILD6
-rw-r--r--pkg/sentry/usage/memory.go22
-rw-r--r--pkg/sentry/vfs/BUILD71
-rw-r--r--pkg/sentry/vfs/README.md6
-rw-r--r--pkg/sentry/vfs/anonfs.go314
-rw-r--r--pkg/sentry/vfs/context.go46
-rw-r--r--pkg/sentry/vfs/dentry.go344
-rw-r--r--pkg/sentry/vfs/device.go132
-rw-r--r--pkg/sentry/vfs/epoll.go383
-rw-r--r--pkg/sentry/vfs/file_description.go701
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go198
-rw-r--r--pkg/sentry/vfs/file_description_impl_util_test.go175
-rw-r--r--pkg/sentry/vfs/filesystem.go439
-rw-r--r--pkg/sentry/vfs/filesystem_impl_util.go43
-rw-r--r--pkg/sentry/vfs/filesystem_type.go71
-rw-r--r--pkg/sentry/vfs/g3doc/inotify.md210
-rw-r--r--pkg/sentry/vfs/genericfstree/BUILD16
-rw-r--r--pkg/sentry/vfs/genericfstree/genericfstree.go81
-rw-r--r--pkg/sentry/vfs/inotify.go774
-rw-r--r--pkg/sentry/vfs/lock.go109
-rw-r--r--pkg/sentry/vfs/memxattr/BUILD15
-rw-r--r--pkg/sentry/vfs/memxattr/xattr.go102
-rw-r--r--pkg/sentry/vfs/mount.go744
-rw-r--r--pkg/sentry/vfs/mount_test.go39
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go84
-rw-r--r--pkg/sentry/vfs/options.go131
-rw-r--r--pkg/sentry/vfs/pathname.go195
-rw-r--r--pkg/sentry/vfs/permissions.go199
-rw-r--r--pkg/sentry/vfs/resolving_path.go232
-rw-r--r--pkg/sentry/vfs/syscalls.go235
-rw-r--r--pkg/sentry/vfs/testutil.go139
-rw-r--r--pkg/sentry/vfs/vfs.go754
-rw-r--r--pkg/sentry/watchdog/BUILD4
-rw-r--r--pkg/sentry/watchdog/watchdog.go173
-rw-r--r--pkg/shim/runsc/BUILD16
-rw-r--r--pkg/shim/runsc/runsc.go514
-rw-r--r--pkg/shim/runsc/utils.go44
-rw-r--r--pkg/shim/v1/proc/BUILD36
-rw-r--r--pkg/shim/v1/proc/deleted_state.go49
-rw-r--r--pkg/shim/v1/proc/exec.go281
-rw-r--r--pkg/shim/v1/proc/exec_state.go154
-rw-r--r--pkg/shim/v1/proc/init.go460
-rw-r--r--pkg/shim/v1/proc/init_state.go182
-rw-r--r--pkg/shim/v1/proc/io.go162
-rw-r--r--pkg/shim/v1/proc/process.go (renamed from test/root/testdata/simple.go)38
-rw-r--r--pkg/shim/v1/proc/types.go69
-rw-r--r--pkg/shim/v1/proc/utils.go90
-rw-r--r--pkg/shim/v1/shim/BUILD40
-rw-r--r--pkg/shim/v1/shim/api.go28
-rw-r--r--pkg/shim/v1/shim/platform.go106
-rw-r--r--pkg/shim/v1/shim/service.go573
-rw-r--r--pkg/shim/v1/utils/BUILD27
-rw-r--r--pkg/shim/v1/utils/annotations.go25
-rw-r--r--pkg/shim/v1/utils/utils.go56
-rw-r--r--pkg/shim/v1/utils/volumes.go155
-rw-r--r--pkg/shim/v1/utils/volumes_test.go308
-rw-r--r--pkg/shim/v2/BUILD43
-rw-r--r--pkg/shim/v2/api.go22
-rw-r--r--pkg/shim/v2/epoll.go129
-rw-r--r--pkg/shim/v2/options/BUILD11
-rw-r--r--pkg/shim/v2/options/options.go33
-rw-r--r--pkg/shim/v2/runtimeoptions/BUILD20
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions.go27
-rw-r--r--pkg/shim/v2/runtimeoptions/runtimeoptions.proto25
-rw-r--r--pkg/shim/v2/service.go824
-rw-r--r--pkg/shim/v2/service_linux.go108
-rw-r--r--pkg/sleep/BUILD7
-rw-r--r--pkg/sleep/commit_noasm.go13
-rw-r--r--pkg/sleep/sleep_test.go25
-rw-r--r--pkg/sleep/sleep_unsafe.go32
-rw-r--r--pkg/state/BUILD80
-rw-r--r--pkg/state/README.md158
-rw-r--r--pkg/state/decode.go922
-rw-r--r--pkg/state/decode_unsafe.go27
-rw-r--r--pkg/state/encode.go1029
-rw-r--r--pkg/state/encode_unsafe.go48
-rw-r--r--pkg/state/map.go221
-rw-r--r--pkg/state/object.proto140
-rw-r--r--pkg/state/pretty/BUILD13
-rw-r--r--pkg/state/pretty/pretty.go273
-rw-r--r--pkg/state/printer.go251
-rw-r--r--pkg/state/state.go364
-rw-r--r--pkg/state/state_norace.go19
-rw-r--r--pkg/state/state_race.go19
-rw-r--r--pkg/state/state_test.go720
-rw-r--r--pkg/state/statefile/BUILD7
-rw-r--r--pkg/state/statefile/statefile.go15
-rw-r--r--pkg/state/stats.go117
-rw-r--r--pkg/state/tests/BUILD43
-rw-r--r--pkg/state/tests/array.go (renamed from test/root/testdata/sandbox.go)33
-rw-r--r--pkg/state/tests/array_test.go134
-rw-r--r--pkg/state/tests/bench.go (renamed from test/root/testdata/httpd.go)24
-rw-r--r--pkg/state/tests/bench_test.go153
-rw-r--r--pkg/state/tests/bool_test.go (renamed from test/root/testdata/busybox.go)31
-rw-r--r--pkg/state/tests/float_test.go118
-rw-r--r--pkg/state/tests/integer.go163
-rw-r--r--pkg/state/tests/integer_test.go94
-rw-r--r--pkg/state/tests/load.go61
-rw-r--r--pkg/state/tests/load_test.go70
-rw-r--r--pkg/state/tests/map.go28
-rw-r--r--pkg/state/tests/map_test.go90
-rw-r--r--pkg/state/tests/register.go (renamed from pkg/sentry/socket/rpcinet/rpcinet.go)9
-rw-r--r--pkg/state/tests/register_test.go167
-rw-r--r--pkg/state/tests/string_test.go34
-rw-r--r--pkg/state/tests/struct.go65
-rw-r--r--pkg/state/tests/struct_test.go89
-rw-r--r--pkg/state/tests/tests.go215
-rw-r--r--pkg/state/types.go361
-rw-r--r--pkg/state/wire/BUILD12
-rw-r--r--pkg/state/wire/wire.go970
-rw-r--r--pkg/sync/BUILD (renamed from third_party/gvsync/BUILD)23
-rw-r--r--pkg/sync/LICENSE (renamed from third_party/gvsync/LICENSE)0
-rw-r--r--pkg/sync/README.md (renamed from third_party/gvsync/README.md)4
-rw-r--r--pkg/sync/aliases.go36
-rw-r--r--pkg/sync/atomicptr_unsafe.go (renamed from third_party/gvsync/atomicptr_unsafe.go)0
-rw-r--r--pkg/sync/atomicptrtest/BUILD (renamed from third_party/gvsync/atomicptrtest/BUILD)7
-rw-r--r--pkg/sync/atomicptrtest/atomicptr_test.go (renamed from third_party/gvsync/atomicptrtest/atomicptr_test.go)0
-rw-r--r--pkg/sync/memmove_unsafe.go (renamed from third_party/gvsync/memmove_unsafe.go)4
-rw-r--r--pkg/sync/mutex_test.go71
-rw-r--r--pkg/sync/mutex_unsafe.go49
-rw-r--r--pkg/sync/nocopy.go28
-rw-r--r--pkg/sync/norace_unsafe.go (renamed from third_party/gvsync/norace_unsafe.go)2
-rw-r--r--pkg/sync/race_unsafe.go (renamed from third_party/gvsync/race_unsafe.go)2
-rw-r--r--pkg/sync/rwmutex_test.go (renamed from third_party/gvsync/downgradable_rwmutex_test.go)69
-rw-r--r--pkg/sync/rwmutex_unsafe.go (renamed from third_party/gvsync/downgradable_rwmutex_unsafe.go)95
-rw-r--r--pkg/sync/seqatomic_unsafe.go (renamed from third_party/gvsync/seqatomic_unsafe.go)16
-rw-r--r--pkg/sync/seqatomictest/BUILD (renamed from third_party/gvsync/seqatomictest/BUILD)13
-rw-r--r--pkg/sync/seqatomictest/seqatomic_test.go (renamed from third_party/gvsync/seqatomictest/seqatomic_test.go)18
-rw-r--r--pkg/sync/seqcount.go (renamed from third_party/gvsync/seqcount.go)2
-rw-r--r--pkg/sync/seqcount_test.go (renamed from third_party/gvsync/seqcount_test.go)2
-rw-r--r--pkg/sync/sync.go (renamed from third_party/gvsync/gvsync.go)4
-rw-r--r--pkg/syncevent/BUILD39
-rw-r--r--pkg/syncevent/broadcaster.go218
-rw-r--r--pkg/syncevent/broadcaster_test.go376
-rw-r--r--pkg/syncevent/receiver.go103
-rw-r--r--pkg/syncevent/source.go59
-rw-r--r--pkg/syncevent/syncevent.go32
-rw-r--r--pkg/syncevent/syncevent_example_test.go108
-rw-r--r--pkg/syncevent/waiter_amd64.s32
-rw-r--r--pkg/syncevent/waiter_arm64.s34
-rw-r--r--pkg/syncevent/waiter_asm_unsafe.go24
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_test.go414
-rw-r--r--pkg/syncevent/waiter_unsafe.go206
-rw-r--r--pkg/syserr/BUILD3
-rw-r--r--pkg/syserr/netstack.go2
-rw-r--r--pkg/syserror/BUILD4
-rw-r--r--pkg/syserror/syserror.go5
-rw-r--r--pkg/tcpip/BUILD16
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD7
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go132
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go95
-rw-r--r--pkg/tcpip/buffer/BUILD6
-rw-r--r--pkg/tcpip/buffer/prependable.go8
-rw-r--r--pkg/tcpip/buffer/prependable_test.go2
-rw-r--r--pkg/tcpip/buffer/view.go142
-rw-r--r--pkg/tcpip/buffer/view_test.go286
-rw-r--r--pkg/tcpip/checker/BUILD4
-rw-r--r--pkg/tcpip/checker/checker.go246
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD10
-rw-r--r--pkg/tcpip/header/BUILD20
-rw-r--r--pkg/tcpip/header/arp.go77
-rw-r--r--pkg/tcpip/header/checksum.go133
-rw-r--r--pkg/tcpip/header/checksum_test.go62
-rw-r--r--pkg/tcpip/header/eth.go45
-rw-r--r--pkg/tcpip/header/eth_test.go34
-rw-r--r--pkg/tcpip/header/icmpv4.go16
-rw-r--r--pkg/tcpip/header/icmpv6.go38
-rw-r--r--pkg/tcpip/header/ipv4.go24
-rw-r--r--pkg/tcpip/header/ipv6.go240
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers.go551
-rw-r--r--pkg/tcpip/header/ipv6_extension_headers_test.go992
-rw-r--r--pkg/tcpip/header/ipv6_test.go417
-rw-r--r--pkg/tcpip/header/ndp_options.go631
-rw-r--r--pkg/tcpip/header/ndp_router_solicit.go36
-rw-r--r--pkg/tcpip/header/ndp_test.go1315
-rw-r--r--pkg/tcpip/header/ndpoptionidentifier_string.go50
-rw-r--r--pkg/tcpip/header/tcp.go91
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/iptables/BUILD15
-rw-r--r--pkg/tcpip/iptables/iptables.go81
-rw-r--r--pkg/tcpip/iptables/targets.go35
-rw-r--r--pkg/tcpip/iptables/types.go196
-rw-r--r--pkg/tcpip/link/channel/BUILD7
-rw-r--r--pkg/tcpip/link/channel/channel.go246
-rw-r--r--pkg/tcpip/link/fdbased/BUILD14
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go248
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go278
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_unsafe.go9
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go16
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go45
-rw-r--r--pkg/tcpip/link/loopback/BUILD5
-rw-r--r--pkg/tcpip/link/loopback/loopback.go49
-rw-r--r--pkg/tcpip/link/muxed/BUILD11
-rw-r--r--pkg/tcpip/link/muxed/injectable.go24
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go21
-rw-r--r--pkg/tcpip/link/nested/BUILD32
-rw-r--r--pkg/tcpip/link/nested/nested.go152
-rw-r--r--pkg/tcpip/link/nested/nested_test.go109
-rw-r--r--pkg/tcpip/link/packetsocket/BUILD14
-rw-r--r--pkg/tcpip/link/packetsocket/endpoint.go50
-rw-r--r--pkg/tcpip/link/qdisc/fifo/BUILD20
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go227
-rw-r--r--pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go84
-rw-r--r--pkg/tcpip/link/rawfile/BUILD7
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go2
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go33
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD12
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD9
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD8
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go53
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go119
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go14
-rw-r--r--pkg/tcpip/link/sniffer/BUILD8
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go311
-rw-r--r--pkg/tcpip/link/tun/BUILD24
-rw-r--r--pkg/tcpip/link/tun/device.go383
-rw-r--r--pkg/tcpip/link/tun/protocol.go56
-rw-r--r--pkg/tcpip/link/waitable/BUILD12
-rw-r--r--pkg/tcpip/link/waitable/waitable.go38
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go39
-rw-r--r--pkg/tcpip/network/BUILD3
-rw-r--r--pkg/tcpip/network/arp/BUILD8
-rw-r--r--pkg/tcpip/network/arp/arp.go128
-rw-r--r--pkg/tcpip/network/arp/arp_test.go87
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD17
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go108
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go139
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go6
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go2
-rw-r--r--pkg/tcpip/network/hash/BUILD3
-rw-r--r--pkg/tcpip/network/hash/hash.go4
-rw-r--r--pkg/tcpip/network/ip_test.go190
-rw-r--r--pkg/tcpip/network/ipv4/BUILD9
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go104
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go419
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go520
-rw-r--r--pkg/tcpip/network/ipv6/BUILD13
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go509
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go340
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go440
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go1324
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go852
-rw-r--r--pkg/tcpip/ports/BUILD9
-rw-r--r--pkg/tcpip/ports/ports.go346
-rw-r--r--pkg/tcpip/ports/ports_test.go234
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD3
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go2
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD3
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go2
-rw-r--r--pkg/tcpip/seqnum/BUILD7
-rw-r--r--pkg/tcpip/seqnum/seqnum.go5
-rw-r--r--pkg/tcpip/stack/BUILD103
-rw-r--r--pkg/tcpip/stack/conntrack.go631
-rw-r--r--pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go40
-rw-r--r--pkg/tcpip/stack/fake_time_test.go209
-rw-r--r--pkg/tcpip/stack/forwarder.go131
-rw-r--r--pkg/tcpip/stack/forwarder_test.go648
-rw-r--r--pkg/tcpip/stack/headertype_string.go39
-rw-r--r--pkg/tcpip/stack/iptables.go423
-rw-r--r--pkg/tcpip/stack/iptables_state.go (renamed from pkg/sentry/fsimpl/proc/filesystems.go)33
-rw-r--r--pkg/tcpip/stack/iptables_targets.go163
-rw-r--r--pkg/tcpip/stack/iptables_types.go262
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go4
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go4
-rw-r--r--pkg/tcpip/stack/ndp.go1910
-rw-r--r--pkg/tcpip/stack/ndp_test.go5204
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go333
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go1726
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go482
-rw-r--r--pkg/tcpip/stack/neighbor_entry_test.go2870
-rw-r--r--pkg/tcpip/stack/neighborstate_string.go44
-rw-r--r--pkg/tcpip/stack/nic.go1346
-rw-r--r--pkg/tcpip/stack/nic_test.go316
-rw-r--r--pkg/tcpip/stack/nud.go466
-rw-r--r--pkg/tcpip/stack/nud_test.go795
-rw-r--r--pkg/tcpip/stack/packet_buffer.go299
-rw-r--r--pkg/tcpip/stack/packet_buffer_test.go397
-rw-r--r--pkg/tcpip/stack/rand.go40
-rw-r--r--pkg/tcpip/stack/registration.go197
-rw-r--r--pkg/tcpip/stack/route.go117
-rw-r--r--pkg/tcpip/stack/stack.go847
-rw-r--r--pkg/tcpip/stack/stack_options.go106
-rw-r--r--pkg/tcpip/stack/stack_test.go2219
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go566
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go294
-rw-r--r--pkg/tcpip/stack/transport_test.go121
-rw-r--r--pkg/tcpip/tcpip.go565
-rw-r--r--pkg/tcpip/tcpip_test.go2
-rw-r--r--pkg/tcpip/tests/integration/BUILD22
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go438
-rw-r--r--pkg/tcpip/time_unsafe.go34
-rw-r--r--pkg/tcpip/timer.go206
-rw-r--r--pkg/tcpip/timer_test.go268
-rw-r--r--pkg/tcpip/transport/icmp/BUILD14
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go196
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go26
-rw-r--r--pkg/tcpip/transport/packet/BUILD13
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go262
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/raw/BUILD13
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go266
-rw-r--r--pkg/tcpip/transport/tcp/BUILD63
-rw-r--r--pkg/tcpip/transport/tcp/accept.go445
-rw-r--r--pkg/tcpip/transport/tcp/connect.go978
-rw-r--r--pkg/tcpip/transport/tcp/connect_unsafe.go (renamed from pkg/log/glog_unsafe.go)20
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go234
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go17
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go1828
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go210
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go14
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go408
-rw-r--r--pkg/tcpip/transport/tcp/rack.go82
-rw-r--r--pkg/tcpip/transport/tcp/rack_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go241
-rw-r--r--pkg/tcpip/transport/tcp/rcv_state.go29
-rw-r--r--pkg/tcpip/transport/tcp/rcv_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/segment.go60
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go17
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go10
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/snd.go512
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go95
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go74
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go113
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go3444
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go30
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD5
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go206
-rw-r--r--pkg/tcpip/transport/tcp/timer.go1
-rw-r--r--pkg/tcpip/transport/tcp/timer_test.go47
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD4
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go30
-rw-r--r--pkg/tcpip/transport/udp/BUILD15
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go775
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go23
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go12
-rw-r--r--pkg/tcpip/transport/udp/protocol.go96
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go1240
-rw-r--r--pkg/test/criutil/BUILD (renamed from runsc/criutil/BUILD)8
-rw-r--r--pkg/test/criutil/criutil.go (renamed from runsc/criutil/criutil.go)221
-rw-r--r--pkg/test/dockerutil/BUILD42
-rw-r--r--pkg/test/dockerutil/README.md86
-rw-r--r--pkg/test/dockerutil/container.go539
-rw-r--r--pkg/test/dockerutil/dockerutil.go177
-rw-r--r--pkg/test/dockerutil/exec.go193
-rw-r--r--pkg/test/dockerutil/network.go113
-rw-r--r--pkg/test/dockerutil/profile.go147
-rw-r--r--pkg/test/dockerutil/profile_test.go116
-rw-r--r--pkg/test/testutil/BUILD (renamed from runsc/testutil/BUILD)12
-rw-r--r--pkg/test/testutil/testutil.go (renamed from runsc/testutil/testutil.go)347
-rw-r--r--pkg/test/testutil/testutil_runfiles.go75
-rw-r--r--pkg/tmutex/BUILD18
-rw-r--r--pkg/tmutex/tmutex.go81
-rw-r--r--pkg/tmutex/tmutex_test.go257
-rw-r--r--pkg/unet/BUILD7
-rw-r--r--pkg/unet/unet_test.go3
-rw-r--r--pkg/urpc/BUILD7
-rw-r--r--pkg/urpc/urpc.go2
-rw-r--r--pkg/usermem/BUILD (renamed from pkg/sentry/usermem/BUILD)19
-rw-r--r--pkg/usermem/README.md (renamed from pkg/sentry/usermem/README.md)0
-rw-r--r--pkg/usermem/access_type.go (renamed from pkg/sentry/usermem/access_type.go)0
-rw-r--r--pkg/usermem/addr.go (renamed from pkg/sentry/usermem/addr.go)17
-rw-r--r--pkg/usermem/addr_range_seq_test.go (renamed from pkg/sentry/usermem/addr_range_seq_test.go)0
-rw-r--r--pkg/usermem/addr_range_seq_unsafe.go (renamed from pkg/sentry/usermem/addr_range_seq_unsafe.go)0
-rw-r--r--pkg/usermem/bytes_io.go (renamed from pkg/sentry/usermem/bytes_io.go)41
-rw-r--r--pkg/usermem/bytes_io_unsafe.go (renamed from pkg/sentry/usermem/bytes_io_unsafe.go)2
-rw-r--r--pkg/usermem/usermem.go (renamed from pkg/sentry/usermem/usermem.go)16
-rw-r--r--pkg/usermem/usermem_arm64.go (renamed from pkg/sentry/usermem/usermem_arm64.go)0
-rw-r--r--pkg/usermem/usermem_test.go (renamed from pkg/sentry/usermem/usermem_test.go)4
-rw-r--r--pkg/usermem/usermem_x86.go (renamed from pkg/sentry/usermem/usermem_x86.go)2
-rw-r--r--pkg/waiter/BUILD15
-rw-r--r--pkg/waiter/waiter.go22
-rw-r--r--runsc/BUILD52
-rw-r--r--runsc/boot/BUILD50
-rw-r--r--runsc/boot/compat.go73
-rw-r--r--runsc/boot/compat_amd64.go93
-rw-r--r--runsc/boot/compat_arm64.go95
-rw-r--r--runsc/boot/compat_test.go45
-rw-r--r--runsc/boot/config.go47
-rw-r--r--runsc/boot/controller.go67
-rw-r--r--runsc/boot/fds.go81
-rw-r--r--runsc/boot/filter/BUILD6
-rw-r--r--runsc/boot/filter/config.go122
-rw-r--r--runsc/boot/filter/config_amd64.go31
-rw-r--r--runsc/boot/filter/config_arm64.go21
-rw-r--r--runsc/boot/filter/config_profile.go34
-rw-r--r--runsc/boot/filter/extra_filters_msan.go2
-rw-r--r--runsc/boot/fs.go172
-rw-r--r--runsc/boot/fs_test.go131
-rw-r--r--runsc/boot/limits.go2
-rw-r--r--runsc/boot/loader.go694
-rw-r--r--runsc/boot/loader_test.go186
-rw-r--r--runsc/boot/network.go130
-rw-r--r--runsc/boot/platforms/BUILD3
-rw-r--r--runsc/boot/pprof/BUILD11
-rw-r--r--runsc/boot/pprof/pprof.go (renamed from runsc/boot/pprof.go)6
-rw-r--r--runsc/boot/vfs.go519
-rw-r--r--runsc/cgroup/BUILD13
-rw-r--r--runsc/cgroup/cgroup.go153
-rw-r--r--runsc/cgroup/cgroup_test.go582
-rw-r--r--runsc/cmd/BUILD22
-rw-r--r--runsc/cmd/boot.go94
-rw-r--r--runsc/cmd/capability_test.go11
-rw-r--r--runsc/cmd/checkpoint.go2
-rw-r--r--runsc/cmd/chroot.go2
-rw-r--r--runsc/cmd/create.go3
-rw-r--r--runsc/cmd/debug.go54
-rw-r--r--runsc/cmd/delete.go2
-rw-r--r--runsc/cmd/do.go85
-rw-r--r--runsc/cmd/events.go2
-rw-r--r--runsc/cmd/exec.go2
-rw-r--r--runsc/cmd/gofer.go26
-rw-r--r--runsc/cmd/help.go16
-rw-r--r--runsc/cmd/install.go2
-rw-r--r--runsc/cmd/kill.go2
-rw-r--r--runsc/cmd/list.go2
-rw-r--r--runsc/cmd/pause.go2
-rw-r--r--runsc/cmd/ps.go2
-rw-r--r--runsc/cmd/restore.go2
-rw-r--r--runsc/cmd/resume.go2
-rw-r--r--runsc/cmd/run.go2
-rw-r--r--runsc/cmd/spec.go224
-rw-r--r--runsc/cmd/start.go3
-rw-r--r--runsc/cmd/state.go2
-rw-r--r--runsc/cmd/statefile.go149
-rw-r--r--runsc/cmd/syscalls.go25
-rw-r--r--runsc/cmd/wait.go2
-rw-r--r--runsc/console/BUILD3
-rw-r--r--runsc/container/BUILD26
-rw-r--r--runsc/container/console_test.go142
-rw-r--r--runsc/container/container.go405
-rw-r--r--runsc/container/container_norace_test.go20
-rw-r--r--runsc/container/container_race_test.go (renamed from pkg/sentry/socket/rpcinet/device.go)7
-rw-r--r--runsc/container/container_test.go2578
-rw-r--r--runsc/container/multi_container_test.go1391
-rw-r--r--runsc/container/shared_volume_test.go22
-rw-r--r--runsc/container/state_file.go185
-rw-r--r--runsc/debian/description6
-rwxr-xr-xrunsc/debian/postinst.sh9
-rw-r--r--runsc/dockerutil/BUILD15
-rw-r--r--runsc/dockerutil/dockerutil.go467
-rw-r--r--runsc/flag/BUILD9
-rw-r--r--runsc/flag/flag.go (renamed from pkg/sentry/kernel/pipe/buffer_test.go)27
-rw-r--r--runsc/fsgofer/BUILD15
-rw-r--r--runsc/fsgofer/filter/BUILD5
-rw-r--r--runsc/fsgofer/filter/config.go20
-rw-r--r--runsc/fsgofer/filter/config_amd64.go33
-rw-r--r--runsc/fsgofer/filter/config_arm64.go (renamed from pkg/fspath/builder_unsafe.go)16
-rw-r--r--runsc/fsgofer/fsgofer.go343
-rw-r--r--runsc/fsgofer/fsgofer_amd64_unsafe.go49
-rw-r--r--runsc/fsgofer/fsgofer_arm64_unsafe.go49
-rw-r--r--runsc/fsgofer/fsgofer_test.go169
-rw-r--r--runsc/fsgofer/fsgofer_unsafe.go25
-rw-r--r--runsc/main.go97
-rw-r--r--runsc/sandbox/BUILD8
-rw-r--r--runsc/sandbox/network.go187
-rw-r--r--runsc/sandbox/sandbox.go221
-rw-r--r--runsc/specutils/BUILD10
-rw-r--r--runsc/specutils/namespace.go19
-rw-r--r--runsc/specutils/specutils.go81
-rwxr-xr-xrunsc/version_test.sh2
-rwxr-xr-xscripts/build.sh79
-rwxr-xr-xscripts/common.sh14
-rwxr-xr-xscripts/common_build.sh (renamed from scripts/common_bazel.sh)53
-rwxr-xr-xscripts/dev.sh4
-rwxr-xr-xscripts/docker_tests.sh5
-rwxr-xr-xscripts/fuse_tests.sh20
-rwxr-xr-xscripts/go.sh2
-rwxr-xr-xscripts/hostnet_tests.sh2
-rwxr-xr-xscripts/iptables_tests.sh26
-rwxr-xr-xscripts/kvm_tests.sh2
-rwxr-xr-xscripts/make_tests.sh5
-rwxr-xr-xscripts/overlay_tests.sh2
-rwxr-xr-xscripts/packetdrill_tests.sh (renamed from kokoro/ubuntu1604/build.sh)9
-rwxr-xr-xscripts/packetimpact_tests.sh (renamed from kokoro/ubuntu1804/build.sh)9
-rwxr-xr-xscripts/release.sh38
-rwxr-xr-xscripts/root_tests.sh12
-rwxr-xr-xscripts/runtime_tests.sh29
-rwxr-xr-xscripts/swgso_tests.sh2
-rwxr-xr-xscripts/syscall_kvm_tests.sh20
-rw-r--r--shim/BUILD15
-rw-r--r--shim/README.md10
-rw-r--r--shim/runsc.toml6
-rw-r--r--shim/v1/BUILD30
-rw-r--r--shim/v1/api.go24
-rw-r--r--shim/v1/config.go40
-rw-r--r--shim/v1/main.go265
-rw-r--r--shim/v2/BUILD18
-rw-r--r--shim/v2/main.go26
-rw-r--r--test/BUILD45
-rw-r--r--test/README.md4
-rw-r--r--test/benchmarks/README.md157
-rw-r--r--test/benchmarks/base/BUILD34
-rw-r--r--test/benchmarks/base/base.go31
-rw-r--r--test/benchmarks/base/size_test.go220
-rw-r--r--test/benchmarks/base/startup_test.go156
-rw-r--r--test/benchmarks/base/sysbench_test.go89
-rw-r--r--test/benchmarks/database/BUILD28
-rw-r--r--test/benchmarks/database/database.go31
-rw-r--r--test/benchmarks/database/redis_test.go123
-rw-r--r--test/benchmarks/fs/BUILD32
-rw-r--r--test/benchmarks/fs/bazel_test.go119
-rw-r--r--test/benchmarks/fs/fio_test.go170
-rw-r--r--test/benchmarks/fs/fs.go31
-rw-r--r--test/benchmarks/harness/BUILD18
-rw-r--r--test/benchmarks/harness/harness.go38
-rw-r--r--test/benchmarks/harness/machine.go81
-rw-r--r--test/benchmarks/harness/util.go48
-rw-r--r--test/benchmarks/media/BUILD22
-rw-r--r--test/benchmarks/media/ffmpeg_test.go53
-rw-r--r--test/benchmarks/media/media.go31
-rw-r--r--test/benchmarks/ml/BUILD22
-rw-r--r--test/benchmarks/ml/ml.go31
-rw-r--r--test/benchmarks/ml/tensorflow_test.go69
-rw-r--r--test/benchmarks/network/BUILD35
-rw-r--r--test/benchmarks/network/httpd_test.go181
-rw-r--r--test/benchmarks/network/iperf_test.go113
-rw-r--r--test/benchmarks/network/network.go31
-rw-r--r--test/benchmarks/network/nginx_test.go104
-rw-r--r--test/benchmarks/network/node_test.go127
-rw-r--r--test/benchmarks/network/ruby_test.go134
-rw-r--r--test/benchmarks/tcp/BUILD41
-rw-r--r--test/benchmarks/tcp/README.md87
-rw-r--r--test/benchmarks/tcp/nsjoin.c47
-rwxr-xr-xtest/benchmarks/tcp/tcp_benchmark.sh392
-rw-r--r--test/benchmarks/tcp/tcp_proxy.go451
-rw-r--r--test/benchmarks/tools/BUILD33
-rw-r--r--test/benchmarks/tools/ab.go94
-rw-r--r--test/benchmarks/tools/ab_test.go90
-rw-r--r--test/benchmarks/tools/fio.go124
-rw-r--r--test/benchmarks/tools/fio_test.go122
-rw-r--r--test/benchmarks/tools/hey.go75
-rw-r--r--test/benchmarks/tools/hey_test.go81
-rw-r--r--test/benchmarks/tools/iperf.go56
-rw-r--r--test/benchmarks/tools/iperf_test.go34
-rw-r--r--test/benchmarks/tools/meminfo.go60
-rw-r--r--test/benchmarks/tools/meminfo_test.go84
-rw-r--r--test/benchmarks/tools/redis.go63
-rw-r--r--test/benchmarks/tools/redis_test.go87
-rw-r--r--test/benchmarks/tools/sysbench.go245
-rw-r--r--test/benchmarks/tools/sysbench_test.go169
-rw-r--r--test/benchmarks/tools/tools.go17
-rw-r--r--test/cmd/test_app/BUILD (renamed from runsc/container/test_app/BUILD)8
-rw-r--r--test/cmd/test_app/fds.go (renamed from runsc/container/test_app/fds.go)4
-rw-r--r--test/cmd/test_app/test_app.go (renamed from runsc/container/test_app/test_app.go)46
-rw-r--r--test/e2e/BUILD10
-rw-r--r--test/e2e/exec_test.go185
-rw-r--r--test/e2e/integration_test.go405
-rw-r--r--test/e2e/regression_test.go20
-rw-r--r--test/fuse/BUILD9
-rw-r--r--test/fuse/README.md103
-rw-r--r--test/fuse/linux/BUILD32
-rw-r--r--test/fuse/linux/fuse_base.cc208
-rw-r--r--test/fuse/linux/fuse_base.h99
-rw-r--r--test/fuse/linux/stat_test.cc169
-rw-r--r--test/image/BUILD9
-rw-r--r--test/image/image_test.go230
-rwxr-xr-x[-rw-r--r--]test/image/ruby.sh0
-rw-r--r--test/iptables/BUILD38
-rw-r--r--test/iptables/README.md54
-rw-r--r--test/iptables/filter_input.go745
-rw-r--r--test/iptables/filter_output.go663
-rw-r--r--test/iptables/iptables.go115
-rw-r--r--test/iptables/iptables_test.go427
-rw-r--r--test/iptables/iptables_unsafe.go63
-rw-r--r--test/iptables/iptables_util.go282
-rw-r--r--test/iptables/nat.go657
-rw-r--r--test/iptables/runner/BUILD12
-rw-r--r--test/iptables/runner/main.go79
-rw-r--r--test/packetdrill/BUILD45
-rw-r--r--test/packetdrill/accept_ack_drop.pkt27
-rw-r--r--test/packetdrill/defs.bzl91
-rw-r--r--test/packetdrill/fin_wait2_timeout.pkt23
-rw-r--r--test/packetdrill/listen_close_before_handshake_complete.pkt31
-rw-r--r--test/packetdrill/no_rst_to_rst.pkt36
-rwxr-xr-xtest/packetdrill/packetdrill_setup.sh26
-rwxr-xr-xtest/packetdrill/packetdrill_test.sh226
-rw-r--r--test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt9
-rw-r--r--test/packetdrill/sanity_test.pkt7
-rw-r--r--test/packetdrill/tcp_defer_accept.pkt48
-rw-r--r--test/packetdrill/tcp_defer_accept_timeout.pkt48
-rw-r--r--test/packetimpact/README.md702
-rw-r--r--test/packetimpact/dut/BUILD18
-rw-r--r--test/packetimpact/dut/posix_server.cc371
-rw-r--r--test/packetimpact/netdevs/BUILD23
-rw-r--r--test/packetimpact/netdevs/netdevs.go115
-rw-r--r--test/packetimpact/netdevs/netdevs_test.go227
-rw-r--r--test/packetimpact/proto/BUILD12
-rw-r--r--test/packetimpact/proto/posix_server.proto230
-rw-r--r--test/packetimpact/runner/BUILD27
-rw-r--r--test/packetimpact/runner/defs.bzl143
-rw-r--r--test/packetimpact/runner/packetimpact_test.go383
-rw-r--r--test/packetimpact/testbench/BUILD46
-rw-r--r--test/packetimpact/testbench/connections.go1205
-rw-r--r--test/packetimpact/testbench/dut.go702
-rw-r--r--test/packetimpact/testbench/dut_client.go28
-rw-r--r--test/packetimpact/testbench/layers.go1506
-rw-r--r--test/packetimpact/testbench/layers_test.go728
-rw-r--r--test/packetimpact/testbench/rawsockets.go188
-rw-r--r--test/packetimpact/testbench/testbench.go128
-rw-r--r--test/packetimpact/tests/BUILD310
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go75
-rw-r--r--test/packetimpact/tests/icmpv6_param_problem_test.go78
-rw-r--r--test/packetimpact/tests/ipv4_id_uniqueness_test.go122
-rw-r--r--test/packetimpact/tests/ipv6_fragment_reassembly_test.go168
-rw-r--r--test/packetimpact/tests/ipv6_unknown_options_action_test.go187
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go109
-rw-r--r--test/packetimpact/tests/tcp_cork_mss_test.go84
-rw-r--r--test/packetimpact/tests/tcp_handshake_window_size_test.go66
-rw-r--r--test/packetimpact/tests/tcp_network_unreachable_test.go141
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go42
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go93
-rw-r--r--test/packetimpact/tests/tcp_paws_mechanism_test.go109
-rw-r--r--test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go132
-rw-r--r--test/packetimpact/tests/tcp_reordering_test.go174
-rw-r--r--test/packetimpact/tests/tcp_retransmits_test.go84
-rw-r--r--test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go105
-rw-r--r--test/packetimpact/tests/tcp_synrcvd_reset_test.go52
-rw-r--r--test/packetimpact/tests/tcp_synsent_reset_test.go90
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go100
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go73
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go104
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_test.go112
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go98
-rw-r--r--test/packetimpact/tests/udp_any_addr_recv_unicast_test.go51
-rw-r--r--test/packetimpact/tests/udp_discard_mcast_source_addr_test.go94
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go363
-rw-r--r--test/packetimpact/tests/udp_recv_mcast_bcast_test.go110
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go104
-rw-r--r--test/perf/BUILD117
-rw-r--r--test/perf/linux/BUILD356
-rw-r--r--test/perf/linux/clock_getres_benchmark.cc39
-rw-r--r--test/perf/linux/clock_gettime_benchmark.cc60
-rw-r--r--test/perf/linux/death_benchmark.cc36
-rw-r--r--test/perf/linux/epoll_benchmark.cc99
-rw-r--r--test/perf/linux/fork_benchmark.cc350
-rw-r--r--test/perf/linux/futex_benchmark.cc198
-rw-r--r--test/perf/linux/getdents_benchmark.cc149
-rw-r--r--test/perf/linux/getpid_benchmark.cc37
-rw-r--r--test/perf/linux/gettid_benchmark.cc38
-rw-r--r--test/perf/linux/mapping_benchmark.cc163
-rw-r--r--test/perf/linux/open_benchmark.cc56
-rw-r--r--test/perf/linux/pipe_benchmark.cc66
-rw-r--r--test/perf/linux/randread_benchmark.cc100
-rw-r--r--test/perf/linux/read_benchmark.cc53
-rw-r--r--test/perf/linux/sched_yield_benchmark.cc37
-rw-r--r--test/perf/linux/send_recv_benchmark.cc372
-rw-r--r--test/perf/linux/seqwrite_benchmark.cc66
-rw-r--r--test/perf/linux/signal_benchmark.cc61
-rw-r--r--test/perf/linux/sleep_benchmark.cc60
-rw-r--r--test/perf/linux/stat_benchmark.cc62
-rw-r--r--test/perf/linux/unlink_benchmark.cc66
-rw-r--r--test/perf/linux/write_benchmark.cc52
-rw-r--r--test/root/BUILD31
-rw-r--r--test/root/cgroup_test.go255
-rw-r--r--test/root/chroot_test.go29
-rw-r--r--test/root/crictl_test.go549
-rw-r--r--test/root/main_test.go2
-rw-r--r--test/root/oom_score_adj_test.go70
-rw-r--r--test/root/runsc_test.go151
-rw-r--r--test/root/testdata/BUILD19
-rw-r--r--test/root/testdata/containerd_config.go39
-rw-r--r--test/root/testdata/httpd_mount_paths.go53
-rw-r--r--test/runner/BUILD29
-rw-r--r--test/runner/defs.bzl249
-rw-r--r--test/runner/gtest/BUILD9
-rw-r--r--test/runner/gtest/gtest.go170
-rw-r--r--test/runner/runner.go (renamed from test/syscalls/syscall_test_runner.go)170
-rw-r--r--test/runtimes/BUILD51
-rw-r--r--test/runtimes/README.md41
-rw-r--r--test/runtimes/blacklist_go1.12.csv16
-rw-r--r--test/runtimes/blacklist_java11.csv126
-rw-r--r--test/runtimes/blacklist_nodejs12.4.0.csv47
-rw-r--r--test/runtimes/blacklist_python3.7.3.csv27
-rw-r--r--test/runtimes/build_defs.bzl57
-rw-r--r--test/runtimes/defs.bzl90
-rw-r--r--test/runtimes/exclude_go1.12.csv13
-rw-r--r--test/runtimes/exclude_java11.csv208
-rw-r--r--test/runtimes/exclude_nodejs12.4.0.csv55
-rw-r--r--test/runtimes/exclude_php7.3.6.csv (renamed from test/runtimes/blacklist_php7.3.6.csv)21
-rw-r--r--test/runtimes/exclude_python3.7.3.csv21
-rw-r--r--test/runtimes/images/Dockerfile_go1.1210
-rw-r--r--test/runtimes/proctor/BUILD (renamed from test/runtimes/images/proctor/BUILD)10
-rw-r--r--test/runtimes/proctor/go.go (renamed from test/runtimes/images/proctor/go.go)29
-rw-r--r--test/runtimes/proctor/java.go (renamed from test/runtimes/images/proctor/java.go)21
-rw-r--r--test/runtimes/proctor/nodejs.go (renamed from test/runtimes/images/proctor/nodejs.go)8
-rw-r--r--test/runtimes/proctor/php.go (renamed from test/runtimes/images/proctor/php.go)9
-rw-r--r--test/runtimes/proctor/proctor.go (renamed from test/runtimes/images/proctor/proctor.go)41
-rw-r--r--test/runtimes/proctor/proctor_test.go (renamed from test/runtimes/images/proctor/proctor_test.go)14
-rw-r--r--test/runtimes/proctor/python.go (renamed from test/runtimes/images/proctor/python.go)8
-rw-r--r--test/runtimes/runner/BUILD22
-rw-r--r--test/runtimes/runner/exclude_test.go (renamed from test/runtimes/blacklist_test.go)12
-rw-r--r--test/runtimes/runner/main.go (renamed from test/runtimes/runner.go)114
-rw-r--r--test/syscalls/BUILD432
-rw-r--r--test/syscalls/build_defs.bzl136
-rw-r--r--test/syscalls/gtest/BUILD12
-rw-r--r--test/syscalls/gtest/gtest.go93
-rw-r--r--test/syscalls/linux/32bit.cc136
-rw-r--r--test/syscalls/linux/BUILD1004
-rw-r--r--test/syscalls/linux/accept_bind.cc44
-rw-r--r--test/syscalls/linux/accept_bind_stream.cc2
-rw-r--r--test/syscalls/linux/aio.cc14
-rw-r--r--test/syscalls/linux/alarm.cc3
-rw-r--r--test/syscalls/linux/bad.cc12
-rw-r--r--test/syscalls/linux/chmod.cc1
-rw-r--r--test/syscalls/linux/chroot.cc5
-rw-r--r--test/syscalls/linux/clock_gettime.cc6
-rw-r--r--test/syscalls/linux/concurrency.cc6
-rw-r--r--test/syscalls/linux/connect_external.cc12
-rw-r--r--test/syscalls/linux/dev.cc21
-rw-r--r--test/syscalls/linux/epoll.cc26
-rw-r--r--test/syscalls/linux/eventfd.cc42
-rw-r--r--test/syscalls/linux/exceptions.cc183
-rw-r--r--test/syscalls/linux/exec.cc254
-rw-r--r--test/syscalls/linux/exec_binary.cc170
-rw-r--r--test/syscalls/linux/exec_proc_exe_workload.cc6
-rw-r--r--test/syscalls/linux/fallocate.cc57
-rw-r--r--test/syscalls/linux/fault.cc3
-rw-r--r--test/syscalls/linux/fcntl.cc454
-rw-r--r--test/syscalls/linux/file_base.h109
-rw-r--r--test/syscalls/linux/flock.cc76
-rw-r--r--test/syscalls/linux/fork.cc26
-rw-r--r--test/syscalls/linux/fpsig_fork.cc36
-rw-r--r--test/syscalls/linux/fpsig_nested.cc57
-rw-r--r--test/syscalls/linux/futex.cc113
-rw-r--r--test/syscalls/linux/getdents.cc15
-rw-r--r--test/syscalls/linux/getrandom.cc2
-rw-r--r--test/syscalls/linux/getrusage.cc2
-rw-r--r--test/syscalls/linux/inotify.cc885
-rw-r--r--test/syscalls/linux/ioctl.cc3
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc64
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h62
-rw-r--r--test/syscalls/linux/iptables.h2
-rw-r--r--test/syscalls/linux/itimer.cc24
-rw-r--r--test/syscalls/linux/link.cc18
-rw-r--r--test/syscalls/linux/lseek.cc2
-rw-r--r--test/syscalls/linux/madvise.cc12
-rw-r--r--test/syscalls/linux/memfd.cc1
-rw-r--r--test/syscalls/linux/memory_accounting.cc1
-rw-r--r--test/syscalls/linux/mempolicy.cc10
-rw-r--r--test/syscalls/linux/mkdir.cc22
-rw-r--r--test/syscalls/linux/mknod.cc26
-rw-r--r--test/syscalls/linux/mlock.cc6
-rw-r--r--test/syscalls/linux/mmap.cc226
-rw-r--r--test/syscalls/linux/mount.cc37
-rw-r--r--test/syscalls/linux/msync.cc4
-rw-r--r--test/syscalls/linux/network_namespace.cc52
-rw-r--r--test/syscalls/linux/open.cc137
-rw-r--r--test/syscalls/linux/open_create.cc27
-rw-r--r--test/syscalls/linux/packet_socket.cc258
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc370
-rw-r--r--test/syscalls/linux/partial_bad_buffer.cc138
-rw-r--r--test/syscalls/linux/ping_socket.cc91
-rw-r--r--test/syscalls/linux/pipe.cc51
-rw-r--r--test/syscalls/linux/poll.cc11
-rw-r--r--test/syscalls/linux/prctl.cc2
-rw-r--r--test/syscalls/linux/prctl_setuid.cc2
-rw-r--r--test/syscalls/linux/pread64.cc16
-rw-r--r--test/syscalls/linux/preadv.cc1
-rw-r--r--test/syscalls/linux/preadv2.cc4
-rw-r--r--test/syscalls/linux/proc.cc342
-rw-r--r--test/syscalls/linux/proc_net.cc212
-rw-r--r--test/syscalls/linux/proc_net_tcp.cc1
-rw-r--r--test/syscalls/linux/proc_net_udp.cc1
-rw-r--r--test/syscalls/linux/proc_net_unix.cc6
-rw-r--r--test/syscalls/linux/proc_pid_oomscore.cc72
-rw-r--r--test/syscalls/linux/proc_pid_smaps.cc4
-rw-r--r--test/syscalls/linux/ptrace.cc48
-rw-r--r--test/syscalls/linux/pty.cc33
-rw-r--r--test/syscalls/linux/pty_root.cc22
-rw-r--r--test/syscalls/linux/pwrite64.cc21
-rw-r--r--test/syscalls/linux/pwritev2.cc61
-rw-r--r--test/syscalls/linux/raw_socket.cc869
-rw-r--r--test/syscalls/linux/raw_socket_hdrincl.cc101
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc2
-rw-r--r--test/syscalls/linux/raw_socket_ipv4.cc392
-rw-r--r--test/syscalls/linux/read.cc1
-rw-r--r--test/syscalls/linux/readv.cc4
-rw-r--r--test/syscalls/linux/readv_common.cc45
-rw-r--r--test/syscalls/linux/readv_socket.cc45
-rw-r--r--test/syscalls/linux/rename.cc1
-rw-r--r--test/syscalls/linux/rseq.cc198
-rw-r--r--test/syscalls/linux/rseq/BUILD61
-rw-r--r--test/syscalls/linux/rseq/critical.h39
-rw-r--r--test/syscalls/linux/rseq/critical_amd64.S66
-rw-r--r--test/syscalls/linux/rseq/critical_arm64.S66
-rw-r--r--test/syscalls/linux/rseq/rseq.cc366
-rw-r--r--test/syscalls/linux/rseq/start_amd64.S45
-rw-r--r--test/syscalls/linux/rseq/start_arm64.S45
-rw-r--r--test/syscalls/linux/rseq/syscalls.h69
-rw-r--r--test/syscalls/linux/rseq/test.h43
-rw-r--r--test/syscalls/linux/rseq/types.h31
-rw-r--r--test/syscalls/linux/rseq/uapi.h51
-rw-r--r--test/syscalls/linux/rtsignal.cc3
-rw-r--r--test/syscalls/linux/seccomp.cc57
-rw-r--r--test/syscalls/linux/select.cc3
-rw-r--r--test/syscalls/linux/semaphore.cc5
-rw-r--r--test/syscalls/linux/sendfile.cc51
-rw-r--r--test/syscalls/linux/sendfile_socket.cc107
-rw-r--r--test/syscalls/linux/shm.cc3
-rw-r--r--test/syscalls/linux/sigaction.cc53
-rw-r--r--test/syscalls/linux/sigaltstack.cc14
-rw-r--r--test/syscalls/linux/sigiret.cc7
-rw-r--r--test/syscalls/linux/signalfd.cc120
-rw-r--r--test/syscalls/linux/sigprocmask.cc2
-rw-r--r--test/syscalls/linux/sigstop.cc2
-rw-r--r--test/syscalls/linux/sigtimedwait.cc3
-rw-r--r--test/syscalls/linux/socket.cc38
-rw-r--r--test/syscalls/linux/socket_abstract.cc2
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc25
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc400
-rw-r--r--test/syscalls/linux/socket_blocking.cc1
-rw-r--r--test/syscalls/linux/socket_capability.cc61
-rw-r--r--test/syscalls/linux/socket_filesystem.cc2
-rw-r--r--test/syscalls/linux/socket_generic.cc98
-rw-r--r--test/syscalls/linux/socket_generic_stress.cc130
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc1626
-rw-r--r--test/syscalls/linux/socket_inet_loopback_nogotsan.cc174
-rw-r--r--test/syscalls/linux/socket_ip_loopback_blocking.cc3
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc371
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic_loopback.cc3
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback.cc2
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc3
-rw-r--r--test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc3
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc259
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback.cc2
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback_blocking.cc2
-rw-r--r--test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_ip_unbound.cc122
-rw-r--r--test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc1
-rw-r--r--test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc6
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc1835
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.h4
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc106
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h12
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc6
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc13
-rw-r--r--test/syscalls/linux/socket_netdevice.cc36
-rw-r--r--test/syscalls/linux/socket_netlink.cc153
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc568
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc162
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h55
-rw-r--r--test/syscalls/linux/socket_netlink_uevent.cc83
-rw-r--r--test/syscalls/linux/socket_netlink_util.cc90
-rw-r--r--test/syscalls/linux/socket_netlink_util.h25
-rw-r--r--test/syscalls/linux/socket_non_stream.cc113
-rw-r--r--test/syscalls/linux/socket_non_stream_blocking.cc37
-rw-r--r--test/syscalls/linux/socket_stream.cc55
-rw-r--r--test/syscalls/linux/socket_stream_blocking.cc64
-rw-r--r--test/syscalls/linux/socket_test_util.cc122
-rw-r--r--test/syscalls/linux/socket_test_util.h14
-rw-r--r--test/syscalls/linux/socket_unix.cc20
-rw-r--r--test/syscalls/linux/socket_unix_abstract_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_unix_blocking_local.cc5
-rw-r--r--test/syscalls/linux/socket_unix_cmsg.cc29
-rw-r--r--test/syscalls/linux/socket_unix_dgram.cc1
-rw-r--r--test/syscalls/linux/socket_unix_dgram_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_dgram_non_blocking.cc1
-rw-r--r--test/syscalls/linux/socket_unix_domain.cc2
-rw-r--r--test/syscalls/linux/socket_unix_filesystem_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_unix_non_stream.cc4
-rw-r--r--test/syscalls/linux/socket_unix_non_stream_blocking_local.cc5
-rw-r--r--test/syscalls/linux/socket_unix_pair.cc2
-rw-r--r--test/syscalls/linux/socket_unix_pair_nonblock.cc2
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket.cc19
-rw-r--r--test/syscalls/linux/socket_unix_seqpacket_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_stream.cc14
-rw-r--r--test/syscalls/linux/socket_unix_stream_blocking_local.cc5
-rw-r--r--test/syscalls/linux/socket_unix_stream_local.cc2
-rw-r--r--test/syscalls/linux/socket_unix_stream_nonblock_local.cc5
-rw-r--r--test/syscalls/linux/socket_unix_unbound_abstract.cc1
-rw-r--r--test/syscalls/linux/socket_unix_unbound_filesystem.cc1
-rw-r--r--test/syscalls/linux/socket_unix_unbound_seqpacket.cc1
-rw-r--r--test/syscalls/linux/socket_unix_unbound_stream.cc1
-rw-r--r--test/syscalls/linux/splice.cc106
-rw-r--r--test/syscalls/linux/stat.cc75
-rw-r--r--test/syscalls/linux/sticky.cc68
-rw-r--r--test/syscalls/linux/symlink.cc27
-rw-r--r--test/syscalls/linux/sync.cc3
-rw-r--r--test/syscalls/linux/sysret.cc35
-rw-r--r--test/syscalls/linux/tcp_socket.cc497
-rw-r--r--test/syscalls/linux/time.cc3
-rw-r--r--test/syscalls/linux/timerfd.cc29
-rw-r--r--test/syscalls/linux/timers.cc20
-rw-r--r--test/syscalls/linux/tkill.cc2
-rw-r--r--test/syscalls/linux/truncate.cc1
-rw-r--r--test/syscalls/linux/tuntap.cc422
-rw-r--r--test/syscalls/linux/tuntap_hostinet.cc38
-rw-r--r--test/syscalls/linux/udp_socket.cc1321
-rw-r--r--test/syscalls/linux/udp_socket_errqueue_test_case.cc57
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc1781
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.h82
-rw-r--r--test/syscalls/linux/uidgid.cc21
-rw-r--r--test/syscalls/linux/unix_domain_socket_test_util.cc1
-rw-r--r--test/syscalls/linux/unix_domain_socket_test_util.h1
-rw-r--r--test/syscalls/linux/utimes.cc73
-rw-r--r--test/syscalls/linux/vdso_clock_gettime.cc1
-rw-r--r--test/syscalls/linux/vfork.cc2
-rw-r--r--test/syscalls/linux/vsyscall.cc2
-rw-r--r--test/syscalls/linux/write.cc10
-rw-r--r--test/syscalls/linux/xattr.cc610
-rwxr-xr-xtest/syscalls/syscall_test_runner.sh34
-rw-r--r--test/uds/BUILD3
-rw-r--r--test/util/BUILD65
-rw-r--r--test/util/capability_util.cc12
-rw-r--r--test/util/fs_util.cc17
-rw-r--r--test/util/fs_util.h18
-rw-r--r--test/util/fs_util_test.cc4
-rw-r--r--test/util/mount_util.h9
-rw-r--r--test/util/multiprocess_util.h6
-rw-r--r--test/util/platform_util.cc48
-rw-r--r--test/util/platform_util.h56
-rw-r--r--test/util/posix_error_test.cc1
-rw-r--r--test/util/pty_util.cc10
-rw-r--r--test/util/pty_util.h3
-rw-r--r--test/util/rlimit_util.cc1
-rw-r--r--test/util/save_util_linux.cc18
-rw-r--r--test/util/save_util_other.cc4
-rw-r--r--test/util/signal_util.cc1
-rw-r--r--test/util/signal_util.h15
-rw-r--r--test/util/temp_path.cc3
-rw-r--r--test/util/temp_path.h1
-rw-r--r--test/util/temp_umask.h (renamed from test/syscalls/linux/temp_umask.h)6
-rw-r--r--test/util/test_main.cc2
-rw-r--r--test/util/test_util.cc49
-rw-r--r--test/util/test_util.h62
-rw-r--r--test/util/test_util_impl.cc52
-rw-r--r--test/util/test_util_runfiles.cc50
-rw-r--r--test/util/test_util_test.cc1
-rw-r--r--third_party/gvsync/downgradable_rwmutex_1_12_unsafe.go21
-rw-r--r--third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go16
-rw-r--r--tools/BUILD9
-rw-r--r--tools/bazel.mk181
-rw-r--r--tools/bazeldefs/BUILD106
-rw-r--r--tools/bazeldefs/defs.bzl181
-rw-r--r--tools/bazeldefs/platforms.bzl9
-rw-r--r--tools/bazeldefs/tags.bzl56
-rw-r--r--tools/bigquery/BUILD10
-rw-r--r--tools/bigquery/bigquery.go121
-rw-r--r--tools/checkescape/BUILD16
-rw-r--r--tools/checkescape/checkescape.go726
-rw-r--r--tools/checkescape/test1/BUILD9
-rw-r--r--tools/checkescape/test1/test1.go195
-rw-r--r--tools/checkescape/test2/BUILD9
-rw-r--r--tools/checkescape/test2/test2.go94
-rw-r--r--tools/checkunsafe/BUILD8
-rw-r--r--tools/defs.bzl253
-rwxr-xr-xtools/go_branch.sh57
-rw-r--r--tools/go_generics/BUILD30
-rw-r--r--tools/go_generics/defs.bzl53
-rw-r--r--tools/go_generics/generics.go4
-rw-r--r--tools/go_generics/generics_tests/all_stmts/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/all_types/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/anon/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/consts/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/imports/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/remove_typedef/opts.txt1
-rw-r--r--tools/go_generics/generics_tests/simple/opts.txt1
-rw-r--r--tools/go_generics/globals/BUILD4
-rw-r--r--tools/go_generics/globals/scope.go4
-rwxr-xr-xtools/go_generics/go_generics_unittest.sh70
-rw-r--r--tools/go_generics/go_merge/BUILD4
-rw-r--r--tools/go_generics/rules_tests/BUILD2
-rw-r--r--tools/go_generics/tests/BUILD7
-rw-r--r--tools/go_generics/tests/all_stmts/BUILD16
-rw-r--r--tools/go_generics/tests/all_stmts/input.go (renamed from tools/go_generics/generics_tests/all_stmts/input.go)0
-rw-r--r--tools/go_generics/tests/all_stmts/output.go (renamed from tools/go_generics/generics_tests/all_stmts/output/output.go)0
-rw-r--r--tools/go_generics/tests/all_types/BUILD16
-rw-r--r--tools/go_generics/tests/all_types/input.go (renamed from tools/go_generics/generics_tests/all_types/input.go)4
-rw-r--r--tools/go_generics/tests/all_types/lib/lib.go (renamed from tools/go_generics/generics_tests/all_types/lib/lib.go)0
-rw-r--r--tools/go_generics/tests/all_types/output.go (renamed from tools/go_generics/generics_tests/all_types/output/output.go)4
-rw-r--r--tools/go_generics/tests/anon/BUILD18
-rw-r--r--tools/go_generics/tests/anon/input.go (renamed from tools/go_generics/generics_tests/anon/input.go)0
-rw-r--r--tools/go_generics/tests/anon/output.go (renamed from tools/go_generics/generics_tests/anon/output/output.go)4
-rw-r--r--tools/go_generics/tests/consts/BUILD23
-rw-r--r--tools/go_generics/tests/consts/input.go (renamed from tools/go_generics/generics_tests/consts/input.go)0
-rw-r--r--tools/go_generics/tests/consts/output.go (renamed from tools/go_generics/generics_tests/consts/output/output.go)0
-rw-r--r--tools/go_generics/tests/defs.bzl67
-rw-r--r--tools/go_generics/tests/imports/BUILD24
-rw-r--r--tools/go_generics/tests/imports/input.go (renamed from tools/go_generics/generics_tests/imports/input.go)0
-rw-r--r--tools/go_generics/tests/imports/output.go (renamed from tools/go_generics/generics_tests/imports/output/output.go)0
-rw-r--r--tools/go_generics/tests/remove_typedef/BUILD16
-rw-r--r--tools/go_generics/tests/remove_typedef/input.go (renamed from tools/go_generics/generics_tests/remove_typedef/input.go)0
-rw-r--r--tools/go_generics/tests/remove_typedef/output.go (renamed from tools/go_generics/generics_tests/remove_typedef/output/output.go)0
-rw-r--r--tools/go_generics/tests/simple/BUILD17
-rw-r--r--tools/go_generics/tests/simple/input.go (renamed from tools/go_generics/generics_tests/simple/input.go)0
-rw-r--r--tools/go_generics/tests/simple/output.go (renamed from tools/go_generics/generics_tests/simple/output/output.go)0
-rw-r--r--tools/go_marshal/BUILD15
-rw-r--r--tools/go_marshal/README.md60
-rw-r--r--tools/go_marshal/analysis/BUILD5
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go4
-rw-r--r--tools/go_marshal/defs.bzl113
-rw-r--r--tools/go_marshal/gomarshal/BUILD10
-rw-r--r--tools/go_marshal/gomarshal/generator.go253
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go457
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go146
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go289
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces_struct.go622
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go125
-rw-r--r--tools/go_marshal/gomarshal/util.go184
-rw-r--r--tools/go_marshal/main.go13
-rw-r--r--tools/go_marshal/marshal/BUILD9
-rw-r--r--tools/go_marshal/marshal/marshal.go141
-rw-r--r--tools/go_marshal/marshal/marshal_impl_util.go78
-rw-r--r--tools/go_marshal/primitive/BUILD18
-rw-r--r--tools/go_marshal/primitive/primitive.go247
-rw-r--r--tools/go_marshal/test/BUILD25
-rw-r--r--tools/go_marshal/test/benchmark_test.go46
-rw-r--r--tools/go_marshal/test/escape/BUILD14
-rw-r--r--tools/go_marshal/test/escape/escape.go95
-rw-r--r--tools/go_marshal/test/external/BUILD6
-rw-r--r--tools/go_marshal/test/external/external.go8
-rw-r--r--tools/go_marshal/test/marshal_test.go515
-rw-r--r--tools/go_marshal/test/test.go75
-rw-r--r--tools/go_stateify/BUILD11
-rw-r--r--tools/go_stateify/defs.bzl90
-rw-r--r--tools/go_stateify/main.go200
-rw-r--r--tools/installers/BUILD41
-rwxr-xr-xtools/installers/containerd.sh114
-rwxr-xr-xtools/installers/head.sh27
-rwxr-xr-xtools/installers/images.sh24
-rwxr-xr-xtools/installers/master.sh34
-rwxr-xr-xtools/installers/shim.sh33
-rw-r--r--tools/issue_reviver/BUILD12
-rw-r--r--tools/issue_reviver/github/BUILD24
-rw-r--r--tools/issue_reviver/github/github.go176
-rw-r--r--tools/issue_reviver/github/github_test.go55
-rw-r--r--tools/issue_reviver/main.go100
-rw-r--r--tools/issue_reviver/reviver/BUILD18
-rw-r--r--tools/issue_reviver/reviver/reviver.go192
-rw-r--r--tools/issue_reviver/reviver/reviver_test.go88
-rwxr-xr-xtools/make_apt.sh139
-rwxr-xr-xtools/make_release.sh81
-rwxr-xr-xtools/make_repository.sh79
-rw-r--r--tools/nogo.js7
-rw-r--r--tools/nogo/BUILD55
-rw-r--r--tools/nogo/README.md31
-rw-r--r--tools/nogo/build.go40
-rw-r--r--tools/nogo/check/BUILD12
-rw-r--r--tools/nogo/check/main.go24
-rw-r--r--tools/nogo/config.go116
-rw-r--r--tools/nogo/data/BUILD10
-rw-r--r--tools/nogo/data/data.go21
-rw-r--r--tools/nogo/defs.bzl176
-rw-r--r--tools/nogo/io_bazel_rules_go-visibility.patch25
-rw-r--r--tools/nogo/matchers.go143
-rw-r--r--tools/nogo/nogo.go326
-rw-r--r--tools/nogo/register.go64
-rwxr-xr-xtools/tag_release.sh26
-rw-r--r--tools/tags/BUILD11
-rw-r--r--tools/tags/tags.go89
-rw-r--r--tools/vm/BUILD63
-rw-r--r--tools/vm/README.md48
-rwxr-xr-xtools/vm/build.sh (renamed from tools/image_build.sh)77
-rw-r--r--tools/vm/defs.bzl202
-rwxr-xr-xtools/vm/execute.sh160
-rw-r--r--tools/vm/test.cc27
-rwxr-xr-xtools/vm/ubuntu1604/10_core.sh (renamed from kokoro/ubuntu1604/10_core.sh)21
-rwxr-xr-xtools/vm/ubuntu1604/15_gcloud.sh50
-rwxr-xr-xtools/vm/ubuntu1604/20_bazel.sh (renamed from kokoro/ubuntu1604/20_bazel.sh)14
-rwxr-xr-xtools/vm/ubuntu1604/30_docker.sh (renamed from kokoro/ubuntu1604/25_docker.sh)43
-rwxr-xr-xtools/vm/ubuntu1604/40_kokoro.sh (renamed from kokoro/ubuntu1604/40_kokoro.sh)38
-rw-r--r--tools/vm/ubuntu1604/BUILD7
-rw-r--r--tools/vm/ubuntu1804/BUILD7
-rwxr-xr-xtools/vm/zone.sh17
-rwxr-xr-xtools/workspace_status.sh2
-rw-r--r--vdso/BUILD43
-rw-r--r--vdso/syscalls.h33
-rw-r--r--vdso/vdso.cc16
-rw-r--r--vdso/vdso_amd64.lds1
-rw-r--r--website/BUILD188
-rw-r--r--website/_config.yml36
-rw-r--r--website/_includes/byline.html18
-rw-r--r--website/_includes/footer-links.html43
-rw-r--r--website/_includes/footer.html72
-rw-r--r--website/_includes/graph.html205
-rw-r--r--website/_includes/header-links.html19
-rw-r--r--website/_includes/header.html30
-rw-r--r--website/_includes/paginator.html10
-rw-r--r--website/_includes/required_linux.html2
-rw-r--r--website/_layouts/base.html9
-rw-r--r--website/_layouts/blog.html17
-rw-r--r--website/_layouts/default.html14
-rw-r--r--website/_layouts/docs.html54
-rw-r--r--website/_layouts/post.html10
-rw-r--r--website/_plugins/svg_mime_type.rb3
-rw-r--r--website/_sass/footer.scss15
-rw-r--r--website/_sass/front.scss17
-rw-r--r--website/_sass/navbar.scss26
-rw-r--r--website/_sass/sidebar.scss61
-rw-r--r--website/_sass/style.scss154
-rw-r--r--website/archive.key29
-rw-r--r--website/assets/favicons/apple-touch-icon-180x180.pngbin0 -> 18820 bytes
-rw-r--r--website/assets/favicons/favicon-16x16.pngbin0 -> 926 bytes
-rw-r--r--website/assets/favicons/favicon-32x32.pngbin0 -> 2308 bytes
-rw-r--r--website/assets/favicons/favicon.icobin0 -> 1150 bytes
-rw-r--r--website/assets/favicons/pwa-192x192.pngbin0 -> 20666 bytes
-rw-r--r--website/assets/favicons/pwa-512x512.pngbin0 -> 24397 bytes
-rw-r--r--website/assets/favicons/tile150x150.pngbin0 -> 18440 bytes
-rw-r--r--website/assets/favicons/tile310x150.pngbin0 -> 21486 bytes
-rw-r--r--website/assets/favicons/tile310x310.pngbin0 -> 25629 bytes
-rw-r--r--website/assets/favicons/tile70x70.pngbin0 -> 11148 bytes
-rw-r--r--website/assets/images/2019-11-18-security-basics-figure1.pngbin0 -> 19088 bytes
-rw-r--r--website/assets/images/2019-11-18-security-basics-figure2.pngbin0 -> 17642 bytes
-rw-r--r--website/assets/images/2019-11-18-security-basics-figure3.pngbin0 -> 16471 bytes
-rw-r--r--website/assets/images/2020-04-02-networking-security-figure1.pngbin0 -> 29775 bytes
-rw-r--r--website/assets/images/background.jpgbin0 -> 1070364 bytes
-rw-r--r--website/assets/logos/Makefile13
-rw-r--r--website/assets/logos/README.md10
-rw-r--r--website/assets/logos/logo_solo_monochrome.pngbin0 -> 10483 bytes
-rw-r--r--website/assets/logos/logo_solo_monochrome.svg73
-rw-r--r--website/assets/logos/logo_solo_on_dark-1024.pngbin0 -> 59374 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark-128.pngbin0 -> 5951 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark-16.pngbin0 -> 701 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark.pngbin0 -> 8387 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark.svg73
-rw-r--r--website/assets/logos/logo_solo_on_dark_full-1024.pngbin0 -> 80121 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full-128.pngbin0 -> 8616 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full-16.pngbin0 -> 900 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full.pngbin0 -> 17055 bytes
-rw-r--r--website/assets/logos/logo_solo_on_dark_full.svg79
-rw-r--r--website/assets/logos/logo_solo_on_white.pngbin0 -> 10572 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white.svg73
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered-1024.pngbin0 -> 95350 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered-128.pngbin0 -> 10231 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered-16.pngbin0 -> 960 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered.pngbin0 -> 15330 bytes
-rw-r--r--website/assets/logos/logo_solo_on_white_bordered.svg82
-rw-r--r--website/assets/logos/logo_with_text_monochrome.pngbin0 -> 22220 bytes
-rw-r--r--website/assets/logos/logo_with_text_monochrome.svg116
-rw-r--r--website/assets/logos/logo_with_text_on_dark-1024.pngbin0 -> 30774 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark-128.pngbin0 -> 3129 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark-16.pngbin0 -> 315 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark.pngbin0 -> 17035 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark.svg116
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full-1024.pngbin0 -> 34866 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full-128.pngbin0 -> 3746 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full-16.pngbin0 -> 372 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full.pngbin0 -> 25956 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_dark_full.svg120
-rw-r--r--website/assets/logos/logo_with_text_on_white.pngbin0 -> 22363 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_white.svg116
-rw-r--r--website/assets/logos/logo_with_text_on_white_bordered.pngbin0 -> 27719 bytes
-rw-r--r--website/assets/logos/logo_with_text_on_white_bordered.svg122
-rw-r--r--website/assets/logos/powered-gvisor.pngbin0 -> 5193 bytes
-rw-r--r--website/blog/2019-11-18-security-basics.md306
-rw-r--r--website/blog/2020-04-02-networking-security.md183
-rw-r--r--website/blog/BUILD37
-rw-r--r--website/blog/index.html22
-rw-r--r--website/cmd/server/BUILD10
-rw-r--r--website/cmd/server/main.go215
-rw-r--r--website/cmd/syscalldocs/BUILD9
-rw-r--r--website/cmd/syscalldocs/main.go211
-rw-r--r--website/css/main.scss5
-rw-r--r--website/defs.bzl178
-rwxr-xr-xwebsite/import.sh (renamed from test/runtimes/runner.sh)24
-rw-r--r--website/index.md50
-rw-r--r--website/performance/README.md10
-rw-r--r--website/performance/applications.csv13
-rw-r--r--website/performance/density.csv9
-rw-r--r--website/performance/ffmpeg.csv3
-rw-r--r--website/performance/fio-tmpfs.csv9
-rw-r--r--website/performance/fio.csv9
-rw-r--r--website/performance/httpd100k.csv17
-rw-r--r--website/performance/httpd10240k.csv17
-rw-r--r--website/performance/iperf.csv5
-rw-r--r--website/performance/redis.csv35
-rw-r--r--website/performance/startup.csv7
-rw-r--r--website/performance/sysbench-cpu.csv3
-rw-r--r--website/performance/sysbench-memory.csv3
-rw-r--r--website/performance/syscall.csv4
-rw-r--r--website/performance/tensorflow.csv3
2261 files changed, 202091 insertions, 39456 deletions
diff --git a/.bazelrc b/.bazelrc
index 379fc8328..a2fe95822 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -12,36 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# RBE requires a strong hash function, such as SHA256.
+startup --host_jvm_args=-Dbazel.DigestFunction=SHA256
+
+# Build with C++17.
+build --cxxopt=-std=c++17
+
# Display the current git revision in the info block.
build --stamp --workspace_status_command tools/workspace_status.sh
# Enable remote execution so actions are performed on the remote systems.
build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com
-
-# Add a custom platform and toolchain that builds in a privileged docker
-# container, which is required by our syscall tests.
-build:remote --host_platform=//test:rbe_ubuntu1604
-build:remote --extra_toolchains=//test:cc-toolchain-clang-x86_64-default
-build:remote --extra_execution_platforms=//test:rbe_ubuntu1604
-build:remote --platforms=//test:rbe_ubuntu1604
-
-# Use default image for crosstool toolchain.
-build:remote --crosstool_top=@rbe_default//cc:toolchain
-
-# Default parallelism and timeout for remote jobs.
-build:remote --jobs=50
-build:remote --remote_timeout=3600
-
-# RBE requires a strong hash function, such as SHA256.
-startup --host_jvm_args=-Dbazel.DigestFunction=SHA256
+build:remote --bes_backend=buildeventservice.googleapis.com
+build:remote --bes_results_url="https://source.cloud.google.com/results/invocations"
+build:remote --bes_timeout=600s
+build:remote --project_id=gvisor-rbe
+build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance
+build:remote3 --remote_executor=grpcs://remotebuildexecution.googleapis.com
+build:remote3 --project_id=gvisor-rbe
+build:remote3 --bes_backend=buildeventservice.googleapis.com
+build:remote3 --bes_results_url="https://source.cloud.google.com/results/invocations"
+build:remote3 --bes_timeout=600s
+build:remote3 --remote_instance_name=projects/gvisor-rbe/instances/default_instance
# Enable authentication. This will pick up application default credentials by
# default. You can use --google_credentials=some_file.json to use a service
# account credential instead.
build:remote --google_default_credentials=true
-
-# Auth scope needed for authentication with RBE.
build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
+build:remote3 --google_default_credentials=true
+build:remote3 --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
+
+# Add a custom platform and toolchain that builds in a privileged docker
+# container, which is required by our syscall tests.
+build:remote --host_platform=//tools/bazeldefs:rbe_ubuntu1604
+build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default
+build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604
+build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604
+build:remote --crosstool_top=@rbe_default//cc:toolchain
+build:remote --jobs=100
+build:remote --remote_timeout=3600
+build:remote3 --host_platform=//tools/bazeldefs:rbe_ubuntu1604_bazel3
+build:remote3 --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default_bazel3
+build:remote3 --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3
+build:remote3 --platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3
+build:remote3 --crosstool_top=@rbe_default//cc:toolchain
+build:remote3 --jobs=100
+build:remote3 --remote_timeout=3600
# Set flags for uploading to BES in order to view results in the Bazel Build
# Results UI.
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 000000000..49a1ba697
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,31 @@
+---
+name: Bug report
+about: Create a bug report to help us improve
+title:
+labels:
+ - 'type: bug'
+assignees: ''
+---
+
+**Description**
+
+A clear description of what the bug is. If possible, explicitly indicate the
+expected behavior vs. the observed behavior.
+
+**Steps to reproduce**
+
+If available, please include detailed reproduction steps.
+
+If the bug requires software that is not publicly available, see if it can be
+reproduced with software that is publicly available.
+
+**Environment**
+
+Please include the following details of your environment:
+
+* `runsc -v`
+* `docker version` or `docker info` (if available)
+* `kubectl version` and `kubectl get nodes` (if using Kubernetes)
+* `uname -a`
+* `git describe` (if built from source)
+* `runsc` debug logs (if available)
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 000000000..772c9a0ac
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,11 @@
+blank_issues_enabled: false
+contact_links:
+ - name: gVisor Documentation (FAQ)
+ url: https://gvisor.dev/docs/user_guide/faq/
+ about: Please see our documentation for common questions and answers.
+ - name: gVisor Documentation (Debugging)
+ url: https://gvisor.dev/docs/user_guide/debugging/
+ about: Please see our documentation for debugging tips.
+ - name: gVisor User Forum
+ url: https://groups.google.com/g/gvisor-users
+ about: Ask and answer general questions here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 000000000..65f60f385
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,21 @@
+---
+name: Feature request
+about: Suggest an idea or improvement
+title: ''
+labels:
+ - 'type: enhancement'
+assignees: ''
+---
+
+**Description**
+
+A clear description of the feature or enhancement.
+
+**Is this feature related to a specific bug?**
+
+Please include a bug references if yes.
+
+**Do you have a specific solution in mind?**
+
+Please include any details about a solution that you have in mind, including any
+alternatives considered.
diff --git a/.github/issue_template.md b/.github/issue_template.md
deleted file mode 100644
index 77c401d22..000000000
--- a/.github/issue_template.md
+++ /dev/null
@@ -1,20 +0,0 @@
-Before filling an issue, please consult our FAQ:
-https://gvisor.dev/docs/user_guide/faq/
-
-Also check that the issue hasn't been reported before.
-
-If you have a question, please email gvisor-users@googlegroups.com rather than filing a bug.
-
-If you believe you've found a security issue, please email gvisor-security@googlegroups.com rather than filing a bug.
-
-If this is your first time compiling or running gVisor, please make sure that your system meets the minimum requirements: https://github.com/google/gvisor#requirements
-
-For all other issues, please attach debug logs. To get debug logs, follow the
-instructions here: https://gvisor.dev/docs/user_guide/debugging/
-
-Other useful information to include is:
-
-* `runsc -v`
-* `docker version` or `docker info` if more relevant
-* `uname -a` - `git describe`
-* Detailed reproduction steps
diff --git a/.github/labeler.yml b/.github/labeler.yml
new file mode 100644
index 000000000..b6a17051c
--- /dev/null
+++ b/.github/labeler.yml
@@ -0,0 +1,42 @@
+"arch: arm":
+ - "**/*_arm64.*"
+ - "**/*_aarch64.*"
+"arch: x86_64":
+ - "**/*_amd64.*"
+ - "**/*_x86.*"
+"area: bazel":
+ - "**/BUILD"
+ - "**/*.bzl"
+"area: docs":
+ - "**/g3doc/**"
+ - "**/README.md"
+"area: filesystem":
+ - "pkg/sentry/fs/**"
+ - "pkg/sentry/vfs/**"
+ - "pkg/sentry/fsimpl/**"
+"area: hostinet":
+ - "pkg/sentry/socket/hostinet/**"
+"area: networking":
+ - "pkg/tcpip/**"
+ - "pkg/sentry/socket/**"
+"area: kernel":
+ - "pkg/sentry/arch/**"
+ - "pkg/sentry/kernel/**"
+ - "pkg/sentry/syscalls/**"
+"area: mm":
+ - "pkg/sentry/mm/**"
+"area: tests":
+ - "**/tests/**"
+ - "**/*_test.go"
+ - "**/test/**"
+"area: tooling":
+ - "tools/**"
+"dependencies":
+ - "WORKSPACE"
+ - "go.mod"
+ - "go.sum"
+"platform: kvm":
+ - "pkg/sentry/platform/kvm/**"
+ - "pkg/sentry/platform/ring0/**"
+"platform: ptrace":
+ - "pkg/sentry/platform/ptrace/**"
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 000000000..264b4e9fa
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,5 @@
+* [ ] Have you followed the guidelines in [CONTRIBUTING.md](../blob/master/CONTRIBUTING.md)?
+* [ ] Have you formatted and linted your code?
+* [ ] Have you added relevant tests?
+* [ ] Have you added appropriate Fixes & Updates references?
+* [ ] If yes, please erase all these lines!
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
new file mode 100644
index 000000000..cf782a580
--- /dev/null
+++ b/.github/workflows/build.yml
@@ -0,0 +1,21 @@
+name: "Build"
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ default:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - uses: actions/cache@v1
+ with:
+ path: ~/.cache/bazel
+ key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }}
+ restore-keys: |
+ ${{ runner.os }}-bazel-
+ - run: make
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
new file mode 100644
index 000000000..4da3853b2
--- /dev/null
+++ b/.github/workflows/go.yml
@@ -0,0 +1,75 @@
+name: "Go"
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+jobs:
+ generate:
+ runs-on: ubuntu-latest
+ steps:
+ - id: setup
+ run: |
+ if ! [[ -z "${{ secrets.GO_TOKEN }}" ]]; then
+ echo ::set-output name=has_token::true
+ else
+ echo ::set-output name=has_token::false
+ fi
+ - run: |
+ jq -nc '{"state": "pending", "context": "go tests"}' | \
+ curl -sL -X POST -d @- \
+ -H "Content-Type: application/json" \
+ -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
+ "${{ github.event.pull_request.statuses_url }}"
+ if: github.event_name == 'pull_request'
+ - uses: actions/checkout@v2
+ if: github.event_name == 'push' && steps.setup.outputs.has_token == 'true'
+ with:
+ fetch-depth: 0
+ token: '${{ secrets.GO_TOKEN }}'
+ - uses: actions/checkout@v2
+ if: github.event_name == 'pull_request' || steps.setup.outputs.has_token != 'true'
+ with:
+ fetch-depth: 0
+ - uses: actions/setup-go@v2
+ with:
+ go-version: 1.14
+ - uses: actions/cache@v1
+ with:
+ path: ~/go/pkg/mod
+ key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
+ restore-keys: |
+ ${{ runner.os }}-go-
+ - uses: actions/cache@v1
+ with:
+ path: ~/.cache/bazel
+ key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }}
+ restore-keys: |
+ ${{ runner.os }}-bazel-
+ - run: |
+ rm -rf bazel-bin/gopath
+ make build TARGETS="//:gopath"
+ - run: tools/go_branch.sh
+ - run: git checkout go && git clean -f
+ - run: go build ./...
+ - if: github.event_name == 'push'
+ run: |
+ git remote add upstream "https://github.com/${{ github.repository }}"
+ git push upstream go:go
+ - if: ${{ success() && github.event_name == 'pull_request' }}
+ run: |
+ jq -nc '{"state": "success", "context": "go tests"}' | \
+ curl -sL -X POST -d @- \
+ -H "Content-Type: application/json" \
+ -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
+ "${{ github.event.pull_request.statuses_url }}"
+ - if: ${{ failure() && github.event_name == 'pull_request' }}
+ run: |
+ jq -nc '{"state": "failure", "context": "go tests"}' | \
+ curl -sL -X POST -d @- \
+ -H "Content-Type: application/json" \
+ -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
+ "${{ github.event.pull_request.statuses_url }}"
diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml
new file mode 100644
index 000000000..2b399a3f2
--- /dev/null
+++ b/.github/workflows/issue_reviver.yml
@@ -0,0 +1,16 @@
+name: "Issue reviver"
+on:
+ schedule:
+ - cron: '0 0 * * *'
+
+jobs:
+ issue_reviver:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ if: github.repository == 'google/gvisor'
+ - run: make run TARGETS="//tools/issue_reviver"
+ if: github.repository == 'google/gvisor'
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ GITHUB_REPOSITORY: ${{ github.repository }}
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
new file mode 100644
index 000000000..c09f7eb36
--- /dev/null
+++ b/.github/workflows/labeler.yml
@@ -0,0 +1,12 @@
+name: "Labeler"
+on:
+- pull_request
+
+jobs:
+ label:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/labeler@v2
+ if: github.base_ref == null
+ with:
+ repo-token: "${{ secrets.GITHUB_TOKEN }}"
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
new file mode 100644
index 000000000..0b31fecf5
--- /dev/null
+++ b/.github/workflows/stale.yml
@@ -0,0 +1,20 @@
+name: "Close stale issues"
+on:
+ schedule:
+ - cron: "0 0 * * *"
+
+jobs:
+ stale:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/stale@v3
+ with:
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
+ stale-issue-label: 'stale'
+ stale-pr-label: 'stale'
+ exempt-issue-labels: 'exported, type: bug, type: cleanup, type: enhancement, type: process, type: proposal, type: question'
+ exempt-pr-labels: 'ready to pull'
+ stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
+ stale-pr-message: 'This pull request is stale because it has been open 90 days with no activity. Remove the stale label or comment or this will be closed in 30 days.'
+ days-before-stale: 90
+ days-before-close: 30
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 000000000..2d9fa80a1
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,47 @@
+language: shell
+dist: xenial
+git:
+ clone: false # Clone manually in before_install
+before_install:
+ - set -e -o pipefail
+ - |
+ if [ "${TRAVIS_PULL_REQUEST}" = false ]; then
+ # This is not a PR build, fetch and checkout the commit being tested
+ git clone -q --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}"
+ cd "${TRAVIS_REPO_SLUG}"
+ git fetch origin "${TRAVIS_COMMIT}" --depth 1
+ git checkout -qf "${TRAVIS_COMMIT}"
+ else
+ # This is a PR build, simulate +refs/pull/{num}/merge.
+ # We can do that by fetching +refs/pull/{num}/head and cherry picking it
+ # onto the target branch.
+ git clone -q --branch "${TRAVIS_BRANCH}" --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}"
+ cd "${TRAVIS_REPO_SLUG}"
+ git fetch origin "+refs/pull/${TRAVIS_PULL_REQUEST}/head" --depth 1
+ git config --global user.email "$(git log -1 FETCH_HEAD --pretty="%cE")"
+ git config --global user.name "$(git log -1 FETCH_HEAD --pretty="%aN")"
+ git cherry-pick --strategy=recursive -X theirs --keep-redundant-commits FETCH_HEAD
+ fi
+cache:
+ directories:
+ - /home/travis/.cache/bazel/
+os: linux
+services:
+ - docker
+jobs:
+ include:
+ # AMD64 builds are tested on kokoro, so don't run them in travis to save
+ # capacity for arm64 builds.
+ # - os: linux
+ # arch: amd64
+ - os: linux
+ arch: arm64
+script:
+ # On arm64, we need to create our own pipes for stderr and stdout,
+ # otherwise we will not be able to open /dev/stderr. This is probably
+ # due to AppArmor rules.
+ - bash -xeo pipefail -c 'uname -a && make smoke-tests 2>&1 | cat'
+branches:
+ except:
+ # Skip copybara branches.
+ - /^test\/cl.*$/
diff --git a/BUILD b/BUILD
index de410b008..2639f8169 100644
--- a/BUILD
+++ b/BUILD
@@ -1,13 +1,68 @@
-package(licenses = ["notice"]) # Apache 2.0
+load("//tools:defs.bzl", "build_test", "gazelle", "go_path")
+load("//website:defs.bzl", "doc")
-load("@io_bazel_rules_go//go:def.bzl", "go_path", "nogo")
-load("@bazel_gazelle//:def.bzl", "gazelle")
+package(licenses = ["notice"])
+
+exports_files(["LICENSE"])
+
+doc(
+ name = "contributing",
+ src = "CONTRIBUTING.md",
+ category = "Project",
+ permalink = "/contributing/",
+ visibility = ["//website:__pkg__"],
+ weight = "20",
+)
+
+doc(
+ name = "security",
+ src = "SECURITY.md",
+ category = "Project",
+ permalink = "/security/",
+ visibility = ["//website:__pkg__"],
+ weight = "30",
+)
+
+doc(
+ name = "governance",
+ src = "GOVERNANCE.md",
+ category = "Project",
+ permalink = "/community/governance/",
+ subcategory = "Community",
+ visibility = ["//website:__pkg__"],
+ weight = "20",
+)
+
+doc(
+ name = "code_of_conduct",
+ src = "CODE_OF_CONDUCT.md",
+ category = "Project",
+ permalink = "/community/code_of_conduct/",
+ subcategory = "Community",
+ visibility = ["//website:__pkg__"],
+ weight = "99",
+)
# The sandbox filegroup is used for sandbox-internal dependencies.
package_group(
name = "sandbox",
- packages = [
- "//...",
+ packages = ["//..."],
+)
+
+# For targets that will not normally build internally, we ensure that they are
+# least build by a static BUILD test.
+build_test(
+ name = "build_test",
+ targets = [
+ "//test/e2e:integration_test",
+ "//test/image:image_test",
+ "//test/root:root_test",
+ "//test/benchmarks/base:base_test",
+ "//test/benchmarks/database:database_test",
+ "//test/benchmarks/fs:fs_test",
+ "//test/benchmarks/media:media_test",
+ "//test/benchmarks/ml:ml_test",
+ "//test/benchmarks/network:network_test",
],
)
@@ -20,10 +75,24 @@ go_path(
name = "gopath",
mode = "link",
deps = [
+ # Main binary.
"//runsc",
+ "//shim/v1:gvisor-containerd-shim",
+ "//shim/v2:containerd-shim-runsc-v1",
# Packages that are not dependencies of //runsc.
+ "//pkg/sentry/kernel/memevent",
+ "//pkg/tcpip/adapters/gonet",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/muxed",
+ "//pkg/tcpip/link/sharedmem",
+ "//pkg/tcpip/link/sharedmem/pipe",
+ "//pkg/tcpip/link/sharedmem/queue",
+ "//pkg/tcpip/link/tun",
+ "//pkg/tcpip/link/waitable",
+ "//pkg/tcpip/sample/tun_tcp_connect",
+ "//pkg/tcpip/sample/tun_tcp_echo",
+ "//pkg/tcpip/transport/tcpconntrack",
],
)
@@ -32,15 +101,3 @@ go_path(
# To update the WORKSPACE from go.mod, use:
# bazel run //:gazelle -- update-repos -from_file=go.mod
gazelle(name = "gazelle")
-
-# nogo applies checks to all Go source in this repository, enforcing code
-# guidelines and restrictions. Note that the tool libraries themselves should
-# live in the tools subdirectory (unless they are standard).
-nogo(
- name = "nogo",
- config = "tools/nogo.js",
- visibility = ["//visibility:public"],
- deps = [
- "//tools/checkunsafe",
- ],
-)
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index eb6c8edae..fbf517fe5 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -87,6 +87,5 @@ harassment or threats to anyone's safety, we may take action without notice.
## Attribution
-This Code of Conduct is adapted from the Contributor Covenant, version 1.4,
-available at
-https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
+This Code of Conduct is adapted from the
+[Contributor Covenant, version 1.4](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html).
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 5d46168bc..89180eb3f 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -32,11 +32,13 @@ will need to be added to the appropriate `BUILD` files, and the `:gopath` target
will need to be re-run to generate appropriate symlinks in the `GOPATH`
directory tree.
+Dependencies can be added by using `go mod get`. In order to keep the
+`WORKSPACE` file in sync, run `tools/go_mod.sh` in place of `go mod`.
+
### Coding Guidelines
-All Go code should conform to the [Go style guidelines][gostyle]. C++ code
-should conform to the [Google C++ Style Guide][cppstyle] and the guidelines
-described for [tests][teststyle].
+All code should comply with the [style guide](g3doc/style.md). Note that code
+may be automatically formatted per the guidelines when merged.
As a secure runtime, we need to maintain the safety of all of code included in
gVisor. The following rules help mitigate issues.
@@ -46,7 +48,7 @@ Definitions for the rules below:
`core`:
* `//pkg/sentry/...`
-* Transitive dependencies in `//pkg/...`, `//third_party/...`.
+* Transitive dependencies in `//pkg/...`, etc.
`runsc`:
@@ -104,32 +106,15 @@ ignored.
### Build and test with Docker
-`scripts/dev.sh` is a convenient script that builds and installs `runsc` as a
-new Docker runtime for you. The scripts tries to extract the runtime name from
-your local environment and will print it at the end. You can also customize it.
-The script creates one regular runtime and another with debug flags enabled.
-Here are a few examples:
+Running `make dev` is a convenient way to build and install `runsc` as a Docker
+runtime. The output of this command will show the runtimes installed.
+
+You may use `make refresh` to refresh the binary after any changes. For example:
```bash
-# Default case (inside branch my-branch)
-$ scripts/dev.sh
-...
-Runtimes my-branch and my-branch-d (debug enabled) setup.
-Use --runtime=my-branch with your Docker command.
- docker run --rm --runtime=my-branch --rm hello-world
-
-If you rebuild, use scripts/dev.sh --refresh.
-Logs are in: /tmp/my-branch/logs
-
-# --refresh just updates the runtime binary and doesn't restart docker.
-$ git/my_branch> scripts/dev.sh --refresh
-
-# Using a custom runtime name
-$ git/my_branch> scripts/dev.sh my-runtime
-...
-Runtimes my-runtime and my-runtime-d (debug enabled) setup.
-Use --runtime=my-runtime with your Docker command.
- docker run --rm --runtime=my-runtime --rm hello-world
+make dev
+docker run --rm --runtime=my-branch --rm hello-world
+make refresh
```
### The small print
@@ -138,10 +123,7 @@ Contributions made by corporations are covered by a different agreement than the
one above, the
[Software Grant and Corporate Contributor License Agreement][gccla].
-[cppstyle]: https://google.github.io/styleguide/cppguide.html
[gcla]: https://cla.developers.google.com/about/google-individual
[gccla]: https://cla.developers.google.com/about/google-corporate
[github]: https://github.com/google/gvisor/compare
[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev
-[gostyle]: https://github.com/golang/go/wiki/CodeReviewComments
-[teststyle]: ./test/
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index 6e9d870db..000000000
--- a/Dockerfile
+++ /dev/null
@@ -1,8 +0,0 @@
-FROM ubuntu:bionic
-
-RUN apt-get update && apt-get install -y curl gnupg2 git python3
-RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
- curl https://bazel.build/bazel-release.pub.gpg | apt-key add -
-RUN apt-get update && apt-get install -y bazel && apt-get clean
-
-WORKDIR /gvisor
diff --git a/GOVERNANCE.md b/GOVERNANCE.md
new file mode 100644
index 000000000..40846bc2f
--- /dev/null
+++ b/GOVERNANCE.md
@@ -0,0 +1,113 @@
+# Governance
+
+## Projects
+
+A *project* is the primary unit of collaboration. Each project may have its own
+repository and contribution process.
+
+All projects are covered by the [Code of Conduct](CODE_OF_CONDUCT.md), and
+should include an up-to-date copy in the project repository or a link here.
+
+## Contributors
+
+Anyone can be a *contributor* to a project, provided they have signed relevant
+Contributor License Agreements (CLAs) and follow the project's contribution
+guidelines. Contributions will be reviewed by a maintainer, and must pass all
+applicable tests.
+
+Reviews check for code quality and style, including documentation, and enforce
+other policies. Contributions may be rejected for reasons unrelated to the code
+in question. For example, a change may be too complex to maintain or duplicate
+existing functionality.
+
+Note that contributions are not limited to code alone. Bugs, documentation,
+experience reports or public advocacy are all valuable ways to contribute to a
+project and build trust in the community.
+
+## Maintainers
+
+Each project has one or more *maintainers*. Maintainers set technical direction,
+facilitate contributions and exercise overall stewardship.
+
+Maintainers have write access to the project repository. Maintainers review and
+approve changes. They can also assign issues and add additional reviewers.
+
+Note that some repositories may not allow direct commit access, which is
+reserved for administrators or automated processes. In this case, maintainers
+have approval rights, and a separate process exists for merging a change.
+
+Maintainers are responsible for upholding the code of conduct in interactions
+via project communication channels. If comments or exchanges are in violation,
+they may remove them at their discretion.
+
+### Repositories requiring synchronization
+
+For some projects initiated by Google, the infrastructure which synchronizes and
+merges internal and external changes requires that merges are performed by a
+Google employee. In such cases, Google will initiate a rotation to merge changes
+once they pass tests and are approved by a maintainer. This does not preclude
+non-Google contributors from becoming maintainers, in which case the maintainer
+holds approval rights and the merge is an automated process. In some cases,
+Google-internal tests may fail and have to be fixed: the Google employee will
+work with the submitter to achieve this.
+
+### Becoming a maintainer
+
+The list of maintainers is defined by the list of people with commit access or
+approval authority on a repository, typically via a Gerrit group or a GitHub
+team.
+
+Existing maintainers may elevate a contributor to maintainer status on evidence
+of previous contributions and established trust. This decision is based on lazy
+consensus from existing maintainers. While contributors may ask maintainers to
+make this decision, existing maintainers will also pro-actively identify
+contributors who have demonstrated a sustained track record of technical
+leadership and direct contributions.
+
+## Special Interest Groups (SIGs)
+
+From time-to-time, a SIG may be formed in order to solve larger, more complex
+problems across one or more projects. There are many avenues for collaboration
+outside a SIG, but a SIG can provide structure for collaboration on a single
+topic.
+
+Each group will be established by a charter, and governed by the Code of
+Conduct. Some resources may be provided to the group, such as mailing lists or
+meeting space, and archives will be public.
+
+## Security disclosure
+
+Projects may maintain security mailing lists for vulnerability reports and
+internal project audits may occasionally reveal security issues. Access to these
+lists and audits will be limited to project *maintainers*; individual
+maintainers should opt to participate in these lists based on need and
+expertise. Once maintainers become aware of a potential security issue, they
+will assess the scope and potential impact. If reported externally, maintainers
+will determine a reasonable embargo period with the reporter.
+
+During the embargo period, the maintainers will prioritize a fix for the
+security issue. They may choose to disclose the issue to additional trusted
+contributors in order to facilitate a fix, subjecting them to the embargo, or
+notify affected users in order to give them an advanced opportunity to mitigate
+the issue. The inclusion of specific users in this disclosure is left to the
+discretion of the maintainers and contributors involved, and depends on the
+scale of known project use and exposure.
+
+Once a fix is widely available or the embargo period ends, the maintainers will
+make technical details about the vulnerability and associated fixes available.
+
+## Mailing lists
+
+There are four key mailing lists that span projects.
+
+* [gvisor-users](mailto:gvisor-users@googlegroups.com): general purpose user
+ list.
+* [gvisor-dev](mailto:gvisor-dev@googlegroups.com): general purpose
+ development list.
+* [gvisor-security](mailto:gvisor-security@googlegroups.com): private security
+ list. Access to this list is restricted to maintainers of the core gVisor
+ project, subject to the security disclosure policy described above.
+* [gvisor-syzkaller](mailto:gvisor-syzkaller@googlegroups.com): private
+ syzkaller bug tracking list. Access to this list is not limited to
+ maintainers, but will be granted to those who can credibly contribute to
+ fixes.
diff --git a/LICENSE b/LICENSE
index d64569567..74fddbbd9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -200,3 +200,25 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+
+------------------
+
+Some files carry the following license, noted at the top of each file:
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in
+ all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ THE SOFTWARE. \ No newline at end of file
diff --git a/Makefile b/Makefile
index 1735c07df..fdbc6fb41 100644
--- a/Makefile
+++ b/Makefile
@@ -1,47 +1,371 @@
-UID := $(shell id -u ${USER})
-GID := $(shell id -g ${USER})
-GVISOR_BAZEL_CACHE := $(shell readlink -f ~/.cache/bazel/)
+#!/usr/bin/make -f
-all: runsc
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
-docker-build:
- docker build -t gvisor-bazel .
+# Helpful pretty-printer.
+MAKEBANNER := \033[1;34mmake\033[0m
+submake = echo -e '$(MAKEBANNER) $1' >&2; $(MAKE) $1
-bazel-shutdown:
- docker exec -i gvisor-bazel bazel shutdown && \
- docker kill gvisor-bazel
+# Described below.
+OPTIONS :=
+STARTUP_OPTIONS :=
+TARGETS := //runsc
+ARGS :=
-bazel-server-start: docker-build
- mkdir -p "$(GVISOR_BAZEL_CACHE)" && \
- docker run -d --rm --name gvisor-bazel \
- --user 0:0 \
- -v "$(GVISOR_BAZEL_CACHE):$(HOME)/.cache/bazel/" \
- -v "$(CURDIR):$(CURDIR)" \
- --workdir "$(CURDIR)" \
- --tmpfs /tmp:rw,exec \
- --privileged \
- gvisor-bazel \
- sh -c "while :; do sleep 100; done" && \
- docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --gid $(GID) -d $(HOME) gvisor"
+default: runsc
+.PHONY: default
-bazel-server:
- docker exec gvisor-bazel true || \
- $(MAKE) bazel-server-start
+## usage: make <target>
+## or
+## make <build|test|copy|run|sudo> STARTUP_OPTIONS="..." OPTIONS="..." TARGETS="..." ARGS="..."
+##
+## Basic targets.
+##
+## This Makefile wraps basic build and test targets for ease-of-use. Bazel
+## is run inside a canonical Docker container in order to simplify up-front
+## requirements.
+##
+## There are common arguments that may be passed to targets. These are:
+## STARTUP_OPTIONS - Bazel startup options.
+## OPTIONS - Build or test options.
+## TARGETS - The bazel targets.
+## ARGS - Arguments for run or sudo.
+##
+## Additionally, the copy target expects a DESTINATION to be provided.
+##
+## For example, to build runsc using this Makefile, you can run:
+## make build OPTIONS="" TARGETS="//runsc"'
+##
+help: ## Shows all targets and help from the Makefile (this message).
+ @grep --no-filename -E '^([a-z.A-Z_-]+:.*?|)##' $(MAKEFILE_LIST) | \
+ awk 'BEGIN {FS = "(:.*?|)## ?"}; { \
+ if (length($$1) > 0) { \
+ printf " \033[36m%-20s\033[0m %s\n", $$1, $$2; \
+ } else { \
+ printf "%s\n", $$2; \
+ } \
+ }'
+build: ## Builds the given $(TARGETS) with the given $(OPTIONS). E.g. make build TARGETS=runsc
+test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test
+copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp
+run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version
+sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v
+.PHONY: help build test copy run sudo
-BAZEL_OPTIONS := build runsc
-bazel: bazel-server
- docker exec -u $(UID):$(GID) -i gvisor-bazel bazel $(BAZEL_OPTIONS)
+# Load all bazel wrappers.
+#
+# This file should define the basic "build", "test", "run" and "sudo" rules, in
+# addition to the $(BRANCH_NAME) variable.
+ifneq (,$(wildcard tools/google.mk))
+include tools/google.mk
+else
+include tools/bazel.mk
+endif
-bazel-alias:
- @echo "alias bazel='docker exec -u $(UID):$(GID) -i gvisor-bazel bazel'"
+##
+## Docker image targets.
+##
+## Images used by the tests must also be built and available locally.
+## The canonical test targets defined below will automatically load
+## relevant images. These can be loaded or built manually via these
+## targets.
+##
+## (*) Note that you may provide an ARCH parameter in order to build
+## and load images from an alternate archiecture (using qemu). When
+## bazel is run as a server, this has the effect of running an full
+## cross-architecture chain, and can produce cross-compiled binaries.
+##
+define images
+$(1)-%: ## Image tool: $(1) a given image (also may use 'all-images').
+ @$(call submake,-C images $$@)
+endef
+rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'.
+$(eval $(call images,rebuild))
+push-...: ## Push the given image. Also may use 'push-all-images'.
+$(eval $(call images,pull))
+pull-...: ## Pull the given image. Also may use 'pull-all-images'.
+$(eval $(call images,push))
+load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'.
+$(eval $(call images,load))
+list-images: ## List all available images.
+ @$(call submake, -C images $$@)
-runsc:
- $(MAKE) BAZEL_OPTIONS="build runsc" bazel
+##
+## Canonical build and test targets.
+##
+## These targets are used by continuous integration and provide
+## convenient entrypoints for testing changes. If you're adding a
+## new subsystem or workflow, consider adding a new target here.
+##
+runsc: ## Builds the runsc binary.
+ @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc")
+.PHONY: runsc
-tests:
- $(MAKE) BAZEL_OPTIONS="test --test_tag_filters runsc_ptrace //test/syscalls/..." bazel
+debian: ## Builds the debian packages.
+ @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc:runsc-debian")
+.PHONY: debian
-unit-tests:
- $(MAKE) BAZEL_OPTIONS="test //pkg/... //runsc/... //tools/..." bazel
+smoke-tests: ## Runs a simple smoke test after build runsc.
+ @$(call submake,run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true")
+.PHONY: smoke-tests
-.PHONY: docker-build bazel-shutdown bazel-server-start bazel-server bazel runsc tests
+unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc.
+ @$(call submake,test TARGETS="pkg/... runsc/... tools/...")
+
+tests: ## Runs all unit tests and syscall tests.
+tests: unit-tests
+ @$(call submake,test TARGETS="test/syscalls/...")
+.PHONY: tests
+
+
+integration-tests: ## Run all standard integration tests.
+integration-tests: docker-tests overlay-tests hostnet-tests swgso-tests
+integration-tests: do-tests kvm-tests root-tests containerd-tests
+.PHONY: integration-tests
+
+network-tests: ## Run all networking integration tests.
+network-tests: iptables-tests packetdrill-tests packetimpact-tests
+.PHONY: network-tests
+
+# Standard integration targets.
+INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test
+
+syscall-%-tests:
+ @$(call submake,test OPTIONS="--test_tag_filters runsc_$* test/syscalls/...")
+
+syscall-native-tests:
+ @$(call submake,test OPTIONS="--test_tag_filters native test/syscalls/...")
+.PHONY: syscall-native-tests
+
+syscall-tests: ## Run all system call tests.
+syscall-tests: syscall-ptrace-tests syscall-kvm-tests syscall-native-tests
+.PHONY: syscall-tests
+
+%-runtime-tests: load-runtimes_%
+ @$(call submake,install-test-runtime)
+ @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*")
+
+do-tests: runsc
+ @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true")
+ @$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true")
+ @$(call submake,sudo TARGETS="//runsc" ARGS="do true")
+.PHONY: do-tests
+
+simple-tests: unit-tests # Compatibility target.
+.PHONY: simple-tests
+
+docker-tests: load-basic-images
+ @$(call submake,install-test-runtime RUNTIME="vfs1")
+ @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)")
+ @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2")
+ @$(call submake,test-runtime RUNTIME="vfs2" TARGETS="$(INTEGRATION_TARGETS)")
+.PHONY: docker-tests
+
+overlay-tests: load-basic-images
+ @$(call submake,install-test-runtime RUNTIME="overlay" ARGS="--overlay")
+ @$(call submake,test-runtime RUNTIME="overlay" TARGETS="$(INTEGRATION_TARGETS)")
+.PHONY: overlay-tests
+
+swgso-tests: load-basic-images
+ @$(call submake,install-test-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false")
+ @$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)")
+.PHONY: swgso-tests
+hostnet-tests: load-basic-images
+ @$(call submake,install-test-runtime RUNTIME="hostnet" ARGS="--network=host")
+ @$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false" TARGETS="$(INTEGRATION_TARGETS)")
+.PHONY: hostnet-tests
+
+kvm-tests: load-basic-images
+ @(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
+ @if ! [[ -w /dev/kvm ]]; then sudo chmod a+rw /dev/kvm; fi
+ @$(call submake,test TARGETS="//pkg/sentry/platform/kvm:kvm_test")
+ @$(call submake,install-test-runtime RUNTIME="kvm" ARGS="--platform=kvm")
+ @$(call submake,test-runtime RUNTIME="kvm" TARGETS="$(INTEGRATION_TARGETS)")
+.PHONY: kvm-tests
+
+iptables-tests: load-iptables
+ @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test")
+ @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw")
+ @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test")
+.PHONY: iptables-tests
+
+packetdrill-tests: load-packetdrill
+ @$(call submake,install-test-runtime RUNTIME="packetdrill")
+ @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetdrill, tests(//...))')")
+.PHONY: packetdrill-tests
+
+packetimpact-tests: load-packetimpact
+ @sudo modprobe iptable_filter ip6table_filter
+ @$(call submake,install-test-runtime RUNTIME="packetimpact")
+ @$(call submake,test-runtime OPTIONS="--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3" RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetimpact, tests(//...))')")
+.PHONY: packetimpact-tests
+
+root-tests: load-basic-images
+ @$(call submake,install-test-runtime)
+ @$(call submake,sudo TARGETS="//test/root:root_test" ARGS="-test.v")
+.PHONY: root-tests
+
+# Specific containerd version tests.
+containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd install-test-runtime
+ @CONTAINERD_VERSION=$* $(MAKE) sudo TARGETS="tools/installers:containerd"
+ @$(MAKE) sudo TARGETS="tools/installers:shim"
+ @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="-test.v"
+
+# Note that we can't run containerd-test-1.1.8 tests here.
+#
+# Containerd 1.1.8 should work, but because of a bug in loading images locally
+# (https://github.com/kubernetes-sigs/cri-tools/issues/421), we are unable to
+# actually drive the tests. The v1 API is tested exclusively through 1.2.13.
+containerd-tests: ## Runs all supported containerd version tests.
+containerd-tests: containerd-test-1.2.13
+containerd-tests: containerd-test-1.3.4
+containerd-tests: containerd-test-1.4.0-beta.0
+
+##
+## Website & documentation helpers.
+##
+## The website is built from repository documentation and wrappers, using
+## using a locally-defined Docker image (see images/jekyll). The following
+## variables may be set when using website-push:
+## WEBSITE_IMAGE - The name of the container image.
+## WEBSITE_SERVICE - The backend service.
+## WEBSITE_PROJECT - The project id to use.
+## WEBSITE_REGION - The region to deploy to.
+##
+WEBSITE_IMAGE := gcr.io/gvisordev/gvisordev
+WEBSITE_SERVICE := gvisordev
+WEBSITE_PROJECT := gvisordev
+WEBSITE_REGION := us-central1
+
+website-build: load-jekyll ## Build the site image locally.
+ @$(call submake,run TARGETS="//website:website")
+.PHONY: website-build
+
+website-server: website-build ## Run a local server for development.
+ @docker run -i -p 8080:8080 gvisor.dev/images/website
+.PHONY: website-server
+
+website-push: website-build ## Push a new image and update the service.
+ @docker tag gvisor.dev/images/website $(WEBSITE_IMAGE) && docker push $(WEBSITE_IMAGE)
+.PHONY: website-push
+
+website-deploy: website-push ## Deploy a new version of the website.
+ @gcloud run deploy $(WEBSITE_SERVICE) --platform=managed --region=$(WEBSITE_REGION) --project=$(WEBSITE_PROJECT) --image=$(WEBSITE_IMAGE)
+.PHONY: website-deploy
+
+##
+## Repository builders.
+##
+## This builds a local apt repository. The following variables may be set:
+## RELEASE_ROOT - The repository root (default: "repo" directory).
+## RELEASE_KEY - The repository GPG private key file (default: dummy key is created).
+## RELEASE_NIGHTLY - Set to true if a nightly release (default: false).
+## RELEASE_COMMIT - The commit or Change-Id for the release (needed for tag).
+## RELEASE_NAME - The name of the release in the proper format (needed for tag).
+## RELEASE_NOTES - The file containing release notes (needed for tag).
+##
+RELEASE_ROOT := $(CURDIR)/repo
+RELEASE_KEY := repo.key
+RELEASE_NIGHTLY := false
+RELEASE_COMMIT :=
+RELEASE_NAME :=
+RELEASE_NOTES :=
+
+GPG_TEST_OPTIONS := $(shell if gpg --pinentry-mode loopback --version >/dev/null 2>&1; then echo --pinentry-mode loopback; fi)
+$(RELEASE_KEY):
+ @echo "WARNING: Generating a key for testing ($@); don't use this."
+ T=$$(mktemp /tmp/keyring.XXXXXX); \
+ C=$$(mktemp /tmp/config.XXXXXX); \
+ echo Key-Type: DSA >> $$C && \
+ echo Key-Length: 1024 >> $$C && \
+ echo Name-Real: Test >> $$C && \
+ echo Name-Email: test@example.com >> $$C && \
+ echo Expire-Date: 0 >> $$C && \
+ echo %commit >> $$C && \
+ gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --keyring $$T --no-tty --gen-key $$C && \
+ gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --keyring $$T --secret-keyring $$T > $@; \
+ rc=$$?; rm -f $$T $$C; exit $$rc
+
+release: $(RELEASE_KEY) ## Builds a release.
+ @mkdir -p $(RELEASE_ROOT)
+ @T=$$(mktemp -d /tmp/release.XXXXXX); \
+ $(call submake,copy TARGETS="runsc" DESTINATION=$$T) && \
+ $(call submake,copy TARGETS="runsc:runsc-debian" DESTINATION=$$T) && \
+ NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \
+ rc=$$?; rm -rf $$T; exit $$rc
+.PHONY: release
+
+tag: ## Creates and pushes a release tag.
+ @tools/tag_release.sh "$(RELEASE_COMMIT)" "$(RELEASE_NAME)" "$(RELEASE_NOTES)"
+.PHONY: tag
+
+##
+## Development helpers and tooling.
+##
+## These targets faciliate local development by automatically
+## installing and configuring a runtime. Several variables may
+## be used here to tweak the installation:
+## RUNTIME - The name of the installed runtime (default: branch).
+## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME).
+## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc).
+## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs).
+## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%).
+##
+ifeq (,$(BRANCH_NAME))
+RUNTIME := runsc
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
+else
+RUNTIME := $(BRANCH_NAME)
+RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME)
+endif
+RUNTIME_BIN := $(RUNTIME_DIR)/runsc
+RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs
+RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+
+dev: ## Installs a set of local runtimes. Requires sudo.
+ @$(call submake,refresh ARGS="--net-raw")
+ @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="--net-raw")
+ @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets")
+ @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-p" ARGS="--net-raw --profile")
+ @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2")
+ @sudo systemctl restart docker
+.PHONY: dev
+
+refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'install-test-runtime' first.
+ @mkdir -p "$(RUNTIME_DIR)"
+ @$(call submake,copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)")
+.PHONY: refresh
+
+install-test-runtime: ## Installs the runtime for testing. Requires sudo.
+ @$(call submake,refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME --debug --strace --log-packets $(ARGS)")
+ @$(call submake,configure RUNTIME_NAME=runsc)
+ @$(call submake,configure RUNTIME_NAME="$(RUNTIME)")
+ @sudo systemctl restart docker
+ @if [[ -f /etc/docker/daemon.json ]]; then \
+ sudo chmod 0755 /etc/docker && \
+ sudo chmod 0644 /etc/docker/daemon.json; \
+ fi
+.PHONY: install-test-runtime
+
+configure: ## Configures a single runtime. Requires sudo. Typically called from dev or install-test-runtime.
+ @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME_NAME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS)
+ @echo -e "$(INFO) Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)"
+ @echo -e "$(INFO) Logs are in: $(RUNTIME_LOG_DIR)"
+ @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)"
+.PHONY: configure
+
+test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided.
+ @$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)")
+.PHONY: test-runtime
diff --git a/README.md b/README.md
index 5ac6f9046..ed9e0e92b 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,12 @@
![gVisor](g3doc/logo.png)
-[![Status](https://storage.googleapis.com/gvisor-build-badges/build.svg)](https://storage.googleapis.com/gvisor-build-badges/build.html)
+![](https://github.com/google/gvisor/workflows/Build/badge.svg)
[![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community)
## What is gVisor?
-**gVisor** is a user-space kernel, written in Go, that implements a substantial
-portion of the Linux system surface. It includes an
+**gVisor** is an application kernel, written in Go, that implements a
+substantial portion of the Linux system surface. It includes an
[Open Container Initiative (OCI)][oci] runtime called `runsc` that provides an
isolation boundary between the application and the host kernel. The `runsc`
runtime integrates with Docker and Kubernetes, making it simple to run sandboxed
@@ -15,16 +15,17 @@ containers.
## Why does gVisor exist?
Containers are not a [**sandbox**][sandbox]. While containers have
-revolutionized how we develop, package, and deploy applications, running
-untrusted or potentially malicious code without additional isolation is not a
-good idea. The efficiency and performance gains from using a single, shared
-kernel also mean that container escape is possible with a single vulnerability.
-
-gVisor is a user-space kernel for containers. It limits the host kernel surface
-accessible to the application while still giving the application access to all
-the features it expects. Unlike most kernels, gVisor does not assume or require
-a fixed set of physical resources; instead, it leverages existing host kernel
-functionality and runs as a normal user-space process. In other words, gVisor
+revolutionized how we develop, package, and deploy applications, using them to
+run untrusted or potentially malicious code without additional isolation is not
+a good idea. While using a single, shared kernel allows for efficiency and
+performance gains, it also means that container escape is possible with a single
+vulnerability.
+
+gVisor is an application kernel for containers. It limits the host kernel
+surface accessible to the application while still giving the application access
+to all the features it expects. Unlike most kernels, gVisor does not assume or
+require a fixed set of physical resources; instead, it leverages existing host
+kernel functionality and runs as a normal process. In other words, gVisor
implements Linux by way of Linux.
gVisor should not be confused with technologies and tools to harden containers
@@ -39,73 +40,43 @@ be found at [gvisor.dev][gvisor-dev].
## Installing from source
-gVisor currently requires x86\_64 Linux to build, though support for other
-architectures may become available in the future.
+gVisor builds on x86_64 and ARM64. Other architectures may become available in
+the future.
+
+For the purposes of these instructions, [bazel][bazel] and other build
+dependencies are wrapped in a build container. It is possible to use
+[bazel][bazel] directly, or type `make help` for standard targets.
### Requirements
Make sure the following dependencies are installed:
* Linux 4.14.77+ ([older linux][old-linux])
-* [git][git]
-* [Bazel][bazel] 0.28.0+
-* [Python][python]
* [Docker version 17.09.0 or greater][docker]
-* Gold linker (e.g. `binutils-gold` package on Ubuntu)
### Building
Build and install the `runsc` binary:
-```
-bazel build runsc
-sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin
-```
-
-If you don't want to install bazel on your system, you can build runsc in a
-Docker container:
-
-```
+```sh
make runsc
sudo cp ./bazel-bin/runsc/linux_amd64_pure_stripped/runsc /usr/local/bin
```
### Testing
-The test suite can be run with Bazel:
+To run standard test suites, you can use:
-```
-bazel test //...
-```
-
-or in a Docker container:
-
-```
+```sh
make unit-tests
make tests
```
-### Using remote execution
-
-If you have a [Remote Build Execution][rbe] environment, you can use it to speed
-up build and test cycles.
-
-You must authenticate with the project first:
-
-```
-gcloud auth application-default login --no-launch-browser
-```
-
-Then invoke bazel with the following flags:
+To run specific tests, you can specify the target:
+```sh
+make test TARGETS="//runsc:version_test"
```
---config=remote
---project_id=$PROJECT
---remote_instance_name=projects/$PROJECT/instances/default_instance
-```
-
-You can also add those flags to your local ~/.bazelrc to avoid needing to
-specify them each time on the command line.
### Using `go get`
@@ -113,12 +84,19 @@ This project uses [bazel][bazel] to build and manage dependencies. A synthetic
`go` branch is maintained that is compatible with standard `go` tooling for
convenience.
-For example, to build `runsc` directly from this branch:
+For example, to build and install `runsc` directly from this branch:
-```
+```sh
echo "module runsc" > go.mod
GO111MODULE=on go get gvisor.dev/gvisor/runsc@go
-CGO_ENABLED=0 GO111MODULE=on go install gvisor.dev/gvisor/runsc
+CGO_ENABLED=0 GO111MODULE=on sudo -E go build -o /usr/local/bin/runsc gvisor.dev/gvisor/runsc
+```
+
+Subsequently, you can build and install the shim binaries for `containerd`:
+
+```sh
+GO111MODULE=on sudo -E go build -o /usr/local/bin/gvisor-containerd-shim gvisor.dev/gvisor/shim/v1
+GO111MODULE=on sudo -E go build -o /usr/local/bin/containerd-shim-runsc-v1 gvisor.dev/gvisor/shim/v2
```
Note that this branch is supported in a best effort capacity, and direct
@@ -127,7 +105,7 @@ development on this branch is not supported. Development should occur on the
## Community & Governance
-The governance model is documented in our [community][community] repository.
+See [GOVERNANCE.md](GOVERNANCE.md) for project governance information.
The [gvisor-users mailing list][gvisor-users-list] and
[gvisor-dev mailing list][gvisor-dev-list] are good starting points for
@@ -142,14 +120,10 @@ See [SECURITY.md](SECURITY.md).
See [Contributing.md](CONTRIBUTING.md).
[bazel]: https://bazel.build
-[community]: https://gvisor.googlesource.com/community
[docker]: https://www.docker.com
-[git]: https://git-scm.com
[gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users
+[gvisor-dev]: https://gvisor.dev
[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev
[oci]: https://www.opencontainers.org
[old-linux]: https://gvisor.dev/docs/user_guide/networking/#gso
-[python]: https://python.org
-[rbe]: https://blog.bazel.build/2018/10/05/remote-build-execution.html
[sandbox]: https://en.wikipedia.org/wiki/Sandbox_(computer_security)
-[gvisor-dev]: https://gvisor.dev
diff --git a/SECURITY.md b/SECURITY.md
index 154d68cb3..a96843895 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -5,7 +5,6 @@ the [gvisor-security mailing list][gvisor-security-list]. You should receive a
prompt response, typically within 48 hours.
Policies for security list access, vulnerability embargo, and vulnerability
-disclosure are outlined in the [community][community] repository.
+disclosure are outlined in the [governance policy](GOVERNANCE.md).
-[community]: https://gvisor.googlesource.com/community
[gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security
diff --git a/WORKSPACE b/WORKSPACE
index 57e6f3558..6dc060bd5 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,37 +1,90 @@
-# Load go bazel rules and gazelle.
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
+load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
+
+# Bazel/starlark utilities.
+http_archive(
+ name = "bazel_skylib",
+ sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
+ ],
+)
+
+load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
+bazel_skylib_workspace()
+
+# Load go bazel rules and gazelle.
+#
+# Note that this repository actually patches some other Go repositories as it
+# loads it, in order to limit visibility. We hack this process by patching the
+# patch used by the Go rules, turning the trick against itself.
http_archive(
name = "io_bazel_rules_go",
- sha256 = "842ec0e6b4fbfdd3de6150b61af92901eeb73681fd4d185746644c338f51d4c0",
+ patch_args = ["-p1"],
+ patches = [
+ "//tools/nogo:io_bazel_rules_go-visibility.patch",
+ ],
+ sha256 = "db2b2d35293f405430f553bc7a865a8749a8ef60c30287e90d2b278c32771afe",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.20.1/rules_go-v0.20.1.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/v0.20.1/rules_go-v0.20.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz",
],
)
http_archive(
name = "bazel_gazelle",
- sha256 = "41bff2a0b32b02f20c227d234aa25ef3783998e5453f7eade929704dcff7cd4b",
+ sha256 = "d8c45ee70ec39a57e7a05e5027c32b1576cc7f16d9dd37135b0eddde45cf1b10",
+ urls = [
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "io_bazel_rules_go_bazel3", # To replace the above.
+ patch_args = ["-p1"],
+ patches = [
+ "//tools/nogo:io_bazel_rules_go-visibility.patch",
+ ],
+ sha256 = "87f0fb9747854cb76a0a82430adccb6269f7d394237104a4523b51061c469171",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "bazel_gazelle_bazel3", # To replace the above.
+ sha256 = "bfd86b3cbe855d6c16c6fce60d76bd51f5c8dbc9cfcaef7a2bb5c1aafd0710e8",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.0/bazel-gazelle-v0.19.0.tar.gz",
- "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.19.0/bazel-gazelle-v0.19.0.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz",
],
)
-load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_toolchains")
+load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies")
go_rules_dependencies()
-go_register_toolchains(
- go_version = "1.13.3",
- nogo = "@//:nogo",
-)
+go_register_toolchains(go_version = "1.14.2")
load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
gazelle_dependencies()
+# The com_google_protobuf repository below would trigger downloading a older
+# version of org_golang_x_sys. If putting this repository statment in a place
+# after that of the com_google_protobuf, this statement will not work as
+# expectd to download a new version of org_golang_x_sys.
+go_repository(
+ name = "org_golang_x_sys",
+ importpath = "golang.org/x/sys",
+ sum = "h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=",
+ version = "v0.0.0-20200302150141-5c8b2ff67527",
+)
+
# Load C++ rules.
http_archive(
name = "rules_cc",
@@ -45,28 +98,40 @@ http_archive(
# Load protobuf dependencies.
http_archive(
- name = "com_google_protobuf",
- sha256 = "532d2575d8c0992065bb19ec5fba13aa3683499726f6055c11b474f91a00bb0c",
- strip_prefix = "protobuf-7f520092d9050d96fb4b707ad11a51701af4ce49",
+ name = "rules_proto",
+ sha256 = "602e7161d9195e50246177e7c55b2f39950a9cf7366f74ed5f22fd45750cd208",
+ strip_prefix = "rules_proto-97d8af4dc474595af3900dd85cb3a29ad28cc313",
urls = [
- "https://mirror.bazel.build/github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
- "https://github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
+ "https://github.com/bazelbuild/rules_proto/archive/97d8af4dc474595af3900dd85cb3a29ad28cc313.tar.gz",
],
)
-load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
+load("@rules_proto//proto:repositories.bzl", "rules_proto_dependencies", "rules_proto_toolchains")
+
+rules_proto_dependencies()
-protobuf_deps()
+rules_proto_toolchains()
# Load bazel_toolchain to support Remote Build Execution.
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "a019fbd579ce5aed0239de865b2d8281dbb809efd537bf42e0d366783e8dec65",
- strip_prefix = "bazel-toolchains-0.29.2",
+ sha256 = "239a1a673861eabf988e9804f45da3b94da28d1aff05c373b013193c315d9d9e",
+ strip_prefix = "bazel-toolchains-3.0.1",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "bazel_toolchains_bazel3", # To replace the above.
+ sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48",
+ strip_prefix = "bazel-toolchains-3.1.1",
+ urls = [
+ "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz",
],
)
@@ -85,12 +150,82 @@ load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
rules_pkg_dependencies()
-# External repositories, in sorted order.
+# Load C++ grpc rules.
+http_archive(
+ name = "com_github_grpc_grpc",
+ sha256 = "2fcb7f1ab160d6fd3aaade64520be3e5446fc4c6fa7ba6581afdc4e26094bd81",
+ strip_prefix = "grpc-1.26.0",
+ urls = [
+ "https://github.com/grpc/grpc/archive/v1.26.0.tar.gz",
+ ],
+)
+
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+
+grpc_deps()
+
+load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
+
+grpc_extra_deps()
+
+# System Call test dependencies.
+http_archive(
+ name = "com_google_absl",
+ sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93",
+ strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e",
+ urls = [
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "com_google_googletest",
+ sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912",
+ strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
+ "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "com_google_benchmark",
+ sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a",
+ strip_prefix = "benchmark-1.5.0",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ "https://github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ ],
+)
+
+# External Go repositories.
+#
+# Unfortunately, gazelle will automatically parse go modules in the
+# repositories and generate new go_repository stanzas. These may not respect
+# pins that we have in go.mod or below. So order actually matters here.
+
+go_repository(
+ name = "com_github_sirupsen_logrus",
+ importpath = "github.com/sirupsen/logrus",
+ replace = "github.com/Sirupsen/logrus",
+ sum = "h1:cWjBmzJnL1sO88XdqJYmq7aiWClqXIQQMJ3Utgy1f+I=",
+ version = "v1.4.2",
+)
+
+go_repository(
+ name = "com_github_containerd_containerd",
+ build_file_proto_mode = "disable",
+ importpath = "github.com/containerd/containerd",
+ sum = "h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=",
+ version = "v1.3.4",
+)
+
go_repository(
name = "com_github_cenkalti_backoff",
importpath = "github.com/cenkalti/backoff",
- sum = "h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=",
- version = "v0.0.0-20190506075156-2146c9339422",
+ sum = "h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4=",
+ version = "v1.1.1-0.20190506075156-2146c9339422",
)
go_repository(
@@ -108,38 +243,45 @@ go_repository(
)
go_repository(
- name = "com_github_google_go-cmp",
- importpath = "github.com/google/go-cmp",
- sum = "h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=",
- version = "v0.2.0",
-)
-
-go_repository(
name = "com_github_google_subcommands",
importpath = "github.com/google/subcommands",
- sum = "h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=",
- version = "v0.0.0-20190508160503-636abe8753b8",
+ sum = "h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM=",
+ version = "v1.0.2-0.20190508160503-636abe8753b8",
)
go_repository(
name = "com_github_google_uuid",
importpath = "github.com/google/uuid",
- sum = "h1:rXQlD9GXkjA/PQZhmEaF/8Pj/sJfdZJK7GJG0gkS8I0=",
- version = "v0.0.0-20171129191014-dec09d789f3d",
+ sum = "h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_kr_pretty",
+ importpath = "github.com/kr/pretty",
+ sum = "h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=",
+ version = "v0.1.0",
)
go_repository(
name = "com_github_kr_pty",
importpath = "github.com/kr/pty",
- sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=",
- version = "v1.1.1",
+ sum = "h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M=",
+ version = "v1.1.4-0.20190131011033-7dc38fb350b1",
)
go_repository(
- name = "com_github_opencontainers_runtime-spec",
- importpath = "github.com/opencontainers/runtime-spec",
- sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=",
- version = "v0.1.2-0.20171211145439-b2d941ef6a78",
+ name = "com_github_kr_text",
+ importpath = "github.com/kr/text",
+ sum = "h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_mohae_deepcopy",
+ importpath = "github.com/mohae/deepcopy",
+ sum = "h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c=",
+ version = "v0.0.0-20170308212314-bb9b5e7adda9",
)
go_repository(
@@ -152,68 +294,79 @@ go_repository(
go_repository(
name = "com_github_vishvananda_netlink",
importpath = "github.com/vishvananda/netlink",
- sum = "h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=",
- version = "v1.0.1-0.20190318003149-adb577d4a45e",
+ sum = "h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w=",
+ version = "v1.0.1-0.20190930145447-2ec5bdc52b86",
)
go_repository(
- name = "com_github_vishvananda_netns",
- importpath = "github.com/vishvananda/netns",
- sum = "h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=",
- version = "v0.0.0-20171111001504-be1fbeda1936",
+ name = "org_golang_google_grpc",
+ build_file_proto_mode = "disable",
+ importpath = "google.golang.org/grpc",
+ sum = "h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo=",
+ version = "v1.29.0",
)
go_repository(
- name = "org_golang_x_crypto",
- importpath = "golang.org/x/crypto",
- sum = "h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=",
- version = "v0.0.0-20190308221718-c2843e01d9a2",
+ name = "in_gopkg_check_v1",
+ importpath = "gopkg.in/check.v1",
+ sum = "h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=",
+ version = "v1.0.0-20180628173108-788fd7840127",
)
go_repository(
- name = "org_golang_x_net",
- importpath = "golang.org/x/net",
- sum = "h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=",
- version = "v0.0.0-20190311183353-d8887717615a",
+ name = "org_golang_x_crypto",
+ importpath = "golang.org/x/crypto",
+ sum = "h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=",
+ version = "v0.0.0-20200622213623-75b288015ac9",
)
go_repository(
- name = "org_golang_x_text",
- importpath = "golang.org/x/text",
- sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=",
+ name = "org_golang_x_mod",
+ importpath = "golang.org/x/mod",
+ sum = "h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=",
version = "v0.3.0",
)
go_repository(
- name = "org_golang_x_tools",
- commit = "36563e24a262",
- importpath = "golang.org/x/tools",
+ name = "org_golang_x_net",
+ importpath = "golang.org/x/net",
+ sum = "h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=",
+ version = "v0.0.0-20200625001655-4c5254603344",
)
go_repository(
name = "org_golang_x_sync",
importpath = "golang.org/x/sync",
- sum = "h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=",
- version = "v0.0.0-20190423024810-112230192c58",
+ sum = "h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=",
+ version = "v0.0.0-20200625203802-6e8e738ad208",
)
go_repository(
- name = "org_golang_x_sys",
- importpath = "golang.org/x/sys",
- sum = "h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=",
- version = "v0.0.0-20190215142949-d0b11bdaac8a",
+ name = "org_golang_x_text",
+ importpath = "golang.org/x/text",
+ sum = "h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=",
+ version = "v0.3.2",
)
go_repository(
name = "org_golang_x_time",
- commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
+ sum = "h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=",
+ version = "v0.0.0-20191024005414-555d28b269f0",
)
go_repository(
name = "org_golang_x_tools",
- commit = "aa82965741a9fecd12b026fbb3d3c6ed3231b8f8",
importpath = "golang.org/x/tools",
+ sum = "h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE=",
+ version = "v0.0.0-20200707200213-416e8f4faf8a",
+)
+
+go_repository(
+ name = "org_golang_x_xerrors",
+ importpath = "golang.org/x/xerrors",
+ sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=",
+ version = "v0.0.0-20191204190536-9bdfabe68543",
)
go_repository(
@@ -226,27 +379,700 @@ go_repository(
go_repository(
name = "com_github_golang_protobuf",
importpath = "github.com/golang/protobuf",
- sum = "h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=",
+ sum = "h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=",
+ version = "v1.4.2",
+)
+
+go_repository(
+ name = "org_golang_x_oauth2",
+ importpath = "golang.org/x/oauth2",
+ sum = "h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=",
+ version = "v0.0.0-20200107190931-bf48bf16ab8d",
+)
+
+go_repository(
+ name = "com_github_docker_docker",
+ importpath = "github.com/docker/docker",
+ sum = "h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w=",
+ version = "v1.4.2-0.20191028175130-9e7d5ac5ea55",
+)
+
+go_repository(
+ name = "com_github_docker_go_connections",
+ importpath = "github.com/docker/go-connections",
+ sum = "h1:3lOnM9cSzgGwx8VfK/NGOW5fLQ0GjIlCkaktF+n1M6o=",
+ version = "v0.3.0",
+)
+
+go_repository(
+ name = "com_github_pkg_errors",
+ importpath = "github.com/pkg/errors",
+ sum = "h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=",
+ version = "v0.9.1",
+)
+
+go_repository(
+ name = "com_github_docker_go_units",
+ importpath = "github.com/docker/go-units",
+ sum = "h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=",
+ version = "v0.4.0",
+)
+
+go_repository(
+ name = "com_github_opencontainers_go_digest",
+ importpath = "github.com/opencontainers/go-digest",
+ sum = "h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_docker_distribution",
+ importpath = "github.com/docker/distribution",
+ sum = "h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE=",
+ version = "v2.7.1-0.20190205005809-0d3efadf0154+incompatible",
+)
+
+go_repository(
+ name = "com_github_davecgh_go_spew",
+ importpath = "github.com/davecgh/go-spew",
+ sum = "h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=",
+ version = "v1.1.1",
+)
+
+go_repository(
+ name = "com_github_konsorten_go_windows_terminal_sequences",
+ importpath = "github.com/konsorten/go-windows-terminal-sequences",
+ sum = "h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=",
+ version = "v1.0.2",
+)
+
+go_repository(
+ name = "com_github_pmezard_go_difflib",
+ importpath = "github.com/pmezard/go-difflib",
+ sum = "h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_stretchr_testify",
+ importpath = "github.com/stretchr/testify",
+ sum = "h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=",
+ version = "v1.4.0",
+)
+
+go_repository(
+ name = "com_github_opencontainers_image_spec",
+ importpath = "github.com/opencontainers/image-spec",
+ sum = "h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI=",
+ version = "v1.0.1",
+)
+
+go_repository(
+ name = "com_github_microsoft_go_winio",
+ importpath = "github.com/Microsoft/go-winio",
+ sum = "h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA=",
+ version = "v0.4.15-0.20190919025122-fc70bd9a86b5",
+)
+
+go_repository(
+ name = "com_github_stretchr_objx",
+ importpath = "github.com/stretchr/objx",
+ sum = "h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=",
+ version = "v0.1.1",
+)
+
+go_repository(
+ name = "org_uber_go_atomic",
+ importpath = "go.uber.org/atomic",
+ sum = "h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=",
+ version = "v1.6.0",
+)
+
+go_repository(
+ name = "org_uber_go_multierr",
+ importpath = "go.uber.org/multierr",
+ sum = "h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=",
+ version = "v1.2.0",
+)
+
+go_repository(
+ name = "com_google_cloud_go",
+ importpath = "cloud.google.com/go",
+ sum = "h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo=",
+ version = "v0.52.1-0.20200122224058-0482b626c726",
+)
+
+go_repository(
+ name = "io_opencensus_go",
+ importpath = "go.opencensus.io",
+ sum = "h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs=",
+ version = "v0.22.2",
+)
+
+go_repository(
+ name = "co_honnef_go_tools",
+ importpath = "honnef.co/go/tools",
+ sum = "h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=",
+ version = "v0.0.1-2019.2.3",
+)
+
+go_repository(
+ name = "com_github_burntsushi_toml",
+ importpath = "github.com/BurntSushi/toml",
+ sum = "h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=",
+ version = "v0.3.1",
+)
+
+go_repository(
+ name = "com_github_census_instrumentation_opencensus_proto",
+ importpath = "github.com/census-instrumentation/opencensus-proto",
+ sum = "h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk=",
+ version = "v0.2.1",
+)
+
+go_repository(
+ name = "com_github_client9_misspell",
+ importpath = "github.com/client9/misspell",
+ sum = "h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=",
+ version = "v0.3.4",
+)
+
+go_repository(
+ name = "com_github_cncf_udpa_go",
+ importpath = "github.com/cncf/udpa/go",
+ sum = "h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU=",
+ version = "v0.0.0-20191209042840-269d4d468f6f",
+)
+
+go_repository(
+ name = "com_github_containerd_cgroups",
+ build_file_proto_mode = "disable",
+ importpath = "github.com/containerd/cgroups",
+ sum = "h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY=",
+ version = "v0.0.0-20181219155423-39b18af02c41",
+)
+
+go_repository(
+ name = "com_github_containerd_console",
+ importpath = "github.com/containerd/console",
+ sum = "h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc=",
+ version = "v0.0.0-20191206165004-02ecf6a7291e",
+)
+
+go_repository(
+ name = "com_github_containerd_continuity",
+ importpath = "github.com/containerd/continuity",
+ sum = "h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k=",
+ version = "v0.0.0-20200710164510-efbc4488d8fe",
+)
+
+go_repository(
+ name = "com_github_containerd_fifo",
+ importpath = "github.com/containerd/fifo",
+ sum = "h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw=",
+ version = "v0.0.0-20191213151349-ff969a566b00",
+)
+
+go_repository(
+ name = "com_github_containerd_go_runc",
+ importpath = "github.com/containerd/go-runc",
+ sum = "h1:PRTagVMbJcCezLcHXe8UJvR1oBzp2lG3CEumeFOLOds=",
+ version = "v0.0.0-20200220073739-7016d3ce2328",
+)
+
+go_repository(
+ name = "com_github_containerd_ttrpc",
+ importpath = "github.com/containerd/ttrpc",
+ sum = "h1:+jgiLE5QylzgADj0Yldb4id1NQNRrDOROj7KDvY9PEc=",
+ version = "v0.0.0-20200121165050-0be804eadb15",
+)
+
+go_repository(
+ name = "com_github_coreos_go_systemd",
+ importpath = "github.com/coreos/go-systemd",
+ sum = "h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU=",
+ version = "v0.0.0-20191104093116-d3cd4ed1dbcf",
+)
+
+go_repository(
+ name = "com_github_docker_go_events",
+ importpath = "github.com/docker/go-events",
+ sum = "h1:+pKlWGMw7gf6bQ+oDZB4KHQFypsfjYlq/C4rfL7D3g8=",
+ version = "v0.0.0-20190806004212-e31b211e4f1c",
+)
+
+go_repository(
+ name = "com_github_dustin_go_humanize",
+ importpath = "github.com/dustin/go-humanize",
+ sum = "h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=",
+ version = "v0.0.0-20171111073723-bb3d318650d4",
+)
+
+go_repository(
+ name = "com_github_envoyproxy_go_control_plane",
+ importpath = "github.com/envoyproxy/go-control-plane",
+ sum = "h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E=",
+ version = "v0.9.4",
+)
+
+go_repository(
+ name = "com_github_envoyproxy_protoc_gen_validate",
+ importpath = "github.com/envoyproxy/protoc-gen-validate",
+ sum = "h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_fsnotify_fsnotify",
+ importpath = "github.com/fsnotify/fsnotify",
+ sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=",
+ version = "v1.4.7",
+)
+
+go_repository(
+ name = "com_github_godbus_dbus",
+ importpath = "github.com/godbus/dbus",
+ sum = "h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=",
+ version = "v0.0.0-20190422162347-ade71ed3457e",
+)
+
+go_repository(
+ name = "com_github_gogo_googleapis",
+ importpath = "github.com/gogo/googleapis",
+ sum = "h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI=",
+ version = "v1.4.0",
+)
+
+go_repository(
+ name = "com_github_gogo_protobuf",
+ importpath = "github.com/gogo/protobuf",
+ sum = "h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls=",
version = "v1.3.1",
)
-# System Call test dependencies.
-http_archive(
- name = "com_google_absl",
- sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93",
- strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e",
- urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
- ],
+go_repository(
+ name = "com_github_golang_glog",
+ importpath = "github.com/golang/glog",
+ sum = "h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=",
+ version = "v0.0.0-20160126235308-23def4e6c14b",
)
-http_archive(
- name = "com_google_googletest",
- sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912",
- strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34",
- urls = [
- "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
- "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
- ],
+go_repository(
+ name = "com_github_google_go_cmp",
+ importpath = "github.com/google/go-cmp",
+ sum = "h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=",
+ version = "v0.5.0",
+)
+
+go_repository(
+ name = "com_github_google_go_github_v28",
+ importpath = "github.com/google/go-github/v28",
+ sum = "h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU=",
+ version = "v28.1.2-0.20191108005307-e555eab49ce8",
+)
+
+go_repository(
+ name = "com_github_google_go_querystring",
+ importpath = "github.com/google/go-querystring",
+ sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_hashicorp_golang_lru",
+ importpath = "github.com/hashicorp/golang-lru",
+ sum = "h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU=",
+ version = "v0.5.1",
+)
+
+go_repository(
+ name = "com_github_hpcloud_tail",
+ importpath = "github.com/hpcloud/tail",
+ sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_inconshreveable_mousetrap",
+ importpath = "github.com/inconshreveable/mousetrap",
+ sum = "h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_kisielk_errcheck",
+ importpath = "github.com/kisielk/errcheck",
+ sum = "h1:reN85Pxc5larApoH1keMBiu2GWtPqXQ1nc9gx+jOU+E=",
+ version = "v1.2.0",
+)
+
+go_repository(
+ name = "com_github_kisielk_gotool",
+ importpath = "github.com/kisielk/gotool",
+ sum = "h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_microsoft_hcsshim",
+ importpath = "github.com/Microsoft/hcsshim",
+ sum = "h1:ZfF0+zZeYdzMIVMZHKtDKJvLHj76XCuVae/jNkjj0IA=",
+ version = "v0.8.6",
+)
+
+go_repository(
+ name = "com_github_onsi_ginkgo",
+ importpath = "github.com/onsi/ginkgo",
+ sum = "h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=",
+ version = "v1.10.1",
+)
+
+go_repository(
+ name = "com_github_onsi_gomega",
+ importpath = "github.com/onsi/gomega",
+ sum = "h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=",
+ version = "v1.7.0",
+)
+
+go_repository(
+ name = "com_github_opencontainers_runc",
+ importpath = "github.com/opencontainers/runc",
+ sum = "h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y=",
+ version = "v0.1.1",
+)
+
+go_repository(
+ name = "com_github_opencontainers_runtime_spec",
+ importpath = "github.com/opencontainers/runtime-spec",
+ sum = "h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8=",
+ version = "v1.0.2-0.20181111125026-1722abf79c2f",
+)
+
+go_repository(
+ name = "com_github_pborman_uuid",
+ importpath = "github.com/pborman/uuid",
+ sum = "h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g=",
+ version = "v1.2.0",
+)
+
+go_repository(
+ name = "com_github_prometheus_client_model",
+ importpath = "github.com/prometheus/client_model",
+ sum = "h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=",
+ version = "v0.0.0-20190812154241-14fe0d1b01d4",
+)
+
+go_repository(
+ name = "com_github_prometheus_procfs",
+ importpath = "github.com/prometheus/procfs",
+ sum = "h1:Lo6mRUjdS99f3zxYOUalftWHUoOGaDRqFk1+j0Q57/I=",
+ version = "v0.0.0-20190522114515-bc1a522cf7b1",
+)
+
+go_repository(
+ name = "com_github_spf13_cobra",
+ importpath = "github.com/spf13/cobra",
+ sum = "h1:GQkkv3XSnxhAMjdq2wLfEnptEVr+2BNvmHizILHn+d4=",
+ version = "v0.0.2-0.20171109065643-2da4a54c5cee",
+)
+
+go_repository(
+ name = "com_github_spf13_pflag",
+ importpath = "github.com/spf13/pflag",
+ sum = "h1:j8jxLbQ0+T1DFggy6XoGvyUnrJWPR/JybflPvu5rwS4=",
+ version = "v1.0.1-0.20171106142849-4c012f6dcd95",
+)
+
+go_repository(
+ name = "com_github_urfave_cli",
+ importpath = "github.com/urfave/cli",
+ sum = "h1:MCfT24H3f//U5+UCrZp1/riVO3B50BovxtDiNn0XKkk=",
+ version = "v0.0.0-20171014202726-7bc6a0acffa5",
+)
+
+go_repository(
+ name = "com_github_yuin_goldmark",
+ importpath = "github.com/yuin/goldmark",
+ sum = "h1:5tjfNdR2ki3yYQ842+eX2sQHeiwpKJ0RnHO4IYOc4V8=",
+ version = "v1.1.32",
+)
+
+go_repository(
+ name = "in_gopkg_airbrake_gobrake_v2",
+ importpath = "gopkg.in/airbrake/gobrake.v2",
+ sum = "h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo=",
+ version = "v2.0.9",
+)
+
+go_repository(
+ name = "in_gopkg_fsnotify_v1",
+ importpath = "gopkg.in/fsnotify.v1",
+ sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=",
+ version = "v1.4.7",
+)
+
+go_repository(
+ name = "in_gopkg_gemnasium_logrus_airbrake_hook_v2",
+ importpath = "gopkg.in/gemnasium/logrus-airbrake-hook.v2",
+ sum = "h1:OAj3g0cR6Dx/R07QgQe8wkA9RNjB2u4i700xBkIT4e0=",
+ version = "v2.1.2",
+)
+
+go_repository(
+ name = "in_gopkg_tomb_v1",
+ importpath = "gopkg.in/tomb.v1",
+ sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=",
+ version = "v1.0.0-20141024135613-dd632973f1e7",
+)
+
+go_repository(
+ name = "in_gopkg_yaml_v2",
+ importpath = "gopkg.in/yaml.v2",
+ sum = "h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=",
+ version = "v2.2.8",
+)
+
+go_repository(
+ name = "org_bazil_fuse",
+ importpath = "bazil.org/fuse",
+ sum = "h1:SC+c6A1qTFstO9qmB86mPV2IpYme/2ZoEQ0hrP+wo+Q=",
+ version = "v0.0.0-20160811212531-371fbbdaa898",
+)
+
+go_repository(
+ name = "org_golang_google_appengine",
+ importpath = "google.golang.org/appengine",
+ sum = "h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=",
+ version = "v1.6.5",
+)
+
+go_repository(
+ name = "org_golang_google_genproto",
+ importpath = "google.golang.org/genproto",
+ sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=",
+ version = "v0.0.0-20200117163144-32f20d992d24",
+)
+
+go_repository(
+ name = "org_golang_google_protobuf",
+ importpath = "google.golang.org/protobuf",
+ sum = "h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=",
+ version = "v1.23.0",
+)
+
+go_repository(
+ name = "org_golang_x_exp",
+ importpath = "golang.org/x/exp",
+ sum = "h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg=",
+ version = "v0.0.0-20191227195350-da58074b4299",
+)
+
+go_repository(
+ name = "org_golang_x_lint",
+ importpath = "golang.org/x/lint",
+ sum = "h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE=",
+ version = "v0.0.0-20191125180803-fdd1cda4f05f",
+)
+
+go_repository(
+ name = "tools_gotest",
+ importpath = "gotest.tools",
+ sum = "h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=",
+ version = "v2.2.0+incompatible",
+)
+
+go_repository(
+ name = "com_github_burntsushi_xgb",
+ importpath = "github.com/BurntSushi/xgb",
+ sum = "h1:1BDTz0u9nC3//pOCMdNH+CiXJVYJh5UQNCOBG7jbELc=",
+ version = "v0.0.0-20160522181843-27f122750802",
+)
+
+go_repository(
+ name = "com_github_chzyer_logex",
+ importpath = "github.com/chzyer/logex",
+ sum = "h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE=",
+ version = "v1.1.10",
+)
+
+go_repository(
+ name = "com_github_chzyer_readline",
+ importpath = "github.com/chzyer/readline",
+ sum = "h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8=",
+ version = "v0.0.0-20180603132655-2972be24d48e",
+)
+
+go_repository(
+ name = "com_github_chzyer_test",
+ importpath = "github.com/chzyer/test",
+ sum = "h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8=",
+ version = "v0.0.0-20180213035817-a1ea475d72b1",
+)
+
+go_repository(
+ name = "com_github_go_gl_glfw_v3_3_glfw",
+ importpath = "github.com/go-gl/glfw/v3.3/glfw",
+ sum = "h1:b+9H1GAsx5RsjvDFLoS5zkNBzIQMuVKUYQDmxU3N5XE=",
+ version = "v0.0.0-20191125211704-12ad95a8df72",
+)
+
+go_repository(
+ name = "com_github_golang_groupcache",
+ importpath = "github.com/golang/groupcache",
+ sum = "h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=",
+ version = "v0.0.0-20191227052852-215e87163ea7",
+)
+
+go_repository(
+ name = "com_github_google_martian",
+ importpath = "github.com/google/martian",
+ sum = "h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no=",
+ version = "v2.1.0+incompatible",
+)
+
+go_repository(
+ name = "com_github_google_pprof",
+ importpath = "github.com/google/pprof",
+ sum = "h1:DLpL8pWq0v4JYoRpEhDfsJhhJyGKCcQM2WPW2TJs31c=",
+ version = "v0.0.0-20191218002539-d4f498aebedc",
+)
+
+go_repository(
+ name = "com_github_google_renameio",
+ importpath = "github.com/google/renameio",
+ sum = "h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=",
+ version = "v0.1.0",
+)
+
+go_repository(
+ name = "com_github_googleapis_gax_go_v2",
+ importpath = "github.com/googleapis/gax-go/v2",
+ sum = "h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM=",
+ version = "v2.0.5",
+)
+
+go_repository(
+ name = "com_github_ianlancetaylor_demangle",
+ importpath = "github.com/ianlancetaylor/demangle",
+ sum = "h1:UDMh68UUwekSh5iP2OMhRRZJiiBccgV7axzUG8vi56c=",
+ version = "v0.0.0-20181102032728-5e5cf60278f6",
+)
+
+go_repository(
+ name = "com_github_jstemmer_go_junit_report",
+ importpath = "github.com/jstemmer/go-junit-report",
+ sum = "h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=",
+ version = "v0.9.1",
+)
+
+go_repository(
+ name = "com_github_rogpeppe_go_internal",
+ importpath = "github.com/rogpeppe/go-internal",
+ sum = "h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk=",
+ version = "v1.3.0",
+)
+
+go_repository(
+ name = "com_shuralyov_dmitri_gpu_mtl",
+ importpath = "dmitri.shuralyov.com/gpu/mtl",
+ sum = "h1:VpgP7xuJadIUuKccphEpTJnWhS2jkQyMt6Y7pJCD7fY=",
+ version = "v0.0.0-20190408044501-666a987793e9",
+)
+
+go_repository(
+ name = "in_gopkg_errgo_v2",
+ importpath = "gopkg.in/errgo.v2",
+ sum = "h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=",
+ version = "v2.1.0",
+)
+
+go_repository(
+ name = "io_rsc_binaryregexp",
+ importpath = "rsc.io/binaryregexp",
+ sum = "h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=",
+ version = "v0.2.0",
+)
+
+go_repository(
+ name = "org_golang_google_api",
+ importpath = "google.golang.org/api",
+ sum = "h1:yzlyyDW/J0w8yNFJIhiAJy4kq74S+1DOLdawELNxFMA=",
+ version = "v0.15.0",
+)
+
+go_repository(
+ name = "org_golang_x_image",
+ importpath = "golang.org/x/image",
+ sum = "h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=",
+ version = "v0.0.0-20190802002840-cff245a6509b",
+)
+
+go_repository(
+ name = "org_golang_x_mobile",
+ importpath = "golang.org/x/mobile",
+ sum = "h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=",
+ version = "v0.0.0-20190719004257-d2bd2a29d028",
+)
+
+go_repository(
+ name = "com_github_containerd_typeurl",
+ importpath = "github.com/containerd/typeurl",
+ sum = "h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o=",
+ version = "v0.0.0-20200205145503-b45ef1f1f737",
+)
+
+go_repository(
+ name = "com_github_vishvananda_netns",
+ importpath = "github.com/vishvananda/netns",
+ sum = "h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so=",
+ version = "v0.0.0-20200520041808-52d707b772fe",
+)
+
+go_repository(
+ name = "com_google_cloud_go_bigquery",
+ importpath = "cloud.google.com/go/bigquery",
+ sum = "h1:hL+ycaJpVE9M7nLoiXb/Pn10ENE2u+oddxbD8uu0ZVU=",
+ version = "v1.0.1",
+)
+
+go_repository(
+ name = "com_google_cloud_go_datastore",
+ importpath = "cloud.google.com/go/datastore",
+ sum = "h1:Kt+gOPPp2LEPWp8CSfxhsM8ik9CcyE/gYu+0r+RnZvM=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_google_cloud_go_pubsub",
+ importpath = "cloud.google.com/go/pubsub",
+ sum = "h1:W9tAK3E57P75u0XLLR82LZyw8VpAnhmyTOxW9qzmyj8=",
+ version = "v1.0.1",
+)
+
+go_repository(
+ name = "com_google_cloud_go_storage",
+ importpath = "cloud.google.com/go/storage",
+ sum = "h1:VV2nUM3wwLLGh9lSABFgZMjInyUbJeaRSE64WuAIQ+4=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_hashicorp_errwrap",
+ importpath = "github.com/hashicorp/errwrap",
+ sum = "h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_hashicorp_go_multierror",
+ importpath = "github.com/hashicorp/go-multierror",
+ sum = "h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=",
+ version = "v1.0.0",
+)
+
+go_repository(
+ name = "com_github_dpjacques_clockwork",
+ importpath = "github.com/dpjacques/clockwork",
+ sum = "h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U=",
+ version = "v0.1.1-0.20190114191937-d864eecc357b",
)
diff --git a/g3doc/BUILD b/g3doc/BUILD
new file mode 100644
index 000000000..f91a77b6f
--- /dev/null
+++ b/g3doc/BUILD
@@ -0,0 +1,44 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "index",
+ src = "README.md",
+ category = "Project",
+ data = glob([
+ "*.png",
+ "*.svg",
+ ]),
+ permalink = "/docs/",
+ weight = "0",
+)
+
+doc(
+ name = "roadmap",
+ src = "roadmap.md",
+ category = "Project",
+ permalink = "/roadmap/",
+ weight = "10",
+)
+
+doc(
+ name = "community",
+ src = "community.md",
+ category = "Project",
+ permalink = "/community/",
+ subcategory = "Community",
+ weight = "10",
+)
+
+doc(
+ name = "style",
+ src = "style.md",
+ category = "Project",
+ permalink = "/community/style_guide/",
+ subcategory = "Community",
+ weight = "99",
+)
diff --git a/g3doc/Layers.png b/g3doc/Layers.png
new file mode 100644
index 000000000..308c6c451
--- /dev/null
+++ b/g3doc/Layers.png
Binary files differ
diff --git a/g3doc/Layers.svg b/g3doc/Layers.svg
new file mode 100644
index 000000000..0a366f841
--- /dev/null
+++ b/g3doc/Layers.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 371.8346456692913 255.01574803149606" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l371.83466 0l0 255.01575l-371.83466 0l0 -255.01575z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l371.83466 0l0 255.01575l-371.83466 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m78.206116 37.98824l5.125 -13.359373l1.90625 0l5.46875 13.359373l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703123q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921873zm9.849823 9.1875l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.891342 8.484375l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.844467 4.78125l0 -13.359373l1.640625 0l0 13.359373l-1.640625 0zm4.191696 -11.468748l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.5781231 0.515625 -2.749998q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.812498q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5624981q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.187498q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578123l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671873q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -9.999998l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm3.5354462 -4.84375q0 -2.687498 1.484375 -3.968748q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609373q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.7968731 -0.8125 -2.718748q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765623zm9.297592 4.84375l0 -9.671873l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.5937481l0 5.953125l-1.640625 0l0 -5.890625q0 -0.9999981 -0.203125 -1.4843731q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515623l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.454068 73.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.454068 73.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m172.45407 73.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m172.45407 73.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m73.43044 56.702377l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m87.06437 74.471375l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.0312424 0 1.5781174 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.7031174 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321053 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.4062653 0 -2.2812653 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.6562653 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.4531403 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.4062653 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#d9d2e9" d="m36.454067 87.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#8e7cc3" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 87.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m98.35295 119.54864l1.59375 0.234375q0.109375 0.75 0.5625 1.078125q0.609375 0.453125 1.671875 0.453125q1.140625 0 1.75 -0.453125q0.625 -0.453125 0.84375 -1.265625q0.125 -0.5 0.109375 -2.109375q-1.0625 1.265625 -2.671875 1.265625q-2.0 0 -3.09375 -1.4375q-1.09375 -1.4375 -1.09375 -3.453125q0 -1.390625 0.5 -2.5625q0.515625 -1.171875 1.453125 -1.796875q0.953125 -0.640625 2.25 -0.640625q1.703125 0 2.8125 1.375l0 -1.15625l1.515625 0l0 8.359375q0 2.265625 -0.46875 3.203125q-0.453125 0.9375 -1.453125 1.484375q-0.984375 0.546875 -2.453125 0.546875q-1.71875 0 -2.796875 -0.78125q-1.0625 -0.765625 -1.03125 -2.34375zm1.359375 -5.8125q0 1.90625 0.75 2.78125q0.765625 0.875 1.90625 0.875q1.125 0 1.890625 -0.859375q0.765625 -0.875 0.765625 -2.734375q0 -1.78125 -0.796875 -2.671875q-0.78125 -0.90625 -1.890625 -0.90625q-1.09375 0 -1.859375 0.890625q-0.765625 0.875 -0.765625 2.625zm13.344467 5.015625l-5.171875 -13.359375l1.921875 0l3.46875 9.703125q0.421875 1.171875 0.703125 2.1875q0.3125 -1.09375 0.71875 -2.1875l3.609375 -9.703125l1.796875 0l-5.234375 13.359375l-1.8125 0zm8.427948 -11.46875l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.46875l0 -9.671875l1.640625 0l0 9.671875l-1.640625 0zm3.4885712 -2.890625l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm9.375 -1.953125q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm9.281967 4.84375l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m3.6351707 152.91733l48.850395 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m3.6351707 152.91733l48.850395 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m195.25722 152.91733l47.338577 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m195.25722 152.91733l47.338577 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m52.485565 136.88583l142.77165 0l0 32.06299l-142.77165 0z" fill-rule="evenodd"/><path fill="#000000" d="m65.21821 157.71732l0 -9.546875l1.265625 0l0 8.421875l4.703125 0l0 1.125l-5.96875 0zm7.3343506 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm2.945465 0l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm11.118057 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm5.507965 -1.046875l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm11.006226 4.125l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm9.865463 1.390625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm7.0859375 4.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8124924 0 1.2031174 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1874924 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8124924 0 1.4218674 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0624924 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.2499924 0.328125 1.7343674 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.4531174 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.695305 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321045 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.40625 0 -2.28125 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.65625 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.453125 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.40625 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m36.454067 167.00784l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 167.00784l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m76.63558 198.35303l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.6249924 -6.625l2.390625 0l-5.5937424 5.421875l5.8437424 7.9375l-2.328125 0l-4.7656174 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943565 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.454068 233.43303l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.454068 233.43303l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m172.45407 233.43303l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m172.45407 233.43303l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m73.43044 217.06717l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m96.04542 237.89867l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.0937424 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.4062424 5.3125l-1.21875 0zm12.859535 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59376526 0.21875 -1.2812653 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.4218903 -0.171875 2.0937653 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.3437653 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.89064026 0 1.4375153 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.9218903 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.2031403 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.53514884533539 0.0 0.0 4.53514884533539 0.0 0.0)" spreadMethod="pad" x1="8.21347768339151" y1="37.02644733653771" x2="8.213461293294644" y2="41.56159618184348"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m37.249344 167.92108l173.29134 0l0 20.566925l-173.29134 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m272.4455 182.06865l129.5748 -74.83464l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m287.51392 188.73558l1.1823425 -0.82717896q0.51071167 0.6974335 1.1166077 0.9970703q0.5980835 0.28611755 1.4464111 0.1931305q0.84054565 -0.10652161 1.6794434 -0.5910187q0.75772095 -0.4376068 1.2010193 -0.9823456q0.44906616 -0.5660858 0.50097656 -1.1013031q0.057678223 -0.55656433 -0.20785522 -1.0166931q-0.27334595 -0.47366333 -0.7392273 -0.6557007q-0.47366333 -0.1955719 -1.2366333 -0.079711914q-0.478302 0.07775879 -2.032318 0.54222107q-1.5618286 0.45092773 -2.2805786 0.48712158q-0.9222717 0.027420044 -1.5864563 -0.31072998q-0.6719971 -0.35168457 -1.0703125 -1.0418701q-0.4295349 -0.74432373 -0.38497925 -1.6361694q0.05029297 -0.9131775 0.6668701 -1.7203827q0.63012695 -0.81500244 1.6313782 -1.3932648q1.1095276 -0.64079285 2.1592712 -0.7598877q1.0419617 -0.1326294 1.8867493 0.29968262q0.8583679 0.4244995 1.3987732 1.2671814l-1.2036743 0.82147217q-0.64712524 -0.87127686 -1.5022583 -1.0089111q-0.8629761 -0.15116882 -2.013092 0.51304626q-1.1906738 0.68766785 -1.4819641 1.4333191q-0.2913208 0.745636 0.06793213 1.3681641q0.30459595 0.52778625 0.8865051 0.6608429q0.5740967 0.119522095 2.3815613 -0.43717957q1.8210144 -0.5645294 2.5725403 -0.6376953q1.0924377 -0.107666016 1.857605 0.28042603q0.77090454 0.366745 1.2316895 1.1652069q0.4529724 0.78492737 0.39904785 1.7543335q-0.040405273 0.9616089 -0.6663208 1.8463745q-0.62594604 0.8847656 -1.6813354 1.4942932q-1.3530579 0.78144836 -2.486084 0.9125519q-1.1408386 0.11756897 -2.1214905 -0.3625946q-0.96713257 -0.48797607 -1.5721436 -1.4738007zm13.40155 -4.9431458l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396332q-0.6629944 0.3829193 -1.1454773 0.39089966q-0.4902649 -0.0055389404 -0.8343506 -0.25790405q-0.3518982 -0.26589966 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.9473114l0.8930054 -0.5157623l-1.0308838 -1.786377l0.79599 -1.434082l1.4526367 2.5171661l1.2177734 -0.7033081l0.5466919 0.94732666l-1.2177734 0.70329285l2.4210815 4.195282q0.30456543 0.5278015 0.4446106 0.645401q0.15356445 0.109802246 0.35705566 0.11857605q0.19570923 -0.004760742 0.4663086 -0.16105652q0.20297241 -0.11721802 0.49645996 -0.35888672zm1.8165283 0.41241455l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.0178833 -1.0001068 0.19332886 -1.4468842q0.21121216 -0.44676208 0.64419556 -0.6968231q0.6088562 -0.35165405 1.471283 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765381 0.6577606q-0.17843628 0.40979004 -0.06384277 0.92100525q0.17193604 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm2.0899658 -6.006668q-1.1480408 -1.9893646 -0.5930481 -3.5910034q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210815q0.04135132 1.0407104 -0.49923706 1.948349q-0.5348511 0.88630676 -1.4819946 1.4333191q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621765zm1.2583313 -0.72673035q0.79663086 1.3803711 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18502808q0.9065552 -0.5235596 1.1036072 -1.5575867q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17720032q-0.90652466 0.5235748 -1.117096 1.5654144q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494751 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.522583 -4.371216q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078125 -0.8442688 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.264862 3.9246216l-1.2177429 0.70329285zm7.819275 -3.7220154l1.2922058 -0.511734q0.3878479 0.51579285 0.8666992 0.5640259q0.65527344 0.072631836 1.4400635 -0.38059998q0.8388977 -0.48449707 1.1036682 -1.0885162q0.26480103 -0.60401917 0.07571411 -1.306778q-0.1161499 -0.42010498 -0.81121826 -1.6245575q-0.25161743 1.408371 -1.4423218 2.096054q-1.4883423 0.85957336 -2.9171448 0.25932312q-1.428833 -0.60025024 -2.2957153 -2.1024323q-0.5935669 -1.0285187 -0.72805786 -2.1056366q-0.12097168 -1.0849152 0.30926514 -1.9649353q0.4437561 -0.8878174 1.3908997 -1.4348297q1.2718811 -0.7345581 2.690796 -0.182724l-0.4998474 -0.8661194l1.1230469 -0.6485901l3.5847168 6.2117157q0.9684448 1.6781158 1.0362854 2.5771942q0.067840576 0.8990936 -0.4420166 1.7348785q-0.5098877 0.8357849 -1.5923462 1.4609375q-1.2854004 0.7423706 -2.4195251 0.62150574q-1.1206055 -0.12869263 -1.7651672 -1.3081818zm-1.4765625 -4.9031525q0.81222534 1.4074402 1.7418518 1.7366486q0.94314575 0.32138062 1.7820435 -0.16311646q0.8388977 -0.4844818 1.0401001 -1.4487457q0.19342041 -0.97779846 -0.60317993 -2.3581848q-0.76538086 -1.3262482 -1.7298889 -1.6533508q-0.97229004 -0.3406372 -1.7976685 0.1360321q-0.8118286 0.46887207 -0.9974365 1.4602051q-0.1855774 0.991333 0.56417847 2.290512z" fill-rule="nonzero"/><path fill="#000000" d="m294.23132 199.68793l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80441284 1.3939056l-1.2177429 0.7033081zm4.920227 8.525894l-4.147064 -7.1861115l1.2177734 -0.7033081l4.1470337 7.1861115l-1.2177429 0.7033081zm1.3493347 -3.6482391l1.0948792 -0.88494873q0.51641846 0.6760864 1.2029724 0.80285645q0.6922302 0.10542297 1.5176086 -0.3712616q0.8388672 -0.4844818 1.0495605 -1.057251q0.20285034 -0.58628845 -0.062683105 -1.0464172q-0.23431396 -0.4059906 -0.7266846 -0.4464264q-0.35079956 -0.013916016 -1.4790955 0.3129425q-1.5347595 0.43530273 -2.2030334 0.49645996q-0.66256714 0.03982544 -1.183075 -0.23695374q-0.5069885 -0.28459167 -0.8115845 -0.8123779q-0.28115845 -0.48719788 -0.2989807 -1.018219q-0.017791748 -0.5310211 0.2048645 -1.0204926q0.15917969 -0.3806305 0.56817627 -0.797287q0.4147339 -0.43800354 0.9694824 -0.75839233q0.852417 -0.49230957 1.6289368 -0.61598206q0.7765198 -0.123687744 1.3162842 0.123931885q0.5455017 0.22625732 1.0733948 0.85964966l-1.0969849 0.85006714q-0.4013672 -0.5079651 -0.9733887 -0.59262085q-0.5720215 -0.0846405 -1.2756042 0.3217163q-0.8388977 0.4844818 -1.0402222 0.9796753q-0.19558716 0.47383118 0.015289307 0.8392334q0.14056396 0.24359131 0.39874268 0.34710693q0.2581787 0.103500366 0.64749146 0.05909729q0.22845459 -0.041732788 1.2620544 -0.31388855q1.4806519 -0.40403748 2.127594 -0.470932q0.6390991 -0.08041382 1.1653442 0.17501831q0.5397949 0.24760437 0.89904785 0.87013245q0.35144043 0.60899353 0.29852295 1.3613129q-0.047210693 0.73095703 -0.5519409 1.4194183q-0.49118042 0.68063354 -1.3300476 1.1651306q-1.407196 0.81269836 -2.4736633 0.6527405q-1.0664673 -0.15994263 -1.933258 -1.193039zm6.1190186 -5.4646606q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47283936 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.85957336 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79663086 1.3803864 0.83795166 2.4210968q0.04135132 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902832 1.7267303q1.0071716 0.33854675 1.9136963 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.90652466 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.957733 -0.3738098l-5.7246704 -9.919815l1.2177734 -0.70329285l5.72464 9.9198l-1.2177429 0.7033081zm7.2713623 -5.390396q-0.34069824 0.972641 -0.8225403 1.5756989q-0.48962402 0.5895386 -1.2067566 1.0036926q-1.1906738 0.6876831 -2.162445 0.47302246q-0.9717407 -0.21464539 -1.4872131 -1.1078339q-0.30456543 -0.52778625 -0.3109436 -1.1015167q-0.014190674 -0.58724976 0.21627808 -1.0631866q0.23620605 -0.49728394 0.64520264 -0.9139252q0.31063843 -0.3057251 0.980896 -0.8010864q1.3733215 -1.02771 1.9363098 -1.6776581q-0.14837646 -0.25712585 -0.18740845 -0.32478333q-0.42956543 -0.74432373 -0.93963623 -0.84669495q-0.7156677 -0.14602661 -1.6357422 0.38536072q-0.8524475 0.49230957 -1.0922546 1.0458069q-0.23410034 0.5321655 0.013824463 1.3994293l-1.2843933 0.52526855q-0.2749939 -0.85162354 -0.18301392 -1.5362854q0.10549927 -0.6924591 0.66851807 -1.3424072q0.5552063 -0.66348267 1.4752808 -1.1948547q0.92007446 -0.5313873 1.6133118 -0.6430664q0.7067566 -0.11949158 1.1648254 0.04902649q0.45803833 0.16850281 0.8552551 0.60672q0.24728394 0.27218628 0.71588135 1.0841675l0.9371643 1.6239929q0.9840393 1.7051697 1.3094177 2.1126862q0.33892822 0.39971924 0.81103516 0.68640137l-1.2718506 0.7345581q-0.40811157 -0.26953125 -0.7590027 -0.75253296zm-1.6645203 -2.6654663q-0.5067749 0.65356445 -1.7234497 1.6088562q-0.69522095 0.5458679 -0.9283142 0.8609314q-0.23312378 0.31506348 -0.25280762 0.6873169q-0.013977051 0.3508911 0.16564941 0.66215515q0.28115845 0.48719788 0.83392334 0.6010132q0.5662842 0.10598755 1.269867 -0.30036926q0.70358276 -0.40634155 1.072998 -1.0166473q0.37512207 -0.63165283 0.3197937 -1.3214569q-0.031341553 -0.5232086 -0.49990845 -1.3352051l-0.25775146 -0.44659424zm7.223419 -0.8156891l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25790405q-0.3518982 -0.26591492 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.94732666l0.8930054 -0.5157471l-1.0308838 -1.786377l0.7960205 -1.434082l1.4526367 2.5171661l1.2177429 -0.7033081l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.52778625 0.4446106 0.645401q0.15356445 0.10978699 0.35708618 0.11856079q0.19567871 -0.004760742 0.46627808 -0.16104126q0.20297241 -0.11721802 0.49645996 -0.35890198zm-3.0901794 -8.121277l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80444336 1.3939209l-1.2177734 0.70329285zm4.920227 8.525894l-4.1470337 -7.1861115l1.2177429 -0.70329285l4.147064 7.186096l-1.2177734 0.7033081zm0.54074097 -5.111908q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210968q0.0413208 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959412 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.9065552 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959473 -0.6329651l0.5857544 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494446 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.5226135 -4.371216q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3539734 -0.30078125 -0.8442383 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821808q-0.2735901 0.8075714 0.52301025 2.1879578l2.264862 3.9246216l-1.2177429 0.70329285z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m282.76047 199.94267c-17.003845 0 -26.795105 -5.566925 -34.007706 -11.133865c-7.2126007 -5.566925 -11.846542 -11.13385 -23.6931 -11.13385" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.76047 199.94267c-17.003876 0 -26.795105 -5.5669403 -34.007706 -11.133865c-3.6062927 -2.7834625 -6.567932 -5.566925 -10.10881 -7.6545258c-0.4426117 -0.2609558 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.1181488 -0.34950256 -0.17605591l-0.13806152 -0.066833496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m237.48381 176.93082l-9.563843 1.350235l8.194244 5.1131744z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m272.4455 118.06866l129.5748 -74.83465l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m290.03357 127.53869l-5.72464 -9.9198l1.3124695 -0.75800323l5.72464 9.919807l-1.3124695 0.7579956zm3.470581 -2.0044022l-4.1470337 -7.186104l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012146q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.2648926 3.924614l-1.2177734 0.70329285zm12.360168 -7.1384735l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001221 -1.8564148 0.44062805q-0.9904785 -0.05949402 -1.8883972 -0.67765045q-0.88442993 -0.62597656 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816315 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898376 -2.1887817q-1.0465393 -1.813446 -0.69085693 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm10.854523 3.1318436l-5.740265 -9.946869l1.1094971 -0.6407852l0.5388794 0.9337845q0.07220459 -0.78147125 0.4031067 -1.333458q0.3444214 -0.5597992 1.0480042 -0.9661484q0.92007446 -0.5313797 1.8970032 -0.46406555q0.9769592 0.06730652 1.8285828 0.7302551q0.8573303 0.64160156 1.4508972 1.6701202q0.6404114 1.1097183 0.74212646 2.2238083q0.11526489 1.1062698 -0.3691101 2.01754q-0.4708252 0.90345 -1.3097229 1.3879471q-0.6088867 0.35164642 -1.2443542 0.37583923q-0.62197876 0.01637268 -1.159668 -0.19635773l2.022766 3.5050888l-1.2177429 0.7033005zm-2.551239 -6.952957q0.80441284 1.3939209 1.7418518 1.7366486q0.95095825 0.3349228 1.7492676 -0.12612915q0.8118286 -0.46886444 0.9953308 -1.495079q0.18353271 -1.026207 -0.6443176 -2.4607239q-0.79663086 -1.3803787 -1.7554016 -1.728836q-0.96658325 -0.36198425 -1.7513428 0.09125519q-0.77124023 0.4454193 -0.958374 1.5278625q-0.1736145 1.0746231 0.62298584 2.4550018zm12.239746 -5.408928l1.3442688 -0.57788086q0.34274292 1.2816391 -0.11764526 2.359497q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.91428375 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.1887894q-1.0465088 -1.8134384 -0.6908264 -3.3540878q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10826874 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173462 1.4440689q0.96810913 0.27087402 1.861145 -0.24488068q0.6765137 -0.39071655 0.9470215 -1.0160828q0.2705078 -0.62537384 0.095947266 -1.5530472zm-5.1317444 0.32941437l4.0050354 -2.3130722q-0.60443115 -0.85983276 -1.2410278 -1.0876236q-0.98791504 -0.3677063 -1.9079895 0.1636734q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.848114zm9.2612915 0.3710785l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22263336 0.9021301 0.687912q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.2648926 3.924614l-1.2177734 0.7033005zm12.360168 -7.138481l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001221 -1.8564148 0.44062805q-0.9904785 -0.05949402 -1.8883972 -0.67765045q-0.88442993 -0.62596893 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816391 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.4430008 -2.4898682 -2.1887817q-1.0465088 -1.813446 -0.6908264 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm9.261322 0.3710785l-4.147064 -7.186104l1.0959778 -0.6329727l0.5857544 1.0149918q0.105285645 -1.6306229 1.6071777 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45761108q0.67401123 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253387l-1.2177429 0.7033005l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335907 -0.96383667 0.2679596q-0.78479004 0.45323944 -1.0640869 1.2821732q-0.2736206 0.80758667 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005zm9.725006 -7.078125l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396408q-0.6629944 0.38291168 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25791168q-0.3518982 -0.26589966 -0.9844971 -1.3620834l-2.382019 -4.127617l-0.8930054 0.5157547l-0.5466919 -0.94731903l0.8930054 -0.5157547l-1.0308838 -1.786377l0.79599 -1.4340897l1.4526672 2.5171661l1.2177429 -0.7033005l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.5277939 0.4446106 0.645401q0.15356445 0.10979462 0.35708618 0.11856842q0.19567871 -0.004760742 0.46627808 -0.16104889q0.20297241 -0.11721802 0.49645996 -0.35889435z" fill-rule="nonzero"/><path fill="#000000" d="m299.15155 144.21382l-5.72464 -9.919815l1.2177429 -0.70329285l3.2645264 5.6568604l1.1950684 -4.587631l1.5830688 -0.9142914l-1.2081604 4.252365l5.625824 2.7774506l-1.5018921 0.8674011l-4.4921265 -2.3134918l-0.38955688 1.3256378l1.6478882 2.8554993l-1.2177429 0.7033081zm10.503723 -9.151794l1.3442383 -0.57788086q0.34274292 1.2816315 -0.11764526 2.359497q-0.44683838 1.0700378 -1.6916504 1.788971q-1.5830688 0.9142761 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.188797q-1.0465393 -1.8134308 -0.69085693 -3.3540802q0.35568237 -1.5406494 1.8440247 -2.400238q1.4342346 -0.8283386 2.9109192 -0.36398315q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10827637 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173157 1.4440765q0.96813965 0.27087402 1.861145 -0.2448883q0.6765137 -0.39071655 0.947052 -1.0160828q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32940674l4.0050354 -2.3130646q-0.60446167 -0.85983276 -1.2410278 -1.0876312q-0.98791504 -0.3677063 -1.90802 0.16368103q-0.8388672 0.48449707 -1.0926819 1.3889008q-0.2480774 0.8830719 0.23669434 1.848114zm9.247772 0.378891l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.017913818 -1.0001068 0.19329834 -1.4468765q0.21124268 -0.4467697 0.64419556 -0.69683075q0.6088867 -0.35165405 1.4713135 -0.32646942l0.22875977 1.3655014q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765686 0.6577606q-0.17840576 0.40979004 -0.063812256 0.92100525q0.17190552 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm4.6274414 -2.6725311l-4.147064 -7.186104l1.0959778 -0.6329727l0.5857544 1.0149841q0.105285645 -1.6306152 1.6071777 -2.4980164q0.6494751 -0.37509918 1.3234558 -0.45761108q0.67401123 -0.0825119 1.1632996 0.14012909q0.4892578 0.22264099 0.9020996 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335907 -0.96383667 0.2679596q-0.78479004 0.45323944 -1.0640869 1.2821732q-0.2736206 0.80758667 0.52301025 2.1879654l2.264862 3.9246216l-1.2177429 0.70329285zm11.281708 -9.60112l1.3442688 -0.57788086q0.34274292 1.2816391 -0.11764526 2.359497q-0.4468689 1.0700378 -1.6916809 1.788971q-1.5830688 0.91428375 -3.0654602 0.47127533q-1.4823914 -0.4430008 -2.4898682 -2.1887817q-1.0465088 -1.8134384 -0.6908264 -3.3540878q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75491333 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32941437l4.0050354 -2.3130722q-0.60443115 -0.85983276 -1.2410278 -1.0876236q-0.98791504 -0.3677063 -1.9079895 0.1636734q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.848114zm9.234253 0.3867035l-5.7246704 -9.919807l1.2177734 -0.7033005l5.72464 9.919807l-1.2177429 0.7033005z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m282.76047 135.94267c-17.003845 0 -26.795105 -5.566925 -34.007706 -11.133858c-7.2126007 -5.5669327 -11.846542 -11.133858 -23.6931 -11.133858" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.76047 135.94267c-17.003876 0 -26.795105 -5.5669403 -34.007706 -11.133865c-3.6062927 -2.7834625 -6.567932 -5.566925 -10.10881 -7.6545258c-0.4426117 -0.26094818 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.11816406 -0.34950256 -0.17607117l-0.13806152 -0.06682587" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m237.48381 112.93081l-9.563843 1.350235l8.194244 5.113182z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/g3doc/Machine-Virtualization.png b/g3doc/Machine-Virtualization.png
new file mode 100644
index 000000000..1ba2ed6b2
--- /dev/null
+++ b/g3doc/Machine-Virtualization.png
Binary files differ
diff --git a/g3doc/Machine-Virtualization.svg b/g3doc/Machine-Virtualization.svg
new file mode 100644
index 000000000..5352da07b
--- /dev/null
+++ b/g3doc/Machine-Virtualization.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 387.7034120734908 336.4225721784777" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l387.7034 0l0 336.42258l-387.7034 0l0 -336.42258z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l387.7034 0l0 336.42258l-387.7034 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m44.454067 14.643044l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 14.643044l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m86.206116 45.98824l5.125 -13.359375l1.90625 0l5.46875 13.359375l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703125q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921875zm9.849823 9.1875l0 -13.375l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546875q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.84375 -0.765625 -2.765625q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.84375zm8.891342 8.484375l0 -13.375l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546875q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.84375 -0.765625 -2.765625q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.84375zm8.844467 4.78125l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0zm4.191696 -11.46875l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.46875l0 -9.671875l1.640625 0l0 9.671875l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.578125 0.515625 -2.75q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.8125q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5625q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.1875q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -10.0l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.46875l0 -9.671875l1.640625 0l0 9.671875l-1.640625 0zm3.5354462 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm9.297592 4.84375l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 81.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 81.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m180.45407 81.068245l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m180.45407 81.068245l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m81.43044 64.70238l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m95.06437 82.471375l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.521843 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.32106 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.4062653 0 -2.2812653 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.6562653 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.4531403 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.4062653 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#fff2cc" d="m44.454067 95.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#f1c232" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 95.40656l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m85.11598 121.51739l0 -1.578125l5.65625 0l0 4.953125q-1.296875 1.046875 -2.6875 1.578125q-1.375 0.515625 -2.84375 0.515625q-1.96875 0 -3.578125 -0.84375q-1.609375 -0.84375 -2.421875 -2.4375q-0.8125 -1.59375 -0.8125 -3.5625q0 -1.953125 0.8125 -3.640625q0.8125 -1.6875 2.34375 -2.5q1.53125 -0.828125 3.515625 -0.828125q1.453125 0 2.625 0.46875q1.171875 0.46875 1.828125 1.3125q0.671875 0.828125 1.015625 2.171875l-1.59375 0.4375q-0.296875 -1.015625 -0.75 -1.59375q-0.4375 -0.59375 -1.265625 -0.9375q-0.828125 -0.34375 -1.84375 -0.34375q-1.203125 0 -2.09375 0.375q-0.890625 0.359375 -1.4375 0.96875q-0.53125 0.59375 -0.828125 1.3125q-0.515625 1.234375 -0.515625 2.6875q0 1.78125 0.609375 2.984375q0.625 1.203125 1.796875 1.796875q1.171875 0.578125 2.5 0.578125q1.140625 0 2.234375 -0.4375q1.09375 -0.453125 1.65625 -0.953125l0 -2.484375l-3.921875 0zm14.386429 5.234375l0 -1.421875q-1.125 1.640625 -3.0625 1.640625q-0.859375 0 -1.609375 -0.328125q-0.734375 -0.328125 -1.09375 -0.828125q-0.359375 -0.5 -0.5 -1.21875q-0.109375 -0.46875 -0.109375 -1.53125l0 -5.984375l1.640625 0l0 5.359375q0 1.28125 0.109375 1.734375q0.15625 0.640625 0.65625 1.015625q0.5 0.375 1.234375 0.375q0.734375 0 1.375 -0.375q0.65625 -0.390625 0.921875 -1.03125q0.265625 -0.65625 0.265625 -1.890625l0 -5.1875l1.640625 0l0 9.671875l-1.46875 0zm10.672592 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm8.485092 2.875l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.625 -6.625l2.390625 0l-5.59375 5.421875l5.84375 7.9375l-2.328125 0l-4.765625 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943573 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000717 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 161.83176l57.574802 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 161.83176l57.574802 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m193.71391 161.83176l60.7874 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m193.71391 161.83176l60.7874 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m70.02887 145.80026l123.68504 0l0 32.06299l-123.68504 0z" fill-rule="evenodd"/><path fill="#000000" d="m87.09864 166.63176l-3.6875 -9.546875l1.359375 0l2.484375 6.9375q0.296875 0.828125 0.5 1.5625q0.21875 -0.78125 0.515625 -1.5625l2.578125 -6.9375l1.28125 0l-3.734375 9.546875l-1.296875 0zm6.0303802 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm2.92984 0l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm7.0164948 -1.046875l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.6717377 1.046875l0 -1.015625q-0.8125 1.171875 -2.1875 1.171875q-0.609375 0 -1.140625 -0.234375q-0.53125 -0.234375 -0.796875 -0.578125q-0.25 -0.359375 -0.359375 -0.875q-0.0625 -0.34375 -0.0625 -1.09375l0 -4.28125l1.171875 0l0 3.828125q0 0.921875 0.0625 1.234375q0.109375 0.46875 0.46875 0.734375q0.359375 0.25 0.890625 0.25q0.515625 0 0.984375 -0.265625q0.46875 -0.265625 0.65625 -0.734375q0.1875 -0.46875 0.1875 -1.34375l0 -3.703125l1.171875 0l0 6.90625l-1.046875 0zm7.3968506 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm6.6797028 0l0 -9.546875l1.171875 0l0 3.421875q0.828125 -0.9375 2.0781174 -0.9375q0.765625 0 1.328125 0.296875q0.5625 0.296875 0.8125 0.84375q0.25 0.53125 0.25 1.546875l0 4.375l-1.171875 0l0 -4.375q0 -0.890625 -0.390625 -1.28125q-0.375 -0.40625 -1.078125 -0.40625q-0.515625 0 -0.9843674 0.28125q-0.453125 0.265625 -0.65625 0.734375q-0.1875 0.453125 -0.1875 1.265625l0 3.78125l-1.171875 0zm11.928093 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.09375 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.40625 5.3125l-1.21875 0zm12.859528 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m44.454067 175.40657l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path stroke="#93c47d" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 175.40657l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m115.3588 206.75175l-5.171875 -13.359375l1.921875 0l3.46875 9.703125q0.421875 1.171875 0.703125 2.1875q0.3125 -1.09375 0.71875 -2.1875l3.609375 -9.703125l1.796875 0l-5.234375 13.359375l-1.8125 0zm8.584198 0l0 -13.359375l2.65625 0l3.1562424 9.453125q0.4375 1.328125 0.640625 1.984375q0.234375 -0.734375 0.703125 -2.140625l3.203125 -9.296875l2.375 0l0 13.359375l-1.703125 0l0 -11.171875l-3.875 11.171875l-1.59375 0l-3.8593674 -11.375l0 11.375l-1.703125 0zm15.540794 0l0 -13.359375l2.65625 0l3.15625 9.453125q0.4375 1.328125 0.640625 1.984375q0.234375 -0.734375 0.703125 -2.140625l3.203125 -9.296875l2.375 0l0 13.359375l-1.703125 0l0 -11.171875l-3.875 11.171875l-1.59375 0l-3.859375 -11.375l0 11.375l-1.703125 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 239.1764l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 239.1764l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m180.45407 239.1764l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m180.45407 239.1764l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m81.43044 222.81055l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m95.06437 240.57954l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.521843 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.32106 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.4062653 0 -2.2812653 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.6562653 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.4531403 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.4062653 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m44.454067 252.7512l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m44.454067 252.7512l174.83464 0l0 48.850388l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m84.63558 284.0964l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.913475 1.46875l0 -13.359375l1.78125 0l0 6.625l6.625 -6.625l2.390625 0l-5.59375 5.421875l5.84375 7.9375l-2.328125 0l-4.765625 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943573 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m12.454068 319.17642l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m12.454068 319.17642l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m180.45407 319.17642l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m180.45407 319.17642l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m81.43044 302.81055l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m104.04542 323.64203l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.938362 0l0 -0.875q-0.65625 1.03125 -1.9374924 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.6249924 0 1.1093674 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.7031174 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.3281174 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.3437424 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.912468 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.09375 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.40625 5.3125l-1.21875 0zm12.859543 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59376526 0.21875 -1.2812653 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.4218903 -0.171875 2.0937653 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.3437653 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.89064026 0 1.4375153 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.9218903 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.2031403 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.545553100086654 0.0 0.0 4.545553100086654 0.0 0.0)" spreadMethod="pad" x1="9.954639806354566" y1="38.70166210013951" x2="9.95462288989064" y2="43.24721520019468"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m45.249344 175.92108l173.29134 0l0 20.661423l-173.29134 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m280.4455 190.06865l129.5748 -74.83464l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m295.51392 196.73558l1.1823425 -0.82717896q0.51071167 0.6974335 1.1166077 0.9970703q0.5980835 0.28611755 1.4464111 0.1931305q0.84054565 -0.10652161 1.6794434 -0.5910187q0.75772095 -0.4376068 1.2010193 -0.9823456q0.44906616 -0.5660858 0.50097656 -1.1013031q0.057678223 -0.55656433 -0.20785522 -1.0166931q-0.27334595 -0.47366333 -0.7392273 -0.6557007q-0.47366333 -0.1955719 -1.2366333 -0.079711914q-0.478302 0.07775879 -2.032318 0.54222107q-1.5618286 0.45092773 -2.2805786 0.48712158q-0.9222717 0.027420044 -1.5864563 -0.31072998q-0.6719971 -0.35168457 -1.0703125 -1.0418701q-0.4295349 -0.74432373 -0.38497925 -1.6361694q0.05029297 -0.9131775 0.6668701 -1.7203827q0.63012695 -0.81500244 1.6313782 -1.3932648q1.1095276 -0.64079285 2.1592712 -0.7598877q1.0419617 -0.1326294 1.8867493 0.29968262q0.8583679 0.4244995 1.3987732 1.2671814l-1.2036743 0.82147217q-0.64712524 -0.87127686 -1.5022583 -1.0089111q-0.8629761 -0.15116882 -2.013092 0.51304626q-1.1906738 0.68766785 -1.4819641 1.4333191q-0.2913208 0.745636 0.06793213 1.3681641q0.30459595 0.52778625 0.8865051 0.6608429q0.5740967 0.119522095 2.3815613 -0.43717957q1.8210144 -0.5645294 2.5725403 -0.6376953q1.0924377 -0.107666016 1.857605 0.28042603q0.77090454 0.366745 1.2316895 1.1652069q0.4529724 0.78492737 0.39904785 1.7543335q-0.040405273 0.9616089 -0.6663208 1.8463745q-0.62594604 0.8847656 -1.6813354 1.4942932q-1.3530579 0.78144836 -2.486084 0.9125519q-1.1408386 0.11756897 -2.1214905 -0.3625946q-0.96713257 -0.48797607 -1.5721436 -1.4738007zm13.40155 -4.9431458l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396332q-0.6629944 0.3829193 -1.1454773 0.39089966q-0.4902649 -0.0055389404 -0.8343506 -0.25790405q-0.3518982 -0.26589966 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.9473114l0.8930054 -0.5157623l-1.0308838 -1.786377l0.79599 -1.434082l1.4526367 2.5171661l1.2177734 -0.7033081l0.5466919 0.94732666l-1.2177734 0.70329285l2.4210815 4.195282q0.30456543 0.5278015 0.4446106 0.645401q0.15356445 0.109802246 0.35705566 0.11857605q0.19570923 -0.004760742 0.4663086 -0.16105652q0.20297241 -0.11721802 0.49645996 -0.35888672zm1.8165283 0.41241455l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.0178833 -1.0001068 0.19332886 -1.4468842q0.21121216 -0.44676208 0.64419556 -0.6968231q0.6088562 -0.35165405 1.471283 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765381 0.6577606q-0.17843628 0.40979004 -0.06384277 0.92100525q0.17193604 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm2.0899658 -6.006668q-1.1480408 -1.9893646 -0.5930481 -3.5910034q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210815q0.04135132 1.0407104 -0.49923706 1.948349q-0.5348511 0.88630676 -1.4819946 1.4333191q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621765zm1.2583313 -0.72673035q0.79663086 1.3803711 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18502808q0.9065552 -0.5235596 1.1036072 -1.5575867q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17720032q-0.90652466 0.5235748 -1.117096 1.5654144q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494751 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.522583 -4.371216q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078125 -0.8442688 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.264862 3.9246216l-1.2177429 0.70329285zm7.819275 -3.7220154l1.2922058 -0.511734q0.3878479 0.51579285 0.8666992 0.5640259q0.65527344 0.072631836 1.4400635 -0.38059998q0.8388977 -0.48449707 1.1036682 -1.0885162q0.26480103 -0.60401917 0.07571411 -1.306778q-0.1161499 -0.42010498 -0.81121826 -1.6245575q-0.25161743 1.408371 -1.4423218 2.096054q-1.4883423 0.85957336 -2.9171448 0.25932312q-1.428833 -0.60025024 -2.2957153 -2.1024323q-0.5935669 -1.0285187 -0.72805786 -2.1056366q-0.12097168 -1.0849152 0.30926514 -1.9649353q0.4437561 -0.8878174 1.3908997 -1.4348297q1.2718811 -0.7345581 2.690796 -0.182724l-0.4998474 -0.8661194l1.1230469 -0.6485901l3.5847168 6.2117157q0.9684448 1.6781158 1.0362854 2.5771942q0.067840576 0.8990936 -0.4420166 1.7348785q-0.5098877 0.8357849 -1.5923462 1.4609375q-1.2854004 0.7423706 -2.4195251 0.62150574q-1.1206055 -0.12869263 -1.7651672 -1.3081818zm-1.4765625 -4.9031525q0.81222534 1.4074402 1.7418518 1.7366486q0.94314575 0.32138062 1.7820435 -0.16311646q0.8388977 -0.4844818 1.0401001 -1.4487457q0.19342041 -0.97779846 -0.60317993 -2.3581848q-0.76538086 -1.3262482 -1.7298889 -1.6533508q-0.97229004 -0.3406372 -1.7976685 0.1360321q-0.8118286 0.46887207 -0.9974365 1.4602051q-0.1855774 0.991333 0.56417847 2.290512z" fill-rule="nonzero"/><path fill="#000000" d="m302.23132 207.68793l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80441284 1.3939056l-1.2177429 0.7033081zm4.920227 8.525894l-4.147064 -7.1861115l1.2177734 -0.7033081l4.1470337 7.1861115l-1.2177429 0.7033081zm1.3493347 -3.6482391l1.0948792 -0.88494873q0.51641846 0.6760864 1.2029724 0.80285645q0.6922302 0.10542297 1.5176086 -0.3712616q0.8388672 -0.4844818 1.0495605 -1.057251q0.20285034 -0.58628845 -0.062683105 -1.0464172q-0.23431396 -0.4059906 -0.7266846 -0.4464264q-0.35079956 -0.013916016 -1.4790955 0.3129425q-1.5347595 0.43530273 -2.2030334 0.49645996q-0.66256714 0.03982544 -1.183075 -0.23695374q-0.5069885 -0.28459167 -0.8115845 -0.8123779q-0.28115845 -0.48719788 -0.2989807 -1.018219q-0.017791748 -0.5310211 0.2048645 -1.0204926q0.15917969 -0.3806305 0.56817627 -0.797287q0.4147339 -0.43800354 0.9694824 -0.75839233q0.852417 -0.49230957 1.6289368 -0.61598206q0.7765198 -0.123687744 1.3162842 0.123931885q0.5455017 0.22625732 1.0733948 0.85964966l-1.0969849 0.85006714q-0.4013672 -0.5079651 -0.9733887 -0.59262085q-0.5720215 -0.0846405 -1.2756042 0.3217163q-0.8388977 0.4844818 -1.0402222 0.9796753q-0.19558716 0.47383118 0.015289307 0.8392334q0.14056396 0.24359131 0.39874268 0.34710693q0.2581787 0.103500366 0.64749146 0.05909729q0.22845459 -0.041732788 1.2620544 -0.31388855q1.4806519 -0.40403748 2.127594 -0.470932q0.6390991 -0.08041382 1.1653442 0.17501831q0.5397949 0.24760437 0.89904785 0.87013245q0.35144043 0.60899353 0.29852295 1.3613129q-0.047210693 0.73095703 -0.5519409 1.4194183q-0.49118042 0.68063354 -1.3300476 1.1651306q-1.407196 0.81269836 -2.4736633 0.6527405q-1.0664673 -0.15994263 -1.933258 -1.193039zm6.1190186 -5.4646606q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47283936 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.85957336 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79663086 1.3803864 0.83795166 2.4210968q0.04135132 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902832 1.7267303q1.0071716 0.33854675 1.9136963 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.90652466 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.957733 -0.3738098l-5.7246704 -9.919815l1.2177734 -0.70329285l5.72464 9.9198l-1.2177429 0.7033081zm7.2713623 -5.390396q-0.34069824 0.972641 -0.8225403 1.5756989q-0.48962402 0.5895386 -1.2067566 1.0036926q-1.1906738 0.6876831 -2.162445 0.47302246q-0.9717407 -0.21464539 -1.4872131 -1.1078339q-0.30456543 -0.52778625 -0.3109436 -1.1015167q-0.014190674 -0.58724976 0.21627808 -1.0631866q0.23620605 -0.49728394 0.64520264 -0.9139252q0.31063843 -0.3057251 0.980896 -0.8010864q1.3733215 -1.02771 1.9363098 -1.6776581q-0.14837646 -0.25712585 -0.18740845 -0.32478333q-0.42956543 -0.74432373 -0.93963623 -0.84669495q-0.7156677 -0.14602661 -1.6357422 0.38536072q-0.8524475 0.49230957 -1.0922546 1.0458069q-0.23410034 0.5321655 0.013824463 1.3994293l-1.2843933 0.52526855q-0.2749939 -0.85162354 -0.18301392 -1.5362854q0.10549927 -0.6924591 0.66851807 -1.3424072q0.5552063 -0.66348267 1.4752808 -1.1948547q0.92007446 -0.5313873 1.6133118 -0.6430664q0.7067566 -0.11949158 1.1648254 0.04902649q0.45803833 0.16850281 0.8552551 0.60672q0.24728394 0.27218628 0.71588135 1.0841675l0.9371643 1.6239929q0.9840393 1.7051697 1.3094177 2.1126862q0.33892822 0.39971924 0.81103516 0.68640137l-1.2718506 0.7345581q-0.40811157 -0.26953125 -0.7590027 -0.75253296zm-1.6645203 -2.6654663q-0.5067749 0.65356445 -1.7234497 1.6088562q-0.69522095 0.5458679 -0.9283142 0.8609314q-0.23312378 0.31506348 -0.25280762 0.6873169q-0.013977051 0.3508911 0.16564941 0.66215515q0.28115845 0.48719788 0.83392334 0.6010132q0.5662842 0.10598755 1.269867 -0.30036926q0.70358276 -0.40634155 1.072998 -1.0166473q0.37512207 -0.63165283 0.3197937 -1.3214569q-0.031341553 -0.5232086 -0.49990845 -1.3352051l-0.25775146 -0.44659424zm7.223419 -0.8156891l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25790405q-0.3518982 -0.26591492 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.94732666l0.8930054 -0.5157471l-1.0308838 -1.786377l0.7960205 -1.434082l1.4526367 2.5171661l1.2177429 -0.7033081l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.52778625 0.4446106 0.645401q0.15356445 0.10978699 0.35708618 0.11856079q0.19567871 -0.004760742 0.46627808 -0.16104126q0.20297241 -0.11721802 0.49645996 -0.35890198zm-3.0901794 -8.121277l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80444336 1.3939209l-1.2177734 0.70329285zm4.920227 8.525894l-4.1470337 -7.1861115l1.2177429 -0.70329285l4.147064 7.186096l-1.2177734 0.7033081zm0.54074097 -5.111908q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47280884 -1.3376465 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210968q0.0413208 1.0406952 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959412 -0.43519592 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902527 1.7267303q1.0072021 0.33854675 1.9137268 -0.18501282q0.9065552 -0.5235748 1.1036072 -1.5576019q0.21057129 -1.0418396 -0.60946655 -2.4628143q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635925 -1.9002075 0.17721558q-0.9065552 0.5235596 -1.117096 1.5653992q-0.2048645 1.0204926 0.59173584 2.400879zm8.984772 -0.38945007l-4.1470337 -7.1861115l1.0959473 -0.6329651l0.5857544 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494446 -0.37509155 1.3234558 -0.45761108q0.6739807 -0.08250427 1.163269 0.14013672q0.48928833 0.22264099 0.9021301 0.687912q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.5226135 -4.371216q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3539734 -0.30078125 -0.8442383 -0.3063202q-0.47677612 -0.01335144 -0.9638672 0.2679596q-0.7847595 0.45324707 -1.0640869 1.2821808q-0.2735901 0.8075714 0.52301025 2.1879578l2.264862 3.9246216l-1.2177429 0.70329285z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m290.76047 207.94267c-17.003845 0 -26.795105 -5.566925 -34.00769 -11.133865c-7.212616 -5.566925 -11.846558 -11.13385 -23.693115 -11.13385" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m290.76047 207.94267c-17.003876 0 -26.795105 -5.5669403 -34.00769 -11.133865c-3.606308 -2.7834625 -6.5679474 -5.566925 -10.108826 -7.6545258c-0.4426117 -0.2609558 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.1181488 -0.34950256 -0.17605591l-0.13806152 -0.066833496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m245.48381 184.93082l-9.563843 1.350235l8.194244 5.1131744z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m280.4455 126.06866l129.5748 -74.83465l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m298.03357 135.5387l-5.72464 -9.919807l1.3124695 -0.75800323l5.72464 9.9198l-1.3124695 0.75801086zm3.470581 -2.0044098l-4.1470337 -7.186104l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012146q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.1879654l2.2648926 3.9246216l-1.2177734 0.70329285zm12.360168 -7.1384735l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001297 -1.8564148 0.44062042q-0.9904785 -0.05949402 -1.8883972 -0.6776428q-0.88442993 -0.62597656 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816315 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898376 -2.1887817q-1.0465393 -1.813446 -0.69085693 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm10.854523 3.1318436l-5.740265 -9.946869l1.1094971 -0.6407852l0.5388794 0.9337845q0.07220459 -0.78147125 0.4031067 -1.333458q0.3444214 -0.5597992 1.0480042 -0.9661484q0.92007446 -0.5313797 1.8970032 -0.46406555q0.9769592 0.06730652 1.8285828 0.7302551q0.8573303 0.64160156 1.4508972 1.6701202q0.6404114 1.1097183 0.74212646 2.2238083q0.11526489 1.1062698 -0.3691101 2.01754q-0.4708252 0.90345 -1.3097229 1.3879471q-0.6088867 0.35164642 -1.2443542 0.37583923q-0.62197876 0.01637268 -1.159668 -0.19635773l2.022766 3.5050888l-1.2177429 0.7033005zm-2.551239 -6.952957q0.80441284 1.3939209 1.7418518 1.7366486q0.95095825 0.3349228 1.7492676 -0.12612915q0.8118286 -0.46886444 0.9953308 -1.495079q0.18353271 -1.026207 -0.6443176 -2.4607239q-0.79663086 -1.3803787 -1.7554016 -1.728836q-0.96658325 -0.36198425 -1.7513428 0.09125519q-0.77124023 0.4454193 -0.958374 1.5278625q-0.1736145 1.0746231 0.62298584 2.4550018zm12.239746 -5.408928l1.3442688 -0.57788086q0.34274292 1.2816391 -0.11764526 2.359497q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.91428375 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.1887894q-1.0465088 -1.8134384 -0.6908264 -3.3540878q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10826874 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173462 1.4440689q0.96810913 0.27087402 1.861145 -0.24488068q0.6765137 -0.39071655 0.9470215 -1.0160828q0.2705078 -0.62537384 0.095947266 -1.5530472zm-5.1317444 0.32941437l4.0050354 -2.3130722q-0.60443115 -0.85983276 -1.2410278 -1.0876236q-0.98791504 -0.3677063 -1.9079895 0.1636734q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.848114zm9.2612915 0.3710785l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22263336 0.9021301 0.687912q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253464l-1.2177429 0.70329285l-2.522583 -4.371208q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078888 -0.8442688 -0.30632782q-0.4767456 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.4532318 -1.0640869 1.2821655q-0.2735901 0.80758667 0.52301025 2.187973l2.2648926 3.924614l-1.2177734 0.7033005zm12.360168 -7.138481l-0.5232849 -0.906723q-0.059539795 1.4598389 -1.3855286 2.2256546q-0.8659668 0.5001221 -1.8564148 0.44062805q-0.9904785 -0.05949402 -1.8883972 -0.67765045q-0.88442993 -0.62596893 -1.5170288 -1.7221603q-0.6247864 -1.0826492 -0.77282715 -2.151947q-0.14230347 -1.0906448 0.30926514 -1.9649353q0.44378662 -0.887825 1.336792 -1.4035797q0.6494751 -0.37509155 1.3063049 -0.39356232q0.6703491 -0.026283264 1.2392883 0.2405777l-2.054016 -3.5592194l1.2177429 -0.7033005l5.7246704 9.919807l-1.1365662 0.6564102zm-5.9122925 -1.3669891q0.79660034 1.3803787 1.7767029 1.7345505q0.97232056 0.3406372 1.7570801 -0.112602234q0.7983093 -0.4610443 0.97817993 -1.4310303q0.18560791 -0.991333 -0.58758545 -2.3311157q-0.8512573 -1.4751129 -1.8178406 -1.8370972q-0.96658325 -0.36198425 -1.805481 0.12251282q-0.8118286 0.46886444 -0.97036743 1.4445648q-0.15853882 0.9757004 0.6693115 2.4102173zm12.53952 -5.5459747l1.3442383 -0.57787323q0.34274292 1.2816391 -0.11764526 2.3594894q-0.4468689 1.0700455 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.4430008 -2.4898682 -2.1887817q-1.0465088 -1.813446 -0.6908264 -3.3540955q0.35568237 -1.5406494 1.8440247 -2.4002304q1.4342346 -0.828331 2.9109192 -0.36397552q1.4823608 0.4430008 2.5054626 2.2158508q0.0625 0.10826111 0.18743896 0.32479095l-5.3580933 3.094513q0.75494385 1.1518478 1.7173462 1.4440689q0.96810913 0.27088165 1.861145 -0.24487305q0.6765137 -0.39072418 0.9470215 -1.0160904q0.2705078 -0.6253662 0.09597778 -1.5530472zm-5.131775 0.32941437l4.005066 -2.3130646q-0.60446167 -0.8598404 -1.2410278 -1.0876312q-0.98794556 -0.36769867 -1.90802 0.16368103q-0.8388977 0.48449707 -1.0926819 1.3889084q-0.2480774 0.88306427 0.23666382 1.8481064zm9.261322 0.3710785l-4.147064 -7.186104l1.0959778 -0.6329727l0.5857544 1.0149918q0.105285645 -1.6306229 1.6071777 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45761108q0.67401123 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26290894 0.29925537 0.74710083 1.1383133l2.553833 4.4253387l-1.2177429 0.7033005l-2.522583 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335907 -0.96383667 0.2679596q-0.78479004 0.45323944 -1.0640869 1.2821732q-0.2736206 0.80758667 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005zm9.725006 -7.078125l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.6396408q-0.6629944 0.38291168 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25791168q-0.3518982 -0.26589966 -0.9844971 -1.3620834l-2.382019 -4.127617l-0.8930054 0.5157547l-0.5466919 -0.94731903l0.8930054 -0.5157547l-1.0308838 -1.786377l0.79599 -1.4340897l1.4526672 2.5171661l1.2177429 -0.7033005l0.5466919 0.94732666l-1.2177429 0.70329285l2.421051 4.195282q0.30459595 0.5277939 0.4446106 0.645401q0.15356445 0.10979462 0.35708618 0.11856842q0.19567871 -0.004760742 0.46627808 -0.16104889q0.20297241 -0.11721802 0.49645996 -0.35889435z" fill-rule="nonzero"/><path fill="#000000" d="m307.15155 152.21382l-5.72464 -9.919815l1.2177429 -0.70329285l3.2645264 5.6568604l1.1950684 -4.587631l1.5830688 -0.9142914l-1.2081604 4.252365l5.625824 2.7774506l-1.5018921 0.8674011l-4.4921265 -2.3134918l-0.38955688 1.3256378l1.6478882 2.8554993l-1.2177429 0.7033081zm10.503723 -9.151794l1.3442383 -0.57788086q0.34274292 1.2816315 -0.11764526 2.359497q-0.44683838 1.0700378 -1.6916504 1.788971q-1.5830688 0.9142761 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.188797q-1.0465393 -1.8134308 -0.69085693 -3.3540802q0.35568237 -1.5406494 1.8440247 -2.400238q1.4342346 -0.8283386 2.9109192 -0.36398315q1.4823914 0.44300842 2.5054932 2.2158508q0.062469482 0.10827637 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518402 1.7173157 1.4440765q0.96813965 0.27087402 1.861145 -0.2448883q0.6765137 -0.39071655 0.947052 -1.0160828q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32940674l4.0050354 -2.3130646q-0.60446167 -0.85983276 -1.2410278 -1.0876312q-0.98791504 -0.3677063 -1.90802 0.16368103q-0.8388672 0.48449707 -1.0926819 1.3889008q-0.2480774 0.8830719 0.23669434 1.848114zm9.247772 0.378891l-4.147064 -7.1861115l1.0959778 -0.6329651l0.6247864 1.0826569q-0.017913818 -1.0001068 0.19329834 -1.4468842q0.21124268 -0.44676208 0.64419556 -0.6968231q0.6088867 -0.35165405 1.4713135 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862122q-0.39239502 0.22662354 -0.5765686 0.6577606q-0.17840576 0.40979004 -0.063812256 0.92100525q0.17190552 0.7667999 0.61709595 1.5381927l2.1711426 3.7622223l-1.2177429 0.70329285zm4.6274414 -2.6725311l-4.147064 -7.1861115l1.0959778 -0.6329651l0.5857544 1.0149841q0.105285645 -1.6306152 1.6071777 -2.4980164q0.6494751 -0.37509155 1.3234558 -0.45761108q0.67401123 -0.08250427 1.1632996 0.14012146q0.4892578 0.22264099 0.9020996 0.68792725q0.26290894 0.29925537 0.74710083 1.1383057l2.553833 4.425354l-1.2177429 0.70329285l-2.522583 -4.371216q-0.4295349 -0.74432373 -0.7892456 -1.0237579q-0.3539734 -0.30078125 -0.8442688 -0.3063202q-0.4767456 -0.01335144 -0.96383667 0.2679596q-0.78479004 0.4532318 -1.0640869 1.2821655q-0.2736206 0.80758667 0.52301025 2.187973l2.264862 3.9246216l-1.2177429 0.70329285zm11.281708 -9.60112l1.3442688 -0.57788086q0.34274292 1.2816315 -0.11764526 2.359497q-0.4468689 1.0700378 -1.6916809 1.788971q-1.5830688 0.9142914 -3.0654602 0.47128296q-1.4823914 -0.44300842 -2.4898682 -2.1887817q-1.0465088 -1.813446 -0.6908264 -3.3540955q0.35565186 -1.5406494 1.8440247 -2.400238q1.4342346 -0.828331 2.9108887 -0.36397552q1.4823914 0.44300842 2.5054932 2.2158432q0.062469482 0.10827637 0.18743896 0.32479858l-5.3580933 3.094513q0.75491333 1.1518555 1.7173462 1.4440765q0.96810913 0.27087402 1.861145 -0.24487305q0.6765137 -0.3907318 0.9470215 -1.016098q0.2705078 -0.6253662 0.095947266 -1.5530396zm-5.1317444 0.32940674l4.0050354 -2.3130646q-0.60443115 -0.85983276 -1.2410278 -1.0876312q-0.98791504 -0.3677063 -1.9079895 0.16368103q-0.8388977 0.48449707 -1.0926819 1.388916q-0.2480774 0.88305664 0.23666382 1.8480988zm9.234253 0.3867035l-5.7246704 -9.9198l1.2177734 -0.7033005l5.72464 9.919807l-1.2177429 0.70329285z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m290.76047 143.94267c-17.003845 0 -26.795105 -5.566925 -34.00769 -11.133865c-7.212616 -5.566925 -11.846558 -11.13385 -23.693115 -11.13385" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m290.76047 143.94267c-17.003876 0 -26.795105 -5.5669403 -34.00769 -11.133865c-3.606308 -2.7834625 -6.5679474 -5.566925 -10.108826 -7.6545258c-0.4426117 -0.26094818 -0.8942871 -0.5110321 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.11816406 -0.34950256 -0.17607117l-0.13806152 -0.06682587" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m245.48381 120.93081l-9.563843 1.350235l8.194244 5.113182z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/g3doc/README.md b/g3doc/README.md
index 49d58cdae..22bfb15f7 100644
--- a/g3doc/README.md
+++ b/g3doc/README.md
@@ -1,2 +1,164 @@
-The gVisor logo files are licensed under CC BY-SA 4.0 (Creative Commons
-Attribution-ShareAlike 4.0 International).
+# What is gVisor?
+
+gVisor is an application kernel, written in Go, that implements a substantial
+portion of the [Linux system call interface][linux]. It provides an additional
+layer of isolation between running applications and the host operating system.
+
+gVisor includes an [Open Container Initiative (OCI)][oci] runtime called `runsc`
+that makes it easy to work with existing container tooling. The `runsc` runtime
+integrates with Docker and Kubernetes, making it simple to run sandboxed
+containers.
+
+gVisor can be used with Docker, Kubernetes, or directly using `runsc`. Use the
+links below to see detailed instructions for each of them:
+
+* [Docker](./user_guide/quick_start/docker.md): The quickest and easiest way
+ to get started.
+* [Kubernetes](./user_guide/quick_start/kubernetes.md): Isolate Pods in your
+ K8s cluster with gVisor.
+* [OCI Quick Start](./user_guide/quick_start/oci.md): Expert mode. Customize
+ gVisor for your environment.
+
+## What does gVisor do?
+
+gVisor provides a virtualized environment in order to sandbox containers. The
+system interfaces normally implemented by the host kernel are moved into a
+distinct, per-sandbox application kernel in order to minimize the risk of an
+container escape exploit. gVisor does not introduce large fixed overheads
+however, and still retains a process-like model with respect to resource
+utilization.
+
+## How is this different?
+
+Two other approaches are commonly taken to provide stronger isolation than
+native containers.
+
+**Machine-level virtualization**, such as [KVM][kvm] and [Xen][xen], exposes
+virtualized hardware to a guest kernel via a Virtual Machine Monitor (VMM). This
+virtualized hardware is generally enlightened (paravirtualized) and additional
+mechanisms can be used to improve the visibility between the guest and host
+(e.g. balloon drivers, paravirtualized spinlocks). Running containers in
+distinct virtual machines can provide great isolation, compatibility and
+performance (though nested virtualization may bring challenges in this area),
+but for containers it often requires additional proxies and agents, and may
+require a larger resource footprint and slower start-up times.
+
+![Machine-level virtualization](Machine-Virtualization.png "Machine-level virtualization")
+
+**Rule-based execution**, such as [seccomp][seccomp], [SELinux][selinux] and
+[AppArmor][apparmor], allows the specification of a fine-grained security policy
+for an application or container. These schemes typically rely on hooks
+implemented inside the host kernel to enforce the rules. If the surface can be
+made small enough, then this is an excellent way to sandbox applications and
+maintain native performance. However, in practice it can be extremely difficult
+(if not impossible) to reliably define a policy for arbitrary, previously
+unknown applications, making this approach challenging to apply universally.
+
+![Rule-based execution](Rule-Based-Execution.png "Rule-based execution")
+
+Rule-based execution is often combined with additional layers for
+defense-in-depth.
+
+**gVisor** provides a third isolation mechanism, distinct from those above.
+
+gVisor intercepts application system calls and acts as the guest kernel, without
+the need for translation through virtualized hardware. gVisor may be thought of
+as either a merged guest kernel and VMM, or as seccomp on steroids. This
+architecture allows it to provide a flexible resource footprint (i.e. one based
+on threads and memory mappings, not fixed guest physical resources) while also
+lowering the fixed costs of virtualization. However, this comes at the price of
+reduced application compatibility and higher per-system call overhead.
+
+![gVisor](Layers.png "gVisor")
+
+On top of this, gVisor employs rule-based execution to provide defense-in-depth
+(details below).
+
+gVisor's approach is similar to [User Mode Linux (UML)][uml], although UML
+virtualizes hardware internally and thus provides a fixed resource footprint.
+
+Each of the above approaches may excel in distinct scenarios. For example,
+machine-level virtualization will face challenges achieving high density, while
+gVisor may provide poor performance for system call heavy workloads.
+
+## Why Go?
+
+gVisor is written in [Go][golang] in order to avoid security pitfalls that can
+plague kernels. With Go, there are strong types, built-in bounds checks, no
+uninitialized variables, no use-after-free, no stack overflow, and a built-in
+race detector. However, the use of Go has its challenges, and the runtime often
+introduces performance overhead.
+
+## What are the different components?
+
+A gVisor sandbox consists of multiple processes. These processes collectively
+comprise an environment in which one or more containers can be run.
+
+Each sandbox has its own isolated instance of:
+
+* The **Sentry**, which is a kernel that runs the containers and intercepts
+ and responds to system calls made by the application.
+
+Each container running in the sandbox has its own isolated instance of:
+
+* A **Gofer** which provides file system access to the containers.
+
+![gVisor architecture diagram](Sentry-Gofer.png "gVisor architecture diagram")
+
+## What is runsc?
+
+The entrypoint to running a sandboxed container is the `runsc` executable.
+`runsc` implements the [Open Container Initiative (OCI)][oci] runtime
+specification, which is used by Docker and Kubernetes. This means that OCI
+compatible _filesystem bundles_ can be run by `runsc`. Filesystem bundles are
+comprised of a `config.json` file containing container configuration, and a root
+filesystem for the container. Please see the [OCI runtime spec][runtime-spec]
+for more information on filesystem bundles. `runsc` implements multiple commands
+that perform various functions such as starting, stopping, listing, and querying
+the status of containers.
+
+### Sentry {#sentry}
+
+The Sentry is the largest component of gVisor. It can be thought of as a
+application kernel. The Sentry implements all the kernel functionality needed by
+the application, including: system calls, signal delivery, memory management and
+page faulting logic, the threading model, and more.
+
+When the application makes a system call, the
+[Platform](./architecture_guide/platforms.md) redirects the call to the Sentry,
+which will do the necessary work to service it. It is important to note that the
+Sentry does not pass system calls through to the host kernel. As a userspace
+application, the Sentry will make some host system calls to support its
+operation, but it does not allow the application to directly control the system
+calls it makes. For example, the Sentry is not able to open files directly; file
+system operations that extend beyond the sandbox (not internal `/proc` files,
+pipes, etc) are sent to the Gofer, described below.
+
+### Gofer {#gofer}
+
+The Gofer is a standard host process which is started with each container and
+communicates with the Sentry via the [9P protocol][9p] over a socket or shared
+memory channel. The Sentry process is started in a restricted seccomp container
+without access to file system resources. The Gofer mediates all access to the
+these resources, providing an additional level of isolation.
+
+### Application {#application}
+
+The application is a normal Linux binary provided to gVisor in an OCI runtime
+bundle. gVisor aims to provide an environment equivalent to Linux v4.4, so
+applications should be able to run unmodified. However, gVisor does not
+presently implement every system call, `/proc` file, or `/sys` file so some
+incompatibilities may occur. See [Compatibility](./user_guide/compatibility.md)
+for more information.
+
+[9p]: https://en.wikipedia.org/wiki/9P_(protocol)
+[apparmor]: https://wiki.ubuntu.com/AppArmor
+[golang]: https://golang.org
+[kvm]: https://www.linux-kvm.org
+[linux]: https://en.wikipedia.org/wiki/Linux_kernel_interfaces
+[oci]: https://www.opencontainers.org
+[runtime-spec]: https://github.com/opencontainers/runtime-spec
+[seccomp]: https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt
+[selinux]: https://selinuxproject.org
+[uml]: http://user-mode-linux.sourceforge.net/
+[xen]: https://www.xenproject.org
diff --git a/g3doc/Rule-Based-Execution.png b/g3doc/Rule-Based-Execution.png
new file mode 100644
index 000000000..b42654a90
--- /dev/null
+++ b/g3doc/Rule-Based-Execution.png
Binary files differ
diff --git a/g3doc/Rule-Based-Execution.svg b/g3doc/Rule-Based-Execution.svg
new file mode 100644
index 000000000..bd6717043
--- /dev/null
+++ b/g3doc/Rule-Based-Execution.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 355.03674540682414 172.5564304461942" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l355.03674 0l0 172.55643l-355.03674 0l0 -172.55643z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l355.03674 0l0 172.55643l-355.03674 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 6.6430445l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m78.206116 37.98824l5.125 -13.359373l1.90625 0l5.46875 13.359373l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703123q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921873zm9.849823 9.1875l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.891342 8.484375l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.844467 4.78125l0 -13.359373l1.640625 0l0 13.359373l-1.640625 0zm4.191696 -11.468748l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.5781231 0.515625 -2.749998q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.812498q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5624981q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.187498q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578123l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671873q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -9.999998l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm3.5354462 -4.84375q0 -2.687498 1.484375 -3.968748q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609373q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.7968731 -0.8125 -2.718748q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765623zm9.297592 4.84375l0 -9.671873l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.5937481l0 5.953125l-1.640625 0l0 -5.890625q0 -0.9999981 -0.203125 -1.4843731q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515623l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m3.6351707 71.39028l48.850395 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m3.6351707 71.39028l48.850395 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m195.25722 71.39028l47.338577 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m195.25722 71.39028l47.338577 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m52.485565 55.358784l142.77165 0l0 32.062992l-142.77165 0z" fill-rule="evenodd"/><path fill="#000000" d="m65.21821 76.19028l0 -9.546875l1.265625 0l0 8.421875l4.703125 0l0 1.125l-5.96875 0zm7.3343506 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm2.945465 0l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm11.118057 -8.1875l0 -1.359375l1.171875 0l0 1.359375l-1.171875 0zm0 8.1875l0 -6.90625l1.171875 0l0 6.90625l-1.171875 0zm5.507965 -1.046875l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm11.006226 4.125l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm9.865463 1.390625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm7.0859375 4.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8124924 0 1.2031174 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1874924 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8124924 0 1.4218674 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0624924 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.2499924 0.328125 1.7343674 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.4531174 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.695305 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321045 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.40625 0 -2.28125 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.65625 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.453125 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.40625 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m36.454067 85.4808l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 85.4808l174.83464 0l0 48.850395l-174.83464 0z" fill-rule="evenodd"/><path fill="#000000" d="m76.63558 116.82599l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.6249924 -6.625l2.390625 0l-5.5937424 5.421875l5.8437424 7.9375l-2.328125 0l-4.7656174 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943565 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.454068 151.90599l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.454068 151.90599l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m172.45407 151.90599l74.04724 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m172.45407 151.90599l74.04724 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m73.43044 135.54013l100.88189 0l0 32.06299l-100.88189 0z" fill-rule="evenodd"/><path fill="#000000" d="m96.04542 156.37163l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.0937424 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.4062424 5.3125l-1.21875 0zm12.859535 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59376526 0.21875 -1.2812653 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.4218903 -0.171875 2.0937653 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.3437653 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.89064026 0 1.4375153 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.9218903 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.2031403 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.54555197232122 0.0 0.0 4.54555197232122 0.0 0.0)" spreadMethod="pad" x1="8.189483259998303" y1="18.80511284496466" x2="8.189466907412452" y2="23.35066481725647"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m37.225723 85.48025l173.29134 0l0 20.661415l-173.29134 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m272.4455 100.54161l129.5748 -74.83464l20.629913 35.74803l-129.5748 74.83464z" fill-rule="evenodd"/><path fill="#000000" d="m287.51392 107.20854l1.1823425 -0.8271866q0.51071167 0.6974335 1.1166077 0.9970856q0.5980835 0.28610992 1.4464111 0.19311523q0.84054565 -0.10652161 1.6794434 -0.5910187q0.75772095 -0.4376068 1.2010193 -0.98233795q0.44906616 -0.5660858 0.50097656 -1.1013031q0.057678223 -0.55656433 -0.20785522 -1.0166931q-0.27334595 -0.47366333 -0.7392273 -0.6557007q-0.47366333 -0.1955719 -1.2366333 -0.079704285q-0.478302 0.07775116 -2.032318 0.54221344q-1.5618286 0.45092773 -2.2805786 0.48712158q-0.9222717 0.027420044 -1.5864563 -0.31072998q-0.6719971 -0.35168457 -1.0703125 -1.0418777q-0.4295349 -0.74432373 -0.38497925 -1.6361618q0.05029297 -0.9131851 0.6668701 -1.7203751q0.63012695 -0.8150101 1.6313782 -1.39328q1.1095276 -0.6407852 2.1592712 -0.75988007q1.0419617 -0.13263702 1.8867493 0.29968262q0.8583679 0.4244995 1.3987732 1.2671738l-1.2036743 0.8214798q-0.64712524 -0.87127686 -1.5022583 -1.0089188q-0.8629761 -0.15116882 -2.013092 0.5130539q-1.1906738 0.6876755 -1.4819641 1.4333115q-0.2913208 0.745636 0.06793213 1.3681641q0.30459595 0.5277939 0.8865051 0.6608505q0.5740967 0.119522095 2.3815613 -0.43717957q1.8210144 -0.5645218 2.5725403 -0.6376953q1.0924377 -0.107666016 1.857605 0.28042603q0.77090454 0.366745 1.2316895 1.1652069q0.4529724 0.78491974 0.39904785 1.7543335q-0.040405273 0.96160126 -0.6663208 1.8463745q-0.62594604 0.8847656 -1.6813354 1.4942932q-1.3530579 0.78144073 -2.486084 0.9125519q-1.1408386 0.11756897 -2.1214905 -0.3625946q-0.96713257 -0.48797607 -1.5721436 -1.4738007zm13.40155 -4.9431534l0.8006897 0.98106384q-0.45169067 0.4052124 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454773 0.39089966q-0.4902649 -0.00554657 -0.8343506 -0.25791168q-0.3518982 -0.2659073 -0.9844971 -1.3620911l-2.382019 -4.127617l-0.8930054 0.5157547l-0.5466919 -0.94731903l0.8930054 -0.5157547l-1.0308838 -1.786377l0.79599 -1.4340897l1.4526367 2.5171661l1.2177734 -0.70329285l0.5466919 0.94731903l-1.2177734 0.7033005l2.4210815 4.1952744q0.30456543 0.5277939 0.4446106 0.645401q0.15356445 0.10979462 0.35705566 0.11856842q0.19570923 -0.004760742 0.4663086 -0.16104889q0.20297241 -0.11721802 0.49645996 -0.35889435zm1.8165283 0.41242218l-4.147064 -7.186104l1.0959778 -0.6329727l0.6247864 1.0826492q-0.0178833 -1.0000992 0.19332886 -1.4468689q0.21121216 -0.44677734 0.64419556 -0.6968384q0.6088562 -0.35164642 1.471283 -0.3264618l0.22875977 1.3654938q-0.59487915 7.4768066E-4 -1.0413818 0.25862885q-0.39239502 0.2266159 -0.5765381 0.6577606q-0.17843628 0.40979004 -0.06384277 0.9209976q0.17193604 0.76680756 0.61709595 1.5382004l2.1711426 3.7622147l-1.2177429 0.7033005zm2.0899658 -6.0066605q-1.1480408 -1.9893799 -0.5930481 -3.5910187q0.47280884 -1.3376465 1.7988281 -2.1034622q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384125 2.4934998 2.132553q0.79660034 1.3803787 0.83795166 2.4210892q0.04135132 1.0407028 -0.49923706 1.948349q-0.5348511 0.8862915 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.4351883 -2.5502625 -2.2621613zm1.2583313 -0.7267456q0.79663086 1.3803787 1.7902527 1.726738q1.0072021 0.33853912 1.9137268 -0.18502045q0.9065552 -0.5235672 1.1036072 -1.5575943q0.21057129 -1.0418396 -0.60946655 -2.462822q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635162 -1.9002075 0.17721558q-0.90652466 0.5235672 -1.117096 1.5654068q-0.2048645 1.0204926 0.59173584 2.4008713zm8.984772 -0.38944244l-4.1470337 -7.1861115l1.0959778 -0.6329651l0.5857239 1.0149841q0.10531616 -1.6306229 1.6072083 -2.498024q0.6494751 -0.37509155 1.3234558 -0.45760345q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26287842 0.29925537 0.74710083 1.1383057l2.553833 4.4253464l-1.2177429 0.7033005l-2.522583 -4.371216q-0.42956543 -0.74432373 -0.7892456 -1.0237579q-0.3540039 -0.30078125 -0.8442688 -0.30632782q-0.47677612 -0.01335144 -0.9638672 0.26796722q-0.7847595 0.45323944 -1.0640869 1.2821732q-0.2735901 0.80757904 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005zm7.819275 -3.7220154l1.2922058 -0.511734q0.3878479 0.5157852 0.8666992 0.56401825q0.65527344 0.072639465 1.4400635 -0.38059235q0.8388977 -0.48449707 1.1036682 -1.0885162q0.26480103 -0.60401917 0.07571411 -1.3067856q-0.1161499 -0.42009735 -0.81121826 -1.6245499q-0.25161743 1.408371 -1.4423218 2.0960464q-1.4883423 0.859581 -2.9171448 0.25933075q-1.428833 -0.60025024 -2.2957153 -2.1024323q-0.5935669 -1.0285187 -0.72805786 -2.1056366q-0.12097168 -1.0849228 0.30926514 -1.9649353q0.4437561 -0.887825 1.3908997 -1.4348297q1.2718811 -0.7345581 2.690796 -0.18271637l-0.4998474 -0.866127l1.1230469 -0.6485977l3.5847168 6.2117233q0.9684448 1.6781082 1.0362854 2.5771942q0.067840576 0.899086 -0.4420166 1.7348709q-0.5098877 0.83579254 -1.5923462 1.4609451q-1.2854004 0.7423706 -2.4195251 0.6214981q-1.1206055 -0.12869263 -1.7651672 -1.3081741zm-1.4765625 -4.90316q0.81222534 1.4074478 1.7418518 1.7366486q0.94314575 0.32138824 1.7820435 -0.16310883q0.8388977 -0.48448944 1.0401001 -1.4487534q0.19342041 -0.97779846 -0.60317993 -2.3581848q-0.76538086 -1.3262482 -1.7298889 -1.6533508q-0.97229004 -0.3406372 -1.7976685 0.13603973q-0.8118286 0.46886444 -0.9974365 1.4601974q-0.1855774 0.991333 0.56417847 2.290512z" fill-rule="nonzero"/><path fill="#000000" d="m294.23132 118.16088l-0.80441284 -1.3939133l1.2177429 -0.7033005l0.80441284 1.3939209l-1.2177429 0.70329285zm4.920227 8.525894l-4.147064 -7.1861115l1.2177734 -0.70329285l4.1470337 7.186104l-1.2177429 0.7033005zm1.3493347 -3.6482391l1.0948792 -0.88494873q0.51641846 0.67609406 1.2029724 0.8028641q0.6922302 0.10542297 1.5176086 -0.37125397q0.8388672 -0.48449707 1.0495605 -1.0572586q0.20285034 -0.5862961 -0.062683105 -1.0464249q-0.23431396 -0.4059906 -0.7266846 -0.44641113q-0.35079956 -0.013923645 -1.4790955 0.31292725q-1.5347595 0.43530273 -2.2030334 0.4964676q-0.66256714 0.03981781 -1.183075 -0.23695374q-0.5069885 -0.28459167 -0.8115845 -0.8123779q-0.28115845 -0.48719788 -0.2989807 -1.0182266q-0.017791748 -0.5310211 0.2048645 -1.0204926q0.15917969 -0.3806305 0.56817627 -0.79727936q0.4147339 -0.43800354 0.9694824 -0.75839233q0.852417 -0.49230957 1.6289368 -0.6159897q0.7765198 -0.123680115 1.3162842 0.123931885q0.5455017 0.22625732 1.0733948 0.8596573l-1.0969849 0.85006714q-0.4013672 -0.5079727 -0.9733887 -0.59262085q-0.5720215 -0.0846405 -1.2756042 0.32170868q-0.8388977 0.48449707 -1.0402222 0.9796829q-0.19558716 0.47383118 0.015289307 0.8392334q0.14056396 0.24359131 0.39874268 0.3470993q0.2581787 0.103507996 0.64749146 0.05910492q0.22845459 -0.041732788 1.2620544 -0.31388855q1.4806519 -0.4040451 2.127594 -0.470932q0.6390991 -0.08041382 1.1653442 0.17501068q0.5397949 0.247612 0.89904785 0.87013245q0.35144043 0.60899353 0.29852295 1.3613129q-0.047210693 0.73096466 -0.5519409 1.4194183q-0.49118042 0.68063354 -1.3300476 1.1651306q-1.407196 0.81269836 -2.4736633 0.65275574q-1.0664673 -0.15995026 -1.933258 -1.1930542zm6.1190186 -5.4646606q-1.1480408 -1.9893723 -0.5930481 -3.591011q0.47283936 -1.3376541 1.7988281 -2.1034698q1.4883423 -0.859581 2.984253 -0.4243927q1.501648 0.41384888 2.4934998 2.1325607q0.79663086 1.3803787 0.83795166 2.4210815q0.04135132 1.0407028 -0.49923706 1.948349q-0.5348511 0.88629913 -1.4819946 1.4333115q-1.5018921 0.8674011 -2.9899902 0.44573975q-1.4959106 -0.43519592 -2.5502625 -2.262169zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902832 1.7267456q1.0071716 0.33853912 1.9136963 -0.18502808q0.9065552 -0.5235672 1.1036072 -1.5575943q0.21057129 -1.0418396 -0.60946655 -2.462822q-0.76538086 -1.3262482 -1.7725525 -1.6647873q-0.99365234 -0.34635925 -1.9002075 0.17720795q-0.90652466 0.5235672 -1.117096 1.5654068q-0.2048645 1.0204926 0.59173584 2.4008713zm8.957733 -0.3738098l-5.7246704 -9.919807l1.2177734 -0.7033005l5.72464 9.919807l-1.2177429 0.7033005zm7.2713623 -5.390396q-0.34069824 0.9726486 -0.8225403 1.5757141q-0.48962402 0.5895233 -1.2067566 1.003685q-1.1906738 0.6876755 -2.162445 0.47302246q-0.9717407 -0.21464539 -1.4872131 -1.1078339q-0.30456543 -0.5277939 -0.3109436 -1.1015167q-0.014190674 -0.58724976 0.21627808 -1.0631866q0.23620605 -0.4972763 0.64520264 -0.9139328q0.31063843 -0.30571747 0.980896 -0.8010864q1.3733215 -1.0277023 1.9363098 -1.6776505q-0.14837646 -0.25712585 -0.18740845 -0.32479095q-0.42956543 -0.74432373 -0.93963623 -0.84669495q-0.7156677 -0.14601898 -1.6357422 0.38536072q-0.8524475 0.49230957 -1.0922546 1.0458145q-0.23410034 0.5321655 0.013824463 1.3994217l-1.2843933 0.5252762q-0.2749939 -0.85163116 -0.18301392 -1.5362854q0.10549927 -0.6924591 0.66851807 -1.3424072q0.5552063 -0.66348267 1.4752808 -1.1948624q0.92007446 -0.5313797 1.6133118 -0.6430588q0.7067566 -0.11948395 1.1648254 0.04901886q0.45803833 0.16851044 0.8552551 0.60672q0.24728394 0.27218628 0.71588135 1.0841827l0.9371643 1.6239777q0.9840393 1.7051773 1.3094177 2.1127014q0.33892822 0.39970398 0.81103516 0.6863861l-1.2718506 0.7345581q-0.40811157 -0.26953125 -0.7590027 -0.75253296zm-1.6645203 -2.6654587q-0.5067749 0.65356445 -1.7234497 1.6088486q-0.69522095 0.5458679 -0.9283142 0.8609314q-0.23312378 0.31506348 -0.25280762 0.6873169q-0.013977051 0.35090637 0.16564941 0.6621628q0.28115845 0.48719788 0.83392334 0.60100555q0.5662842 0.10598755 1.269867 -0.30036163q0.70358276 -0.40634918 1.072998 -1.016655q0.37512207 -0.63165283 0.3197937 -1.3214569q-0.031341553 -0.5232086 -0.49990845 -1.3351974l-0.25775146 -0.44659424zm7.223419 -0.8156891l0.8006897 0.98106384q-0.45169067 0.40522003 -0.857605 0.63964844q-0.6629944 0.38290405 -1.1454468 0.39089966q-0.4902954 -0.0055389404 -0.8343811 -0.25791168q-0.3518982 -0.26589966 -0.9844971 -1.3620911l-2.382019 -4.1276093l-0.8930054 0.5157471l-0.5466919 -0.94731903l0.8930054 -0.5157471l-1.0308838 -1.786377l0.7960205 -1.4340897l1.4526367 2.5171661l1.2177429 -0.7033005l0.5466919 0.94731903l-1.2177429 0.7033005l2.421051 4.195282q0.30459595 0.5277939 0.4446106 0.645401q0.15356445 0.10978699 0.35708618 0.11856079q0.19567871 -0.004760742 0.46627808 -0.16104889q0.20297241 -0.11721039 0.49645996 -0.35889435zm-3.0901794 -8.121277l-0.80441284 -1.3939209l1.2177429 -0.70329285l0.80444336 1.3939133l-1.2177734 0.7033005zm4.920227 8.525887l-4.1470337 -7.186104l1.2177429 -0.7033005l4.147064 7.186104l-1.2177734 0.7033005zm0.54074097 -5.111908q-1.1480408 -1.9893799 -0.5930481 -3.591011q0.47280884 -1.3376541 1.7988281 -2.1034698q1.4883423 -0.8595886 2.984253 -0.4243927q1.501648 0.41384125 2.4934998 2.132553q0.79660034 1.3803864 0.83795166 2.4210892q0.0413208 1.0407028 -0.49923706 1.948349q-0.5348511 0.88629913 -1.4819946 1.4333038q-1.5018921 0.8674011 -2.9899902 0.44574738q-1.4959412 -0.43519592 -2.5502625 -2.262169zm1.2583313 -0.7267456q0.79663086 1.3803864 1.7902527 1.726738q1.0072021 0.33854675 1.9137268 -0.18502045q0.9065552 -0.5235672 1.1036072 -1.5575943q0.21057129 -1.0418396 -0.60946655 -2.462822q-0.76538086 -1.3262482 -1.7725525 -1.6647949q-0.99365234 -0.34635162 -1.9002075 0.17721558q-0.9065552 0.5235672 -1.117096 1.5654068q-0.2048645 1.0204926 0.59173584 2.4008713zm8.984772 -0.38944244l-4.1470337 -7.1861115l1.0959473 -0.6329651l0.5857544 1.0149841q0.10531616 -1.6306152 1.6072083 -2.4980164q0.6494446 -0.37509918 1.3234558 -0.45761108q0.6739807 -0.0825119 1.163269 0.14012909q0.48928833 0.22264099 0.9021301 0.6879196q0.26287842 0.29925537 0.74710083 1.1383133l2.553833 4.4253387l-1.2177429 0.7033005l-2.5226135 -4.371208q-0.4295349 -0.74432373 -0.7892456 -1.0237656q-0.3539734 -0.30078125 -0.8442383 -0.3063202q-0.47677612 -0.01335907 -0.9638672 0.2679596q-0.7847595 0.45323944 -1.0640869 1.2821732q-0.2735901 0.80757904 0.52301025 2.1879654l2.264862 3.924614l-1.2177429 0.7033005z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m282.76047 118.41563c-17.003845 0 -26.795105 -5.566925 -34.007706 -11.133858c-7.2126007 -5.566925 -11.846542 -11.133858 -23.6931 -11.133858" fill-rule="evenodd"/><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.76047 118.415634c-17.003876 0 -26.795105 -5.5669327 -34.007706 -11.133858c-3.6062927 -2.7834702 -6.567932 -5.5669327 -10.10881 -7.6545334c-0.4426117 -0.26094818 -0.8942871 -0.51101685 -1.3573761 -0.74887085c-0.11578369 -0.0594635 -0.23228455 -0.11816406 -0.34950256 -0.17607117l-0.13806152 -0.06682587" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="2.0" stroke-linecap="butt" d="m237.48381 95.40378l-9.563843 1.350235l8.194244 5.1131744z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/g3doc/Sentry-Gofer.png b/g3doc/Sentry-Gofer.png
new file mode 100644
index 000000000..ca2c27ef7
--- /dev/null
+++ b/g3doc/Sentry-Gofer.png
Binary files differ
diff --git a/g3doc/Sentry-Gofer.svg b/g3doc/Sentry-Gofer.svg
new file mode 100644
index 000000000..5c10750d2
--- /dev/null
+++ b/g3doc/Sentry-Gofer.svg
@@ -0,0 +1 @@
+<svg version="1.1" viewBox="0.0 0.0 358.8556430446194 249.67191601049868" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l358.85565 0l0 249.67192l-358.85565 0l0 -249.67192z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l358.85565 0l0 249.67192l-358.85565 0z" fill-rule="evenodd"/><path fill="#f4cccc" d="m36.454067 6.6430445l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path stroke="#cc4125" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 6.6430445l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path fill="#000000" d="m48.00139 37.98824l5.125 -13.359373l1.90625 0l5.46875 13.359373l-2.015625 0l-1.546875 -4.046875l-5.59375 0l-1.46875 4.046875l-1.875 0zm3.859375 -5.484375l4.53125 0l-1.40625 -3.703123q-0.625 -1.6875 -0.9375 -2.765625q-0.265625 1.28125 -0.71875 2.546875l-1.46875 3.921873zm9.849823 9.1875l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.891342 8.484375l0 -13.374998l1.484375 0l0 1.25q0.53125 -0.734375 1.1875 -1.09375q0.671875 -0.375 1.625 -0.375q1.234375 0 2.171875 0.640625q0.953125 0.625 1.4375 1.796875q0.484375 1.15625 0.484375 2.546873q0 1.484375 -0.53125 2.671875q-0.53125 1.1875 -1.546875 1.828125q-1.015625 0.625 -2.140625 0.625q-0.8125 0 -1.46875 -0.34375q-0.65625 -0.34375 -1.0625 -0.875l0 4.703125l-1.640625 0zm1.484375 -8.484375q0 1.859375 0.75 2.765625q0.765625 0.890625 1.828125 0.890625q1.09375 0 1.875 -0.921875q0.78125 -0.9375 0.78125 -2.875q0 -1.8437481 -0.765625 -2.765623q-0.75 -0.921875 -1.8125 -0.921875q-1.046875 0 -1.859375 0.984375q-0.796875 0.96875 -0.796875 2.843748zm8.844467 4.78125l0 -13.359373l1.640625 0l0 13.359373l-1.640625 0zm4.191696 -11.468748l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm10.457321 -3.546875l1.609375 0.21875q-0.265625 1.65625 -1.359375 2.609375q-1.078125 0.9375 -2.671875 0.9375q-1.984375 0 -3.1875 -1.296875q-1.203125 -1.296875 -1.203125 -3.71875q0 -1.5781231 0.515625 -2.749998q0.515625 -1.171875 1.578125 -1.75q1.0625 -0.59375 2.3125 -0.59375q1.578125 0 2.578125 0.796875q1.0 0.796875 1.28125 2.265625l-1.59375 0.234375q-0.234375 -0.96875 -0.8125 -1.453125q-0.578125 -0.5 -1.390625 -0.5q-1.234375 0 -2.015625 0.890625q-0.78125 0.890625 -0.78125 2.812498q0 1.953125 0.75 2.84375q0.75 0.875 1.953125 0.875q0.96875 0 1.609375 -0.59375q0.65625 -0.59375 0.828125 -1.828125zm9.328125 2.359375q-0.921875 0.765625 -1.765625 1.09375q-0.828125 0.3125 -1.796875 0.3125q-1.59375 0 -2.453125 -0.78125q-0.859375 -0.78125 -0.859375 -1.984375q0 -0.71875 0.328125 -1.296875q0.328125 -0.59375 0.84375 -0.9375q0.53125 -0.359375 1.1875 -0.546875q0.46875 -0.125 1.453125 -0.25q1.984375 -0.234375 2.921875 -0.5624981q0.015625 -0.34375 0.015625 -0.421875q0 -1.0 -0.46875 -1.421875q-0.625 -0.546875 -1.875 -0.546875q-1.15625 0 -1.703125 0.40625q-0.546875 0.40625 -0.8125 1.421875l-1.609375 -0.21875q0.21875 -1.015625 0.71875 -1.640625q0.5 -0.640625 1.453125 -0.984375q0.953125 -0.34375 2.1875 -0.34375q1.25 0 2.015625 0.296875q0.78125 0.28125 1.140625 0.734375q0.375 0.4375 0.515625 1.109375q0.078125 0.421875 0.078125 1.515625l0 2.187498q0 2.28125 0.109375 2.890625q0.109375 0.59375 0.40625 1.15625l-1.703125 0q-0.265625 -0.515625 -0.328125 -1.1875zm-0.140625 -3.671875q-0.890625 0.375 -2.671875 0.625q-1.015625 0.140625 -1.4375 0.328125q-0.421875 0.1875 -0.65625 0.53125q-0.21875 0.34375 -0.21875 0.78125q0 0.65625 0.5 1.09375q0.5 0.4375 1.453125 0.4375q0.9375 0 1.671875 -0.40625q0.75 -0.421875 1.09375 -1.140625q0.265625 -0.5625 0.265625 -1.640625l0 -0.609375zm7.781967 3.390625l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578123l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671873q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.6051788 -9.999998l0 -1.890625l1.640625 0l0 1.890625l-1.640625 0zm0 11.468748l0 -9.671873l1.640625 0l0 9.671873l-1.640625 0zm3.5354462 -4.84375q0 -2.687498 1.484375 -3.968748q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609373q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.7968731 -0.8125 -2.718748q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765623zm9.297592 4.84375l0 -9.671873l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.5937481l0 5.953125l-1.640625 0l0 -5.890625q0 -0.9999981 -0.203125 -1.4843731q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515623l0 5.28125l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m4.4540663 73.055115l40.47244 0.25196838" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m4.4540663 73.055115l40.47244 0.25196838" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m144.0 74.0l35.27559 -0.94488525" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m144.0 74.0l35.27559 -0.94488525" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m44.926506 57.27559l97.48032 0l0 32.062996l-97.48032 0z" fill-rule="evenodd"/><path fill="#000000" d="m56.859642 75.044586l1.203125 -0.109375q0.078125 0.71875 0.390625 1.1875q0.3125 0.453125 0.953125 0.734375q0.65625 0.28125 1.46875 0.28125q0.71875 0 1.265625 -0.21875q0.5625 -0.21875 0.828125 -0.578125q0.265625 -0.375 0.265625 -0.828125q0 -0.453125 -0.265625 -0.78125q-0.25 -0.328125 -0.84375 -0.5625q-0.390625 -0.15625 -1.703125 -0.46875q-1.3125 -0.3125 -1.84375 -0.59375q-0.671875 -0.359375 -1.015625 -0.890625q-0.328125 -0.53125 -0.328125 -1.1875q0 -0.71875 0.40625 -1.34375q0.40625 -0.625 1.1875 -0.953125q0.796875 -0.328125 1.765625 -0.328125q1.046875 0 1.859375 0.34375q0.8125 0.34375 1.25 1.015625q0.4375 0.65625 0.46875 1.484375l-1.203125 0.09375q-0.109375 -0.90625 -0.671875 -1.359375q-0.5625 -0.46875 -1.65625 -0.46875q-1.140625 0 -1.671875 0.421875q-0.515625 0.421875 -0.515625 1.015625q0 0.515625 0.359375 0.84375q0.375 0.328125 1.90625 0.6875q1.546875 0.34375 2.109375 0.59375q0.84375 0.390625 1.234375 0.984375q0.390625 0.578125 0.390625 1.359375q0 0.75 -0.4375 1.4375q-0.421875 0.671875 -1.25 1.046875q-0.8125 0.359375 -1.828125 0.359375q-1.296875 0 -2.171875 -0.375q-0.875 -0.375 -1.375 -1.125q-0.5 -0.765625 -0.53125 -1.71875zm9.12413 5.71875l-0.125 -1.09375q0.375 0.109375 0.65625 0.109375q0.390625 0 0.625 -0.140625q0.234375 -0.125 0.390625 -0.359375q0.109375 -0.171875 0.359375 -0.875q0.03125 -0.09375 0.109375 -0.28125l-2.625 -6.921875l1.265625 0l1.4375 4.0q0.28125 0.765625 0.5 1.59375q0.203125 -0.796875 0.46875 -1.578125l1.484375 -4.015625l1.171875 0l-2.625 7.015625q-0.421875 1.140625 -0.65625 1.578125q-0.3125 0.578125 -0.71875 0.84375q-0.40625 0.28125 -0.96875 0.28125q-0.328125 0 -0.75 -0.15625zm6.2421875 -4.71875l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125 0 1.203125 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125 0 1.421875 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.25 0.328125 1.734375 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.453125 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625zm9.6953125 1.015625l0.171875 1.03125q-0.5 0.109375 -0.890625 0.109375q-0.640625 0 -1.0 -0.203125q-0.34375 -0.203125 -0.484375 -0.53125q-0.140625 -0.328125 -0.140625 -1.390625l0 -3.96875l-0.859375 0l0 -0.90625l0.859375 0l0 -1.71875l1.171875 -0.703125l0 2.421875l1.171875 0l0 0.90625l-1.171875 0l0 4.046875q0 0.5 0.046875 0.640625q0.0625 0.140625 0.203125 0.234375q0.140625 0.078125 0.40625 0.078125q0.203125 0 0.515625 -0.046875zm5.8748627 -1.171875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375zm6.5218506 4.125l0 -6.90625l1.046875 0l0 0.96875q0.328125 -0.515625 0.859375 -0.8125q0.546875 -0.3125 1.234375 -0.3125q0.78125 0 1.265625 0.3125q0.484375 0.3125 0.6875 0.890625q0.828125 -1.203125 2.140625 -1.203125q1.03125 0 1.578125 0.578125q0.5625 0.5625 0.5625 1.734375l0 4.75l-1.171875 0l0 -4.359375q0 -0.703125 -0.125 -1.0q-0.109375 -0.3125 -0.40625 -0.5q-0.296875 -0.1875 -0.703125 -0.1875q-0.71875 0 -1.203125 0.484375q-0.484375 0.484375 -0.484375 1.546875l0 4.015625l-1.171875 0l0 -4.484375q0 -0.78125 -0.296875 -1.171875q-0.28125 -0.390625 -0.921875 -0.390625q-0.5 0 -0.921875 0.265625q-0.421875 0.25 -0.609375 0.75q-0.1875 0.5 -0.1875 1.453125l0 3.578125l-1.171875 0zm19.321045 -2.53125l1.15625 0.15625q-0.1875 1.1875 -0.96875 1.859375q-0.78125 0.671875 -1.921875 0.671875q-1.40625 0 -2.28125 -0.921875q-0.859375 -0.9375 -0.859375 -2.65625q0 -1.125 0.375 -1.96875q0.375 -0.84375 1.125 -1.25q0.765625 -0.421875 1.65625 -0.421875q1.125 0 1.84375 0.578125q0.71875 0.5625 0.921875 1.609375l-1.140625 0.171875q-0.171875 -0.703125 -0.59375 -1.046875q-0.40625 -0.359375 -0.984375 -0.359375q-0.890625 0 -1.453125 0.640625q-0.546875 0.640625 -0.546875 2.0q0 1.40625 0.53125 2.03125q0.546875 0.625 1.40625 0.625q0.6875 0 1.140625 -0.421875q0.46875 -0.421875 0.59375 -1.296875zm6.6640625 1.671875q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.96109 0l0 -9.546875l1.171875 0l0 9.546875l-1.171875 0zm2.507965 -2.0625l1.15625 -0.1875q0.109375 0.703125 0.546875 1.078125q0.453125 0.359375 1.25 0.359375q0.8125076 0 1.2031326 -0.328125q0.390625 -0.328125 0.390625 -0.765625q0 -0.390625 -0.359375 -0.625q-0.234375 -0.15625 -1.1875076 -0.390625q-1.296875 -0.328125 -1.796875 -0.5625q-0.484375 -0.25 -0.75 -0.65625q-0.25 -0.421875 -0.25 -0.9375q0 -0.453125 0.203125 -0.84375q0.21875 -0.40625 0.578125 -0.671875q0.28125 -0.1875 0.75 -0.328125q0.46875 -0.140625 1.015625 -0.140625q0.8125076 0 1.4218826 0.234375q0.609375 0.234375 0.90625 0.640625q0.296875 0.390625 0.40625 1.0625l-1.140625 0.15625q-0.078125 -0.53125 -0.453125 -0.828125q-0.375 -0.3125 -1.0625076 -0.3125q-0.8125 0 -1.15625 0.265625q-0.34375 0.265625 -0.34375 0.625q0 0.234375 0.140625 0.421875q0.15625 0.1875 0.453125 0.3125q0.171875 0.0625 1.03125 0.296875q1.2500076 0.328125 1.7343826 0.546875q0.5 0.203125 0.78125 0.609375q0.28125 0.40625 0.28125 1.0q0 0.59375 -0.34375 1.109375q-0.34375 0.515625 -1.0 0.796875q-0.640625 0.28125 -1.4531326 0.28125q-1.34375 0 -2.046875 -0.5625q-0.703125 -0.5625 -0.90625 -1.65625z" fill-rule="nonzero"/><path fill="#d9d2e9" d="m36.454067 87.40682l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path stroke="#8e7cc3" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m36.454067 87.40682l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.55086 114.45515l1.65625 -0.140625q0.125 1.0 0.546875 1.640625q0.4375 0.640625 1.34375 1.046875q0.921875 0.390625 2.0625 0.390625q1.0 0 1.78125 -0.296875q0.78125 -0.296875 1.15625 -0.8125q0.375 -0.53125 0.375 -1.15625q0 -0.625 -0.375 -1.09375q-0.359375 -0.46875 -1.1875 -0.796875q-0.546875 -0.203125 -2.390625 -0.640625q-1.828125 -0.453125 -2.5625 -0.84375q-0.96875 -0.5 -1.4375 -1.234375q-0.46875 -0.75 -0.46875 -1.671875q0 -1.0 0.578125 -1.875q0.578125 -0.890625 1.671875 -1.34375q1.109375 -0.453125 2.453125 -0.453125q1.484375 0 2.609375 0.484375q1.140625 0.46875 1.75 1.40625q0.609375 0.921875 0.65625 2.09375l-1.6875 0.125q-0.140625 -1.265625 -0.9375 -1.90625q-0.78125 -0.65625 -2.3125 -0.65625q-1.609375 0 -2.34375 0.59375q-0.734375 0.59375 -0.734375 1.421875q0 0.71875 0.53125 1.171875q0.5 0.46875 2.65625 0.96875q2.15625 0.484375 2.953125 0.84375q1.171875 0.53125 1.71875 1.359375q0.5625 0.828125 0.5625 1.90625q0 1.0625 -0.609375 2.015625q-0.609375 0.9375 -1.75 1.46875q-1.140625 0.515625 -2.578125 0.515625q-1.8125 0 -3.046875 -0.53125q-1.21875 -0.53125 -1.921875 -1.59375q-0.6875 -1.0625 -0.71875 -2.40625zm19.459198 1.1875l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.141342 5.765625l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm13.953842 -1.46875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm1.5895538 1.46875l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.150177 3.71875l-0.1875 -1.53125q0.546875 0.140625 0.9375 0.140625q0.546875 0 0.875 -0.1875q0.328125 -0.171875 0.546875 -0.5q0.15625 -0.25 0.5 -1.21875q0.046875 -0.140625 0.140625 -0.40625l-3.671875 -9.6875l1.765625 0l2.015625 5.59375q0.390625 1.078125 0.703125 2.25q0.28125 -1.125 0.671875 -2.203125l2.078125 -5.640625l1.640625 0l-3.6875 9.828125q-0.59375 1.609375 -0.921875 2.203125q-0.4375 0.8125 -1.0 1.1875q-0.5625 0.375 -1.34375 0.375q-0.484375 0 -1.0625 -0.203125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m6.8477693 152.91733l48.85039 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m6.8477693 152.91733l48.85039 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m145.08398 152.91733l33.95276 0" fill-rule="evenodd"/><path stroke="#ff0000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m145.08398 152.91733l33.95276 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m55.698162 136.88583l89.38582 0l0 32.06299l-89.38582 0z" fill-rule="evenodd"/><path fill="#000000" d="m81.92428 150.43732l0 -8.59375l1.140625 0l0 7.578125l4.234375 0l0 1.015625l-5.375 0zm6.595703 -7.375l0 -1.21875l1.0625 0l0 1.21875l-1.0625 0zm0 7.375l0 -6.21875l1.0625 0l0 6.21875l-1.0625 0zm2.6660156 0l0 -6.21875l0.9375 0l0 0.875q0.296875 -0.46875 0.78125 -0.734375q0.484375 -0.28125 1.109375 -0.28125q0.6875 0 1.125 0.28125q0.453125 0.28125 0.625 0.796875q0.75 -1.078125 1.921875 -1.078125q0.9375 0 1.421875 0.515625q0.5 0.5 0.5 1.578125l0 4.265625l-1.046875 0l0 -3.921875q0 -0.625 -0.109375 -0.90625q-0.09375 -0.28125 -0.359375 -0.453125q-0.265625 -0.171875 -0.640625 -0.171875q-0.65625 0 -1.09375 0.4375q-0.421875 0.4375 -0.421875 1.40625l0 3.609375l-1.0625 0l0 -4.046875q0 -0.703125 -0.265625 -1.046875q-0.25 -0.359375 -0.828125 -0.359375q-0.453125 0 -0.828125 0.234375q-0.375 0.234375 -0.546875 0.6875q-0.171875 0.453125 -0.171875 1.296875l0 3.234375l-1.046875 0zm9.996094 -7.375l0 -1.21875l1.0625 0l0 1.21875l-1.0625 0zm0 7.375l0 -6.21875l1.0625 0l0 6.21875l-1.0625 0zm4.9628906 -0.9375l0.15625 0.921875q-0.453125 0.09375 -0.796875 0.09375q-0.578125 0 -0.890625 -0.171875q-0.3125 -0.1875 -0.453125 -0.484375q-0.125 -0.296875 -0.125 -1.25l0 -3.578125l-0.765625 0l0 -0.8125l0.765625 0l0 -1.546875l1.046875 -0.625l0 2.171875l1.0625 0l0 0.8125l-1.0625 0l0 3.640625q0 0.453125 0.046875 0.578125q0.0625 0.125 0.1875 0.203125q0.125 0.078125 0.359375 0.078125q0.1875 0 0.46875 -0.03125zm5.2871094 -1.0625l1.09375 0.125q-0.25 0.953125 -0.953125 1.484375q-0.703125 0.53125 -1.78125 0.53125q-1.359375 0 -2.171875 -0.84375q-0.796875 -0.84375 -0.796875 -2.359375q0 -1.5625 0.8125 -2.421875q0.8125 -0.875 2.09375 -0.875q1.25 0 2.03125 0.84375q0.796875 0.84375 0.796875 2.390625q0 0.09375 0 0.28125l-4.640625 0q0.0625 1.03125 0.578125 1.578125q0.515625 0.53125 1.296875 0.53125q0.578125 0 0.984375 -0.296875q0.421875 -0.3125 0.65625 -0.96875zm-3.453125 -1.703125l3.46875 0q-0.0625 -0.796875 -0.390625 -1.1875q-0.515625 -0.609375 -1.3125 -0.609375q-0.734375 0 -1.234375 0.484375q-0.484375 0.484375 -0.53125 1.3125zm9.908203 3.703125l0 -0.78125q-0.59375 0.921875 -1.734375 0.921875q-0.75 0 -1.375 -0.40625q-0.625 -0.421875 -0.96875 -1.15625q-0.34375 -0.734375 -0.34375 -1.6875q0 -0.921875 0.3125 -1.6875q0.3125 -0.765625 0.9375 -1.15625q0.625 -0.40625 1.390625 -0.40625q0.5625 0 1.0 0.234375q0.4375 0.234375 0.71875 0.609375l0 -3.078125l1.046875 0l0 8.59375l-0.984375 0zm-3.328125 -3.109375q0 1.203125 0.5 1.796875q0.5 0.578125 1.1875 0.578125q0.6875 0 1.171875 -0.5625q0.484375 -0.5625 0.484375 -1.71875q0 -1.28125 -0.5 -1.875q-0.484375 -0.59375 -1.203125 -0.59375q-0.703125 0 -1.171875 0.578125q-0.46875 0.5625 -0.46875 1.796875z" fill-rule="nonzero"/><path fill="#000000" d="m68.0942 162.57794l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5zm6.375 4.25l-0.125 -0.984375q0.34375 0.09375 0.609375 0.09375q0.34375 0 0.546875 -0.125q0.21875 -0.109375 0.359375 -0.3125q0.09375 -0.171875 0.328125 -0.796875q0.015625 -0.078125 0.09375 -0.25l-2.375 -6.234375l1.140625 0l1.296875 3.59375q0.25 0.6875 0.453125 1.453125q0.1875 -0.734375 0.4375 -1.421875l1.328125 -3.625l1.046875 0l-2.359375 6.328125q-0.390625 1.015625 -0.59375 1.40625q-0.28125 0.53125 -0.65625 0.765625q-0.359375 0.25 -0.859375 0.25q-0.296875 0 -0.671875 -0.140625zm5.625 -4.25l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5zm8.71875 0.921875l0.15625 0.921875q-0.453125 0.09375 -0.796875 0.09375q-0.578125 0 -0.890625 -0.171875q-0.3125 -0.1875 -0.453125 -0.484375q-0.125 -0.296875 -0.125 -1.25l0 -3.578125l-0.765625 0l0 -0.8125l0.765625 0l0 -1.546875l1.046875 -0.625l0 2.171875l1.0625 0l0 0.8125l-1.0625 0l0 3.640625q0 0.453125 0.046875 0.578125q0.0625 0.125 0.1875 0.203125q0.125 0.078125 0.359375 0.078125q0.1875 0 0.46875 -0.03125zm5.2871094 -1.0625l1.09375 0.125q-0.25 0.953125 -0.953125 1.484375q-0.703125 0.53125 -1.78125 0.53125q-1.359375 0 -2.171875 -0.84375q-0.796875 -0.84375 -0.796875 -2.359375q0 -1.5625 0.8125 -2.421875q0.8125 -0.875 2.09375 -0.875q1.25 0 2.03125 0.84375q0.796875 0.84375 0.796875 2.390625q0 0.09375 0 0.28125l-4.640625 0q0.0625 1.03125 0.578125 1.578125q0.515625 0.53125 1.296875 0.53125q0.578125 0 0.984375 -0.296875q0.421875 -0.3125 0.65625 -0.96875zm-3.453125 -1.703125l3.46875 0q-0.0625 -0.796875 -0.390625 -1.1875q-0.515625 -0.609375 -1.3125 -0.609375q-0.734375 0 -1.234375 0.484375q-0.484375 0.484375 -0.53125 1.3125zm5.876953 3.703125l0 -6.21875l0.9375 0l0 0.875q0.296875 -0.46875 0.78125 -0.734375q0.484375 -0.28125 1.109375 -0.28125q0.6875 0 1.125 0.28125q0.453125 0.28125 0.625 0.796875q0.75 -1.078125 1.921875 -1.078125q0.9375 0 1.421875 0.515625q0.5 0.5 0.5 1.578125l0 4.265625l-1.046875 0l0 -3.921875q0 -0.625 -0.109375 -0.90625q-0.09375 -0.28125 -0.359375 -0.453125q-0.265625 -0.171875 -0.640625 -0.171875q-0.65625 0 -1.09375 0.4375q-0.421875 0.4375 -0.421875 1.40625l0 3.609375l-1.0625 0l0 -4.046875q0 -0.703125 -0.265625 -1.046875q-0.25 -0.359375 -0.828125 -0.359375q-0.453125 0 -0.828125 0.234375q-0.375 0.234375 -0.546875 0.6875q-0.171875 0.453125 -0.171875 1.296875l0 3.234375l-1.046875 0zm17.392578 -2.28125l1.03125 0.140625q-0.171875 1.0625 -0.875 1.671875q-0.703125 0.609375 -1.71875 0.609375q-1.28125 0 -2.0625 -0.828125q-0.765625 -0.84375 -0.765625 -2.40625q0 -1.0 0.328125 -1.75q0.34375 -0.765625 1.015625 -1.140625q0.6875 -0.375 1.5 -0.375q1.0 0 1.640625 0.515625q0.65625 0.5 0.84375 1.453125l-1.03125 0.15625q-0.140625 -0.625 -0.515625 -0.9375q-0.375 -0.328125 -0.90625 -0.328125q-0.796875 0 -1.296875 0.578125q-0.5 0.5625 -0.5 1.796875q0 1.265625 0.484375 1.828125q0.484375 0.5625 1.25 0.5625q0.625 0 1.03125 -0.375q0.421875 -0.375 0.546875 -1.171875zm6.0000076 1.515625q-0.5937576 0.5 -1.1406326 0.703125q-0.53125 0.203125 -1.15625 0.203125q-1.03125 0 -1.578125 -0.5q-0.546875 -0.5 -0.546875 -1.28125q0 -0.453125 0.203125 -0.828125q0.203125 -0.390625 0.546875 -0.609375q0.34375 -0.234375 0.765625 -0.34375q0.296875 -0.09375 0.9375 -0.171875q1.265625 -0.140625 1.8750076 -0.359375q0 -0.21875 0 -0.265625q0 -0.65625 -0.29688263 -0.921875q-0.40625 -0.34375 -1.203125 -0.34375q-0.734375 0 -1.09375 0.265625q-0.359375 0.25 -0.53125 0.90625l-1.03125 -0.140625q0.140625 -0.65625 0.46875 -1.0625q0.328125 -0.40625 0.9375 -0.625q0.609375 -0.21875 1.40625 -0.21875q0.796875 0 1.2968826 0.1875q0.5 0.1875 0.734375 0.46875q0.234375 0.28125 0.328125 0.71875q0.046875 0.265625 0.046875 0.96875l0 1.40625q0 1.46875 0.0625 1.859375q0.078125 0.390625 0.28125 0.75l-1.109375 0q-0.15625 -0.328125 -0.203125 -0.765625zm-0.09375 -2.359375q-0.5781326 0.234375 -1.7187576 0.40625q-0.65625 0.09375 -0.921875 0.21875q-0.265625 0.109375 -0.421875 0.328125q-0.140625 0.21875 -0.140625 0.5q0 0.421875 0.3125 0.703125q0.328125 0.28125 0.9375 0.28125q0.609375 0 1.078125 -0.265625q0.484375 -0.265625 0.703125 -0.734375q0.17188263 -0.359375 0.17188263 -1.046875l0 -0.390625zm2.6738281 3.125l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.6660156 0l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.2753906 -1.859375l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5z" fill-rule="nonzero"/><path fill="#cfe2f3" d="m89.902885 171.04015l174.83463 0l0 48.850388l-174.83463 0z" fill-rule="evenodd"/><path stroke="#6d9eeb" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m89.902885 171.04015l174.83463 0l0 48.850388l-174.83463 0z" fill-rule="evenodd"/><path fill="#000000" d="m130.0844 202.38535l0 -13.359375l1.765625 0l0 5.484375l6.9375 0l0 -5.484375l1.765625 0l0 13.359375l-1.765625 0l0 -6.296875l-6.9375 0l0 6.296875l-1.765625 0zm12.597946 -4.84375q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.046875 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.03125 0 -3.28125 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.8125 0.921875 2.046875 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.046875 0.921875q-0.796875 0.90625 -0.796875 2.765625zm8.641342 1.953125l1.625 -0.25q0.125 0.96875 0.75 1.5q0.625 0.515625 1.75 0.515625q1.125 0 1.671875 -0.453125q0.546875 -0.46875 0.546875 -1.09375q0 -0.546875 -0.484375 -0.875q-0.328125 -0.21875 -1.671875 -0.546875q-1.8125 -0.46875 -2.515625 -0.796875q-0.6875 -0.328125 -1.046875 -0.90625q-0.359375 -0.59375 -0.359375 -1.3125q0 -0.640625 0.296875 -1.1875q0.296875 -0.5625 0.8125 -0.921875q0.375 -0.28125 1.03125 -0.46875q0.671875 -0.203125 1.421875 -0.203125q1.140625 0 2.0 0.328125q0.859375 0.328125 1.265625 0.890625q0.421875 0.5625 0.578125 1.5l-1.609375 0.21875q-0.109375 -0.75 -0.640625 -1.171875q-0.515625 -0.421875 -1.46875 -0.421875q-1.140625 0 -1.625 0.375q-0.46875 0.375 -0.46875 0.875q0 0.3125 0.1875 0.578125q0.203125 0.265625 0.640625 0.4375q0.234375 0.09375 1.4375 0.421875q1.75 0.453125 2.4375 0.75q0.6875 0.296875 1.078125 0.859375q0.390625 0.5625 0.390625 1.40625q0 0.828125 -0.484375 1.546875q-0.46875 0.71875 -1.375 1.125q-0.90625 0.390625 -2.046875 0.390625q-1.875 0 -2.875 -0.78125q-0.984375 -0.78125 -1.25 -2.328125zm13.5625 1.421875l0.234375 1.453125q-0.6875 0.140625 -1.234375 0.140625q-0.890625 0 -1.390625 -0.28125q-0.484375 -0.28125 -0.6875 -0.734375q-0.203125 -0.46875 -0.203125 -1.9375l0 -5.578125l-1.203125 0l0 -1.265625l1.203125 0l0 -2.390625l1.625 -0.984375l0 3.375l1.65625 0l0 1.265625l-1.65625 0l0 5.671875q0 0.6875 0.078125 0.890625q0.09375 0.203125 0.28125 0.328125q0.203125 0.109375 0.578125 0.109375q0.265625 0 0.71875 -0.0625zm6.9134827 1.46875l0 -13.359375l1.78125 0l0 6.625l6.625 -6.625l2.390625 0l-5.59375 5.421875l5.84375 7.9375l-2.328125 0l-4.765625 -6.765625l-2.171875 2.140625l0 4.625l-1.78125 0zm18.943573 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125717 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0zm6.228302 0l0 -9.671875l1.46875 0l0 1.375q1.0625 -1.59375 3.078125 -1.59375q0.875 0 1.609375 0.3125q0.734375 0.3125 1.09375 0.828125q0.375 0.5 0.515625 1.203125q0.09375 0.453125 0.09375 1.59375l0 5.953125l-1.640625 0l0 -5.890625q0 -1.0 -0.203125 -1.484375q-0.1875 -0.5 -0.671875 -0.796875q-0.484375 -0.296875 -1.140625 -0.296875q-1.046875 0 -1.8125 0.671875q-0.75 0.65625 -0.75 2.515625l0 5.28125l-1.640625 0zm17.000732 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.7656403 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375153 0 3.1562653 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.2187653 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.5468903 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.3906403 -2.65625l5.4062653 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.0312653 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.094467 5.765625l0 -13.359375l1.640625 0l0 13.359375l-1.640625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m9.1482525 232.87665l117.79527 -0.34646606" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m9.1482525 232.87665l117.79527 -0.34646606" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.43172 232.53018l120.34645 0.34646606" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.43172 232.53018l120.34645 0.34646606" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m126.94353 216.49869l106.48819 0l0 32.06299l-106.48819 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.36165 237.33018l0 -9.546875l1.265625 0l0 3.921875l4.953125 0l0 -3.921875l1.265625 0l0 9.546875l-1.265625 0l0 -4.5l-4.953125 0l0 4.5l-1.265625 0zm13.953278 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm8.93837 0l0 -0.875q-0.65625 1.03125 -1.9375 1.03125q-0.8125 0 -1.515625 -0.453125q-0.6875 -0.453125 -1.078125 -1.265625q-0.375 -0.828125 -0.375 -1.890625q0 -1.03125 0.34375 -1.875q0.34375 -0.84375 1.03125 -1.28125q0.703125 -0.453125 1.546875 -0.453125q0.625 0 1.109375 0.265625q0.5 0.25 0.796875 0.671875l0 -3.421875l1.171875 0l0 9.546875l-1.09375 0zm-3.703125 -3.453125q0 1.328125 0.5625 1.984375q0.5625 0.65625 1.328125 0.65625q0.765625 0 1.296875 -0.625q0.53125 -0.625 0.53125 -1.90625q0 -1.421875 -0.546875 -2.078125q-0.546875 -0.671875 -1.34375 -0.671875q-0.78125 0 -1.3125 0.640625q-0.515625 0.625 -0.515625 2.0zm7.9124756 3.453125l-2.125 -6.90625l1.21875 0l1.09375 3.984375l0.421875 1.484375q0.015625 -0.109375 0.359375 -1.421875l1.09375 -4.046875l1.203125 0l1.03125 4.0l0.34375 1.328125l0.40625 -1.34375l1.171875 -3.984375l1.140625 0l-2.15625 6.90625l-1.21875 0l-1.09375 -4.140625l-0.265625 -1.171875l-1.40625 5.3125l-1.21875 0zm12.859528 -0.859375q-0.65625 0.5625 -1.265625 0.796875q-0.59375 0.21875 -1.28125 0.21875q-1.140625 0 -1.75 -0.546875q-0.609375 -0.5625 -0.609375 -1.4375q0 -0.5 0.21875 -0.921875q0.234375 -0.421875 0.609375 -0.671875q0.375 -0.25 0.84375 -0.390625q0.34375 -0.078125 1.046875 -0.171875q1.421875 -0.171875 2.09375 -0.40625q0 -0.234375 0 -0.296875q0 -0.71875 -0.328125 -1.015625q-0.453125 -0.390625 -1.34375 -0.390625q-0.8125 0 -1.21875 0.296875q-0.390625 0.28125 -0.578125 1.015625l-1.140625 -0.15625q0.15625 -0.734375 0.515625 -1.1875q0.359375 -0.453125 1.03125 -0.6875q0.671875 -0.25 1.5625 -0.25q0.890625 0 1.4375 0.203125q0.5625 0.203125 0.8125 0.53125q0.265625 0.3125 0.375 0.796875q0.046875 0.296875 0.046875 1.078125l0 1.5625q0 1.625 0.078125 2.0625q0.078125 0.4375 0.296875 0.828125l-1.21875 0q-0.1875 -0.359375 -0.234375 -0.859375zm-0.09375 -2.609375q-0.640625 0.265625 -1.921875 0.4375q-0.71875 0.109375 -1.015625 0.25q-0.296875 0.125 -0.46875 0.375q-0.15625 0.25 -0.15625 0.546875q0 0.46875 0.34375 0.78125q0.359375 0.3125 1.046875 0.3125q0.671875 0 1.203125 -0.296875q0.53125 -0.296875 0.78125 -0.8125q0.1875 -0.390625 0.1875 -1.171875l0 -0.421875zm2.9749756 3.46875l0 -6.90625l1.0625 0l0 1.046875q0.40625 -0.734375 0.734375 -0.96875q0.34375 -0.234375 0.765625 -0.234375q0.59375 0 1.203125 0.375l-0.40625 1.078125q-0.4375 -0.25 -0.859375 -0.25q-0.390625 0 -0.703125 0.234375q-0.296875 0.234375 -0.421875 0.640625q-0.203125 0.625 -0.203125 1.359375l0 3.625l-1.171875 0zm9.18837 -2.21875l1.203125 0.140625q-0.28125 1.0625 -1.0625 1.65625q-0.765625 0.578125 -1.96875 0.578125q-1.515625 0 -2.40625 -0.9375q-0.890625 -0.9375 -0.890625 -2.609375q0 -1.75 0.890625 -2.703125q0.90625 -0.96875 2.34375 -0.96875q1.390625 0 2.265625 0.9375q0.875 0.9375 0.875 2.65625q0 0.109375 0 0.3125l-5.15625 0q0.0625 1.140625 0.640625 1.75q0.578125 0.59375 1.4375 0.59375q0.65625 0 1.109375 -0.328125q0.453125 -0.34375 0.71875 -1.078125zm-3.84375 -1.90625l3.859375 0q-0.078125 -0.859375 -0.4375 -1.296875q-0.5625 -0.6875 -1.453125 -0.6875q-0.8125 0 -1.359375 0.546875q-0.546875 0.53125 -0.609375 1.4375z" fill-rule="nonzero"/><defs><linearGradient id="p.1" gradientUnits="userSpaceOnUse" gradientTransform="matrix(4.500288203456436 0.0 0.0 4.500288203456436 0.0 0.0)" spreadMethod="pad" x1="20.153856515233464" y1="38.20913288577608" x2="20.153840567727556" y2="42.70942108920426"><stop offset="0.0" stop-color="#ff0000"/><stop offset="0.51" stop-color="#dab7a6"/><stop offset="0.99999994" stop-color="#dab7a6" stop-opacity="0.0"/><stop offset="1.0" stop-color="#ffffff" stop-opacity="0.0"/></linearGradient></defs><path fill="url(#p.1)" d="m90.698166 171.95273l173.29134 0l0 20.251968l-173.29134 0z" fill-rule="evenodd"/><path fill="#d9d2e9" d="m203.76447 87.804726l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path stroke="#8e7cc3" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m203.76447 87.804726l114.4252 0l0 48.850395l-114.4252 0z" fill-rule="evenodd"/><path fill="#000000" d="m245.33514 113.91555l0 -1.578125l5.65625 0l0 4.953125q-1.296875 1.046875 -2.6875 1.578125q-1.375 0.515625 -2.84375 0.515625q-1.96875 0 -3.578125 -0.84375q-1.609375 -0.84375 -2.421875 -2.4375q-0.8125 -1.59375 -0.8125 -3.5625q0 -1.953125 0.8125 -3.640625q0.8125 -1.6875 2.34375 -2.5q1.53125 -0.828125 3.515625 -0.828125q1.453125 0 2.625 0.46875q1.171875 0.46875 1.828125 1.3125q0.671875 0.828125 1.015625 2.171875l-1.59375 0.4375q-0.296875 -1.015625 -0.75 -1.59375q-0.4375 -0.59375 -1.265625 -0.9375q-0.828125 -0.34375 -1.84375 -0.34375q-1.203125 0 -2.09375 0.375q-0.890625 0.359375 -1.4375 0.96875q-0.53125 0.59375 -0.828125 1.3125q-0.515625 1.234375 -0.515625 2.6875q0 1.78125 0.609375 2.984375q0.625 1.203125 1.796875 1.796875q1.171875 0.578125 2.5 0.578125q1.140625 0 2.234375 -0.4375q1.09375 -0.453125 1.65625 -0.953125l0 -2.484375l-3.921875 0zm7.448929 0.390625q0 -2.6875 1.484375 -3.96875q1.25 -1.078125 3.0468597 -1.078125q2.0 0 3.265625 1.3125q1.265625 1.296875 1.265625 3.609375q0 1.859375 -0.5625 2.9375q-0.5625 1.0625 -1.640625 1.65625q-1.0625 0.59375 -2.328125 0.59375q-2.0312347 0 -3.2812347 -1.296875q-1.25 -1.3125 -1.25 -3.765625zm1.6875 0q0 1.859375 0.796875 2.796875q0.81248474 0.921875 2.0468597 0.921875q1.21875 0 2.03125 -0.921875q0.8125 -0.9375 0.8125 -2.84375q0 -1.796875 -0.8125 -2.71875q-0.8125 -0.921875 -2.03125 -0.921875q-1.234375 0 -2.0468597 0.921875q-0.796875 0.90625 -0.796875 2.765625zm9.688217 4.84375l0 -8.40625l-1.453125 0l0 -1.265625l1.453125 0l0 -1.03125q0 -0.96875 0.171875 -1.453125q0.234375 -0.640625 0.828125 -1.03125q0.59375 -0.390625 1.671875 -0.390625q0.6875 0 1.53125 0.15625l-0.25 1.4375q-0.5 -0.09375 -0.953125 -0.09375q-0.75 0 -1.0625 0.328125q-0.3125 0.3125 -0.3125 1.1875l0 0.890625l1.890625 0l0 1.265625l-1.890625 0l0 8.40625l-1.625 0zm11.417664 -3.109375l1.6875 0.203125q-0.40625 1.484375 -1.484375 2.3125q-1.078125 0.8125 -2.765625 0.8125q-2.125 0 -3.375 -1.296875q-1.234375 -1.3125 -1.234375 -3.671875q0 -2.453125 1.25 -3.796875q1.265625 -1.34375 3.265625 -1.34375q1.9375 0 3.15625 1.328125q1.234375 1.3125 1.234375 3.703125q0 0.15625 0 0.4375l-7.21875 0q0.09375 1.59375 0.90625 2.453125q0.8125 0.84375 2.015625 0.84375q0.90625 0 1.546875 -0.46875q0.640625 -0.484375 1.015625 -1.515625zm-5.390625 -2.65625l5.40625 0q-0.109375 -1.21875 -0.625 -1.828125q-0.78125 -0.953125 -2.03125 -0.953125q-1.125 0 -1.90625 0.765625q-0.765625 0.75 -0.84375 2.015625zm9.125732 5.765625l0 -9.671875l1.46875 0l0 1.46875q0.5625 -1.03125 1.03125 -1.359375q0.484375 -0.328125 1.0625 -0.328125q0.828125 0 1.6875 0.53125l-0.5625 1.515625q-0.609375 -0.359375 -1.203125 -0.359375q-0.546875 0 -0.96875 0.328125q-0.421875 0.328125 -0.609375 0.890625q-0.28125 0.875 -0.28125 1.921875l0 5.0625l-1.625 0z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m179.05511 152.91733l37.984253 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m179.05511 152.91733l37.984253 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m306.4252 152.91733l47.338593 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m306.4252 152.91733l47.338593 0" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m217.03937 136.88583l89.38583 0l0 32.06299l-89.38583 0z" fill-rule="evenodd"/><path fill="#000000" d="m226.58624 154.67169l1.0625 -0.09375q0.078125 0.65625 0.359375 1.0625q0.28125 0.40625 0.859375 0.671875q0.59375 0.25 1.328125 0.25q0.640625 0 1.140625 -0.1875q0.5 -0.203125 0.734375 -0.53125q0.25 -0.34375 0.25 -0.734375q0 -0.40625 -0.234375 -0.703125q-0.234375 -0.3125 -0.765625 -0.515625q-0.359375 -0.140625 -1.546875 -0.421875q-1.171875 -0.28125 -1.640625 -0.53125q-0.625 -0.328125 -0.921875 -0.796875q-0.296875 -0.484375 -0.296875 -1.078125q0 -0.640625 0.359375 -1.203125q0.375 -0.578125 1.078125 -0.859375q0.71875 -0.296875 1.578125 -0.296875q0.953125 0 1.6875 0.3125q0.734375 0.296875 1.125 0.90625q0.390625 0.59375 0.421875 1.34375l-1.09375 0.078125q-0.09375 -0.8125 -0.609375 -1.21875q-0.5 -0.421875 -1.484375 -0.421875q-1.03125 0 -1.5 0.375q-0.46875 0.375 -0.46875 0.90625q0 0.46875 0.328125 0.765625q0.328125 0.296875 1.703125 0.609375q1.390625 0.3125 1.90625 0.546875q0.75 0.359375 1.109375 0.890625q0.359375 0.515625 0.359375 1.21875q0 0.6875 -0.390625 1.296875q-0.390625 0.59375 -1.125 0.9375q-0.734375 0.328125 -1.65625 0.328125q-1.171875 0 -1.96875 -0.328125q-0.78125 -0.34375 -1.234375 -1.03125q-0.4375 -0.6875 -0.453125 -1.546875zm8.207031 5.15625l-0.125 -0.984375q0.34375 0.09375 0.609375 0.09375q0.34375 0 0.546875 -0.125q0.21875 -0.109375 0.359375 -0.3125q0.09375 -0.171875 0.328125 -0.796875q0.015625 -0.078125 0.09375 -0.25l-2.375 -6.234375l1.140625 0l1.296875 3.59375q0.25 0.6875 0.453125 1.453125q0.1875 -0.734375 0.4375 -1.421875l1.328125 -3.625l1.046875 0l-2.359375 6.328125q-0.390625 1.015625 -0.59375 1.40625q-0.28125 0.53125 -0.65625 0.765625q-0.359375 0.25 -0.859375 0.25q-0.296875 0 -0.671875 -0.140625zm5.625 -4.25l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5zm8.71875 0.921875l0.15625 0.921875q-0.453125 0.09375 -0.796875 0.09375q-0.578125 0 -0.890625 -0.171875q-0.3125 -0.1875 -0.453125 -0.484375q-0.125 -0.296875 -0.125 -1.25l0 -3.578125l-0.765625 0l0 -0.8125l0.765625 0l0 -1.546875l1.046875 -0.625l0 2.171875l1.0625 0l0 0.8125l-1.0625 0l0 3.640625q0 0.453125 0.046875 0.578125q0.0625 0.125 0.1875 0.203125q0.125 0.078125 0.359375 0.078125q0.1875 0 0.46875 -0.03125zm5.2871094 -1.0625l1.09375 0.125q-0.25 0.953125 -0.953125 1.484375q-0.703125 0.53125 -1.78125 0.53125q-1.359375 0 -2.171875 -0.84375q-0.796875 -0.84375 -0.796875 -2.359375q0 -1.5625 0.8125 -2.421875q0.8125 -0.875 2.09375 -0.875q1.25 0 2.03125 0.84375q0.796875 0.84375 0.796875 2.390625q0 0.09375 0 0.28125l-4.640625 0q0.0625 1.03125 0.578125 1.578125q0.515625 0.53125 1.296875 0.53125q0.578125 0 0.984375 -0.296875q0.421875 -0.3125 0.65625 -0.96875zm-3.453125 -1.703125l3.46875 0q-0.0625 -0.796875 -0.390625 -1.1875q-0.515625 -0.609375 -1.3125 -0.609375q-0.734375 0 -1.234375 0.484375q-0.484375 0.484375 -0.53125 1.3125zm5.876953 3.703125l0 -6.21875l0.9375 0l0 0.875q0.296875 -0.46875 0.78125 -0.734375q0.484375 -0.28125 1.109375 -0.28125q0.6875 0 1.125 0.28125q0.453125 0.28125 0.625 0.796875q0.75 -1.078125 1.921875 -1.078125q0.9375 0 1.421875 0.515625q0.5 0.5 0.5 1.578125l0 4.265625l-1.046875 0l0 -3.921875q0 -0.625 -0.109375 -0.90625q-0.09375 -0.28125 -0.359375 -0.453125q-0.265625 -0.171875 -0.640625 -0.171875q-0.65625 0 -1.09375 0.4375q-0.421875 0.4375 -0.421875 1.40625l0 3.609375l-1.0625 0l0 -4.046875q0 -0.703125 -0.265625 -1.046875q-0.25 -0.359375 -0.828125 -0.359375q-0.453125 0 -0.828125 0.234375q-0.375 0.234375 -0.546875 0.6875q-0.171875 0.453125 -0.171875 1.296875l0 3.234375l-1.046875 0zm17.392578 -2.28125l1.03125 0.140625q-0.171875 1.0625 -0.875 1.671875q-0.703125 0.609375 -1.71875 0.609375q-1.28125 0 -2.0625 -0.828125q-0.765625 -0.84375 -0.765625 -2.40625q0 -1.0 0.328125 -1.75q0.34375 -0.765625 1.015625 -1.140625q0.6875 -0.375 1.5 -0.375q1.0 0 1.640625 0.515625q0.65625 0.5 0.84375 1.453125l-1.03125 0.15625q-0.140625 -0.625 -0.515625 -0.9375q-0.375 -0.328125 -0.90625 -0.328125q-0.796875 0 -1.296875 0.578125q-0.5 0.5625 -0.5 1.796875q0 1.265625 0.484375 1.828125q0.484375 0.5625 1.25 0.5625q0.625 0 1.03125 -0.375q0.421875 -0.375 0.546875 -1.171875zm6.0 1.515625q-0.59375 0.5 -1.140625 0.703125q-0.53125 0.203125 -1.15625 0.203125q-1.03125 0 -1.578125 -0.5q-0.546875 -0.5 -0.546875 -1.28125q0 -0.453125 0.203125 -0.828125q0.203125 -0.390625 0.546875 -0.609375q0.34375 -0.234375 0.765625 -0.34375q0.296875 -0.09375 0.9375 -0.171875q1.265625 -0.140625 1.875 -0.359375q0 -0.21875 0 -0.265625q0 -0.65625 -0.296875 -0.921875q-0.40625 -0.34375 -1.203125 -0.34375q-0.734375 0 -1.09375 0.265625q-0.359375 0.25 -0.53125 0.90625l-1.03125 -0.140625q0.140625 -0.65625 0.46875 -1.0625q0.328125 -0.40625 0.9375 -0.625q0.609375 -0.21875 1.40625 -0.21875q0.796875 0 1.296875 0.1875q0.5 0.1875 0.734375 0.46875q0.234375 0.28125 0.328125 0.71875q0.046875 0.265625 0.046875 0.96875l0 1.40625q0 1.46875 0.0625 1.859375q0.078125 0.390625 0.28125 0.75l-1.109375 0q-0.15625 -0.328125 -0.203125 -0.765625zm-0.09375 -2.359375q-0.578125 0.234375 -1.71875 0.40625q-0.65625 0.09375 -0.921875 0.21875q-0.265625 0.109375 -0.421875 0.328125q-0.140625 0.21875 -0.140625 0.5q0 0.421875 0.3125 0.703125q0.328125 0.28125 0.9375 0.28125q0.609375 0 1.078125 -0.265625q0.484375 -0.265625 0.703125 -0.734375q0.171875 -0.359375 0.171875 -1.046875l0 -0.390625zm2.6738281 3.125l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.6660156 0l0 -8.59375l1.0625 0l0 8.59375l-1.0625 0zm2.2753906 -1.859375l1.03125 -0.15625q0.09375 0.625 0.484375 0.953125q0.40625 0.328125 1.140625 0.328125q0.71875 0 1.0625 -0.28125q0.359375 -0.296875 0.359375 -0.703125q0 -0.359375 -0.3125 -0.5625q-0.21875 -0.140625 -1.078125 -0.359375q-1.15625 -0.296875 -1.609375 -0.5q-0.4375 -0.21875 -0.671875 -0.59375q-0.234375 -0.375 -0.234375 -0.84375q0 -0.40625 0.1875 -0.765625q0.1875 -0.359375 0.515625 -0.59375q0.25 -0.171875 0.671875 -0.296875q0.421875 -0.125 0.921875 -0.125q0.71875 0 1.265625 0.21875q0.5625 0.203125 0.828125 0.5625q0.265625 0.359375 0.359375 0.953125l-1.03125 0.140625q-0.0625 -0.46875 -0.40625 -0.734375q-0.328125 -0.28125 -0.953125 -0.28125q-0.71875 0 -1.03125 0.25q-0.3125 0.234375 -0.3125 0.5625q0 0.203125 0.125 0.359375q0.140625 0.171875 0.40625 0.28125q0.15625 0.0625 0.9375 0.265625q1.125 0.3125 1.5625 0.5q0.4375 0.1875 0.6875 0.546875q0.25 0.359375 0.25 0.90625q0 0.53125 -0.3125 1.0q-0.296875 0.453125 -0.875 0.71875q-0.578125 0.25 -1.3125 0.25q-1.21875 0 -1.859375 -0.5q-0.625 -0.515625 -0.796875 -1.5z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m150.87927 111.83202l52.88188 0.40944672" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m154.30624 111.85855l46.027924 0.35638428" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m154.30624 111.85856l1.133255 -1.1158447l-3.0983734 1.1006241l3.0809631 1.1484756z" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m200.33417 112.214935l-1.133255 1.1158447l3.0983887 -1.1006317l-3.0809784 -1.148468z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m159.04854 85.32021l37.417328 0l0 32.06299l-37.417328 0z" fill-rule="evenodd"/><path fill="#000000" d="m168.78291 104.91708l1.125 -0.109375q0.140625 0.796875 0.546875 1.15625q0.40625 0.359375 1.03125 0.359375q0.53125 0 0.9375 -0.25q0.421875 -0.25 0.671875 -0.65625q0.265625 -0.421875 0.4375 -1.125q0.171875 -0.703125 0.171875 -1.421875q0 -0.078125 0 -0.234375q-0.359375 0.546875 -0.96875 0.90625q-0.59375 0.34375 -1.3125 0.34375q-1.1875 0 -2.015625 -0.859375q-0.8125 -0.859375 -0.8125 -2.265625q0 -1.453125 0.859375 -2.328125q0.859375 -0.890625 2.140625 -0.890625q0.9375 0 1.703125 0.5q0.78125 0.5 1.171875 1.4375q0.40625 0.921875 0.40625 2.671875q0 1.828125 -0.40625 2.921875q-0.390625 1.078125 -1.171875 1.640625q-0.78125 0.5625 -1.84375 0.5625q-1.109375 0 -1.828125 -0.609375q-0.703125 -0.625 -0.84375 -1.75zm4.796875 -4.21875q0 -1.0 -0.546875 -1.59375q-0.53125 -0.59375 -1.28125 -0.59375q-0.78125 0 -1.375 0.640625q-0.578125 0.640625 -0.578125 1.65625q0 0.90625 0.546875 1.484375q0.5625 0.5625 1.359375 0.5625q0.828125 0 1.34375 -0.5625q0.53125 -0.578125 0.53125 -1.59375zm2.9124756 6.421875l0 -9.546875l3.59375 0q0.953125 0 1.453125 0.09375q0.703125 0.125 1.171875 0.453125q0.484375 0.328125 0.765625 0.921875q0.296875 0.59375 0.296875 1.296875q0 1.21875 -0.78125 2.0625q-0.765625 0.84375 -2.796875 0.84375l-2.4375 0l0 3.875l-1.265625 0zm1.265625 -5.0l2.453125 0q1.234375 0 1.75 -0.453125q0.515625 -0.46875 0.515625 -1.28125q0 -0.609375 -0.3125 -1.03125q-0.296875 -0.421875 -0.796875 -0.5625q-0.3125 -0.09375 -1.171875 -0.09375l-2.4375 0l0 3.421875z" fill-rule="nonzero"/></g></svg> \ No newline at end of file
diff --git a/g3doc/architecture_guide/BUILD b/g3doc/architecture_guide/BUILD
new file mode 100644
index 000000000..404f627a4
--- /dev/null
+++ b/g3doc/architecture_guide/BUILD
@@ -0,0 +1,50 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "platforms",
+ src = "platforms.md",
+ category = "Architecture Guide",
+ data = [
+ "platforms.png",
+ "platforms.svg",
+ ],
+ permalink = "/docs/architecture_guide/platforms/",
+ weight = "40",
+)
+
+doc(
+ name = "resources",
+ src = "resources.md",
+ category = "Architecture Guide",
+ data = [
+ "resources.png",
+ "resources.svg",
+ ],
+ permalink = "/docs/architecture_guide/resources/",
+ weight = "30",
+)
+
+doc(
+ name = "security",
+ src = "security.md",
+ category = "Architecture Guide",
+ data = [
+ "security.png",
+ "security.svg",
+ ],
+ permalink = "/docs/architecture_guide/security/",
+ weight = "10",
+)
+
+doc(
+ name = "performance",
+ src = "performance.md",
+ category = "Architecture Guide",
+ permalink = "/docs/architecture_guide/performance/",
+ weight = "20",
+)
diff --git a/g3doc/architecture_guide/performance.md b/g3doc/architecture_guide/performance.md
new file mode 100644
index 000000000..39dbb0045
--- /dev/null
+++ b/g3doc/architecture_guide/performance.md
@@ -0,0 +1,277 @@
+# Performance Guide
+
+[TOC]
+
+gVisor is designed to provide a secure, virtualized environment while preserving
+key benefits of containerization, such as small fixed overheads and a dynamic
+resource footprint. For containerized infrastructure, this can provide a
+turn-key solution for sandboxing untrusted workloads: there are no changes to
+the fundamental resource model.
+
+gVisor imposes runtime costs over native containers. These costs come in two
+forms: additional cycles and memory usage, which may manifest as increased
+latency, reduced throughput or density, or not at all. In general, these costs
+come from two different sources.
+
+First, the existence of the [Sentry](../README.md#sentry) means that additional
+memory will be required, and application system calls must traverse additional
+layers of software. The design emphasizes
+[security](/docs/architecture_guide/security/) and therefore we chose to use a
+language for the Sentry that provides benefits in this domain but may not yet
+offer the raw performance of other choices. Costs imposed by these design
+choices are **structural costs**.
+
+Second, as gVisor is an independent implementation of the system call surface,
+many of the subsystems or specific calls are not as optimized as more mature
+implementations. A good example here is the network stack, which is continuing
+to evolve but does not support all the advanced recovery mechanisms offered by
+other stacks and is less CPU efficient. This is an **implementation cost** and
+is distinct from **structural costs**. Improvements here are ongoing and driven
+by the workloads that matter to gVisor users and contributors.
+
+This page provides a guide for understanding baseline performance, and calls out
+distint **structural costs** and **implementation costs**, highlighting where
+improvements are possible and not possible.
+
+While we include a variety of workloads here, it’s worth emphasizing that gVisor
+may not be an appropriate solution for every workload, for reasons other than
+performance. For example, a sandbox may provide minimal benefit for a trusted
+database, since _user data would already be inside the sandbox_ and there is no
+need for an attacker to break out in the first place.
+
+## Methodology
+
+All data below was generated using the [benchmark tools][benchmark-tools]
+repository, and the machines under test are uniform [Google Compute Engine][gce]
+Virtual Machines (VMs) with the following specifications:
+
+ Machine type: n1-standard-4 (broadwell)
+ Image: Debian GNU/Linux 9 (stretch) 4.19.0-0
+ BootDisk: 2048GB SSD persistent disk
+
+Through this document, `runsc` is used to indicate the runtime provided by
+gVisor. When relevant, we use the name `runsc-platform` to describe a specific
+[platform choice](/docs/architecture_guide/platforms/).
+
+**Except where specified, all tests below are conducted with the `ptrace`
+platform. The `ptrace` platform works everywhere and does not require hardware
+virtualization or kernel modifications but suffers from the highest structural
+costs by far. This platform is used to provide a clear understanding of the
+performance model, but in no way represents an ideal scenario. In the future,
+this guide will be extended to bare metal environments and include additional
+platforms.**
+
+## Memory access
+
+gVisor does not introduce any additional costs with respect to raw memory
+accesses. Page faults and other Operating System (OS) mechanisms are translated
+through the Sentry, but once mappings are installed and available to the
+application, there is no additional overhead.
+
+{% include graph.html id="sysbench-memory"
+url="/performance/sysbench-memory.csv" title="perf.py sysbench.memory
+--runtime=runc --runtime=runsc" %}
+
+The above figure demonstrates the memory transfer rate as measured by
+`sysbench`.
+
+## Memory usage
+
+The Sentry provides an additional layer of indirection, and it requires memory
+in order to store state associated with the application. This memory generally
+consists of a fixed component, plus an amount that varies with the usage of
+operating system resources (e.g. how many sockets or files are opened).
+
+For many use cases, fixed memory overheads are a primary concern. This may be
+because sandboxed containers handle a low volume of requests, and it is
+therefore important to achieve high densities for efficiency.
+
+{% include graph.html id="density" url="/performance/density.csv" title="perf.py
+density --runtime=runc --runtime=runsc" log="true" y_min="100000" %}
+
+The above figure demonstrates these costs based on three sample applications.
+This test is the result of running many instances of a container (50, or 5 in
+the case of redis) and calculating available memory on the host before and
+afterwards, and dividing the difference by the number of containers. This
+technique is used for measuring memory usage over the `usage_in_bytes` value of
+the container cgroup because we found that some container runtimes, other than
+`runc` and `runsc`, do not use an individual container cgroup.
+
+The first application is an instance of `sleep`: a trivial application that does
+nothing. The second application is a synthetic `node` application which imports
+a number of modules and listens for requests. The third application is a similar
+synthetic `ruby` application which does the same. Finally, we include an
+instance of `redis` storing approximately 1GB of data. In all cases, the sandbox
+itself is responsible for a small, mostly fixed amount of memory overhead.
+
+## CPU performance
+
+gVisor does not perform emulation or otherwise interfere with the raw execution
+of CPU instructions by the application. Therefore, there is no runtime cost
+imposed for CPU operations.
+
+{% include graph.html id="sysbench-cpu" url="/performance/sysbench-cpu.csv"
+title="perf.py sysbench.cpu --runtime=runc --runtime=runsc" %}
+
+The above figure demonstrates the `sysbench` measurement of CPU events per
+second. Events per second is based on a CPU-bound loop that calculates all prime
+numbers in a specified range. We note that `runsc` does not impose a performance
+penalty, as the code is executing natively in both cases.
+
+This has important consequences for classes of workloads that are often
+CPU-bound, such as data processing or machine learning. In these cases, `runsc`
+will similarly impose minimal runtime overhead.
+
+{% include graph.html id="tensorflow" url="/performance/tensorflow.csv"
+title="perf.py tensorflow --runtime=runc --runtime=runsc" %}
+
+For example, the above figure shows a sample TensorFlow workload, the
+[convolutional neural network example][cnn]. The time indicated includes the
+full start-up and run time for the workload, which trains a model.
+
+## System calls
+
+Some **structural costs** of gVisor are heavily influenced by the
+[platform choice](/docs/architecture_guide/platforms/), which implements system
+call interception. Today, gVisor supports a variety of platforms. These
+platforms present distinct performance, compatibility and security trade-offs.
+For example, the KVM platform has low overhead system call interception but runs
+poorly with nested virtualization.
+
+{% include graph.html id="syscall" url="/performance/syscall.csv" title="perf.py
+syscall --runtime=runc --runtime=runsc-ptrace --runtime=runsc-kvm" y_min="100"
+log="true" %}
+
+The above figure demonstrates the time required for a raw system call on various
+platforms. The test is implemented by a custom binary which performs a large
+number of system calls and calculates the average time required.
+
+This cost will principally impact applications that are system call bound, which
+tend to be high-performance data stores and static network services. In general,
+the impact of system call interception will be lower the more work an
+application does.
+
+{% include graph.html id="redis" url="/performance/redis.csv" title="perf.py
+redis --runtime=runc --runtime=runsc" %}
+
+For example, `redis` is an application that performs relatively little work in
+userspace: in general it reads from a connected socket, reads or modifies some
+data, and writes a result back to the socket. The above figure shows the results
+of running [comprehensive set of benchmarks][redis-benchmark]. We can see that
+small operations impose a large overhead, while larger operations, such as
+`LRANGE`, where more work is done in the application, have a smaller relative
+overhead.
+
+Some of these costs above are **structural costs**, and `redis` is likely to
+remain a challenging performance scenario. However, optimizing the
+[platform](/docs/architecture_guide/platforms/) will also have a dramatic
+impact.
+
+## Start-up time
+
+For many use cases, the ability to spin-up containers quickly and efficiently is
+important. A sandbox may be short-lived and perform minimal user work (e.g. a
+function invocation).
+
+{% include graph.html id="startup" url="/performance/startup.csv" title="perf.py
+startup --runtime=runc --runtime=runsc" %}
+
+The above figure indicates how total time required to start a container through
+[Docker][docker]. This benchmark uses three different applications. First, an
+alpine Linux-container that executes `true`. Second, a `node` application that
+loads a number of modules and binds an HTTP server. The time is measured by a
+successful request to the bound port. Finally, a `ruby` application that
+similarly loads a number of modules and binds an HTTP server.
+
+> Note: most of the time overhead above is associated Docker itself. This is
+> evident with the empty `runc` benchmark. To avoid these costs with `runsc`,
+> you may also consider using `runsc do` mode or invoking the
+> [OCI runtime](../user_guide/quick_start/oci.md) directly.
+
+## Network
+
+Networking is mostly bound by **implementation costs**, and gVisor's network
+stack is improving quickly.
+
+While typically not an important metric in practice for common sandbox use
+cases, nevertheless `iperf` is a common microbenchmark used to measure raw
+throughput.
+
+{% include graph.html id="iperf" url="/performance/iperf.csv" title="perf.py
+iperf --runtime=runc --runtime=runsc" %}
+
+The above figure shows the result of an `iperf` test between two instances. For
+the upload case, the specified runtime is used for the `iperf` client, and in
+the download case, the specified runtime is the server. A native runtime is
+always used for the other endpoint in the test.
+
+{% include graph.html id="applications" metric="requests_per_second"
+url="/performance/applications.csv" title="perf.py http.(node|ruby)
+--connections=25 --runtime=runc --runtime=runsc" %}
+
+The above figure shows the result of simple `node` and `ruby` web services that
+render a template upon receiving a request. Because these synthetic benchmarks
+do minimal work per request, must like the `redis` case, they suffer from high
+overheads. In practice, the more work an application does the smaller the impact
+of **structural costs** become.
+
+## File system
+
+Some aspects of file system performance are also reflective of **implementation
+costs**, and an area where gVisor's implementation is improving quickly.
+
+In terms of raw disk I/O, gVisor does not introduce significant fundamental
+overhead. For general file operations, gVisor introduces a small fixed overhead
+for data that transitions across the sandbox boundary. This manifests as
+**structural costs** in some cases, since these operations must be routed
+through the [Gofer](../README.md#gofer) as a result of our
+[Security Model](/docs/architecture_guide/security/), but in most cases are
+dominated by **implementation costs**, due to an internal
+[Virtual File System][vfs] (VFS) implementation that needs improvement.
+
+{% include graph.html id="fio-bw" url="/performance/fio.csv" title="perf.py fio
+--engine=sync --runtime=runc --runtime=runsc" log="true" %}
+
+The above figures demonstrate the results of `fio` for reads and writes to and
+from the disk. In this case, the disk quickly becomes the bottleneck and
+dominates other costs.
+
+{% include graph.html id="fio-tmpfs-bw" url="/performance/fio-tmpfs.csv"
+title="perf.py fio --engine=sync --runtime=runc --tmpfs=True --runtime=runsc"
+log="true" %}
+
+The above figure shows the raw I/O performance of using a `tmpfs` mount which is
+sandbox-internal in the case of `runsc`. Generally these operations are
+similarly bound to the cost of copying around data in-memory, and we don't see
+the cost of VFS operations.
+
+{% include graph.html id="httpd100k" metric="transfer_rate"
+url="/performance/httpd100k.csv" title="perf.py http.httpd --connections=1
+--connections=5 --connections=10 --connections=25 --runtime=runc
+--runtime=runsc" %}
+
+The high costs of VFS operations can manifest in benchmarks that execute many
+such operations in the hot path for serving requests, for example. The above
+figure shows the result of using gVisor to serve small pieces of static content
+with predictably poor results. This workload represents `apache` serving a
+single file sized 100k from the container image to a client running
+[ApacheBench][ab] with varying levels of concurrency. The high overhead comes
+principally from the VFS implementation that needs improvement, with several
+internal serialization points (since all requests are reading the same file).
+Note that some of some of network stack performance issues also impact this
+benchmark.
+
+{% include graph.html id="ffmpeg" url="/performance/ffmpeg.csv" title="perf.py
+media.ffmpeg --runtime=runc --runtime=runsc" %}
+
+For benchmarks that are bound by raw disk I/O and a mix of compute, file system
+operations are less of an issue. The above figure shows the total time required
+for an `ffmpeg` container to start, load and transcode a 27MB input video.
+
+[ab]: https://en.wikipedia.org/wiki/ApacheBench
+[benchmark-tools]: https://github.com/google/gvisor/tree/master/benchmarks
+[gce]: https://cloud.google.com/compute/
+[cnn]: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/convolutional_network.py
+[docker]: https://docker.io
+[redis-benchmark]: https://redis.io/topics/benchmarks
+[vfs]: https://en.wikipedia.org/wiki/Virtual_file_system
diff --git a/g3doc/architecture_guide/platforms.md b/g3doc/architecture_guide/platforms.md
new file mode 100644
index 000000000..d112c9a28
--- /dev/null
+++ b/g3doc/architecture_guide/platforms.md
@@ -0,0 +1,61 @@
+# Platform Guide
+
+[TOC]
+
+gVisor requires a platform to implement interception of syscalls, basic context
+switching, and memory mapping functionality. Internally, gVisor uses an
+abstraction sensibly called [Platform][platform]. A simplified version of this
+interface looks like:
+
+```golang
+type Platform interface {
+ NewAddressSpace() (AddressSpace, error)
+ NewContext() Context
+}
+
+type Context interface {
+ Switch(as AddressSpace, ac arch.Context) (..., error)
+}
+
+type AddressSpace interface {
+ MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, ...) error
+ Unmap(addr usermem.Addr, length uint64)
+}
+```
+
+There are a number of different ways to implement this interface that come with
+various trade-offs, generally around performance and hardware requirements.
+
+## Implementations
+
+The choice of platform depends on the context in which `runsc` is executing. In
+general, virtualized platforms may be limited to platforms that do not require
+hardware virtualized support (since the hardware is already in use):
+
+![Platforms](platforms.png "Platform examples.")
+
+### ptrace
+
+The ptrace platform uses [PTRACE_SYSEMU][ptrace] to execute user code without
+allowing it to execute host system calls. This platform can run anywhere that
+`ptrace` works (even VMs without nested virtualization), which is ubiquitous.
+
+Unfortunately, the ptrace platform has high context switch overhead, so system
+call-heavy applications may pay a [performance penalty](./performance.md).
+
+### KVM
+
+The KVM platform uses the kernel's [KVM][kvm] functionality to allow the Sentry
+to act as both guest OS and VMM. The KVM platform can run on bare-metal or in a
+VM with nested virtualization enabled. While there is no virtualized hardware
+layer -- the sandbox retains a process model -- gVisor leverages virtualization
+extensions available on modern processors in order to improve isolation and
+performance of address space switches.
+
+## Changing Platforms
+
+See [Changing Platforms](../user_guide/platforms.md).
+
+[kvm]: https://www.kernel.org/doc/Documentation/virtual/kvm/api.txt
+[platform]: https://cs.opensource.google/gvisor/gvisor/+/release-20190304.1:pkg/sentry/platform/platform.go;l=33
+[ptrace]: http://man7.org/linux/man-pages/man2/ptrace.2.html
diff --git a/g3doc/architecture_guide/platforms.png b/g3doc/architecture_guide/platforms.png
new file mode 100644
index 000000000..005d56feb
--- /dev/null
+++ b/g3doc/architecture_guide/platforms.png
Binary files differ
diff --git a/g3doc/architecture_guide/platforms.svg b/g3doc/architecture_guide/platforms.svg
new file mode 100644
index 000000000..b0bac9ba7
--- /dev/null
+++ b/g3doc/architecture_guide/platforms.svg
@@ -0,0 +1,334 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ width="142.67763mm"
+ height="67.063133mm"
+ viewBox="0 0 142.67763 67.063134"
+ version="1.1"
+ id="svg8"
+ inkscape:export-filename="/home/ascannell/resources.png"
+ inkscape:export-xdpi="53.50127"
+ inkscape:export-ydpi="53.50127"
+ inkscape:version="0.92.4 (5da689c313, 2019-01-14)"
+ sodipodi:docname="platforms.svg">
+ <defs
+ id="defs2" />
+ <sodipodi:namedview
+ id="base"
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1.0"
+ inkscape:pageopacity="0.0"
+ inkscape:pageshadow="2"
+ inkscape:zoom="0.98994949"
+ inkscape:cx="86.443612"
+ inkscape:cy="102.88104"
+ inkscape:document-units="mm"
+ inkscape:current-layer="layer1"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:window-width="1920"
+ inkscape:window-height="1005"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1" />
+ <metadata
+ id="metadata5">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ <dc:title></dc:title>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1"
+ transform="translate(-36.081387,-98.953278)">
+ <rect
+ id="rect10"
+ width="33.408691"
+ height="33.408691"
+ x="36.081387"
+ y="120.06757"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ style="fill:#b3b3b3;stroke-width:0.23881446"
+ id="rect16"
+ width="142.45465"
+ height="10.423517"
+ x="36.08139"
+ y="155.5929" />
+ <rect
+ id="rect10-7"
+ width="30.52453"
+ height="18.976137"
+ x="37.416695"
+ y="121.65508"
+ style="fill:#ff8080;stroke-width:0.19060372" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="41.03727"
+ y="148.58765"
+ id="text65"><tspan
+ sodipodi:role="line"
+ id="tspan63"
+ x="41.03727"
+ y="148.58765"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="45.473087"
+ y="132.50232"
+ id="text123"><tspan
+ sodipodi:role="line"
+ id="tspan121"
+ x="45.473087"
+ y="132.50232"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:6.43922186px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.16098055"
+ x="97.768547"
+ y="163.15665"
+ id="text163"><tspan
+ sodipodi:role="line"
+ id="tspan161"
+ x="97.768547"
+ y="163.15665"
+ style="stroke-width:0.16098055">host</tspan></text>
+ <rect
+ style="fill:#e9afdd;stroke-width:0.39185274"
+ id="rect16-7"
+ width="72.9646"
+ height="54.79026"
+ x="105.79441"
+ y="98.953278" />
+ <rect
+ id="rect10-5"
+ width="33.408691"
+ height="33.408691"
+ x="108.24348"
+ y="100.53072"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ id="rect10-7-6"
+ width="30.52453"
+ height="20.045216"
+ x="109.57877"
+ y="102.11823"
+ style="fill:#ff8080;stroke-width:0.19589928" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="112.86765"
+ y="129.01863"
+ id="text65-2"><tspan
+ sodipodi:role="line"
+ id="tspan63-9"
+ x="112.86765"
+ y="129.01863"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="117.63519"
+ y="114.02371"
+ id="text123-1"><tspan
+ sodipodi:role="line"
+ id="tspan121-2"
+ x="117.63519"
+ y="114.02371"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <rect
+ id="rect10-7-7"
+ width="11.815663"
+ height="8.0126781"
+ x="54.538059"
+ y="143.27702"
+ style="fill:#aaccff;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:4.35074377px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.10876859"
+ x="55.931114"
+ y="148.90578"
+ id="text144"><tspan
+ sodipodi:role="line"
+ id="tspan142"
+ x="55.931114"
+ y="148.90578"
+ style="stroke-width:0.10876859">KVM</tspan></text>
+ <rect
+ id="rect10-6"
+ width="33.408691"
+ height="33.408691"
+ x="71.044685"
+ y="119.73112"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ id="rect10-7-0"
+ width="30.52453"
+ height="18.976137"
+ x="72.37999"
+ y="121.31865"
+ style="fill:#ff8080;stroke-width:0.19060372" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="76.000565"
+ y="148.25128"
+ id="text65-6"><tspan
+ sodipodi:role="line"
+ id="tspan63-2"
+ x="76.000565"
+ y="148.25128"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="80.436386"
+ y="132.16595"
+ id="text123-6"><tspan
+ sodipodi:role="line"
+ id="tspan121-1"
+ x="80.436386"
+ y="132.16595"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <rect
+ id="rect10-7-7-8"
+ width="11.815664"
+ height="8.0126781"
+ x="89.501358"
+ y="142.94067"
+ style="fill:#ffeeaa;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.39456654px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08486416"
+ x="89.92292"
+ y="147.89806"
+ id="text144-7"><tspan
+ sodipodi:role="line"
+ id="tspan142-9"
+ x="89.92292"
+ y="147.89806"
+ style="stroke-width:0.08486416">ptrace</tspan></text>
+ <rect
+ id="rect10-7-7-8-3"
+ width="11.815665"
+ height="8.0126781"
+ x="127.08897"
+ y="123.97878"
+ style="fill:#ffeeaa;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.39456654px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08486416"
+ x="127.51052"
+ y="128.9362"
+ id="text144-7-7"><tspan
+ sodipodi:role="line"
+ id="tspan142-9-5"
+ x="127.51052"
+ y="128.9362"
+ style="stroke-width:0.08486416">ptrace</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:5.45061255px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.13626531"
+ x="138.49318"
+ y="152.11841"
+ id="text229"><tspan
+ sodipodi:role="line"
+ id="tspan227"
+ x="138.49318"
+ y="152.11841"
+ style="stroke-width:0.13626531">VM</tspan></text>
+ <rect
+ style="fill:#b3b3b3;stroke-width:0.16518368"
+ id="rect16-9"
+ width="68.15374"
+ height="10.423517"
+ x="108.24348"
+ y="134.99774" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:6.17854786px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.15446369"
+ x="132.91473"
+ y="142.07658"
+ id="text248"><tspan
+ sodipodi:role="line"
+ id="tspan246"
+ x="132.91473"
+ y="142.07658"
+ style="stroke-width:0.15446369">guest</tspan></text>
+ <rect
+ id="rect10-5-2"
+ width="33.408691"
+ height="33.408691"
+ x="143.32402"
+ y="100.35877"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <rect
+ id="rect10-7-6-2"
+ width="30.52453"
+ height="20.045216"
+ x="144.65933"
+ y="101.94627"
+ style="fill:#ff8080;stroke-width:0.19589929" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="147.94815"
+ y="128.84665"
+ id="text65-2-8"><tspan
+ sodipodi:role="line"
+ id="tspan63-9-9"
+ x="147.94815"
+ y="128.84665"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="152.71565"
+ y="113.85176"
+ id="text123-1-7"><tspan
+ sodipodi:role="line"
+ id="tspan121-2-3"
+ x="152.71565"
+ y="113.85176"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <rect
+ id="rect10-7-7-8-3-6"
+ width="11.815666"
+ height="8.0126781"
+ x="162.16933"
+ y="123.80682"
+ style="fill:#ffeeaa;stroke-width:0.07705856" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.39456654px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08486416"
+ x="162.59088"
+ y="128.76421"
+ id="text144-7-7-1"><tspan
+ sodipodi:role="line"
+ id="tspan142-9-5-2"
+ x="162.59088"
+ y="128.76421"
+ style="stroke-width:0.08486416">ptrace</tspan></text>
+ </g>
+</svg>
diff --git a/g3doc/architecture_guide/resources.md b/g3doc/architecture_guide/resources.md
new file mode 100644
index 000000000..1dec37bd1
--- /dev/null
+++ b/g3doc/architecture_guide/resources.md
@@ -0,0 +1,144 @@
+# Resource Model
+
+[TOC]
+
+The resource model for gVisor does not assume a fixed number of threads of
+execution (i.e. vCPUs) or amount of physical memory. Where possible, decisions
+about underlying physical resources are delegated to the host system, where
+optimizations can be made with global information. This delegation allows the
+sandbox to be highly dynamic in terms of resource usage: spanning a large number
+of cores and large amount of memory when busy, and yielding those resources back
+to the host when not.
+
+In order words, the shape of the sandbox should closely track the shape of the
+sandboxed process:
+
+![Resource model](resources.png "Workloads of different shapes.")
+
+## Processes
+
+Much like a Virtual Machine (VM), a gVisor sandbox appears as an opaque process
+on the system. Processes within the sandbox do not manifest as processes on the
+host system, and process-level interactions within the sandbox requires entering
+the sandbox (e.g. via a [Docker exec][exec]).
+
+## Networking
+
+The sandbox attaches a network endpoint to the system, but runs it's own network
+stack. All network resources, other than packets in flight on the host, exist
+only inside the sandbox, bound by relevant resource limits.
+
+You can interact with network endpoints exposed by the sandbox, just as you
+would any other container, but network introspection similarly requires entering
+the sandbox.
+
+## Files
+
+Files in the sandbox may be backed by different implementations. For host-native
+files (where a file descriptor is available), the Gofer may return a file
+descriptor to the Sentry via [SCM_RIGHTS][scmrights][^1].
+
+These files may be read from and written to through standard system calls, and
+also mapped into the associated application's address space. This allows the
+same host memory to be shared across multiple sandboxes, although this mechanism
+does not preclude the use of side-channels (see [Security Model](./security.md).
+
+Note that some file systems exist only within the context of the sandbox. For
+example, in many cases a `tmpfs` mount will be available at `/tmp` or
+`/dev/shm`, which allocates memory directly from the sandbox memory file (see
+below). Ultimately, these will be accounted against relevant limits in a similar
+way as the host native case.
+
+## Threads
+
+The Sentry models individual task threads with [goroutines][goroutine]. As a
+result, each task thread is a lightweight [green thread][greenthread], and may
+not correspond to an underlying host thread.
+
+However, application execution is modelled as a blocking system call with the
+Sentry. This means that additional host threads may be created, *depending on
+the number of active application threads*. In practice, a busy application will
+converge on the number of active threads, and the host will be able to make
+scheduling decisions about all application threads.
+
+## Time
+
+Time in the sandbox is provided by the Sentry, through its own [vDSO][vdso] and
+time-keeping implementation. This is distinct from the host time, and no state
+is shared with the host, although the time will be initialized with the host
+clock.
+
+The Sentry runs timers to note the passage of time, much like a kernel running
+on hardware (though the timers are software timers, in this case). These timers
+provide updates to the vDSO, the time returned through system calls, and the
+time recorded for usage or limit tracking (e.g. [RLIMIT_CPU][rlimit]).
+
+When all application threads are idle, the Sentry disables timers until an event
+occurs that wakes either the Sentry or an application thread, similar to a
+[tickless kernel][tickless]. This allows the Sentry to achieve near zero CPU
+usage for idle applications.
+
+## Memory
+
+The Sentry implements its own memory management, including demand-paging and a
+Sentry internal page cache for files that cannot be used natively. A single
+[memfd][memfd] backs all application memory.
+
+### Address spaces
+
+The creation of address spaces is platform-specific. For some platforms,
+additional "stub" processes may be created on the host in order to support
+additional address spaces. These stubs are subject to various limits applied at
+the sandbox level (e.g. PID limits).
+
+### Physical memory
+
+The host is able to manage physical memory using regular means (e.g. tracking
+working sets, reclaiming and swapping under pressure). The Sentry lazily
+populates host mappings for applications, and allow the host to demand-page
+those regions, which is critical for the functioning of those mechanisms.
+
+In order to avoid excessive overhead, the Sentry does not demand-page individual
+pages. Instead, it selects appropriate regions based on heuristics. There is a
+trade-off here: the Sentry is unable to trivially determine which pages are
+active and which are not. Even if pages were individually faulted, the host may
+select pages to be reclaimed or swapped without the Sentry's knowledge.
+
+Therefore, memory usage statistics within the sandbox (e.g. via `proc`) are
+approximations. The Sentry maintains an internal breakdown of memory usage, and
+can collect accurate information but only through a relatively expensive API
+call. In any case, it would likely be considered unwise to share precise
+information about how the host is managing memory with the sandbox.
+
+Finally, when an application marks a region of memory as no longer needed, for
+example via a call to [madvise][madvise], the Sentry *releases this memory back
+to the host*. There can be performance penalties for this, since it may be
+cheaper in many cases to retain the memory and use it to satisfy some other
+request. However, releasing it immediately to the host allows the host to more
+effectively multiplex resources and apply an efficient global policy.
+
+## Limits
+
+All Sentry threads and Sentry memory are subject to a container cgroup. However,
+application usage will not appear as anonymous memory usage, and will instead be
+accounted to the `memfd`. All anonymous memory will correspond to Sentry usage,
+and host memory charged to the container will work as standard.
+
+The cgroups can be monitored for standard signals: pressure indicators,
+threshold notifiers, etc. and can also be adjusted dynamically. Note that the
+Sentry itself may listen for pressure signals in its containing cgroup, in order
+to purge internal caches.
+
+[goroutine]: https://tour.golang.org/concurrency/1
+[greenthread]: https://en.wikipedia.org/wiki/Green_threads
+[scheduler]: https://morsmachine.dk/go-scheduler
+[vdso]: https://en.wikipedia.org/wiki/VDSO
+[rlimit]: http://man7.org/linux/man-pages/man2/getrlimit.2.html
+[tickless]: https://en.wikipedia.org/wiki/Tickless_kernel
+[memfd]: http://man7.org/linux/man-pages/man2/memfd_create.2.html
+[scmrights]: http://man7.org/linux/man-pages/man7/unix.7.html
+[madvise]: http://man7.org/linux/man-pages/man2/madvise.2.html
+[exec]: https://docs.docker.com/engine/reference/commandline/exec/
+[^1]: Unless host networking is enabled, the Sentry is not able to create or
+ open host file descriptors itself, it can only receive them in this way
+ from the Gofer.
diff --git a/g3doc/architecture_guide/resources.png b/g3doc/architecture_guide/resources.png
new file mode 100644
index 000000000..f715008ec
--- /dev/null
+++ b/g3doc/architecture_guide/resources.png
Binary files differ
diff --git a/g3doc/architecture_guide/resources.svg b/g3doc/architecture_guide/resources.svg
new file mode 100644
index 000000000..fd7805d90
--- /dev/null
+++ b/g3doc/architecture_guide/resources.svg
@@ -0,0 +1,208 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ width="108.24417mm"
+ height="47.513165mm"
+ viewBox="0 0 108.24417 47.513165"
+ version="1.1"
+ id="svg8"
+ inkscape:export-filename="/home/ascannell/resources.png"
+ inkscape:export-xdpi="53.50127"
+ inkscape:export-ydpi="53.50127"
+ inkscape:version="0.92.4 (5da689c313, 2019-01-14)"
+ sodipodi:docname="resources.svg">
+ <defs
+ id="defs2" />
+ <sodipodi:namedview
+ id="base"
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1.0"
+ inkscape:pageopacity="0.0"
+ inkscape:pageshadow="2"
+ inkscape:zoom="0.98994949"
+ inkscape:cx="16.897058"
+ inkscape:cy="41.261746"
+ inkscape:document-units="mm"
+ inkscape:current-layer="layer1"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:window-width="1920"
+ inkscape:window-height="1005"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1" />
+ <metadata
+ id="metadata5">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ <dc:title></dc:title>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1"
+ transform="translate(-36.081387,-118.50325)">
+ <rect
+ id="rect10"
+ width="33.408691"
+ height="33.408691"
+ x="36.081387"
+ y="120.06757"
+ style="fill:#44aa00;stroke-width:0.26458332" />
+ <circle
+ style="fill:#44aa00;stroke-width:0.21849461"
+ id="path12"
+ cx="87.958534"
+ cy="136.63828"
+ r="17.105247" />
+ <path
+ sodipodi:type="star"
+ style="fill:#44aa00;stroke-width:0.26458332"
+ id="path14"
+ sodipodi:sides="3"
+ sodipodi:cx="124.13387"
+ sodipodi:cy="141.81859"
+ sodipodi:r1="23.31534"
+ sodipodi:r2="11.65767"
+ sodipodi:arg1="0.52359878"
+ sodipodi:arg2="1.5707963"
+ inkscape:flatsided="false"
+ inkscape:rounded="0"
+ inkscape:randomized="0"
+ d="m 144.32555,153.47626 -20.19168,0 -20.19167,0 10.09583,-17.48651 10.09584,-17.4865 10.09584,17.4865 z"
+ inkscape:transform-center-x="1.8384776e-06"
+ inkscape:transform-center-y="-5.8288369" />
+ <rect
+ style="fill:#b3b3b3;stroke-width:0.20817307"
+ id="rect16"
+ width="108.24416"
+ height="10.423517"
+ x="36.08139"
+ y="155.5929" />
+ <path
+ sodipodi:type="star"
+ style="fill:#ff8080;stroke-width:0.20018946"
+ id="path14-3"
+ sodipodi:sides="3"
+ sodipodi:cx="124.13387"
+ sodipodi:cy="139.31911"
+ sodipodi:r1="17.640888"
+ sodipodi:r2="8.8204451"
+ sodipodi:arg1="0.52359878"
+ sodipodi:arg2="1.5707963"
+ inkscape:flatsided="false"
+ inkscape:rounded="0"
+ inkscape:randomized="0"
+ d="m 139.41133,148.13955 -15.27746,0 -15.27745,0 7.63872,-13.23067 7.63873,-13.23066 7.63873,13.23066 z"
+ inkscape:transform-center-x="3.9117172e-06"
+ inkscape:transform-center-y="-4.4102243" />
+ <circle
+ style="fill:#ff8080;stroke-width:0.18094084"
+ id="path12-6"
+ cx="87.93705"
+ cy="134.75125"
+ r="14.165282" />
+ <rect
+ id="rect10-7"
+ width="30.52453"
+ height="25.657875"
+ x="37.416695"
+ y="121.65508"
+ style="fill:#ff8080;stroke-width:0.22163473" />
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="47.387276"
+ y="151.7626"
+ id="text65"><tspan
+ sodipodi:role="line"
+ id="tspan63"
+ x="47.387276"
+ y="151.7626"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="82.156319"
+ y="151.71547"
+ id="text65-5"><tspan
+ sodipodi:role="line"
+ id="tspan63-3"
+ x="82.156319"
+ y="151.71547"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.40292525px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08507314"
+ x="118.66879"
+ y="151.71547"
+ id="text65-5-5"><tspan
+ sodipodi:role="line"
+ id="tspan63-3-6"
+ x="118.66879"
+ y="151.71547"
+ style="stroke-width:0.08507314">gVisor</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="45.473087"
+ y="136.20644"
+ id="text123"><tspan
+ sodipodi:role="line"
+ id="tspan121"
+ x="45.473087"
+ y="136.20644"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="80.153076"
+ y="136.00925"
+ id="text123-1"><tspan
+ sodipodi:role="line"
+ id="tspan121-2"
+ x="80.153076"
+ y="136.00925"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:3.33113885px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.08327847"
+ x="116.50173"
+ y="138.68195"
+ id="text123-1-7"><tspan
+ sodipodi:role="line"
+ id="tspan121-2-0"
+ x="116.50173"
+ y="138.68195"
+ style="stroke-width:0.08327847">workload</tspan></text>
+ <text
+ xml:space="preserve"
+ style="font-style:normal;font-weight:normal;font-size:6.43922186px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.16098055"
+ x="81.893562"
+ y="163.15665"
+ id="text163"><tspan
+ sodipodi:role="line"
+ id="tspan161"
+ x="81.893562"
+ y="163.15665"
+ style="stroke-width:0.16098055">host</tspan></text>
+ </g>
+</svg>
diff --git a/g3doc/architecture_guide/security.md b/g3doc/architecture_guide/security.md
new file mode 100644
index 000000000..b99b86332
--- /dev/null
+++ b/g3doc/architecture_guide/security.md
@@ -0,0 +1,255 @@
+# Security Model
+
+[TOC]
+
+gVisor was created in order to provide additional defense against the
+exploitation of kernel bugs by untrusted userspace code. In order to understand
+how gVisor achieves this goal, it is first necessary to understand the basic
+threat model.
+
+## Threats: The Anatomy of an Exploit
+
+An exploit takes advantage of a software or hardware bug in order to escalate
+privileges, gain access to privileged data, or disrupt services. All of the
+possible interactions that a malicious application can have with the rest of the
+system (attack vectors) define the attack surface. We categorize these attack
+vectors into several common classes.
+
+### System API
+
+An operating system or hypervisor exposes an abstract System API in the form of
+system calls and traps. This API may be documented and stable, as with Linux, or
+it may be abstracted behind a library, as with Windows (i.e. win32.dll or
+ntdll.dll). The System API includes all standard interfaces that application
+code uses to interact with the system. This includes high-level abstractions
+that are derived from low-level system calls, such as system files, sockets and
+namespaces.
+
+Although the System API is exposed to applications by design, bugs and race
+conditions within the kernel or hypervisor may occasionally be exploitable via
+the API. This is common in part due to the fact that most kernels and
+hypervisors are written in [C][clang], which is well-suited to interfacing with
+hardware but often prone to security issues. In order to exploit these issues, a
+typical attack might involve some combination of the following:
+
+1. Opening or creating some combination of files, sockets or other descriptors.
+1. Passing crafted, malicious arguments, structures or packets.
+1. Racing with multiple threads in order to hit specific code paths.
+
+For example, for the [Dirty Cow][dirtycow] privilege escalation bug, an
+application would open a specific file in `/proc` or use a specific `ptrace`
+system call, and use multiple threads in order to trigger a race condition when
+touching a fresh page of memory. The attacker then gains control over a page of
+memory belonging to the system. With additional privileges or access to
+privileged data in the kernel, an attacker will often be able to employ
+additional techniques to gain full access to the rest of the system.
+
+While bugs in the implementation of the System API are readily fixed, they are
+also the most common form of exploit. The exposure created by this class of
+exploit is what gVisor aims to minimize and control, described in detail below.
+
+### System ABI
+
+Hardware and software exploits occasionally exist in execution paths that are
+not part of an intended System API. In this case, exploits may be found as part
+of implicit actions the hardware or privileged system code takes in response to
+certain events, such as traps or interrupts. For example, the recent
+[POPSS][popss] flaw required only native code execution (no specific system call
+or file access). In that case, the Xen hypervisor was similarly vulnerable,
+highlighting that hypervisors are not immune to this vector.
+
+### Side Channels
+
+Hardware side channels may be exploitable by any code running on a system:
+native, sandboxed, or virtualized. However, many host-level mitigations against
+hardware side channels are still effective with a sandbox. For example, kernels
+built with retpoline protect against some speculative execution attacks
+(Spectre) and frame poisoning may protect against L1 terminal fault (L1TF)
+attacks. Hypervisors may introduce additional complications in this regard, as
+there is no mitigation against an application in a normally functioning Virtual
+Machine (VM) exploiting the L1TF vulnerability for another VM on the sibling
+hyperthread.
+
+### Other Vectors
+
+The above categories in no way represent an exhaustive list of exploits, as we
+focus only on running untrusted code from within the operating system or
+hypervisor. We do not consider other ways that a more generic adversary may
+interact with a system, such as inserting a portable storage device with a
+malicious filesystem image, using a combination of crafted keyboard or touch
+inputs, or saturating a network device with ill-formed packets.
+
+Furthermore, high-level systems may contain exploitable components. An attacker
+need not escalate privileges within a container if there’s an exploitable
+network-accessible service on the host or some other API path. *A sandbox is not
+a substitute for a secure architecture*.
+
+## Goals: Limiting Exposure
+
+![Threat model](security.png "Threat model.")
+
+gVisor’s primary design goal is to minimize the System API attack vector through
+multiple layers of defense, while still providing a process model. There are two
+primary security principles that inform this design. First, the application’s
+direct interactions with the host System API are intercepted by the Sentry,
+which implements the System API instead. Second, the System API accessible to
+the Sentry itself is minimized to a safer, restricted set. The first principle
+minimizes the possibility of direct exploitation of the host System API by
+applications, and the second principle minimizes indirect exploitability, which
+is the exploitation by an exploited or buggy Sentry (e.g. chaining an exploit).
+
+The first principle is similar to the security basis for a Virtual Machine (VM).
+With a VM, an application’s interactions with the host are replaced by
+interactions with a guest operating system and a set of virtualized hardware
+devices. These hardware devices are then implemented via the host System API by
+a Virtual Machine Monitor (VMM). The Sentry similarly prevents direct
+interactions by providing its own implementation of the System API that the
+application must interact with. Applications are not able to to directly craft
+specific arguments or flags for the host System API, or interact directly with
+host primitives.
+
+For both the Sentry and a VMM, it’s worth noting that while direct interactions
+are not possible, indirect interactions are still possible. For example, a read
+on a host-backed file in the Sentry may ultimately result in a host read system
+call (made by the Sentry, not by passing through arguments from the
+application), similar to how a read on a block device in a VM may result in the
+VMM issuing a corresponding host read system call from a backing file.
+
+An important distinction from a VM is that the Sentry implements a System API
+based directly on host System API primitives instead of relying on virtualized
+hardware and a guest operating system. This selects a distinct set of
+trade-offs, largely in the performance, efficiency and compatibility domains.
+Since transitions in and out of the sandbox are relatively expensive, a guest
+operating system will typically take ownership of resources. For example, in the
+above case, the guest operating system may read the block device data in a local
+page cache, to avoid subsequent reads. This may lead to better performance but
+lower efficiency, since memory may be wasted or duplicated. The Sentry opts
+instead to defer to the host for many operations during runtime, for improved
+efficiency but lower performance in some use cases.
+
+### What can a sandbox do?
+
+An application in a gVisor sandbox is permitted to do most things a standard
+container can do: for example, applications can read and write files mapped
+within the container, make network connections, etc. As described above,
+gVisor's primary goal is to limit exposure to bugs and exploits while still
+allowing most applications to run. Even so, gVisor will limit some operations
+that might be permitted with a standard container. Even with appropriate
+capabilities, a user in a gVisor sandbox will only be able to manipulate
+virtualized system resources (e.g. the system time, kernel settings or
+filesystem attributes) and not underlying host system resources.
+
+While the sandbox virtualizes many operations for the application, we limit the
+sandbox's own interactions with the host to the following high-level operations:
+
+1. Communicate with a Gofer process via a connected socket. The sandbox may
+ receive new file descriptors from the Gofer process, corresponding to opened
+ files. These files can then be read from and written to by the sandbox.
+1. Make a minimal set of host system calls. The calls do not include the
+ creation of new sockets (unless host networking mode is enabled) or opening
+ files. The calls include duplication and closing of file descriptors,
+ synchronization, timers and signal management.
+1. Read and write packets to a virtual ethernet device. This is not required if
+ host networking is enabled (or networking is disabled).
+
+### System ABI, Side Channels and Other Vectors
+
+gVisor relies on the host operating system and the platform for defense against
+hardware-based attacks. Given the nature of these vulnerabilities, there is
+little defense that gVisor can provide (there’s no guarantee that additional
+hardware measures, such as virtualization, memory encryption, etc. would
+actually decrease the attack surface). Note that this is true even when using
+hardware virtualization for acceleration, as the host kernel or hypervisor is
+ultimately responsible for defending against attacks from within malicious
+guests.
+
+gVisor similarly relies on the host resource mechanisms (cgroups) for defense
+against resource exhaustion and denial of service attacks. Network policy
+controls should be applied at the container level to ensure appropriate network
+policy enforcement. Note that the sandbox itself is not capable of altering or
+configuring these mechanisms, and the sandbox itself should make an attacker
+less likely to exploit or override these controls through other means.
+
+## Principles: Defense-in-Depth
+
+For gVisor development, there are several engineering principles that are
+employed in order to ensure that the system meets its design goals.
+
+1. No system call is passed through directly to the host. Every supported call
+ has an independent implementation in the Sentry, that is unlikely to suffer
+ from identical vulnerabilities that may appear in the host. This has the
+ consequence that all kernel features used by applications require an
+ implementation within the Sentry.
+1. Only common, universal functionality is implemented. Some filesystems,
+ network devices or modules may expose specialized functionality to user
+ space applications via mechanisms such as extended attributes, raw sockets
+ or ioctls. Since the Sentry is responsible for implementing the full system
+ call surface, we do not implement or pass through these specialized APIs.
+1. The host surface exposed to the Sentry is minimized. While the system call
+ surface is not trivial, it is explicitly enumerated and controlled. The
+ Sentry is not permitted to open new files, create new sockets or do many
+ other interesting things on the host.
+
+Additionally, we have practical restrictions that are imposed on the project to
+minimize the risk of Sentry exploitability. For example:
+
+1. Unsafe code is carefully controlled. All unsafe code is isolated in files
+ that end with "unsafe.go", in order to facilitate validation and auditing.
+ No file without the unsafe suffix may import the unsafe package.
+1. No CGo is allowed. The Sentry must be a pure Go binary.
+1. External imports are not generally allowed within the core packages. Only
+ limited external imports are used within the setup code. The code available
+ inside the Sentry is carefully controlled, to ensure that the above rules
+ are effective.
+
+Finally, we recognize that security is a process, and that vigilance is
+critical. Beyond our security disclosure process, the Sentry is fuzzed
+continuously to identify potential bugs and races proactively, and production
+crashes are recorded and triaged to similarly identify material issues.
+
+## FAQ
+
+### Is this more or less secure than a Virtual Machine?
+
+The security of a VM depends to a large extent on what is exposed from the host
+kernel and userspace support code. For example, device emulation code in the
+host kernel (e.g. APIC) or optimizations (e.g. vhost) can be more complex than a
+simple system call, and exploits carry the same risks. Similarly, the userspace
+support code is frequently unsandboxed, and exploits, while rare, may allow
+unfettered access to the system.
+
+Some platforms leverage the same virtualization hardware as VMs in order to
+provide better system call interception performance. However, gVisor does not
+implement any device emulation, and instead opts to use a sandboxed host System
+API directly. Both approaches significantly reduce the original attack surface.
+Ultimately, since gVisor is capable of using the same hardware mechanism, one
+should not assume that the mere use of virtualization hardware makes a system
+more or less secure, just as it would be a mistake to make the claim that the
+use of a unibody alone makes a car safe.
+
+### Does this stop hardware side channels?
+
+In general, gVisor does not provide protection against hardware side channels,
+although it may make exploits that rely on direct access to the host System API
+more difficult to use. To minimize exposure, you should follow relevant guidance
+from vendors and keep your host kernel and firmware up-to-date.
+
+### Is this just a ptrace sandbox?
+
+No: the term “ptrace sandbox” generally refers to software that uses the Linux
+ptrace facility to inspect and authorize system calls made by applications,
+enforcing a specific policy. These commonly suffer from two issues. First,
+vulnerable system calls may be authorized by the sandbox, as the application
+still has direct access to some System API. Second, it’s impossible to avoid
+time-of-check, time-of-use race conditions without disabling multi-threading.
+
+In gVisor, the platforms that use ptrace operate differently. The stubs that are
+traced are never allowed to continue execution into the host kernel and complete
+a call directly. Instead, all system calls are interpreted and handled by the
+Sentry itself, who reflects resulting register state back into the tracee before
+continuing execution in userspace. This is very similar to the mechanism used by
+User-Mode Linux (UML).
+
+[dirtycow]: https://en.wikipedia.org/wiki/Dirty_COW
+[clang]: https://en.wikipedia.org/wiki/C_(programming_language)
+[popss]: https://nvd.nist.gov/vuln/detail/CVE-2018-8897
diff --git a/g3doc/architecture_guide/security.png b/g3doc/architecture_guide/security.png
new file mode 100644
index 000000000..c29befbf6
--- /dev/null
+++ b/g3doc/architecture_guide/security.png
Binary files differ
diff --git a/g3doc/architecture_guide/security.svg b/g3doc/architecture_guide/security.svg
new file mode 100644
index 000000000..0575e2dec
--- /dev/null
+++ b/g3doc/architecture_guide/security.svg
@@ -0,0 +1,153 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ width="92.963379mm"
+ height="107.18885mm"
+ viewBox="0 0 92.963379 107.18885"
+ version="1.1"
+ id="svg8"
+ inkscape:version="0.92.4 (5da689c313, 2019-01-14)"
+ sodipodi:docname="defense.svg">
+ <defs
+ id="defs2" />
+ <sodipodi:namedview
+ id="base"
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1.0"
+ inkscape:pageopacity="0.0"
+ inkscape:pageshadow="2"
+ inkscape:zoom="0.98994949"
+ inkscape:cx="-242.99254"
+ inkscape:cy="136.90181"
+ inkscape:document-units="mm"
+ inkscape:current-layer="layer4"
+ showgrid="false"
+ inkscape:object-nodes="true"
+ inkscape:window-width="1920"
+ inkscape:window-height="1005"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0" />
+ <metadata
+ id="metadata5">
+ <rdf:RDF>
+ <cc:Work
+ rdf:about="">
+ <dc:format>image/svg+xml</dc:format>
+ <dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" />
+ <dc:title></dc:title>
+ </cc:Work>
+ </rdf:RDF>
+ </metadata>
+ <g
+ inkscape:groupmode="layer"
+ id="layer2"
+ inkscape:label="Layer 2"
+ transform="translate(-61.112559,-78.160466)">
+ <g
+ id="g4644"
+ style="fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ transform="matrix(1,0,0,-1,2.138671,277.94235)">
+ <path
+ transform="scale(0.26458333)"
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:3.77952766;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ d="M 398.57227,351.84766 224.7832,452.18359 398.57227,552.51953 572.35938,452.18359 Z"
+ id="path4638" />
+ <path
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:3.77952766;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ d="M 572.35938,452.18359 398.57227,552.51953 V 753.19141 L 572.35938,652.85547 Z"
+ transform="scale(0.26458333)"
+ id="path4640" />
+ <path
+ id="path4642"
+ d="m 59.473888,119.64024 45.981172,26.54722 v 53.09443 L 59.473888,172.73467 Z"
+ style="opacity:1;fill:none;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:0.25572576"
+ inkscape:connector-curvature="0" />
+ </g>
+ </g>
+ <g
+ inkscape:groupmode="layer"
+ id="layer3"
+ inkscape:label="Layer 3"
+ transform="translate(-61.112559,-78.160466)">
+ <g
+ id="g4554"
+ transform="matrix(-0.39771468,0.69855937,-0.69855937,-0.39771468,366.58103,126.65261)">
+ <g
+ id="g4662"
+ transform="translate(59.46839,130.66062)">
+ <path
+ inkscape:connector-curvature="0"
+ id="path4548"
+ transform="scale(0.26458333)"
+ d="M 398.57227,351.84766 224.7832,452.18359 398.57227,552.51953 572.35938,452.18359 Z"
+ style="opacity:1;fill:#0066ff;fill-opacity:0.34509804;stroke:#00a5ff;stroke-width:4.70182848;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" />
+ <path
+ inkscape:connector-curvature="0"
+ id="path4550"
+ transform="scale(0.26458333)"
+ d="M 572.35938,452.18359 398.57227,552.51953 V 753.19141 L 572.35938,652.85547 Z"
+ style="opacity:1;fill:#0044aa;fill-opacity:0.34509804;stroke:#00a5ff;stroke-width:4.29276943;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" />
+ <path
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:#5599ff;fill-opacity:0.34509804;stroke:#00a5ff;stroke-width:1.24402535;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ d="m 59.473888,119.64024 45.981172,26.54722 v 53.09443 L 59.473888,172.73467 Z"
+ id="path4552" />
+ </g>
+ </g>
+ </g>
+ <g
+ inkscape:groupmode="layer"
+ id="layer4"
+ inkscape:label="Layer 4"
+ transform="translate(-61.112559,-78.160466)">
+ <path
+ style="fill:#e000ae;fill-opacity:1;stroke-width:0.12476727"
+ d="m 84.610811,107.36071 v 2.55773 2.55772 h 2.49535 2.49534 v -2.55772 -2.55773 h -2.49534 z m 40.674129,0 v 2.55773 2.55772 h 2.49535 2.49534 v -2.55772 -2.55773 h -2.49534 z m -35.558669,5.11545 v 2.55773 2.55773 h 2.49535 2.49534 v -2.55773 -2.55773 h -2.49534 z m 4.99069,5.11546 v 2.55773 2.55773 h -2.49534 -2.49535 v 2.49534 2.49535 h -2.55773 -2.55773 v 2.55773 2.55773 h -2.55773 -2.55773 v 10.16853 10.16853 h 2.55773 2.55773 v -7.67562 -7.67587 l 2.52654,0.0339 2.52654,0.0336 0.0327,5.08427 0.0327,5.08426 h 2.49388 2.49388 v 2.55919 2.5592 l 5.08427,-0.0327 5.084269,-0.0326 v -2.49534 -2.49535 l -5.084269,-0.0324 -5.08427,-0.0327 v -2.55626 -2.55651 h 12.726269 12.72626 v 2.55651 2.55626 l -5.05868,0.0327 -5.05893,0.0324 v 2.49535 2.49534 l 5.05893,0.0326 5.05868,0.0327 v -2.55919 -2.55919 h 2.49388 2.49413 l 0.0324,-5.08426 0.0327,-5.08427 2.52653,-0.0336 2.52654,-0.0339 v 7.67586 7.67563 h 2.55773 2.55773 v -10.16854 -10.16853 h -2.55773 -2.55773 v -2.55773 -2.55773 h -2.55773 -2.55773 v -2.49535 -2.49534 h -2.49535 -2.49534 v -2.55773 -2.55773 h -2.55773 -2.55773 v 2.55773 2.55773 h -7.6108 -7.610809 v -2.55773 -2.55773 h -2.55774 z m 25.452519,0 h 2.49535 2.49535 v -2.55773 -2.55773 h -2.49535 -2.49535 v 2.55773 z m -25.452519,10.10615 h 5.11546 5.115459 v 2.55773 2.55773 h -5.115459 -5.11546 v -2.55773 z m 15.221609,0 h 5.11546 5.11545 v 2.55773 2.55773 h -5.11545 -5.11546 v -2.55773 z"
+ id="path4732"
+ inkscape:connector-curvature="0" />
+ </g>
+ <g
+ inkscape:label="Layer 1"
+ inkscape:groupmode="layer"
+ id="layer1"
+ style="display:inline"
+ transform="translate(-61.112559,-78.160466)">
+ <g
+ transform="translate(-131.49557,42.495842)"
+ style="fill:#007200;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ id="g4628">
+ <path
+ id="path4529"
+ d="m 239.09034,36.164616 -45.98169,26.547215 45.98169,26.547217 45.98117,-26.547217 z"
+ style="opacity:1;fill:#4aba19;fill-opacity:0.34509804;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ inkscape:connector-curvature="0" />
+ <path
+ id="path4531"
+ d="m 285.07151,62.711828 -45.98117,26.54722 v 53.094432 l 45.98117,-26.54722 z"
+ style="opacity:1;fill:#007900;fill-opacity:0.34351148;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ inkscape:connector-curvature="0" />
+ <path
+ inkscape:connector-curvature="0"
+ style="opacity:1;fill:#003d00;fill-opacity:0.34509804;stroke:#00a500;stroke-width:1;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1"
+ d="m 193.10865,62.711831 45.98117,26.54722 v 53.094429 l -45.98117,-26.54722 z"
+ id="path4541" />
+ </g>
+ </g>
+</svg>
diff --git a/g3doc/community.md b/g3doc/community.md
new file mode 100644
index 000000000..76f4d87c3
--- /dev/null
+++ b/g3doc/community.md
@@ -0,0 +1,31 @@
+# Participation
+
+To contribute code, please read the [contributing guide](../CONTRIBUTING.md).
+
+Please note that the [Code of Conduct](../CODE_OF_CONDUCT.md) applies to
+community forums as well as technical participation.
+
+## Communication channels
+
+The project maintains two mailing lists:
+
+* [gvisor-users][gvisor-users] for accouncements and general discussion.
+* [gvisor-dev][gvisor-dev] for development and contribution.
+
+We also have a [chat room hosted on Gitter][gitter-chat].
+
+We'd love to hear from you!
+
+## Community meetings
+
+The community calendar shows upcoming public meetings and opportunities to
+collaborate or discuss the project. Meetings are planned and announced ahead of
+time via the [gvisor-users][gvisor-users] mailing list.
+
+These meetings are public: anyone can join.
+
+<iframe src="https://calendar.google.com/calendar/b/1/embed?showTitle=0&amp;height=600&amp;wkst=1&amp;bgcolor=%23FFFFFF&amp;src=bd6f4k210u3ukmlj9b8vl053fk%40group.calendar.google.com&amp;color=%23AB8B00&amp;ctz=America%2FLos_Angeles" style="border-width:0" width="600" height="400" frameborder="0" scrolling="no"></iframe>
+
+[gitter-chat]: https://gitter.im/gvisor/community
+[gvisor-dev]: https://groups.google.com/forum/#!forum/gvisor-dev
+[gvisor-users]: https://groups.google.com/forum/#!forum/gvisor-users
diff --git a/g3doc/logo.txt b/g3doc/logo.txt
new file mode 100644
index 000000000..92f9cad5f
--- /dev/null
+++ b/g3doc/logo.txt
@@ -0,0 +1 @@
+The gVisor logo files are licensed under CC BY-SA 4.0 (Creative Commons Attribution-ShareAlike 4.0 International).
diff --git a/g3doc/roadmap.md b/g3doc/roadmap.md
new file mode 100644
index 000000000..06ea25a8b
--- /dev/null
+++ b/g3doc/roadmap.md
@@ -0,0 +1,49 @@
+# Roadmap
+
+gVisor [GitHub Issues][issues] serve as the source-of-truth for most work in
+flight. Specific performance and compatibility issues are generally tracked
+there. [GitHub Milestones][milestones] may be used to track larger features that
+span many issues. However, labels are also used to aggregate cross-cutting
+feature work.
+
+## Core Improvements
+
+Most gVisor work is focused on four areas.
+
+* [Performance][performance]: overall sandbox performance, including platform
+ performance, is a critical area for investment. This includes: network
+ performance (throughput and latency), file system performance (metadata and
+ data I/O), application switch and fault costs, etc. The goal of gVisor is to
+ provide sandboxing without a material performance or efficiency impact on
+ all but the most performance-sensitive applications.
+
+* [Compatibility][compatibility]: supporting a wide range of applications
+ requires supporting a large system API, including special system files (e.g.
+ proc, sys, dev, etc.). The goal of gVisor is to support the broad set of
+ applications that depend on a generic Linux API, rather than a specific
+ kernel version.
+
+* [Infrastructure & tooling][infrastructure]: the above goals require
+ aggressive testing and coverage, and well-established processes. This
+ includes adding appropriate system call coverage, end-to-end suites and
+ runtime tests.
+
+* [Integration][integration]: Container infrastructure is evolving rapidly and
+ becoming more complex, and gVisor must continuously implement relevant and
+ popular features to ensure that integration points remain robust and
+ feature-complete while preserving security guarantees.
+
+## Releases
+
+Releases are available on [GitHub][releases].
+
+As a convenience, binary packages are also published. Instructions for their use
+are available via the [Installation instructions](./user_guide/install.md).
+
+[issues]: https://github.com/google/gvisor/issues
+[milestones]: https://github.com/google/gvisor/milestones
+[releases]: https://github.com/google/gvisor/releases
+[performance]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+performance%22
+[integration]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+integration%22
+[compatibility]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+compatibility%22
+[infrastructure]: https://github.com/google/gvisor/issues?q=is%3Aopen+is%3Aissue+label%3A%22area%3A+tooling%22
diff --git a/g3doc/style.md b/g3doc/style.md
new file mode 100644
index 000000000..d10549fe9
--- /dev/null
+++ b/g3doc/style.md
@@ -0,0 +1,88 @@
+# Provisional style guide
+
+> These guidelines are new and may change. This note will be removed when
+> consensus is reached.
+
+Not all existing code will comply with this style guide, but new code should.
+Further, it is a goal to eventually update all existing code to be in
+compliance.
+
+## All code
+
+### Early exit
+
+All code, unless it substantially increases the line count or complexity, should
+use early exits from loops and functions where possible.
+
+## Go specific
+
+All Go code should comply with the [Go Code Review Comments][gostyle] and
+[Effective Go][effective_go] guides, as well as the additional guidelines
+described below.
+
+### Mutexes
+
+#### Naming
+
+Mutexes should be named mu or xxxMu. Mutexes as a general rule should not be
+exported. Instead, export methods which use the mutexes to avoid leaky
+abstractions.
+
+#### Location
+
+Mutexes should be sibling fields to the fields that they protect. Mutexes should
+not be declared as global variables, instead use a struct (anonymous ok, but
+naming conventions still apply).
+
+Mutexes should be ordered before the fields that they protect.
+
+#### Comments
+
+Mutexes should have a comment on their declaration explaining any ordering
+requirements (or pointing to where this information can be found), if
+applicable. There is no need for a comment explaining which fields are
+protected.
+
+Each field or variable protected by a mutex should state as such in a comment on
+the field or variable declaration.
+
+### Unused returns
+
+Unused returns should be explicitly ignored with underscores. If there is a
+function which is commonly used without using its return(s), a wrapper function
+should be declared which explicitly ignores the returns. That said, in many
+cases, it may make sense for the wrapper to check the returns.
+
+### Formatting verbs
+
+Built-in types should use their associated verbs (e.g. %d for integral types),
+but other types should use a %v variant, even if they implement fmt.Stringer.
+The built-in `error` type should use %w when formatted with `fmt.Errorf`, but
+only then.
+
+### Wrapping
+
+Comments should be wrapped at 80 columns with a 2 space tab size.
+
+Code does not need to be wrapped, but if wrapping would make it more readable,
+it should be wrapped with each subcomponent of the thing being wrapped on its
+own line. For example, if a struct is split between lines, each field should be
+on its own line.
+
+#### Example
+
+```go
+_ = exec.Cmd{
+ Path: "/foo/bar",
+ Args: []string{"-baz"},
+}
+```
+
+## C++ specific
+
+C++ code should conform to the [Google C++ Style Guide][cppstyle] and the
+guidelines described for tests.
+
+[cppstyle]: https://google.github.io/styleguide/cppguide.html
+[gostyle]: https://github.com/golang/go/wiki/CodeReviewComments
+[effective_go]: https://golang.org/doc/effective_go.html
diff --git a/g3doc/user_guide/BUILD b/g3doc/user_guide/BUILD
new file mode 100644
index 000000000..b69aee12c
--- /dev/null
+++ b/g3doc/user_guide/BUILD
@@ -0,0 +1,70 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "compatibility",
+ src = "compatibility.md",
+ category = "Compatibility",
+ permalink = "/docs/user_guide/compatibility/",
+ weight = "0",
+)
+
+doc(
+ name = "checkpoint_restore",
+ src = "checkpoint_restore.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/checkpoint_restore/",
+ weight = "60",
+)
+
+doc(
+ name = "debugging",
+ src = "debugging.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/debugging/",
+ weight = "70",
+)
+
+doc(
+ name = "FAQ",
+ src = "FAQ.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/faq/",
+ weight = "90",
+)
+
+doc(
+ name = "filesystem",
+ src = "filesystem.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/filesystem/",
+ weight = "40",
+)
+
+doc(
+ name = "networking",
+ src = "networking.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/networking/",
+ weight = "50",
+)
+
+doc(
+ name = "install",
+ src = "install.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/install/",
+ weight = "10",
+)
+
+doc(
+ name = "platforms",
+ src = "platforms.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/platforms/",
+ weight = "30",
+)
diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md
new file mode 100644
index 000000000..89df65e99
--- /dev/null
+++ b/g3doc/user_guide/FAQ.md
@@ -0,0 +1,122 @@
+# FAQ
+
+[TOC]
+
+### What operating systems are supported? {#supported-os}
+
+Today, gVisor requires Linux.
+
+### What CPU architectures are supported? {#supported-cpus}
+
+gVisor currently supports [x86_64/AMD64](https://en.wikipedia.org/wiki/X86-64)
+compatible processors. Preliminary support is also available for
+[ARM64](https://en.wikipedia.org/wiki/ARM_architecture#AArch64).
+
+### Do I need to modify my Linux application to use gVisor? {#modify-app}
+
+No. gVisor is capable of running unmodified Linux binaries.
+
+### What binary formats does gVisor support? {#supported-binaries}
+
+gVisor supports Linux
+[ELF](https://en.wikipedia.org/wiki/Executable_and_Linkable_Format) binaries.
+
+Binaries run in gVisor should be built for the
+[AMD64](https://en.wikipedia.org/wiki/X86-64) or
+[AArch64](https://en.wikipedia.org/wiki/ARM_architecture#AArch64) CPU
+architectures.
+
+### Can I run Docker images using gVisor? {#docker-images}
+
+Yes. Please see the [Docker Quick Start][docker].
+
+### Can I run Kubernetes pods using gVisor? {#k8s-pods}
+
+Yes. Please see the [Kubernetes Quick Start][k8s].
+
+### What's the security model? {#security-model}
+
+See the [Security Model][security-model].
+
+## Troubleshooting
+
+### My container runs fine with `runc` but fails with `runsc` {#app-compatibility}
+
+If you’re having problems running a container with `runsc` it’s most likely due
+to a compatibility issue or a missing feature in gVisor. See
+[Debugging][debugging].
+
+### When I run my container, docker fails with: `open /run/containerd/.../<containerid>/log.json: no such file or directory` {#memfd-create}
+
+You are using an older version of Linux which doesn't support `memfd_create`.
+
+This is tracked in [bug #268](https://gvisor.dev/issue/268).
+
+### When I run my container, docker fails with: `flag provided but not defined: -console` {#old-docker}
+
+You're using an old version of Docker. See [Docker Quick Start][docker].
+
+### I can’t see a file copied with: `docker cp` {#fs-cache}
+
+For performance reasons, gVisor caches directory contents, and therefore it may
+not realize a new file was copied to a given directory. To invalidate the cache
+and force a refresh, create a file under the directory in question and list the
+contents again.
+
+As a workaround, shared root filesystem can be enabled. See
+[Filesystem][filesystem].
+
+This bug is tracked in [bug #4](https://gvisor.dev/issue/4).
+
+Note that `kubectl cp` works because it does the copy by exec'ing inside the
+sandbox, and thus gVisor's internal cache is made aware of the new files and
+directories.
+
+### I'm getting an error like: `panic: unable to attach: operation not permitted` or `fork/exec /proc/self/exe: invalid argument: unknown` {#runsc-perms}
+
+Make sure that permissions and the owner is correct on the `runsc` binary.
+
+```bash
+sudo chown root:root /usr/local/bin/runsc
+sudo chmod 0755 /usr/local/bin/runsc
+```
+
+### I'm getting an error like `mount submount "/etc/hostname": creating mount with source ".../hostname": input/output error: unknown.` {#memlock}
+
+There is a bug in Linux kernel versions 5.1 to 5.3.15, 5.4.2, and 5.5. Upgrade
+to a newer kernel or add the following to
+`/lib/systemd/system/containerd.service` as a workaround.
+
+```
+LimitMEMLOCK=infinity
+```
+
+And run `systemctl daemon-reload && systemctl restart containerd` to restart
+containerd.
+
+See [issue #1765](https://gvisor.dev/issue/1765) for more details.
+
+### My container cannot resolve another container's name when using Docker user defined bridge {#docker-bridge}
+
+This is normally indicated by errors like `bad address 'container-name'` when
+trying to communicate to another container in the same network.
+
+Docker user defined bridge uses an embedded DNS server bound to the loopback
+interface on address 127.0.0.10. This requires access to the host network in
+order to communicate to the DNS server. runsc network is isolated from the host
+and cannot access the DNS server on the host network without breaking the
+sandbox isolation. There are a few different workarounds you can try:
+
+* Use default bridge network with `--link` to connect containers. Default
+ bridge doesn't use embedded DNS.
+* Use [`--network=host`][host-net] option in runsc, however beware that it
+ will use the host network stack and is less secure.
+* Use IPs instead of container names.
+* Use [Kubernetes][k8s]. Container name lookup works fine in Kubernetes.
+
+[security-model]: /docs/architecture_guide/security/
+[host-net]: /docs/user_guide/networking/#network-passthrough
+[debugging]: /docs/user_guide/debugging/
+[filesystem]: /docs/user_guide/filesystem/
+[docker]: /docs/user_guide/quick_start/docker/
+[k8s]: /docs/user_guide/quick_start/kubernetes/
diff --git a/g3doc/user_guide/checkpoint_restore.md b/g3doc/user_guide/checkpoint_restore.md
new file mode 100644
index 000000000..0ab0911b0
--- /dev/null
+++ b/g3doc/user_guide/checkpoint_restore.md
@@ -0,0 +1,101 @@
+# Checkpoint/Restore
+
+[TOC]
+
+gVisor has the ability to checkpoint a process, save its current state in a
+state file, and restore into a new container using the state file.
+
+## How to use checkpoint/restore
+
+Checkpoint/restore functionality is currently available via raw `runsc`
+commands. To use the checkpoint command, first run a container.
+
+```bash
+runsc run <container id>
+```
+
+To checkpoint the container, the `--image-path` flag must be provided. This is
+the directory path within which the checkpoint state-file will be created. The
+file will be called `checkpoint.img` and necessary directories will be created
+if they do not yet exist.
+
+> Note: Two checkpoints cannot be saved to the same directory; every image-path
+> provided must be unique.
+
+```bash
+runsc checkpoint --image-path=<path> <container id>
+```
+
+There is also an optional `--leave-running` flag that allows the container to
+continue to run after the checkpoint has been made. (By default, containers stop
+their processes after committing a checkpoint.)
+
+> Note: All top-level runsc flags needed when calling run must be provided to
+> checkpoint if --leave-running is used.
+
+> Note: --leave-running functions by causing an immediate restore so the
+> container, although will maintain its given container id, may have a different
+> process id.
+
+```bash
+runsc checkpoint --image-path=<path> --leave-running <container id>
+```
+
+To restore, provide the image path to the `checkpoint.img` file created during
+the checkpoint. Because containers stop by default after checkpointing, restore
+needs to happen in a new container (restore is a command which parallels start).
+
+```bash
+runsc create <container id>
+
+runsc restore --image-path=<path> <container id>
+```
+
+## How to use checkpoint/restore in Docker:
+
+Currently checkpoint/restore through `runsc` is not entirely compatible with
+Docker, although there has been progress made from both gVisor and Docker to
+enable compatibility. Here, we document the ideal workflow.
+
+Run a container:
+
+```bash
+docker run [options] --runtime=runsc <image>`
+```
+
+Checkpoint a container:
+
+```bash
+docker checkpoint create <container> <checkpoint_name>`
+```
+
+Create a new container into which to restore:
+
+```bash
+docker create [options] --runtime=runsc <image>
+```
+
+Restore a container:
+
+```bash
+docker start --checkpoint --checkpoint-dir=<directory> <container>
+```
+
+### Issues Preventing Compatibility with Docker
+
+- **[Moby #37360][leave-running]:** Docker version 18.03.0-ce and earlier
+ hangs when checkpointing and does not create the checkpoint. To successfully
+ use this feature, install a custom version of docker-ce from the moby
+ repository. This issue is caused by an improper implementation of the
+ `--leave-running` flag. This issue is fixed in newer releases.
+- **Docker does not support restoration into new containers:** Docker
+ currently expects the container which created the checkpoint to be the same
+ container used to restore which is not possible in runsc. When Docker
+ supports container migration and therefore restoration into new containers,
+ this will be the flow.
+- **[Moby #37344][checkpoint-dir]:** Docker does not currently support the
+ `--checkpoint-dir` flag but this will be required when restoring from a
+ checkpoint made in another container.
+
+[leave-running]: https://github.com/moby/moby/pull/37360
+[checkpoint-dir]: https://github.com/moby/moby/issues/37344
diff --git a/g3doc/user_guide/compatibility.md b/g3doc/user_guide/compatibility.md
new file mode 100644
index 000000000..9d3e3680f
--- /dev/null
+++ b/g3doc/user_guide/compatibility.md
@@ -0,0 +1,93 @@
+# Applications
+
+[TOC]
+
+gVisor implements a large portion of the Linux surface and while we strive to
+make it broadly compatible, there are (and always will be) unimplemented
+features and bugs. The only real way to know if it will work is to try. If you
+find a container that doesn’t work and there is no known issue, please
+[file a bug][bug] indicating the full command you used to run the image. You can
+view open issues related to compatibility [here][issues].
+
+If you're able to provide the [debug logs](../debugging/), the problem likely to
+be fixed much faster.
+
+## What works?
+
+The following applications/images have been tested:
+
+* elasticsearch
+* golang
+* httpd
+* java8
+* jenkins
+* mariadb
+* memcached
+* mongo
+* mysql
+* nginx
+* node
+* php
+* postgres
+* prometheus
+* python
+* redis
+* registry
+* tomcat
+* wordpress
+
+## Utilities
+
+Most common utilities work. Note that:
+
+* Some tools, such as `tcpdump` and old versions of `ping`, require explicitly
+ enabling raw sockets via the unsafe `--net-raw` runsc flag.
+* Different Docker images can behave differently. For example, Alpine Linux
+ and Ubuntu have different `ip` binaries.
+
+ Specific tools include:
+
+<!-- mdformat off(don't wrap the table) -->
+
+| Tool | Status |
+|:--------:|:-----------------------------------------:|
+| apt-get | Working. |
+| bundle | Working. |
+| cat | Working. |
+| curl | Working. |
+| dd | Working. |
+| df | Working. |
+| dig | Working. |
+| drill | Working. |
+| env | Working. |
+| find | Working. |
+| gdb | Working. |
+| gosu | Working. |
+| grep | Working (unless stdin is a pipe and stdout is /dev/null). |
+| ifconfig | Works partially, like ip. Full support [in progress](https://gvisor.dev/issue/578). |
+| ip | Some subcommands work (e.g. addr, route). Full support [in progress](https://gvisor.dev/issue/578). |
+| less | Working. |
+| ls | Working. |
+| lsof | Working. |
+| mount | Works in readonly mode. gVisor doesn't currently support creating new mounts at runtime. |
+| nc | Working. |
+| nmap | Not working. |
+| netstat | [In progress](https://gvisor.dev/issue/2112). |
+| nslookup | Working. |
+| ping | Working. |
+| ps | Working. |
+| route | Working. |
+| ss | [In progress](https://gvisor.dev/issue/2114). |
+| sshd | Partially working. Job control [in progress](https://gvisor.dev/issue/154). |
+| strace | Working. |
+| tar | Working. |
+| tcpdump | [In progress](https://gvisor.dev/issue/173). |
+| top | Working. |
+| uptime | Working. |
+| vim | Working. |
+| wget | Working. |
+
+<!-- mdformat on -->
+
+[bug]: https://github.com/google/gvisor/issues/new?title=Compatibility%20Issue:
+[issues]: https://github.com/google/gvisor/issues?q=is%3Aissue+is%3Aopen+label%3A%22area%3A+compatibility%22
diff --git a/g3doc/user_guide/containerd/BUILD b/g3doc/user_guide/containerd/BUILD
new file mode 100644
index 000000000..979d46105
--- /dev/null
+++ b/g3doc/user_guide/containerd/BUILD
@@ -0,0 +1,33 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "quick_start",
+ src = "quick_start.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/containerd/quick_start/",
+ subcategory = "Containerd",
+ weight = "10",
+)
+
+doc(
+ name = "configuration",
+ src = "configuration.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/containerd/configuration/",
+ subcategory = "Containerd",
+ weight = "90",
+)
+
+doc(
+ name = "containerd_11",
+ src = "containerd_11.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/containerd/containerd_11/",
+ subcategory = "Containerd",
+ weight = "99",
+)
diff --git a/g3doc/user_guide/containerd/configuration.md b/g3doc/user_guide/containerd/configuration.md
new file mode 100644
index 000000000..5d485c24b
--- /dev/null
+++ b/g3doc/user_guide/containerd/configuration.md
@@ -0,0 +1,70 @@
+# Containerd Advanced Configuration
+
+This document describes how to configure runtime options for
+`containerd-shim-runsc-v1`. This follows the
+[Containerd Quick Start](./quick_start.md) and requires containerd 1.2 or later.
+
+### Update `/etc/containerd/config.toml` to point to a configuration file for `containerd-shim-runsc-v1`.
+
+`containerd-shim-runsc-v1` supports a few different configuration options based
+on the version of containerd that is used. For versions >= 1.3, it supports a
+configurable `ConfigPath` in the containerd runtime configuration.
+
+```shell
+cat <<EOF | sudo tee /etc/containerd/config.toml
+disabled_plugins = ["restart"]
+[plugins.linux]
+ shim_debug = true
+[plugins.cri.containerd.runtimes.runsc]
+ runtime_type = "io.containerd.runsc.v1"
+[plugins.cri.containerd.runtimes.runsc.options]
+ TypeUrl = "io.containerd.runsc.v1.options"
+ # containerd 1.3 only!
+ ConfigPath = "/etc/containerd/runsc.toml"
+EOF
+```
+
+When you are done restart containerd to pick up the new configuration files.
+
+```shell
+sudo systemctl restart containerd
+```
+
+### Configure `/etc/containerd/runsc.toml`
+
+> Note: For containerd 1.2, the config file should named `config.toml` and
+> located in the runtime root. By default, this is `/run/containerd/runsc`.
+
+The set of options that can be configured can be found in
+[options.go](https://github.com/google/gvisor/blob/master/pkg/shim/v2/options/options.go).
+
+#### Example: Enable the KVM platform
+
+gVisor enables the use of a number of platforms. This example shows how to
+configure `containerd-shim-runsc-v1` to use gvisor with the KVM platform.
+
+Find out more about platform in the
+[Platforms Guide](../../architecture_guide/platforms.md).
+
+```shell
+cat <<EOF | sudo tee /etc/containerd/runsc.toml
+[runsc_config]
+platform = "kvm"
+EOF
+```
+
+### Example: Enable gVisor debug logging
+
+gVisor debug logging can be enabled by setting the `debug` and `debug-log` flag.
+The shim will replace "%ID%" with the container ID, and "%COMMAND%" with the
+runsc command (run, boot, etc.) in the path of the `debug-log` flag.
+
+Find out more about debugging in the [debugging guide](../debugging.md).
+
+```shell
+cat <<EOF | sudo tee /etc/containerd/runsc.toml
+[runsc_config]
+ debug=true
+ debug-log=/var/log/%ID%/gvisor.%COMMAND%.log
+EOF
+```
diff --git a/g3doc/user_guide/containerd/containerd_11.md b/g3doc/user_guide/containerd/containerd_11.md
new file mode 100644
index 000000000..50befbdf4
--- /dev/null
+++ b/g3doc/user_guide/containerd/containerd_11.md
@@ -0,0 +1,163 @@
+# Older Versions (containerd 1.1)
+
+This document describes how to install and run the `gvisor-containerd-shim`
+using the untrusted workload CRI extension. This requires `containerd` 1.1 or
+later.
+
+*Note: The untrusted workload CRI extension is deprecated by containerd and
+`gvisor-containerd-shim` is maintained on a best-effort basis. If you are using
+containerd 1.2+, please see the
+[containerd 1.2+ documentation](./quick_start.md) and use
+`containerd-shim-runsc-v1`.*
+
+## Requirements
+
+- **runsc** and **gvisor-containerd-shim**: See the
+ [installation guide](/docs/user_guide/install/).
+- **containerd**: See the [containerd website](https://containerd.io/) for
+ information on how to install containerd.
+
+## Configure containerd
+
+Create the configuration for the gvisor shim in
+`/etc/containerd/gvisor-containerd-shim.toml`:
+
+```shell
+cat <<EOF | sudo tee /etc/containerd/gvisor-containerd-shim.toml
+# This is the path to the default runc containerd-shim.
+runc_shim = "/usr/local/bin/containerd-shim"
+EOF
+```
+
+Update `/etc/containerd/config.toml`. Be sure to update the path to
+`gvisor-containerd-shim` and `runsc` if necessary:
+
+```shell
+cat <<EOF | sudo tee /etc/containerd/config.toml
+disabled_plugins = ["restart"]
+[plugins.linux]
+ shim = "/usr/local/bin/gvisor-containerd-shim"
+ shim_debug = true
+[plugins.cri.containerd.untrusted_workload_runtime]
+ runtime_type = "io.containerd.runtime.v1.linux"
+ runtime_engine = "/usr/local/bin/runsc"
+ runtime_root = "/run/containerd/runsc"
+EOF
+```
+
+Restart `containerd`:
+
+```shell
+sudo systemctl restart containerd
+```
+
+## Usage
+
+You can run containers in gVisor via containerd's CRI.
+
+### Install crictl
+
+Download and install the `crictl` binary:
+
+```shell
+{
+wget https://github.com/kubernetes-sigs/cri-tools/releases/download/v1.13.0/crictl-v1.13.0-linux-amd64.tar.gz
+tar xf crictl-v1.13.0-linux-amd64.tar.gz
+sudo mv crictl /usr/local/bin
+}
+```
+
+Write the `crictl` configuration file:
+
+```shell
+cat <<EOF | sudo tee /etc/crictl.yaml
+runtime-endpoint: unix:///run/containerd/containerd.sock
+EOF
+```
+
+### Create the nginx Sandbox in gVisor
+
+Pull the nginx image:
+
+```shell
+sudo crictl pull nginx
+```
+
+Create the sandbox creation request:
+
+```shell
+cat <<EOF | tee sandbox.json
+{
+ "metadata": {
+ "name": "nginx-sandbox",
+ "namespace": "default",
+ "attempt": 1,
+ "uid": "hdishd83djaidwnduwk28bcsb"
+ },
+ "annotations": {
+ "io.kubernetes.cri.untrusted-workload": "true"
+ },
+ "linux": {
+ },
+ "log_directory": "/tmp"
+}
+EOF
+```
+
+Create the pod in gVisor:
+
+```shell
+SANDBOX_ID=$(sudo crictl runp sandbox.json)
+```
+
+### Run the nginx Container in the Sandbox
+
+Create the nginx container creation request:
+
+```shell
+cat <<EOF | tee container.json
+{
+ "metadata": {
+ "name": "nginx"
+ },
+ "image":{
+ "image": "nginx"
+ },
+ "log_path":"nginx.0.log",
+ "linux": {
+ }
+}
+EOF
+```
+
+Create the nginx container:
+
+```shell
+CONTAINER_ID=$(sudo crictl create ${SANDBOX_ID} container.json sandbox.json)
+```
+
+Start the nginx container:
+
+```shell
+sudo crictl start ${CONTAINER_ID}
+```
+
+### Validate the container
+
+Inspect the created pod:
+
+```shell
+sudo crictl inspectp ${SANDBOX_ID}
+```
+
+Inspect the nginx container:
+
+```shell
+sudo crictl inspect ${CONTAINER_ID}
+```
+
+Verify that nginx is running in gVisor:
+
+```shell
+sudo crictl exec ${CONTAINER_ID} dmesg | grep -i gvisor
+```
diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md
new file mode 100644
index 000000000..2f67eecb3
--- /dev/null
+++ b/g3doc/user_guide/containerd/quick_start.md
@@ -0,0 +1,176 @@
+# Containerd Quick Start
+
+This document describes how to install and configure `containerd-shim-runsc-v1`
+using the containerd runtime handler support on `containerd` 1.2 or later.
+
+## Requirements
+
+- **runsc** and **containerd-shim-runsc-v1**: See the
+ [installation guide](/docs/user_guide/install/).
+- **containerd**: See the [containerd website](https://containerd.io/) for
+ information on how to install containerd.
+
+## Configure containerd
+
+Update `/etc/containerd/config.toml`. Make sure `containerd-shim-runsc-v1` is in
+`${PATH}` or in the same directory as `containerd` binary.
+
+```shell
+cat <<EOF | sudo tee /etc/containerd/config.toml
+disabled_plugins = ["restart"]
+[plugins.linux]
+ shim_debug = true
+[plugins.cri.containerd.runtimes.runsc]
+ runtime_type = "io.containerd.runsc.v1"
+EOF
+```
+
+Restart `containerd`:
+
+```shell
+sudo systemctl restart containerd
+```
+
+## Usage
+
+You can run containers in gVisor via containerd's CRI.
+
+### Install crictl
+
+Download and install the `crictl`` binary:
+
+```shell
+{
+wget https://github.com/kubernetes-sigs/cri-tools/releases/download/v1.13.0/crictl-v1.13.0-linux-amd64.tar.gz
+tar xf crictl-v1.13.0-linux-amd64.tar.gz
+sudo mv crictl /usr/local/bin
+}
+```
+
+Write the `crictl` configuration file:
+
+```shell
+cat <<EOF | sudo tee /etc/crictl.yaml
+runtime-endpoint: unix:///run/containerd/containerd.sock
+EOF
+```
+
+### Create the nginx sandbox in gVisor
+
+Pull the nginx image:
+
+```shell
+sudo crictl pull nginx
+```
+
+Create the sandbox creation request:
+
+```shell
+cat <<EOF | tee sandbox.json
+{
+ "metadata": {
+ "name": "nginx-sandbox",
+ "namespace": "default",
+ "attempt": 1,
+ "uid": "hdishd83djaidwnduwk28bcsb"
+ },
+ "linux": {
+ },
+ "log_directory": "/tmp"
+}
+EOF
+```
+
+Create the pod in gVisor:
+
+```shell
+SANDBOX_ID=$(sudo crictl runp --runtime runsc sandbox.json)
+```
+
+### Run the nginx container in the sandbox
+
+Create the nginx container creation request:
+
+```shell
+cat <<EOF | tee container.json
+{
+ "metadata": {
+ "name": "nginx"
+ },
+ "image":{
+ "image": "nginx"
+ },
+ "log_path":"nginx.0.log",
+ "linux": {
+ }
+}
+EOF
+```
+
+Create the nginx container:
+
+```shell
+CONTAINER_ID=$(sudo crictl create ${SANDBOX_ID} container.json sandbox.json)
+```
+
+Start the nginx container:
+
+```shell
+sudo crictl start ${CONTAINER_ID}
+```
+
+### Validate the container
+
+Inspect the created pod:
+
+```shell
+sudo crictl inspectp ${SANDBOX_ID}
+```
+
+Inspect the nginx container:
+
+```shell
+sudo crictl inspect ${CONTAINER_ID}
+```
+
+Verify that nginx is running in gVisor:
+
+```shell
+sudo crictl exec ${CONTAINER_ID} dmesg | grep -i gvisor
+```
+
+### Set up the Kubernetes RuntimeClass
+
+Install the RuntimeClass for gVisor:
+
+```shell
+cat <<EOF | kubectl apply -f -
+apiVersion: node.k8s.io/v1beta1
+kind: RuntimeClass
+metadata:
+ name: gvisor
+handler: runsc
+EOF
+```
+
+Create a Pod with the gVisor RuntimeClass:
+
+```shell
+cat <<EOF | kubectl apply -f -
+apiVersion: v1
+kind: Pod
+metadata:
+ name: nginx-gvisor
+spec:
+ runtimeClassName: gvisor
+ containers:
+ - name: nginx
+ image: nginx
+EOF
+```
+
+Verify that the Pod is running:
+
+```shell
+kubectl get pod nginx-gvisor -o wide
+```
diff --git a/g3doc/user_guide/debugging.md b/g3doc/user_guide/debugging.md
new file mode 100644
index 000000000..54fdce34f
--- /dev/null
+++ b/g3doc/user_guide/debugging.md
@@ -0,0 +1,141 @@
+# Debugging
+
+[TOC]
+
+To enable debug and system call logging, add the `runtimeArgs` below to your
+[Docker](../quick_start/docker/) configuration (`/etc/docker/daemon.json`):
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--debug-log=/tmp/runsc/",
+ "--debug",
+ "--strace"
+ ]
+ }
+ }
+}
+```
+
+> Note: the last `/` in `--debug-log` is needed to interpret it as a directory.
+> Then each `runsc` command executed will create a separate log file. Otherwise,
+> log messages from all commands will be appended to the same file.
+
+You may also want to pass `--log-packets` to troubleshoot network problems. Then
+restart the Docker daemon:
+
+```bash
+sudo systemctl restart docker
+```
+
+Run your container again, and inspect the files under `/tmp/runsc`. The log file
+ending with `.boot` will contain the strace logs from your application, which
+can be useful for identifying missing or broken system calls in gVisor. If you
+are having problems starting the container, the log file ending with `.create`
+may have the reason for the failure.
+
+## Stack traces
+
+The command `runsc debug --stacks` collects stack traces while the sandbox is
+running which can be useful to troubleshoot issues or just to learn more about
+gVisor. It connects to the sandbox process, collects a stack dump, and writes it
+to the console. For example:
+
+```bash
+docker run --runtime=runsc --rm -d alpine sh -c "while true; do echo running; sleep 1; done"
+63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+
+sudo runsc --root /var/run/docker/runtime-runsc/moby debug --stacks 63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+```
+
+> Note: `--root` variable is provided by docker and is normally set to
+> `/var/run/docker/runtime-[runtime-name]/moby`. If in doubt, `--root` is logged
+> to `runsc` logs.
+
+## Debugger
+
+You can debug gVisor like any other Golang program. If you're running with
+Docker, you'll need to find the sandbox PID and attach the debugger as root.
+Here is an example:
+
+```bash
+# Get a runsc with debug symbols (download nightly or build with symbols).
+bazel build -c dbg //runsc:runsc
+
+# Start the container you want to debug.
+docker run --runtime=runsc --rm --name=test -d alpine sleep 1000
+
+# Find the sandbox PID.
+docker inspect test | grep Pid | head -n 1
+
+# Attach your favorite debugger.
+sudo dlv attach <PID>
+
+# Set a breakpoint and resume.
+break mm.MemoryManager.MMap
+continue
+```
+
+## Profiling
+
+`runsc` integrates with Go profiling tools and gives you easy commands to
+profile CPU and heap usage. First you need to enable `--profile` in the command
+line options before starting the container:
+
+```json
+{
+ "runtimes": {
+ "runsc-prof": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--profile"
+ ]
+ }
+ }
+}
+```
+
+> Note: Enabling profiling loosens the seccomp protection added to the sandbox,
+> and should not be run in production under normal circumstances.
+
+Then restart docker to refresh the runtime options. While the container is
+running, execute `runsc debug` to collect profile information and save to a
+file. Here are the options available:
+
+* **--profile-heap:** Generates heap profile to the speficied file.
+* **--profile-cpu:** Enables CPU profiler, waits for `--duration` seconds and
+ generates CPU profile to the speficied file.
+
+For example:
+
+```bash
+docker run --runtime=runsc-prof --rm -d alpine sh -c "while true; do echo running; sleep 1; done"
+63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+
+sudo runsc --root /var/run/docker/runtime-runsc-prof/moby debug --profile-heap=/tmp/heap.prof 63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+sudo runsc --root /var/run/docker/runtime-runsc-prof/moby debug --profile-cpu=/tmp/cpu.prof --duration=30s 63254c6ab3a6989623fa1fb53616951eed31ac605a2637bb9ddba5d8d404b35b
+```
+
+The resulting files can be opened using `go tool pprof` or [pprof][]. The
+examples below create image file (`.svg`) with the heap profile and writes the
+top functions using CPU to the console:
+
+```bash
+go tool pprof -svg /usr/local/bin/runsc /tmp/heap.prof
+go tool pprof -top /usr/local/bin/runsc /tmp/cpu.prof
+```
+
+[pprof]: https://github.com/google/pprof/blob/master/doc/README.md
+
+### Docker Proxy
+
+When forwarding a port to the container, Docker will likely route traffic
+through the [docker-proxy][]. This proxy may make profiling noisy, so it can be
+helpful to bypass it. Do so by sending traffic directly to the container IP and
+port. e.g., if the `docker0` IP is `192.168.9.1`, the container IP is likely a
+subsequent IP, such as `192.168.9.2`.
+
+[docker-proxy]: https://windsock.io/the-docker-proxy/
diff --git a/g3doc/user_guide/filesystem.md b/g3doc/user_guide/filesystem.md
new file mode 100644
index 000000000..cd00762dd
--- /dev/null
+++ b/g3doc/user_guide/filesystem.md
@@ -0,0 +1,60 @@
+# Filesystem
+
+[TOC]
+
+gVisor accesses the filesystem through a file proxy, called the Gofer. The gofer
+runs as a separate process, that is isolated from the sandbox. Gofer instances
+communicate with their respective sentry using the 9P protocol. For another
+explanation see [What is gVisor?](../README.md).
+
+## Sandbox overlay
+
+To isolate the host filesystem from the sandbox, you can set a writable tmpfs
+overlay on top of the entire filesystem. All modifications are made to the
+overlay, keeping the host filesystem unmodified.
+
+> Note: All created and modified files are stored in memory inside the sandbox.
+
+To use the tmpfs overlay, add the following `runtimeArgs` to your Docker
+configuration (`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--overlay"
+ ]
+ }
+ }
+}
+```
+
+## Shared root filesystem
+
+The root filesystem is where the image is extracted and is not generally
+modified from outside the sandbox. This allows for some optimizations, like
+skipping checks to determine if a directory has changed since the last time it
+was cached, thus missing updates that may have happened. If you need to `docker
+cp` files inside the root filesystem, you may want to enable shared mode. Just
+be aware that file system access will be slower due to the extra checks that are
+required.
+
+> Note: External mounts are always shared.
+
+To use set the root filesystem shared, add the following `runtimeArgs` to your
+Docker configuration (`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--file-access=shared"
+ ]
+ }
+ }
+}
+```
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md
new file mode 100644
index 000000000..9afdd264d
--- /dev/null
+++ b/g3doc/user_guide/install.md
@@ -0,0 +1,157 @@
+# Installation
+
+[TOC]
+
+> Note: gVisor supports only x86\_64 and requires Linux 4.14.77+
+> ([older Linux](./networking.md#gso)).
+
+## Versions
+
+The `runsc` binaries and repositories are available in multiple versions and
+release channels. You should pick the version you'd like to install. For
+experimentation, the nightly release is recommended. For production use, the
+latest release is recommended.
+
+After selecting an appropriate release channel from the options below, proceed
+to the preferred installation mechanism: manual or from an `apt` repository.
+
+### HEAD
+
+Binaries are available for every commit on the `master` branch, and are
+available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/master/latest/runsc`
+
+Checksums for the release binary are at:
+
+`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512`
+
+For `apt` installation, use the `master` as the `${DIST}` below.
+
+### Nightly
+
+Nightly releases are built most nights from the master branch, and are available
+at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc`
+
+Checksums for the release binary are at:
+
+`https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc.sha512`
+
+Specific nightly releases can be found at:
+
+`https://storage.googleapis.com/gvisor/releases/nightly/${yyyy-mm-dd}/runsc`
+
+Note that a release may not be available for every day.
+
+For `apt` installation, use the `nightly` as the `${DIST}` below.
+
+### Latest release
+
+The latest official release is available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/release/latest`
+
+For `apt` installation, use the `release` as the `${DIST}` below.
+
+### Specific release
+
+A given release release is available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}`
+
+See the [releases][releases] page for information about specific releases.
+
+For `apt` installation of a specific release, which may include point updates,
+use the date of the release, e.g. `${yyyymmdd}`, as the `${DIST}` below.
+
+> Note: only newer releases may be available as `apt` repositories.
+
+### Point release
+
+A given point release is available at the following URL:
+
+`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}`
+
+Note that `apt` installation of a specific point release is not supported.
+
+## Install from an `apt` repository
+
+First, appropriate dependencies must be installed to allow `apt` to install
+packages via https:
+
+```bash
+sudo apt-get update && \
+sudo apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common
+```
+
+Next, the key used to sign archives should be added to your `apt` keychain:
+
+```bash
+curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
+```
+
+Based on the release type, you will need to substitute `${DIST}` below, using
+one of:
+
+* `master`: For HEAD.
+* `nightly`: For nightly releases.
+* `release`: For the latest release.
+* `${yyyymmdd}`: For a specific releases (see above).
+
+The repository for the release you wish to install should be added:
+
+```bash
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases ${DIST} main"
+```
+
+For example, to install the latest official release, you can use:
+
+```bash
+sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+```
+
+Now the runsc package can be installed:
+
+```bash
+sudo apt-get update && sudo apt-get install -y runsc
+```
+
+If you have Docker installed, it will be automatically configured.
+
+## Install directly
+
+The binary URLs provided above can be used to install directly. For example, the
+latest nightly binary can be downloaded, validated, and placed in an appropriate
+location by running:
+
+```bash
+(
+ set -e
+ URL=https://storage.googleapis.com/gvisor/releases/nightly/latest
+ wget ${URL}/runsc
+ wget ${URL}/runsc.sha512
+ sha512sum -c runsc.sha512
+ rm -f runsc.sha512
+ sudo mv runsc /usr/local/bin
+ sudo chown root:root /usr/local/bin/runsc
+ sudo chmod 0755 /usr/local/bin/runsc
+)
+```
+
+**It is important to copy this binary to a location that is accessible to all
+users, and ensure it is executable by all users**, since `runsc` executes itself
+as user `nobody` to avoid unnecessary privileges. The `/usr/local/bin` directory
+is a good place to put the `runsc` binary.
+
+After installation, try out `runsc` by following the
+[Docker Quick Start](./quick_start/docker.md) or
+[OCI Quick Start](./quick_start/oci.md).
+
+[releases]: https://github.com/google/gvisor/releases
diff --git a/g3doc/user_guide/networking.md b/g3doc/user_guide/networking.md
new file mode 100644
index 000000000..4aa394c91
--- /dev/null
+++ b/g3doc/user_guide/networking.md
@@ -0,0 +1,85 @@
+# Networking
+
+[TOC]
+
+gVisor implements its own network stack called [netstack][netstack]. All aspects
+of the network stack are handled inside the Sentry — including TCP connection
+state, control messages, and packet assembly — keeping it isolated from the host
+network stack. Data link layer packets are written directly to the virtual
+device inside the network namespace setup by Docker or Kubernetes.
+
+The IP address and routes configured for the device are transferred inside the
+sandbox. The loopback device runs exclusively inside the sandbox and does not
+use the host. You can inspect them by running:
+
+```bash
+docker run --rm --runtime=runsc alpine ip addr
+```
+
+## Network passthrough
+
+For high-performance networking applications, you may choose to disable the user
+space network stack and instead use the host network stack, including the
+loopback. Note that this mode decreases the isolation to the host.
+
+Add the following `runtimeArgs` to your Docker configuration
+(`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--network=host"
+ ]
+ }
+ }
+}
+```
+
+## Disabling external networking
+
+To completely isolate the host and network from the sandbox, external networking
+can be disabled. The sandbox will still contain a loopback provided by netstack.
+
+Add the following `runtimeArgs` to your Docker configuration
+(`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--network=none"
+ ]
+ }
+ }
+}
+```
+
+### Disable GSO {#gso}
+
+If your Linux is older than 4.14.17, you can disable Generic Segmentation
+Offload (GSO) to run with a kernel that is newer than 3.17. Add the
+`--gso=false` flag to your Docker runtime configuration
+(`/etc/docker/daemon.json`) and restart the Docker daemon:
+
+> Note: Network performance, especially for large payloads, will be greatly
+> reduced.
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--gso=false"
+ ]
+ }
+ }
+}
+```
+
+[netstack]: https://github.com/google/netstack
diff --git a/g3doc/user_guide/platforms.md b/g3doc/user_guide/platforms.md
new file mode 100644
index 000000000..752025881
--- /dev/null
+++ b/g3doc/user_guide/platforms.md
@@ -0,0 +1,95 @@
+# Changing Platforms
+
+[TOC]
+
+This guide described how to change the
+[platform](../architecture_guide/platforms.md) used by `runsc`.
+
+## Prerequisites
+
+If you intend to run the KVM platform, you will also to have KVM installed on
+your system. If you are running a Debian based system like Debian or Ubuntu you
+can usually do this by ensuring the module is loaded, and permissions are
+appropriately set on the `/dev/kvm` device.
+
+If you have an Intel CPU:
+
+```bash
+sudo modprobe kvm-intel && sudo chmod a+rw /dev/kvm
+```
+
+If you have an AMD CPU:
+
+```bash
+sudo modprobe kvm-amd && sudo chmod a+rw /dev/kvm
+```
+
+If you are using a virtual machine you will need to make sure that nested
+virtualization is configured. Here are links to documents on how to set up
+nested virtualization in several popular environments:
+
+* Google Cloud: [Enabling Nested Virtualization for VM Instances][nested-gcp]
+* Microsoft Azure:
+ [How to enable nested virtualization in an Azure VM][nested-azure]
+* VirtualBox: [Nested Virtualization][nested-virtualbox]
+* KVM: [Nested Guests][nested-kvm]
+
+***Note: nested virtualization will have poor performance and is historically a
+cause of security issues (e.g.
+[CVE-2018-12904](https://nvd.nist.gov/vuln/detail/CVE-2018-12904)). It is not
+recommended for production.***
+
+## Configuring Docker
+
+The platform is selected by the `--platform` command line flag passed to
+`runsc`. By default, the ptrace platform is selected. For example, to select the
+KVM platform, modify your Docker configuration (`/etc/docker/daemon.json`) to
+pass the `--platform` argument:
+
+```json
+{
+ "runtimes": {
+ "runsc": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--platform=kvm"
+ ]
+ }
+ }
+}
+```
+
+You must restart the Docker daemon after making changes to this file, typically
+this is done via `systemd`:
+
+```bash
+sudo systemctl restart docker
+```
+
+Note that you may configure multiple runtimes using different platforms. For
+example, the following configuration has one configuration for ptrace and one
+for the KVM platform:
+
+```json
+{
+ "runtimes": {
+ "runsc-ptrace": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--platform=ptrace"
+ ]
+ },
+ "runsc-kvm": {
+ "path": "/usr/local/bin/runsc",
+ "runtimeArgs": [
+ "--platform=kvm"
+ ]
+ }
+ }
+}
+```
+
+[nested-azure]: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/nested-virtualization
+[nested-gcp]: https://cloud.google.com/compute/docs/instances/enable-nested-virtualization-vm-instances
+[nested-virtualbox]: https://www.virtualbox.org/manual/UserManual.html#nested-virt
+[nested-kvm]: https://www.linux-kvm.org/page/Nested_Guests
diff --git a/g3doc/user_guide/quick_start/BUILD b/g3doc/user_guide/quick_start/BUILD
new file mode 100644
index 000000000..63f17f9cb
--- /dev/null
+++ b/g3doc/user_guide/quick_start/BUILD
@@ -0,0 +1,33 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "docker",
+ src = "docker.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/quick_start/docker/",
+ subcategory = "Quick Start",
+ weight = "11",
+)
+
+doc(
+ name = "oci",
+ src = "oci.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/quick_start/oci/",
+ subcategory = "Quick Start",
+ weight = "12",
+)
+
+doc(
+ name = "kubernetes",
+ src = "kubernetes.md",
+ category = "User Guide",
+ permalink = "/docs/user_guide/quick_start/kubernetes/",
+ subcategory = "Quick Start",
+ weight = "13",
+)
diff --git a/g3doc/user_guide/quick_start/docker.md b/g3doc/user_guide/quick_start/docker.md
new file mode 100644
index 000000000..6ad594ecc
--- /dev/null
+++ b/g3doc/user_guide/quick_start/docker.md
@@ -0,0 +1,96 @@
+# Docker Quick Start
+
+> Note: This guide requires Docker version 17.09.0 or greater. Refer to the
+> [Docker documentation][docker] for how to install it.
+
+This guide will help you quickly get started running Docker containers using
+gVisor.
+
+First, follow the [Installation guide][install].
+
+If you use the `apt` repository or the `automated` install, then you can skip
+the next section and proceed straight to running a container.
+
+## Configuring Docker
+
+First you will need to configure Docker to use `runsc` by adding a runtime entry
+to your Docker configuration (e.g. `/etc/docker/daemon.json`). The easiest way
+to this is via the `runsc install` command. This will install a docker runtime
+named "runsc" by default.
+
+```bash
+sudo runsc install
+```
+
+You may also wish to install a runtime entry for debugging. The `runsc install`
+command can accept options that will be passed to the runtime when it is invoked
+by Docker.
+
+```bash
+sudo runsc install --runtime runsc-debug -- \
+ --debug \
+ --debug-log=/tmp/runsc-debug.log \
+ --strace \
+ --log-packets
+```
+
+You must restart the Docker daemon after installing the runtime. Typically this
+is done via `systemd`:
+
+```bash
+sudo systemctl restart docker
+```
+
+## Running a container
+
+Now run your container using the `runsc` runtime:
+
+```bash
+docker run --runtime=runsc --rm hello-world
+```
+
+You can also run a terminal to explore the container.
+
+```bash
+docker run --runtime=runsc --rm -it ubuntu /bin/bash
+```
+
+Many docker options are compatible with gVisor, try them out. Here is an
+example:
+
+```bash
+docker run --runtime=runsc --rm --link backend:database -v ~/bin:/tools:ro -p 8080:80 --cpus=0.5 -it busybox telnet towel.blinkenlights.nl
+```
+
+## Verify the runtime
+
+You can verify that you are running in gVisor using the `dmesg` command.
+
+```text
+$ docker run --runtime=runsc -it ubuntu dmesg
+[ 0.000000] Starting gVisor...
+[ 0.354495] Daemonizing children...
+[ 0.564053] Constructing home...
+[ 0.976710] Preparing for the zombie uprising...
+[ 1.299083] Creating process schedule...
+[ 1.479987] Committing treasure map to memory...
+[ 1.704109] Searching for socket adapter...
+[ 1.748935] Generating random numbers by fair dice roll...
+[ 2.059747] Digging up root...
+[ 2.259327] Checking naughty and nice process list...
+[ 2.610538] Rewriting operating system in Javascript...
+[ 2.613217] Ready!
+```
+
+Note that this is easily replicated by an attacker so applications should never
+use `dmesg` to verify the runtime in a security sensitive context.
+
+Next, look at the different options available for gVisor: [platform][platforms],
+[network][networking], [filesystem][filesystem].
+
+[docker]: https://docs.docker.com/install/
+[storage-driver]: https://docs.docker.com/engine/reference/commandline/dockerd/#daemon-storage-driver
+[install]: /docs/user_guide/install/
+[filesystem]: /docs/user_guide/filesystem/
+[networking]: /docs/user_guide/networking/
+[platforms]: /docs/user_guide/platforms/
diff --git a/g3doc/user_guide/quick_start/kubernetes.md b/g3doc/user_guide/quick_start/kubernetes.md
new file mode 100644
index 000000000..395cd4b71
--- /dev/null
+++ b/g3doc/user_guide/quick_start/kubernetes.md
@@ -0,0 +1,34 @@
+# Kubernetes Quick Start
+
+gVisor can be used to run Kubernetes pods and has several integration points
+with Kubernetes.
+
+## Using Minikube
+
+gVisor can run sandboxed containers in a Kubernetes cluster with Minikube. After
+the gVisor addon is enabled, pods with a `gvisor` [Runtime Class][runtimeclass]
+set to true will execute with `runsc`. Follow [these instructions][minikube] to
+enable gVisor addon.
+
+## Using Containerd
+
+You can also setup Kubernetes nodes to run pods in gVisor using
+[containerd][containerd] and the gVisor containerd shim. You can find
+instructions in the [Containerd Quick Start][gvisor-containerd].
+
+## Using GKE Sandbox
+
+[GKE Sandbox][gke-sandbox] is available in [Google Kubernetes Engine][gke]. You
+just need to deploy a node pool with gVisor enabled in your cluster, and it will
+run pods annotated with `runtimeClassName: gvisor` inside a gVisor sandbox for
+you. [Here][wordpress-quick] is a quick example showing how to deploy a
+WordPress site. You can view the full documentation [here][gke-sandbox-docs].
+
+[containerd]: https://containerd.io/
+[minikube]: https://github.com/kubernetes/minikube/blob/master/deploy/addons/gvisor/README.md
+[gke]: https://cloud.google.com/kubernetes-engine/
+[gke-sandbox]: https://cloud.google.com/kubernetes-engine/sandbox/
+[gke-sandbox-docs]: https://cloud.google.com/kubernetes-engine/docs/how-to/sandbox-pods
+[gvisor-containerd]: /docs/user_guide/containerd/quick_start/
+[runtimeclass]: https://kubernetes.io/docs/concepts/containers/runtime-class/
+[wordpress-quick]: /docs/tutorials/kubernetes/
diff --git a/g3doc/user_guide/quick_start/oci.md b/g3doc/user_guide/quick_start/oci.md
new file mode 100644
index 000000000..e7768946b
--- /dev/null
+++ b/g3doc/user_guide/quick_start/oci.md
@@ -0,0 +1,43 @@
+# OCI Quick Start
+
+This guide will quickly get you started running your first gVisor sandbox
+container using the runtime directly with the default platform.
+
+First, follow the [Installation guide][install].
+
+## Run an OCI compatible container
+
+Now we will create an [OCI][oci] container bundle to run our container. First we
+will create a root directory for our bundle.
+
+```bash
+mkdir bundle
+cd bundle
+```
+
+Create a root file system for the container. We will use the Docker
+`hello-world` image as the basis for our container.
+
+```bash
+mkdir rootfs
+docker export $(docker create hello-world) | tar -xf - -C rootfs
+```
+
+Next, create an specification file called `config.json` that contains our
+container specification. We tell the container to run the `/hello` program.
+
+```bash
+runsc spec -- /hello
+```
+
+Finally run the container.
+
+```bash
+sudo runsc run hello
+```
+
+Next try [using CNI to set up networking](../../../tutorials/cni/) or
+[running gVisor using Docker](../docker/).
+
+[oci]: https://opencontainers.org/
+[install]: /docs/user_guide/install
diff --git a/g3doc/user_guide/tutorials/BUILD b/g3doc/user_guide/tutorials/BUILD
new file mode 100644
index 000000000..405026a33
--- /dev/null
+++ b/g3doc/user_guide/tutorials/BUILD
@@ -0,0 +1,37 @@
+load("//website:defs.bzl", "doc")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+doc(
+ name = "docker",
+ src = "docker.md",
+ category = "User Guide",
+ permalink = "/docs/tutorials/docker/",
+ subcategory = "Tutorials",
+ weight = "10",
+)
+
+doc(
+ name = "kubernetes",
+ src = "kubernetes.md",
+ category = "User Guide",
+ data = [
+ "add-node-pool.png",
+ "node-pool-button.png",
+ ],
+ permalink = "/docs/tutorials/kubernetes/",
+ subcategory = "Tutorials",
+ weight = "20",
+)
+
+doc(
+ name = "cni",
+ src = "cni.md",
+ category = "User Guide",
+ permalink = "/docs/tutorials/cni/",
+ subcategory = "Tutorials",
+ weight = "30",
+)
diff --git a/g3doc/user_guide/tutorials/add-node-pool.png b/g3doc/user_guide/tutorials/add-node-pool.png
new file mode 100644
index 000000000..e4560359b
--- /dev/null
+++ b/g3doc/user_guide/tutorials/add-node-pool.png
Binary files differ
diff --git a/g3doc/user_guide/tutorials/cni.md b/g3doc/user_guide/tutorials/cni.md
new file mode 100644
index 000000000..ce2fd09a8
--- /dev/null
+++ b/g3doc/user_guide/tutorials/cni.md
@@ -0,0 +1,174 @@
+# Using CNI
+
+This tutorial will show you how to set up networking for a gVisor sandbox using
+the
+[Container Networking Interface (CNI)](https://github.com/containernetworking/cni).
+
+## Install CNI Plugins
+
+First you will need to install the CNI plugins. CNI plugins are used to set up a
+network namespace that `runsc` can use with the sandbox.
+
+Start by creating the directories for CNI plugin binaries:
+
+```
+sudo mkdir -p /opt/cni/bin
+```
+
+Download the CNI plugins:
+
+```
+wget https://github.com/containernetworking/plugins/releases/download/v0.8.3/cni-plugins-linux-amd64-v0.8.3.tgz
+```
+
+Next, unpack the plugins into the CNI binary directory:
+
+```
+sudo tar -xvf cni-plugins-linux-amd64-v0.8.3.tgz -C /opt/cni/bin/
+```
+
+## Configure CNI Plugins
+
+This section will show you how to configure CNI plugins. This tutorial will use
+the "bridge" and "loopback" plugins which will create the necessary bridge and
+loopback devices in our network namespace. However, you should be able to use
+any CNI compatible plugin to set up networking for gVisor sandboxes.
+
+The bridge plugin configuration specifies the IP address subnet range for IP
+addresses that will be assigned to sandboxes as well as the network routing
+configuration. This tutorial will assign IP addresses from the `10.22.0.0/16`
+range and allow all outbound traffic, however you can modify this configuration
+to suit your use case.
+
+Create the bridge and loopback plugin configurations:
+
+```
+sudo mkdir -p /etc/cni/net.d
+
+sudo sh -c 'cat > /etc/cni/net.d/10-bridge.conf << EOF
+{
+ "cniVersion": "0.4.0",
+ "name": "mynet",
+ "type": "bridge",
+ "bridge": "cni0",
+ "isGateway": true,
+ "ipMasq": true,
+ "ipam": {
+ "type": "host-local",
+ "subnet": "10.22.0.0/16",
+ "routes": [
+ { "dst": "0.0.0.0/0" }
+ ]
+ }
+}
+EOF'
+
+sudo sh -c 'cat > /etc/cni/net.d/99-loopback.conf << EOF
+{
+ "cniVersion": "0.4.0",
+ "name": "lo",
+ "type": "loopback"
+}
+EOF'
+```
+
+## Create a Network Namespace
+
+For each gVisor sandbox you will create a network namespace and configure it
+using CNI. First, create a random network namespace name and then create the
+namespace.
+
+The network namespace path will then be `/var/run/netns/${CNI_CONTAINERID}`.
+
+```
+export CNI_PATH=/opt/cni/bin
+export CNI_CONTAINERID=$(printf '%x%x%x%x' $RANDOM $RANDOM $RANDOM $RANDOM)
+export CNI_COMMAND=ADD
+export CNI_NETNS=/var/run/netns/${CNI_CONTAINERID}
+
+sudo ip netns add ${CNI_CONTAINERID}
+```
+
+Next, run the bridge and loopback plugins to apply the configuration that was
+created earlier to the namespace. Each plugin outputs some JSON indicating the
+results of executing the plugin. For example, The bridge plugin's response
+includes the IP address assigned to the ethernet device created in the network
+namespace. Take note of the IP address for use later.
+
+```
+export CNI_IFNAME="eth0"
+sudo -E /opt/cni/bin/bridge < /etc/cni/net.d/10-bridge.conf
+export CNI_IFNAME="lo"
+sudo -E /opt/cni/bin/loopback < /etc/cni/net.d/99-loopback.conf
+```
+
+Get the IP address assigned to our sandbox:
+
+```
+POD_IP=$(sudo ip netns exec ${CNI_CONTAINERID} ip -4 addr show eth0 | grep -oP '(?<=inet\s)\d+(\.\d+){3}')
+```
+
+## Create the OCI Bundle
+
+Now that our network namespace is created and configured, we can create the OCI
+bundle for our container. As part of the bundle's `config.json` we will specify
+that the container use the network namespace that we created.
+
+The container will run a simple python webserver that we will be able to connect
+to via the IP address assigned to it via the bridge CNI plugin.
+
+Create the bundle and root filesystem directories:
+
+```
+sudo mkdir -p bundle
+cd bundle
+sudo mkdir rootfs
+sudo docker export $(docker create python) | sudo tar --same-owner -pxf - -C rootfs
+sudo mkdir -p rootfs/var/www/html
+sudo sh -c 'echo "Hello World!" > rootfs/var/www/html/index.html'
+```
+
+Next create the `config.json` specifying the network namespace.
+
+```
+sudo /usr/local/bin/runsc spec \
+ --cwd /var/www/html \
+ --netns /var/run/netns/${CNI_CONTAINERID} \
+ -- python -m http.server
+```
+
+## Run the Container
+
+Now we can run and connect to the webserver. Run the container in gVisor. Use
+the same ID used for the network namespace to be consistent:
+
+```
+sudo runsc run -detach ${CNI_CONTAINERID}
+```
+
+Connect to the server via the sandbox's IP address:
+
+```
+curl http://${POD_IP}:8000/
+```
+
+You should see the server returning `Hello World!`.
+
+## Cleanup
+
+After you are finished running the container, you can clean up the network
+namespace .
+
+```
+sudo runsc kill ${CNI_CONTAINERID}
+sudo runsc delete ${CNI_CONTAINERID}
+
+export CNI_COMMAND=DEL
+
+export CNI_IFNAME="lo"
+sudo -E /opt/cni/bin/loopback < /etc/cni/net.d/99-loopback.conf
+export CNI_IFNAME="eth0"
+sudo -E /opt/cni/bin/bridge < /etc/cni/net.d/10-bridge.conf
+
+sudo ip netns delete ${CNI_CONTAINERID}
+```
diff --git a/g3doc/user_guide/tutorials/docker.md b/g3doc/user_guide/tutorials/docker.md
new file mode 100644
index 000000000..705560038
--- /dev/null
+++ b/g3doc/user_guide/tutorials/docker.md
@@ -0,0 +1,68 @@
+# WordPress with Docker
+
+This page shows you how to deploy a sample [WordPress][wordpress] site using
+[Docker][docker].
+
+### Before you begin
+
+[Follow these instructions][docker-install] to install runsc with Docker. This
+document assumes that the runtime name chosen is `runsc`.
+
+### Running WordPress
+
+Now, let's deploy a WordPress site using Docker. WordPress site requires two
+containers: web server in the frontend, MySQL database in the backend.
+
+First, let's define a few environment variables that are shared between both
+containers:
+
+```bash
+export MYSQL_PASSWORD=${YOUR_SECRET_PASSWORD_HERE?}
+export MYSQL_DB=wordpress
+export MYSQL_USER=wordpress
+```
+
+Next, let's start the database container running MySQL and wait until the
+database is initialized:
+
+```bash
+docker run --runtime=runsc --name mysql -d \
+ -e MYSQL_RANDOM_ROOT_PASSWORD=1 \
+ -e MYSQL_PASSWORD="${MYSQL_PASSWORD}" \
+ -e MYSQL_DATABASE="${MYSQL_DB}" \
+ -e MYSQL_USER="${MYSQL_USER}" \
+ mysql:5.7
+
+# Wait until this message appears in the log.
+docker logs mysql |& grep 'port: 3306 MySQL Community Server (GPL)'
+```
+
+Once the database is running, you can start the WordPress frontend. We use the
+`--link` option to connect the frontend to the database, and expose the
+WordPress to port 8080 on the localhost.
+
+```bash
+docker run --runtime=runsc --name wordpress -d \
+ --link mysql:mysql \
+ -p 8080:80 \
+ -e WORDPRESS_DB_HOST=mysql \
+ -e WORDPRESS_DB_USER="${MYSQL_USER}" \
+ -e WORDPRESS_DB_PASSWORD="${MYSQL_PASSWORD}" \
+ -e WORDPRESS_DB_NAME="${MYSQL_DB}" \
+ -e WORDPRESS_TABLE_PREFIX=wp_ \
+ wordpress
+```
+
+Now, you can access the WordPress website pointing your favorite browser to
+<http://localhost:8080>.
+
+Congratulations! You have just deployed a WordPress site using Docker.
+
+### What's next
+
+[Learn how to deploy WordPress with Kubernetes][wordpress-k8s].
+
+[docker]: https://www.docker.com/
+[docker-install]: /docs/user_guide/quick_start/docker/
+[wordpress]: https://wordpress.com/
+[wordpress-k8s]: /docs/tutorials/kubernetes/
diff --git a/g3doc/user_guide/tutorials/kubernetes.md b/g3doc/user_guide/tutorials/kubernetes.md
new file mode 100644
index 000000000..d2a94b1b7
--- /dev/null
+++ b/g3doc/user_guide/tutorials/kubernetes.md
@@ -0,0 +1,134 @@
+# WordPress with Kubernetes
+
+This page shows you how to deploy a sample [WordPress][wordpress] site using
+[GKE Sandbox][gke-sandbox].
+
+### Before you begin
+
+Take the following steps to enable the Kubernetes Engine API:
+
+1. Visit the [Kubernetes Engine page][project-selector] in the Google Cloud
+ Platform Console.
+1. Create or select a project.
+
+### Creating a node pool with gVisor enabled
+
+Create a node pool inside your cluster with option `--sandbox type=gvisor` added
+to the command, like below:
+
+```bash
+gcloud beta container node-pools create sandbox-pool --cluster=${CLUSTER_NAME} --image-type=cos_containerd --sandbox type=gvisor
+```
+
+If you prefer to use the console, select your cluster and select the **ADD NODE
+POOL** button:
+
+![+ ADD NODE POOL](./node-pool-button.png)
+
+Then select the **Image type** with **Containerd** and select **Enable sandbox
+with gVisor** option. Select other options as you like:
+
+![+ NODE POOL](./add-node-pool.png)
+
+### Check that gVisor is enabled
+
+The gvisor RuntimeClass is instantiated during node creation. You can check for
+the existence of the gvisor RuntimeClass using the following command:
+
+```bash
+kubectl get runtimeclasses
+```
+
+### Wordpress deployment
+
+Now, let's deploy a WordPress site using GKE Sandbox. WordPress site requires
+two pods: web server in the frontend, MySQL database in the backend. Both
+applications use PersistentVolumes to store the site data data. In addition,
+they use secret store to share MySQL password between them.
+
+First, let's download the deployment configuration files to add the runtime
+class annotation to them:
+
+```bash
+curl -LO https://k8s.io/examples/application/wordpress/wordpress-deployment.yaml
+curl -LO https://k8s.io/examples/application/wordpress/mysql-deployment.yaml
+```
+
+Add a **spec.template.spec.runtimeClassName** set to **gvisor** to both files,
+as shown below:
+
+**wordpress-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata:
+name: wordpress labels: app: wordpress spec: ports: - port: 80 selector: app:
+wordpress tier: frontend
+
+## type: LoadBalancer
+
+apiVersion: v1 kind: PersistentVolumeClaim metadata: name: wp-pv-claim labels:
+app: wordpress spec: accessModes: - ReadWriteOnce resources: requests:
+
+## storage: 20Gi
+
+apiVersion: apps/v1 kind: Deployment metadata: name: wordpress labels: app:
+wordpress spec: selector: matchLabels: app: wordpress tier: frontend strategy:
+type: Recreate template: metadata: labels: app: wordpress tier: frontend spec:
+runtimeClassName: gvisor # ADD THIS LINE containers: - image:
+wordpress:4.8-apache name: wordpress env: - name: WORDPRESS_DB_HOST value:
+wordpress-mysql - name: WORDPRESS_DB_PASSWORD valueFrom: secretKeyRef: name:
+mysql-pass key: password ports: - containerPort: 80 name: wordpress
+volumeMounts: - name: wordpress-persistent-storage mountPath: /var/www/html
+volumes: - name: wordpress-persistent-storage persistentVolumeClaim: claimName:
+wp-pv-claim ```
+
+**mysql-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata: name:
+wordpress-mysql labels: app: wordpress spec: ports: - port: 3306 selector: app:
+wordpress tier: mysql
+
+## clusterIP: None
+
+apiVersion: v1 kind: PersistentVolumeClaim metadata: name: mysql-pv-claim
+labels: app: wordpress spec: accessModes: - ReadWriteOnce resources: requests:
+
+## storage: 20Gi
+
+apiVersion: apps/v1 kind: Deployment metadata: name: wordpress-mysql labels:
+app: wordpress spec: selector: matchLabels: app: wordpress tier: mysql strategy:
+type: Recreate template: metadata: labels: app: wordpress tier: mysql spec:
+runtimeClassName: gvisor # ADD THIS LINE containers: - image: mysql:5.6 name:
+mysql env: - name: MYSQL_ROOT_PASSWORD valueFrom: secretKeyRef: name: mysql-pass
+key: password ports: - containerPort: 3306 name: mysql volumeMounts: - name:
+mysql-persistent-storage mountPath: /var/lib/mysql volumes: - name:
+mysql-persistent-storage persistentVolumeClaim: claimName: mysql-pv-claim ```
+
+Note that apart from `runtimeClassName: gvisor`, nothing else about the
+Deployment has is changed.
+
+You are now ready to deploy the entire application. Just create a secret to
+store MySQL's password and *apply* both deployments:
+
+```bash
+kubectl create secret generic mysql-pass --from-literal=password=${YOUR_SECRET_PASSWORD_HERE?}
+kubectl apply -f mysql-deployment.yaml
+kubectl apply -f wordpress-deployment.yaml
+```
+
+Wait for the deployments to be ready and an external IP to be assigned to the
+Wordpress service:
+
+```bash
+watch kubectl get service wordpress
+```
+
+Now, copy the service `EXTERNAL-IP` from above to your favorite browser to view
+and configure your new WordPress site.
+
+Congratulations! You have just deployed a WordPress site using GKE Sandbox.
+
+### What's next
+
+To learn more about GKE Sandbox and how to run your deployment securely, take a
+look at the [documentation][gke-sandbox-docs].
+
+[gke-sandbox-docs]: https://cloud.google.com/kubernetes-engine/docs/how-to/sandbox-pods
+[gke-sandbox]: https://cloud.google.com/kubernetes-engine/sandbox/
+[project-selector]: https://console.cloud.google.com/projectselector/kubernetes
+[wordpress]: https://wordpress.com/
diff --git a/g3doc/user_guide/tutorials/node-pool-button.png b/g3doc/user_guide/tutorials/node-pool-button.png
new file mode 100644
index 000000000..bee0c11dc
--- /dev/null
+++ b/g3doc/user_guide/tutorials/node-pool-button.png
Binary files differ
diff --git a/go.mod b/go.mod
index 821273d22..2fcba5cc9 100644
--- a/go.mod
+++ b/go.mod
@@ -1,21 +1,52 @@
module gvisor.dev/gvisor
-go 1.12
+go 1.14
+
+replace github.com/Sirupsen/logrus => github.com/sirupsen/logrus v1.6.0
require (
- github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422
- github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079
- github.com/golang/mock v1.3.1
- github.com/golang/protobuf v1.3.1
- github.com/google/btree v1.0.0
- github.com/google/go-cmp v0.2.0
- github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8
- github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d
- github.com/kr/pty v1.1.1
- github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78
- github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2
- github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e
- github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936
- golang.org/x/net v0.0.0-20190311183353-d8887717615a
- golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
+ cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 // indirect
+ github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 // indirect
+ github.com/Microsoft/hcsshim v0.8.6 // indirect
+ github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect
+ github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect
+ github.com/containerd/containerd v1.3.4 // indirect
+ github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe // indirect
+ github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect
+ github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect
+ github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 // indirect
+ github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 // indirect
+ github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect
+ github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible // indirect
+ github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 // indirect
+ github.com/docker/go-connections v0.3.0 // indirect
+ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect
+ github.com/docker/go-units v0.4.0 // indirect
+ github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b // indirect
+ github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect
+ github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect
+ github.com/gogo/googleapis v1.4.0 // indirect
+ github.com/golang/protobuf v1.4.2 // indirect
+ github.com/google/go-cmp v0.5.0 // indirect
+ github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect
+ github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect
+ github.com/hashicorp/go-multierror v1.0.0 // indirect
+ github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect
+ github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect
+ github.com/opencontainers/go-digest v1.0.0 // indirect
+ github.com/opencontainers/image-spec v1.0.1 // indirect
+ github.com/opencontainers/runc v0.1.1 // indirect
+ github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f // indirect
+ github.com/pborman/uuid v1.2.0 // indirect
+ github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect
+ github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5 // indirect
+ github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect
+ github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe // indirect
+ go.uber.org/atomic v1.6.0 // indirect
+ go.uber.org/multierr v1.2.0 // indirect
+ golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
+ golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a // indirect
+ google.golang.org/grpc v1.29.0 // indirect
+ gopkg.in/yaml.v2 v2.2.8 // indirect
+ gotest.tools v2.2.0+incompatible // indirect
)
diff --git a/go.sum b/go.sum
index 7a0bc175a..f98132971 100644
--- a/go.sum
+++ b/go.sum
@@ -1,19 +1,387 @@
-github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
+bazil.org/fuse v0.0.0-20160811212531-371fbbdaa898/go.mod h1:Xbm+BRKSBEpa4q4hTSxohYNQpsxXPbPry4JJWOB3LB8=
+cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
+cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU=
+cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU=
+cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY=
+cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc=
+cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0=
+cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo=
+cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4=
+cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o=
+cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE=
+cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I=
+cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
+dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
+github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
+github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
+github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/Microsoft/go-winio v0.4.14 h1:+hMXMk01us9KgxGb7ftKQt2Xpf5hH/yky+TDA+qxleU=
+github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA=
+github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA=
+github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw=
+github.com/Microsoft/hcsshim v0.8.6/go.mod h1:Op3hHsoHPAvb6lceZHDtd9OkTew38wNoXnJs8iY7rUg=
+github.com/Microsoft/hcsshim v0.8.7/go.mod h1:OHd7sQqRFrYd3RmSgbgji+ctCwkbq2wbEYNSzOYtcBQ=
+github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8=
+github.com/Microsoft/hcsshim v0.8.9 h1:VrfodqvztU8YSOvygU+DN1BGaSGxmrNfqOv5oOuX2Bk=
+github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8=
+github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
+github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4=
+github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM=
+github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
+github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
+github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
+github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
+github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
+github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY=
+github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI=
+github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f h1:tSNMc+rJDfmYntojat8lljbt1mgKNpTxUZJsSzJ9Y1s=
+github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f/go.mod h1:OApqhQ4XNSNC13gXIwDjhOQxjWa/NxkwZXJ1EvqT0ko=
+github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw=
+github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc=
+github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE=
+github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
+github.com/containerd/containerd v1.3.4 h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=
+github.com/containerd/containerd v1.3.4/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA=
+github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y=
+github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k=
+github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe/go.mod h1:cECdGN1O8G9bgKTlLhuPJimka6Xb/Gg7vYzCTNVxhvo=
+github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI=
+github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw=
+github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0=
+github.com/containerd/go-runc v0.0.0-20180907222934-5a6d9f37cfa3/go.mod h1:IV7qH3hrUgRmyYrtgEeGWJfWbgcHL9CSRruz2Vqcph0=
+github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 h1:PRTagVMbJcCezLcHXe8UJvR1oBzp2lG3CEumeFOLOds=
+github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328/go.mod h1:PpyHrqVs8FTi9vpyHwPwiNEGaACDxT/N/pLcvMSRA9g=
+github.com/containerd/ttrpc v0.0.0-20190828154514-0e0f228740de/go.mod h1:PvCDdDGpgqzQIzDW1TphrGLssLDZp2GuS+X5DkEJB8o=
+github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 h1:+jgiLE5QylzgADj0Yldb4id1NQNRrDOROj7KDvY9PEc=
+github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15/go.mod h1:UAxOpgT9ziI0gJrmKvgcZivgxOp8iFPSk8httJEt98Y=
+github.com/containerd/typeurl v0.0.0-20180627222232-a93fcdb778cd/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc=
+github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o=
+github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737/go.mod h1:TB1hUtrpaiO88KEK56ijojHS1+NeF0izUACaJW2mdXg=
+github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
+github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU=
+github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE=
+github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
+github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w=
+github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
+github.com/docker/go-connections v0.3.0 h1:3lOnM9cSzgGwx8VfK/NGOW5fLQ0GjIlCkaktF+n1M6o=
+github.com/docker/go-connections v0.3.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
+github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ+oDZB4KHQFypsfjYlq/C4rfL7D3g8=
+github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA=
+github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
+github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
+github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U=
+github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y=
+github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
+github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
+github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
+github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
+github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
+github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
+github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=
+github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4=
+github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
+github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI=
+github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
+github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
+github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls=
+github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
+github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
+github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
+github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=
+github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
+github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
+github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s=
github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y=
+github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
+github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
+github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
+github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
+github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
+github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
+github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
+github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
+github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
+github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
-github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
-github.com/google/uuid v0.0.0-20171129191014-dec09d789f3d/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
+github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=
+github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU=
+github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ=
+github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
+github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
+github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
+github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
+github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
+github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
+github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM=
+github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
+github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
+github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
+github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
+github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
+github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
+github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=
+github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk=
+github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
+github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
+github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
+github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
+github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
+github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
+github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=
+github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
+github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
+github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
+github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
+github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=
+github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
+github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M=
+github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
+github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
+github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c=
+github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
+github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
+github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
+github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s=
+github.com/opencontainers/go-digest v1.0.0-rc1/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s=
+github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
+github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
+github.com/opencontainers/image-spec v1.0.1 h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI=
+github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0=
+github.com/opencontainers/runc v0.0.0-20190115041553-12f6a991201f/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U=
+github.com/opencontainers/runc v0.1.1 h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y=
+github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U=
+github.com/opencontainers/runtime-spec v0.1.2-0.20190507144316-5b71a03e2700/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/opencontainers/runtime-spec v1.0.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8=
+github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
+github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g=
+github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
+github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
+github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
+github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
+github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA=
+github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
+github.com/sirupsen/logrus v1.0.4-0.20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc=
+github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
+github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
+github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
+github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
+github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
+github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
+github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
+github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8=
github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
-github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
-github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI=
+github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
+github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w=
+github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
+github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so=
+github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
+github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
+go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
+go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
+go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs=
+go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
+go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk=
+go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
+go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=
+go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
+golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
+golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
+golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
+golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
+golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
+golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek=
+golang.org/x/exp v0.0.0-20191227195350-da58074b4299 h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg=
+golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
+golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
+golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
+golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
+golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
+golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
+golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
+golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE=
+golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs=
+golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
+golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o=
+golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
+golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
+golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
+golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=
+golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
+golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
+golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
+golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=
+golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
+golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
+golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
+golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
+golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=
+golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
+golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=
+golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190514135907-3a4b5fb9f71f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200523222454-059865788121 h1:rITEj+UZHYC927n8GT97eC3zrpzXdb/voyeOuVKS46o=
+golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
+golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
+golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
+golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
+golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
+golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
+golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
+golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
+golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE=
+golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
+golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
+google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
+google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
+google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg=
+google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI=
+google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
+google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
+google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
+google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
+google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
+google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
+google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
+google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
+google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
+google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8=
+google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=
+google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc=
+google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
+google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
+google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=
+google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
+google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
+google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
+google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
+google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo=
+google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
+google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
+google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
+google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
+google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
+google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
+google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=
+google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
+gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
+gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
+gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo=
+gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
+gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
+gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
+honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
+honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=
+honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
+rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
diff --git a/images/BUILD b/images/BUILD
new file mode 100644
index 000000000..a50f388e9
--- /dev/null
+++ b/images/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"])
+
+# The images filegroup is definitely not a hermetic target, and requires Make
+# to do anything meaningful with. However, this will be slurped up and used by
+# the tools/installer/images.sh installer, which will ensure that all required
+# images are available locally when running vm_tests.
+filegroup(
+ name = "images",
+ srcs = glob(["**"]),
+ visibility = ["//tools/installers:__pkg__"],
+)
diff --git a/images/Makefile b/images/Makefile
new file mode 100644
index 000000000..278dec02f
--- /dev/null
+++ b/images/Makefile
@@ -0,0 +1,100 @@
+#!/usr/bin/make -f
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ARCH is the architecture used for the build. This may be overriden at the
+# command line in order to perform a cross-build (in a limited capacity).
+ARCH := $(shell uname -m)
+
+# Note that the image prefixes used here must match the image mangling in
+# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all
+# tests are using locally-defined images (that are consistent and idempotent).
+REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit
+LOCAL_IMAGE_PREFIX ?= gvisor.dev/images
+ALL_IMAGES := $(subst /,_,$(subst ./,,$(shell find . -name Dockerfile -exec dirname {} \;)))
+ifneq ($(ARCH),$(shell uname -m))
+DOCKER_PLATFORM_ARGS := --platform=$(ARCH)
+else
+DOCKER_PLATFORM_ARGS :=
+endif
+
+list-all-images:
+ @for image in $(ALL_IMAGES); do echo $${image}; done
+.PHONY: list-build-images
+
+# Handy wrapper to allow load-all-images, push-all-images, etc.
+%-all-images:
+ @$(MAKE) $(patsubst %,$*-%,$(ALL_IMAGES))
+load-all-images:
+ @$(MAKE) $(patsubst %,load-%,$(ALL_IMAGES))
+
+# Handy wrapper to load specified "groups", e.g. load-basic-images, etc.
+load-%-images:
+ @$(MAKE) $(patsubst %,load-%,$(subst /,_,$(subst ./,,$(shell find ./$* -name Dockerfile -exec dirname {} \;))))
+
+# tag is a function that returns the tag name, given an image.
+#
+# The tag constructed is used to memoize the image generated (see README.md).
+# This scheme is used to enable aggressive caching in a central repository, but
+# ensuring that images will always be sourced using the local files if there
+# are changes.
+path = $(subst _,/,$(1))
+tag = $(shell find $(call path,$(1)) -type f -print | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16)
+remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH):$(call tag,$(1))
+local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1))
+
+# rebuild builds the image locally. Only the "remote" tag will be applied. Note
+# we need to explicitly repull the base layer in order to ensure that the
+# architecture is correct. Note that we use the term "rebuild" here to avoid
+# conflicting with the bazel "build" terminology, which is used elsewhere.
+rebuild-%: FROM=$(shell grep FROM $(call path,$*)/Dockerfile } cut -d' ' -f2)
+rebuild-%: register-cross
+ $(foreach IMAGE,$(FROM),docker $(DOCKER_PLATFORM_ARGS) $(IMAGE); &&) true
+ T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \
+ docker build $(DOCKER_PLATFORM_ARGS) -t $(call remote_image,$*) $$T && \
+ rm -rf $$T
+
+# pull will check the "remote" image and pull if necessary. If the remote image
+# must be pulled, then it will tag with the latest local target. Note that pull
+# may fail if the remote image is not available.
+pull-%:
+ docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*)
+
+# load will either pull the "remote" or build it locally. This is the preferred
+# entrypoint, as it should never file. The local tag should always be set after
+# this returns (either by the pull or the build).
+load-%:
+ docker inspect $(call remote_image,$*) >/dev/null 2>&1 || $(MAKE) pull-$* || $(MAKE) rebuild-$*
+ docker tag $(call remote_image,$*) $(call local_image,$*)
+
+# push pushes the remote image, after either pulling (to validate that the tag
+# already exists) or building manually.
+push-%: load-%
+ docker push $(call remote_image,$*)
+
+# register-cross registers the necessary qemu binaries for cross-compilation.
+# This may be used by any target that may execute containers that are not the
+# native format.
+register-cross:
+ifneq ($(ARCH),$(shell uname -m))
+ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*))
+ docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes
+else
+ @true # Already registered.
+endif
+else
+ @true # No cross required.
+endif
+.PHONY: register-cross
diff --git a/images/README.md b/images/README.md
new file mode 100644
index 000000000..9880946a6
--- /dev/null
+++ b/images/README.md
@@ -0,0 +1,70 @@
+# Container Images
+
+This directory contains all images used by tests.
+
+Note that all these images must be pushed to the testing project hosted on
+[Google Container Registry][gcr]. This will happen automatically as part of
+continuous integration. This will speed up loading as images will not need to be
+built from scratch for each test run.
+
+Image tooling is accessible via `make`, specifically via `images/Makefile`.
+
+## Why make?
+
+Make is used because it can bootstrap the `default` image, which contains
+`bazel` and all other parts of the toolchain.
+
+## Listing images
+
+To list all images, use `make list-all-images` from the top-level directory.
+
+## Loading and referencing images
+
+To build a specific image, use `make load-<image>` from the top-level directory.
+This will ensure that an image `gvisor.dev/images/<image>:latest` is available.
+
+Images should always be referred to via the `gvisor.dev/images` canonical path.
+This tag exists only locally, but serves to decouple tests from the underlying
+image infrastructure.
+
+The continuous integration system can either take fine-grained dependencies on
+single images via individual `load` targets, or pull all images via a single
+`load-all-images` invocation.
+
+## Adding new images
+
+To add a new image, create a new directory under `images` containing a
+Dockerfile and any other files that the image requires. You may choose to add to
+an existing subdirectory if applicable, or create a new one.
+
+All images will be tagged and memoized using a hash of the directory contents.
+As a result, every image should be made completely reproducible if possible.
+This means using fixed tags and fixed versions whenever feasible.
+
+Notes that images should also be made architecture-independent if possible. The
+build scripts will handling loading the appropriate architecture onto the
+machine and tagging it with the single canonical tag.
+
+Add a `load-<image>` dependency in the Makefile if the image is required for a
+particular set of tests. This target will pull the tag from the image repository
+if available.
+
+## Building and pushing images
+
+All images can be built manually by running `build-<image>` and pushed using
+`push-<image>`. Note that you can also use `build-all-images` and
+`push-all-images`. Note that pushing will require appropriate permissions in the
+project.
+
+The continuous integration system can either take fine-grained dependencies on
+individual `push` targets, or ensure all images are up-to-date with a single
+`push-all-images` invocation.
+
+## Multi-Arch images
+
+By default, the image is built for host architecture. Cross-building can be
+achieved by specifying `ARCH` variable to make. For example:
+
+```
+$ make ARCH=aarch64 rebuild-default
+```
diff --git a/images/basic/alpine/Dockerfile b/images/basic/alpine/Dockerfile
new file mode 100644
index 000000000..12b26040a
--- /dev/null
+++ b/images/basic/alpine/Dockerfile
@@ -0,0 +1 @@
+FROM alpine:3.11.5
diff --git a/images/basic/busybox/Dockerfile b/images/basic/busybox/Dockerfile
new file mode 100644
index 000000000..79b3f683a
--- /dev/null
+++ b/images/basic/busybox/Dockerfile
@@ -0,0 +1 @@
+FROM busybox:1.31.1
diff --git a/images/basic/hostoverlaytest/Dockerfile b/images/basic/hostoverlaytest/Dockerfile
new file mode 100644
index 000000000..6cef1a542
--- /dev/null
+++ b/images/basic/hostoverlaytest/Dockerfile
@@ -0,0 +1,8 @@
+FROM ubuntu:bionic
+
+WORKDIR /root
+COPY . .
+
+RUN apt-get update && apt-get install -y gcc
+RUN gcc -O2 -o test_copy_up test_copy_up.c
+RUN gcc -O2 -o test_rewinddir test_rewinddir.c
diff --git a/images/basic/hostoverlaytest/copy_up_testfile.txt b/images/basic/hostoverlaytest/copy_up_testfile.txt
new file mode 100644
index 000000000..e4188c841
--- /dev/null
+++ b/images/basic/hostoverlaytest/copy_up_testfile.txt
@@ -0,0 +1 @@
+old data
diff --git a/images/basic/hostoverlaytest/test_copy_up.c b/images/basic/hostoverlaytest/test_copy_up.c
new file mode 100644
index 000000000..010b261dc
--- /dev/null
+++ b/images/basic/hostoverlaytest/test_copy_up.c
@@ -0,0 +1,88 @@
+#include <err.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+int main(int argc, char** argv) {
+ const char kTestFilePath[] = "copy_up_testfile.txt";
+ const char kOldFileData[] = "old data\n";
+ const char kNewFileData[] = "new data\n";
+ const size_t kPageSize = sysconf(_SC_PAGE_SIZE);
+
+ // Open a file that already exists in a host overlayfs lower layer.
+ const int fd_rdonly = open(kTestFilePath, O_RDONLY);
+ if (fd_rdonly < 0) {
+ err(1, "open(%s, O_RDONLY)", kTestFilePath);
+ }
+
+ // Check that the file's initial contents are what we expect when read via
+ // syscall.
+ char oldbuf[sizeof(kOldFileData)] = {};
+ ssize_t n = pread(fd_rdonly, oldbuf, sizeof(oldbuf), 0);
+ if (n < 0) {
+ err(1, "initial pread");
+ }
+ if (n != strlen(kOldFileData)) {
+ errx(1, "short initial pread (%ld/%lu bytes)", n, strlen(kOldFileData));
+ }
+ if (strcmp(oldbuf, kOldFileData) != 0) {
+ errx(1, "initial pread returned wrong data: %s", oldbuf);
+ }
+
+ // Check that the file's initial contents are what we expect when read via
+ // memory mapping.
+ void* page = mmap(NULL, kPageSize, PROT_READ, MAP_SHARED, fd_rdonly, 0);
+ if (page == MAP_FAILED) {
+ err(1, "mmap");
+ }
+ if (strcmp(page, kOldFileData) != 0) {
+ errx(1, "mapping contains wrong initial data: %s", (const char*)page);
+ }
+
+ // Open the same file writably, causing host overlayfs to copy it up, and
+ // replace its contents.
+ const int fd_rdwr = open(kTestFilePath, O_RDWR);
+ if (fd_rdwr < 0) {
+ err(1, "open(%s, O_RDWR)", kTestFilePath);
+ }
+ n = write(fd_rdwr, kNewFileData, strlen(kNewFileData));
+ if (n < 0) {
+ err(1, "write");
+ }
+ if (n != strlen(kNewFileData)) {
+ errx(1, "short write (%ld/%lu bytes)", n, strlen(kNewFileData));
+ }
+ if (ftruncate(fd_rdwr, strlen(kNewFileData)) < 0) {
+ err(1, "truncate");
+ }
+
+ int failed = 0;
+
+ // Check that syscalls on the old FD return updated contents. (Before Linux
+ // 4.18, this requires that runsc use a post-copy-up FD to service the read.)
+ char newbuf[sizeof(kNewFileData)] = {};
+ n = pread(fd_rdonly, newbuf, sizeof(newbuf), 0);
+ if (n < 0) {
+ err(1, "final pread");
+ }
+ if (n != strlen(kNewFileData)) {
+ warnx("short final pread (%ld/%lu bytes)", n, strlen(kNewFileData));
+ failed = 1;
+ } else if (strcmp(newbuf, kNewFileData) != 0) {
+ warnx("final pread returned wrong data: %s", newbuf);
+ failed = 1;
+ }
+
+ // Check that the memory mapping of the old FD has been updated. (Linux
+ // overlayfs does not do this, so regardless of kernel version this requires
+ // that runsc replace existing memory mappings with mappings of a
+ // post-copy-up FD.)
+ if (strcmp(page, kNewFileData) != 0) {
+ warnx("mapping contains wrong final data: %s", (const char*)page);
+ failed = 1;
+ }
+
+ return failed;
+}
diff --git a/images/basic/hostoverlaytest/test_rewinddir.c b/images/basic/hostoverlaytest/test_rewinddir.c
new file mode 100644
index 000000000..f1a4085e1
--- /dev/null
+++ b/images/basic/hostoverlaytest/test_rewinddir.c
@@ -0,0 +1,78 @@
+#include <dirent.h>
+#include <err.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+int main(int argc, char** argv) {
+ const char kDirPath[] = "rewinddir_test_dir";
+ const char kFileBasename[] = "rewinddir_test_file";
+
+ // Create the test directory.
+ if (mkdir(kDirPath, 0755) < 0) {
+ err(1, "mkdir(%s)", kDirPath);
+ }
+
+ // The test directory should initially be empty.
+ DIR* dir = opendir(kDirPath);
+ if (!dir) {
+ err(1, "opendir(%s)", kDirPath);
+ }
+ int failed = 0;
+ while (1) {
+ errno = 0;
+ struct dirent* d = readdir(dir);
+ if (!d) {
+ if (errno != 0) {
+ err(1, "readdir");
+ }
+ break;
+ }
+ if (strcmp(d->d_name, ".") != 0 && strcmp(d->d_name, "..") != 0) {
+ warnx("unexpected file %s in new directory", d->d_name);
+ failed = 1;
+ }
+ }
+
+ // Create a file in the test directory.
+ char* file_path = malloc(strlen(kDirPath) + 1 + strlen(kFileBasename));
+ if (!file_path) {
+ errx(1, "malloc");
+ }
+ strcpy(file_path, kDirPath);
+ file_path[strlen(kDirPath)] = '/';
+ strcpy(file_path + strlen(kDirPath) + 1, kFileBasename);
+ if (mknod(file_path, 0644, 0) < 0) {
+ err(1, "mknod(%s)", file_path);
+ }
+
+ // After rewinddir(), re-reading the directory stream should yield the new
+ // file.
+ rewinddir(dir);
+ size_t found_file = 0;
+ while (1) {
+ errno = 0;
+ struct dirent* d = readdir(dir);
+ if (!d) {
+ if (errno != 0) {
+ err(1, "readdir");
+ }
+ break;
+ }
+ if (strcmp(d->d_name, kFileBasename) == 0) {
+ found_file++;
+ } else if (strcmp(d->d_name, ".") != 0 && strcmp(d->d_name, "..") != 0) {
+ warnx("unexpected file %s in new directory", d->d_name);
+ failed = 1;
+ }
+ }
+ if (found_file != 1) {
+ warnx("readdir returned file %s %zu times, wanted 1", kFileBasename,
+ found_file);
+ failed = 1;
+ }
+
+ return failed;
+}
diff --git a/images/basic/httpd/Dockerfile b/images/basic/httpd/Dockerfile
new file mode 100644
index 000000000..83bc0ed88
--- /dev/null
+++ b/images/basic/httpd/Dockerfile
@@ -0,0 +1 @@
+FROM httpd:2.4.43
diff --git a/images/basic/linktest/Dockerfile b/images/basic/linktest/Dockerfile
new file mode 100644
index 000000000..baebc9b76
--- /dev/null
+++ b/images/basic/linktest/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:bionic
+
+WORKDIR /root
+COPY . .
+
+RUN apt-get update && apt-get install -y gcc
+RUN gcc -O2 -o link_test link_test.c
diff --git a/images/basic/linktest/link_test.c b/images/basic/linktest/link_test.c
new file mode 100644
index 000000000..45ab00abe
--- /dev/null
+++ b/images/basic/linktest/link_test.c
@@ -0,0 +1,93 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <err.h>
+#include <fcntl.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+// Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it
+// cannot use tricks like userns as root. For this reason, run a basic link test
+// to ensure some coverage.
+int main(int argc, char** argv) {
+ const char kOldPath[] = "old.txt";
+ int fd = open(kOldPath, O_RDWR | O_CREAT | O_TRUNC, 0600);
+ if (fd < 0) {
+ errx(1, "open(%s) failed", kOldPath);
+ }
+ const char kData[] = "some random content";
+ if (write(fd, kData, sizeof(kData)) < 0) {
+ err(1, "write failed");
+ }
+ close(fd);
+
+ struct stat old_stat;
+ if (stat(kOldPath, &old_stat)) {
+ errx(1, "stat(%s) failed", kOldPath);
+ }
+
+ const char kNewPath[] = "new.txt";
+ if (link(kOldPath, kNewPath)) {
+ errx(1, "link(%s, %s) failed", kOldPath, kNewPath);
+ }
+
+ struct stat new_stat;
+ if (stat(kNewPath, &new_stat)) {
+ errx(1, "stat(%s) failed", kNewPath);
+ }
+
+ // Check that files are the same.
+ if (old_stat.st_dev != new_stat.st_dev) {
+ errx(1, "files st_dev is different, want: %lu, got: %lu", old_stat.st_dev,
+ new_stat.st_dev);
+ }
+ if (old_stat.st_ino != new_stat.st_ino) {
+ errx(1, "files st_ino is different, want: %lu, got: %lu", old_stat.st_ino,
+ new_stat.st_ino);
+ }
+
+ // Check that link count is correct.
+ if (new_stat.st_nlink != old_stat.st_nlink + 1) {
+ errx(1, "wrong nlink, want: %lu, got: %lu", old_stat.st_nlink + 1,
+ new_stat.st_nlink);
+ }
+
+ // Check taht contents are the same.
+ fd = open(kNewPath, O_RDONLY);
+ if (fd < 0) {
+ errx(1, "open(%s) failed", kNewPath);
+ }
+ char buf[sizeof(kData)] = {};
+ if (read(fd, buf, sizeof(buf)) < 0) {
+ err(1, "read failed");
+ }
+ close(fd);
+
+ if (strcmp(buf, kData) != 0) {
+ errx(1, "file content mismatch: %s", buf);
+ }
+
+ // Cleanup.
+ if (unlink(kNewPath)) {
+ errx(1, "unlink(%s) failed", kNewPath);
+ }
+ if (unlink(kOldPath)) {
+ errx(1, "unlink(%s) failed", kOldPath);
+ }
+
+ // Success!
+ return 0;
+}
diff --git a/images/basic/mysql/Dockerfile b/images/basic/mysql/Dockerfile
new file mode 100644
index 000000000..95da9c48d
--- /dev/null
+++ b/images/basic/mysql/Dockerfile
@@ -0,0 +1 @@
+FROM mysql:8.0.19
diff --git a/images/basic/nginx/Dockerfile b/images/basic/nginx/Dockerfile
new file mode 100644
index 000000000..af2e62526
--- /dev/null
+++ b/images/basic/nginx/Dockerfile
@@ -0,0 +1 @@
+FROM nginx:1.17.9
diff --git a/images/basic/python/Dockerfile b/images/basic/python/Dockerfile
new file mode 100644
index 000000000..acf07cca9
--- /dev/null
+++ b/images/basic/python/Dockerfile
@@ -0,0 +1,2 @@
+FROM python:3
+ENTRYPOINT ["python", "-m", "http.server", "8080"]
diff --git a/images/basic/resolv/Dockerfile b/images/basic/resolv/Dockerfile
new file mode 100644
index 000000000..13665bdaf
--- /dev/null
+++ b/images/basic/resolv/Dockerfile
@@ -0,0 +1 @@
+FROM k8s.gcr.io/busybox:latest
diff --git a/images/basic/ruby/Dockerfile b/images/basic/ruby/Dockerfile
new file mode 100644
index 000000000..d290418fb
--- /dev/null
+++ b/images/basic/ruby/Dockerfile
@@ -0,0 +1 @@
+FROM ruby:2.7.1
diff --git a/images/basic/tmpfile/Dockerfile b/images/basic/tmpfile/Dockerfile
new file mode 100644
index 000000000..e3816c8cb
--- /dev/null
+++ b/images/basic/tmpfile/Dockerfile
@@ -0,0 +1,4 @@
+# Create file under /tmp to ensure files inside '/tmp' are not overridden.
+FROM alpine:3.11.5
+RUN mkdir -p /tmp/foo \
+ && echo 123 > /tmp/foo/file.txt
diff --git a/images/basic/tomcat/Dockerfile b/images/basic/tomcat/Dockerfile
new file mode 100644
index 000000000..c7db39a36
--- /dev/null
+++ b/images/basic/tomcat/Dockerfile
@@ -0,0 +1 @@
+FROM tomcat:8.0
diff --git a/images/basic/ubuntu/Dockerfile b/images/basic/ubuntu/Dockerfile
new file mode 100644
index 000000000..331b71343
--- /dev/null
+++ b/images/basic/ubuntu/Dockerfile
@@ -0,0 +1 @@
+FROM ubuntu:trusty
diff --git a/images/benchmarks/ab/Dockerfile b/images/benchmarks/ab/Dockerfile
new file mode 100644
index 000000000..10544639b
--- /dev/null
+++ b/images/benchmarks/ab/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ apache2-utils \
+ && rm -rf /var/lib/apt/lists/*
diff --git a/images/benchmarks/absl/Dockerfile b/images/benchmarks/absl/Dockerfile
new file mode 100644
index 000000000..b0dd97695
--- /dev/null
+++ b/images/benchmarks/absl/Dockerfile
@@ -0,0 +1,21 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ wget \
+ git \
+ pkg-config \
+ zip \
+ g++ \
+ zlib1g-dev \
+ unzip \
+ python3 \
+ && rm -rf /var/lib/apt/lists/*
+RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh
+RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh
+RUN ./bazel-0.27.0-installer-linux-x86_64.sh
+
+RUN mkdir abseil-cpp && cd abseil-cpp \
+ && git init && git remote add origin https://github.com/abseil/abseil-cpp.git \
+ && git fetch --depth 1 origin 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f && git checkout FETCH_HEAD
diff --git a/images/benchmarks/alpine/Dockerfile b/images/benchmarks/alpine/Dockerfile
new file mode 100644
index 000000000..b09b037ca
--- /dev/null
+++ b/images/benchmarks/alpine/Dockerfile
@@ -0,0 +1 @@
+FROM alpine:latest
diff --git a/images/benchmarks/ffmpeg/Dockerfile b/images/benchmarks/ffmpeg/Dockerfile
new file mode 100644
index 000000000..7108df64f
--- /dev/null
+++ b/images/benchmarks/ffmpeg/Dockerfile
@@ -0,0 +1,9 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ ffmpeg \
+ && rm -rf /var/lib/apt/lists/*
+WORKDIR /media
+ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4
diff --git a/images/benchmarks/fio/Dockerfile b/images/benchmarks/fio/Dockerfile
new file mode 100644
index 000000000..9531df7fa
--- /dev/null
+++ b/images/benchmarks/fio/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ fio \
+ && rm -rf /var/lib/apt/lists/*
diff --git a/images/benchmarks/hey/Dockerfile b/images/benchmarks/hey/Dockerfile
new file mode 100644
index 000000000..f586978b6
--- /dev/null
+++ b/images/benchmarks/hey/Dockerfile
@@ -0,0 +1,12 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ wget \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN wget https://storage.googleapis.com/hey-release/hey_linux_amd64 \
+ && chmod 777 hey_linux_amd64 \
+ && cp hey_linux_amd64 /bin/hey \
+ && rm hey_linux_amd64
diff --git a/images/benchmarks/httpd/Dockerfile b/images/benchmarks/httpd/Dockerfile
new file mode 100644
index 000000000..b72406012
--- /dev/null
+++ b/images/benchmarks/httpd/Dockerfile
@@ -0,0 +1,17 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ apache2 \
+ && rm -rf /var/lib/apt/lists/*
+
+# Generate a bunch of relevant files.
+RUN mkdir -p /local && \
+ for size in 1 10 100 1000 1024 10240; do \
+ dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \
+ done
+
+# Rewrite DocumentRoot to point to /tmp/html instead of the default path.
+RUN sed -i 's/DocumentRoot.*\/var\/www\/html$/DocumentRoot \/tmp\/html/' /etc/apache2/sites-enabled/000-default.conf
+COPY ./apache2-tmpdir.conf /etc/apache2/sites-enabled/apache2-tmpdir.conf
diff --git a/images/benchmarks/httpd/apache2-tmpdir.conf b/images/benchmarks/httpd/apache2-tmpdir.conf
new file mode 100644
index 000000000..e33f8d9bb
--- /dev/null
+++ b/images/benchmarks/httpd/apache2-tmpdir.conf
@@ -0,0 +1,5 @@
+<Directory /tmp/html/>
+ Options Indexes FollowSymLinks
+ AllowOverride None
+ Require all granted
+</Directory> \ No newline at end of file
diff --git a/images/benchmarks/iperf/Dockerfile b/images/benchmarks/iperf/Dockerfile
new file mode 100644
index 000000000..4cbfd0d70
--- /dev/null
+++ b/images/benchmarks/iperf/Dockerfile
@@ -0,0 +1,8 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ iperf \
+ && rm -rf /var/lib/apt/lists/*
+
diff --git a/images/benchmarks/nginx/Dockerfile b/images/benchmarks/nginx/Dockerfile
new file mode 100644
index 000000000..b64eb52ae
--- /dev/null
+++ b/images/benchmarks/nginx/Dockerfile
@@ -0,0 +1 @@
+FROM nginx:1.15.10
diff --git a/images/benchmarks/node/Dockerfile b/images/benchmarks/node/Dockerfile
new file mode 100644
index 000000000..bf45650a0
--- /dev/null
+++ b/images/benchmarks/node/Dockerfile
@@ -0,0 +1 @@
+FROM node:onbuild
diff --git a/images/benchmarks/node/index.hbs b/images/benchmarks/node/index.hbs
new file mode 100644
index 000000000..03feceb75
--- /dev/null
+++ b/images/benchmarks/node/index.hbs
@@ -0,0 +1,8 @@
+<!DOCTYPE html>
+<html>
+<body>
+ {{#each text}}
+ <p>{{this}}</p>
+ {{/each}}
+</body>
+</html>
diff --git a/images/benchmarks/node/index.js b/images/benchmarks/node/index.js
new file mode 100644
index 000000000..831015d18
--- /dev/null
+++ b/images/benchmarks/node/index.js
@@ -0,0 +1,42 @@
+const app = require('express')();
+const path = require('path');
+const redis = require('redis');
+const srs = require('secure-random-string');
+
+// The hostname is the first argument.
+const host_name = process.argv[2];
+
+var client = redis.createClient({host: host_name, detect_buffers: true});
+
+app.set('views', __dirname);
+app.set('view engine', 'hbs');
+
+app.get('/', (req, res) => {
+ var tmp = [];
+ /* Pull four random keys from the redis server. */
+ for (i = 0; i < 4; i++) {
+ client.get(Math.floor(Math.random() * (100)), function(err, reply) {
+ tmp.push(reply.toString());
+ });
+ }
+ res.render('index', {text: tmp});
+});
+
+/**
+ * Securely generate a random string.
+ * @param {number} len
+ * @return {string}
+ */
+function randomBody(len) {
+ return srs({alphanumeric: true, length: len});
+}
+
+/** Mutates one hundred keys randomly. */
+function generateText() {
+ for (i = 0; i < 100; i++) {
+ client.set(i, randomBody(1024));
+ }
+}
+
+generateText();
+app.listen(8080);
diff --git a/images/benchmarks/node/package-lock.json b/images/benchmarks/node/package-lock.json
new file mode 100644
index 000000000..580e68aa5
--- /dev/null
+++ b/images/benchmarks/node/package-lock.json
@@ -0,0 +1,486 @@
+{
+ "name": "nodedum",
+ "version": "1.0.0",
+ "lockfileVersion": 1,
+ "requires": true,
+ "dependencies": {
+ "accepts": {
+ "version": "1.3.5",
+ "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz",
+ "integrity": "sha1-63d99gEXI6OxTopywIBcjoZ0a9I=",
+ "requires": {
+ "mime-types": "~2.1.18",
+ "negotiator": "0.6.1"
+ }
+ },
+ "array-flatten": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz",
+ "integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI="
+ },
+ "async": {
+ "version": "2.6.2",
+ "resolved": "https://registry.npmjs.org/async/-/async-2.6.2.tgz",
+ "integrity": "sha512-H1qVYh1MYhEEFLsP97cVKqCGo7KfCyTt6uEWqsTBr9SO84oK9Uwbyd/yCW+6rKJLHksBNUVWZDAjfS+Ccx0Bbg==",
+ "requires": {
+ "lodash": "^4.17.11"
+ }
+ },
+ "body-parser": {
+ "version": "1.18.3",
+ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.18.3.tgz",
+ "integrity": "sha1-WykhmP/dVTs6DyDe0FkrlWlVyLQ=",
+ "requires": {
+ "bytes": "3.0.0",
+ "content-type": "~1.0.4",
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "http-errors": "~1.6.3",
+ "iconv-lite": "0.4.23",
+ "on-finished": "~2.3.0",
+ "qs": "6.5.2",
+ "raw-body": "2.3.3",
+ "type-is": "~1.6.16"
+ }
+ },
+ "bytes": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz",
+ "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg="
+ },
+ "commander": {
+ "version": "2.20.0",
+ "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.0.tgz",
+ "integrity": "sha512-7j2y+40w61zy6YC2iRNpUe/NwhNyoXrYpHMrSunaMG64nRnaf96zO/KMQR4OyN/UnE5KLyEBnKHd4aG3rskjpQ==",
+ "optional": true
+ },
+ "content-disposition": {
+ "version": "0.5.2",
+ "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.2.tgz",
+ "integrity": "sha1-DPaLud318r55YcOoUXjLhdunjLQ="
+ },
+ "content-type": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.4.tgz",
+ "integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA=="
+ },
+ "cookie": {
+ "version": "0.3.1",
+ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz",
+ "integrity": "sha1-5+Ch+e9DtMi6klxcWpboBtFoc7s="
+ },
+ "cookie-signature": {
+ "version": "1.0.6",
+ "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz",
+ "integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw="
+ },
+ "debug": {
+ "version": "2.6.9",
+ "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz",
+ "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==",
+ "requires": {
+ "ms": "2.0.0"
+ }
+ },
+ "depd": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz",
+ "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak="
+ },
+ "destroy": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz",
+ "integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA="
+ },
+ "double-ended-queue": {
+ "version": "2.1.0-0",
+ "resolved": "https://registry.npmjs.org/double-ended-queue/-/double-ended-queue-2.1.0-0.tgz",
+ "integrity": "sha1-ED01J/0xUo9AGIEwyEHv3XgmTlw="
+ },
+ "ee-first": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz",
+ "integrity": "sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0="
+ },
+ "encodeurl": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz",
+ "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k="
+ },
+ "escape-html": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz",
+ "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg="
+ },
+ "etag": {
+ "version": "1.8.1",
+ "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz",
+ "integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc="
+ },
+ "express": {
+ "version": "4.16.4",
+ "resolved": "https://registry.npmjs.org/express/-/express-4.16.4.tgz",
+ "integrity": "sha512-j12Uuyb4FMrd/qQAm6uCHAkPtO8FDTRJZBDd5D2KOL2eLaz1yUNdUB/NOIyq0iU4q4cFarsUCrnFDPBcnksuOg==",
+ "requires": {
+ "accepts": "~1.3.5",
+ "array-flatten": "1.1.1",
+ "body-parser": "1.18.3",
+ "content-disposition": "0.5.2",
+ "content-type": "~1.0.4",
+ "cookie": "0.3.1",
+ "cookie-signature": "1.0.6",
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "etag": "~1.8.1",
+ "finalhandler": "1.1.1",
+ "fresh": "0.5.2",
+ "merge-descriptors": "1.0.1",
+ "methods": "~1.1.2",
+ "on-finished": "~2.3.0",
+ "parseurl": "~1.3.2",
+ "path-to-regexp": "0.1.7",
+ "proxy-addr": "~2.0.4",
+ "qs": "6.5.2",
+ "range-parser": "~1.2.0",
+ "safe-buffer": "5.1.2",
+ "send": "0.16.2",
+ "serve-static": "1.13.2",
+ "setprototypeof": "1.1.0",
+ "statuses": "~1.4.0",
+ "type-is": "~1.6.16",
+ "utils-merge": "1.0.1",
+ "vary": "~1.1.2"
+ }
+ },
+ "finalhandler": {
+ "version": "1.1.1",
+ "resolved": "http://registry.npmjs.org/finalhandler/-/finalhandler-1.1.1.tgz",
+ "integrity": "sha512-Y1GUDo39ez4aHAw7MysnUD5JzYX+WaIj8I57kO3aEPT1fFRL4sr7mjei97FgnwhAyyzRYmQZaTHb2+9uZ1dPtg==",
+ "requires": {
+ "debug": "2.6.9",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "on-finished": "~2.3.0",
+ "parseurl": "~1.3.2",
+ "statuses": "~1.4.0",
+ "unpipe": "~1.0.0"
+ }
+ },
+ "foreachasync": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz",
+ "integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY="
+ },
+ "forwarded": {
+ "version": "0.1.2",
+ "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz",
+ "integrity": "sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ="
+ },
+ "fresh": {
+ "version": "0.5.2",
+ "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz",
+ "integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac="
+ },
+ "handlebars": {
+ "version": "4.0.14",
+ "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.0.14.tgz",
+ "integrity": "sha512-E7tDoyAA8ilZIV3xDJgl18sX3M8xB9/fMw8+mfW4msLW8jlX97bAnWgT3pmaNXuvzIEgSBMnAHfuXsB2hdzfow==",
+ "requires": {
+ "async": "^2.5.0",
+ "optimist": "^0.6.1",
+ "source-map": "^0.6.1",
+ "uglify-js": "^3.1.4"
+ }
+ },
+ "hbs": {
+ "version": "4.0.4",
+ "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.0.4.tgz",
+ "integrity": "sha512-esVlyV/V59mKkwFai5YmPRSNIWZzhqL5YMN0++ueMxyK1cCfPa5f6JiHtapPKAIVAhQR6rpGxow0troav9WMEg==",
+ "requires": {
+ "handlebars": "4.0.14",
+ "walk": "2.3.9"
+ }
+ },
+ "http-errors": {
+ "version": "1.6.3",
+ "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz",
+ "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=",
+ "requires": {
+ "depd": "~1.1.2",
+ "inherits": "2.0.3",
+ "setprototypeof": "1.1.0",
+ "statuses": ">= 1.4.0 < 2"
+ }
+ },
+ "iconv-lite": {
+ "version": "0.4.23",
+ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.23.tgz",
+ "integrity": "sha512-neyTUVFtahjf0mB3dZT77u+8O0QB89jFdnBkd5P1JgYPbPaia3gXXOVL2fq8VyU2gMMD7SaN7QukTB/pmXYvDA==",
+ "requires": {
+ "safer-buffer": ">= 2.1.2 < 3"
+ }
+ },
+ "inherits": {
+ "version": "2.0.3",
+ "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz",
+ "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4="
+ },
+ "ipaddr.js": {
+ "version": "1.8.0",
+ "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.8.0.tgz",
+ "integrity": "sha1-6qM9bd16zo9/b+DJygRA5wZzix4="
+ },
+ "lodash": {
+ "version": "4.17.15",
+ "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz",
+ "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A=="
+ },
+ "media-typer": {
+ "version": "0.3.0",
+ "resolved": "http://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz",
+ "integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g="
+ },
+ "merge-descriptors": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz",
+ "integrity": "sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E="
+ },
+ "methods": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz",
+ "integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4="
+ },
+ "mime": {
+ "version": "1.4.1",
+ "resolved": "https://registry.npmjs.org/mime/-/mime-1.4.1.tgz",
+ "integrity": "sha512-KI1+qOZu5DcW6wayYHSzR/tXKCDC5Om4s1z2QJjDULzLcmf3DvzS7oluY4HCTrc+9FiKmWUgeNLg7W3uIQvxtQ=="
+ },
+ "mime-db": {
+ "version": "1.37.0",
+ "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.37.0.tgz",
+ "integrity": "sha512-R3C4db6bgQhlIhPU48fUtdVmKnflq+hRdad7IyKhtFj06VPNVdk2RhiYL3UjQIlso8L+YxAtFkobT0VK+S/ybg=="
+ },
+ "mime-types": {
+ "version": "2.1.21",
+ "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.21.tgz",
+ "integrity": "sha512-3iL6DbwpyLzjR3xHSFNFeb9Nz/M8WDkX33t1GFQnFOllWk8pOrh/LSrB5OXlnlW5P9LH73X6loW/eogc+F5lJg==",
+ "requires": {
+ "mime-db": "~1.37.0"
+ }
+ },
+ "minimist": {
+ "version": "0.0.10",
+ "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz",
+ "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8="
+ },
+ "ms": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz",
+ "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g="
+ },
+ "negotiator": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.1.tgz",
+ "integrity": "sha1-KzJxhOiZIQEXeyhWP7XnECrNDKk="
+ },
+ "on-finished": {
+ "version": "2.3.0",
+ "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.3.0.tgz",
+ "integrity": "sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=",
+ "requires": {
+ "ee-first": "1.1.1"
+ }
+ },
+ "optimist": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmjs.org/optimist/-/optimist-0.6.1.tgz",
+ "integrity": "sha1-2j6nRob6IaGaERwybpDrFaAZZoY=",
+ "requires": {
+ "minimist": "~0.0.1",
+ "wordwrap": "~0.0.2"
+ }
+ },
+ "parseurl": {
+ "version": "1.3.2",
+ "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.2.tgz",
+ "integrity": "sha1-/CidTtiZMRlGDBViUyYs3I3mW/M="
+ },
+ "path-to-regexp": {
+ "version": "0.1.7",
+ "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz",
+ "integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w="
+ },
+ "proxy-addr": {
+ "version": "2.0.4",
+ "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.4.tgz",
+ "integrity": "sha512-5erio2h9jp5CHGwcybmxmVqHmnCBZeewlfJ0pex+UW7Qny7OOZXTtH56TGNyBizkgiOwhJtMKrVzDTeKcySZwA==",
+ "requires": {
+ "forwarded": "~0.1.2",
+ "ipaddr.js": "1.8.0"
+ }
+ },
+ "qs": {
+ "version": "6.5.2",
+ "resolved": "https://registry.npmjs.org/qs/-/qs-6.5.2.tgz",
+ "integrity": "sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA=="
+ },
+ "range-parser": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.0.tgz",
+ "integrity": "sha1-9JvmtIeJTdxA3MlKMi9hEJLgDV4="
+ },
+ "raw-body": {
+ "version": "2.3.3",
+ "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.3.3.tgz",
+ "integrity": "sha512-9esiElv1BrZoI3rCDuOuKCBRbuApGGaDPQfjSflGxdy4oyzqghxu6klEkkVIvBje+FF0BX9coEv8KqW6X/7njw==",
+ "requires": {
+ "bytes": "3.0.0",
+ "http-errors": "1.6.3",
+ "iconv-lite": "0.4.23",
+ "unpipe": "1.0.0"
+ }
+ },
+ "redis": {
+ "version": "2.8.0",
+ "resolved": "https://registry.npmjs.org/redis/-/redis-2.8.0.tgz",
+ "integrity": "sha512-M1OkonEQwtRmZv4tEWF2VgpG0JWJ8Fv1PhlgT5+B+uNq2cA3Rt1Yt/ryoR+vQNOQcIEgdCdfH0jr3bDpihAw1A==",
+ "requires": {
+ "double-ended-queue": "^2.1.0-0",
+ "redis-commands": "^1.2.0",
+ "redis-parser": "^2.6.0"
+ },
+ "dependencies": {
+ "redis-commands": {
+ "version": "1.4.0",
+ "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.4.0.tgz",
+ "integrity": "sha512-cu8EF+MtkwI4DLIT0x9P8qNTLFhQD4jLfxLR0cCNkeGzs87FN6879JOJwNQR/1zD7aSYNbU0hgsV9zGY71Itvw=="
+ },
+ "redis-parser": {
+ "version": "2.6.0",
+ "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz",
+ "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs="
+ }
+ }
+ },
+ "redis-commands": {
+ "version": "1.5.0",
+ "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.5.0.tgz",
+ "integrity": "sha512-6KxamqpZ468MeQC3bkWmCB1fp56XL64D4Kf0zJSwDZbVLLm7KFkoIcHrgRvQ+sk8dnhySs7+yBg94yIkAK7aJg=="
+ },
+ "redis-parser": {
+ "version": "2.6.0",
+ "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz",
+ "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs="
+ },
+ "safe-buffer": {
+ "version": "5.1.2",
+ "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz",
+ "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g=="
+ },
+ "safer-buffer": {
+ "version": "2.1.2",
+ "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
+ "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg=="
+ },
+ "secure-random-string": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.0.tgz",
+ "integrity": "sha512-V/h8jqoz58zklNGybVhP++cWrxEPXlLM/6BeJ4e0a8zlb4BsbYRzFs16snrxByPa5LUxCVTD3M6EYIVIHR1fAg=="
+ },
+ "send": {
+ "version": "0.16.2",
+ "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz",
+ "integrity": "sha512-E64YFPUssFHEFBvpbbjr44NCLtI1AohxQ8ZSiJjQLskAdKuriYEP6VyGEsRDH8ScozGpkaX1BGvhanqCwkcEZw==",
+ "requires": {
+ "debug": "2.6.9",
+ "depd": "~1.1.2",
+ "destroy": "~1.0.4",
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "etag": "~1.8.1",
+ "fresh": "0.5.2",
+ "http-errors": "~1.6.2",
+ "mime": "1.4.1",
+ "ms": "2.0.0",
+ "on-finished": "~2.3.0",
+ "range-parser": "~1.2.0",
+ "statuses": "~1.4.0"
+ }
+ },
+ "serve-static": {
+ "version": "1.13.2",
+ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.13.2.tgz",
+ "integrity": "sha512-p/tdJrO4U387R9oMjb1oj7qSMaMfmOyd4j9hOFoxZe2baQszgHcSWjuya/CiT5kgZZKRudHNOA0pYXOl8rQ5nw==",
+ "requires": {
+ "encodeurl": "~1.0.2",
+ "escape-html": "~1.0.3",
+ "parseurl": "~1.3.2",
+ "send": "0.16.2"
+ }
+ },
+ "setprototypeof": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz",
+ "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ=="
+ },
+ "source-map": {
+ "version": "0.6.1",
+ "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
+ "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g=="
+ },
+ "statuses": {
+ "version": "1.4.0",
+ "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.4.0.tgz",
+ "integrity": "sha512-zhSCtt8v2NDrRlPQpCNtw/heZLtfUDqxBM1udqikb/Hbk52LK4nQSwr10u77iopCW5LsyHpuXS0GnEc48mLeew=="
+ },
+ "type-is": {
+ "version": "1.6.16",
+ "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.16.tgz",
+ "integrity": "sha512-HRkVv/5qY2G6I8iab9cI7v1bOIdhm94dVjQCPFElW9W+3GeDOSHmy2EBYe4VTApuzolPcmgFTN3ftVJRKR2J9Q==",
+ "requires": {
+ "media-typer": "0.3.0",
+ "mime-types": "~2.1.18"
+ }
+ },
+ "uglify-js": {
+ "version": "3.5.9",
+ "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.5.9.tgz",
+ "integrity": "sha512-WpT0RqsDtAWPNJK955DEnb6xjymR8Fn0OlK4TT4pS0ASYsVPqr5ELhgwOwLCP5J5vHeJ4xmMmz3DEgdqC10JeQ==",
+ "optional": true,
+ "requires": {
+ "commander": "~2.20.0",
+ "source-map": "~0.6.1"
+ }
+ },
+ "unpipe": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz",
+ "integrity": "sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw="
+ },
+ "utils-merge": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz",
+ "integrity": "sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM="
+ },
+ "vary": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz",
+ "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw="
+ },
+ "walk": {
+ "version": "2.3.9",
+ "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.9.tgz",
+ "integrity": "sha1-MbTbZnjyrgHDnqn7hyWpAx5Vins=",
+ "requires": {
+ "foreachasync": "^3.0.0"
+ }
+ },
+ "wordwrap": {
+ "version": "0.0.3",
+ "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz",
+ "integrity": "sha1-o9XabNXAvAAI03I0u68b7WMFkQc="
+ }
+ }
+}
diff --git a/images/benchmarks/node/package.json b/images/benchmarks/node/package.json
new file mode 100644
index 000000000..7dcadd523
--- /dev/null
+++ b/images/benchmarks/node/package.json
@@ -0,0 +1,19 @@
+{
+ "name": "nodedum",
+ "version": "1.0.0",
+ "description": "",
+ "main": "index.js",
+ "scripts": {
+ "test": "echo \"Error: no test specified\" && exit 1"
+ },
+ "author": "",
+ "license": "ISC",
+ "dependencies": {
+ "express": "^4.16.4",
+ "hbs": "^4.0.4",
+ "redis": "^2.8.0",
+ "redis-commands": "^1.2.0",
+ "redis-parser": "^2.6.0",
+ "secure-random-string": "^1.1.0"
+ }
+}
diff --git a/images/benchmarks/redis/Dockerfile b/images/benchmarks/redis/Dockerfile
new file mode 100644
index 000000000..0f17249af
--- /dev/null
+++ b/images/benchmarks/redis/Dockerfile
@@ -0,0 +1 @@
+FROM redis:5.0.4
diff --git a/images/benchmarks/ruby/Dockerfile b/images/benchmarks/ruby/Dockerfile
new file mode 100755
index 000000000..13c4f6eed
--- /dev/null
+++ b/images/benchmarks/ruby/Dockerfile
@@ -0,0 +1,27 @@
+# example based on https://github.com/errm/fib
+FROM alpine:3.9 as build
+
+COPY Gemfile Gemfile.lock ./
+
+RUN apk add --no-cache ruby ruby-dev ruby-bundler ruby-json build-base bash \
+ && bundle install --frozen -j4 -r3 --no-cache --without development \
+ && apk del --no-cache ruby-bundler \
+ && rm -rf /usr/lib/ruby/gems/*/cache
+
+FROM alpine:3.9 as prod
+
+COPY --from=build /usr/lib/ruby/gems /usr/lib/ruby/gems
+RUN apk add --no-cache ruby ruby-json ruby-etc redis apache2-utils \
+ && ruby -e "Gem::Specification.map.each do |spec| \
+ Gem::Installer.for_spec( \
+ spec, \
+ wrappers: true, \
+ force: true, \
+ install_dir: spec.base_dir, \
+ build_args: spec.build_args, \
+ ).generate_bin \
+ end"
+
+COPY . /app/.
+
+STOPSIGNAL SIGINT
diff --git a/images/benchmarks/ruby/Gemfile b/images/benchmarks/ruby/Gemfile
new file mode 100755
index 000000000..ac521b32c
--- /dev/null
+++ b/images/benchmarks/ruby/Gemfile
@@ -0,0 +1,5 @@
+source "https://rubygems.org"
+
+gem "sinatra"
+gem "puma"
+gem "redis" \ No newline at end of file
diff --git a/images/benchmarks/ruby/Gemfile.lock b/images/benchmarks/ruby/Gemfile.lock
new file mode 100644
index 000000000..041778e02
--- /dev/null
+++ b/images/benchmarks/ruby/Gemfile.lock
@@ -0,0 +1,26 @@
+GEM
+ remote: https://rubygems.org/
+ specs:
+ mustermann (1.0.3)
+ puma (3.4.0)
+ rack (2.0.6)
+ rack-protection (2.0.5)
+ rack
+ redis (4.1.0)
+ sinatra (2.0.5)
+ mustermann (~> 1.0)
+ rack (~> 2.0)
+ rack-protection (= 2.0.5)
+ tilt (~> 2.0)
+ tilt (2.0.9)
+
+PLATFORMS
+ ruby
+
+DEPENDENCIES
+ puma
+ redis
+ sinatra
+
+BUNDLED WITH
+ 1.17.1 \ No newline at end of file
diff --git a/images/benchmarks/ruby/config.ru b/images/benchmarks/ruby/config.ru
new file mode 100755
index 000000000..b2d135cc0
--- /dev/null
+++ b/images/benchmarks/ruby/config.ru
@@ -0,0 +1,2 @@
+require './main'
+run Sinatra::Application \ No newline at end of file
diff --git a/images/benchmarks/ruby/index.erb b/images/benchmarks/ruby/index.erb
new file mode 100755
index 000000000..7f7300e80
--- /dev/null
+++ b/images/benchmarks/ruby/index.erb
@@ -0,0 +1,8 @@
+<!DOCTYPE html>
+<html>
+<body>
+ <% text.each do |t| %>
+ <p><%= t %></p>
+ <% end %>
+</body>
+</html>
diff --git a/images/benchmarks/ruby/main.rb b/images/benchmarks/ruby/main.rb
new file mode 100755
index 000000000..b998f004e
--- /dev/null
+++ b/images/benchmarks/ruby/main.rb
@@ -0,0 +1,27 @@
+require "sinatra"
+require "securerandom"
+require "redis"
+
+redis_host = ENV["HOST"]
+$redis = Redis.new(host: redis_host)
+
+def generateText
+ for i in 0..99
+ $redis.set(i, randomBody(1024))
+ end
+end
+
+def randomBody(length)
+ return SecureRandom.alphanumeric(length)
+end
+
+generateText
+template = ERB.new(File.read('./index.erb'))
+
+get "/" do
+ texts = Array.new
+ for i in 0..4
+ texts.push($redis.get(rand(0..99)))
+ end
+ template.result_with_hash(text: texts)
+end
diff --git a/images/benchmarks/runsc/Dockerfile b/images/benchmarks/runsc/Dockerfile
new file mode 100644
index 000000000..6c3aafa57
--- /dev/null
+++ b/images/benchmarks/runsc/Dockerfile
@@ -0,0 +1,24 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ wget \
+ git \
+ pkg-config \
+ zip \
+ g++ \
+ zlib1g-dev \
+ unzip \
+ python-minimal \
+ python3 \
+ python3-pip \
+ && rm -rf /var/lib/apt/lists/*
+RUN wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-installer-linux-x86_64.sh
+RUN chmod +x bazel-3.4.1-installer-linux-x86_64.sh
+RUN ./bazel-3.4.1-installer-linux-x86_64.sh
+
+# Download release-20200601.0
+RUN mkdir gvisor && cd gvisor \
+ && git init && git remote add origin https://github.com/google/gvisor.git \
+ && git fetch --depth 1 origin a9b47390c821942d60784e308f681f213645049c && git checkout FETCH_HEAD
diff --git a/images/benchmarks/sysbench/Dockerfile b/images/benchmarks/sysbench/Dockerfile
new file mode 100644
index 000000000..55e865f43
--- /dev/null
+++ b/images/benchmarks/sysbench/Dockerfile
@@ -0,0 +1,7 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ sysbench \
+ && rm -rf /var/lib/apt/lists/*
diff --git a/images/benchmarks/tensorflow/Dockerfile b/images/benchmarks/tensorflow/Dockerfile
new file mode 100644
index 000000000..7564a4ee5
--- /dev/null
+++ b/images/benchmarks/tensorflow/Dockerfile
@@ -0,0 +1,7 @@
+FROM tensorflow/tensorflow:1.13.2
+
+RUN apt-get update \
+ && apt-get install -y git
+RUN git clone --depth 1 https://github.com/aymericdamien/TensorFlow-Examples.git
+RUN python -m pip install -U pip setuptools
+RUN python -m pip install matplotlib
diff --git a/images/benchmarks/util/Dockerfile b/images/benchmarks/util/Dockerfile
new file mode 100644
index 000000000..f2799b3e6
--- /dev/null
+++ b/images/benchmarks/util/Dockerfile
@@ -0,0 +1,3 @@
+FROM ubuntu:bionic
+
+RUN apt-get update && apt-get install -y wget
diff --git a/images/default/Dockerfile b/images/default/Dockerfile
new file mode 100644
index 000000000..d058b83cb
--- /dev/null
+++ b/images/default/Dockerfile
@@ -0,0 +1,16 @@
+FROM fedora:31
+# Install bazel.
+RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
+RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch diffutils
+RUN pip install --no-cache-dir pycparser
+RUN dnf install -y bazel3
+# Install gcloud.
+RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \
+ tar zxvf - google-cloud-sdk && \
+ google-cloud-sdk/install.sh && \
+ ln -s /google-cloud-sdk/bin/gcloud /usr/bin/gcloud
+# Install Docker client for the website build.
+RUN dnf config-manager --add-repo https://download.docker.com/linux/fedora/docker-ce.repo
+RUN dnf install -y docker-ce-cli
+WORKDIR /workspace
+ENTRYPOINT ["/usr/bin/bazel"]
diff --git a/images/iptables/Dockerfile b/images/iptables/Dockerfile
new file mode 100644
index 000000000..efd91cb80
--- /dev/null
+++ b/images/iptables/Dockerfile
@@ -0,0 +1,2 @@
+FROM ubuntu
+RUN apt update && apt install -y iptables
diff --git a/images/jekyll/Dockerfile b/images/jekyll/Dockerfile
new file mode 100644
index 000000000..ba039ba15
--- /dev/null
+++ b/images/jekyll/Dockerfile
@@ -0,0 +1,14 @@
+FROM jekyll/jekyll:4.0.0
+USER root
+RUN gem install \
+ html-proofer:3.10.2 \
+ nokogiri:1.10.1 \
+ jekyll-autoprefixer:1.0.2 \
+ jekyll-inline-svg:1.1.4 \
+ jekyll-paginate:1.1.0 \
+ kramdown-parser-gfm:1.1.0 \
+ jekyll-relative-links:0.6.1 \
+ jekyll-feed:0.13.0 \
+ jekyll-sitemap:1.4.0
+COPY checks.rb /checks.rb
+CMD ["/usr/gem/gems/jekyll-4.0.0/exe/jekyll", "build", "-t", "-s", "/input", "-d", "/output"]
diff --git a/images/jekyll/checks.rb b/images/jekyll/checks.rb
new file mode 100644
index 000000000..fc7e6b5a8
--- /dev/null
+++ b/images/jekyll/checks.rb
@@ -0,0 +1,36 @@
+#!/usr/local/bin/ruby
+#
+# HTMLProofer checks for the gVisor website.
+#
+require 'html-proofer'
+
+# NoOpenerCheck checks to make sure links with target=_blank include the
+# rel=noopener attribute.
+class NoOpenerCheck < ::HTMLProofer::Check
+ def run
+ @html.css('a').each do |node|
+ link = create_element(node)
+ line = node.line
+
+ rel = link.respond_to?(:rel) ? link.rel.split(' ') : []
+
+ if link.respond_to?(:target) && link.target == "_blank" && !rel.include?("noopener")
+ return add_issue("You should set rel=noopener for links with target=_blank", line: line)
+ end
+ end
+ end
+end
+
+def main()
+ options = {
+ :check_html => true,
+ :check_favicon => true,
+ :disable_external => true,
+ }
+
+ HTMLProofer.check_directories(ARGV, options).run
+end
+
+if __FILE__ == $0
+ main
+end
diff --git a/images/packetdrill/Dockerfile b/images/packetdrill/Dockerfile
new file mode 100644
index 000000000..01296dbaf
--- /dev/null
+++ b/images/packetdrill/Dockerfile
@@ -0,0 +1,8 @@
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \
+ netcat tcpdump jq tar bison flex make
+RUN hash -r
+RUN git clone --depth 1 --branch packetdrill-v2.0 \
+ https://github.com/google/packetdrill.git
+RUN cd packetdrill/gtests/net/packetdrill && ./configure && make
+CMD /bin/bash
diff --git a/images/packetimpact/Dockerfile b/images/packetimpact/Dockerfile
new file mode 100644
index 000000000..87aa99ef2
--- /dev/null
+++ b/images/packetimpact/Dockerfile
@@ -0,0 +1,16 @@
+FROM ubuntu:bionic
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ # iptables to disable OS native packet processing.
+ iptables \
+ # nc to check that the posix_server is running.
+ netcat \
+ # tcpdump to log brief packet sniffing.
+ tcpdump \
+ # ip link show to display MAC addresses.
+ iproute2 \
+ # tshark to log verbose packet sniffing.
+ tshark \
+ # killall for cleanup.
+ psmisc
+RUN hash -r
+CMD /bin/bash
diff --git a/images/runtimes/go1.12/Dockerfile b/images/runtimes/go1.12/Dockerfile
new file mode 100644
index 000000000..cb2944062
--- /dev/null
+++ b/images/runtimes/go1.12/Dockerfile
@@ -0,0 +1,4 @@
+# Go is easy, since we already have everything we need to compile the proctor
+# binary and run the tests in the golang Docker image.
+FROM golang:1.12
+RUN ["go", "tool", "dist", "test", "-compile-only"]
diff --git a/test/runtimes/images/Dockerfile_java11 b/images/runtimes/java11/Dockerfile
index 9b7c3d5a3..03bc8aaf1 100644
--- a/test/runtimes/images/Dockerfile_java11
+++ b/images/runtimes/java11/Dockerfile
@@ -1,8 +1,3 @@
-# Compile the proctor binary.
-FROM golang:1.12 AS golang
-ADD ["proctor/", "/go/src/proctor/"]
-RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
-
FROM ubuntu:bionic
RUN apt-get update && apt-get install -y \
autoconf \
@@ -25,6 +20,3 @@ RUN set -ex \
RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
RUN tar -xzf jtreg.tar.gz
ENV PATH="/root/jtreg/bin:$PATH"
-
-COPY --from=golang /proctor /proctor
-ENTRYPOINT ["/proctor", "--runtime=java"]
diff --git a/test/runtimes/images/Dockerfile_nodejs12.4.0 b/images/runtimes/nodejs12.4.0/Dockerfile
index 26f68b487..d17924b62 100644
--- a/test/runtimes/images/Dockerfile_nodejs12.4.0
+++ b/images/runtimes/nodejs12.4.0/Dockerfile
@@ -1,8 +1,3 @@
-# Compile the proctor binary.
-FROM golang:1.12 AS golang
-ADD ["proctor/", "/go/src/proctor/"]
-RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
-
FROM ubuntu:bionic
RUN apt-get update && apt-get install -y \
curl \
@@ -21,8 +16,6 @@ RUN ./configure
RUN make
RUN make test-build
-COPY --from=golang /proctor /proctor
-
# Including dumb-init emulates the Linux "init" process, preventing the failure
# of tests involving worker processes.
-ENTRYPOINT ["/usr/bin/dumb-init", "/proctor", "--runtime=nodejs"]
+ENTRYPOINT ["/usr/bin/dumb-init"]
diff --git a/test/runtimes/images/Dockerfile_php7.3.6 b/images/runtimes/php7.3.6/Dockerfile
index e6b4c6329..e5f67f79c 100644
--- a/test/runtimes/images/Dockerfile_php7.3.6
+++ b/images/runtimes/php7.3.6/Dockerfile
@@ -1,8 +1,3 @@
-# Compile the proctor binary.
-FROM golang:1.12 AS golang
-ADD ["proctor/", "/go/src/proctor/"]
-RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
-
FROM ubuntu:bionic
RUN apt-get update && apt-get install -y \
autoconf \
@@ -22,6 +17,3 @@ RUN tar -zxf php-${VERSION}.tar.gz
WORKDIR /root/php-${VERSION}
RUN ./configure
RUN make
-
-COPY --from=golang /proctor /proctor
-ENTRYPOINT ["/proctor", "--runtime=php"]
diff --git a/test/runtimes/images/Dockerfile_python3.7.3 b/images/runtimes/python3.7.3/Dockerfile
index 905cd22d7..4d1e1e221 100644
--- a/test/runtimes/images/Dockerfile_python3.7.3
+++ b/images/runtimes/python3.7.3/Dockerfile
@@ -1,10 +1,4 @@
-# Compile the proctor binary.
-FROM golang:1.12 AS golang
-ADD ["proctor/", "/go/src/proctor/"]
-RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
-
FROM ubuntu:bionic
-
RUN apt-get update && apt-get install -y \
curl \
gcc \
@@ -25,6 +19,3 @@ RUN tar -zxf cpython-${VERSION}.tar.gz
WORKDIR /root/cpython-${VERSION}
RUN ./configure --with-pydebug
RUN make -s -j2
-
-COPY --from=golang /proctor /proctor
-ENTRYPOINT ["/proctor", "--runtime=python"]
diff --git a/kokoro/build.cfg b/kokoro/build.cfg
deleted file mode 100644
index 6c1d262d4..000000000
--- a/kokoro/build.cfg
+++ /dev/null
@@ -1,23 +0,0 @@
-build_file: "repo/scripts/build.sh"
-
-before_action {
- fetch_keystore {
- keystore_resource {
- keystore_config_id: 73898
- keyname: "kokoro-repo-key"
- }
- }
-}
-
-env_vars {
- key: "KOKORO_REPO_KEY"
- value: "73898_kokoro-repo-key"
-}
-
-action {
- define_artifacts {
- regex: "**/runsc"
- regex: "**/runsc.*"
- regex: "**/dists/**"
- }
-}
diff --git a/kokoro/build_tests.cfg b/kokoro/build_tests.cfg
deleted file mode 100644
index c64b7e679..000000000
--- a/kokoro/build_tests.cfg
+++ /dev/null
@@ -1 +0,0 @@
-build_file: "repo/scripts/build.sh"
diff --git a/kokoro/common.cfg b/kokoro/common.cfg
deleted file mode 100644
index 669a2e458..000000000
--- a/kokoro/common.cfg
+++ /dev/null
@@ -1,29 +0,0 @@
-# Give Kokoro access to Remote Build Executor (RBE) service account key.
-before_action {
- fetch_keystore {
- keystore_resource {
- keystore_config_id: 73898
- keyname: "kokoro-rbe-service-account"
- }
- }
-}
-
-# Configure bazel to access RBE.
-bazel_setting {
- # Our GCP project name.
- project_id: "gvisor-rbe"
-
- # Use RBE for execution as well as caching.
- local_execution: false
-
- # This must match the values in the job config.
- auth_credential: {
- keystore_config_id: 73898
- keyname: "kokoro-rbe-service-account"
- }
-
- # Do not change unless you know what you are doing.
- bes_backend_address: "buildeventservice.googleapis.com"
- foundry_backend_address: "remotebuildexecution.googleapis.com"
- upsalite_frontend_address: "https://source.cloud.google.com"
-}
diff --git a/kokoro/do_tests.cfg b/kokoro/do_tests.cfg
deleted file mode 100644
index b45ec0b42..000000000
--- a/kokoro/do_tests.cfg
+++ /dev/null
@@ -1,9 +0,0 @@
-build_file: "repo/scripts/do_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- }
-}
diff --git a/kokoro/docker_tests.cfg b/kokoro/docker_tests.cfg
deleted file mode 100644
index 0a0ef87ed..000000000
--- a/kokoro/docker_tests.cfg
+++ /dev/null
@@ -1,10 +0,0 @@
-build_file: "repo/scripts/docker_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc_logs_*.tar.gz"
- }
-}
diff --git a/kokoro/go.cfg b/kokoro/go.cfg
deleted file mode 100644
index b9c1fcb12..000000000
--- a/kokoro/go.cfg
+++ /dev/null
@@ -1,20 +0,0 @@
-build_file: "repo/scripts/go.sh"
-
-before_action {
- fetch_keystore {
- keystore_resource {
- keystore_config_id: 73898
- keyname: "kokoro-github-access-token"
- }
- }
-}
-
-env_vars {
- key: "KOKORO_GITHUB_ACCESS_TOKEN"
- value: "73898_kokoro-github-access-token"
-}
-
-env_vars {
- key: "KOKORO_GO_PUSH"
- value: "true"
-}
diff --git a/kokoro/go_tests.cfg b/kokoro/go_tests.cfg
deleted file mode 100644
index 5eb51041a..000000000
--- a/kokoro/go_tests.cfg
+++ /dev/null
@@ -1 +0,0 @@
-build_file: "repo/scripts/go.sh"
diff --git a/kokoro/hostnet_tests.cfg b/kokoro/hostnet_tests.cfg
deleted file mode 100644
index 520dc55a3..000000000
--- a/kokoro/hostnet_tests.cfg
+++ /dev/null
@@ -1,10 +0,0 @@
-build_file: "repo/scripts/hostnet_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc_logs_*.tar.gz"
- }
-}
diff --git a/kokoro/kvm_tests.cfg b/kokoro/kvm_tests.cfg
deleted file mode 100644
index 1feb60c8a..000000000
--- a/kokoro/kvm_tests.cfg
+++ /dev/null
@@ -1,10 +0,0 @@
-build_file: "repo/scripts/kvm_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc_logs_*.tar.gz"
- }
-}
diff --git a/kokoro/make_tests.cfg b/kokoro/make_tests.cfg
deleted file mode 100644
index d973130ff..000000000
--- a/kokoro/make_tests.cfg
+++ /dev/null
@@ -1,9 +0,0 @@
-build_file: "repo/scripts/make_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- }
-}
diff --git a/kokoro/overlay_tests.cfg b/kokoro/overlay_tests.cfg
deleted file mode 100644
index 6a2ddbd03..000000000
--- a/kokoro/overlay_tests.cfg
+++ /dev/null
@@ -1,10 +0,0 @@
-build_file: "repo/scripts/overlay_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc_logs_*.tar.gz"
- }
-}
diff --git a/kokoro/release.cfg b/kokoro/release.cfg
deleted file mode 100644
index b9d35bc51..000000000
--- a/kokoro/release.cfg
+++ /dev/null
@@ -1 +0,0 @@
-build_file: "repo/scripts/release.sh"
diff --git a/kokoro/root_tests.cfg b/kokoro/root_tests.cfg
deleted file mode 100644
index 28351695c..000000000
--- a/kokoro/root_tests.cfg
+++ /dev/null
@@ -1,10 +0,0 @@
-build_file: "repo/scripts/root_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc_logs_*.tar.gz"
- }
-}
diff --git a/kokoro/simple_tests.cfg b/kokoro/simple_tests.cfg
deleted file mode 100644
index 32e0a9431..000000000
--- a/kokoro/simple_tests.cfg
+++ /dev/null
@@ -1,9 +0,0 @@
-build_file: "repo/scripts/simple_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- }
-}
diff --git a/kokoro/swgso_tests.cfg b/kokoro/swgso_tests.cfg
deleted file mode 100644
index 101a9c607..000000000
--- a/kokoro/swgso_tests.cfg
+++ /dev/null
@@ -1,9 +0,0 @@
-build_file: "repo/scripts/swgso_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- }
-}
diff --git a/kokoro/syscall_tests.cfg b/kokoro/syscall_tests.cfg
deleted file mode 100644
index ee6e4a3a4..000000000
--- a/kokoro/syscall_tests.cfg
+++ /dev/null
@@ -1,9 +0,0 @@
-build_file: "repo/scripts/syscall_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- }
-}
diff --git a/kokoro/ubuntu1604/30_containerd.sh b/kokoro/ubuntu1604/30_containerd.sh
deleted file mode 100755
index a7472bd1c..000000000
--- a/kokoro/ubuntu1604/30_containerd.sh
+++ /dev/null
@@ -1,76 +0,0 @@
-#!/bin/bash
-
-# Copyright 2019 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-set -xeo pipefail
-
-# Helper for Go packages below.
-install_helper() {
- PACKAGE="${1}"
- TAG="${2}"
- GOPATH="${3}"
-
- # Clone the repository.
- mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \
- git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}"
-
- # Checkout and build the repository.
- (cd "${GOPATH}"/src/"${PACKAGE}" && \
- git checkout "${TAG}" && \
- GOPATH="${GOPATH}" make && \
- GOPATH="${GOPATH}" make install)
-}
-
-# Install dependencies for the crictl tests.
-apt-get install -y btrfs-tools libseccomp-dev
-
-# Install containerd & cri-tools.
-GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
-install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}"
-install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}"
-
-# Install gvisor-containerd-shim.
-declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
-declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
-declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
-wget --no-verbose "${base}"/latest -O ${latest}
-wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
-chmod +x ${shim_path}
-mv ${shim_path} /usr/local/bin
-
-# Configure containerd-shim.
-declare -r shim_config_path=/etc/containerd
-declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml)
-mkdir -p ${shim_config_path}
-cat > ${shim_config_tmp_path} <<-EOF
- runc_shim = "/usr/local/bin/containerd-shim"
-
-[runsc_config]
- debug = "true"
- debug-log = "/tmp/runsc-logs/"
- strace = "true"
- file-access = "shared"
-EOF
-mv ${shim_config_tmp_path} ${shim_config_path}
-
-# Configure CNI.
-(cd "${GOPATH}" && GOPATH="${GOPATH}" \
- src/github.com/containerd/containerd/script/setup/install-cni)
-
-# Cleanup the above.
-rm -rf "${GOPATH}"
-rm -rf "${latest}"
-rm -rf "${shim_path}"
-rm -rf "${shim_config_tmp_path}"
diff --git a/kokoro/ubuntu1804/10_core.sh b/kokoro/ubuntu1804/10_core.sh
deleted file mode 120000
index 6facceeee..000000000
--- a/kokoro/ubuntu1804/10_core.sh
+++ /dev/null
@@ -1 +0,0 @@
-../ubuntu1604/10_core.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/20_bazel.sh b/kokoro/ubuntu1804/20_bazel.sh
deleted file mode 120000
index 39194c0f5..000000000
--- a/kokoro/ubuntu1804/20_bazel.sh
+++ /dev/null
@@ -1 +0,0 @@
-../ubuntu1604/20_bazel.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/25_docker.sh b/kokoro/ubuntu1804/25_docker.sh
deleted file mode 120000
index 63269bd83..000000000
--- a/kokoro/ubuntu1804/25_docker.sh
+++ /dev/null
@@ -1 +0,0 @@
-../ubuntu1604/25_docker.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/30_containerd.sh b/kokoro/ubuntu1804/30_containerd.sh
deleted file mode 120000
index 6ac2377ed..000000000
--- a/kokoro/ubuntu1804/30_containerd.sh
+++ /dev/null
@@ -1 +0,0 @@
-../ubuntu1604/30_containerd.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/40_kokoro.sh b/kokoro/ubuntu1804/40_kokoro.sh
deleted file mode 120000
index e861fb5e1..000000000
--- a/kokoro/ubuntu1804/40_kokoro.sh
+++ /dev/null
@@ -1 +0,0 @@
-../ubuntu1604/40_kokoro.sh \ No newline at end of file
diff --git a/pkg/abi/BUILD b/pkg/abi/BUILD
index f5c08ea06..839f822eb 100644
--- a/pkg/abi/BUILD
+++ b/pkg/abi/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,6 +9,5 @@ go_library(
"abi_linux.go",
"flag.go",
],
- importpath = "gvisor.dev/gvisor/pkg/abi",
visibility = ["//:sandbox"],
)
diff --git a/pkg/abi/abi.go b/pkg/abi/abi.go
index d56c481c9..e6be93c3a 100644
--- a/pkg/abi/abi.go
+++ b/pkg/abi/abi.go
@@ -39,3 +39,7 @@ func (o OS) String() string {
return fmt.Sprintf("OS(%d)", o)
}
}
+
+// ABI is an interface that defines OS-specific interactions.
+type ABI interface {
+}
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index 51774c6b6..b5c5cc20b 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
# Package linux contains the constants and types needed to interface with a
# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
@@ -11,6 +10,7 @@ go_library(
name = "linux",
srcs = [
"aio.go",
+ "arch_amd64.go",
"audit.go",
"bpf.go",
"capability.go",
@@ -18,17 +18,22 @@ go_library(
"dev.go",
"elf.go",
"epoll.go",
+ "epoll_amd64.go",
+ "epoll_arm64.go",
"errors.go",
"eventfd.go",
"exec.go",
+ "fadvise.go",
"fcntl.go",
"file.go",
"file_amd64.go",
"file_arm64.go",
"fs.go",
+ "fuse.go",
"futex.go",
"inotify.go",
"ioctl.go",
+ "ioctl_tun.go",
"ip.go",
"ipc.go",
"limits.go",
@@ -36,11 +41,15 @@ go_library(
"mm.go",
"netdevice.go",
"netfilter.go",
+ "netfilter_ipv6.go",
"netlink.go",
"netlink_route.go",
"poll.go",
"prctl.go",
"ptrace.go",
+ "ptrace_amd64.go",
+ "ptrace_arm64.go",
+ "rseq.go",
"rusage.go",
"sched.go",
"seccomp.go",
@@ -57,13 +66,17 @@ go_library(
"uio.go",
"utsname.go",
"wait.go",
+ "xattr.go",
],
- importpath = "gvisor.dev/gvisor/pkg/abi/linux",
+ marshal = True,
visibility = ["//visibility:public"],
deps = [
"//pkg/abi",
"//pkg/binary",
"//pkg/bits",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
@@ -71,7 +84,7 @@ go_test(
name = "linux_test",
size = "small",
srcs = ["netfilter_test.go"],
- embed = [":linux"],
+ library = ":linux",
deps = [
"//pkg/binary",
],
diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go
index 3c6e0079d..86ee3f8b5 100644
--- a/pkg/abi/linux/aio.go
+++ b/pkg/abi/linux/aio.go
@@ -14,7 +14,63 @@
package linux
+import "encoding/binary"
+
+// AIORingSize is sizeof(struct aio_ring).
+const AIORingSize = 32
+
+// I/O commands.
const (
- // AIORingSize is sizeof(struct aio_ring).
- AIORingSize = 32
+ IOCB_CMD_PREAD = 0
+ IOCB_CMD_PWRITE = 1
+ IOCB_CMD_FSYNC = 2
+ IOCB_CMD_FDSYNC = 3
+ // 4 was the experimental IOCB_CMD_PREADX.
+ IOCB_CMD_POLL = 5
+ IOCB_CMD_NOOP = 6
+ IOCB_CMD_PREADV = 7
+ IOCB_CMD_PWRITEV = 8
)
+
+// I/O flags.
+const (
+ IOCB_FLAG_RESFD = 1
+ IOCB_FLAG_IOPRIO = 2
+)
+
+// IOCallback describes an I/O request.
+//
+// The priority field is currently ignored in the implementation below. Also
+// note that the IOCB_FLAG_RESFD feature is not supported.
+type IOCallback struct {
+ Data uint64
+ Key uint32
+ _ uint32
+
+ OpCode uint16
+ ReqPrio int16
+ FD int32
+
+ Buf uint64
+ Bytes uint64
+ Offset int64
+
+ Reserved2 uint64
+ Flags uint32
+
+ // eventfd to signal if IOCB_FLAG_RESFD is set in flags.
+ ResFD int32
+}
+
+// IOEvent describes an I/O result.
+//
+// +stateify savable
+type IOEvent struct {
+ Data uint64
+ Obj uint64
+ Result int64
+ Result2 int64
+}
+
+// IOEventSize is the size of an ioEvent encoded.
+var IOEventSize = binary.Size(IOEvent{})
diff --git a/pkg/abi/linux/arch_amd64.go b/pkg/abi/linux/arch_amd64.go
new file mode 100644
index 000000000..0be31e755
--- /dev/null
+++ b/pkg/abi/linux/arch_amd64.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// Start and end addresses of the vsyscall page.
+const (
+ VSyscallStartAddr uint64 = 0xffffffffff600000
+ VSyscallEndAddr uint64 = 0xffffffffff601000
+)
diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go
index 421e11256..192e2093b 100644
--- a/pkg/abi/linux/dev.go
+++ b/pkg/abi/linux/dev.go
@@ -36,9 +36,20 @@ func DecodeDeviceID(rdev uint32) (uint16, uint32) {
//
// See Documentations/devices.txt and uapi/linux/major.h.
const (
+ // UNNAMED_MAJOR is the major device number for "unnamed" devices, whose
+ // minor numbers are dynamically allocated by the kernel.
+ UNNAMED_MAJOR = 0
+
+ // MEM_MAJOR is the major device number for "memory" character devices.
+ MEM_MAJOR = 1
+
// TTYAUX_MAJOR is the major device number for alternate TTY devices.
TTYAUX_MAJOR = 5
+ // MISC_MAJOR is the major device number for non-serial mice, misc feature
+ // devices.
+ MISC_MAJOR = 10
+
// UNIX98_PTY_MASTER_MAJOR is the initial major device number for
// Unix98 PTY masters.
UNIX98_PTY_MASTER_MAJOR = 128
diff --git a/pkg/abi/linux/elf.go b/pkg/abi/linux/elf.go
index 40f0459a0..7c9a02f20 100644
--- a/pkg/abi/linux/elf.go
+++ b/pkg/abi/linux/elf.go
@@ -102,4 +102,7 @@ const (
// NT_X86_XSTATE is for x86 extended state using xsave.
NT_X86_XSTATE = 0x202
+
+ // NT_ARM_TLS is for ARM TLS register.
+ NT_ARM_TLS = 0x401
)
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
index 72083b604..1121a1a92 100644
--- a/pkg/abi/linux/epoll.go
+++ b/pkg/abi/linux/epoll.go
@@ -14,12 +14,9 @@
package linux
-// EpollEvent is equivalent to struct epoll_event from epoll(2).
-type EpollEvent struct {
- Events uint32
- Fd int32
- Data int32
-}
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+)
// Event masks.
const (
@@ -38,8 +35,14 @@ const (
// Per-file descriptor flags.
const (
- EPOLLET = 0x80000000
- EPOLLONESHOT = 0x40000000
+ EPOLLEXCLUSIVE = 1 << 28
+ EPOLLWAKEUP = 1 << 29
+ EPOLLONESHOT = 1 << 30
+ EPOLLET = 1 << 31
+
+ // EP_PRIVATE_BITS is fs/eventpoll.c:EP_PRIVATE_BITS, the set of all bits
+ // in an epoll event mask that correspond to flags rather than I/O events.
+ EP_PRIVATE_BITS = EPOLLEXCLUSIVE | EPOLLWAKEUP | EPOLLONESHOT | EPOLLET
)
// Operation flags.
@@ -54,3 +57,6 @@ const (
EPOLL_CTL_DEL = 0x2
EPOLL_CTL_MOD = 0x3
)
+
+// SizeOfEpollEvent is the size of EpollEvent struct.
+var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
diff --git a/pkg/sentry/fsimpl/proc/mounts.go b/pkg/abi/linux/epoll_amd64.go
index e81b1e910..7e74b1143 100644
--- a/pkg/sentry/fsimpl/proc/mounts.go
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -12,22 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package proc
+// +build amd64
-import "gvisor.dev/gvisor/pkg/sentry/kernel"
+package linux
-// TODO(b/138862512): Implement mountInfoFile and mountsFile.
-
-// mountInfoFile implements vfs.DynamicBytesSource for /proc/[pid]/mountinfo.
-//
-// +stateify savable
-type mountInfoFile struct {
- t *kernel.Task
-}
-
-// mountsFile implements vfs.DynamicBytesSource for /proc/[pid]/mounts.
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
//
-// +stateify savable
-type mountsFile struct {
- t *kernel.Task
+// +marshal slice:EpollEventSlice
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event::data a __u64. We represent it as
+ // [2]int32 because, on amd64, Linux also makes struct epoll_event
+ // __attribute__((packed)), such that there is no padding between Events
+ // and Data.
+ Data [2]int32
}
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
new file mode 100644
index 000000000..a35939cc9
--- /dev/null
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal slice:EpollEventSlice
+type EpollEvent struct {
+ Events uint32
+ // Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
+ // here.
+ _ int32
+ Data [2]int32
+}
diff --git a/pkg/abi/linux/fadvise.go b/pkg/abi/linux/fadvise.go
new file mode 100644
index 000000000..b06ff9964
--- /dev/null
+++ b/pkg/abi/linux/fadvise.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ POSIX_FADV_NORMAL = 0
+ POSIX_FADV_RANDOM = 1
+ POSIX_FADV_SEQUENTIAL = 2
+ POSIX_FADV_WILLNEED = 3
+ POSIX_FADV_DONTNEED = 4
+ POSIX_FADV_NOREUSE = 5
+)
diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go
index f78315ebf..9242e80a5 100644
--- a/pkg/abi/linux/fcntl.go
+++ b/pkg/abi/linux/fcntl.go
@@ -16,15 +16,17 @@ package linux
// Commands from linux/fcntl.h.
const (
- F_DUPFD = 0x0
- F_GETFD = 0x1
- F_SETFD = 0x2
- F_GETFL = 0x3
- F_SETFL = 0x4
- F_SETLK = 0x6
- F_SETLKW = 0x7
- F_SETOWN = 0x8
- F_GETOWN = 0x9
+ F_DUPFD = 0
+ F_GETFD = 1
+ F_SETFD = 2
+ F_GETFL = 3
+ F_SETFL = 4
+ F_SETLK = 6
+ F_SETLKW = 7
+ F_SETOWN = 8
+ F_GETOWN = 9
+ F_SETOWN_EX = 15
+ F_GETOWN_EX = 16
F_DUPFD_CLOEXEC = 1024 + 6
F_SETPIPE_SZ = 1024 + 7
F_GETPIPE_SZ = 1024 + 8
@@ -32,9 +34,9 @@ const (
// Commands for F_SETLK.
const (
- F_RDLCK = 0x0
- F_WRLCK = 0x1
- F_UNLCK = 0x2
+ F_RDLCK = 0
+ F_WRLCK = 1
+ F_UNLCK = 2
)
// Flags for fcntl.
@@ -42,7 +44,7 @@ const (
FD_CLOEXEC = 00000001
)
-// Lock structure for F_SETLK.
+// Flock is the lock structure for F_SETLK.
type Flock struct {
Type int16
Whence int16
@@ -52,3 +54,16 @@ type Flock struct {
Pid int32
_ [4]byte
}
+
+// Owner types for F_SETOWN_EX and F_GETOWN_EX.
+const (
+ F_OWNER_TID = 0
+ F_OWNER_PID = 1
+ F_OWNER_PGRP = 2
+)
+
+// FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX.
+type FOwnerEx struct {
+ Type int32
+ PID int32
+}
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index c9ee098f4..e11ca2d62 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -24,27 +24,23 @@ import (
// Constants for open(2).
const (
- O_ACCMODE = 000000003
- O_RDONLY = 000000000
- O_WRONLY = 000000001
- O_RDWR = 000000002
- O_CREAT = 000000100
- O_EXCL = 000000200
- O_NOCTTY = 000000400
- O_TRUNC = 000001000
- O_APPEND = 000002000
- O_NONBLOCK = 000004000
- O_DSYNC = 000010000
- O_ASYNC = 000020000
- O_DIRECT = 000040000
- O_LARGEFILE = 000100000
- O_DIRECTORY = 000200000
- O_NOFOLLOW = 000400000
- O_NOATIME = 001000000
- O_CLOEXEC = 002000000
- O_SYNC = 004000000 // __O_SYNC in Linux
- O_PATH = 010000000
- O_TMPFILE = 020000000 // __O_TMPFILE in Linux
+ O_ACCMODE = 000000003
+ O_RDONLY = 000000000
+ O_WRONLY = 000000001
+ O_RDWR = 000000002
+ O_CREAT = 000000100
+ O_EXCL = 000000200
+ O_NOCTTY = 000000400
+ O_TRUNC = 000001000
+ O_APPEND = 000002000
+ O_NONBLOCK = 000004000
+ O_DSYNC = 000010000
+ O_ASYNC = 000020000
+ O_NOATIME = 001000000
+ O_CLOEXEC = 002000000
+ O_SYNC = 004000000 // __O_SYNC in Linux
+ O_PATH = 010000000
+ O_TMPFILE = 020000000 // __O_TMPFILE in Linux
)
// Constants for fstatat(2).
@@ -144,9 +140,13 @@ const (
ModeCharacterDevice = S_IFCHR
ModeNamedPipe = S_IFIFO
- ModeSetUID = 04000
- ModeSetGID = 02000
- ModeSticky = 01000
+ S_ISUID = 04000
+ S_ISGID = 02000
+ S_ISVTX = 01000
+
+ ModeSetUID = S_ISUID
+ ModeSetGID = S_ISGID
+ ModeSticky = S_ISVTX
ModeUserAll = 0700
ModeUserRead = 0400
@@ -176,10 +176,24 @@ const (
DT_WHT = 14
)
+// DirentType are the friendly strings for linux_dirent64.d_type.
+var DirentType = abi.ValueSet{
+ DT_UNKNOWN: "DT_UNKNOWN",
+ DT_FIFO: "DT_FIFO",
+ DT_CHR: "DT_CHR",
+ DT_DIR: "DT_DIR",
+ DT_BLK: "DT_BLK",
+ DT_REG: "DT_REG",
+ DT_LNK: "DT_LNK",
+ DT_SOCK: "DT_SOCK",
+ DT_WHT: "DT_WHT",
+}
+
// Values for preadv2/pwritev2.
const (
- // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is
- // accepted as a valid flag argument for preadv2/pwritev2.
+ // NOTE(b/120162627): gVisor does not implement the RWF_HIPRI feature, but
+ // the flag is accepted as a valid flag argument for preadv2/pwritev2 and
+ // silently ignored.
RWF_HIPRI = 0x00000001
RWF_DSYNC = 0x00000002
RWF_SYNC = 0x00000004
@@ -228,6 +242,8 @@ const (
)
// Statx represents struct statx.
+//
+// +marshal
type Statx struct {
Mask uint32
Blksize uint32
@@ -251,6 +267,9 @@ type Statx struct {
DevMinor uint32
}
+// SizeOfStatx is the size of a Statx struct.
+var SizeOfStatx = binary.Size(Statx{})
+
// FileMode represents a mode_t.
type FileMode uint16
@@ -269,6 +288,11 @@ func (m FileMode) ExtraBits() FileMode {
return m &^ (PermissionsMask | FileTypeMask)
}
+// IsDir returns true if file type represents a directory.
+func (m FileMode) IsDir() bool {
+ return m.FileType() == S_IFDIR
+}
+
// String returns a string representation of m.
func (m FileMode) String() string {
var s []string
@@ -282,6 +306,29 @@ func (m FileMode) String() string {
return strings.Join(s, "|")
}
+// DirentType maps file types to dirent types appropriate for (struct
+// dirent)::d_type.
+func (m FileMode) DirentType() uint8 {
+ switch m.FileType() {
+ case ModeSocket:
+ return DT_SOCK
+ case ModeSymlink:
+ return DT_LNK
+ case ModeRegular:
+ return DT_REG
+ case ModeBlockDevice:
+ return DT_BLK
+ case ModeDirectory:
+ return DT_DIR
+ case ModeCharacterDevice:
+ return DT_CHR
+ case ModeNamedPipe:
+ return DT_FIFO
+ default:
+ return DT_UNKNOWN
+ }
+}
+
var modeExtraBits = abi.FlagSet{
{
Flag: ModeSetUID,
diff --git a/pkg/abi/linux/file_amd64.go b/pkg/abi/linux/file_amd64.go
index 74c554be6..6b72364ea 100644
--- a/pkg/abi/linux/file_amd64.go
+++ b/pkg/abi/linux/file_amd64.go
@@ -12,9 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package linux
+// Constants for open(2).
+const (
+ O_DIRECT = 000040000
+ O_LARGEFILE = 000100000
+ O_DIRECTORY = 000200000
+ O_NOFOLLOW = 000400000
+)
+
// Stat represents struct stat.
+//
+// +marshal
type Stat struct {
Dev uint64
Ino uint64
diff --git a/pkg/abi/linux/file_arm64.go b/pkg/abi/linux/file_arm64.go
index f16c07589..6492c9038 100644
--- a/pkg/abi/linux/file_arm64.go
+++ b/pkg/abi/linux/file_arm64.go
@@ -12,9 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package linux
+// Constants for open(2).
+const (
+ O_DIRECTORY = 000040000
+ O_NOFOLLOW = 000100000
+ O_DIRECT = 000200000
+ O_LARGEFILE = 000400000
+)
+
// Stat represents struct stat.
+//
+// +marshal
type Stat struct {
Dev uint64
Ino uint64
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index b416e3472..158d2db5b 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -38,6 +38,8 @@ const (
)
// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+//
+// +marshal
type Statfs struct {
// Type is one of the filesystem magic values, defined above.
Type uint64
@@ -92,3 +94,10 @@ const (
SYNC_FILE_RANGE_WRITE = 2
SYNC_FILE_RANGE_WAIT_AFTER = 4
)
+
+// Flag argument to renameat2(2), from include/uapi/linux/fs.h.
+const (
+ RENAME_NOREPLACE = (1 << 0) // Don't overwrite target.
+ RENAME_EXCHANGE = (1 << 1) // Exchange src and dst.
+ RENAME_WHITEOUT = (1 << 2) // Whiteout src.
+)
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
new file mode 100644
index 000000000..7e30483ee
--- /dev/null
+++ b/pkg/abi/linux/fuse.go
@@ -0,0 +1,303 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// +marshal
+type FUSEOpcode uint32
+
+// +marshal
+type FUSEOpID uint64
+
+// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
+const (
+ FUSE_LOOKUP FUSEOpcode = 1
+ FUSE_FORGET = 2 /* no reply */
+ FUSE_GETATTR = 3
+ FUSE_SETATTR = 4
+ FUSE_READLINK = 5
+ FUSE_SYMLINK = 6
+ _
+ FUSE_MKNOD = 8
+ FUSE_MKDIR = 9
+ FUSE_UNLINK = 10
+ FUSE_RMDIR = 11
+ FUSE_RENAME = 12
+ FUSE_LINK = 13
+ FUSE_OPEN = 14
+ FUSE_READ = 15
+ FUSE_WRITE = 16
+ FUSE_STATFS = 17
+ FUSE_RELEASE = 18
+ _
+ FUSE_FSYNC = 20
+ FUSE_SETXATTR = 21
+ FUSE_GETXATTR = 22
+ FUSE_LISTXATTR = 23
+ FUSE_REMOVEXATTR = 24
+ FUSE_FLUSH = 25
+ FUSE_INIT = 26
+ FUSE_OPENDIR = 27
+ FUSE_READDIR = 28
+ FUSE_RELEASEDIR = 29
+ FUSE_FSYNCDIR = 30
+ FUSE_GETLK = 31
+ FUSE_SETLK = 32
+ FUSE_SETLKW = 33
+ FUSE_ACCESS = 34
+ FUSE_CREATE = 35
+ FUSE_INTERRUPT = 36
+ FUSE_BMAP = 37
+ FUSE_DESTROY = 38
+ FUSE_IOCTL = 39
+ FUSE_POLL = 40
+ FUSE_NOTIFY_REPLY = 41
+ FUSE_BATCH_FORGET = 42
+)
+
+const (
+ // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem.
+ // This is the minimum size Linux supports. See linux.fuse.h.
+ FUSE_MIN_READ_BUFFER uint32 = 8192
+)
+
+// FUSEHeaderIn is the header read by the daemon with each request.
+//
+// +marshal
+type FUSEHeaderIn struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Opcode specifies the kind of operation of the request.
+ Opcode FUSEOpcode
+
+ // Unique specifies the unique identifier for this request.
+ Unique FUSEOpID
+
+ // NodeID is the ID of the filesystem object being operated on.
+ NodeID uint64
+
+ // UID is the UID of the requesting process.
+ UID uint32
+
+ // GID is the GID of the requesting process.
+ GID uint32
+
+ // PID is the PID of the requesting process.
+ PID uint32
+
+ _ uint32
+}
+
+// FUSEHeaderOut is the header written by the daemon when it processes
+// a request and wants to send a reply (almost all operations require a
+// reply; if they do not, this will be explicitly documented).
+//
+// +marshal
+type FUSEHeaderOut struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Error specifies the error that occurred (0 if none).
+ Error int32
+
+ // Unique specifies the unique identifier of the corresponding request.
+ Unique FUSEOpID
+}
+
+// FUSEWriteIn is the header written by a daemon when it makes a
+// write request to the FUSE filesystem.
+//
+// +marshal
+type FUSEWriteIn struct {
+ // Fh specifies the file handle that is being written to.
+ Fh uint64
+
+ // Offset is the offset of the write.
+ Offset uint64
+
+ // Size is the size of data being written.
+ Size uint32
+
+ // WriteFlags is the flags used during the write.
+ WriteFlags uint32
+
+ // LockOwner is the ID of the lock owner.
+ LockOwner uint64
+
+ // Flags is the flags for the request.
+ Flags uint32
+
+ _ uint32
+}
+
+// FUSE_INIT flags, consistent with the ones in include/uapi/linux/fuse.h.
+const (
+ FUSE_ASYNC_READ = 1 << 0
+ FUSE_POSIX_LOCKS = 1 << 1
+ FUSE_FILE_OPS = 1 << 2
+ FUSE_ATOMIC_O_TRUNC = 1 << 3
+ FUSE_EXPORT_SUPPORT = 1 << 4
+ FUSE_BIG_WRITES = 1 << 5
+ FUSE_DONT_MASK = 1 << 6
+ FUSE_SPLICE_WRITE = 1 << 7
+ FUSE_SPLICE_MOVE = 1 << 8
+ FUSE_SPLICE_READ = 1 << 9
+ FUSE_FLOCK_LOCKS = 1 << 10
+ FUSE_HAS_IOCTL_DIR = 1 << 11
+ FUSE_AUTO_INVAL_DATA = 1 << 12
+ FUSE_DO_READDIRPLUS = 1 << 13
+ FUSE_READDIRPLUS_AUTO = 1 << 14
+ FUSE_ASYNC_DIO = 1 << 15
+ FUSE_WRITEBACK_CACHE = 1 << 16
+ FUSE_NO_OPEN_SUPPORT = 1 << 17
+ FUSE_PARALLEL_DIROPS = 1 << 18
+ FUSE_HANDLE_KILLPRIV = 1 << 19
+ FUSE_POSIX_ACL = 1 << 20
+ FUSE_ABORT_ERROR = 1 << 21
+ FUSE_MAX_PAGES = 1 << 22
+ FUSE_CACHE_SYMLINKS = 1 << 23
+ FUSE_NO_OPENDIR_SUPPORT = 1 << 24
+ FUSE_EXPLICIT_INVAL_DATA = 1 << 25
+ FUSE_MAP_ALIGNMENT = 1 << 26
+)
+
+// currently supported FUSE protocol version numbers.
+const (
+ FUSE_KERNEL_VERSION = 7
+ FUSE_KERNEL_MINOR_VERSION = 31
+)
+
+// FUSEInitIn is the request sent by the kernel to the daemon,
+// to negotiate the version and flags.
+//
+// +marshal
+type FUSEInitIn struct {
+ // Major version supported by kernel.
+ Major uint32
+
+ // Minor version supported by the kernel.
+ Minor uint32
+
+ // MaxReadahead is the maximum number of bytes to read-ahead
+ // decided by the kernel.
+ MaxReadahead uint32
+
+ // Flags of this init request.
+ Flags uint32
+}
+
+// FUSEInitOut is the reply sent by the daemon to the kernel
+// for FUSEInitIn.
+//
+// +marshal
+type FUSEInitOut struct {
+ // Major version supported by daemon.
+ Major uint32
+
+ // Minor version supported by daemon.
+ Minor uint32
+
+ // MaxReadahead is the maximum number of bytes to read-ahead.
+ // Decided by the daemon, after receiving the value from kernel.
+ MaxReadahead uint32
+
+ // Flags of this init reply.
+ Flags uint32
+
+ // MaxBackground is the maximum number of pending background requests
+ // that the daemon wants.
+ MaxBackground uint16
+
+ // CongestionThreshold is the daemon-decided threshold for
+ // the number of the pending background requests.
+ CongestionThreshold uint16
+
+ // MaxWrite is the daemon's maximum size of a write buffer.
+ // Kernel adjusts it to the minimum (fuse/init.go:fuseMinMaxWrite).
+ // if the value from daemon is too small.
+ MaxWrite uint32
+
+ // TimeGran is the daemon's time granularity for mtime and ctime metadata.
+ // The unit is nanosecond.
+ // Value should be power of 10.
+ // 1 indicates full nanosecond granularity support.
+ TimeGran uint32
+
+ // MaxPages is the daemon's maximum number of pages for one write operation.
+ // Kernel adjusts it to the maximum (fuse/init.go:FUSE_MAX_MAX_PAGES).
+ // if the value from daemon is too large.
+ MaxPages uint16
+
+ // MapAlignment is an unknown field and not used by this package at this moment.
+ // Use as a placeholder to be consistent with the FUSE protocol.
+ MapAlignment uint16
+
+ _ [8]uint32
+}
+
+// FUSEGetAttrIn is the request sent by the kernel to the daemon,
+// to get the attribute of a inode.
+//
+// +marshal
+type FUSEGetAttrIn struct {
+ // GetAttrFlags specifies whether getattr request is sent with a nodeid or
+ // with a file handle.
+ GetAttrFlags uint32
+
+ _ uint32
+
+ // Fh is the file handler when GetAttrFlags has FUSE_GETATTR_FH bit. If
+ // used, the operation is analogous to fstat(2).
+ Fh uint64
+}
+
+// FUSEAttr is the struct used in the reponse FUSEGetAttrOut.
+//
+// +marshal
+type FUSEAttr struct {
+ Ino uint64
+ Size uint64
+ Blocks uint64
+ Atime uint64
+ Mtime uint64
+ Ctime uint64
+ AtimeNsec uint32
+ MtimeNsec uint32
+ CtimeNsec uint32
+ Mode uint32
+ Nlink uint32
+ UID uint32
+ GID uint32
+ Rdev uint32
+ BlkSize uint32
+ _ uint32
+}
+
+// FUSEGetAttrOut is the reply sent by the daemon to the kernel
+// for FUSEGetAttrIn.
+//
+// +marshal
+type FUSEGetAttrOut struct {
+ // AttrValid and AttrValidNsec describe the attribute cache duration
+ AttrValid uint64
+
+ // AttrValidNsec is the nanosecond part of the attribute cache duration
+ AttrValidNsec uint32
+
+ _ uint32
+
+ // Attr contains the metadata returned from the FUSE server
+ Attr FUSEAttr
+}
diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go
index 08bfde3b5..8138088a6 100644
--- a/pkg/abi/linux/futex.go
+++ b/pkg/abi/linux/futex.go
@@ -60,3 +60,21 @@ const (
FUTEX_WAITERS = 0x80000000
FUTEX_OWNER_DIED = 0x40000000
)
+
+// FUTEX_BITSET_MATCH_ANY has all bits set.
+const FUTEX_BITSET_MATCH_ANY = 0xffffffff
+
+// ROBUST_LIST_LIMIT protects against a deliberately circular list.
+const ROBUST_LIST_LIMIT = 2048
+
+// RobustListHead corresponds to Linux's struct robust_list_head.
+//
+// +marshal
+type RobustListHead struct {
+ List uint64
+ FutexOffset uint64
+ ListOpPending uint64
+}
+
+// SizeOfRobustListHead is the size of a RobustListHead struct.
+var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes()
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 0e18db9ef..2c5e56ae5 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -67,8 +67,53 @@ const (
// ioctl(2) requests provided by uapi/linux/sockios.h
const (
- SIOCGIFMEM = 0x891f
- SIOCGIFPFLAGS = 0x8935
- SIOCGMIIPHY = 0x8947
- SIOCGMIIREG = 0x8948
+ SIOCGIFNAME = 0x8910
+ SIOCGIFCONF = 0x8912
+ SIOCGIFFLAGS = 0x8913
+ SIOCGIFADDR = 0x8915
+ SIOCGIFDSTADDR = 0x8917
+ SIOCGIFBRDADDR = 0x8919
+ SIOCGIFNETMASK = 0x891b
+ SIOCGIFMETRIC = 0x891d
+ SIOCGIFMTU = 0x8921
+ SIOCGIFMEM = 0x891f
+ SIOCGIFHWADDR = 0x8927
+ SIOCGIFINDEX = 0x8933
+ SIOCGIFPFLAGS = 0x8935
+ SIOCGIFTXQLEN = 0x8942
+ SIOCETHTOOL = 0x8946
+ SIOCGMIIPHY = 0x8947
+ SIOCGMIIREG = 0x8948
+ SIOCGIFMAP = 0x8970
)
+
+// ioctl(2) requests provided by uapi/asm-generic/sockios.h
+const (
+ SIOCGSTAMP = 0x8906
+)
+
+// ioctl(2) directions. Used to calculate requests number.
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NONE = 0
+ _IOC_WRITE = 1
+ _IOC_READ = 2
+)
+
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NRBITS = 8
+ _IOC_TYPEBITS = 8
+ _IOC_SIZEBITS = 14
+ _IOC_DIRBITS = 2
+
+ _IOC_NRSHIFT = 0
+ _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS
+ _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS
+ _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
+)
+
+// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
+func IOC(dir, typ, nr, size uint32) uint32 {
+ return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
+}
diff --git a/pkg/abi/linux/ioctl_tun.go b/pkg/abi/linux/ioctl_tun.go
new file mode 100644
index 000000000..c59c9c136
--- /dev/null
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// ioctl(2) request numbers from linux/if_tun.h
+var (
+ TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4)
+ TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4)
+)
+
+// Flags from net/if_tun.h
+const (
+ IFF_TUN = 0x0001
+ IFF_TAP = 0x0002
+ IFF_NO_PI = 0x1000
+ IFF_NOFILTER = 0x1000
+)
diff --git a/pkg/abi/linux/ip.go b/pkg/abi/linux/ip.go
index 31e56ffa6..ef6d1093e 100644
--- a/pkg/abi/linux/ip.go
+++ b/pkg/abi/linux/ip.go
@@ -92,6 +92,16 @@ const (
IP_UNICAST_IF = 50
)
+// IP_MTU_DISCOVER values from uapi/linux/in.h
+const (
+ IP_PMTUDISC_DONT = 0
+ IP_PMTUDISC_WANT = 1
+ IP_PMTUDISC_DO = 2
+ IP_PMTUDISC_PROBE = 3
+ IP_PMTUDISC_INTERFACE = 4
+ IP_PMTUDISC_OMIT = 5
+)
+
// Socket options from uapi/linux/in6.h
const (
IPV6_ADDRFORM = 1
diff --git a/pkg/abi/linux/mm.go b/pkg/abi/linux/mm.go
index cd043dac3..07cc1895e 100644
--- a/pkg/abi/linux/mm.go
+++ b/pkg/abi/linux/mm.go
@@ -90,14 +90,19 @@ const (
MS_SYNC = 1 << 2
)
+// NumaPolicy is the NUMA memory policy for a memory range. See numa(7).
+//
+// +marshal
+type NumaPolicy int32
+
// Policies for get_mempolicy(2)/set_mempolicy(2).
const (
- MPOL_DEFAULT = 0
- MPOL_PREFERRED = 1
- MPOL_BIND = 2
- MPOL_INTERLEAVE = 3
- MPOL_LOCAL = 4
- MPOL_MAX = 5
+ MPOL_DEFAULT NumaPolicy = 0
+ MPOL_PREFERRED NumaPolicy = 1
+ MPOL_BIND NumaPolicy = 2
+ MPOL_INTERLEAVE NumaPolicy = 3
+ MPOL_LOCAL NumaPolicy = 4
+ MPOL_MAX NumaPolicy = 5
)
// Flags for get_mempolicy(2).
diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go
index 7866352b4..0faf015c7 100644
--- a/pkg/abi/linux/netdevice.go
+++ b/pkg/abi/linux/netdevice.go
@@ -22,6 +22,8 @@ const (
)
// IFReq is an interface request.
+//
+// +marshal
type IFReq struct {
// IFName is an encoded name, normally null-terminated. This should be
// accessed via the Name and SetName functions.
@@ -79,6 +81,8 @@ type IFMap struct {
// IFConf is used to return a list of interfaces and their addresses. See
// netdevice(7) and struct ifconf for more detail on its use.
+//
+// +marshal
type IFConf struct {
Len int32
_ [4]byte // Pad to sizeof(struct ifconf).
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 269ba5567..91e35366f 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -14,6 +14,14 @@
package linux
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
// This file contains structures required to support netfilter, specifically
// the iptables tool.
@@ -42,7 +50,16 @@ const (
NF_RETURN = -NF_REPEAT - 1
)
-// Socket options. These correspond to values in
+// VerdictStrings maps int verdicts to the strings they represent. It is used
+// for debugging.
+var VerdictStrings = map[int32]string{
+ -NF_DROP - 1: "DROP",
+ -NF_ACCEPT - 1: "ACCEPT",
+ -NF_QUEUE - 1: "QUEUE",
+ NF_RETURN: "RETURN",
+}
+
+// Socket options for SOL_SOCKET. These correspond to values in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
const (
IPT_BASE_CTL = 64
@@ -57,6 +74,12 @@ const (
IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET
)
+// Socket option for SOL_IP. This corresponds to the value in
+// include/uapi/linux/netfilter_ipv4.h.
+const (
+ SO_ORIGINAL_DST = 80
+)
+
// Name lengths. These correspond to values in
// include/uapi/linux/netfilter/x_tables.h.
const (
@@ -67,6 +90,8 @@ const (
// IPTEntry is an iptable rule. It corresponds to struct ipt_entry in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTEntry struct {
// IP is used to filter packets based on the IP header.
IP IPTIP
@@ -103,21 +128,41 @@ type IPTEntry struct {
// SizeOfIPTEntry is the size of an IPTEntry.
const SizeOfIPTEntry = 112
-// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This
-// struct marshaled via the binary package to write an IPTEntry to userspace.
+// KernelIPTEntry is identical to IPTEntry, but includes the Elems field.
+// KernelIPTEntry itself is not Marshallable but it implements some methods of
+// marshal.Marshallable that help in other implementations of Marshallable.
type KernelIPTEntry struct {
- IPTEntry
+ Entry IPTEntry
// Elems holds the data for all this rule's matches followed by the
// target. It is variable length -- users have to iterate over any
// matches and use TargetOffset and NextOffset to make sense of the
// data.
- Elems []byte
+ Elems primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTEntry) SizeBytes() int {
+ return ke.Entry.SizeBytes() + ke.Elems.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTEntry) MarshalBytes(dst []byte) {
+ ke.Entry.MarshalBytes(dst)
+ ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) {
+ ke.Entry.UnmarshalBytes(src)
+ ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
}
// IPTIP contains information for matching a packet's IP header.
// It corresponds to struct ipt_ip in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTIP struct {
// Src is the source IP address.
Src InetAddr
@@ -137,7 +182,7 @@ type IPTIP struct {
// OutputInterface is the output network interface.
OutputInterface [IFNAMSIZ]byte
- // InputInterfaceMask is the intput interface mask.
+ // InputInterfaceMask is the input interface mask.
InputInterfaceMask [IFNAMSIZ]byte
// OuputInterfaceMask is the output interface mask.
@@ -149,15 +194,39 @@ type IPTIP struct {
// Flags define matching behavior for the IP header.
Flags uint8
- // InverseFlags invert the meaning of fields in struct IPTIP.
+ // InverseFlags invert the meaning of fields in struct IPTIP. See the
+ // IPT_INV_* flags.
InverseFlags uint8
}
+// Flags in IPTIP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+const (
+ // Invert the meaning of InputInterface.
+ IPT_INV_VIA_IN = 0x01
+ // Invert the meaning of OutputInterface.
+ IPT_INV_VIA_OUT = 0x02
+ // Unclear what this is, as no references to it exist in the kernel.
+ IPT_INV_TOS = 0x04
+ // Invert the meaning of Src.
+ IPT_INV_SRCIP = 0x08
+ // Invert the meaning of Dst.
+ IPT_INV_DSTIP = 0x10
+ // Invert the meaning of the IPT_F_FRAG flag.
+ IPT_INV_FRAG = 0x20
+ // Invert the meaning of the Protocol field.
+ IPT_INV_PROTO = 0x40
+ // Enable all flags.
+ IPT_INV_MASK = 0x7F
+)
+
// SizeOfIPTIP is the size of an IPTIP.
const SizeOfIPTIP = 84
// XTCounters holds packet and byte counts for a rule. It corresponds to struct
// xt_counters in include/uapi/linux/netfilter/x_tables.h.
+//
+// +marshal
type XTCounters struct {
// Pcnt is the packet count.
Pcnt uint64
@@ -179,7 +248,7 @@ const SizeOfXTCounters = 16
// the user data.
type XTEntryMatch struct {
MatchSize uint16
- Name [XT_EXTENSION_MAXNAMELEN]byte
+ Name ExtensionName
Revision uint8
// Data is omitted here because it would cause XTEntryMatch to be an
// extra byte larger (see http://www.catb.org/esr/structure-packing/).
@@ -189,6 +258,13 @@ type XTEntryMatch struct {
// SizeOfXTEntryMatch is the size of an XTEntryMatch.
const SizeOfXTEntryMatch = 32
+// KernelXTEntryMatch is identical to XTEntryMatch, but contains
+// variable-length Data field.
+type KernelXTEntryMatch struct {
+ XTEntryMatch
+ Data []byte
+}
+
// XTEntryTarget holds a target for a rule. For example, it can specify that
// packets matching the rule should DROP, ACCEPT, or use an extension target.
// iptables-extension(8) has a list of possible targets.
@@ -199,7 +275,7 @@ const SizeOfXTEntryMatch = 32
// the user data.
type XTEntryTarget struct {
TargetSize uint16
- Name [XT_EXTENSION_MAXNAMELEN]byte
+ Name ExtensionName
Revision uint8
// Data is omitted here because it would cause XTEntryTarget to be an
// extra byte larger (see http://www.catb.org/esr/structure-packing/).
@@ -209,11 +285,14 @@ type XTEntryTarget struct {
// SizeOfXTEntryTarget is the size of an XTEntryTarget.
const SizeOfXTEntryTarget = 32
-// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or
-// RETURN. It corresponds to struct xt_standard_target in
+// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
+// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTStandardTarget struct {
- Target XTEntryTarget
+ Target XTEntryTarget
+ // A positive verdict indicates a jump, and is the offset from the
+ // start of the table to jump to. A negative value means one of the
+ // other built-in targets.
Verdict int32
_ [4]byte
}
@@ -226,18 +305,64 @@ const SizeOfXTStandardTarget = 40
// ErrorName. It corresponds to struct xt_error_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTErrorTarget struct {
- Target XTEntryTarget
- ErrorName [XT_FUNCTION_MAXNAMELEN]byte
- _ [2]byte
+ Target XTEntryTarget
+ Name ErrorName
+ _ [2]byte
}
// SizeOfXTErrorTarget is the size of an XTErrorTarget.
const SizeOfXTErrorTarget = 64
+// Flag values for NfNATIPV4Range. The values indicate whether to map
+// protocol specific part(ports) or IPs. It corresponds to values in
+// include/uapi/linux/netfilter/nf_nat.h.
+const (
+ NF_NAT_RANGE_MAP_IPS = 1 << 0
+ NF_NAT_RANGE_PROTO_SPECIFIED = 1 << 1
+ NF_NAT_RANGE_PROTO_RANDOM = 1 << 2
+ NF_NAT_RANGE_PERSISTENT = 1 << 3
+ NF_NAT_RANGE_PROTO_RANDOM_FULLY = 1 << 4
+ NF_NAT_RANGE_PROTO_RANDOM_ALL = (NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+ NF_NAT_RANGE_MASK = (NF_NAT_RANGE_MAP_IPS |
+ NF_NAT_RANGE_PROTO_SPECIFIED | NF_NAT_RANGE_PROTO_RANDOM |
+ NF_NAT_RANGE_PERSISTENT | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
+)
+
+// NfNATIPV4Range corresponds to struct nf_nat_ipv4_range
+// in include/uapi/linux/netfilter/nf_nat.h. The fields are in
+// network byte order.
+type NfNATIPV4Range struct {
+ Flags uint32
+ MinIP [4]byte
+ MaxIP [4]byte
+ MinPort uint16
+ MaxPort uint16
+}
+
+// NfNATIPV4MultiRangeCompat corresponds to struct
+// nf_nat_ipv4_multi_range_compat in include/uapi/linux/netfilter/nf_nat.h.
+type NfNATIPV4MultiRangeCompat struct {
+ RangeSize uint32
+ RangeIPV4 NfNATIPV4Range
+}
+
+// XTRedirectTarget triggers a redirect when reached.
+// Adding 4 bytes of padding to make the struct 8 byte aligned.
+type XTRedirectTarget struct {
+ Target XTEntryTarget
+ NfRange NfNATIPV4MultiRangeCompat
+ _ [4]byte
+}
+
+// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
+const SizeOfXTRedirectTarget = 56
+
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetinfo struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
ValidHooks uint32
HookEntry [NF_INET_NUMHOOKS]uint32
Underflow [NF_INET_NUMHOOKS]uint32
@@ -248,16 +373,13 @@ type IPTGetinfo struct {
// SizeOfIPTGetinfo is the size of an IPTGetinfo.
const SizeOfIPTGetinfo = 84
-// TableName returns the table name.
-func (info *IPTGetinfo) TableName() string {
- return tableName(info.Name[:])
-}
-
// IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It
// corresponds to struct ipt_get_entries in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
+//
+// +marshal
type IPTGetEntries struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
Size uint32
_ [4]byte
// Entrytable is omitted here because it would cause IPTGetEntries to
@@ -266,34 +388,112 @@ type IPTGetEntries struct {
// Entrytable [0]IPTEntry
}
-// TableName returns the entries' table name.
-func (entries *IPTGetEntries) TableName() string {
- return tableName(entries.Name[:])
-}
-
// SizeOfIPTGetEntries is the size of an IPTGetEntries.
const SizeOfIPTGetEntries = 40
-// KernelIPTGetEntries is identical to IPTEntry, but includes the Elems field.
-// This struct marshaled via the binary package to write an KernelIPTGetEntries
-// to userspace.
+// KernelIPTGetEntries is identical to IPTGetEntries, but includes the
+// Entrytable field. This has been manually made marshal.Marshallable since it
+// is dynamically sized.
type KernelIPTGetEntries struct {
- Name [XT_TABLE_MAXNAMELEN]byte
- Size uint32
- _ [4]byte
+ IPTGetEntries
Entrytable []KernelIPTEntry
}
-// TableName returns the entries' table name.
-func (entries *KernelIPTGetEntries) TableName() string {
- return tableName(entries.Name[:])
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIPTGetEntries) SizeBytes() int {
+ res := ke.IPTGetEntries.SizeBytes()
+ for _, entry := range ke.Entrytable {
+ res += entry.SizeBytes()
+ }
+ return res
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) {
+ ke.IPTGetEntries.MarshalBytes(dst)
+ marshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := range ke.Entrytable {
+ ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
+ marshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) {
+ ke.IPTGetEntries.UnmarshalBytes(src)
+ unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := range ke.Entrytable {
+ ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
+ unmarshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (ke *KernelIPTGetEntries) Packed() bool {
+ // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an
+ // indirection to the actual data we want to marshal (the slice data
+ // pointer), and the memory for KernelIPTGetEntries contains the slice
+ // header which we don't want to marshal.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) {
+ // Fall back to safe Marshal because the type in not packed.
+ ke.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) {
+ // Fall back to safe Unmarshal because the type in not packed.
+ ke.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results in a
+ // partially unmarshalled struct.
+ ke.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
}
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task))
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+}
+
+func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte {
+ buf := task.CopyScratchBuffer(ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ return buf
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) {
+ buf := make([]byte, ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ length, err := w.Write(buf)
+ return int64(length), err
+}
+
+var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil)
+
// IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It
// corresponds to struct ipt_replace in
// include/uapi/linux/netfilter_ipv4/ip_tables.h.
type IPTReplace struct {
- Name [XT_TABLE_MAXNAMELEN]byte
+ Name TableName
ValidHooks uint32
NumEntries uint32
Size uint32
@@ -309,11 +509,172 @@ type IPTReplace struct {
// SizeOfIPTReplace is the size of an IPTReplace.
const SizeOfIPTReplace = 96
-func tableName(name []byte) string {
- for i, c := range name {
+// ExtensionName holds the name of a netfilter extension.
+type ExtensionName [XT_EXTENSION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ExtensionName) String() string {
+ return goString(en[:])
+}
+
+// TableName holds the name of a netfilter table.
+//
+// +marshal
+type TableName [XT_TABLE_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (tn TableName) String() string {
+ return goString(tn[:])
+}
+
+// ErrorName holds the name of a netfilter error. These can also hold
+// user-defined chains.
+type ErrorName [XT_FUNCTION_MAXNAMELEN]byte
+
+// String implements fmt.Stringer.
+func (en ErrorName) String() string {
+ return goString(en[:])
+}
+
+func goString(cstring []byte) string {
+ for i, c := range cstring {
if c == 0 {
- return string(name[:i])
+ return string(cstring[:i])
}
}
- return string(name)
+ return string(cstring)
+}
+
+// XTTCP holds data for matching TCP packets. It corresponds to struct xt_tcp
+// in include/uapi/linux/netfilter/xt_tcpudp.h.
+type XTTCP struct {
+ // SourcePortStart specifies the inclusive start of the range of source
+ // ports to which the matcher applies.
+ SourcePortStart uint16
+
+ // SourcePortEnd specifies the inclusive end of the range of source ports
+ // to which the matcher applies.
+ SourcePortEnd uint16
+
+ // DestinationPortStart specifies the start of the destination port
+ // range to which the matcher applies.
+ DestinationPortStart uint16
+
+ // DestinationPortEnd specifies the end of the destination port
+ // range to which the matcher applies.
+ DestinationPortEnd uint16
+
+ // Option specifies that a particular TCP option must be set.
+ Option uint8
+
+ // FlagMask masks TCP flags when comparing to the FlagCompare byte. It allows
+ // for specification of which flags are important to the matcher.
+ FlagMask uint8
+
+ // FlagCompare, in combination with FlagMask, is used to match only packets
+ // that have certain flags set.
+ FlagCompare uint8
+
+ // InverseFlags flips the meaning of certain fields. See the
+ // TX_TCP_INV_* flags.
+ InverseFlags uint8
}
+
+// SizeOfXTTCP is the size of an XTTCP.
+const SizeOfXTTCP = 12
+
+// Flags in XTTCP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_tcpudp.h.
+const (
+ // Invert the meaning of SourcePortStart/End.
+ XT_TCP_INV_SRCPT = 0x01
+ // Invert the meaning of DestinationPortStart/End.
+ XT_TCP_INV_DSTPT = 0x02
+ // Invert the meaning of FlagCompare.
+ XT_TCP_INV_FLAGS = 0x04
+ // Invert the meaning of Option.
+ XT_TCP_INV_OPTION = 0x08
+ // Enable all flags.
+ XT_TCP_INV_MASK = 0x0F
+)
+
+// XTUDP holds data for matching UDP packets. It corresponds to struct xt_udp
+// in include/uapi/linux/netfilter/xt_tcpudp.h.
+type XTUDP struct {
+ // SourcePortStart is the inclusive start of the range of source ports
+ // to which the matcher applies.
+ SourcePortStart uint16
+
+ // SourcePortEnd is the inclusive end of the range of source ports to
+ // which the matcher applies.
+ SourcePortEnd uint16
+
+ // DestinationPortStart is the inclusive start of the destination port
+ // range to which the matcher applies.
+ DestinationPortStart uint16
+
+ // DestinationPortEnd is the inclusive end of the destination port
+ // range to which the matcher applies.
+ DestinationPortEnd uint16
+
+ // InverseFlags flips the meaning of certain fields. See the
+ // TX_UDP_INV_* flags.
+ InverseFlags uint8
+
+ _ uint8
+}
+
+// SizeOfXTUDP is the size of an XTUDP.
+const SizeOfXTUDP = 10
+
+// Flags in XTUDP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_tcpudp.h.
+const (
+ // Invert the meaning of SourcePortStart/End.
+ XT_UDP_INV_SRCPT = 0x01
+ // Invert the meaning of DestinationPortStart/End.
+ XT_UDP_INV_DSTPT = 0x02
+ // Enable all flags.
+ XT_UDP_INV_MASK = 0x03
+)
+
+// IPTOwnerInfo holds data for matching packets with owner. It corresponds
+// to struct ipt_owner_info in libxt_owner.c of iptables binary.
+type IPTOwnerInfo struct {
+ // UID is user id which created the packet.
+ UID uint32
+
+ // GID is group id which created the packet.
+ GID uint32
+
+ // PID is process id of the process which created the packet.
+ PID uint32
+
+ // SID is session id which created the packet.
+ SID uint32
+
+ // Comm is the command name which created the packet.
+ Comm [16]byte
+
+ // Match is used to match UID/GID of the socket. See the
+ // XT_OWNER_* flags below.
+ Match uint8
+
+ // Invert flips the meaning of Match field.
+ Invert uint8
+}
+
+// SizeOfIPTOwnerInfo is the size of an XTOwnerMatchInfo.
+const SizeOfIPTOwnerInfo = 34
+
+// Flags in IPTOwnerInfo.Match. Corresponding constants are in
+// include/uapi/linux/netfilter/xt_owner.h.
+const (
+ // Match the UID of the packet.
+ XT_OWNER_UID = 1 << 0
+ // Match the GID of the packet.
+ XT_OWNER_GID = 1 << 1
+ // Match if the socket exists for the packet. Forwarded
+ // packets do not have an associated socket.
+ XT_OWNER_SOCKET = 1 << 2
+)
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
new file mode 100644
index 000000000..9bb9efb10
--- /dev/null
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -0,0 +1,310 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
+// This file contains structures required to support IPv6 netfilter and
+// ip6tables. Some constants and structs are equal to their IPv4 analogues, and
+// are only distinguished by context (e.g. whether used on an IPv4 of IPv6
+// socket).
+
+// Socket options for SOL_SOCLET. These correspond to values in
+// include/uapi/linux/netfilter_ipv6/ip6_tables.h.
+const (
+ IP6T_BASE_CTL = 64
+ IP6T_SO_SET_REPLACE = IPT_BASE_CTL
+ IP6T_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1
+ IP6T_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS
+
+ IP6T_SO_GET_INFO = IPT_BASE_CTL
+ IP6T_SO_GET_ENTRIES = IPT_BASE_CTL + 1
+ IP6T_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 4
+ IP6T_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 5
+ IP6T_SO_GET_MAX = IP6T_SO_GET_REVISION_TARGET
+)
+
+// IP6T_ORIGINAL_DST is the ip6tables SOL_IPV6 socket option. Corresponds to
+// the value in include/uapi/linux/netfilter_ipv6/ip6_tables.h.
+// TODO(gvisor.dev/issue/3549): Support IPv6 original destination.
+const IP6T_ORIGINAL_DST = 80
+
+// IP6TReplace is the argument for the IP6T_SO_SET_REPLACE sockopt. It
+// corresponds to struct ip6t_replace in
+// include/uapi/linux/netfilter_ipv6/ip6_tables.h.
+//
+// +marshal
+type IP6TReplace struct {
+ Name TableName
+ ValidHooks uint32
+ NumEntries uint32
+ Size uint32
+ HookEntry [NF_INET_NUMHOOKS]uint32
+ Underflow [NF_INET_NUMHOOKS]uint32
+ NumCounters uint32
+ Counters uint64 // This is really a *XTCounters.
+ // Entries is omitted here because it would cause IP6TReplace to be an
+ // extra byte longer (see http://www.catb.org/esr/structure-packing/).
+ // Entries [0]IP6TEntry
+}
+
+// SizeOfIP6TReplace is the size of an IP6TReplace.
+const SizeOfIP6TReplace = 96
+
+// KernelIP6TGetEntries is identical to IP6TGetEntries, but includes the
+// Entrytable field. This has been manually made marshal.Marshallable since it
+// is dynamically sized.
+type KernelIP6TGetEntries struct {
+ IPTGetEntries
+ Entrytable []KernelIP6TEntry
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIP6TGetEntries) SizeBytes() int {
+ res := ke.IPTGetEntries.SizeBytes()
+ for _, entry := range ke.Entrytable {
+ res += entry.SizeBytes()
+ }
+ return res
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) {
+ ke.IPTGetEntries.MarshalBytes(dst)
+ marshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := range ke.Entrytable {
+ ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:])
+ marshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) {
+ ke.IPTGetEntries.UnmarshalBytes(src)
+ unmarshalledUntil := ke.IPTGetEntries.SizeBytes()
+ for i := range ke.Entrytable {
+ ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:])
+ unmarshalledUntil += ke.Entrytable[i].SizeBytes()
+ }
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (ke *KernelIP6TGetEntries) Packed() bool {
+ // KernelIP6TGetEntries isn't packed because the ke.Entrytable contains
+ // an indirection to the actual data we want to marshal (the slice data
+ // pointer), and the memory for KernelIP6TGetEntries contains the slice
+ // header which we don't want to marshal.
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (ke *KernelIP6TGetEntries) MarshalUnsafe(dst []byte) {
+ // Fall back to safe Marshal because the type in not packed.
+ ke.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) {
+ // Fall back to safe Unmarshal because the type in not packed.
+ ke.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay.
+ length, err := task.CopyInBytes(addr, buf) // escapes: okay.
+ // Unmarshal unconditionally. If we had a short copy-in, this results
+ // in a partially unmarshalled struct.
+ ke.UnmarshalBytes(buf) // escapes: fallback.
+ return length, err
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (ke *KernelIP6TGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ // Type KernelIP6TGetEntries doesn't have a packed layout in memory,
+ // fall back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task))
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (ke *KernelIP6TGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall
+ // back to MarshalBytes.
+ return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit])
+}
+
+func (ke *KernelIP6TGetEntries) marshalAll(task marshal.Task) []byte {
+ buf := task.CopyScratchBuffer(ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ return buf
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (ke *KernelIP6TGetEntries) WriteTo(w io.Writer) (int64, error) {
+ buf := make([]byte, ke.SizeBytes())
+ ke.MarshalBytes(buf)
+ length, err := w.Write(buf)
+ return int64(length), err
+}
+
+var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil)
+
+// IP6TEntry is an iptables rule. It corresponds to struct ip6t_entry in
+// include/uapi/linux/netfilter_ipv6/ip6_tables.h.
+//
+// +marshal
+type IP6TEntry struct {
+ // IPv6 is used to filter packets based on the IPv6 header.
+ IPv6 IP6TIP
+
+ // NFCache relates to kernel-internal caching and isn't used by
+ // userspace.
+ NFCache uint32
+
+ // TargetOffset is the byte offset from the beginning of this IPTEntry
+ // to the start of the entry's target.
+ TargetOffset uint16
+
+ // NextOffset is the byte offset from the beginning of this IPTEntry to
+ // the start of the next entry. It is thus also the size of the entry.
+ NextOffset uint16
+
+ // Comeback is a return pointer. It is not used by userspace.
+ Comeback uint32
+
+ _ [4]byte
+
+ // Counters holds the packet and byte counts for this rule.
+ Counters XTCounters
+
+ // Elems holds the data for all this rule's matches followed by the
+ // target. It is variable length -- users have to iterate over any
+ // matches and use TargetOffset and NextOffset to make sense of the
+ // data.
+ //
+ // Elems is omitted here because it would cause IPTEntry to be an extra
+ // byte larger (see http://www.catb.org/esr/structure-packing/).
+ //
+ // Elems [0]byte
+}
+
+// SizeOfIP6TEntry is the size of an IP6TEntry.
+const SizeOfIP6TEntry = 168
+
+// KernelIP6TEntry is identical to IP6TEntry, but includes the Elems field.
+// KernelIP6TEntry itself is not Marshallable but it implements some methods of
+// marshal.Marshallable that help in other implementations of Marshallable.
+type KernelIP6TEntry struct {
+ Entry IP6TEntry
+
+ // Elems holds the data for all this rule's matches followed by the
+ // target. It is variable length -- users have to iterate over any
+ // matches and use TargetOffset and NextOffset to make sense of the
+ // data.
+ Elems primitive.ByteSlice
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (ke *KernelIP6TEntry) SizeBytes() int {
+ return ke.Entry.SizeBytes() + ke.Elems.SizeBytes()
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) {
+ ke.Entry.MarshalBytes(dst)
+ ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():])
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) {
+ ke.Entry.UnmarshalBytes(src)
+ ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():])
+}
+
+// IP6TIP contains information for matching a packet's IP header.
+// It corresponds to struct ip6t_ip6 in
+// include/uapi/linux/netfilter_ipv6/ip6_tables.h.
+//
+// +marshal
+type IP6TIP struct {
+ // Src is the source IP address.
+ Src Inet6Addr
+
+ // Dst is the destination IP address.
+ Dst Inet6Addr
+
+ // SrcMask is the source IP mask.
+ SrcMask Inet6Addr
+
+ // DstMask is the destination IP mask.
+ DstMask Inet6Addr
+
+ // InputInterface is the input network interface.
+ InputInterface [IFNAMSIZ]byte
+
+ // OutputInterface is the output network interface.
+ OutputInterface [IFNAMSIZ]byte
+
+ // InputInterfaceMask is the input interface mask.
+ InputInterfaceMask [IFNAMSIZ]byte
+
+ // OuputInterfaceMask is the output interface mask.
+ OutputInterfaceMask [IFNAMSIZ]byte
+
+ // Protocol is the transport protocol.
+ Protocol uint16
+
+ // TOS matches TOS flags when Flags indicates filtering by TOS.
+ TOS uint8
+
+ // Flags define matching behavior for the IP header.
+ Flags uint8
+
+ // InverseFlags invert the meaning of fields in struct IPTIP. See the
+ // IP6T_INV_* flags.
+ InverseFlags uint8
+
+ // Linux defines in6_addr (Inet6Addr for us) as the union of a
+ // 16-element byte array and a 4-element 32-bit integer array, so the
+ // whole struct is 4-byte aligned.
+ _ [3]byte
+}
+
+const SizeOfIP6TIP = 136
+
+// Flags in IP6TIP.InverseFlags. Corresponding constants are in
+// include/uapi/linux/netfilter_ipv6/ip6_tables.h.
+const (
+ // Invert the meaning of InputInterface.
+ IP6T_INV_VIA_IN = 0x01
+ // Invert the meaning of OutputInterface.
+ IP6T_INV_VIA_OUT = 0x02
+ // Invert the meaning of TOS.
+ IP6T_INV_TOS = 0x04
+ // Invert the meaning of Src.
+ IP6T_INV_SRCIP = 0x08
+ // Invert the meaning of Dst.
+ IP6T_INV_DSTIP = 0x10
+ // Invert the meaning of the IPT_F_FRAG flag.
+ IP6T_INV_FRAG = 0x20
+ // Enable all flags.
+ IP6T_INV_MASK = 0x7F
+)
diff --git a/pkg/abi/linux/netfilter_test.go b/pkg/abi/linux/netfilter_test.go
index 21e237f92..bf73271c6 100644
--- a/pkg/abi/linux/netfilter_test.go
+++ b/pkg/abi/linux/netfilter_test.go
@@ -29,12 +29,16 @@ func TestSizes(t *testing.T) {
{IPTGetEntries{}, SizeOfIPTGetEntries},
{IPTGetinfo{}, SizeOfIPTGetinfo},
{IPTIP{}, SizeOfIPTIP},
+ {IPTOwnerInfo{}, SizeOfIPTOwnerInfo},
{IPTReplace{}, SizeOfIPTReplace},
{XTCounters{}, SizeOfXTCounters},
{XTEntryMatch{}, SizeOfXTEntryMatch},
{XTEntryTarget{}, SizeOfXTEntryTarget},
{XTErrorTarget{}, SizeOfXTErrorTarget},
{XTStandardTarget{}, SizeOfXTStandardTarget},
+ {IP6TReplace{}, SizeOfIP6TReplace},
+ {IP6TEntry{}, SizeOfIP6TEntry},
+ {IP6TIP{}, SizeOfIP6TIP},
}
for _, tc := range testCases {
diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go
index 3898d2314..ceda0a8d3 100644
--- a/pkg/abi/linux/netlink_route.go
+++ b/pkg/abi/linux/netlink_route.go
@@ -187,10 +187,12 @@ const (
// Device types, from uapi/linux/if_arp.h.
const (
+ ARPHRD_NONE = 65534
+ ARPHRD_ETHER = 1
ARPHRD_LOOPBACK = 772
)
-// RouteMessage struct rtmsg, from uapi/linux/rtnetlink.h.
+// RouteMessage is struct rtmsg, from uapi/linux/rtnetlink.h.
type RouteMessage struct {
Family uint8
DstLen uint8
@@ -205,6 +207,9 @@ type RouteMessage struct {
Flags uint32
}
+// SizeOfRouteMessage is the size of RouteMessage.
+const SizeOfRouteMessage = 12
+
// Route types, from uapi/linux/rtnetlink.h.
const (
// RTN_UNSPEC represents an unspecified route type.
@@ -331,3 +336,13 @@ const (
RTF_GATEWAY = 0x2
RTF_UP = 0x1
)
+
+// RtAttr is the header of optional addition route information, as a netlink
+// attribute. From include/uapi/linux/rtnetlink.h.
+type RtAttr struct {
+ Len uint16
+ Type uint16
+}
+
+// SizeOfRtAttr is the size of RtAttr.
+const SizeOfRtAttr = 4
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
new file mode 100644
index 000000000..ed3881e27
--- /dev/null
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+// PtraceRegs is the set of CPU registers exposed by ptrace. Source:
+// syscall.PtraceRegs.
+//
+// +marshal
+// +stateify savable
+type PtraceRegs struct {
+ R15 uint64
+ R14 uint64
+ R13 uint64
+ R12 uint64
+ Rbp uint64
+ Rbx uint64
+ R11 uint64
+ R10 uint64
+ R9 uint64
+ R8 uint64
+ Rax uint64
+ Rcx uint64
+ Rdx uint64
+ Rsi uint64
+ Rdi uint64
+ Orig_rax uint64
+ Rip uint64
+ Cs uint64
+ Eflags uint64
+ Rsp uint64
+ Ss uint64
+ Fs_base uint64
+ Gs_base uint64
+ Ds uint64
+ Es uint64
+ Fs uint64
+ Gs uint64
+}
diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go
new file mode 100644
index 000000000..6147738b3
--- /dev/null
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+// PtraceRegs is the set of CPU registers exposed by ptrace. Source:
+// syscall.PtraceRegs.
+//
+// +marshal
+// +stateify savable
+type PtraceRegs struct {
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
+}
diff --git a/pkg/abi/linux/rseq.go b/pkg/abi/linux/rseq.go
new file mode 100644
index 000000000..76253ba30
--- /dev/null
+++ b/pkg/abi/linux/rseq.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Flags passed to rseq(2).
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_FLAG_UNREGISTER unregisters the current thread.
+ RSEQ_FLAG_UNREGISTER = 1 << 0
+)
+
+// Critical section flags used in RSeqCriticalSection.Flags and RSeq.Flags.
+//
+// Defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT inhibits restart on preemption.
+ RSEQ_CS_FLAG_NO_RESTART_ON_PREEMPT = 1 << 0
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL inhibits restart on signal
+ // delivery.
+ RSEQ_CS_FLAG_NO_RESTART_ON_SIGNAL = 1 << 1
+
+ // RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE inhibits restart on CPU
+ // migration.
+ RSEQ_CS_FLAG_NO_RESTART_ON_MIGRATE = 1 << 2
+)
+
+// RSeqCriticalSection describes a restartable sequences critical section. It
+// is equivalent to struct rseq_cs, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+//
+// +marshal
+type RSeqCriticalSection struct {
+ // Version is the version of this structure. Version 0 is defined here.
+ Version uint32
+
+ // Flags are the critical section flags, defined above.
+ Flags uint32
+
+ // Start is the start address of the critical section.
+ Start uint64
+
+ // PostCommitOffset is the offset from Start of the first instruction
+ // outside of the critical section.
+ PostCommitOffset uint64
+
+ // Abort is the abort address. It must be outside the critical section,
+ // and the 4 bytes prior must match the abort signature.
+ Abort uint64
+}
+
+const (
+ // SizeOfRSeqCriticalSection is the size of RSeqCriticalSection.
+ SizeOfRSeqCriticalSection = 32
+
+ // SizeOfRSeqSignature is the size of the signature immediately
+ // preceding RSeqCriticalSection.Abort.
+ SizeOfRSeqSignature = 4
+)
+
+// Special values for RSeq.CPUID, defined in include/uapi/linux/rseq.h.
+const (
+ // RSEQ_CPU_ID_UNINITIALIZED indicates that this thread has not
+ // performed rseq initialization.
+ RSEQ_CPU_ID_UNINITIALIZED = ^uint32(0) // -1
+
+ // RSEQ_CPU_ID_REGISTRATION_FAILED indicates that rseq initialization
+ // failed.
+ RSEQ_CPU_ID_REGISTRATION_FAILED = ^uint32(1) // -2
+)
+
+// RSeq is the thread-local restartable sequences config/status. It
+// is equivalent to struct rseq, defined in include/uapi/linux/rseq.h.
+//
+// In userspace, this structure is always aligned to 32 bytes.
+type RSeq struct {
+ // CPUIDStart contains the current CPU ID if rseq is initialized.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUIDStart uint32
+
+ // CPUID contains the current CPU ID or one of the CPU ID special
+ // values defined above.
+ //
+ // This field should only be read by the thread which registered this
+ // structure, and must be read atomically.
+ CPUID uint32
+
+ // RSeqCriticalSection is a pointer to the current RSeqCriticalSection
+ // block, or NULL. It is reset to NULL by the kernel on restart or
+ // non-restarting preempt/signal.
+ //
+ // This field should only be written by the thread which registered
+ // this structure, and must be written atomically.
+ RSeqCriticalSection uint64
+
+ // Flags are the critical section flags that apply to all critical
+ // sections on this thread, defined above.
+ Flags uint32
+}
+
+const (
+ // SizeOfRSeq is the size of RSeq.
+ //
+ // Note that RSeq is naively 24 bytes. However, it has 32-byte
+ // alignment, which in C increases sizeof to 32. That is the size that
+ // the Linux kernel uses.
+ SizeOfRSeq = 32
+
+ // AlignOfRSeq is the standard alignment of RSeq.
+ AlignOfRSeq = 32
+
+ // OffsetOfRSeqCriticalSection is the offset of RSeqCriticalSection in RSeq.
+ OffsetOfRSeqCriticalSection = 8
+)
diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go
index 4eeb5cd7a..d0607e256 100644
--- a/pkg/abi/linux/seccomp.go
+++ b/pkg/abi/linux/seccomp.go
@@ -63,3 +63,10 @@ func (a BPFAction) String() string {
func (a BPFAction) Data() uint16 {
return uint16(a & SECCOMP_RET_DATA)
}
+
+// SockFprog is sock_fprog taken from <linux/filter.h>.
+type SockFprog struct {
+ Len uint16
+ pad [6]byte
+ Filter *BPFInstruction
+}
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index c69b04ea9..1c330e763 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -115,6 +115,8 @@ const (
)
// SignalSet is a signal mask with a bit corresponding to each signal.
+//
+// +marshal
type SignalSet uint64
// SignalSetSize is the size in bytes of a SignalSet.
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 2e2cc6be7..e37c8727d 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -83,7 +83,6 @@ const (
MSG_MORE = 0x8000
MSG_WAITFORONE = 0x10000
MSG_SENDPAGE_NOTLAST = 0x20000
- MSG_REINJECT = 0x8000000
MSG_ZEROCOPY = 0x4000000
MSG_FASTOPEN = 0x20000000
MSG_CMSG_CLOEXEC = 0x40000000
@@ -134,6 +133,15 @@ const (
SHUT_RDWR = 2
)
+// Packet types from <linux/if_packet.h>
+const (
+ PACKET_HOST = 0 // To us
+ PACKET_BROADCAST = 1 // To all
+ PACKET_MULTICAST = 2 // To group
+ PACKET_OTHERHOST = 3 // To someone else
+ PACKET_OUTGOING = 4 // Outgoing of any type
+)
+
// Socket options from socket.h.
const (
SO_DEBUG = 1
@@ -225,14 +233,18 @@ const (
const SockAddrMax = 128
// InetAddr is struct in_addr, from uapi/linux/in.h.
+//
+// +marshal
type InetAddr [4]byte
// SockAddrInet is struct sockaddr_in, from uapi/linux/in.h.
+//
+// +marshal
type SockAddrInet struct {
Family uint16
Port uint16
Addr InetAddr
- Zero [8]uint8 // pad to sizeof(struct sockaddr).
+ _ [8]uint8 // pad to sizeof(struct sockaddr).
}
// InetMulticastRequest is struct ip_mreq, from uapi/linux/in.h.
@@ -247,6 +259,11 @@ type InetMulticastRequestWithNIC struct {
InterfaceIndex int32
}
+// Inet6Addr is struct in6_addr, from uapi/linux/in6.h.
+//
+// +marshal
+type Inet6Addr [16]byte
+
// SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h.
type SockAddrInet6 struct {
Family uint16
@@ -294,6 +311,8 @@ func (s *SockAddrUnix) implementsSockAddr() {}
func (s *SockAddrNetlink) implementsSockAddr() {}
// Linger is struct linger, from include/linux/socket.h.
+//
+// +marshal
type Linger struct {
OnOff int32
Linger int32
@@ -308,6 +327,8 @@ const SizeOfLinger = 8
// the end of this struct or within existing unusued space, so its size grows
// over time. The current iteration is based on linux v4.17. New versions are
// always backwards compatible.
+//
+// +marshal
type TCPInfo struct {
State uint8
CaState uint8
@@ -405,12 +426,23 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{}))
// A ControlMessageCredentials is an SCM_CREDENTIALS socket control message.
//
// ControlMessageCredentials represents struct ucred from linux/socket.h.
+//
+// +marshal
type ControlMessageCredentials struct {
PID int32
UID uint32
GID uint32
}
+// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
+//
+// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+type ControlMessageIPPacketInfo struct {
+ NIC int32
+ LocalAddr InetAddr
+ DestinationAddr InetAddr
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
@@ -422,6 +454,19 @@ type ControlMessageRights []int32
// ControlMessageRights.
const SizeOfControlMessageRight = 4
+// SizeOfControlMessageInq is the size of a TCP_INQ control message.
+const SizeOfControlMessageInq = 4
+
+// SizeOfControlMessageTOS is the size of an IP_TOS control message.
+const SizeOfControlMessageTOS = 1
+
+// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message.
+const SizeOfControlMessageTClass = 4
+
+// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO
+// control message.
+const SizeOfControlMessageIPPacketInfo = 12
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/abi/linux/tcp.go b/pkg/abi/linux/tcp.go
index 174d470e2..2a8d4708b 100644
--- a/pkg/abi/linux/tcp.go
+++ b/pkg/abi/linux/tcp.go
@@ -57,4 +57,5 @@ const (
const (
MAX_TCP_KEEPIDLE = 32767
MAX_TCP_KEEPINTVL = 32767
+ MAX_TCP_KEEPCNT = 127
)
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index 546668bca..e6860ed49 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -101,6 +101,8 @@ func NsecToTimeT(nsec int64) TimeT {
}
// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
type Timespec struct {
Sec int64
Nsec int64
@@ -155,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec {
const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
+//
+// +marshal
type Timeval struct {
Sec int64
Usec int64
@@ -228,12 +232,27 @@ type Tms struct {
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
+//
+// +marshal
type StatxTimestamp struct {
Sec int64
Nsec uint32
_ int32
}
+// ToNsec returns the nanosecond representation.
+func (sxts StatxTimestamp) ToNsec() int64 {
+ return int64(sxts.Sec)*1e9 + int64(sxts.Nsec)
+}
+
+// ToNsecCapped returns the safe nanosecond representation.
+func (sxts StatxTimestamp) ToNsecCapped() int64 {
+ if sxts.Sec > maxSecInDuration {
+ return math.MaxInt64
+ }
+ return sxts.ToNsec()
+}
+
// NsecToStatxTimestamp translates nanoseconds to StatxTimestamp.
func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
return StatxTimestamp{
@@ -243,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
}
// Utime represents struct utimbuf used by utimes(2).
+//
+// +marshal
type Utime struct {
Actime int64
Modtime int64
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
new file mode 100644
index 000000000..99180b208
--- /dev/null
+++ b/pkg/abi/linux/xattr.go
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// Constants for extended attributes.
+const (
+ XATTR_NAME_MAX = 255
+ XATTR_SIZE_MAX = 65536
+ XATTR_LIST_MAX = 65536
+
+ XATTR_CREATE = 1
+ XATTR_REPLACE = 2
+
+ XATTR_USER_PREFIX = "user."
+ XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX)
+)
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 6bc486b62..ffc918846 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -1,18 +1,18 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "amutex",
srcs = ["amutex.go"],
- importpath = "gvisor.dev/gvisor/pkg/amutex",
visibility = ["//:sandbox"],
+ deps = ["//pkg/syserror"],
)
go_test(
name = "amutex_test",
size = "small",
srcs = ["amutex_test.go"],
- embed = [":amutex"],
+ library = ":amutex",
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go
index 1c4fd1784..a078a31db 100644
--- a/pkg/amutex/amutex.go
+++ b/pkg/amutex/amutex.go
@@ -18,6 +18,8 @@ package amutex
import (
"sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/syserror"
)
// Sleeper must be implemented by users of the abortable mutex to allow for
@@ -53,6 +55,21 @@ func (NoopSleeper) SleepFinish(success bool) {}
// Interrupted implements Sleeper.Interrupted.
func (NoopSleeper) Interrupted() bool { return false }
+// Block blocks until either receiving from ch succeeds (in which case it
+// returns nil) or sleeper is interrupted (in which case it returns
+// syserror.ErrInterrupted).
+func Block(sleeper Sleeper, ch <-chan struct{}) error {
+ cancel := sleeper.SleepStart()
+ select {
+ case <-ch:
+ sleeper.SleepFinish(true)
+ return nil
+ case <-cancel:
+ sleeper.SleepFinish(false)
+ return syserror.ErrInterrupted
+ }
+}
+
// AbortableMutex is an abortable mutex. It allows Lock() to be aborted while it
// waits to acquire the mutex.
type AbortableMutex struct {
diff --git a/pkg/amutex/amutex_test.go b/pkg/amutex/amutex_test.go
index 1d7f45641..8a3952f2a 100644
--- a/pkg/amutex/amutex_test.go
+++ b/pkg/amutex/amutex_test.go
@@ -15,9 +15,10 @@
package amutex
import (
- "sync"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
type sleeper struct {
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 36beaade9..1a30f6967 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -1,23 +1,22 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "atomicbitops",
srcs = [
- "atomic_bitops.go",
- "atomic_bitops_amd64.s",
- "atomic_bitops_arm64.s",
- "atomic_bitops_common.go",
+ "atomicbitops.go",
+ "atomicbitops_amd64.s",
+ "atomicbitops_arm64.s",
+ "atomicbitops_noasm.go",
],
- importpath = "gvisor.dev/gvisor/pkg/atomicbitops",
visibility = ["//:sandbox"],
)
go_test(
name = "atomicbitops_test",
size = "small",
- srcs = ["atomic_bitops_test.go"],
- embed = [":atomicbitops"],
+ srcs = ["atomicbitops_test.go"],
+ library = ":atomicbitops",
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomicbitops.go
index fcc41a9ea..1be081719 100644
--- a/pkg/atomicbitops/atomic_bitops.go
+++ b/pkg/atomicbitops/atomicbitops.go
@@ -14,47 +14,34 @@
// +build amd64 arm64
-// Package atomicbitops provides basic bitwise operations in an atomic way.
-// The implementation on amd64 leverages the LOCK prefix directly instead of
-// relying on the generic cas primitives, and the arm64 leverages the LDAXR
-// and STLXR pair primitives.
+// Package atomicbitops provides extensions to the sync/atomic package.
//
-// WARNING: the bitwise ops provided in this package doesn't imply any memory
-// ordering. Using them to construct locks must employ proper memory barriers.
+// All read-modify-write operations implemented by this package have
+// acquire-release memory ordering (like sync/atomic).
package atomicbitops
-// AndUint32 atomically applies bitwise and operation to *addr with val.
+// AndUint32 atomically applies bitwise AND operation to *addr with val.
func AndUint32(addr *uint32, val uint32)
-// OrUint32 atomically applies bitwise or operation to *addr with val.
+// OrUint32 atomically applies bitwise OR operation to *addr with val.
func OrUint32(addr *uint32, val uint32)
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
+// XorUint32 atomically applies bitwise XOR operation to *addr with val.
func XorUint32(addr *uint32, val uint32)
// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32
-// AndUint64 atomically applies bitwise and operation to *addr with val.
+// AndUint64 atomically applies bitwise AND operation to *addr with val.
func AndUint64(addr *uint64, val uint64)
-// OrUint64 atomically applies bitwise or operation to *addr with val.
+// OrUint64 atomically applies bitwise OR operation to *addr with val.
func OrUint64(addr *uint64, val uint64)
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
+// XorUint64 atomically applies bitwise XOR operation to *addr with val.
func XorUint64(addr *uint64, val uint64)
// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool
diff --git a/pkg/atomicbitops/atomic_bitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index db0972001..54c887ee5 100644
--- a/pkg/atomicbitops/atomic_bitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -75,41 +75,3 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
CMPXCHGQ DX, 0(DI)
MOVQ AX, ret+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- TESTL AX, AX
- JZ fail
- LEAL 1(AX), DX
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB AX, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- LEAL -1(AX), DX
- TESTL DX, DX
- JZ fail
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB DX, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 97f8808c1..5c780851b 100644
--- a/pkg/atomicbitops/atomic_bitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -50,7 +50,6 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
-
again:
LDAXRW (R0), R3
CMPW R1, R3
@@ -95,7 +94,6 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
-
again:
LDAXR (R0), R3
CMP R1, R3
@@ -105,35 +103,3 @@ again:
done:
MOVD R3, prev+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- CBZ R1, fail
- ADDW $1, R1
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- SUBSW $1, R1, R1
- BEQ fail
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 85163ad62..3b2898256 100644
--- a/pkg/atomicbitops/atomic_bitops_common.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -20,7 +20,6 @@ import (
"sync/atomic"
)
-// AndUint32 atomically applies bitwise and operation to *addr with val.
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,7 +30,6 @@ func AndUint32(addr *uint32, val uint32) {
}
}
-// OrUint32 atomically applies bitwise or operation to *addr with val.
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -42,7 +40,6 @@ func OrUint32(addr *uint32, val uint32) {
}
}
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -53,8 +50,6 @@ func XorUint32(addr *uint32, val uint32) {
}
}
-// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
-// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -67,7 +62,6 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
-// AndUint64 atomically applies bitwise and operation to *addr with val.
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -78,7 +72,6 @@ func AndUint64(addr *uint64, val uint64) {
}
}
-// OrUint64 atomically applies bitwise or operation to *addr with val.
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -89,7 +82,6 @@ func OrUint64(addr *uint64, val uint64) {
}
}
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -100,8 +92,6 @@ func XorUint64(addr *uint64, val uint64) {
}
}
-// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
-// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
@@ -113,35 +103,3 @@ func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
}
}
}
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v+1) {
- return true
- }
- }
-}
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 1 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v-1) {
- return true
- }
- }
-}
diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomicbitops_test.go
index 965e9be79..73af71bb4 100644
--- a/pkg/atomicbitops/atomic_bitops_test.go
+++ b/pkg/atomicbitops/atomicbitops_test.go
@@ -16,8 +16,9 @@ package atomicbitops
import (
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
const iterations = 100
@@ -195,67 +196,3 @@ func TestCompareAndSwapUint64(t *testing.T) {
}
}
}
-
-func TestIncUnlessZeroInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: 0,
- ret: false,
- },
- {
- initial: 1,
- final: 2,
- ret: true,
- },
- {
- initial: 2,
- final: 3,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := IncUnlessZeroInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
-
-func TestDecUnlessOneInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: -1,
- ret: true,
- },
- {
- initial: 1,
- final: 1,
- ret: false,
- },
- {
- initial: 2,
- final: 1,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := DecUnlessOneInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD
index 543fb54bf..7ca2fda90 100644
--- a/pkg/binary/BUILD
+++ b/pkg/binary/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "binary",
srcs = ["binary.go"],
- importpath = "gvisor.dev/gvisor/pkg/binary",
visibility = ["//:sandbox"],
)
@@ -14,5 +12,5 @@ go_test(
name = "binary_test",
size = "small",
srcs = ["binary_test.go"],
- embed = [":binary"],
+ library = ":binary",
)
diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go
index 631785f7b..25065aef9 100644
--- a/pkg/binary/binary.go
+++ b/pkg/binary/binary.go
@@ -254,3 +254,13 @@ func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error {
_, err := w.Write(buf)
return err
}
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a length down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
index 93b88a29a..63f4670d7 100644
--- a/pkg/bits/BUILD
+++ b/pkg/bits/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
@@ -15,7 +14,6 @@ go_library(
"uint64_arch_arm64_asm.s",
"uint64_arch_generic.go",
],
- importpath = "gvisor.dev/gvisor/pkg/bits",
visibility = ["//:sandbox"],
)
@@ -53,5 +51,5 @@ go_test(
name = "bits_test",
size = "small",
srcs = ["uint64_test.go"],
- embed = [":bits"],
+ library = ":bits",
)
diff --git a/pkg/bits/bits_template.go b/pkg/bits/bits_template.go
index 93a435b80..998645388 100644
--- a/pkg/bits/bits_template.go
+++ b/pkg/bits/bits_template.go
@@ -42,3 +42,11 @@ func Mask(is ...int) T {
func MaskOf(i int) T {
return T(1) << T(i)
}
+
+// IsPowerOfTwo returns true if v is power of 2.
+func IsPowerOfTwo(v T) bool {
+ if v == 0 {
+ return false
+ }
+ return v&(v-1) == 0
+}
diff --git a/pkg/bits/uint64_test.go b/pkg/bits/uint64_test.go
index 1b018d808..193d1ebcd 100644
--- a/pkg/bits/uint64_test.go
+++ b/pkg/bits/uint64_test.go
@@ -114,3 +114,21 @@ func TestIsOn(t *testing.T) {
}
}
}
+
+func TestIsPowerOfTwo(t *testing.T) {
+ for _, tc := range []struct {
+ v uint64
+ want bool
+ }{
+ {v: 0, want: false},
+ {v: 1, want: true},
+ {v: 2, want: true},
+ {v: 3, want: false},
+ {v: 4, want: true},
+ {v: 5, want: false},
+ } {
+ if got := IsPowerOfTwo64(tc.v); got != tc.want {
+ t.Errorf("IsPowerOfTwo(%d) = %t, want: %t", tc.v, got, tc.want)
+ }
+ }
+}
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index fba5643e8..2a6977f85 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -12,7 +11,6 @@ go_library(
"interpreter.go",
"program_builder.go",
],
- importpath = "gvisor.dev/gvisor/pkg/bpf",
visibility = ["//visibility:public"],
deps = ["//pkg/abi/linux"],
)
@@ -25,7 +23,7 @@ go_test(
"interpreter_test.go",
"program_builder_test.go",
],
- embed = [":bpf"],
+ library = ":bpf",
deps = [
"//pkg/abi/linux",
"//pkg/binary",
diff --git a/pkg/bpf/interpreter_test.go b/pkg/bpf/interpreter_test.go
index 547921d0a..c85d786b9 100644
--- a/pkg/bpf/interpreter_test.go
+++ b/pkg/bpf/interpreter_test.go
@@ -767,7 +767,7 @@ func TestSimpleFilter(t *testing.T) {
expectedRet: 0,
},
{
- desc: "Whitelisted syscall is allowed",
+ desc: "Allowed syscall is indeed allowed",
seccompData: seccompData{nr: 231 /* __NR_exit_group */, arch: 0xc000003e},
expectedRet: 0x7fff0000,
},
diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD
new file mode 100644
index 000000000..dcd086298
--- /dev/null
+++ b/pkg/buffer/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "buffer_list",
+ out = "buffer_list.go",
+ package = "buffer",
+ prefix = "buffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*buffer",
+ "Linker": "*buffer",
+ },
+)
+
+go_library(
+ name = "buffer",
+ srcs = [
+ "buffer.go",
+ "buffer_list.go",
+ "safemem.go",
+ "view.go",
+ "view_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/log",
+ "//pkg/safemem",
+ ],
+)
+
+go_test(
+ name = "buffer_test",
+ size = "small",
+ srcs = [
+ "safemem_test.go",
+ "view_test.go",
+ ],
+ library = ":buffer",
+ deps = ["//pkg/safemem"],
+)
diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go
new file mode 100644
index 000000000..c6d089fd9
--- /dev/null
+++ b/pkg/buffer/buffer.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package buffer provides the implementation of a buffer view.
+//
+// A view is an flexible buffer, backed by a pool, supporting the safecopy
+// operations natively as well as the ability to grow via either prepend or
+// append, as well as shrink.
+package buffer
+
+import (
+ "sync"
+)
+
+const bufferSize = 8144 // See below.
+
+// buffer encapsulates a queueable byte buffer.
+//
+// Note that the total size is slightly less than two pages. This is done
+// intentionally to ensure that the buffer object aligns with runtime
+// internals. We have no hard size or alignment requirements. This two page
+// size will effectively minimize internal fragmentation, but still have a
+// large enough chunk to limit excessive segmentation.
+//
+// +stateify savable
+type buffer struct {
+ data [bufferSize]byte
+ read int
+ write int
+ bufferEntry
+}
+
+// reset resets internal data.
+//
+// This must be called before returning the buffer to the pool.
+func (b *buffer) Reset() {
+ b.read = 0
+ b.write = 0
+}
+
+// Full indicates the buffer is full.
+//
+// This indicates there is no capacity left to write.
+func (b *buffer) Full() bool {
+ return b.write == len(b.data)
+}
+
+// ReadSize returns the number of bytes available for reading.
+func (b *buffer) ReadSize() int {
+ return b.write - b.read
+}
+
+// ReadMove advances the read index by the given amount.
+func (b *buffer) ReadMove(n int) {
+ b.read += n
+}
+
+// ReadSlice returns the read slice for this buffer.
+func (b *buffer) ReadSlice() []byte {
+ return b.data[b.read:b.write]
+}
+
+// WriteSize returns the number of bytes available for writing.
+func (b *buffer) WriteSize() int {
+ return len(b.data) - b.write
+}
+
+// WriteMove advances the write index by the given amount.
+func (b *buffer) WriteMove(n int) {
+ b.write += n
+}
+
+// WriteSlice returns the write slice for this buffer.
+func (b *buffer) WriteSlice() []byte {
+ return b.data[b.write:]
+}
+
+// bufferPool is a pool for buffers.
+var bufferPool = sync.Pool{
+ New: func() interface{} {
+ return new(buffer)
+ },
+}
diff --git a/pkg/buffer/safemem.go b/pkg/buffer/safemem.go
new file mode 100644
index 000000000..b789e56e9
--- /dev/null
+++ b/pkg/buffer/safemem.go
@@ -0,0 +1,133 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// WriteBlock returns this buffer as a write Block.
+func (b *buffer) WriteBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.WriteSlice())
+}
+
+// ReadBlock returns this buffer as a read Block.
+func (b *buffer) ReadBlock() safemem.Block {
+ return safemem.BlockFromSafeSlice(b.ReadSlice())
+}
+
+// WriteFromSafememReader writes up to count bytes from r to v and advances the
+// write index by the number of bytes written. It calls r.ReadToBlocks() at
+// most once.
+func (v *View) WriteFromSafememReader(r safemem.Reader, count uint64) (uint64, error) {
+ if count == 0 {
+ return 0, nil
+ }
+
+ var (
+ dst safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ // Need at least one buffer.
+ firstBuf := v.data.Back()
+ if firstBuf == nil {
+ firstBuf = bufferPool.Get().(*buffer)
+ v.data.PushBack(firstBuf)
+ }
+
+ // Does the last block have sufficient capacity alone?
+ if l := uint64(firstBuf.WriteSize()); l >= count {
+ dst = safemem.BlockSeqOf(firstBuf.WriteBlock().TakeFirst64(count))
+ } else {
+ // Append blocks until sufficient.
+ count -= l
+ blocks = append(blocks, firstBuf.WriteBlock())
+ for count > 0 {
+ emptyBuf := bufferPool.Get().(*buffer)
+ v.data.PushBack(emptyBuf)
+ block := emptyBuf.WriteBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
+ }
+ dst = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform I/O.
+ n, err := r.ReadToBlocks(dst)
+ v.size += int64(n)
+
+ // Update all indices.
+ for left := n; left > 0; firstBuf = firstBuf.Next() {
+ if l := firstBuf.WriteSize(); left >= uint64(l) {
+ firstBuf.WriteMove(l) // Whole block.
+ left -= uint64(l)
+ } else {
+ firstBuf.WriteMove(int(left)) // Partial block.
+ left = 0
+ }
+ }
+
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. It advances the
+// write index by the number of bytes written.
+func (v *View) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return v.WriteFromSafememReader(&safemem.BlockSeqReader{srcs}, srcs.NumBytes())
+}
+
+// ReadToSafememWriter reads up to count bytes from v to w. It does not advance
+// the read index. It calls w.WriteFromBlocks() at most once.
+func (v *View) ReadToSafememWriter(w safemem.Writer, count uint64) (uint64, error) {
+ if count == 0 {
+ return 0, nil
+ }
+
+ var (
+ src safemem.BlockSeq
+ blocks []safemem.Block
+ )
+
+ firstBuf := v.data.Front()
+ if firstBuf == nil {
+ return 0, nil // No EOF.
+ }
+
+ // Is all the data in a single block?
+ if l := uint64(firstBuf.ReadSize()); l >= count {
+ src = safemem.BlockSeqOf(firstBuf.ReadBlock().TakeFirst64(count))
+ } else {
+ // Build a list of all the buffers.
+ count -= l
+ blocks = append(blocks, firstBuf.ReadBlock())
+ for buf := firstBuf.Next(); buf != nil && count > 0; buf = buf.Next() {
+ block := buf.ReadBlock().TakeFirst64(count)
+ count -= uint64(block.Len())
+ blocks = append(blocks, block)
+ }
+ src = safemem.BlockSeqFromSlice(blocks)
+ }
+
+ // Perform I/O. As documented, we don't advance the read index.
+ return w.WriteFromBlocks(src)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks. It does not advance the
+// read index by the number of bytes read, such that it's only safe to call if
+// the caller guarantees that ReadToBlocks will only be called once.
+func (v *View) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return v.ReadToSafememWriter(&safemem.BlockSeqWriter{dsts}, dsts.NumBytes())
+}
diff --git a/pkg/buffer/safemem_test.go b/pkg/buffer/safemem_test.go
new file mode 100644
index 000000000..47f357e0c
--- /dev/null
+++ b/pkg/buffer/safemem_test.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+func TestSafemem(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ readLen int
+ op func(*View)
+ }{
+ // Basic coverage.
+ {
+ name: "short",
+ input: "010",
+ output: "010",
+ },
+ {
+ name: "long",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize) + "0",
+ },
+ {
+ name: "short-read",
+ input: "0",
+ readLen: 100, // > size.
+ output: "0",
+ },
+ {
+ name: "zero-read",
+ input: "0",
+ output: "",
+ },
+ {
+ name: "read-empty",
+ input: "",
+ readLen: 1, // > size.
+ output: "",
+ },
+
+ // Ensure offsets work.
+ {
+ name: "offsets-short",
+ input: "012",
+ output: "2",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "offsets-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(v *View) {
+ v.TrimFront(2)
+ },
+ },
+ {
+ name: "offsets-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "10",
+ op: func(v *View) {
+ v.TrimFront(bufferSize)
+ },
+ },
+
+ // Ensure truncation works.
+ {
+ name: "truncate-short",
+ input: "012",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ {
+ name: "truncate-long0",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(v *View) {
+ v.Truncate(bufferSize + 1)
+ },
+ },
+ {
+ name: "truncate-long1",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(v *View) {
+ v.Truncate(bufferSize)
+ },
+ },
+ {
+ name: "truncate-long2",
+ input: "0" + strings.Repeat("1", bufferSize) + "0",
+ output: "01",
+ op: func(v *View) {
+ v.Truncate(2)
+ },
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Construct the new view.
+ var view View
+ bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice([]byte(tc.input)))
+ n, err := view.WriteFromBlocks(bs)
+ if err != nil {
+ t.Errorf("expected err nil, got %v", err)
+ }
+ if n != uint64(len(tc.input)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.input), n)
+ }
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(&view)
+ }
+
+ // Read and validate.
+ readLen := tc.readLen
+ if readLen == 0 {
+ readLen = len(tc.output) // Default.
+ }
+ out := make([]byte, readLen)
+ bs = safemem.BlockSeqOf(safemem.BlockFromSafeSlice(out))
+ n, err = view.ReadToBlocks(bs)
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if n != uint64(len(tc.output)) {
+ t.Errorf("expected %d bytes, got %d", len(tc.output), n)
+ }
+
+ // Ensure the contents are correct.
+ if !bytes.Equal(out[:n], []byte(tc.output[:n])) {
+ t.Errorf("contents are wrong: expected %q, got %q", tc.output, string(out))
+ }
+ })
+ }
+}
diff --git a/pkg/buffer/view.go b/pkg/buffer/view.go
new file mode 100644
index 000000000..e6901eadb
--- /dev/null
+++ b/pkg/buffer/view.go
@@ -0,0 +1,390 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "fmt"
+ "io"
+)
+
+// View is a non-linear buffer.
+//
+// All methods are thread compatible.
+//
+// +stateify savable
+type View struct {
+ data bufferList
+ size int64
+}
+
+// TrimFront removes the first count bytes from the buffer.
+func (v *View) TrimFront(count int64) {
+ if count >= v.size {
+ v.advanceRead(v.size)
+ } else {
+ v.advanceRead(count)
+ }
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (v *View) ReadAt(p []byte, offset int64) (int, error) {
+ var (
+ skipped int64
+ done int64
+ )
+ for buf := v.data.Front(); buf != nil && done < int64(len(p)); buf = buf.Next() {
+ needToSkip := int(offset - skipped)
+ if sz := buf.ReadSize(); sz <= needToSkip {
+ skipped += int64(sz)
+ continue
+ }
+
+ // Actually read data.
+ n := copy(p[done:], buf.ReadSlice()[needToSkip:])
+ skipped += int64(needToSkip)
+ done += int64(n)
+ }
+ if int(done) < len(p) || offset+done == v.size {
+ return int(done), io.EOF
+ }
+ return int(done), nil
+}
+
+// advanceRead advances the view's read index.
+//
+// Precondition: there must be sufficient bytes in the buffer.
+func (v *View) advanceRead(count int64) {
+ for buf := v.data.Front(); buf != nil && count > 0; {
+ sz := int64(buf.ReadSize())
+ if sz > count {
+ // There is still data for reading.
+ buf.ReadMove(int(count))
+ v.size -= count
+ count = 0
+ break
+ }
+
+ // Consume the whole buffer.
+ oldBuf := buf
+ buf = buf.Next() // Iterate.
+ v.data.Remove(oldBuf)
+ oldBuf.Reset()
+ bufferPool.Put(oldBuf)
+
+ // Update counts.
+ count -= sz
+ v.size -= sz
+ }
+ if count > 0 {
+ panic(fmt.Sprintf("advanceRead still has %d bytes remaining", count))
+ }
+}
+
+// Truncate truncates the view to the given bytes.
+//
+// This will not grow the view, only shrink it. If a length is passed that is
+// greater than the current size of the view, then nothing will happen.
+//
+// Precondition: length must be >= 0.
+func (v *View) Truncate(length int64) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ if length >= v.size {
+ return // Nothing to do.
+ }
+ for buf := v.data.Back(); buf != nil && v.size > length; buf = v.data.Back() {
+ sz := int64(buf.ReadSize())
+ if after := v.size - sz; after < length {
+ // Truncate the buffer locally.
+ left := (length - after)
+ buf.write = buf.read + int(left)
+ v.size = length
+ break
+ }
+
+ // Drop the buffer completely; see above.
+ v.data.Remove(buf)
+ buf.Reset()
+ bufferPool.Put(buf)
+ v.size -= sz
+ }
+}
+
+// Grow grows the given view to the number of bytes, which will be appended. If
+// zero is true, all these bytes will be zero. If zero is false, then this is
+// the caller's responsibility.
+//
+// Precondition: length must be >= 0.
+func (v *View) Grow(length int64, zero bool) {
+ if length < 0 {
+ panic("negative length provided")
+ }
+ for v.size < length {
+ buf := v.data.Back()
+
+ // Is there some space in the last buffer?
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Write up to length bytes.
+ sz := buf.WriteSize()
+ if int64(sz) > length-v.size {
+ sz = int(length - v.size)
+ }
+
+ // Zero the written section; note that this pattern is
+ // specifically recognized and optimized by the compiler.
+ if zero {
+ for i := buf.write; i < buf.write+sz; i++ {
+ buf.data[i] = 0
+ }
+ }
+
+ // Advance the index.
+ buf.WriteMove(sz)
+ v.size += int64(sz)
+ }
+}
+
+// Prepend prepends the given data.
+func (v *View) Prepend(data []byte) {
+ // Is there any space in the first buffer?
+ if buf := v.data.Front(); buf != nil && buf.read > 0 {
+ // Fill up before the first write.
+ avail := buf.read
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read -= n
+ }
+
+ for len(data) > 0 {
+ // Do we need an empty buffer?
+ buf := bufferPool.Get().(*buffer)
+ v.data.PushFront(buf)
+
+ // The buffer is empty; copy last chunk.
+ avail := len(buf.data)
+ bStart := 0
+ dStart := len(data) - avail
+ if avail > len(data) {
+ bStart = avail - len(data)
+ dStart = 0
+ }
+
+ // We have to put the data at the end of the current
+ // buffer in order to ensure that the next prepend will
+ // correctly fill up the beginning of this buffer.
+ n := copy(buf.data[bStart:], data[dStart:])
+ data = data[:dStart]
+ v.size += int64(n)
+ buf.read = len(buf.data) - n
+ buf.write = len(buf.data)
+ }
+}
+
+// Append appends the given data.
+func (v *View) Append(data []byte) {
+ for done := 0; done < len(data); {
+ buf := v.data.Back()
+
+ // Ensure there's a buffer with space.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Copy in to the given buffer.
+ n := copy(buf.WriteSlice(), data[done:])
+ done += n
+ buf.WriteMove(n)
+ v.size += int64(n)
+ }
+}
+
+// Flatten returns a flattened copy of this data.
+//
+// This method should not be used in any performance-sensitive paths. It may
+// allocate a fresh byte slice sufficiently large to contain all the data in
+// the buffer. This is principally for debugging.
+//
+// N.B. Tee data still belongs to this view, as if there is a single buffer
+// present, then it will be returned directly. This should be used for
+// temporary use only, and a reference to the given slice should not be held.
+func (v *View) Flatten() []byte {
+ if buf := v.data.Front(); buf == nil {
+ return nil // No data at all.
+ } else if buf.Next() == nil {
+ return buf.ReadSlice() // Only one buffer.
+ }
+ data := make([]byte, 0, v.size) // Need to flatten.
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ // Copy to the allocated slice.
+ data = append(data, buf.ReadSlice()...)
+ }
+ return data
+}
+
+// Size indicates the total amount of data available in this view.
+func (v *View) Size() int64 {
+ return v.size
+}
+
+// Copy makes a strict copy of this view.
+func (v *View) Copy() (other View) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ other.Append(buf.ReadSlice())
+ }
+ return
+}
+
+// Apply applies the given function across all valid data.
+func (v *View) Apply(fn func([]byte)) {
+ for buf := v.data.Front(); buf != nil; buf = buf.Next() {
+ fn(buf.ReadSlice())
+ }
+}
+
+// Merge merges the provided View with this one.
+//
+// The other view will be appended to v, and other will be empty after this
+// operation completes.
+func (v *View) Merge(other *View) {
+ // Copy over all buffers.
+ for buf := other.data.Front(); buf != nil; buf = other.data.Front() {
+ other.data.Remove(buf)
+ v.data.PushBack(buf)
+ }
+
+ // Adjust sizes.
+ v.size += other.size
+ other.size = 0
+}
+
+// WriteFromReader writes to the buffer from an io.Reader.
+//
+// A minimum read size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ for done < count {
+ buf := v.data.Back()
+
+ // Ensure we have an empty buffer.
+ if buf == nil || buf.Full() {
+ buf = bufferPool.Get().(*buffer)
+ v.data.PushBack(buf)
+ }
+
+ // Is this less than the minimum batch?
+ if buf.WriteSize() < minBatch && (count-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = r.Read(tmp)
+ v.Append(tmp[:n])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the read, if necessary.
+ sz := buf.WriteSize()
+ if left := count - done; int64(sz) > left {
+ sz = int(left)
+ }
+
+ // Pass the relevant portion of the buffer.
+ n, err = r.Read(buf.WriteSlice()[:sz])
+ buf.WriteMove(n)
+ done += int64(n)
+ v.size += int64(n)
+ if err == io.EOF {
+ err = nil // Short write allowed.
+ break
+ } else if err != nil {
+ break
+ }
+ }
+ return done, err
+}
+
+// ReadToWriter reads from the buffer into an io.Writer.
+//
+// N.B. This does not consume the bytes read. TrimFront should
+// be called appropriately after this call in order to do so.
+//
+// A minimum write size equal to unsafe.Sizeof(unintptr) is enforced,
+// provided that count is greater than or equal to unsafe.Sizeof(uintptr).
+func (v *View) ReadToWriter(w io.Writer, count int64) (int64, error) {
+ var (
+ done int64
+ n int
+ err error
+ )
+ offset := 0 // Spill-over for batching.
+ for buf := v.data.Front(); buf != nil && done < count; buf = buf.Next() {
+ // Has this been consumed? Skip it.
+ sz := buf.ReadSize()
+ if sz <= offset {
+ offset -= sz
+ continue
+ }
+ sz -= offset
+
+ // Is this less than the minimum batch?
+ left := count - done
+ if sz < minBatch && left >= int64(minBatch) && (v.size-done) >= int64(minBatch) {
+ tmp := make([]byte, minBatch)
+ n, err = v.ReadAt(tmp, done)
+ w.Write(tmp[:n])
+ done += int64(n)
+ offset = n - sz // Reset below.
+ if err != nil {
+ break
+ }
+ continue
+ }
+
+ // Limit the write if necessary.
+ if int64(sz) >= left {
+ sz = int(left)
+ }
+
+ // Perform the actual write.
+ n, err = w.Write(buf.ReadSlice()[offset : offset+sz])
+ done += int64(n)
+ if err != nil {
+ break
+ }
+
+ // Reset spill-over.
+ offset = 0
+ }
+ return done, err
+}
diff --git a/pkg/buffer/view_test.go b/pkg/buffer/view_test.go
new file mode 100644
index 000000000..3db1bc6ee
--- /dev/null
+++ b/pkg/buffer/view_test.go
@@ -0,0 +1,467 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "bytes"
+ "io"
+ "strings"
+ "testing"
+)
+
+func fillAppend(v *View, data []byte) {
+ v.Append(data)
+}
+
+func fillAppendEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ v.Append(data)
+ v.TrimFront(bufferSize - 1)
+}
+
+func fillWriteFromReader(v *View, data []byte) {
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+}
+
+func fillWriteFromReaderEnd(v *View, data []byte) {
+ v.Grow(bufferSize-1, false)
+ b := bytes.NewBuffer(data)
+ v.WriteFromReader(b, int64(len(data)))
+ v.TrimFront(bufferSize - 1)
+}
+
+var fillFuncs = map[string]func(*View, []byte){
+ "append": fillAppend,
+ "appendEnd": fillAppendEnd,
+ "writeFromReader": fillWriteFromReader,
+ "writeFromReaderEnd": fillWriteFromReaderEnd,
+}
+
+func testReadAt(t *testing.T, v *View, offset int64, n int, wantStr string, wantErr error) {
+ t.Helper()
+ d := make([]byte, n)
+ n, err := v.ReadAt(d, offset)
+ if n != len(wantStr) {
+ t.Errorf("got %d, want %d", n, len(wantStr))
+ }
+ if err != wantErr {
+ t.Errorf("got err %v, want %v", err, wantErr)
+ }
+ if !bytes.Equal(d[:n], []byte(wantStr)) {
+ t.Errorf("got %q, want %q", string(d[:n]), wantStr)
+ }
+}
+
+func TestView(t *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ output string
+ op func(*testing.T, *View)
+ }{
+ // Preconditions.
+ {
+ name: "truncate-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Truncate(-1) did not panic")
+ }
+ }()
+ v.Truncate(-1)
+ },
+ },
+ {
+ name: "grow-check",
+ input: "hello",
+ output: "hello", // Not touched.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Grow(-1) did not panic")
+ }
+ }()
+ v.Grow(-1, false)
+ },
+ },
+ {
+ name: "advance-check",
+ input: "hello",
+ output: "", // Consumed.
+ op: func(t *testing.T, v *View) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("advanceRead(Size()+1) did not panic")
+ }
+ }()
+ v.advanceRead(v.Size() + 1)
+ },
+ },
+
+ // Prepend.
+ {
+ name: "prepend",
+ input: "world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("hello "))
+ },
+ },
+ {
+ name: "prepend-backfill-full",
+ input: "hello world",
+ output: "jello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("j"))
+ },
+ },
+ {
+ name: "prepend-backfill-under",
+ input: "hello world",
+ output: "hola world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(5)
+ v.Prepend([]byte("hola"))
+ },
+ },
+ {
+ name: "prepend-backfill-over",
+ input: "hello world",
+ output: "smello world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ v.Prepend([]byte("sm"))
+ },
+ },
+ {
+ name: "prepend-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: "0" + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: "0" + strings.Repeat("1", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte("0"))
+ },
+ },
+ {
+ name: "prepend-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("0", bufferSize*3) + strings.Repeat("1", bufferSize-1),
+ op: func(t *testing.T, v *View) {
+ v.Prepend([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Append and write.
+ {
+ name: "append",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(" world"))
+ },
+ },
+ {
+ name: "append-fill",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-overflow",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + "0",
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte("0"))
+ },
+ },
+ {
+ name: "append-multiple-buffers",
+ input: strings.Repeat("1", bufferSize-1),
+ output: strings.Repeat("1", bufferSize-1) + strings.Repeat("0", bufferSize*3),
+ op: func(t *testing.T, v *View) {
+ v.Append([]byte(strings.Repeat("0", bufferSize*3)))
+ },
+ },
+
+ // Truncate.
+ {
+ name: "truncate",
+ input: "hello world",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+ {
+ name: "truncate-noop",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(v.Size() + 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.Truncate(bufferSize*2 - 1)
+ },
+ },
+ {
+ name: "truncate-multiple-buffers-to-one",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "11111",
+ op: func(t *testing.T, v *View) {
+ v.Truncate(5)
+ },
+ },
+
+ // TrimFront.
+ {
+ name: "trim",
+ input: "hello world",
+ output: "world",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(6)
+ },
+ },
+ {
+ name: "trim-too-large",
+ input: "hello world",
+ output: "",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(v.Size() + 1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers",
+ input: strings.Repeat("1", bufferSize*2),
+ output: strings.Repeat("1", bufferSize*2-1),
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(1)
+ },
+ },
+ {
+ name: "trim-multiple-buffers-to-one-buffer",
+ input: strings.Repeat("1", bufferSize*2),
+ output: "1",
+ op: func(t *testing.T, v *View) {
+ v.TrimFront(bufferSize*2 - 1)
+ },
+ },
+
+ // Grow.
+ {
+ name: "grow",
+ input: "hello world",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ v.Grow(1, true)
+ },
+ },
+ {
+ name: "grow-from-zero",
+ output: strings.Repeat("\x00", 1024),
+ op: func(t *testing.T, v *View) {
+ v.Grow(1024, true)
+ },
+ },
+ {
+ name: "grow-from-non-zero",
+ input: strings.Repeat("1", bufferSize),
+ output: strings.Repeat("1", bufferSize) + strings.Repeat("\x00", bufferSize),
+ op: func(t *testing.T, v *View) {
+ v.Grow(bufferSize*2, true)
+ },
+ },
+
+ // Copy.
+ {
+ name: "copy",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte("hello")
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+ {
+ name: "copy-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ other := v.Copy()
+ bs := other.Flatten()
+ want := []byte(strings.Repeat("1", bufferSize+1))
+ if !bytes.Equal(bs, want) {
+ t.Errorf("expected %v, got %v", want, bs)
+ }
+ },
+ },
+
+ // Merge.
+ {
+ name: "merge",
+ input: "hello",
+ output: "hello world",
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(" world"))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+ {
+ name: "merge-large",
+ input: strings.Repeat("1", bufferSize+1),
+ output: strings.Repeat("1", bufferSize+1) + strings.Repeat("0", bufferSize+1),
+ op: func(t *testing.T, v *View) {
+ var other View
+ other.Append([]byte(strings.Repeat("0", bufferSize+1)))
+ v.Merge(&other)
+ if sz := other.Size(); sz != 0 {
+ t.Errorf("expected 0, got %d", sz)
+ }
+ },
+ },
+
+ // ReadAt.
+ {
+ name: "readat",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 6, "hello", io.EOF) },
+ },
+ {
+ name: "readat-long",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 8, "hello", io.EOF) },
+ },
+ {
+ name: "readat-short",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 0, 3, "hel", nil) },
+ },
+ {
+ name: "readat-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 3, "llo", io.EOF) },
+ },
+ {
+ name: "readat-long-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 8, "llo", io.EOF) },
+ },
+ {
+ name: "readat-short-offset",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, 2, 2, "ll", nil) },
+ },
+ {
+ name: "readat-skip-all",
+ input: "hello",
+ output: "hello",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "", io.EOF) },
+ },
+ {
+ name: "readat-second-buffer",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 1, "1", nil) },
+ },
+ {
+ name: "readat-second-buffer-end",
+ input: strings.Repeat("0", bufferSize+1) + "12",
+ output: strings.Repeat("0", bufferSize+1) + "12",
+ op: func(t *testing.T, v *View) { testReadAt(t, v, bufferSize+1, 2, "12", io.EOF) },
+ },
+ }
+
+ for _, tc := range testCases {
+ for fillName, fn := range fillFuncs {
+ t.Run(fillName+"/"+tc.name, func(t *testing.T) {
+ // Construct & fill the view.
+ var view View
+ fn(&view, []byte(tc.input))
+
+ // Run the operation.
+ if tc.op != nil {
+ tc.op(t, &view)
+ }
+
+ // Flatten and validate.
+ out := view.Flatten()
+ if !bytes.Equal([]byte(tc.output), out) {
+ t.Errorf("expected %q, got %q", tc.output, string(out))
+ }
+
+ // Ensure the size is correct.
+ if len(out) != int(view.Size()) {
+ t.Errorf("size is wrong: expected %d, got %d", len(out), view.Size())
+ }
+
+ // Calculate contents via apply.
+ var appliedOut []byte
+ view.Apply(func(b []byte) {
+ appliedOut = append(appliedOut, b...)
+ })
+ if len(appliedOut) != len(out) {
+ t.Errorf("expected %d, got %d", len(out), len(appliedOut))
+ }
+ if !bytes.Equal(appliedOut, out) {
+ t.Errorf("expected %v, got %v", out, appliedOut)
+ }
+
+ // Calculate contents via ReadToWriter.
+ var b bytes.Buffer
+ n, err := view.ReadToWriter(&b, int64(len(out)))
+ if n != int64(len(out)) {
+ t.Errorf("expected %d, got %d", len(out), n)
+ }
+ if err != nil {
+ t.Errorf("expected nil, got %v", err)
+ }
+ if !bytes.Equal(b.Bytes(), out) {
+ t.Errorf("expected %v, got %v", out, b.Bytes())
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/buffer/view_unsafe.go b/pkg/buffer/view_unsafe.go
new file mode 100644
index 000000000..d1ef39b26
--- /dev/null
+++ b/pkg/buffer/view_unsafe.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package buffer
+
+import (
+ "unsafe"
+)
+
+// minBatch is the smallest Read or Write operation that the
+// WriteFromReader and ReadToWriter functions will use.
+//
+// This is defined as the size of a native pointer.
+const minBatch = int(unsafe.Sizeof(uintptr(0)))
diff --git a/pkg/cleanup/BUILD b/pkg/cleanup/BUILD
new file mode 100644
index 000000000..5c34b9872
--- /dev/null
+++ b/pkg/cleanup/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "cleanup",
+ srcs = ["cleanup.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ ],
+)
+
+go_test(
+ name = "cleanup_test",
+ srcs = ["cleanup_test.go"],
+ library = ":cleanup",
+)
diff --git a/pkg/cleanup/cleanup.go b/pkg/cleanup/cleanup.go
new file mode 100644
index 000000000..14a05f076
--- /dev/null
+++ b/pkg/cleanup/cleanup.go
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cleanup provides utilities to clean "stuff" on defers.
+package cleanup
+
+// Cleanup allows defers to be aborted when cleanup needs to happen
+// conditionally. Usage:
+// cu := cleanup.Make(func() { f.Close() })
+// defer cu.Clean() // failure before release is called will close the file.
+// ...
+// cu.Add(func() { f2.Close() }) // Adds another cleanup function
+// ...
+// cu.Release() // on success, aborts closing the file.
+// return f
+type Cleanup struct {
+ cleaners []func()
+}
+
+// Make creates a new Cleanup object.
+func Make(f func()) Cleanup {
+ return Cleanup{cleaners: []func(){f}}
+}
+
+// Add adds a new function to be called on Clean().
+func (c *Cleanup) Add(f func()) {
+ c.cleaners = append(c.cleaners, f)
+}
+
+// Clean calls all cleanup functions in reverse order.
+func (c *Cleanup) Clean() {
+ clean(c.cleaners)
+ c.cleaners = nil
+}
+
+// Release releases the cleanup from its duties, i.e. cleanup functions are not
+// called after this point. Returns a function that calls all registered
+// functions in case the caller has use for them.
+func (c *Cleanup) Release() func() {
+ old := c.cleaners
+ c.cleaners = nil
+ return func() { clean(old) }
+}
+
+func clean(cleaners []func()) {
+ for i := len(cleaners) - 1; i >= 0; i-- {
+ cleaners[i]()
+ }
+}
diff --git a/pkg/cleanup/cleanup_test.go b/pkg/cleanup/cleanup_test.go
new file mode 100644
index 000000000..ab3d9ed95
--- /dev/null
+++ b/pkg/cleanup/cleanup_test.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cleanup
+
+import "testing"
+
+func testCleanupHelper(clean, cleanAdd *bool, release bool) func() {
+ cu := Make(func() {
+ *clean = true
+ })
+ cu.Add(func() {
+ *cleanAdd = true
+ })
+ defer cu.Clean()
+ if release {
+ return cu.Release()
+ }
+ return nil
+}
+
+func TestCleanup(t *testing.T) {
+ clean := false
+ cleanAdd := false
+ testCleanupHelper(&clean, &cleanAdd, false)
+ if !clean {
+ t.Fatalf("cleanup function was not called.")
+ }
+ if !cleanAdd {
+ t.Fatalf("added cleanup function was not called.")
+ }
+}
+
+func TestRelease(t *testing.T) {
+ clean := false
+ cleanAdd := false
+ cleaner := testCleanupHelper(&clean, &cleanAdd, true)
+
+ // Check that clean was not called after release.
+ if clean {
+ t.Fatalf("cleanup function was called.")
+ }
+ if cleanAdd {
+ t.Fatalf("added cleanup function was called.")
+ }
+
+ // Call the cleaner function and check that both cleanup functions are called.
+ cleaner()
+ if !clean {
+ t.Fatalf("cleanup function was not called.")
+ }
+ if !cleanAdd {
+ t.Fatalf("added cleanup function was not called.")
+ }
+}
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index a0b21d4bd..1f75319a7 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -1,19 +1,20 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "compressio",
srcs = ["compressio.go"],
- importpath = "gvisor.dev/gvisor/pkg/compressio",
visibility = ["//:sandbox"],
- deps = ["//pkg/binary"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/sync",
+ ],
)
go_test(
name = "compressio_test",
size = "medium",
srcs = ["compressio_test.go"],
- embed = [":compressio"],
+ library = ":compressio",
)
diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go
index 3b0bb086e..b094c5662 100644
--- a/pkg/compressio/compressio.go
+++ b/pkg/compressio/compressio.go
@@ -52,9 +52,9 @@ import (
"hash"
"io"
"runtime"
- "sync"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sync"
)
var bufPool = sync.Pool{
@@ -346,20 +346,22 @@ func (p *pool) schedule(c *chunk, callback func(*chunk) error) error {
}
}
-// reader chunks reads and decompresses.
-type reader struct {
+// Reader is a compressed reader.
+type Reader struct {
pool
// in is the source.
in io.Reader
}
+var _ io.Reader = (*Reader)(nil)
+
// NewReader returns a new compressed reader. If key is non-nil, the data stream
// is assumed to contain expected hash values, which will be compared against
// hash values computed from the compressed bytes. See package comments for
// details.
-func NewReader(in io.Reader, key []byte) (io.Reader, error) {
- r := &reader{
+func NewReader(in io.Reader, key []byte) (*Reader, error) {
+ r := &Reader{
in: in,
}
@@ -394,8 +396,19 @@ var errNewBuffer = errors.New("buffer ready")
// ErrHashMismatch is returned if the hash does not match.
var ErrHashMismatch = errors.New("hash mismatch")
+// ReadByte implements wire.Reader.ReadByte.
+func (r *Reader) ReadByte() (byte, error) {
+ var p [1]byte
+ n, err := r.Read(p[:])
+ if n != 1 {
+ return p[0], err
+ }
+ // Suppress EOF.
+ return p[0], nil
+}
+
// Read implements io.Reader.Read.
-func (r *reader) Read(p []byte) (int, error) {
+func (r *Reader) Read(p []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()
@@ -551,8 +564,8 @@ func (r *reader) Read(p []byte) (int, error) {
return done, nil
}
-// writer chunks and schedules writes.
-type writer struct {
+// Writer is a compressed writer.
+type Writer struct {
pool
// out is the underlying writer.
@@ -562,6 +575,8 @@ type writer struct {
closed bool
}
+var _ io.Writer = (*Writer)(nil)
+
// NewWriter returns a new compressed writer. If key is non-nil, hash values are
// generated and written out for compressed bytes. See package comments for
// details.
@@ -569,8 +584,8 @@ type writer struct {
// The recommended chunkSize is on the order of 1M. Extra memory may be
// buffered (in the form of read-ahead, or buffered writes), and is limited to
// O(chunkSize * [1+GOMAXPROCS]).
-func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) {
- w := &writer{
+func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) {
+ w := &Writer{
pool: pool{
chunkSize: chunkSize,
buf: bufPool.Get().(*bytes.Buffer),
@@ -597,7 +612,7 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write
}
// flush writes a single buffer.
-func (w *writer) flush(c *chunk) error {
+func (w *Writer) flush(c *chunk) error {
// Prefix each chunk with a length; this allows the reader to safely
// limit reads while buffering.
l := uint32(c.compressed.Len())
@@ -624,8 +639,23 @@ func (w *writer) flush(c *chunk) error {
return nil
}
+// WriteByte implements wire.Writer.WriteByte.
+//
+// Note that this implementation is necessary on the object itself, as an
+// interface-based dispatch cannot tell whether the array backing the slice
+// escapes, therefore the all bytes written will generate an escape.
+func (w *Writer) WriteByte(b byte) error {
+ var p [1]byte
+ p[0] = b
+ n, err := w.Write(p[:])
+ if n != 1 {
+ return err
+ }
+ return nil
+}
+
// Write implements io.Writer.Write.
-func (w *writer) Write(p []byte) (int, error) {
+func (w *Writer) Write(p []byte) (int, error) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -710,7 +740,7 @@ func (w *writer) Write(p []byte) (int, error) {
}
// Close implements io.Closer.Close.
-func (w *writer) Close() error {
+func (w *Writer) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
diff --git a/pkg/sentry/context/BUILD b/pkg/context/BUILD
index 8dc1a77b1..239f31149 100644
--- a/pkg/sentry/context/BUILD
+++ b/pkg/context/BUILD
@@ -1,12 +1,11 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "context",
srcs = ["context.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/context",
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/amutex",
"//pkg/log",
diff --git a/pkg/sentry/context/context.go b/pkg/context/context.go
index dfd62cbdb..5319b6d8d 100644
--- a/pkg/sentry/context/context.go
+++ b/pkg/context/context.go
@@ -12,10 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package context defines the sentry's Context type.
+// Package context defines an internal context type.
+//
+// The given Context conforms to the standard Go context, but mandates
+// additional methods that are specific to the kernel internals. Note however,
+// that the Context described by this package carries additional constraints
+// regarding concurrent access and retaining beyond the scope of a call.
+//
+// See the Context type for complete details.
package context
import (
+ "context"
+ "time"
+
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/log"
)
@@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
type Context interface {
log.Logger
amutex.Sleeper
+ context.Context
// UninterruptibleSleepStart indicates the beginning of an uninterruptible
// sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
@@ -72,19 +83,36 @@ type Context interface {
// AddressSpace is activated. Normally activate is the same value as the
// deactivate parameter passed to UninterruptibleSleepStart.
UninterruptibleSleepFinish(activate bool)
+}
+
+// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep
+// methods for anonymous embedding in other types that do not implement sleeps.
+type NoopSleeper struct {
+ amutex.NoopSleeper
+}
+
+// UninterruptibleSleepStart does nothing.
+func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+
+// UninterruptibleSleepFinish does nothing.
+func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+
+// Deadline returns zero values, meaning no deadline.
+func (NoopSleeper) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
- // Value returns the value associated with this Context for key, or nil if
- // no value is associated with key. Successive calls to Value with the same
- // key returns the same result.
- //
- // A key identifies a specific value in a Context. Functions that wish to
- // retrieve values from Context typically allocate a key in a global
- // variable then use that key as the argument to Context.Value. A key can
- // be any type that supports equality; packages should define keys as an
- // unexported type to avoid collisions.
- Value(key interface{}) interface{}
+// Done returns nil.
+func (NoopSleeper) Done() <-chan struct{} {
+ return nil
+}
+
+// Err returns nil.
+func (NoopSleeper) Err() error {
+ return nil
}
+// logContext implements basic logging.
type logContext struct {
log.Logger
NoopSleeper
@@ -95,27 +123,10 @@ func (logContext) Value(key interface{}) interface{} {
return nil
}
-// NoopSleeper is a noop implementation of amutex.Sleeper and
-// Context.UninterruptibleSleep* methods for anonymous embedding in other types
-// that do not want to notify kernel.Task about sleeps.
-type NoopSleeper struct {
- amutex.NoopSleeper
-}
-
-// UninterruptibleSleepStart does nothing.
-func (NoopSleeper) UninterruptibleSleepStart(bool) {}
-
-// UninterruptibleSleepFinish does nothing.
-func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
-
// bgContext is the context returned by context.Background.
var bgContext = &logContext{Logger: log.Log()}
// Background returns an empty context using the default logger.
-//
-// Users should be wary of using a Background context. Please tag any use with
-// FIXME(b/38173783) and a note to remove this use.
-//
// Generally, one should use the Task as their context when available, or avoid
// having to use a context in places where a Task is unavailable.
//
diff --git a/pkg/control/client/BUILD b/pkg/control/client/BUILD
index 066d7b1a1..1b9e10ee7 100644
--- a/pkg/control/client/BUILD
+++ b/pkg/control/client/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -7,7 +7,6 @@ go_library(
srcs = [
"client.go",
],
- importpath = "gvisor.dev/gvisor/pkg/control/client",
visibility = ["//:sandbox"],
deps = [
"//pkg/unet",
diff --git a/pkg/control/server/BUILD b/pkg/control/server/BUILD
index 21adf3adf..002d2ef44 100644
--- a/pkg/control/server/BUILD
+++ b/pkg/control/server/BUILD
@@ -1,14 +1,14 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "server",
srcs = ["server.go"],
- importpath = "gvisor.dev/gvisor/pkg/control/server",
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
],
diff --git a/pkg/control/server/server.go b/pkg/control/server/server.go
index a56152d10..41abe1f2d 100644
--- a/pkg/control/server/server.go
+++ b/pkg/control/server/server.go
@@ -22,9 +22,9 @@ package server
import (
"os"
- "sync"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
)
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index ed111fd2a..d6cb1a549 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,8 +7,9 @@ go_library(
srcs = [
"cpu_amd64.s",
"cpuid.go",
+ "cpuid_arm64.go",
+ "cpuid_x86.go",
],
- importpath = "gvisor.dev/gvisor/pkg/cpuid",
visibility = ["//:sandbox"],
deps = ["//pkg/log"],
)
@@ -17,16 +17,19 @@ go_library(
go_test(
name = "cpuid_test",
size = "small",
- srcs = ["cpuid_test.go"],
- embed = [":cpuid"],
+ srcs = [
+ "cpuid_arm64_test.go",
+ "cpuid_x86_test.go",
+ ],
+ library = ":cpuid",
)
go_test(
name = "cpuid_parse_test",
size = "small",
srcs = [
- "cpuid_parse_test.go",
+ "cpuid_parse_x86_test.go",
],
- embed = [":cpuid"],
+ library = ":cpuid",
tags = ["manual"],
)
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 5d61dc2ff..f7f9dbf86 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
-
// Package cpuid provides basic functionality for creating and adjusting CPU
// feature sets.
//
@@ -21,1060 +19,20 @@
// known platform, or HostFeatureSet()) and then add, remove, and test for
// features as desired.
//
-// For example: Test for hardware extended state saving, and if we don't have
-// it, don't expose AVX, which cannot be saved with fxsave.
+// For example: on x86, test for hardware extended state saving, and if
+// we don't have it, don't expose AVX, which cannot be saved with fxsave.
//
// if !HostFeatureSet().HasFeature(X86FeatureXSAVE) {
// exposedFeatures.Remove(X86FeatureAVX)
// }
package cpuid
-import (
- "bytes"
- "fmt"
- "io/ioutil"
- "strconv"
- "strings"
-
- "gvisor.dev/gvisor/pkg/log"
-)
-
-// Common references for CPUID leaves and bits:
-//
-// Intel:
-// * Intel SDM Volume 2, Chapter 3.2 "CPUID" (more up-to-date)
-// * Intel Application Note 485 (more detailed)
-//
-// AMD:
-// * AMD64 APM Volume 3, Appendix 3 "Obtaining Processor Information ..."
-
// Feature is a unique identifier for a particular cpu feature. We just use an
-// int as a feature number on x86.
+// int as a feature number on x86 and arm64.
//
-// Features are numbered according to "blocks". Each block is 32 bits, and
-// feature bits from the same source (cpuid leaf/level) are in the same block.
-type Feature int
-
-// block is a collection of 32 Feature bits.
-type block int
-
-const blockSize = 32
-
-// Feature bits are numbered according to "blocks". Each block is 32 bits, and
+// On x86, features are numbered according to "blocks". Each block is 32 bits, and
// feature bits from the same source (cpuid leaf/level) are in the same block.
-func featureID(b block, bit int) Feature {
- return Feature(32*int(b) + bit)
-}
-
-// Block 0 constants are all of the "basic" feature bits returned by a cpuid in
-// ecx with eax=1.
-const (
- X86FeatureSSE3 Feature = iota
- X86FeaturePCLMULDQ
- X86FeatureDTES64
- X86FeatureMONITOR
- X86FeatureDSCPL
- X86FeatureVMX
- X86FeatureSMX
- X86FeatureEST
- X86FeatureTM2
- X86FeatureSSSE3 // Not a typo, "supplemental" SSE3.
- X86FeatureCNXTID
- X86FeatureSDBG
- X86FeatureFMA
- X86FeatureCX16
- X86FeatureXTPR
- X86FeaturePDCM
- _ // ecx bit 16 is reserved.
- X86FeaturePCID
- X86FeatureDCA
- X86FeatureSSE4_1
- X86FeatureSSE4_2
- X86FeatureX2APIC
- X86FeatureMOVBE
- X86FeaturePOPCNT
- X86FeatureTSCD
- X86FeatureAES
- X86FeatureXSAVE
- X86FeatureOSXSAVE
- X86FeatureAVX
- X86FeatureF16C
- X86FeatureRDRAND
- _ // ecx bit 31 is reserved.
-)
-
-// Block 1 constants are all of the "basic" feature bits returned by a cpuid in
-// edx with eax=1.
-const (
- X86FeatureFPU Feature = 32 + iota
- X86FeatureVME
- X86FeatureDE
- X86FeaturePSE
- X86FeatureTSC
- X86FeatureMSR
- X86FeaturePAE
- X86FeatureMCE
- X86FeatureCX8
- X86FeatureAPIC
- _ // edx bit 10 is reserved.
- X86FeatureSEP
- X86FeatureMTRR
- X86FeaturePGE
- X86FeatureMCA
- X86FeatureCMOV
- X86FeaturePAT
- X86FeaturePSE36
- X86FeaturePSN
- X86FeatureCLFSH
- _ // edx bit 20 is reserved.
- X86FeatureDS
- X86FeatureACPI
- X86FeatureMMX
- X86FeatureFXSR
- X86FeatureSSE
- X86FeatureSSE2
- X86FeatureSS
- X86FeatureHTT
- X86FeatureTM
- X86FeatureIA64
- X86FeaturePBE
-)
-
-// Block 2 bits are the "structured extended" features returned in ebx for
-// eax=7, ecx=0.
-const (
- X86FeatureFSGSBase Feature = 2*32 + iota
- X86FeatureTSC_ADJUST
- _ // ebx bit 2 is reserved.
- X86FeatureBMI1
- X86FeatureHLE
- X86FeatureAVX2
- X86FeatureFDP_EXCPTN_ONLY
- X86FeatureSMEP
- X86FeatureBMI2
- X86FeatureERMS
- X86FeatureINVPCID
- X86FeatureRTM
- X86FeatureCQM
- X86FeatureFPCSDS
- X86FeatureMPX
- X86FeatureRDT
- X86FeatureAVX512F
- X86FeatureAVX512DQ
- X86FeatureRDSEED
- X86FeatureADX
- X86FeatureSMAP
- X86FeatureAVX512IFMA
- X86FeaturePCOMMIT
- X86FeatureCLFLUSHOPT
- X86FeatureCLWB
- X86FeatureIPT // Intel processor trace.
- X86FeatureAVX512PF
- X86FeatureAVX512ER
- X86FeatureAVX512CD
- X86FeatureSHA
- X86FeatureAVX512BW
- X86FeatureAVX512VL
-)
-
-// Block 3 bits are the "extended" features returned in ecx for eax=7, ecx=0.
-const (
- X86FeaturePREFETCHWT1 Feature = 3*32 + iota
- X86FeatureAVX512VBMI
- X86FeatureUMIP
- X86FeaturePKU
-)
-
-// Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX.
-// The CPUID leaf is available only if 'X86FeatureXSAVE' is present.
-const (
- X86FeatureXSAVEOPT Feature = 4*32 + iota
- X86FeatureXSAVEC
- X86FeatureXGETBV1
- X86FeatureXSAVES
- // EAX[31:4] are reserved.
-)
-
-// Block 5 constants are the extended feature bits in
-// CPUID.(EAX=0x80000001):ECX.
-const (
- X86FeatureLAHF64 Feature = 5*32 + iota
- X86FeatureCMP_LEGACY
- X86FeatureSVM
- X86FeatureEXTAPIC
- X86FeatureCR8_LEGACY
- X86FeatureLZCNT
- X86FeatureSSE4A
- X86FeatureMISALIGNSSE
- X86FeaturePREFETCHW
- X86FeatureOSVW
- X86FeatureIBS
- X86FeatureXOP
- X86FeatureSKINIT
- X86FeatureWDT
- _ // ecx bit 14 is reserved.
- X86FeatureLWP
- X86FeatureFMA4
- X86FeatureTCE
- _ // ecx bit 18 is reserved.
- _ // ecx bit 19 is reserved.
- _ // ecx bit 20 is reserved.
- X86FeatureTBM
- X86FeatureTOPOLOGY
- X86FeaturePERFCTR_CORE
- X86FeaturePERFCTR_NB
- _ // ecx bit 25 is reserved.
- X86FeatureBPEXT
- X86FeaturePERFCTR_TSC
- X86FeaturePERFCTR_LLC
- X86FeatureMWAITX
- // ECX[31:30] are reserved.
-)
-
-// Block 6 constants are the extended feature bits in
-// CPUID.(EAX=0x80000001):EDX.
//
-// These are sparse, and so the bit positions are assigned manually.
-const (
- // On AMD, EDX[24:23] | EDX[17:12] | EDX[9:0] are duplicate features
- // also defined in block 1 (in identical bit positions). Those features
- // are not listed here.
- block6DuplicateMask = 0x183f3ff
-
- X86FeatureSYSCALL Feature = 6*32 + 11
- X86FeatureNX Feature = 6*32 + 20
- X86FeatureMMXEXT Feature = 6*32 + 22
- X86FeatureFXSR_OPT Feature = 6*32 + 25
- X86FeatureGBPAGES Feature = 6*32 + 26
- X86FeatureRDTSCP Feature = 6*32 + 27
- X86FeatureLM Feature = 6*32 + 29
- X86Feature3DNOWEXT Feature = 6*32 + 30
- X86Feature3DNOW Feature = 6*32 + 31
-)
-
-// linuxBlockOrder defines the order in which linux organizes the feature
-// blocks. Linux also tracks feature bits in 32-bit blocks, but in an order
-// which doesn't match well here, so for the /proc/cpuinfo generation we simply
-// re-map the blocks to Linux's ordering and then go through the bits in each
-// block.
-var linuxBlockOrder = []block{1, 6, 0, 5, 2, 4, 3}
-
-// To make emulation of /proc/cpuinfo easy, these names match the names of the
-// basic features in Linux defined in arch/x86/kernel/cpu/capflags.c.
-var x86FeatureStrings = map[Feature]string{
- // Block 0.
- X86FeatureSSE3: "pni",
- X86FeaturePCLMULDQ: "pclmulqdq",
- X86FeatureDTES64: "dtes64",
- X86FeatureMONITOR: "monitor",
- X86FeatureDSCPL: "ds_cpl",
- X86FeatureVMX: "vmx",
- X86FeatureSMX: "smx",
- X86FeatureEST: "est",
- X86FeatureTM2: "tm2",
- X86FeatureSSSE3: "ssse3",
- X86FeatureCNXTID: "cid",
- X86FeatureSDBG: "sdbg",
- X86FeatureFMA: "fma",
- X86FeatureCX16: "cx16",
- X86FeatureXTPR: "xtpr",
- X86FeaturePDCM: "pdcm",
- X86FeaturePCID: "pcid",
- X86FeatureDCA: "dca",
- X86FeatureSSE4_1: "sse4_1",
- X86FeatureSSE4_2: "sse4_2",
- X86FeatureX2APIC: "x2apic",
- X86FeatureMOVBE: "movbe",
- X86FeaturePOPCNT: "popcnt",
- X86FeatureTSCD: "tsc_deadline_timer",
- X86FeatureAES: "aes",
- X86FeatureXSAVE: "xsave",
- X86FeatureAVX: "avx",
- X86FeatureF16C: "f16c",
- X86FeatureRDRAND: "rdrand",
-
- // Block 1.
- X86FeatureFPU: "fpu",
- X86FeatureVME: "vme",
- X86FeatureDE: "de",
- X86FeaturePSE: "pse",
- X86FeatureTSC: "tsc",
- X86FeatureMSR: "msr",
- X86FeaturePAE: "pae",
- X86FeatureMCE: "mce",
- X86FeatureCX8: "cx8",
- X86FeatureAPIC: "apic",
- X86FeatureSEP: "sep",
- X86FeatureMTRR: "mtrr",
- X86FeaturePGE: "pge",
- X86FeatureMCA: "mca",
- X86FeatureCMOV: "cmov",
- X86FeaturePAT: "pat",
- X86FeaturePSE36: "pse36",
- X86FeaturePSN: "pn",
- X86FeatureCLFSH: "clflush",
- X86FeatureDS: "dts",
- X86FeatureACPI: "acpi",
- X86FeatureMMX: "mmx",
- X86FeatureFXSR: "fxsr",
- X86FeatureSSE: "sse",
- X86FeatureSSE2: "sse2",
- X86FeatureSS: "ss",
- X86FeatureHTT: "ht",
- X86FeatureTM: "tm",
- X86FeatureIA64: "ia64",
- X86FeaturePBE: "pbe",
-
- // Block 2.
- X86FeatureFSGSBase: "fsgsbase",
- X86FeatureTSC_ADJUST: "tsc_adjust",
- X86FeatureBMI1: "bmi1",
- X86FeatureHLE: "hle",
- X86FeatureAVX2: "avx2",
- X86FeatureSMEP: "smep",
- X86FeatureBMI2: "bmi2",
- X86FeatureERMS: "erms",
- X86FeatureINVPCID: "invpcid",
- X86FeatureRTM: "rtm",
- X86FeatureCQM: "cqm",
- X86FeatureMPX: "mpx",
- X86FeatureRDT: "rdt_a",
- X86FeatureAVX512F: "avx512f",
- X86FeatureAVX512DQ: "avx512dq",
- X86FeatureRDSEED: "rdseed",
- X86FeatureADX: "adx",
- X86FeatureSMAP: "smap",
- X86FeatureCLWB: "clwb",
- X86FeatureAVX512PF: "avx512pf",
- X86FeatureAVX512ER: "avx512er",
- X86FeatureAVX512CD: "avx512cd",
- X86FeatureSHA: "sha_ni",
- X86FeatureAVX512BW: "avx512bw",
- X86FeatureAVX512VL: "avx512vl",
-
- // Block 3.
- X86FeatureAVX512VBMI: "avx512vbmi",
- X86FeatureUMIP: "umip",
- X86FeaturePKU: "pku",
-
- // Block 4.
- X86FeatureXSAVEOPT: "xsaveopt",
- X86FeatureXSAVEC: "xsavec",
- X86FeatureXGETBV1: "xgetbv1",
- X86FeatureXSAVES: "xsaves",
-
- // Block 5.
- X86FeatureLAHF64: "lahf_lm", // LAHF/SAHF in long mode
- X86FeatureCMP_LEGACY: "cmp_legacy",
- X86FeatureSVM: "svm",
- X86FeatureEXTAPIC: "extapic",
- X86FeatureCR8_LEGACY: "cr8_legacy",
- X86FeatureLZCNT: "abm", // Advanced bit manipulation
- X86FeatureSSE4A: "sse4a",
- X86FeatureMISALIGNSSE: "misalignsse",
- X86FeaturePREFETCHW: "3dnowprefetch",
- X86FeatureOSVW: "osvw",
- X86FeatureIBS: "ibs",
- X86FeatureXOP: "xop",
- X86FeatureSKINIT: "skinit",
- X86FeatureWDT: "wdt",
- X86FeatureLWP: "lwp",
- X86FeatureFMA4: "fma4",
- X86FeatureTCE: "tce",
- X86FeatureTBM: "tbm",
- X86FeatureTOPOLOGY: "topoext",
- X86FeaturePERFCTR_CORE: "perfctr_core",
- X86FeaturePERFCTR_NB: "perfctr_nb",
- X86FeatureBPEXT: "bpext",
- X86FeaturePERFCTR_TSC: "ptsc",
- X86FeaturePERFCTR_LLC: "perfctr_llc",
- X86FeatureMWAITX: "mwaitx",
-
- // Block 6.
- X86FeatureSYSCALL: "syscall",
- X86FeatureNX: "nx",
- X86FeatureMMXEXT: "mmxext",
- X86FeatureFXSR_OPT: "fxsr_opt",
- X86FeatureGBPAGES: "pdpe1gb",
- X86FeatureRDTSCP: "rdtscp",
- X86FeatureLM: "lm",
- X86Feature3DNOWEXT: "3dnowext",
- X86Feature3DNOW: "3dnow",
-}
-
-// These flags are parse only---they can be used for setting / unsetting the
-// flags, but will not get printed out in /proc/cpuinfo.
-var x86FeatureParseOnlyStrings = map[Feature]string{
- // Block 0.
- X86FeatureOSXSAVE: "osxsave",
-
- // Block 2.
- X86FeatureFDP_EXCPTN_ONLY: "fdp_excptn_only",
- X86FeatureFPCSDS: "fpcsds",
- X86FeatureIPT: "pt",
- X86FeatureCLFLUSHOPT: "clfushopt",
-
- // Block 3.
- X86FeaturePREFETCHWT1: "prefetchwt1",
-}
-
-// intelCacheDescriptors describe the caches and TLBs on the system. They are
-// returned in the registers for eax=2. Intel only.
-type intelCacheDescriptor uint8
-
-// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
-// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
-const (
- intelNullDescriptor intelCacheDescriptor = 0
- intelNoTLBDescriptor intelCacheDescriptor = 0xfe
- intelNoCacheDescriptor intelCacheDescriptor = 0xff
-
- // Most descriptors omitted for brevity as they are currently unused.
-)
-
-// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
-type CacheType uint8
-
-const (
- // cacheNull indicates that there are no more entries.
- cacheNull CacheType = iota
-
- // CacheData is a data cache.
- CacheData
-
- // CacheInstruction is an instruction cache.
- CacheInstruction
-
- // CacheUnified is a unified instruction and data cache.
- CacheUnified
-)
-
-// Cache describes the parameters of a single cache on the system.
-//
-// +stateify savable
-type Cache struct {
- // Level is the hierarchical level of this cache (L1, L2, etc).
- Level uint32
-
- // Type is the type of cache.
- Type CacheType
-
- // FullyAssociative indicates that entries may be placed in any block.
- FullyAssociative bool
-
- // Partitions is the number of physical partitions in the cache.
- Partitions uint32
-
- // Ways is the number of ways of associativity in the cache.
- Ways uint32
-
- // Sets is the number of sets in the cache.
- Sets uint32
-
- // InvalidateHierarchical indicates that WBINVD/INVD from threads
- // sharing this cache acts upon lower level caches for threads sharing
- // this cache.
- InvalidateHierarchical bool
-
- // Inclusive indicates that this cache is inclusive of lower cache
- // levels.
- Inclusive bool
-
- // DirectMapped indicates that this cache is directly mapped from
- // address, rather than using a hash function.
- DirectMapped bool
-}
-
-// Just a way to wrap cpuid function numbers.
-type cpuidFunction uint32
-
-// The constants below are the lower or "standard" cpuid functions, ordered as
-// defined by the hardware.
-const (
- vendorID cpuidFunction = iota // Returns vendor ID and largest standard function.
- featureInfo // Returns basic feature bits and processor signature.
- intelCacheDescriptors // Returns list of cache descriptors. Intel only.
- intelSerialNumber // Returns processor serial number (obsolete on new hardware). Intel only.
- intelDeterministicCacheParams // Returns deterministic cache information. Intel only.
- monitorMwaitParams // Returns information about monitor/mwait instructions.
- powerParams // Returns information about power management and thermal sensors.
- extendedFeatureInfo // Returns extended feature bits.
- _ // Function 0x8 is reserved.
- intelDCAParams // Returns direct cache access information. Intel only.
- intelPMCInfo // Returns information about performance monitoring features. Intel only.
- intelX2APICInfo // Returns core/logical processor topology. Intel only.
- _ // Function 0xc is reserved.
- xSaveInfo // Returns information about extended state management.
-)
-
-// The "extended" functions start at 0x80000000.
-const (
- extendedFunctionInfo cpuidFunction = 0x80000000 + iota // Returns highest available extended function in eax.
- extendedFeatures // Returns some extended feature bits in edx and ecx.
-)
-
-// These are the extended floating point state features. They are used to
-// enumerate floating point features in XCR0, XSTATE_BV, etc.
-const (
- XSAVEFeatureX87 = 1 << 0
- XSAVEFeatureSSE = 1 << 1
- XSAVEFeatureAVX = 1 << 2
- XSAVEFeatureBNDREGS = 1 << 3
- XSAVEFeatureBNDCSR = 1 << 4
- XSAVEFeatureAVX512op = 1 << 5
- XSAVEFeatureAVX512zmm0 = 1 << 6
- XSAVEFeatureAVX512zmm16 = 1 << 7
- XSAVEFeaturePKRU = 1 << 9
-)
-
-var cpuFreqMHz float64
-
-// x86FeaturesFromString includes features from x86FeatureStrings and
-// x86FeatureParseOnlyStrings.
-var x86FeaturesFromString = make(map[string]Feature)
-
-// FeatureFromString returns the Feature associated with the given feature
-// string plus a bool to indicate if it could find the feature.
-func FeatureFromString(s string) (Feature, bool) {
- f, b := x86FeaturesFromString[s]
- return f, b
-}
-
-// String implements fmt.Stringer.
-func (f Feature) String() string {
- if s := f.flagString(false); s != "" {
- return s
- }
-
- block := int(f) / 32
- bit := int(f) % 32
- return fmt.Sprintf("<cpuflag %d; block %d bit %d>", f, block, bit)
-}
-
-func (f Feature) flagString(cpuinfoOnly bool) string {
- if s, ok := x86FeatureStrings[f]; ok {
- return s
- }
- if !cpuinfoOnly {
- return x86FeatureParseOnlyStrings[f]
- }
- return ""
-}
-
-// FeatureSet is a set of Features for a CPU.
-//
-// +stateify savable
-type FeatureSet struct {
- // Set is the set of features that are enabled in this FeatureSet.
- Set map[Feature]bool
-
- // VendorID is the 12-char string returned in ebx:edx:ecx for eax=0.
- VendorID string
-
- // ExtendedFamily is part of the processor signature.
- ExtendedFamily uint8
-
- // ExtendedModel is part of the processor signature.
- ExtendedModel uint8
-
- // ProcessorType is part of the processor signature.
- ProcessorType uint8
-
- // Family is part of the processor signature.
- Family uint8
-
- // Model is part of the processor signature.
- Model uint8
-
- // SteppingID is part of the processor signature.
- SteppingID uint8
-
- // Caches describes the caches on the CPU.
- Caches []Cache
-
- // CacheLine is the size of a cache line in bytes.
- //
- // All caches use the same line size. This is not enforced in the CPUID
- // encoding, but is true on all known x86 processors.
- CacheLine uint32
-}
-
-// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
-// equivalent to the "flags" field in /proc/cpuinfo.
-func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
- var s []string
- for _, b := range linuxBlockOrder {
- for i := 0; i < blockSize; i++ {
- if f := featureID(b, i); fs.Set[f] {
- if fstr := f.flagString(cpuinfoOnly); fstr != "" {
- s = append(s, fstr)
- }
- }
- }
- }
- return strings.Join(s, " ")
-}
-
-// CPUInfo is to generate a section of one cpu in /proc/cpuinfo. This is a
-// minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
-// not always printed in Linux. The bogomips field is simply made up.
-func (fs FeatureSet) CPUInfo(cpu uint) string {
- var b bytes.Buffer
- fmt.Fprintf(&b, "processor\t: %d\n", cpu)
- fmt.Fprintf(&b, "vendor_id\t: %s\n", fs.VendorID)
- fmt.Fprintf(&b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
- fmt.Fprintf(&b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
- fmt.Fprintf(&b, "model name\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(&b, "stepping\t: %s\n", "unknown") // Unknown for now.
- fmt.Fprintf(&b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
- fmt.Fprintln(&b, "fpu\t\t: yes")
- fmt.Fprintln(&b, "fpu_exception\t: yes")
- fmt.Fprintf(&b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
- fmt.Fprintln(&b, "wp\t\t: yes")
- fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
- fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine)
- fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine)
- fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
- fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
- fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
- return b.String()
-}
-
-const (
- amdVendorID = "AuthenticAMD"
- intelVendorID = "GenuineIntel"
-)
-
-// AMD returns true if fs describes an AMD CPU.
-func (fs *FeatureSet) AMD() bool {
- return fs.VendorID == amdVendorID
-}
-
-// Intel returns true if fs describes an Intel CPU.
-func (fs *FeatureSet) Intel() bool {
- return fs.VendorID == intelVendorID
-}
-
-// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
-// subset of the host feature set.
-type ErrIncompatible struct {
- message string
-}
-
-// Error implements error.
-func (e ErrIncompatible) Error() string {
- return e.message
-}
-
-// CheckHostCompatible returns nil if fs is a subset of the host feature set.
-func (fs *FeatureSet) CheckHostCompatible() error {
- hfs := HostFeatureSet()
-
- if diff := fs.Subtract(hfs); diff != nil {
- return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
- }
-
- // The size of a cache line must match, as it is critical to correctly
- // utilizing CLFLUSH. Other cache properties are allowed to change, as
- // they are not important to correctness.
- if fs.CacheLine != hfs.CacheLine {
- return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
- }
-
- return nil
-}
-
-// Helper to convert 3 regs into 12-byte vendor ID.
-func vendorIDFromRegs(bx, cx, dx uint32) string {
- bytes := make([]byte, 0, 12)
- for i := uint(0); i < 4; i++ {
- b := byte(bx >> (i * 8))
- bytes = append(bytes, b)
- }
-
- for i := uint(0); i < 4; i++ {
- b := byte(dx >> (i * 8))
- bytes = append(bytes, b)
- }
-
- for i := uint(0); i < 4; i++ {
- b := byte(cx >> (i * 8))
- bytes = append(bytes, b)
- }
- return string(bytes)
-}
-
-// ExtendedStateSize returns the number of bytes needed to save the "extended
-// state" for this processor and the boundary it must be aligned to. Extended
-// state includes floating point registers, and other cpu state that's not
-// associated with the normal task context.
-//
-// Note: We can save some space here with an optimization where we use a
-// smaller chunk of memory depending on features that are actually enabled.
-// Currently we just use the largest possible size for simplicity (which is
-// about 2.5K worst case, with avx512).
-func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
- if fs.UseXsave() {
- // Leaf 0 of xsaveinfo function returns the size for currently
- // enabled xsave features in ebx, the maximum size if all valid
- // features are saved with xsave in ecx, and valid XCR0 bits in
- // edx:eax.
- _, _, maxSize, _ := HostID(uint32(xSaveInfo), 0)
- return uint(maxSize), 64
- }
-
- // If we don't support xsave, we fall back to fxsave, which requires
- // 512 bytes aligned to 16 bytes.
- return 512, 16
-}
-
-// ValidXCR0Mask returns the bits that may be set to 1 in control register
-// XCR0.
-func (fs *FeatureSet) ValidXCR0Mask() uint64 {
- if !fs.UseXsave() {
- return 0
- }
- eax, _, _, edx := HostID(uint32(xSaveInfo), 0)
- return uint64(edx)<<32 | uint64(eax)
-}
-
-// vendorIDRegs returns the 3 register values used to construct the 12-byte
-// vendor ID string for eax=0.
-func (fs *FeatureSet) vendorIDRegs() (bx, dx, cx uint32) {
- for i := uint(0); i < 4; i++ {
- bx |= uint32(fs.VendorID[i]) << (i * 8)
- }
-
- for i := uint(0); i < 4; i++ {
- dx |= uint32(fs.VendorID[i+4]) << (i * 8)
- }
-
- for i := uint(0); i < 4; i++ {
- cx |= uint32(fs.VendorID[i+8]) << (i * 8)
- }
- return
-}
-
-// signature returns the signature dword that's returned in eax when eax=1.
-func (fs *FeatureSet) signature() uint32 {
- var s uint32
- s |= uint32(fs.SteppingID & 0xf)
- s |= uint32(fs.Model&0xf) << 4
- s |= uint32(fs.Family&0xf) << 8
- s |= uint32(fs.ProcessorType&0x3) << 12
- s |= uint32(fs.ExtendedModel&0xf) << 16
- s |= uint32(fs.ExtendedFamily&0xff) << 20
- return s
-}
-
-// Helper to deconstruct signature dword.
-func signatureSplit(v uint32) (ef, em, pt, f, m, sid uint8) {
- sid = uint8(v & 0xf)
- m = uint8(v>>4) & 0xf
- f = uint8(v>>8) & 0xf
- pt = uint8(v>>12) & 0x3
- em = uint8(v>>16) & 0xf
- ef = uint8(v >> 20)
- return
-}
-
-// Helper to convert blockwise feature bit masks into a set of features. Masks
-// must be provided in order for each block, without skipping them. If a block
-// does not matter for this feature set, 0 is specified.
-func setFromBlockMasks(blocks ...uint32) map[Feature]bool {
- s := make(map[Feature]bool)
- for b, blockMask := range blocks {
- for i := 0; i < blockSize; i++ {
- if blockMask&1 != 0 {
- s[featureID(block(b), i)] = true
- }
- blockMask >>= 1
- }
- }
- return s
-}
-
-// blockMask returns the 32-bit mask associated with a block of features.
-func (fs *FeatureSet) blockMask(b block) uint32 {
- var mask uint32
- for i := 0; i < blockSize; i++ {
- if fs.Set[featureID(b, i)] {
- mask |= 1 << uint(i)
- }
- }
- return mask
-}
-
-// Remove removes a Feature from a FeatureSet. It ignores features
-// that are not in the FeatureSet.
-func (fs *FeatureSet) Remove(feature Feature) {
- delete(fs.Set, feature)
-}
-
-// Add adds a Feature to a FeatureSet. It ignores duplicate features.
-func (fs *FeatureSet) Add(feature Feature) {
- fs.Set[feature] = true
-}
-
-// HasFeature tests whether or not a feature is in the given feature set.
-func (fs *FeatureSet) HasFeature(feature Feature) bool {
- return fs.Set[feature]
-}
-
-// Subtract returns the features present in fs that are not present in other.
-// If all features in fs are present in other, Subtract returns nil.
-func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
- for f := range fs.Set {
- if !other.Set[f] {
- if diff == nil {
- diff = make(map[Feature]bool)
- }
- diff[f] = true
- }
- }
-
- return
-}
-
-// EmulateID emulates a cpuid instruction based on the feature set.
-func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
- switch cpuidFunction(origAx) {
- case vendorID:
- ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
- bx, dx, cx = fs.vendorIDRegs()
- case featureInfo:
- // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
- bx = (fs.CacheLine / 8) << 8
- cx = fs.blockMask(block(0))
- dx = fs.blockMask(block(1))
- ax = fs.signature()
- case intelCacheDescriptors:
- if !fs.Intel() {
- // Reserved on non-Intel.
- return 0, 0, 0, 0
- }
-
- // "The least-significant byte in register EAX (register AL)
- // will always return 01H. Software should ignore this value
- // and not interpret it as an informational descriptor." - SDM
- //
- // We only support reporting cache parameters via
- // intelDeterministicCacheParams; report as much here.
- //
- // We do not support exposing TLB information at all.
- ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
- case intelDeterministicCacheParams:
- if !fs.Intel() {
- // Reserved on non-Intel.
- return 0, 0, 0, 0
- }
-
- // cx is the index of the cache to describe.
- if int(origCx) >= len(fs.Caches) {
- return uint32(cacheNull), 0, 0, 0
- }
- c := fs.Caches[origCx]
-
- ax = uint32(c.Type)
- ax |= c.Level << 5
- ax |= 1 << 8 // Always claim the cache is "self-initializing".
- if c.FullyAssociative {
- ax |= 1 << 9
- }
- // Processor topology not supported.
-
- bx = fs.CacheLine - 1
- bx |= (c.Partitions - 1) << 12
- bx |= (c.Ways - 1) << 22
-
- cx = c.Sets - 1
-
- if !c.InvalidateHierarchical {
- dx |= 1
- }
- if c.Inclusive {
- dx |= 1 << 1
- }
- if !c.DirectMapped {
- dx |= 1 << 2
- }
- case xSaveInfo:
- if !fs.UseXsave() {
- return 0, 0, 0, 0
- }
- return HostID(uint32(xSaveInfo), origCx)
- case extendedFeatureInfo:
- if origCx != 0 {
- break // Only leaf 0 is supported.
- }
- bx = fs.blockMask(block(2))
- cx = fs.blockMask(block(3))
- case extendedFunctionInfo:
- // We only support showing the extended features.
- ax = uint32(extendedFeatures)
- cx = 0
- case extendedFeatures:
- cx = fs.blockMask(block(5))
- dx = fs.blockMask(block(6))
- if fs.AMD() {
- // AMD duplicates some block 1 features in block 6.
- dx |= fs.blockMask(block(1)) & block6DuplicateMask
- }
- }
-
- return
-}
-
-// UseXsave returns the choice of fp state saving instruction.
-func (fs *FeatureSet) UseXsave() bool {
- return fs.HasFeature(X86FeatureXSAVE) && fs.HasFeature(X86FeatureOSXSAVE)
-}
-
-// UseXsaveopt returns true if 'fs' supports the "xsaveopt" instruction.
-func (fs *FeatureSet) UseXsaveopt() bool {
- return fs.UseXsave() && fs.HasFeature(X86FeatureXSAVEOPT)
-}
-
-// HostID executes a native CPUID instruction.
-func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
-
-// HostFeatureSet uses cpuid to get host values and construct a feature set
-// that matches that of the host machine. Note that there are several places
-// where there appear to be some unnecessary assignments between register names
-// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
-// where the different feature blocks come from, to make the code easier to
-// inspect and read.
-func HostFeatureSet() *FeatureSet {
- // eax=0 gets max supported feature and vendor ID.
- _, bx, cx, dx := HostID(0, 0)
- vendorID := vendorIDFromRegs(bx, cx, dx)
-
- // eax=1 gets basic features in ecx:edx.
- ax, bx, cx, dx := HostID(1, 0)
- featureBlock0 := cx
- featureBlock1 := dx
- ef, em, pt, f, m, sid := signatureSplit(ax)
- cacheLine := 8 * (bx >> 8) & 0xff
-
- // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
- var caches []Cache
- if vendorID == intelVendorID {
- // ecx selects the cache index until a null type is returned.
- for i := uint32(0); ; i++ {
- ax, bx, cx, dx := HostID(4, i)
- t := CacheType(ax & 0xf)
- if t == cacheNull {
- break
- }
-
- lineSize := (bx & 0xfff) + 1
- if lineSize != cacheLine {
- panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
- }
-
- caches = append(caches, Cache{
- Type: t,
- Level: (ax >> 5) & 0x7,
- FullyAssociative: ((ax >> 9) & 1) == 1,
- Partitions: ((bx >> 12) & 0x3ff) + 1,
- Ways: ((bx >> 22) & 0x3ff) + 1,
- Sets: cx + 1,
- InvalidateHierarchical: (dx & 1) == 0,
- Inclusive: ((dx >> 1) & 1) == 1,
- DirectMapped: ((dx >> 2) & 1) == 0,
- })
- }
- }
-
- // eax=7, ecx=0 gets extended features in ecx:ebx.
- _, bx, cx, _ = HostID(7, 0)
- featureBlock2 := bx
- featureBlock3 := cx
-
- // Leaf 0xd is supported only if CPUID.1:ECX.XSAVE[bit 26] is set.
- var featureBlock4 uint32
- if (featureBlock0 & (1 << 26)) != 0 {
- featureBlock4, _, _, _ = HostID(uint32(xSaveInfo), 1)
- }
-
- // eax=0x80000000 gets supported extended levels. We use this to
- // determine if there are any non-zero block 4 or block 6 bits to find.
- var featureBlock5, featureBlock6 uint32
- if ax, _, _, _ := HostID(uint32(extendedFunctionInfo), 0); ax >= uint32(extendedFeatures) {
- // eax=0x80000001 gets AMD added feature bits.
- _, _, cx, dx = HostID(uint32(extendedFeatures), 0)
- featureBlock5 = cx
- // Ignore features duplicated from block 1 on AMD. These bits
- // are reserved on Intel.
- featureBlock6 = dx &^ block6DuplicateMask
- }
-
- set := setFromBlockMasks(featureBlock0, featureBlock1, featureBlock2, featureBlock3, featureBlock4, featureBlock5, featureBlock6)
- return &FeatureSet{
- Set: set,
- VendorID: vendorID,
- ExtendedFamily: ef,
- ExtendedModel: em,
- ProcessorType: pt,
- Family: f,
- Model: m,
- SteppingID: sid,
- CacheLine: cacheLine,
- Caches: caches,
- }
-}
-
-// Reads max cpu frequency from host /proc/cpuinfo. Must run before
-// whitelisting. This value is used to create the fake /proc/cpuinfo from a
-// FeatureSet.
-func initCPUFreq() {
- cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
- if err != nil {
- // Leave it as 0... The standalone VDSO bails out in the same
- // way.
- log.Warningf("Could not read /proc/cpuinfo: %v", err)
- return
- }
- cpuinfo := string(cpuinfob)
-
- // We get the value straight from host /proc/cpuinfo. On machines with
- // frequency scaling enabled, this will only get the current value
- // which will likely be inaccurate. This is fine on machines with
- // frequency scaling disabled.
- for _, line := range strings.Split(cpuinfo, "\n") {
- if strings.Contains(line, "cpu MHz") {
- splitMHz := strings.Split(line, ":")
- if len(splitMHz) < 2 {
- log.Warningf("Could not read /proc/cpuinfo: malformed cpu MHz line")
- return
- }
-
- // If there was a problem, leave cpuFreqMHz as 0.
- var err error
- cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
- if err != nil {
- log.Warningf("Could not parse cpu MHz value %v: %v", splitMHz[1], err)
- cpuFreqMHz = 0
- return
- }
- return
- }
- }
- log.Warningf("Could not parse /proc/cpuinfo, it is empty or does not contain cpu MHz")
-}
-
-func initFeaturesFromString() {
- for f, s := range x86FeatureStrings {
- x86FeaturesFromString[s] = f
- }
- for f, s := range x86FeatureParseOnlyStrings {
- x86FeaturesFromString[s] = f
- }
-}
-
-func init() {
- // initCpuFreq must be run before whitelists are enabled.
- initCPUFreq()
- initFeaturesFromString()
-}
+// On arm64, features are numbered according to the ELF HWCAP definition.
+// arch/arm64/include/uapi/asm/hwcap.h
+type Feature int
diff --git a/pkg/cpuid/cpuid_arm64.go b/pkg/cpuid/cpuid_arm64.go
new file mode 100644
index 000000000..ac7bb6774
--- /dev/null
+++ b/pkg/cpuid/cpuid_arm64.go
@@ -0,0 +1,483 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package cpuid
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// ARM64 doesn't have a 'cpuid' equivalent, which means it have no architected
+// discovery mechanism for hardware features available to userspace code at EL0.
+// The kernel exposes the presence of these features to userspace through a set
+// of flags(HWCAP/HWCAP2) bits, exposed in the auxilliary vector.
+// Ref Documentation/arm64/elf_hwcaps.rst for more info.
+//
+// Currently, only the HWCAP bits are supported.
+
+const (
+ // ARM64FeatureFP indicates support for single and double precision
+ // float point types.
+ ARM64FeatureFP Feature = iota
+
+ // ARM64FeatureASIMD indicates support for Advanced SIMD with single
+ // and double precision float point arithmetic.
+ ARM64FeatureASIMD
+
+ // ARM64FeatureEVTSTRM indicates support for the generic timer
+ // configured to generate events at a frequency of approximately
+ // 100KHz.
+ ARM64FeatureEVTSTRM
+
+ // ARM64FeatureAES indicates support for AES instructions
+ // (AESE/AESD/AESMC/AESIMC).
+ ARM64FeatureAES
+
+ // ARM64FeaturePMULL indicates support for AES instructions
+ // (PMULL/PMULL2).
+ ARM64FeaturePMULL
+
+ // ARM64FeatureSHA1 indicates support for SHA1 instructions
+ // (SHA1C/SHA1P/SHA1M etc).
+ ARM64FeatureSHA1
+
+ // ARM64FeatureSHA2 indicates support for SHA2 instructions
+ // (SHA256H/SHA256H2/SHA256SU0 etc).
+ ARM64FeatureSHA2
+
+ // ARM64FeatureCRC32 indicates support for CRC32 instructions
+ // (CRC32B/CRC32H/CRC32W etc).
+ ARM64FeatureCRC32
+
+ // ARM64FeatureATOMICS indicates support for atomic instructions
+ // (LDADD/LDCLR/LDEOR/LDSET etc).
+ ARM64FeatureATOMICS
+
+ // ARM64FeatureFPHP indicates support for half precision float point
+ // arithmetic.
+ ARM64FeatureFPHP
+
+ // ARM64FeatureASIMDHP indicates support for ASIMD with half precision
+ // float point arithmetic.
+ ARM64FeatureASIMDHP
+
+ // ARM64FeatureCPUID indicates support for EL0 access to certain ID
+ // registers is available.
+ ARM64FeatureCPUID
+
+ // ARM64FeatureASIMDRDM indicates support for SQRDMLAH and SQRDMLSH
+ // instructions.
+ ARM64FeatureASIMDRDM
+
+ // ARM64FeatureJSCVT indicates support for the FJCVTZS instruction.
+ ARM64FeatureJSCVT
+
+ // ARM64FeatureFCMA indicates support for the FCMLA and FCADD
+ // instructions.
+ ARM64FeatureFCMA
+
+ // ARM64FeatureLRCPC indicates support for the LDAPRB/LDAPRH/LDAPR
+ // instructions.
+ ARM64FeatureLRCPC
+
+ // ARM64FeatureDCPOP indicates support for DC instruction (DC CVAP).
+ ARM64FeatureDCPOP
+
+ // ARM64FeatureSHA3 indicates support for SHA3 instructions
+ // (EOR3/RAX1/XAR/BCAX).
+ ARM64FeatureSHA3
+
+ // ARM64FeatureSM3 indicates support for SM3 instructions
+ // (SM3SS1/SM3TT1A/SM3TT1B).
+ ARM64FeatureSM3
+
+ // ARM64FeatureSM4 indicates support for SM4 instructions
+ // (SM4E/SM4EKEY).
+ ARM64FeatureSM4
+
+ // ARM64FeatureASIMDDP indicates support for dot product instructions
+ // (UDOT/SDOT).
+ ARM64FeatureASIMDDP
+
+ // ARM64FeatureSHA512 indicates support for SHA2 instructions
+ // (SHA512H/SHA512H2/SHA512SU0).
+ ARM64FeatureSHA512
+
+ // ARM64FeatureSVE indicates support for Scalable Vector Extension.
+ ARM64FeatureSVE
+
+ // ARM64FeatureASIMDFHM indicates support for FMLAL and FMLSL
+ // instructions.
+ ARM64FeatureASIMDFHM
+)
+
+// ELF auxiliary vector tags
+const (
+ _AT_NULL = 0 // End of vector
+ _AT_HWCAP = 16 // hardware capability bit vector
+ _AT_HWCAP2 = 26 // hardware capability bit vector 2
+)
+
+// These should not be changed after they are initialized.
+var hwCap uint
+
+// To make emulation of /proc/cpuinfo easy, these names match the names of the
+// basic features in Linux defined in arch/arm64/kernel/cpuinfo.c.
+var arm64FeatureStrings = map[Feature]string{
+ ARM64FeatureFP: "fp",
+ ARM64FeatureASIMD: "asimd",
+ ARM64FeatureEVTSTRM: "evtstrm",
+ ARM64FeatureAES: "aes",
+ ARM64FeaturePMULL: "pmull",
+ ARM64FeatureSHA1: "sha1",
+ ARM64FeatureSHA2: "sha2",
+ ARM64FeatureCRC32: "crc32",
+ ARM64FeatureATOMICS: "atomics",
+ ARM64FeatureFPHP: "fphp",
+ ARM64FeatureASIMDHP: "asimdhp",
+ ARM64FeatureCPUID: "cpuid",
+ ARM64FeatureASIMDRDM: "asimdrdm",
+ ARM64FeatureJSCVT: "jscvt",
+ ARM64FeatureFCMA: "fcma",
+ ARM64FeatureLRCPC: "lrcpc",
+ ARM64FeatureDCPOP: "dcpop",
+ ARM64FeatureSHA3: "sha3",
+ ARM64FeatureSM3: "sm3",
+ ARM64FeatureSM4: "sm4",
+ ARM64FeatureASIMDDP: "asimddp",
+ ARM64FeatureSHA512: "sha512",
+ ARM64FeatureSVE: "sve",
+ ARM64FeatureASIMDFHM: "asimdfhm",
+}
+
+var (
+ cpuFreqMHz float64
+ cpuImplHex uint64
+ cpuArchDec uint64
+ cpuVarHex uint64
+ cpuPartHex uint64
+ cpuRevDec uint64
+)
+
+// arm64FeaturesFromString includes features from arm64FeatureStrings.
+var arm64FeaturesFromString = make(map[string]Feature)
+
+// FeatureFromString returns the Feature associated with the given feature
+// string plus a bool to indicate if it could find the feature.
+func FeatureFromString(s string) (Feature, bool) {
+ f, b := arm64FeaturesFromString[s]
+ return f, b
+}
+
+// String implements fmt.Stringer.
+func (f Feature) String() string {
+ if s := f.flagString(); s != "" {
+ return s
+ }
+
+ return fmt.Sprintf("<cpuflag %d>", f)
+}
+
+func (f Feature) flagString() string {
+ if s, ok := arm64FeatureStrings[f]; ok {
+ return s
+ }
+
+ return ""
+}
+
+// FeatureSet is a set of Features for a CPU.
+//
+// +stateify savable
+type FeatureSet struct {
+ // Set is the set of features that are enabled in this FeatureSet.
+ Set map[Feature]bool
+
+ // CPUImplementer is part of the processor signature.
+ CPUImplementer uint8
+
+ // CPUArchitecture is part of the processor signature.
+ CPUArchitecture uint8
+
+ // CPUVariant is part of the processor signature.
+ CPUVariant uint8
+
+ // CPUPartnum is part of the processor signature.
+ CPUPartnum uint16
+
+ // CPURevision is part of the processor signature.
+ CPURevision uint8
+}
+
+// CheckHostCompatible returns nil if fs is a subset of the host feature set.
+// Noop on arm64.
+func (fs *FeatureSet) CheckHostCompatible() error {
+ return nil
+}
+
+// ExtendedStateSize returns the number of bytes needed to save the "extended
+// state" for this processor and the boundary it must be aligned to. Extended
+// state includes floating point(NEON) registers, and other cpu state that's not
+// associated with the normal task context.
+func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
+ // ARMv8 provide 32x128bits NEON registers.
+ //
+ // Ref arch/arm64/include/uapi/asm/ptrace.h
+ // struct user_fpsimd_state {
+ // __uint128_t vregs[32];
+ // __u32 fpsr;
+ // __u32 fpcr;
+ // __u32 __reserved[2];
+ // };
+ return 528, 16
+}
+
+// HasFeature tests whether or not a feature is in the given feature set.
+func (fs *FeatureSet) HasFeature(feature Feature) bool {
+ return fs.Set[feature]
+}
+
+// UseXsave returns true if 'fs' supports the "xsave" instruction.
+//
+// Irrelevant on arm64.
+func (fs *FeatureSet) UseXsave() bool {
+ return false
+}
+
+// FlagsString prints out supported CPU "flags" field in /proc/cpuinfo.
+func (fs *FeatureSet) FlagsString() string {
+ var s []string
+ for f, _ := range arm64FeatureStrings {
+ if fs.Set[f] {
+ if fstr := f.flagString(); fstr != "" {
+ s = append(s, fstr)
+ }
+ }
+ }
+ return strings.Join(s, " ")
+}
+
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, and the bogomips field is simply made up.
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "BogoMIPS\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "Features\t\t: %s\n", fs.FlagsString())
+ fmt.Fprintf(b, "CPU implementer\t: 0x%x\n", cpuImplHex)
+ fmt.Fprintf(b, "CPU architecture\t: %d\n", cpuArchDec)
+ fmt.Fprintf(b, "CPU variant\t: 0x%x\n", cpuVarHex)
+ fmt.Fprintf(b, "CPU part\t: 0x%x\n", cpuPartHex)
+ fmt.Fprintf(b, "CPU revision\t: %d\n", cpuRevDec)
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
+}
+
+// HostFeatureSet uses hwCap to get host values and construct a feature set
+// that matches that of the host machine.
+func HostFeatureSet() *FeatureSet {
+ s := make(map[Feature]bool)
+
+ for f, _ := range arm64FeatureStrings {
+ if hwCap&(1<<f) != 0 {
+ s[f] = true
+ }
+ }
+
+ return &FeatureSet{
+ Set: s,
+ CPUImplementer: uint8(cpuImplHex),
+ CPUArchitecture: uint8(cpuArchDec),
+ CPUVariant: uint8(cpuVarHex),
+ CPUPartnum: uint16(cpuPartHex),
+ CPURevision: uint8(cpuRevDec),
+ }
+}
+
+// Reads bogomips from host /proc/cpuinfo. Must run before syscall filter
+// installation. This value is used to create the fake /proc/cpuinfo from a
+// FeatureSet.
+func initCPUInfo() {
+ cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ // Leave it as 0. The standalone VDSO bails out in the same way.
+ log.Warningf("Could not read /proc/cpuinfo: %v", err)
+ return
+ }
+ cpuinfo := string(cpuinfob)
+
+ // We get the value straight from host /proc/cpuinfo.
+ for _, line := range strings.Split(cpuinfo, "\n") {
+ switch {
+ case strings.Contains(line, "BogoMIPS"):
+ {
+ splitMHz := strings.Split(line, ":")
+ if len(splitMHz) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed BogoMIPS")
+ break
+ }
+
+ // If there was a problem, leave cpuFreqMHz as 0.
+ var err error
+ cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
+ if err != nil {
+ log.Warningf("Could not parse BogoMIPS value %v: %v", splitMHz[1], err)
+ cpuFreqMHz = 0
+ }
+ }
+ case strings.Contains(line, "CPU implementer"):
+ {
+ splitImpl := strings.Split(line, ":")
+ if len(splitImpl) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU implementer")
+ break
+ }
+
+ // If there was a problem, leave cpuImplHex as 0.
+ var err error
+ cpuImplHex, err = strconv.ParseUint(strings.TrimSpace(splitImpl[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU implementer value %v: %v", splitImpl[1], err)
+ cpuImplHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU architecture"):
+ {
+ splitArch := strings.Split(line, ":")
+ if len(splitArch) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU architecture")
+ break
+ }
+
+ // If there was a problem, leave cpuArchDec as 0.
+ var err error
+ cpuArchDec, err = strconv.ParseUint(strings.TrimSpace(splitArch[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU architecture value %v: %v", splitArch[1], err)
+ cpuArchDec = 0
+ }
+ }
+ case strings.Contains(line, "CPU variant"):
+ {
+ splitVar := strings.Split(line, ":")
+ if len(splitVar) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU variant")
+ break
+ }
+
+ // If there was a problem, leave cpuVarHex as 0.
+ var err error
+ cpuVarHex, err = strconv.ParseUint(strings.TrimSpace(splitVar[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU variant value %v: %v", splitVar[1], err)
+ cpuVarHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU part"):
+ {
+ splitPart := strings.Split(line, ":")
+ if len(splitPart) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU part")
+ break
+ }
+
+ // If there was a problem, leave cpuPartHex as 0.
+ var err error
+ cpuPartHex, err = strconv.ParseUint(strings.TrimSpace(splitPart[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU part value %v: %v", splitPart[1], err)
+ cpuPartHex = 0
+ }
+ }
+ case strings.Contains(line, "CPU revision"):
+ {
+ splitRev := strings.Split(line, ":")
+ if len(splitRev) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed CPU revision")
+ break
+ }
+
+ // If there was a problem, leave cpuRevDec as 0.
+ var err error
+ cpuRevDec, err = strconv.ParseUint(strings.TrimSpace(splitRev[1]), 0, 64)
+ if err != nil {
+ log.Warningf("Could not parse CPU revision value %v: %v", splitRev[1], err)
+ cpuRevDec = 0
+ }
+ }
+ }
+ }
+}
+
+// The auxiliary vector of a process on the Linux system can be read
+// from /proc/self/auxv, and tags and values are stored as 8-bytes
+// decimal key-value pairs on the 64-bit system.
+//
+// $ od -t d8 /proc/self/auxv
+// 0000000 33 140734615224320
+// 0000020 16 3219913727
+// 0000040 6 4096
+// 0000060 17 100
+// 0000100 3 94665627353152
+// 0000120 4 56
+// 0000140 5 9
+// 0000160 7 140425502162944
+// 0000200 8 0
+// 0000220 9 94665627365760
+// 0000240 11 1000
+// 0000260 12 1000
+// 0000300 13 1000
+// 0000320 14 1000
+// 0000340 23 0
+// 0000360 25 140734614619513
+// 0000400 26 0
+// 0000420 31 140734614626284
+// 0000440 15 140734614619529
+// 0000460 0 0
+func initHwCap() {
+ auxv, err := ioutil.ReadFile("/proc/self/auxv")
+ if err != nil {
+ log.Warningf("Could not read /proc/self/auxv: %v", err)
+ return
+ }
+
+ l := len(auxv) / 16
+ for i := 0; i < l; i++ {
+ tag := binary.LittleEndian.Uint64(auxv[i*16:])
+ val := binary.LittleEndian.Uint64(auxv[(i*16 + 8):])
+ if tag == _AT_HWCAP {
+ hwCap = uint(val)
+ break
+ }
+ }
+}
+
+func initFeaturesFromString() {
+ for f, s := range arm64FeatureStrings {
+ arm64FeaturesFromString[s] = f
+ }
+}
+
+func init() {
+ initCPUInfo()
+ initHwCap()
+ initFeaturesFromString()
+}
diff --git a/pkg/cpuid/cpuid_arm64_test.go b/pkg/cpuid/cpuid_arm64_test.go
new file mode 100644
index 000000000..a34f67779
--- /dev/null
+++ b/pkg/cpuid/cpuid_arm64_test.go
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package cpuid
+
+import (
+ "testing"
+)
+
+var justFP = &FeatureSet{
+ Set: map[Feature]bool{
+ ARM64FeatureFP: true,
+ }}
+
+func TestHostFeatureSet(t *testing.T) {
+ hostFeatures := HostFeatureSet()
+ if len(hostFeatures.Set) == 0 {
+ t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
+ }
+}
+
+func TestHasFeature(t *testing.T) {
+ if !justFP.HasFeature(ARM64FeatureFP) {
+ t.Errorf("HasFeature failed, %v should contain %v", justFP, ARM64FeatureFP)
+ }
+
+ if justFP.HasFeature(ARM64FeatureSM3) {
+ t.Errorf("HasFeature failed, %v should not contain %v", justFP, ARM64FeatureSM3)
+ }
+}
+
+func TestFeatureFromString(t *testing.T) {
+ f, ok := FeatureFromString("asimd")
+ if f != ARM64FeatureASIMD || !ok {
+ t.Errorf("got %v want asimd", f)
+ }
+
+ f, ok = FeatureFromString("bad")
+ if ok {
+ t.Errorf("got %v want nothing", f)
+ }
+}
diff --git a/pkg/cpuid/cpuid_parse_test.go b/pkg/cpuid/cpuid_parse_x86_test.go
index dd9969db4..c9bd40e1b 100644
--- a/pkg/cpuid/cpuid_parse_test.go
+++ b/pkg/cpuid/cpuid_parse_x86_test.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build 386 amd64
+
package cpuid
import (
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
new file mode 100644
index 000000000..17a89c00d
--- /dev/null
+++ b/pkg/cpuid/cpuid_x86.go
@@ -0,0 +1,1111 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build 386 amd64
+
+package cpuid
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// Common references for CPUID leaves and bits:
+//
+// Intel:
+// * Intel SDM Volume 2, Chapter 3.2 "CPUID" (more up-to-date)
+// * Intel Application Note 485 (more detailed)
+//
+// AMD:
+// * AMD64 APM Volume 3, Appendix 3 "Obtaining Processor Information ..."
+
+// block is a collection of 32 Feature bits.
+type block int
+
+const blockSize = 32
+
+// Feature bits are numbered according to "blocks". Each block is 32 bits, and
+// feature bits from the same source (cpuid leaf/level) are in the same block.
+func featureID(b block, bit int) Feature {
+ return Feature(32*int(b) + bit)
+}
+
+// Block 0 constants are all of the "basic" feature bits returned by a cpuid in
+// ecx with eax=1.
+const (
+ X86FeatureSSE3 Feature = iota
+ X86FeaturePCLMULDQ
+ X86FeatureDTES64
+ X86FeatureMONITOR
+ X86FeatureDSCPL
+ X86FeatureVMX
+ X86FeatureSMX
+ X86FeatureEST
+ X86FeatureTM2
+ X86FeatureSSSE3 // Not a typo, "supplemental" SSE3.
+ X86FeatureCNXTID
+ X86FeatureSDBG
+ X86FeatureFMA
+ X86FeatureCX16
+ X86FeatureXTPR
+ X86FeaturePDCM
+ _ // ecx bit 16 is reserved.
+ X86FeaturePCID
+ X86FeatureDCA
+ X86FeatureSSE4_1
+ X86FeatureSSE4_2
+ X86FeatureX2APIC
+ X86FeatureMOVBE
+ X86FeaturePOPCNT
+ X86FeatureTSCD
+ X86FeatureAES
+ X86FeatureXSAVE
+ X86FeatureOSXSAVE
+ X86FeatureAVX
+ X86FeatureF16C
+ X86FeatureRDRAND
+ _ // ecx bit 31 is reserved.
+)
+
+// Block 1 constants are all of the "basic" feature bits returned by a cpuid in
+// edx with eax=1.
+const (
+ X86FeatureFPU Feature = 32 + iota
+ X86FeatureVME
+ X86FeatureDE
+ X86FeaturePSE
+ X86FeatureTSC
+ X86FeatureMSR
+ X86FeaturePAE
+ X86FeatureMCE
+ X86FeatureCX8
+ X86FeatureAPIC
+ _ // edx bit 10 is reserved.
+ X86FeatureSEP
+ X86FeatureMTRR
+ X86FeaturePGE
+ X86FeatureMCA
+ X86FeatureCMOV
+ X86FeaturePAT
+ X86FeaturePSE36
+ X86FeaturePSN
+ X86FeatureCLFSH
+ _ // edx bit 20 is reserved.
+ X86FeatureDS
+ X86FeatureACPI
+ X86FeatureMMX
+ X86FeatureFXSR
+ X86FeatureSSE
+ X86FeatureSSE2
+ X86FeatureSS
+ X86FeatureHTT
+ X86FeatureTM
+ X86FeatureIA64
+ X86FeaturePBE
+)
+
+// Block 2 bits are the "structured extended" features returned in ebx for
+// eax=7, ecx=0.
+const (
+ X86FeatureFSGSBase Feature = 2*32 + iota
+ X86FeatureTSC_ADJUST
+ _ // ebx bit 2 is reserved.
+ X86FeatureBMI1
+ X86FeatureHLE
+ X86FeatureAVX2
+ X86FeatureFDP_EXCPTN_ONLY
+ X86FeatureSMEP
+ X86FeatureBMI2
+ X86FeatureERMS
+ X86FeatureINVPCID
+ X86FeatureRTM
+ X86FeatureCQM
+ X86FeatureFPCSDS
+ X86FeatureMPX
+ X86FeatureRDT
+ X86FeatureAVX512F
+ X86FeatureAVX512DQ
+ X86FeatureRDSEED
+ X86FeatureADX
+ X86FeatureSMAP
+ X86FeatureAVX512IFMA
+ X86FeaturePCOMMIT
+ X86FeatureCLFLUSHOPT
+ X86FeatureCLWB
+ X86FeatureIPT // Intel processor trace.
+ X86FeatureAVX512PF
+ X86FeatureAVX512ER
+ X86FeatureAVX512CD
+ X86FeatureSHA
+ X86FeatureAVX512BW
+ X86FeatureAVX512VL
+)
+
+// Block 3 bits are the "extended" features returned in ecx for eax=7, ecx=0.
+const (
+ X86FeaturePREFETCHWT1 Feature = 3*32 + iota
+ X86FeatureAVX512VBMI
+ X86FeatureUMIP
+ X86FeaturePKU
+ X86FeatureOSPKE
+ X86FeatureWAITPKG
+ X86FeatureAVX512_VBMI2
+ _ // ecx bit 7 is reserved
+ X86FeatureGFNI
+ X86FeatureVAES
+ X86FeatureVPCLMULQDQ
+ X86FeatureAVX512_VNNI
+ X86FeatureAVX512_BITALG
+ X86FeatureTME
+ X86FeatureAVX512_VPOPCNTDQ
+ _ // ecx bit 15 is reserved
+ X86FeatureLA57
+ // ecx bits 17-21 are reserved
+ _
+ _
+ _
+ _
+ _
+ X86FeatureRDPID
+ // ecx bits 23-24 are reserved
+ _
+ _
+ X86FeatureCLDEMOTE
+ _ // ecx bit 26 is reserved
+ X86FeatureMOVDIRI
+ X86FeatureMOVDIR64B
+)
+
+// Block 4 constants are for xsave capabilities in CPUID.(EAX=0DH,ECX=01H):EAX.
+// The CPUID leaf is available only if 'X86FeatureXSAVE' is present.
+const (
+ X86FeatureXSAVEOPT Feature = 4*32 + iota
+ X86FeatureXSAVEC
+ X86FeatureXGETBV1
+ X86FeatureXSAVES
+ // EAX[31:4] are reserved.
+)
+
+// Block 5 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):ECX.
+const (
+ X86FeatureLAHF64 Feature = 5*32 + iota
+ X86FeatureCMP_LEGACY
+ X86FeatureSVM
+ X86FeatureEXTAPIC
+ X86FeatureCR8_LEGACY
+ X86FeatureLZCNT
+ X86FeatureSSE4A
+ X86FeatureMISALIGNSSE
+ X86FeaturePREFETCHW
+ X86FeatureOSVW
+ X86FeatureIBS
+ X86FeatureXOP
+ X86FeatureSKINIT
+ X86FeatureWDT
+ _ // ecx bit 14 is reserved.
+ X86FeatureLWP
+ X86FeatureFMA4
+ X86FeatureTCE
+ _ // ecx bit 18 is reserved.
+ _ // ecx bit 19 is reserved.
+ _ // ecx bit 20 is reserved.
+ X86FeatureTBM
+ X86FeatureTOPOLOGY
+ X86FeaturePERFCTR_CORE
+ X86FeaturePERFCTR_NB
+ _ // ecx bit 25 is reserved.
+ X86FeatureBPEXT
+ X86FeaturePERFCTR_TSC
+ X86FeaturePERFCTR_LLC
+ X86FeatureMWAITX
+ // TODO(b/152776797): Some CPUs set this but it is not documented anywhere.
+ X86FeatureBlock5Bit30
+ _ // ecx bit 31 is reserved.
+)
+
+// Block 6 constants are the extended feature bits in
+// CPUID.(EAX=0x80000001):EDX.
+//
+// These are sparse, and so the bit positions are assigned manually.
+const (
+ // On AMD, EDX[24:23] | EDX[17:12] | EDX[9:0] are duplicate features
+ // also defined in block 1 (in identical bit positions). Those features
+ // are not listed here.
+ block6DuplicateMask = 0x183f3ff
+
+ X86FeatureSYSCALL Feature = 6*32 + 11
+ X86FeatureNX Feature = 6*32 + 20
+ X86FeatureMMXEXT Feature = 6*32 + 22
+ X86FeatureFXSR_OPT Feature = 6*32 + 25
+ X86FeatureGBPAGES Feature = 6*32 + 26
+ X86FeatureRDTSCP Feature = 6*32 + 27
+ X86FeatureLM Feature = 6*32 + 29
+ X86Feature3DNOWEXT Feature = 6*32 + 30
+ X86Feature3DNOW Feature = 6*32 + 31
+)
+
+// linuxBlockOrder defines the order in which linux organizes the feature
+// blocks. Linux also tracks feature bits in 32-bit blocks, but in an order
+// which doesn't match well here, so for the /proc/cpuinfo generation we simply
+// re-map the blocks to Linux's ordering and then go through the bits in each
+// block.
+var linuxBlockOrder = []block{1, 6, 0, 5, 2, 4, 3}
+
+// To make emulation of /proc/cpuinfo easy, these names match the names of the
+// basic features in Linux defined in arch/x86/kernel/cpu/capflags.c.
+var x86FeatureStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureSSE3: "pni",
+ X86FeaturePCLMULDQ: "pclmulqdq",
+ X86FeatureDTES64: "dtes64",
+ X86FeatureMONITOR: "monitor",
+ X86FeatureDSCPL: "ds_cpl",
+ X86FeatureVMX: "vmx",
+ X86FeatureSMX: "smx",
+ X86FeatureEST: "est",
+ X86FeatureTM2: "tm2",
+ X86FeatureSSSE3: "ssse3",
+ X86FeatureCNXTID: "cid",
+ X86FeatureSDBG: "sdbg",
+ X86FeatureFMA: "fma",
+ X86FeatureCX16: "cx16",
+ X86FeatureXTPR: "xtpr",
+ X86FeaturePDCM: "pdcm",
+ X86FeaturePCID: "pcid",
+ X86FeatureDCA: "dca",
+ X86FeatureSSE4_1: "sse4_1",
+ X86FeatureSSE4_2: "sse4_2",
+ X86FeatureX2APIC: "x2apic",
+ X86FeatureMOVBE: "movbe",
+ X86FeaturePOPCNT: "popcnt",
+ X86FeatureTSCD: "tsc_deadline_timer",
+ X86FeatureAES: "aes",
+ X86FeatureXSAVE: "xsave",
+ X86FeatureAVX: "avx",
+ X86FeatureF16C: "f16c",
+ X86FeatureRDRAND: "rdrand",
+
+ // Block 1.
+ X86FeatureFPU: "fpu",
+ X86FeatureVME: "vme",
+ X86FeatureDE: "de",
+ X86FeaturePSE: "pse",
+ X86FeatureTSC: "tsc",
+ X86FeatureMSR: "msr",
+ X86FeaturePAE: "pae",
+ X86FeatureMCE: "mce",
+ X86FeatureCX8: "cx8",
+ X86FeatureAPIC: "apic",
+ X86FeatureSEP: "sep",
+ X86FeatureMTRR: "mtrr",
+ X86FeaturePGE: "pge",
+ X86FeatureMCA: "mca",
+ X86FeatureCMOV: "cmov",
+ X86FeaturePAT: "pat",
+ X86FeaturePSE36: "pse36",
+ X86FeaturePSN: "pn",
+ X86FeatureCLFSH: "clflush",
+ X86FeatureDS: "dts",
+ X86FeatureACPI: "acpi",
+ X86FeatureMMX: "mmx",
+ X86FeatureFXSR: "fxsr",
+ X86FeatureSSE: "sse",
+ X86FeatureSSE2: "sse2",
+ X86FeatureSS: "ss",
+ X86FeatureHTT: "ht",
+ X86FeatureTM: "tm",
+ X86FeatureIA64: "ia64",
+ X86FeaturePBE: "pbe",
+
+ // Block 2.
+ X86FeatureFSGSBase: "fsgsbase",
+ X86FeatureTSC_ADJUST: "tsc_adjust",
+ X86FeatureBMI1: "bmi1",
+ X86FeatureHLE: "hle",
+ X86FeatureAVX2: "avx2",
+ X86FeatureSMEP: "smep",
+ X86FeatureBMI2: "bmi2",
+ X86FeatureERMS: "erms",
+ X86FeatureINVPCID: "invpcid",
+ X86FeatureRTM: "rtm",
+ X86FeatureCQM: "cqm",
+ X86FeatureMPX: "mpx",
+ X86FeatureRDT: "rdt_a",
+ X86FeatureAVX512F: "avx512f",
+ X86FeatureAVX512DQ: "avx512dq",
+ X86FeatureRDSEED: "rdseed",
+ X86FeatureADX: "adx",
+ X86FeatureSMAP: "smap",
+ X86FeatureCLWB: "clwb",
+ X86FeatureAVX512PF: "avx512pf",
+ X86FeatureAVX512ER: "avx512er",
+ X86FeatureAVX512CD: "avx512cd",
+ X86FeatureSHA: "sha_ni",
+ X86FeatureAVX512BW: "avx512bw",
+ X86FeatureAVX512VL: "avx512vl",
+
+ // Block 3.
+ X86FeatureAVX512VBMI: "avx512vbmi",
+ X86FeatureUMIP: "umip",
+ X86FeaturePKU: "pku",
+ X86FeatureOSPKE: "ospke",
+ X86FeatureWAITPKG: "waitpkg",
+ X86FeatureAVX512_VBMI2: "avx512_vbmi2",
+ X86FeatureGFNI: "gfni",
+ X86FeatureVAES: "vaes",
+ X86FeatureVPCLMULQDQ: "vpclmulqdq",
+ X86FeatureAVX512_VNNI: "avx512_vnni",
+ X86FeatureAVX512_BITALG: "avx512_bitalg",
+ X86FeatureTME: "tme",
+ X86FeatureAVX512_VPOPCNTDQ: "avx512_vpopcntdq",
+ X86FeatureLA57: "la57",
+ X86FeatureRDPID: "rdpid",
+ X86FeatureCLDEMOTE: "cldemote",
+ X86FeatureMOVDIRI: "movdiri",
+ X86FeatureMOVDIR64B: "movdir64b",
+
+ // Block 4.
+ X86FeatureXSAVEOPT: "xsaveopt",
+ X86FeatureXSAVEC: "xsavec",
+ X86FeatureXGETBV1: "xgetbv1",
+ X86FeatureXSAVES: "xsaves",
+
+ // Block 5.
+ X86FeatureLAHF64: "lahf_lm", // LAHF/SAHF in long mode
+ X86FeatureCMP_LEGACY: "cmp_legacy",
+ X86FeatureSVM: "svm",
+ X86FeatureEXTAPIC: "extapic",
+ X86FeatureCR8_LEGACY: "cr8_legacy",
+ X86FeatureLZCNT: "abm", // Advanced bit manipulation
+ X86FeatureSSE4A: "sse4a",
+ X86FeatureMISALIGNSSE: "misalignsse",
+ X86FeaturePREFETCHW: "3dnowprefetch",
+ X86FeatureOSVW: "osvw",
+ X86FeatureIBS: "ibs",
+ X86FeatureXOP: "xop",
+ X86FeatureSKINIT: "skinit",
+ X86FeatureWDT: "wdt",
+ X86FeatureLWP: "lwp",
+ X86FeatureFMA4: "fma4",
+ X86FeatureTCE: "tce",
+ X86FeatureTBM: "tbm",
+ X86FeatureTOPOLOGY: "topoext",
+ X86FeaturePERFCTR_CORE: "perfctr_core",
+ X86FeaturePERFCTR_NB: "perfctr_nb",
+ X86FeatureBPEXT: "bpext",
+ X86FeaturePERFCTR_TSC: "ptsc",
+ X86FeaturePERFCTR_LLC: "perfctr_llc",
+ X86FeatureMWAITX: "mwaitx",
+
+ // Block 6.
+ X86FeatureSYSCALL: "syscall",
+ X86FeatureNX: "nx",
+ X86FeatureMMXEXT: "mmxext",
+ X86FeatureFXSR_OPT: "fxsr_opt",
+ X86FeatureGBPAGES: "pdpe1gb",
+ X86FeatureRDTSCP: "rdtscp",
+ X86FeatureLM: "lm",
+ X86Feature3DNOWEXT: "3dnowext",
+ X86Feature3DNOW: "3dnow",
+}
+
+// These flags are parse only---they can be used for setting / unsetting the
+// flags, but will not get printed out in /proc/cpuinfo.
+var x86FeatureParseOnlyStrings = map[Feature]string{
+ // Block 0.
+ X86FeatureOSXSAVE: "osxsave",
+
+ // Block 2.
+ X86FeatureFDP_EXCPTN_ONLY: "fdp_excptn_only",
+ X86FeatureFPCSDS: "fpcsds",
+ X86FeatureIPT: "pt",
+ X86FeatureCLFLUSHOPT: "clfushopt",
+
+ // Block 3.
+ X86FeaturePREFETCHWT1: "prefetchwt1",
+
+ // Block 5.
+ X86FeatureBlock5Bit30: "block5_bit30",
+}
+
+// intelCacheDescriptors describe the caches and TLBs on the system. They are
+// returned in the registers for eax=2. Intel only.
+type intelCacheDescriptor uint8
+
+// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
+// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
+const (
+ intelNullDescriptor intelCacheDescriptor = 0
+ intelNoTLBDescriptor intelCacheDescriptor = 0xfe
+ intelNoCacheDescriptor intelCacheDescriptor = 0xff
+
+ // Most descriptors omitted for brevity as they are currently unused.
+)
+
+// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
+type CacheType uint8
+
+const (
+ // cacheNull indicates that there are no more entries.
+ cacheNull CacheType = iota
+
+ // CacheData is a data cache.
+ CacheData
+
+ // CacheInstruction is an instruction cache.
+ CacheInstruction
+
+ // CacheUnified is a unified instruction and data cache.
+ CacheUnified
+)
+
+// Cache describes the parameters of a single cache on the system.
+//
+// +stateify savable
+type Cache struct {
+ // Level is the hierarchical level of this cache (L1, L2, etc).
+ Level uint32
+
+ // Type is the type of cache.
+ Type CacheType
+
+ // FullyAssociative indicates that entries may be placed in any block.
+ FullyAssociative bool
+
+ // Partitions is the number of physical partitions in the cache.
+ Partitions uint32
+
+ // Ways is the number of ways of associativity in the cache.
+ Ways uint32
+
+ // Sets is the number of sets in the cache.
+ Sets uint32
+
+ // InvalidateHierarchical indicates that WBINVD/INVD from threads
+ // sharing this cache acts upon lower level caches for threads sharing
+ // this cache.
+ InvalidateHierarchical bool
+
+ // Inclusive indicates that this cache is inclusive of lower cache
+ // levels.
+ Inclusive bool
+
+ // DirectMapped indicates that this cache is directly mapped from
+ // address, rather than using a hash function.
+ DirectMapped bool
+}
+
+// Just a way to wrap cpuid function numbers.
+type cpuidFunction uint32
+
+// The constants below are the lower or "standard" cpuid functions, ordered as
+// defined by the hardware.
+const (
+ vendorID cpuidFunction = iota // Returns vendor ID and largest standard function.
+ featureInfo // Returns basic feature bits and processor signature.
+ intelCacheDescriptors // Returns list of cache descriptors. Intel only.
+ intelSerialNumber // Returns processor serial number (obsolete on new hardware). Intel only.
+ intelDeterministicCacheParams // Returns deterministic cache information. Intel only.
+ monitorMwaitParams // Returns information about monitor/mwait instructions.
+ powerParams // Returns information about power management and thermal sensors.
+ extendedFeatureInfo // Returns extended feature bits.
+ _ // Function 0x8 is reserved.
+ intelDCAParams // Returns direct cache access information. Intel only.
+ intelPMCInfo // Returns information about performance monitoring features. Intel only.
+ intelX2APICInfo // Returns core/logical processor topology. Intel only.
+ _ // Function 0xc is reserved.
+ xSaveInfo // Returns information about extended state management.
+)
+
+// The "extended" functions start at 0x80000000.
+const (
+ extendedFunctionInfo cpuidFunction = 0x80000000 + iota // Returns highest available extended function in eax.
+ extendedFeatures // Returns some extended feature bits in edx and ecx.
+)
+
+// These are the extended floating point state features. They are used to
+// enumerate floating point features in XCR0, XSTATE_BV, etc.
+const (
+ XSAVEFeatureX87 = 1 << 0
+ XSAVEFeatureSSE = 1 << 1
+ XSAVEFeatureAVX = 1 << 2
+ XSAVEFeatureBNDREGS = 1 << 3
+ XSAVEFeatureBNDCSR = 1 << 4
+ XSAVEFeatureAVX512op = 1 << 5
+ XSAVEFeatureAVX512zmm0 = 1 << 6
+ XSAVEFeatureAVX512zmm16 = 1 << 7
+ XSAVEFeaturePKRU = 1 << 9
+)
+
+var cpuFreqMHz float64
+
+// x86FeaturesFromString includes features from x86FeatureStrings and
+// x86FeatureParseOnlyStrings.
+var x86FeaturesFromString = make(map[string]Feature)
+
+// FeatureFromString returns the Feature associated with the given feature
+// string plus a bool to indicate if it could find the feature.
+func FeatureFromString(s string) (Feature, bool) {
+ f, b := x86FeaturesFromString[s]
+ return f, b
+}
+
+// String implements fmt.Stringer.
+func (f Feature) String() string {
+ if s := f.flagString(false); s != "" {
+ return s
+ }
+
+ block := int(f) / 32
+ bit := int(f) % 32
+ return fmt.Sprintf("<cpuflag %d; block %d bit %d>", f, block, bit)
+}
+
+func (f Feature) flagString(cpuinfoOnly bool) string {
+ if s, ok := x86FeatureStrings[f]; ok {
+ return s
+ }
+ if !cpuinfoOnly {
+ return x86FeatureParseOnlyStrings[f]
+ }
+ return ""
+}
+
+// FeatureSet is a set of Features for a CPU.
+//
+// +stateify savable
+type FeatureSet struct {
+ // Set is the set of features that are enabled in this FeatureSet.
+ Set map[Feature]bool
+
+ // VendorID is the 12-char string returned in ebx:edx:ecx for eax=0.
+ VendorID string
+
+ // ExtendedFamily is part of the processor signature.
+ ExtendedFamily uint8
+
+ // ExtendedModel is part of the processor signature.
+ ExtendedModel uint8
+
+ // ProcessorType is part of the processor signature.
+ ProcessorType uint8
+
+ // Family is part of the processor signature.
+ Family uint8
+
+ // Model is part of the processor signature.
+ Model uint8
+
+ // SteppingID is part of the processor signature.
+ SteppingID uint8
+
+ // Caches describes the caches on the CPU.
+ Caches []Cache
+
+ // CacheLine is the size of a cache line in bytes.
+ //
+ // All caches use the same line size. This is not enforced in the CPUID
+ // encoding, but is true on all known x86 processors.
+ CacheLine uint32
+}
+
+// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
+// equivalent to the "flags" field in /proc/cpuinfo.
+func (fs *FeatureSet) FlagsString(cpuinfoOnly bool) string {
+ var s []string
+ for _, b := range linuxBlockOrder {
+ for i := 0; i < blockSize; i++ {
+ if f := featureID(b, i); fs.Set[f] {
+ if fstr := f.flagString(cpuinfoOnly); fstr != "" {
+ s = append(s, fstr)
+ }
+ }
+ }
+ }
+ return strings.Join(s, " ")
+}
+
+// WriteCPUInfoTo is to generate a section of one cpu in /proc/cpuinfo. This is
+// a minimal /proc/cpuinfo, it is missing some fields like "microcode" that are
+// not always printed in Linux. The bogomips field is simply made up.
+func (fs FeatureSet) WriteCPUInfoTo(cpu uint, b *bytes.Buffer) {
+ fmt.Fprintf(b, "processor\t: %d\n", cpu)
+ fmt.Fprintf(b, "vendor_id\t: %s\n", fs.VendorID)
+ fmt.Fprintf(b, "cpu family\t: %d\n", ((fs.ExtendedFamily<<4)&0xff)|fs.Family)
+ fmt.Fprintf(b, "model\t\t: %d\n", ((fs.ExtendedModel<<4)&0xff)|fs.Model)
+ fmt.Fprintf(b, "model name\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "stepping\t: %s\n", "unknown") // Unknown for now.
+ fmt.Fprintf(b, "cpu MHz\t\t: %.3f\n", cpuFreqMHz)
+ fmt.Fprintln(b, "fpu\t\t: yes")
+ fmt.Fprintln(b, "fpu_exception\t: yes")
+ fmt.Fprintf(b, "cpuid level\t: %d\n", uint32(xSaveInfo)) // Same as ax in vendorID.
+ fmt.Fprintln(b, "wp\t\t: yes")
+ fmt.Fprintf(b, "flags\t\t: %s\n", fs.FlagsString(true))
+ fmt.Fprintf(b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
+ fmt.Fprintf(b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "cache_alignment\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
+ fmt.Fprintln(b, "power management:") // This is always here, but can be blank.
+ fmt.Fprintln(b, "") // The /proc/cpuinfo file ends with an extra newline.
+}
+
+const (
+ amdVendorID = "AuthenticAMD"
+ intelVendorID = "GenuineIntel"
+)
+
+// AMD returns true if fs describes an AMD CPU.
+func (fs *FeatureSet) AMD() bool {
+ return fs.VendorID == amdVendorID
+}
+
+// Intel returns true if fs describes an Intel CPU.
+func (fs *FeatureSet) Intel() bool {
+ return fs.VendorID == intelVendorID
+}
+
+// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
+// subset of the host feature set.
+type ErrIncompatible struct {
+ message string
+}
+
+// Error implements error.
+func (e ErrIncompatible) Error() string {
+ return e.message
+}
+
+// CheckHostCompatible returns nil if fs is a subset of the host feature set.
+func (fs *FeatureSet) CheckHostCompatible() error {
+ hfs := HostFeatureSet()
+
+ if diff := fs.Subtract(hfs); diff != nil {
+ return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
+ }
+
+ // The size of a cache line must match, as it is critical to correctly
+ // utilizing CLFLUSH. Other cache properties are allowed to change, as
+ // they are not important to correctness.
+ if fs.CacheLine != hfs.CacheLine {
+ return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
+ }
+
+ return nil
+}
+
+// Helper to convert 3 regs into 12-byte vendor ID.
+func vendorIDFromRegs(bx, cx, dx uint32) string {
+ bytes := make([]byte, 0, 12)
+ for i := uint(0); i < 4; i++ {
+ b := byte(bx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(dx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ b := byte(cx >> (i * 8))
+ bytes = append(bytes, b)
+ }
+ return string(bytes)
+}
+
+var maxXsaveSize = func() uint32 {
+ // Leaf 0 of xsaveinfo function returns the size for currently
+ // enabled xsave features in ebx, the maximum size if all valid
+ // features are saved with xsave in ecx, and valid XCR0 bits in
+ // edx:eax.
+ //
+ // If xSaveInfo isn't supported, cpuid will not fault but will
+ // return bogus values.
+ _, _, maxXsaveSize, _ := HostID(uint32(xSaveInfo), 0)
+ return maxXsaveSize
+}()
+
+// ExtendedStateSize returns the number of bytes needed to save the "extended
+// state" for this processor and the boundary it must be aligned to. Extended
+// state includes floating point registers, and other cpu state that's not
+// associated with the normal task context.
+//
+// Note: We can save some space here with an optimization where we use a
+// smaller chunk of memory depending on features that are actually enabled.
+// Currently we just use the largest possible size for simplicity (which is
+// about 2.5K worst case, with avx512).
+func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
+ if fs.UseXsave() {
+ return uint(maxXsaveSize), 64
+ }
+
+ // If we don't support xsave, we fall back to fxsave, which requires
+ // 512 bytes aligned to 16 bytes.
+ return 512, 16
+}
+
+// ValidXCR0Mask returns the bits that may be set to 1 in control register
+// XCR0.
+func (fs *FeatureSet) ValidXCR0Mask() uint64 {
+ if !fs.UseXsave() {
+ return 0
+ }
+ eax, _, _, edx := HostID(uint32(xSaveInfo), 0)
+ return uint64(edx)<<32 | uint64(eax)
+}
+
+// vendorIDRegs returns the 3 register values used to construct the 12-byte
+// vendor ID string for eax=0.
+func (fs *FeatureSet) vendorIDRegs() (bx, dx, cx uint32) {
+ for i := uint(0); i < 4; i++ {
+ bx |= uint32(fs.VendorID[i]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ dx |= uint32(fs.VendorID[i+4]) << (i * 8)
+ }
+
+ for i := uint(0); i < 4; i++ {
+ cx |= uint32(fs.VendorID[i+8]) << (i * 8)
+ }
+ return
+}
+
+// signature returns the signature dword that's returned in eax when eax=1.
+func (fs *FeatureSet) signature() uint32 {
+ var s uint32
+ s |= uint32(fs.SteppingID & 0xf)
+ s |= uint32(fs.Model&0xf) << 4
+ s |= uint32(fs.Family&0xf) << 8
+ s |= uint32(fs.ProcessorType&0x3) << 12
+ s |= uint32(fs.ExtendedModel&0xf) << 16
+ s |= uint32(fs.ExtendedFamily&0xff) << 20
+ return s
+}
+
+// Helper to deconstruct signature dword.
+func signatureSplit(v uint32) (ef, em, pt, f, m, sid uint8) {
+ sid = uint8(v & 0xf)
+ m = uint8(v>>4) & 0xf
+ f = uint8(v>>8) & 0xf
+ pt = uint8(v>>12) & 0x3
+ em = uint8(v>>16) & 0xf
+ ef = uint8(v >> 20)
+ return
+}
+
+// Helper to convert blockwise feature bit masks into a set of features. Masks
+// must be provided in order for each block, without skipping them. If a block
+// does not matter for this feature set, 0 is specified.
+func setFromBlockMasks(blocks ...uint32) map[Feature]bool {
+ s := make(map[Feature]bool)
+ for b, blockMask := range blocks {
+ for i := 0; i < blockSize; i++ {
+ if blockMask&1 != 0 {
+ s[featureID(block(b), i)] = true
+ }
+ blockMask >>= 1
+ }
+ }
+ return s
+}
+
+// blockMask returns the 32-bit mask associated with a block of features.
+func (fs *FeatureSet) blockMask(b block) uint32 {
+ var mask uint32
+ for i := 0; i < blockSize; i++ {
+ if fs.Set[featureID(b, i)] {
+ mask |= 1 << uint(i)
+ }
+ }
+ return mask
+}
+
+// Remove removes a Feature from a FeatureSet. It ignores features
+// that are not in the FeatureSet.
+func (fs *FeatureSet) Remove(feature Feature) {
+ delete(fs.Set, feature)
+}
+
+// Add adds a Feature to a FeatureSet. It ignores duplicate features.
+func (fs *FeatureSet) Add(feature Feature) {
+ fs.Set[feature] = true
+}
+
+// HasFeature tests whether or not a feature is in the given feature set.
+func (fs *FeatureSet) HasFeature(feature Feature) bool {
+ return fs.Set[feature]
+}
+
+// Subtract returns the features present in fs that are not present in other.
+// If all features in fs are present in other, Subtract returns nil.
+func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
+ for f := range fs.Set {
+ if !other.Set[f] {
+ if diff == nil {
+ diff = make(map[Feature]bool)
+ }
+ diff[f] = true
+ }
+ }
+
+ return
+}
+
+// EmulateID emulates a cpuid instruction based on the feature set.
+func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
+ switch cpuidFunction(origAx) {
+ case vendorID:
+ ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
+ bx, dx, cx = fs.vendorIDRegs()
+ case featureInfo:
+ // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
+ bx = (fs.CacheLine / 8) << 8
+ cx = fs.blockMask(block(0))
+ dx = fs.blockMask(block(1))
+ ax = fs.signature()
+ case intelCacheDescriptors:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // "The least-significant byte in register EAX (register AL)
+ // will always return 01H. Software should ignore this value
+ // and not interpret it as an informational descriptor." - SDM
+ //
+ // We only support reporting cache parameters via
+ // intelDeterministicCacheParams; report as much here.
+ //
+ // We do not support exposing TLB information at all.
+ ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
+ case intelDeterministicCacheParams:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // cx is the index of the cache to describe.
+ if int(origCx) >= len(fs.Caches) {
+ return uint32(cacheNull), 0, 0, 0
+ }
+ c := fs.Caches[origCx]
+
+ ax = uint32(c.Type)
+ ax |= c.Level << 5
+ ax |= 1 << 8 // Always claim the cache is "self-initializing".
+ if c.FullyAssociative {
+ ax |= 1 << 9
+ }
+ // Processor topology not supported.
+
+ bx = fs.CacheLine - 1
+ bx |= (c.Partitions - 1) << 12
+ bx |= (c.Ways - 1) << 22
+
+ cx = c.Sets - 1
+
+ if !c.InvalidateHierarchical {
+ dx |= 1
+ }
+ if c.Inclusive {
+ dx |= 1 << 1
+ }
+ if !c.DirectMapped {
+ dx |= 1 << 2
+ }
+ case xSaveInfo:
+ if !fs.UseXsave() {
+ return 0, 0, 0, 0
+ }
+ return HostID(uint32(xSaveInfo), origCx)
+ case extendedFeatureInfo:
+ if origCx != 0 {
+ break // Only leaf 0 is supported.
+ }
+ bx = fs.blockMask(block(2))
+ cx = fs.blockMask(block(3))
+ case extendedFunctionInfo:
+ // We only support showing the extended features.
+ ax = uint32(extendedFeatures)
+ cx = 0
+ case extendedFeatures:
+ cx = fs.blockMask(block(5))
+ dx = fs.blockMask(block(6))
+ if fs.AMD() {
+ // AMD duplicates some block 1 features in block 6.
+ dx |= fs.blockMask(block(1)) & block6DuplicateMask
+ }
+ }
+
+ return
+}
+
+// UseXsave returns the choice of fp state saving instruction.
+func (fs *FeatureSet) UseXsave() bool {
+ return fs.HasFeature(X86FeatureXSAVE) && fs.HasFeature(X86FeatureOSXSAVE)
+}
+
+// UseXsaveopt returns true if 'fs' supports the "xsaveopt" instruction.
+func (fs *FeatureSet) UseXsaveopt() bool {
+ return fs.UseXsave() && fs.HasFeature(X86FeatureXSAVEOPT)
+}
+
+// HostID executes a native CPUID instruction.
+func HostID(axArg, cxArg uint32) (ax, bx, cx, dx uint32)
+
+// HostFeatureSet uses cpuid to get host values and construct a feature set
+// that matches that of the host machine. Note that there are several places
+// where there appear to be some unnecessary assignments between register names
+// (ax, bx, cx, or dx) and featureBlockN variables. This is to explicitly show
+// where the different feature blocks come from, to make the code easier to
+// inspect and read.
+func HostFeatureSet() *FeatureSet {
+ // eax=0 gets max supported feature and vendor ID.
+ _, bx, cx, dx := HostID(0, 0)
+ vendorID := vendorIDFromRegs(bx, cx, dx)
+
+ // eax=1 gets basic features in ecx:edx.
+ ax, bx, cx, dx := HostID(1, 0)
+ featureBlock0 := cx
+ featureBlock1 := dx
+ ef, em, pt, f, m, sid := signatureSplit(ax)
+ cacheLine := 8 * (bx >> 8) & 0xff
+
+ // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
+ var caches []Cache
+ if vendorID == intelVendorID {
+ // ecx selects the cache index until a null type is returned.
+ for i := uint32(0); ; i++ {
+ ax, bx, cx, dx := HostID(4, i)
+ t := CacheType(ax & 0xf)
+ if t == cacheNull {
+ break
+ }
+
+ lineSize := (bx & 0xfff) + 1
+ if lineSize != cacheLine {
+ panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
+ }
+
+ caches = append(caches, Cache{
+ Type: t,
+ Level: (ax >> 5) & 0x7,
+ FullyAssociative: ((ax >> 9) & 1) == 1,
+ Partitions: ((bx >> 12) & 0x3ff) + 1,
+ Ways: ((bx >> 22) & 0x3ff) + 1,
+ Sets: cx + 1,
+ InvalidateHierarchical: (dx & 1) == 0,
+ Inclusive: ((dx >> 1) & 1) == 1,
+ DirectMapped: ((dx >> 2) & 1) == 0,
+ })
+ }
+ }
+
+ // eax=7, ecx=0 gets extended features in ecx:ebx.
+ _, bx, cx, _ = HostID(7, 0)
+ featureBlock2 := bx
+ featureBlock3 := cx
+
+ // Leaf 0xd is supported only if CPUID.1:ECX.XSAVE[bit 26] is set.
+ var featureBlock4 uint32
+ if (featureBlock0 & (1 << 26)) != 0 {
+ featureBlock4, _, _, _ = HostID(uint32(xSaveInfo), 1)
+ }
+
+ // eax=0x80000000 gets supported extended levels. We use this to
+ // determine if there are any non-zero block 4 or block 6 bits to find.
+ var featureBlock5, featureBlock6 uint32
+ if ax, _, _, _ := HostID(uint32(extendedFunctionInfo), 0); ax >= uint32(extendedFeatures) {
+ // eax=0x80000001 gets AMD added feature bits.
+ _, _, cx, dx = HostID(uint32(extendedFeatures), 0)
+ featureBlock5 = cx
+ // Ignore features duplicated from block 1 on AMD. These bits
+ // are reserved on Intel.
+ featureBlock6 = dx &^ block6DuplicateMask
+ }
+
+ set := setFromBlockMasks(featureBlock0, featureBlock1, featureBlock2, featureBlock3, featureBlock4, featureBlock5, featureBlock6)
+ return &FeatureSet{
+ Set: set,
+ VendorID: vendorID,
+ ExtendedFamily: ef,
+ ExtendedModel: em,
+ ProcessorType: pt,
+ Family: f,
+ Model: m,
+ SteppingID: sid,
+ CacheLine: cacheLine,
+ Caches: caches,
+ }
+}
+
+// Reads max cpu frequency from host /proc/cpuinfo. Must run before syscall
+// filter installation. This value is used to create the fake /proc/cpuinfo
+// from a FeatureSet.
+func initCPUFreq() {
+ cpuinfob, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ // Leave it as 0... The standalone VDSO bails out in the same
+ // way.
+ log.Warningf("Could not read /proc/cpuinfo: %v", err)
+ return
+ }
+ cpuinfo := string(cpuinfob)
+
+ // We get the value straight from host /proc/cpuinfo. On machines with
+ // frequency scaling enabled, this will only get the current value
+ // which will likely be inaccurate. This is fine on machines with
+ // frequency scaling disabled.
+ for _, line := range strings.Split(cpuinfo, "\n") {
+ if strings.Contains(line, "cpu MHz") {
+ splitMHz := strings.Split(line, ":")
+ if len(splitMHz) < 2 {
+ log.Warningf("Could not read /proc/cpuinfo: malformed cpu MHz line")
+ return
+ }
+
+ // If there was a problem, leave cpuFreqMHz as 0.
+ var err error
+ cpuFreqMHz, err = strconv.ParseFloat(strings.TrimSpace(splitMHz[1]), 64)
+ if err != nil {
+ log.Warningf("Could not parse cpu MHz value %v: %v", splitMHz[1], err)
+ cpuFreqMHz = 0
+ return
+ }
+ return
+ }
+ }
+ log.Warningf("Could not parse /proc/cpuinfo, it is empty or does not contain cpu MHz")
+}
+
+func initFeaturesFromString() {
+ for f, s := range x86FeatureStrings {
+ x86FeaturesFromString[s] = f
+ }
+ for f, s := range x86FeatureParseOnlyStrings {
+ x86FeaturesFromString[s] = f
+ }
+}
+
+func init() {
+ initCPUFreq()
+ initFeaturesFromString()
+}
diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_x86_test.go
index a707ebb55..bacf345c8 100644
--- a/pkg/cpuid/cpuid_test.go
+++ b/pkg/cpuid/cpuid_x86_test.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build 386 amd64
+
package cpuid
import (
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 71f2abc83..bee28b68d 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -1,6 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
package(licenses = ["notice"])
@@ -10,11 +8,11 @@ go_library(
"event.go",
"rate.go",
],
- importpath = "gvisor.dev/gvisor/pkg/eventchannel",
visibility = ["//:sandbox"],
deps = [
":eventchannel_go_proto",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_golang_protobuf//ptypes:go_default_library_gen",
@@ -23,22 +21,17 @@ go_library(
)
proto_library(
- name = "eventchannel_proto",
+ name = "eventchannel",
srcs = ["event.proto"],
-)
-
-go_proto_library(
- name = "eventchannel_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto",
- proto = ":eventchannel_proto",
visibility = ["//:sandbox"],
)
go_test(
name = "eventchannel_test",
srcs = ["event_test.go"],
- embed = [":eventchannel"],
+ library = ":eventchannel",
deps = [
+ "//pkg/sync",
"@com_github_golang_protobuf//proto:go_default_library",
],
)
diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go
index d37ad0428..9a29c58bd 100644
--- a/pkg/eventchannel/event.go
+++ b/pkg/eventchannel/event.go
@@ -22,13 +22,13 @@ package eventchannel
import (
"encoding/binary"
"fmt"
- "sync"
"syscall"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go
index 3649097d6..43750360b 100644
--- a/pkg/eventchannel/event_test.go
+++ b/pkg/eventchannel/event_test.go
@@ -16,11 +16,11 @@ package eventchannel
import (
"fmt"
- "sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
// testEmitter is an emitter that can be used in tests. It records all events
@@ -78,7 +78,7 @@ func TestMultiEmitter(t *testing.T) {
for _, name := range names {
m := testMessage{name: name}
if _, err := me.Emit(m); err != nil {
- t.Fatal("me.Emit(%v) failed: %v", m, err)
+ t.Fatalf("me.Emit(%v) failed: %v", m, err)
}
}
@@ -96,7 +96,7 @@ func TestMultiEmitter(t *testing.T) {
// Close multiEmitter.
if err := me.Close(); err != nil {
- t.Fatal("me.Close() failed: %v", err)
+ t.Fatalf("me.Close() failed: %v", err)
}
// All testEmitters should be closed.
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index afa8f7659..872361546 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "fd",
srcs = ["fd.go"],
- importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
)
@@ -14,5 +12,5 @@ go_test(
name = "fd_test",
size = "small",
srcs = ["fd_test.go"],
- embed = [":fd"],
+ library = ":fd",
)
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
index 56495cbd9..d9104ef02 100644
--- a/pkg/fdchannel/BUILD
+++ b/pkg/fdchannel/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
-package(licenses = ["notice"])
+licenses(["notice"])
go_library(
name = "fdchannel",
srcs = ["fdchannel_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/fdchannel",
visibility = ["//visibility:public"],
)
@@ -14,5 +12,6 @@ go_test(
name = "fdchannel_test",
size = "small",
srcs = ["fdchannel_test.go"],
- embed = [":fdchannel"],
+ library = ":fdchannel",
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/fdchannel/fdchannel_test.go b/pkg/fdchannel/fdchannel_test.go
index 5d01dc636..7a8a63a59 100644
--- a/pkg/fdchannel/fdchannel_test.go
+++ b/pkg/fdchannel/fdchannel_test.go
@@ -17,10 +17,11 @@ package fdchannel
import (
"io/ioutil"
"os"
- "sync"
"syscall"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSendRecvFD(t *testing.T) {
diff --git a/pkg/fdnotifier/BUILD b/pkg/fdnotifier/BUILD
index aca2d8a82..235dcc490 100644
--- a/pkg/fdnotifier/BUILD
+++ b/pkg/fdnotifier/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,9 +8,9 @@ go_library(
"fdnotifier.go",
"poll_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/fdnotifier",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/sync",
"//pkg/waiter",
"@org_golang_x_sys//unix:go_default_library",
],
diff --git a/pkg/fdnotifier/fdnotifier.go b/pkg/fdnotifier/fdnotifier.go
index f4aae1953..a6b63c982 100644
--- a/pkg/fdnotifier/fdnotifier.go
+++ b/pkg/fdnotifier/fdnotifier.go
@@ -22,10 +22,10 @@ package fdnotifier
import (
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index 5643d5f26..aa8e4e1f3 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -1,7 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
-package(licenses = ["notice"])
+licenses(["notice"])
go_library(
name = "flipcall",
@@ -12,14 +11,14 @@ go_library(
"futex_linux.go",
"io.go",
"packet_window_allocator.go",
+ "packet_window_mmap.go",
],
- importpath = "gvisor.dev/gvisor/pkg/flipcall",
visibility = ["//visibility:public"],
deps = [
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
- "//third_party/gvsync",
+ "//pkg/sync",
],
)
@@ -30,5 +29,6 @@ go_test(
"flipcall_example_test.go",
"flipcall_test.go",
],
- embed = [":flipcall"],
+ library = ":flipcall",
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index 8390915a2..e7c3a3a0b 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -113,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error {
return nil
case epsBlocked | epsShutdown:
atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
- return shutdownError{}
+ return ShutdownError{}
default:
// Most likely due to ep.enterFutexWait() being called concurrently
// from multiple goroutines.
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 386cee42c..ec742c091 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -95,7 +95,7 @@ func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...
if pwd.Length > math.MaxUint32 {
return fmt.Errorf("packet window size (%d) exceeds maximum (%d)", pwd.Length, math.MaxUint32)
}
- m, _, e := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ m, e := packetWindowMmap(pwd)
if e != 0 {
return fmt.Errorf("failed to mmap packet window: %v", e)
}
@@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() {
// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
-// Endpoint, to unblock and return errors. It does not wait for concurrent
-// calls to return. Successive calls to Shutdown have no effect.
+// Endpoint, to unblock and return ShutdownErrors. It does not wait for
+// concurrent calls to return. Successive calls to Shutdown have no effect.
//
// Shutdown is the only Endpoint method that may be called concurrently with
// other methods on the same Endpoint.
@@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool {
return atomic.LoadUint32(&ep.shutdown) != 0
}
-type shutdownError struct{}
+// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown()
+// has been called.
+type ShutdownError struct{}
// Error implements error.Error.
-func (shutdownError) Error() string {
+func (ShutdownError) Error() string {
return "flipcall connection shutdown"
}
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
index 8d88b845d..2e28a149a 100644
--- a/pkg/flipcall/flipcall_example_test.go
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -17,7 +17,8 @@ package flipcall
import (
"bytes"
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func Example() {
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index 168a487ec..33fd55a44 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -16,9 +16,10 @@ package flipcall
import (
"runtime"
- "sync"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
var testPacketWindowSize = pageSize
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index a37952637..ac974b232 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -18,7 +18,7 @@ import (
"reflect"
"unsafe"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Packets consist of a 16-byte header followed by an arbitrarily-sized
@@ -75,13 +75,13 @@ func (ep *Endpoint) Data() []byte {
var ioSync int64
func raceBecomeActive() {
- if gvsync.RaceEnabled {
- gvsync.RaceAcquire((unsafe.Pointer)(&ioSync))
+ if sync.RaceEnabled {
+ sync.RaceAcquire((unsafe.Pointer)(&ioSync))
}
}
func raceBecomeInactive() {
- if gvsync.RaceEnabled {
- gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ if sync.RaceEnabled {
+ sync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
}
}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
index b127a2bbb..168c1ccff 100644
--- a/pkg/flipcall/futex_linux.go
+++ b/pkg/flipcall/futex_linux.go
@@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error {
if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
switch cs := atomic.LoadUint32(ep.connState()); cs {
case csShutdown:
- return shutdownError{}
+ return ShutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
}
@@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error {
return nil
case ep.inactiveState:
if ep.isShutdownLocally() {
- return shutdownError{}
+ return ShutdownError{}
}
if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
}
continue
case csShutdown:
- return shutdownError{}
+ return ShutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
diff --git a/pkg/flipcall/packet_window_allocator.go b/pkg/flipcall/packet_window_allocator.go
index ccb918fab..af9cc3d21 100644
--- a/pkg/flipcall/packet_window_allocator.go
+++ b/pkg/flipcall/packet_window_allocator.go
@@ -134,7 +134,7 @@ func (pwa *PacketWindowAllocator) Allocate(size int) (PacketWindowDescriptor, er
start := pwa.nextAlloc
pwa.nextAlloc = end
return PacketWindowDescriptor{
- FD: pwa.fd,
+ FD: pwa.FD(),
Offset: start,
Length: size,
}, nil
@@ -158,7 +158,7 @@ func (pwa *PacketWindowAllocator) ensureFileSize(min int64) error {
}
newSize = newNewSize
}
- if err := syscall.Ftruncate(pwa.fd, newSize); err != nil {
+ if err := syscall.Ftruncate(pwa.FD(), newSize); err != nil {
return fmt.Errorf("ftruncate failed: %v", err)
}
pwa.fileSize = newSize
diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap.go
new file mode 100644
index 000000000..869183b11
--- /dev/null
+++ b/pkg/flipcall/packet_window_mmap.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flipcall
+
+import (
+ "syscall"
+)
+
+// Return a memory mapping of the pwd in memory that can be shared outside the sandbox.
+func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) {
+ m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset))
+ return m, err
+}
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index 0c5f50397..67dd1e225 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -1,20 +1,18 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
-package(
- default_visibility = ["//visibility:public"],
- licenses = ["notice"],
-)
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])
go_library(
name = "fspath",
srcs = [
"builder.go",
- "builder_unsafe.go",
"fspath.go",
],
- importpath = "gvisor.dev/gvisor/pkg/fspath",
- deps = ["//pkg/syserror"],
+ deps = [
+ "//pkg/gohacks",
+ ],
)
go_test(
@@ -24,6 +22,5 @@ go_test(
"builder_test.go",
"fspath_test.go",
],
- embed = [":fspath"],
- deps = ["//pkg/syserror"],
+ library = ":fspath",
)
diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go
index 7ddb36826..6318d3874 100644
--- a/pkg/fspath/builder.go
+++ b/pkg/fspath/builder.go
@@ -16,6 +16,8 @@ package fspath
import (
"fmt"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
)
// Builder is similar to strings.Builder, but is used to produce pathnames
@@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) {
copy(b.buf[b.start:], b.buf[oldStart:])
copy(b.buf[len(b.buf)-len(str):], str)
}
+
+// String returns the accumulated string. No other methods should be called
+// after String.
+func (b *Builder) String() string {
+ return gohacks.StringFromImmutableBytes(b.buf[b.start:])
+}
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
index f68752560..4c983d5fd 100644
--- a/pkg/fspath/fspath.go
+++ b/pkg/fspath/fspath.go
@@ -18,19 +18,17 @@ package fspath
import (
"strings"
-
- "gvisor.dev/gvisor/pkg/syserror"
)
const pathSep = '/'
-// Parse parses a pathname as described by path_resolution(7).
-func Parse(pathname string) (Path, error) {
+// Parse parses a pathname as described by path_resolution(7), except that
+// empty pathnames will be parsed successfully to a Path for which
+// Path.Absolute == Path.Dir == Path.HasComponents() == false. (This is
+// necessary to support AT_EMPTY_PATH.)
+func Parse(pathname string) Path {
if len(pathname) == 0 {
- // "... POSIX decrees that an empty pathname must not be resolved
- // successfully. Linux returns ENOENT in this case." -
- // path_resolution(7)
- return Path{}, syserror.ENOENT
+ return Path{}
}
// Skip leading path separators.
i := 0
@@ -41,7 +39,7 @@ func Parse(pathname string) (Path, error) {
return Path{
Absolute: true,
Dir: true,
- }, nil
+ }
}
}
// Skip trailing path separators. This is required by Iterator.Next. This
@@ -64,12 +62,13 @@ func Parse(pathname string) (Path, error) {
},
Absolute: i != 0,
Dir: j != len(pathname)-1,
- }, nil
+ }
}
// Path contains the information contained in a pathname string.
//
-// Path is copyable by value.
+// Path is copyable by value. The zero value for Path is equivalent to
+// fspath.Parse(""), i.e. the empty path.
type Path struct {
// Begin is an iterator to the first path component in the relative part of
// the path.
@@ -111,6 +110,12 @@ func (p Path) String() string {
return b.String()
}
+// HasComponents returns true if p contains a non-zero number of path
+// components.
+func (p Path) HasComponents() bool {
+ return p.Begin.Ok()
+}
+
// An Iterator represents either a path component in a Path or a terminal
// iterator indicating that the end of the path has been reached.
//
diff --git a/pkg/fspath/fspath_test.go b/pkg/fspath/fspath_test.go
index 215b35622..d5e9a549a 100644
--- a/pkg/fspath/fspath_test.go
+++ b/pkg/fspath/fspath_test.go
@@ -18,15 +18,10 @@ import (
"reflect"
"strings"
"testing"
-
- "gvisor.dev/gvisor/pkg/syserror"
)
func TestParseIteratorPartialPathnames(t *testing.T) {
- path, err := Parse("/foo//bar///baz////")
- if err != nil {
- t.Fatalf("Parse failed: %v", err)
- }
+ path := Parse("/foo//bar///baz////")
// Parse strips leading slashes, and records their presence as
// Path.Absolute.
if !path.Absolute {
@@ -71,6 +66,12 @@ func TestParse(t *testing.T) {
}
tests := []testCase{
{
+ pathname: "",
+ relpath: []string{},
+ abs: false,
+ dir: false,
+ },
+ {
pathname: "/",
relpath: []string{},
abs: true,
@@ -113,10 +114,7 @@ func TestParse(t *testing.T) {
for _, test := range tests {
t.Run(test.pathname, func(t *testing.T) {
- p, err := Parse(test.pathname)
- if err != nil {
- t.Fatalf("failed to parse pathname %q: %v", test.pathname, err)
- }
+ p := Parse(test.pathname)
t.Logf("pathname %q => path %q", test.pathname, p)
if p.Absolute != test.abs {
t.Errorf("path absoluteness: got %v, wanted %v", p.Absolute, test.abs)
@@ -134,10 +132,3 @@ func TestParse(t *testing.T) {
})
}
}
-
-func TestParseEmptyPathname(t *testing.T) {
- p, err := Parse("")
- if err != syserror.ENOENT {
- t.Errorf("parsing empty pathname: got (%v, %v), wanted (<unspecified>, ENOENT)", p, err)
- }
-}
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
index 4b9321711..dd3141143 100644
--- a/pkg/gate/BUILD
+++ b/pkg/gate/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,7 +7,6 @@ go_library(
srcs = [
"gate.go",
],
- importpath = "gvisor.dev/gvisor/pkg/gate",
visibility = ["//visibility:public"],
)
@@ -19,5 +17,6 @@ go_test(
],
deps = [
":gate",
+ "//pkg/sync",
],
)
diff --git a/pkg/gate/gate_test.go b/pkg/gate/gate_test.go
index 5dbd8d712..316015e06 100644
--- a/pkg/gate/gate_test.go
+++ b/pkg/gate/gate_test.go
@@ -15,11 +15,12 @@
package gate_test
import (
- "sync"
+ "runtime"
"testing"
"time"
"gvisor.dev/gvisor/pkg/gate"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestBasicEnter(t *testing.T) {
@@ -165,6 +166,8 @@ func worker(g *gate.Gate, done *sync.WaitGroup) {
if !g.Enter() {
break
}
+ // Golang before v1.14 doesn't preempt busyloops.
+ runtime.Gosched()
g.Leave()
}
done.Done()
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
new file mode 100644
index 000000000..35683fe98
--- /dev/null
+++ b/pkg/gohacks/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gohacks",
+ srcs = [
+ "gohacks_unsafe.go",
+ ],
+ stateify = False,
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
new file mode 100644
index 000000000..aad675172
--- /dev/null
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gohacks contains utilities for subverting the Go compiler.
+package gohacks
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Noescape hides a pointer from escape analysis. Noescape is the identity
+// function but escape analysis doesn't think the output depends on the input.
+// Noescape is inlined and currently compiles down to zero instructions.
+// USE CAREFULLY!
+//
+// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().)
+//
+//go:nosplit
+func Noescape(p unsafe.Pointer) unsafe.Pointer {
+ x := uintptr(p)
+ return unsafe.Pointer(x ^ 0)
+}
+
+// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the
+// same memory backing s instead of making a heap-allocated copy. This is only
+// valid if the returned slice is never mutated.
+func ImmutableBytesFromString(s string) []byte {
+ shdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ var bs []byte
+ bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ bshdr.Data = shdr.Data
+ bshdr.Len = shdr.Len
+ bshdr.Cap = shdr.Len
+ return bs
+}
+
+// StringFromImmutableBytes is equivalent to string(bs), except that it uses
+// the same memory backing bs instead of making a heap-allocated copy. This is
+// only valid if bs is never mutated after StringFromImmutableBytes returns.
+func StringFromImmutableBytes(bs []byte) string {
+ // This is cheaper than messing with reflect.StringHeader and
+ // reflect.SliceHeader, which as of this writing produces many dead stores
+ // of zeroes. Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/goid/BUILD b/pkg/goid/BUILD
new file mode 100644
index 000000000..7a82631c5
--- /dev/null
+++ b/pkg/goid/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "goid",
+ srcs = [
+ "goid.go",
+ "goid_amd64.s",
+ "goid_arm64.s",
+ "goid_race.go",
+ "goid_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+go_test(
+ name = "goid_test",
+ size = "small",
+ srcs = [
+ "empty_test.go",
+ "goid_test.go",
+ ],
+ library = ":goid",
+)
diff --git a/pkg/goid/empty_test.go b/pkg/goid/empty_test.go
new file mode 100644
index 000000000..c0a4b17ab
--- /dev/null
+++ b/pkg/goid/empty_test.go
@@ -0,0 +1,22 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package goid
+
+import "testing"
+
+// TestNothing exists to make the build system happy.
+func TestNothing(t *testing.T) {}
diff --git a/pkg/goid/goid.go b/pkg/goid/goid.go
new file mode 100644
index 000000000..39df30031
--- /dev/null
+++ b/pkg/goid/goid.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
+
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ panic("unimplemented for non-race builds")
+}
diff --git a/pkg/goid/goid_amd64.s b/pkg/goid/goid_amd64.s
new file mode 100644
index 000000000..d9f5cd2a3
--- /dev/null
+++ b/pkg/goid/goid_amd64.s
@@ -0,0 +1,21 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// func getg() *g
+TEXT ·getg(SB),NOSPLIT,$0-8
+ MOVQ (TLS), R14
+ MOVQ R14, ret+0(FP)
+ RET
diff --git a/pkg/goid/goid_arm64.s b/pkg/goid/goid_arm64.s
new file mode 100644
index 000000000..a7465b75d
--- /dev/null
+++ b/pkg/goid/goid_arm64.s
@@ -0,0 +1,21 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// func getg() *g
+TEXT ·getg(SB),NOSPLIT,$0-8
+ MOVD g, R0 // g
+ MOVD R0, ret+0(FP)
+ RET
diff --git a/pkg/goid/goid_race.go b/pkg/goid/goid_race.go
new file mode 100644
index 000000000..1766beaee
--- /dev/null
+++ b/pkg/goid/goid_race.go
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Only available in race/gotsan builds.
+// +build race
+
+// Package goid provides access to the ID of the current goroutine in
+// race/gotsan builds.
+package goid
+
+// Get returns the ID of the current goroutine.
+func Get() int64 {
+ return goid()
+}
diff --git a/pkg/goid/goid_test.go b/pkg/goid/goid_test.go
new file mode 100644
index 000000000..31970ce79
--- /dev/null
+++ b/pkg/goid/goid_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package goid
+
+import (
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestInitialGoID(t *testing.T) {
+ const max = 10000
+ if id := goid(); id < 0 || id > max {
+ t.Errorf("got goid = %d, want 0 < goid <= %d", id, max)
+ }
+}
+
+// TestGoIDSquence verifies that goid returns values which could plausibly be
+// goroutine IDs. If this test breaks or becomes flaky, the structs in
+// goid_unsafe.go may need to be updated.
+func TestGoIDSquence(t *testing.T) {
+ // Goroutine IDs are cached by each P.
+ runtime.GOMAXPROCS(1)
+
+ // Fill any holes in lower range.
+ for i := 0; i < 50; i++ {
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ }
+
+ id := goid()
+ for i := 0; i < 100; i++ {
+ var (
+ newID int64
+ wg sync.WaitGroup
+ )
+ wg.Add(1)
+ go func() {
+ newID = goid()
+ wg.Done()
+
+ // Leak the goroutine to prevent the ID from being
+ // reused.
+ select {}
+ }()
+ wg.Wait()
+ if max := id + 100; newID <= id || newID > max {
+ t.Errorf("unexpected goroutine ID pattern, got goid = %d, want %d < goid <= %d (previous = %d)", newID, id, max, id)
+ }
+ id = newID
+ }
+}
diff --git a/pkg/goid/goid_unsafe.go b/pkg/goid/goid_unsafe.go
new file mode 100644
index 000000000..ded8004dd
--- /dev/null
+++ b/pkg/goid/goid_unsafe.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package goid
+
+// Structs from Go runtime. These may change in the future and require
+// updating. These structs are currently the same on both AMD64 and ARM64,
+// but may diverge in the future.
+
+type stack struct {
+ lo uintptr
+ hi uintptr
+}
+
+type gobuf struct {
+ sp uintptr
+ pc uintptr
+ g uintptr
+ ctxt uintptr
+ ret uint64
+ lr uintptr
+ bp uintptr
+}
+
+type g struct {
+ stack stack
+ stackguard0 uintptr
+ stackguard1 uintptr
+
+ _panic uintptr
+ _defer uintptr
+ m uintptr
+ sched gobuf
+ syscallsp uintptr
+ syscallpc uintptr
+ stktopsp uintptr
+ param uintptr
+ atomicstatus uint32
+ stackLock uint32
+ goid int64
+
+ // More fields...
+ //
+ // We only use goid and the fields before it are only listed to
+ // calculate the correct offset.
+}
+
+func getg() *g
+
+// goid returns the ID of the current goroutine.
+func goid() int64 {
+ return getg().goid
+}
diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD
index 34d2673ef..3f6eb07df 100644
--- a/pkg/ilist/BUILD
+++ b/pkg/ilist/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
srcs = [
"interface_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/ilist",
visibility = ["//visibility:public"],
)
@@ -41,7 +39,7 @@ go_test(
"list_test.go",
"test_list.go",
],
- embed = [":ilist"],
+ library = ":ilist",
)
go_template(
diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go
index 019caadca..f4a4c33d3 100644
--- a/pkg/ilist/list.go
+++ b/pkg/ilist/list.go
@@ -86,11 +86,21 @@ func (l *List) Back() Element {
return l.tail
}
+// Len returns the number of elements in the list.
+//
+// NOTE: This is an O(n) operation.
+func (l *List) Len() (count int) {
+ for e := l.Front(); e != nil; e = (ElementMapper{}.linkerFor(e)).Next() {
+ count++
+ }
+ return count
+}
+
// PushFront inserts the element e at the front of list l.
func (l *List) PushFront(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(l.head)
- ElementMapper{}.linkerFor(e).SetPrev(nil)
-
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(l.head)
+ linker.SetPrev(nil)
if l.head != nil {
ElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
@@ -102,9 +112,9 @@ func (l *List) PushFront(e Element) {
// PushBack inserts the element e at the back of list l.
func (l *List) PushBack(e Element) {
- ElementMapper{}.linkerFor(e).SetNext(nil)
- ElementMapper{}.linkerFor(e).SetPrev(l.tail)
-
+ linker := ElementMapper{}.linkerFor(e)
+ linker.SetNext(nil)
+ linker.SetPrev(l.tail)
if l.tail != nil {
ElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
@@ -125,17 +135,20 @@ func (l *List) PushBackList(m *List) {
l.tail = m.tail
}
-
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *List) InsertAfter(b, e Element) {
- a := ElementMapper{}.linkerFor(b).Next()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(b).SetNext(e)
+ bLinker := ElementMapper{}.linkerFor(b)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ a := bLinker.Next()
+
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ bLinker.SetNext(e)
if a != nil {
ElementMapper{}.linkerFor(a).SetPrev(e)
@@ -146,10 +159,13 @@ func (l *List) InsertAfter(b, e Element) {
// InsertBefore inserts e before a.
func (l *List) InsertBefore(a, e Element) {
- b := ElementMapper{}.linkerFor(a).Prev()
- ElementMapper{}.linkerFor(e).SetNext(a)
- ElementMapper{}.linkerFor(e).SetPrev(b)
- ElementMapper{}.linkerFor(a).SetPrev(e)
+ aLinker := ElementMapper{}.linkerFor(a)
+ eLinker := ElementMapper{}.linkerFor(e)
+
+ b := aLinker.Prev()
+ eLinker.SetNext(a)
+ eLinker.SetPrev(b)
+ aLinker.SetPrev(e)
if b != nil {
ElementMapper{}.linkerFor(b).SetNext(e)
@@ -160,20 +176,24 @@ func (l *List) InsertBefore(a, e Element) {
// Remove removes e from l.
func (l *List) Remove(e Element) {
- prev := ElementMapper{}.linkerFor(e).Prev()
- next := ElementMapper{}.linkerFor(e).Next()
+ linker := ElementMapper{}.linkerFor(e)
+ prev := linker.Prev()
+ next := linker.Next()
if prev != nil {
ElementMapper{}.linkerFor(prev).SetNext(next)
- } else {
+ } else if l.head == e {
l.head = next
}
if next != nil {
ElementMapper{}.linkerFor(next).SetPrev(prev)
- } else {
+ } else if l.tail == e {
l.tail = prev
}
+
+ linker.SetNext(nil)
+ linker.SetPrev(nil)
}
// Entry is a default implementation of Linker. Users can add anonymous fields
diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD
new file mode 100644
index 000000000..eda82cfc1
--- /dev/null
+++ b/pkg/iovec/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "iovec",
+ srcs = ["iovec.go"],
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/abi/linux"],
+)
+
+go_test(
+ name = "iovec_test",
+ size = "small",
+ srcs = ["iovec_test.go"],
+ library = ":iovec",
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
+)
diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go
new file mode 100644
index 000000000..dd70fe80f
--- /dev/null
+++ b/pkg/iovec/iovec.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+// Package iovec provides helpers to interact with vectorized I/O on host
+// system.
+package iovec
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// MaxIovs is the maximum number of iovecs host platform can accept.
+var MaxIovs = linux.UIO_MAXIOV
+
+// Builder is a builder for slice of syscall.Iovec.
+type Builder struct {
+ iovec []syscall.Iovec
+ storage [8]syscall.Iovec
+
+ // overflow tracks the last buffer when iovec length is at MaxIovs.
+ overflow []byte
+}
+
+// Add adds buf to b preparing to be written. Zero-length buf won't be added.
+func (b *Builder) Add(buf []byte) {
+ if len(buf) == 0 {
+ return
+ }
+ if b.iovec == nil {
+ b.iovec = b.storage[:0]
+ }
+ if len(b.iovec) >= MaxIovs {
+ b.addByAppend(buf)
+ return
+ }
+ b.iovec = append(b.iovec, syscall.Iovec{
+ Base: &buf[0],
+ Len: uint64(len(buf)),
+ })
+ // Keep the last buf if iovec is at max capacity. We will need to append to it
+ // for later bufs.
+ if len(b.iovec) == MaxIovs {
+ n := len(buf)
+ b.overflow = buf[:n:n]
+ }
+}
+
+func (b *Builder) addByAppend(buf []byte) {
+ b.overflow = append(b.overflow, buf...)
+ b.iovec[len(b.iovec)-1] = syscall.Iovec{
+ Base: &b.overflow[0],
+ Len: uint64(len(b.overflow)),
+ }
+}
+
+// Build returns the final Iovec slice. The length of returned iovec will not
+// excceed MaxIovs.
+func (b *Builder) Build() []syscall.Iovec {
+ return b.iovec
+}
diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go
new file mode 100644
index 000000000..a3900c299
--- /dev/null
+++ b/pkg/iovec/iovec_test.go
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package iovec
+
+import (
+ "bytes"
+ "fmt"
+ "syscall"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func TestBuilderEmpty(t *testing.T) {
+ var builder Builder
+ iovecs := builder.Build()
+ if got, want := len(iovecs), 0; got != want {
+ t.Errorf("len(iovecs) = %d, want %d", got, want)
+ }
+}
+
+func TestBuilderBuild(t *testing.T) {
+ a := []byte{1, 2}
+ b := []byte{3, 4, 5}
+
+ var builder Builder
+ builder.Add(a)
+ builder.Add(b)
+ builder.Add(nil) // Nil slice won't be added.
+ builder.Add([]byte{}) // Empty slice won't be added.
+ iovecs := builder.Build()
+
+ if got, want := len(iovecs), 2; got != want {
+ t.Fatalf("len(iovecs) = %d, want %d", got, want)
+ }
+ for i, data := range [][]byte{a, b} {
+ if got, want := *iovecs[i].Base, data[0]; got != want {
+ t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want)
+ }
+ if got, want := iovecs[i].Len, uint64(len(data)); got != want {
+ t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want)
+ }
+ }
+}
+
+func TestBuilderBuildMaxIov(t *testing.T) {
+ for _, test := range []struct {
+ numIov int
+ }{
+ {
+ numIov: MaxIovs - 1,
+ },
+ {
+ numIov: MaxIovs,
+ },
+ {
+ numIov: MaxIovs + 1,
+ },
+ {
+ numIov: MaxIovs + 10,
+ },
+ } {
+ name := fmt.Sprintf("numIov=%v", test.numIov)
+ t.Run(name, func(t *testing.T) {
+ var data []byte
+ var builder Builder
+ for i := 0; i < test.numIov; i++ {
+ buf := []byte{byte(i)}
+ builder.Add(buf)
+ data = append(data, buf...)
+ }
+ iovec := builder.Build()
+
+ // Check the expected length of iovec.
+ wantNum := test.numIov
+ if wantNum > MaxIovs {
+ wantNum = MaxIovs
+ }
+ if got, want := len(iovec), wantNum; got != want {
+ t.Errorf("len(iovec) = %d, want %d", got, want)
+ }
+
+ // Test a real read-write.
+ var fds [2]int
+ if err := unix.Pipe(fds[:]); err != nil {
+ t.Fatalf("Pipe: %v", err)
+ }
+ defer syscall.Close(fds[0])
+ defer syscall.Close(fds[1])
+
+ wrote, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
+ if int(wrote) != len(data) || e != 0 {
+ t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data))
+ }
+
+ got := make([]byte, len(data))
+ if n, err := syscall.Read(fds[0], got); n != len(got) || err != nil {
+ t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got))
+ }
+
+ if !bytes.Equal(got, data) {
+ t.Errorf("read: got data %v, want %v", got, data)
+ }
+ })
+ }
+}
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
index a5d980d14..f84d03700 100644
--- a/pkg/linewriter/BUILD
+++ b/pkg/linewriter/BUILD
@@ -1,17 +1,18 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "linewriter",
srcs = ["linewriter.go"],
- importpath = "gvisor.dev/gvisor/pkg/linewriter",
+ marshal = False,
+ stateify = False,
visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
)
go_test(
name = "linewriter_test",
srcs = ["linewriter_test.go"],
- embed = [":linewriter"],
+ library = ":linewriter",
)
diff --git a/pkg/linewriter/linewriter.go b/pkg/linewriter/linewriter.go
index cd6e4e2ce..a1b1285d4 100644
--- a/pkg/linewriter/linewriter.go
+++ b/pkg/linewriter/linewriter.go
@@ -17,7 +17,8 @@ package linewriter
import (
"bytes"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Writer is an io.Writer which buffers input, flushing
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index fc5f5779b..3ed6aba5c 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -7,16 +6,19 @@ go_library(
name = "log",
srcs = [
"glog.go",
- "glog_unsafe.go",
"json.go",
"json_k8s.go",
"log.go",
],
- importpath = "gvisor.dev/gvisor/pkg/log",
+ marshal = False,
+ stateify = False,
visibility = [
"//visibility:public",
],
- deps = ["//pkg/linewriter"],
+ deps = [
+ "//pkg/linewriter",
+ "//pkg/sync",
+ ],
)
go_test(
@@ -26,5 +28,5 @@ go_test(
"json_test.go",
"log_test.go",
],
- embed = [":log"],
+ library = ":log",
)
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index 5732785b4..f57c4427b 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -15,149 +15,71 @@
package log
import (
+ "fmt"
"os"
+ "runtime"
+ "strings"
"time"
)
// GoogleEmitter is a wrapper that emits logs in a format compatible with
// package github.com/golang/glog.
type GoogleEmitter struct {
- // Emitter is the underlying emitter.
- Emitter
-}
-
-// buffer is a simple inline buffer to avoid churn. The data slice is generally
-// kept to the local byte array, and we avoid having to allocate it on the heap.
-type buffer struct {
- local [256]byte
- data []byte
-}
-
-func (b *buffer) start() {
- b.data = b.local[:0]
-}
-
-func (b *buffer) String() string {
- return unsafeString(b.data)
-}
-
-func (b *buffer) write(c byte) {
- b.data = append(b.data, c)
-}
-
-func (b *buffer) writeAll(d []byte) {
- b.data = append(b.data, d...)
-}
-
-func (b *buffer) writeOneDigit(d byte) {
- b.write('0' + d)
-}
-
-func (b *buffer) writeTwoDigits(v int) {
- v = v % 100
- b.writeOneDigit(byte(v / 10))
- b.writeOneDigit(byte(v % 10))
-}
-
-func (b *buffer) writeSixDigits(v int) {
- v = v % 1000000
- b.writeOneDigit(byte(v / 100000))
- b.writeOneDigit(byte((v % 100000) / 10000))
- b.writeOneDigit(byte((v % 10000) / 1000))
- b.writeOneDigit(byte((v % 1000) / 100))
- b.writeOneDigit(byte((v % 100) / 10))
- b.writeOneDigit(byte(v % 10))
-}
-
-func calculateBytes(v int, pad int) []byte {
- var d []byte
- r := 1
-
- for n := 10; v >= r; n = n * 10 {
- d = append(d, '0'+byte((v%n)/r))
- r = n
- }
-
- for i := len(d); i < pad; i++ {
- d = append(d, ' ')
- }
-
- for i := 0; i < len(d)/2; i++ {
- d[i], d[len(d)-(i+1)] = d[len(d)-(i+1)], d[i]
- }
- return d
+ *Writer
}
// pid is used for the threadid component of the header.
-//
-// The glog package logger uses 7 spaces of padding. See
-// glob.loggingT.formatHeader.
-var pid = calculateBytes(os.Getpid(), 7)
-
-// caller is faked out as the caller. See FIXME below.
-var caller = []byte("x:0")
+var pid = os.Getpid()
// Emit emits the message, google-style.
-func (g GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
- var b buffer
- b.start()
-
- // Log lines have this form:
- // Lmmdd hh:mm:ss.uuuuuu threadid file:line] msg...
- //
- // where the fields are defined as follows:
- // L A single character, representing the log level (eg 'I' for INFO)
- // mm The month (zero padded; ie May is '05')
- // dd The day (zero padded)
- // hh:mm:ss.uuuuuu Time in hours, minutes and fractional seconds
- // threadid The space-padded thread ID as returned by GetTID()
- // file The file name
- // line The line number
- // msg The user-supplied message
-
+//
+// Log lines have this form:
+// Lmmdd hh:mm:ss.uuuuuu threadid file:line] msg...
+//
+// where the fields are defined as follows:
+// L A single character, representing the log level (eg 'I' for INFO)
+// mm The month (zero padded; ie May is '05')
+// dd The day (zero padded)
+// hh:mm:ss.uuuuuu Time in hours, minutes and fractional seconds
+// threadid The space-padded thread ID as returned by GetTID()
+// file The file name
+// line The line number
+// msg The user-supplied message
+//
+func (g GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
+ prefix := byte('?')
switch level {
case Debug:
- b.write('D')
+ prefix = byte('D')
case Info:
- b.write('I')
+ prefix = byte('I')
case Warning:
- b.write('W')
+ prefix = byte('W')
}
// Timestamp.
_, month, day := timestamp.Date()
hour, minute, second := timestamp.Clock()
- b.writeTwoDigits(int(month))
- b.writeTwoDigits(int(day))
- b.write(' ')
- b.writeTwoDigits(int(hour))
- b.write(':')
- b.writeTwoDigits(int(minute))
- b.write(':')
- b.writeTwoDigits(int(second))
- b.write('.')
- b.writeSixDigits(int(timestamp.Nanosecond() / 1000))
- b.write(' ')
-
- // The pid.
- b.writeAll(pid)
- b.write(' ')
-
- // FIXME(b/73383460): The caller, fabricated. This really sucks, but it
- // is unacceptable to put runtime.Callers() in the hot path.
- b.writeAll(caller)
- b.write(']')
- b.write(' ')
-
- // User-provided format string, copied.
- for i := 0; i < len(format); i++ {
- b.write(format[i])
+ microsecond := int(timestamp.Nanosecond() / 1000)
+
+ // 0 = this frame.
+ _, file, line, ok := runtime.Caller(depth + 1)
+ if ok {
+ // Trim any directory path from the file.
+ slash := strings.LastIndexByte(file, byte('/'))
+ if slash >= 0 {
+ file = file[slash+1:]
+ }
+ } else {
+ // We don't have a filename.
+ file = "???"
+ line = 0
}
- // End with a newline.
- b.write('\n')
+ // Generate the message.
+ message := fmt.Sprintf(format, args...)
- // Pass to the underlying routine.
- g.Emitter.Emit(level, timestamp, b.String(), args...)
+ // Emit the formatted result.
+ fmt.Fprintf(g.Writer, "%c%02d%02d %02d:%02d:%02d.%06d % 7d %s:%d] %s\n", prefix, int(month), day, hour, minute, second, microsecond, pid, file, line, message)
}
diff --git a/pkg/log/json.go b/pkg/log/json.go
index a278c8fc8..bdf9d691e 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -58,11 +58,11 @@ func (lv *Level) UnmarshalJSON(b []byte) error {
// JSONEmitter logs messages in json format.
type JSONEmitter struct {
- Writer
+ *Writer
}
// Emit implements Emitter.Emit.
-func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := jsonLog{
Msg: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index c2c019915..5883e95e1 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -29,11 +29,11 @@ type k8sJSONLog struct {
// K8sJSONEmitter logs messages in json format that is compatible with
// Kubernetes fluent configuration.
type K8sJSONEmitter struct {
- Writer
+ *Writer
}
// Emit implements Emitter.Emit.
-func (e K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 9387586e6..37e0605ad 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -17,6 +17,18 @@
// This is separate from the standard logging package because logging may be a
// high-impact activity, and therefore we wanted to provide as much flexibility
// as possible in the underlying implementation.
+//
+// Note that logging should still be considered high-impact, and should not be
+// done in the hot path. If necessary, logging statements should be protected
+// with guards regarding the logging level. For example,
+//
+// if log.IsLogging(log.Debug) {
+// log.Debugf(...)
+// }
+//
+// This is because the log.Debugf(...) statement alone will generate a
+// significant amount of garbage and churn in many cases, even if no log
+// message is ultimately emitted.
package log
import (
@@ -25,12 +37,12 @@ import (
stdlog "log"
"os"
"runtime"
- "sync"
"sync/atomic"
"syscall"
"time"
"gvisor.dev/gvisor/pkg/linewriter"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Level is the log level.
@@ -67,7 +79,7 @@ func (l Level) String() string {
type Emitter interface {
// Emit emits the given log statement. This allows for control over the
// timestamp used for logging.
- Emit(level Level, timestamp time.Time, format string, v ...interface{})
+ Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{})
}
// Writer writes the output to the given writer.
@@ -130,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) {
}
// Emit emits the message.
-func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) {
fmt.Fprintf(l, format, args...)
}
@@ -138,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i
type MultiEmitter []Emitter
// Emit emits to all emitters.
-func (m MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
- for _, e := range m {
- e.Emit(level, timestamp, format, v...)
+func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) {
+ for _, e := range *m {
+ e.Emit(1+depth, level, timestamp, format, v...)
}
}
@@ -155,7 +167,7 @@ type TestEmitter struct {
}
// Emit emits to the TestLogger.
-func (t TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
t.Logf(format, v...)
}
@@ -186,22 +198,37 @@ type BasicLogger struct {
// Debugf implements logger.Debugf.
func (l *BasicLogger) Debugf(format string, v ...interface{}) {
- if l.IsLogging(Debug) {
- l.Emit(Debug, time.Now(), format, v...)
- }
+ l.DebugfAtDepth(1, format, v...)
}
// Infof implements logger.Infof.
func (l *BasicLogger) Infof(format string, v ...interface{}) {
- if l.IsLogging(Info) {
- l.Emit(Info, time.Now(), format, v...)
- }
+ l.InfofAtDepth(1, format, v...)
}
// Warningf implements logger.Warningf.
func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ l.WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs at a specific depth.
+func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(1+depth, Debug, time.Now(), format, v...)
+ }
+}
+
+// InfofAtDepth logs at a specific depth.
+func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(1+depth, Info, time.Now(), format, v...)
+ }
+}
+
+// WarningfAtDepth logs at a specific depth.
+func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) {
if l.IsLogging(Warning) {
- l.Emit(Warning, time.Now(), format, v...)
+ l.Emit(1+depth, Warning, time.Now(), format, v...)
}
}
@@ -245,17 +272,32 @@ func SetLevel(newLevel Level) {
// Debugf logs to the global logger.
func Debugf(format string, v ...interface{}) {
- Log().Debugf(format, v...)
+ Log().DebugfAtDepth(1, format, v...)
}
// Infof logs to the global logger.
func Infof(format string, v ...interface{}) {
- Log().Infof(format, v...)
+ Log().InfofAtDepth(1, format, v...)
}
// Warningf logs to the global logger.
func Warningf(format string, v ...interface{}) {
- Log().Warningf(format, v...)
+ Log().WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs to the global logger.
+func DebugfAtDepth(depth int, format string, v ...interface{}) {
+ Log().DebugfAtDepth(1+depth, format, v...)
+}
+
+// InfofAtDepth logs to the global logger.
+func InfofAtDepth(depth int, format string, v ...interface{}) {
+ Log().InfofAtDepth(1+depth, format, v...)
+}
+
+// WarningfAtDepth logs to the global logger.
+func WarningfAtDepth(depth int, format string, v ...interface{}) {
+ Log().WarningfAtDepth(1+depth, format, v...)
}
// defaultStackSize is the default buffer size to allocate for stack traces.
diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go
index 0634e7c1f..9ff18559b 100644
--- a/pkg/log/log_test.go
+++ b/pkg/log/log_test.go
@@ -16,18 +16,23 @@ package log
import (
"fmt"
+ "strings"
"testing"
)
type testWriter struct {
lines []string
fail bool
+ limit int
}
func (w *testWriter) Write(bytes []byte) (int, error) {
if w.fail {
return 0, fmt.Errorf("simulated failure")
}
+ if w.limit > 0 && len(w.lines) >= w.limit {
+ return len(bytes), nil
+ }
w.lines = append(w.lines, string(bytes))
return len(bytes), nil
}
@@ -47,7 +52,7 @@ func TestDropMessages(t *testing.T) {
t.Fatalf("Write should have failed")
}
- fmt.Printf("writer: %+v\n", w)
+ fmt.Printf("writer: %#v\n", &w)
tw.fail = false
if _, err := w.Write([]byte("line 2\n")); err != nil {
@@ -68,3 +73,33 @@ func TestDropMessages(t *testing.T) {
}
}
}
+
+func TestCaller(t *testing.T) {
+ tw := &testWriter{}
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
+ bl := &BasicLogger{
+ Emitter: e,
+ Level: Debug,
+ }
+ bl.Debugf("testing...\n") // Just for file + line.
+ if len(tw.lines) != 1 {
+ t.Errorf("expected 1 line, got %d", len(tw.lines))
+ }
+ if !strings.Contains(tw.lines[0], "log_test.go") {
+ t.Errorf("expected log_test.go, got %q", tw.lines[0])
+ }
+}
+
+func BenchmarkGoogleLogging(b *testing.B) {
+ tw := &testWriter{
+ limit: 1, // Only record one message.
+ }
+ e := GoogleEmitter{Writer: &Writer{Next: tw}}
+ bl := &BasicLogger{
+ Emitter: e,
+ Level: Debug,
+ }
+ for i := 0; i < b.N; i++ {
+ bl.Debugf("hello %d, %d, %d", 1, 2, 3)
+ }
+}
diff --git a/pkg/memutil/BUILD b/pkg/memutil/BUILD
index 7b50e2b28..9d07d98b4 100644
--- a/pkg/memutil/BUILD
+++ b/pkg/memutil/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "memutil",
srcs = ["memutil_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/memutil",
visibility = ["//visibility:public"],
deps = ["@org_golang_x_sys//unix:go_default_library"],
)
diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD
new file mode 100644
index 000000000..a8fcb2e19
--- /dev/null
+++ b/pkg/merkletree/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "merkletree",
+ srcs = ["merkletree.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/usermem"],
+)
+
+go_test(
+ name = "merkletree_test",
+ srcs = ["merkletree_test.go"],
+ library = ":merkletree",
+ deps = ["//pkg/usermem"],
+)
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
new file mode 100644
index 000000000..955c9c473
--- /dev/null
+++ b/pkg/merkletree/merkletree.go
@@ -0,0 +1,314 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package merkletree implements Merkle tree generating and verification.
+package merkletree
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ // sha256DigestSize specifies the digest size of a SHA256 hash.
+ sha256DigestSize = 32
+)
+
+// Layout defines the scale of a Merkle tree.
+type Layout struct {
+ // blockSize is the size of a data block to be hashed.
+ blockSize int64
+ // digestSize is the size of a generated hash.
+ digestSize int64
+ // levelOffset contains the offset of the begnning of each level in
+ // bytes. The number of levels in the tree is the length of the slice.
+ // The leaf nodes (level 0) contain hashes of blocks of the input data.
+ // Each level N contains hashes of the blocks in level N-1. The highest
+ // level is the root hash.
+ levelOffset []int64
+}
+
+// InitLayout initializes and returns a new Layout object describing the structure
+// of a tree. dataSize specifies the size of input data in bytes.
+func InitLayout(dataSize int64) Layout {
+ layout := Layout{
+ blockSize: usermem.PageSize,
+ // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512).
+ digestSize: sha256DigestSize,
+ }
+ numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize
+ level := 0
+ offset := int64(0)
+
+ // Calculate the number of levels in the Merkle tree and the beginning
+ // offset of each level. Level 0 consists of the leaf nodes that
+ // contain the hashes of the data blocks, while level numLevels - 1 is
+ // the root.
+ for numBlocks > 1 {
+ layout.levelOffset = append(layout.levelOffset, offset*layout.blockSize)
+ // Round numBlocks up to fill up a block.
+ numBlocks += (layout.hashesPerBlock() - numBlocks%layout.hashesPerBlock()) % layout.hashesPerBlock()
+ offset += numBlocks / layout.hashesPerBlock()
+ numBlocks = numBlocks / layout.hashesPerBlock()
+ level++
+ }
+ layout.levelOffset = append(layout.levelOffset, offset*layout.blockSize)
+ return layout
+}
+
+// hashesPerBlock() returns the number of digests in each block. For example,
+// if blockSize is 4096 bytes, and digestSize is 32 bytes, there will be 128
+// hashesPerBlock. Therefore 128 hashes in one level will be combined in one
+// hash in the level above.
+func (layout Layout) hashesPerBlock() int64 {
+ return layout.blockSize / layout.digestSize
+}
+
+// numLevels returns the total number of levels in the Merkle tree.
+func (layout Layout) numLevels() int {
+ return len(layout.levelOffset)
+}
+
+// rootLevel returns the level of the root hash.
+func (layout Layout) rootLevel() int {
+ return layout.numLevels() - 1
+}
+
+// digestOffset finds the offset of a digest from the beginning of the tree.
+// The target digest is at level of the tree, with index from the beginning of
+// the current level.
+func (layout Layout) digestOffset(level int, index int64) int64 {
+ return layout.levelOffset[level] + index*layout.digestSize
+}
+
+// blockOffset finds the offset of a block from the beginning of the tree. The
+// target block is at level of the tree, with index from the beginning of the
+// current level.
+func (layout Layout) blockOffset(level int, index int64) int64 {
+ return layout.levelOffset[level] + index*layout.blockSize
+}
+
+// Generate constructs a Merkle tree for the contents of data. The output is
+// written to treeWriter. The treeReader should be able to read the tree after
+// it has been written. That is, treeWriter and treeReader should point to the
+// same underlying data but have separate cursors.
+func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter io.Writer) ([]byte, error) {
+ layout := InitLayout(dataSize)
+
+ numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize
+
+ var root []byte
+ for level := 0; level < layout.numLevels(); level++ {
+ for i := int64(0); i < numBlocks; i++ {
+ buf := make([]byte, layout.blockSize)
+ var (
+ n int
+ err error
+ )
+ if level == 0 {
+ // Read data block from the target file since level 0 includes hashes
+ // of blocks in the input data.
+ n, err = data.Read(buf)
+ } else {
+ // Read data block from the tree file since levels higher than 0 are
+ // hashing the lower level hashes.
+ n, err = treeReader.Read(buf)
+ }
+
+ // err is populated as long as the bytes read is smaller than the buffer
+ // size. This could be the case if we are reading the last block, and
+ // break in that case. If this is the last block, the end of the block
+ // will be zero-padded.
+ if n == 0 && err == io.EOF {
+ break
+ } else if err != nil && err != io.EOF {
+ return nil, err
+ }
+ // Hash the bytes in buf.
+ digest := sha256.Sum256(buf)
+
+ if level == layout.rootLevel() {
+ root = digest[:]
+ }
+
+ // Write the generated hash to the end of the tree file.
+ if _, err = treeWriter.Write(digest[:]); err != nil {
+ return nil, err
+ }
+ }
+ // If the generated digests do not round up to a block, zero-padding the
+ // remaining of the last block. But no need to do so for root.
+ if level != layout.rootLevel() && numBlocks%layout.hashesPerBlock() != 0 {
+ zeroBuf := make([]byte, layout.blockSize-(numBlocks%layout.hashesPerBlock())*layout.digestSize)
+ if _, err := treeWriter.Write(zeroBuf[:]); err != nil {
+ return nil, err
+ }
+ }
+ numBlocks = (numBlocks + layout.hashesPerBlock() - 1) / layout.hashesPerBlock()
+ }
+ return root, nil
+}
+
+// Verify verifies the content read from data with offset. The content is
+// verified against tree. If content spans across multiple blocks, each block is
+// verified. Verification fails if the hash of the data does not match the tree
+// at any level, or if the final root hash does not match expectedRoot.
+// Once the data is verified, it will be written using w.
+// Verify will modify the cursor for data, but always restores it to its
+// original position upon exit. The cursor for tree is modified and not
+// restored.
+func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte) error {
+ if readSize <= 0 {
+ return fmt.Errorf("Unexpected read size: %d", readSize)
+ }
+ layout := InitLayout(int64(dataSize))
+
+ // Calculate the index of blocks that includes the target range in input
+ // data.
+ firstDataBlock := readOffset / layout.blockSize
+ lastDataBlock := (readOffset + readSize - 1) / layout.blockSize
+
+ // Store the current offset, so we can set it back once verification
+ // finishes.
+ origOffset, err := data.Seek(0, io.SeekCurrent)
+ if err != nil {
+ return fmt.Errorf("Find current data offset failed: %v", err)
+ }
+ defer data.Seek(origOffset, io.SeekStart)
+
+ // Move to the first block that contains target data.
+ if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil {
+ return fmt.Errorf("Seek to datablock start failed: %v", err)
+ }
+
+ buf := make([]byte, layout.blockSize)
+ var readErr error
+ bytesRead := 0
+ for i := firstDataBlock; i <= lastDataBlock; i++ {
+ // Read a block that includes all or part of target range in
+ // input data.
+ bytesRead, readErr = data.Read(buf)
+ // If at the end of input data and all previous blocks are
+ // verified, return the verified input data and EOF.
+ if readErr == io.EOF && bytesRead == 0 {
+ break
+ }
+ if readErr != nil && readErr != io.EOF {
+ return fmt.Errorf("Read from data failed: %v", err)
+ }
+ // If this is the end of file, zero the remaining bytes in buf,
+ // otherwise they are still from the previous block.
+ // TODO(b/162908070): Investigate possible issues with zero
+ // padding the data.
+ if bytesRead < len(buf) {
+ for j := bytesRead; j < len(buf); j++ {
+ buf[j] = 0
+ }
+ }
+ if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil {
+ return err
+ }
+ // startOff is the beginning of the read range within the
+ // current data block. Note that for all blocks other than the
+ // first, startOff should be 0.
+ startOff := int64(0)
+ if i == firstDataBlock {
+ startOff = readOffset % layout.blockSize
+ }
+ // endOff is the end of the read range within the current data
+ // block. Note that for all blocks other than the last, endOff
+ // should be the block size.
+ endOff := layout.blockSize
+ if i == lastDataBlock {
+ endOff = (readOffset+readSize-1)%layout.blockSize + 1
+ }
+ // If the provided size exceeds the end of input data, we should
+ // only copy the parts in buf that's part of input data.
+ if startOff > int64(bytesRead) {
+ startOff = int64(bytesRead)
+ }
+ if endOff > int64(bytesRead) {
+ endOff = int64(bytesRead)
+ }
+ w.Write(buf[startOff:endOff])
+
+ }
+ return readErr
+}
+
+// verifyBlock verifies a block against tree. index is the number of block in
+// original data. The block is verified through each level of the tree. It
+// fails if the calculated hash from block is different from any level of
+// hashes stored in tree. And the final root hash is compared with
+// expectedRoot. verifyBlock modifies the cursor for tree. Users needs to
+// maintain the cursor if intended.
+func verifyBlock(tree io.ReadSeeker, layout Layout, dataBlock []byte, blockIndex int64, expectedRoot []byte) error {
+ if len(dataBlock) != int(layout.blockSize) {
+ return fmt.Errorf("incorrect block size")
+ }
+
+ expectedDigest := make([]byte, layout.digestSize)
+ treeBlock := make([]byte, layout.blockSize)
+ var digest []byte
+ for level := 0; level < layout.numLevels(); level++ {
+ // Calculate hash.
+ if level == 0 {
+ digestArray := sha256.Sum256(dataBlock)
+ digest = digestArray[:]
+ } else {
+ // Read a block in previous level that contains the
+ // hash we just generated, and generate a next level
+ // hash from it.
+ if _, err := tree.Seek(layout.blockOffset(level-1, blockIndex), io.SeekStart); err != nil {
+ return err
+ }
+ if _, err := tree.Read(treeBlock); err != nil {
+ return err
+ }
+ digestArray := sha256.Sum256(treeBlock)
+ digest = digestArray[:]
+ }
+
+ // Move to stored hash for the current block, read the digest
+ // and store in expectedDigest.
+ if _, err := tree.Seek(layout.digestOffset(level, blockIndex), io.SeekStart); err != nil {
+ return err
+ }
+ if _, err := tree.Read(expectedDigest); err != nil {
+ return err
+ }
+
+ if !bytes.Equal(digest, expectedDigest) {
+ return fmt.Errorf("Verification failed")
+ }
+
+ // If this is the root layer, no need to generate next level
+ // hash.
+ if level == layout.rootLevel() {
+ break
+ }
+ blockIndex = blockIndex / layout.hashesPerBlock()
+ }
+
+ // Verification for the tree succeeded. Now compare the root hash in the
+ // tree with expectedRoot.
+ if !bytes.Equal(digest[:], expectedRoot) {
+ return fmt.Errorf("Verification failed")
+ }
+ return nil
+}
diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go
new file mode 100644
index 000000000..911f61df9
--- /dev/null
+++ b/pkg/merkletree/merkletree_test.go
@@ -0,0 +1,353 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package merkletree
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestLayout(t *testing.T) {
+ testCases := []struct {
+ dataSize int64
+ expectedLevelOffset []int64
+ }{
+ {
+ dataSize: 100,
+ expectedLevelOffset: []int64{0},
+ },
+ {
+ dataSize: 1000000,
+ expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize},
+ },
+ {
+ dataSize: 4096 * int64(usermem.PageSize),
+ expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) {
+ p := InitLayout(tc.dataSize)
+ if p.blockSize != int64(usermem.PageSize) {
+ t.Errorf("got blockSize %d, want %d", p.blockSize, usermem.PageSize)
+ }
+ if p.digestSize != sha256DigestSize {
+ t.Errorf("got digestSize %d, want %d", p.digestSize, sha256DigestSize)
+ }
+ if p.numLevels() != len(tc.expectedLevelOffset) {
+ t.Errorf("got levels %d, want %d", p.numLevels(), len(tc.expectedLevelOffset))
+ }
+ for i := 0; i < p.numLevels() && i < len(tc.expectedLevelOffset); i++ {
+ if p.levelOffset[i] != tc.expectedLevelOffset[i] {
+ t.Errorf("got levelStart[%d] %d, want %d", i, p.levelOffset[i], tc.expectedLevelOffset[i])
+ }
+ }
+ })
+ }
+}
+
+func TestGenerate(t *testing.T) {
+ // The input data has size dataSize. It starts with the data in startWith,
+ // and all other bytes are zeroes.
+ testCases := []struct {
+ data []byte
+ expectedRoot []byte
+ }{
+ {
+ data: bytes.Repeat([]byte{0}, usermem.PageSize),
+ expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167},
+ },
+ {
+ data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1),
+ expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132},
+ },
+ {
+ data: []byte{'a'},
+ expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210},
+ },
+ {
+ data: bytes.Repeat([]byte{'a'}, usermem.PageSize),
+ expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) {
+ var tree bytes.Buffer
+
+ root, err := Generate(bytes.NewBuffer(tc.data), int64(len(tc.data)), &tree, &tree)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
+ }
+
+ if !bytes.Equal(root, tc.expectedRoot) {
+ t.Errorf("Unexpected root")
+ }
+ })
+ }
+}
+
+// bytesReadWriter is used to read from/write to/seek in a byte array. Unlike
+// bytes.Buffer, it keeps the whole buffer during read so that it can be reused.
+type bytesReadWriter struct {
+ // bytes contains the underlying byte array.
+ bytes []byte
+ // readPos is the currently location for Read. Write always appends to
+ // the end of the array.
+ readPos int
+}
+
+func (brw *bytesReadWriter) Write(p []byte) (int, error) {
+ brw.bytes = append(brw.bytes, p...)
+ return len(p), nil
+}
+
+func (brw *bytesReadWriter) Read(p []byte) (int, error) {
+ if brw.readPos >= len(brw.bytes) {
+ return 0, io.EOF
+ }
+ bytesRead := copy(p, brw.bytes[brw.readPos:])
+ brw.readPos += bytesRead
+ if bytesRead < len(p) {
+ return bytesRead, io.EOF
+ }
+ return bytesRead, nil
+}
+
+func (brw *bytesReadWriter) Seek(offset int64, whence int) (int64, error) {
+ off := offset
+ if whence == io.SeekCurrent {
+ off += int64(brw.readPos)
+ }
+ if whence == io.SeekEnd {
+ off += int64(len(brw.bytes))
+ }
+ if off < 0 {
+ panic("seek with negative offset")
+ }
+ if off >= int64(len(brw.bytes)) {
+ return 0, io.EOF
+ }
+ brw.readPos = int(off)
+ return off, nil
+}
+
+func TestVerify(t *testing.T) {
+ // The input data has size dataSize. The portion to be verified ranges from
+ // verifyStart with verifySize. A bit is flipped in outOfRangeByteIndex to
+ // confirm that modifications outside the verification range does not cause
+ // issue. And a bit is flipped in modifyByte to confirm that
+ // modifications in the verification range is caught during verification.
+ testCases := []struct {
+ dataSize int64
+ verifyStart int64
+ verifySize int64
+ // A byte in input data is modified during the test. If the
+ // modified byte falls in verification range, Verify should
+ // fail, otherwise Verify should still succeed.
+ modifyByte int64
+ shouldSucceed bool
+ }{
+ // Verify range start outside the data range should fail.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: usermem.PageSize,
+ verifySize: 1,
+ modifyByte: 0,
+ shouldSucceed: false,
+ },
+ // Verifying range is valid if it starts inside data and ends
+ // outside data range, in that case start to the end of data is
+ // verified.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 2 * usermem.PageSize,
+ modifyByte: 0,
+ shouldSucceed: false,
+ },
+ // Invalid verify range (negative size) should fail.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 1,
+ verifySize: -1,
+ modifyByte: 0,
+ shouldSucceed: false,
+ },
+ // Invalid verify range (0 size) should fail.
+ {
+ dataSize: usermem.PageSize,
+ verifyStart: 0,
+ verifySize: 0,
+ modifyByte: 0,
+ shouldSucceed: false,
+ },
+ // The test cases below use a block-aligned verify range.
+ // Modifying a byte in the verified range should cause verify
+ // to fail.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4 * usermem.PageSize,
+ verifySize: usermem.PageSize,
+ modifyByte: 4 * usermem.PageSize,
+ shouldSucceed: false,
+ },
+ // Modifying a byte before the verified range should not cause
+ // verify to fail.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4 * usermem.PageSize,
+ verifySize: usermem.PageSize,
+ modifyByte: 4*usermem.PageSize - 1,
+ shouldSucceed: true,
+ },
+ // Modifying a byte after the verified range should not cause
+ // verify to fail.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4 * usermem.PageSize,
+ verifySize: usermem.PageSize,
+ modifyByte: 5 * usermem.PageSize,
+ shouldSucceed: true,
+ },
+ // The tests below use a non-block-aligned verify range.
+ // Modifying a byte at strat of verify range should cause
+ // verify to fail.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4*usermem.PageSize + 123,
+ verifySize: 2 * usermem.PageSize,
+ modifyByte: 4*usermem.PageSize + 123,
+ shouldSucceed: false,
+ },
+ // Modifying a byte at the end of verify range should cause
+ // verify to fail.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4*usermem.PageSize + 123,
+ verifySize: 2 * usermem.PageSize,
+ modifyByte: 6*usermem.PageSize + 123,
+ shouldSucceed: false,
+ },
+ // Modifying a byte in the middle verified block should cause
+ // verify to fail.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4*usermem.PageSize + 123,
+ verifySize: 2 * usermem.PageSize,
+ modifyByte: 5*usermem.PageSize + 123,
+ shouldSucceed: false,
+ },
+ // Modifying a byte in the first block in the verified range
+ // should cause verify to fail, even the modified bit itself is
+ // out of verify range.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4*usermem.PageSize + 123,
+ verifySize: 2 * usermem.PageSize,
+ modifyByte: 4*usermem.PageSize + 122,
+ shouldSucceed: false,
+ },
+ // Modifying a byte in the last block in the verified range
+ // should cause verify to fail, even the modified bit itself is
+ // out of verify range.
+ {
+ dataSize: 8 * usermem.PageSize,
+ verifyStart: 4*usermem.PageSize + 123,
+ verifySize: 2 * usermem.PageSize,
+ modifyByte: 6*usermem.PageSize + 124,
+ shouldSucceed: false,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("%d", tc.modifyByte), func(t *testing.T) {
+ data := make([]byte, tc.dataSize)
+ // Generate random bytes in data.
+ rand.Read(data)
+ var tree bytesReadWriter
+
+ root, err := Generate(bytes.NewBuffer(data), int64(tc.dataSize), &tree, &tree)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
+ }
+
+ // Flip a bit in data and checks Verify results.
+ var buf bytes.Buffer
+ data[tc.modifyByte] ^= 1
+ if tc.shouldSucceed {
+ if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root); err != nil && err != io.EOF {
+ t.Errorf("Verification failed when expected to succeed: %v", err)
+ }
+ if int64(buf.Len()) != tc.verifySize || !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) {
+ t.Errorf("Incorrect output from Verify")
+ }
+ } else {
+ if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root); err == nil {
+ t.Errorf("Verification succeeded when expected to fail")
+ }
+ }
+ })
+ }
+}
+
+func TestVerifyRandom(t *testing.T) {
+ rand.Seed(time.Now().UnixNano())
+ // Use a random dataSize. Minimum size 2 so that we can pick a random
+ // portion from it.
+ dataSize := rand.Int63n(200*usermem.PageSize) + 2
+ data := make([]byte, dataSize)
+ // Generate random bytes in data.
+ rand.Read(data)
+ var tree bytesReadWriter
+
+ root, err := Generate(bytes.NewBuffer(data), int64(dataSize), &tree, &tree)
+ if err != nil {
+ t.Fatalf("Generate failed: %v", err)
+ }
+
+ // Pick a random portion of data.
+ start := rand.Int63n(dataSize - 1)
+ size := rand.Int63n(dataSize) + 1
+
+ var buf bytes.Buffer
+ // Checks that the random portion of data from the original data is
+ // verified successfully.
+ if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root); err != nil && err != io.EOF {
+ t.Errorf("Verification failed for correct data: %v", err)
+ }
+ if size > dataSize-start {
+ size = dataSize - start
+ }
+ if int64(buf.Len()) != size || !bytes.Equal(data[start:start+size], buf.Bytes()) {
+ t.Errorf("Incorrect output from Verify")
+ }
+
+ buf.Reset()
+ // Flip a random bit in randPortion, and check that verification fails.
+ randBytePos := rand.Int63n(size)
+ data[start+randBytePos] ^= 1
+
+ if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root); err == nil {
+ t.Errorf("Verification succeeded for modified data")
+ }
+}
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index dd6ca6d39..58305009d 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -1,45 +1,29 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
package(licenses = ["notice"])
go_library(
name = "metric",
srcs = ["metric.go"],
- importpath = "gvisor.dev/gvisor/pkg/metric",
visibility = ["//:sandbox"],
deps = [
":metric_go_proto",
"//pkg/eventchannel",
"//pkg/log",
+ "//pkg/sync",
],
)
proto_library(
- name = "metric_proto",
+ name = "metric",
srcs = ["metric.proto"],
visibility = ["//:sandbox"],
)
-cc_proto_library(
- name = "metric_cc_proto",
- visibility = ["//:sandbox"],
- deps = [":metric_proto"],
-)
-
-go_proto_library(
- name = "metric_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto",
- proto = ":metric_proto",
- visibility = ["//:sandbox"],
-)
-
go_test(
name = "metric_test",
srcs = ["metric_test.go"],
- embed = [":metric"],
+ library = ":metric",
deps = [
":metric_go_proto",
"//pkg/eventchannel",
diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go
index eadde06e4..64aa365ce 100644
--- a/pkg/metric/metric.go
+++ b/pkg/metric/metric.go
@@ -18,12 +18,12 @@ package metric
import (
"errors"
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/eventchannel"
"gvisor.dev/gvisor/pkg/log"
pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
@@ -39,17 +39,11 @@ var (
// Uint64Metric encapsulates a uint64 that represents some kind of metric to be
// monitored.
//
-// All metrics must be cumulative, meaning that their values will only increase
-// over time.
-//
// Metrics are not saved across save/restore and thus reset to zero on restore.
//
-// TODO(b/67298402): Support non-cumulative metrics.
// TODO(b/67298427): Support metric fields.
-//
type Uint64Metric struct {
- // value is the actual value of the metric. It must be accessed
- // atomically.
+ // value is the actual value of the metric. It must be accessed atomically.
value uint64
}
@@ -111,13 +105,10 @@ type customUint64Metric struct {
// Register must only be called at init and will return and error if called
// after Initialized.
//
-// All metrics must be cumulative, meaning that the return values of value must
-// only increase over time.
-//
// Preconditions:
// * name must be globally unique.
// * Initialize/Disable have not been called.
-func RegisterCustomUint64Metric(name string, sync bool, description string, value func() uint64) error {
+func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error {
if initialized {
return ErrInitializationDone
}
@@ -130,9 +121,10 @@ func RegisterCustomUint64Metric(name string, sync bool, description string, valu
metadata: &pb.MetricMetadata{
Name: name,
Description: description,
- Cumulative: true,
+ Cumulative: cumulative,
Sync: sync,
- Type: pb.MetricMetadata_UINT64,
+ Type: pb.MetricMetadata_TYPE_UINT64,
+ Units: units,
},
value: value,
}
@@ -141,24 +133,32 @@ func RegisterCustomUint64Metric(name string, sync bool, description string, valu
// MustRegisterCustomUint64Metric calls RegisterCustomUint64Metric and panics
// if it returns an error.
-func MustRegisterCustomUint64Metric(name string, sync bool, description string, value func() uint64) {
- if err := RegisterCustomUint64Metric(name, sync, description, value); err != nil {
+func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func() uint64) {
+ if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value); err != nil {
panic(fmt.Sprintf("Unable to register metric %q: %v", name, err))
}
}
-// NewUint64Metric creates and registers a new metric with the given name.
+// NewUint64Metric creates and registers a new cumulative metric with the given name.
//
// Metrics must be statically defined (i.e., at init).
-func NewUint64Metric(name string, sync bool, description string) (*Uint64Metric, error) {
+func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, description string) (*Uint64Metric, error) {
var m Uint64Metric
- return &m, RegisterCustomUint64Metric(name, sync, description, m.Value)
+ return &m, RegisterCustomUint64Metric(name, true /* cumulative */, sync, units, description, m.Value)
}
-// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an
-// error.
+// MustCreateNewUint64Metric calls NewUint64Metric and panics if it returns an error.
func MustCreateNewUint64Metric(name string, sync bool, description string) *Uint64Metric {
- m, err := NewUint64Metric(name, sync, description)
+ m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description)
+ if err != nil {
+ panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
+ }
+ return m
+}
+
+// MustCreateNewUint64NanosecondsMetric calls NewUint64Metric and panics if it returns an error.
+func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric {
+ m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description)
if err != nil {
panic(fmt.Sprintf("Unable to create metric %q: %v", name, err))
}
@@ -245,6 +245,6 @@ func EmitMetricUpdate() {
return
}
- log.Debugf("Emitting metrics: %v", m)
+ log.Debugf("Emitting metrics: %v", &m)
eventchannel.Emit(&m)
}
diff --git a/pkg/metric/metric.proto b/pkg/metric/metric.proto
index a2c2bd1ba..3cc89047d 100644
--- a/pkg/metric/metric.proto
+++ b/pkg/metric/metric.proto
@@ -36,10 +36,18 @@ message MetricMetadata {
// the monitoring system.
bool sync = 4;
- enum Type { UINT64 = 0; }
+ enum Type { TYPE_UINT64 = 0; }
// type is the type of the metric value.
Type type = 5;
+
+ enum Units {
+ UNITS_NONE = 0;
+ UNITS_NANOSECONDS = 1;
+ }
+
+ // units is the units of the metric value.
+ Units units = 6;
}
// MetricRegistration contains the metadata for all metrics that will be in
diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go
index 34969385a..c425ea532 100644
--- a/pkg/metric/metric_test.go
+++ b/pkg/metric/metric_test.go
@@ -66,12 +66,12 @@ const (
func TestInitialize(t *testing.T) {
defer reset()
- _, err := NewUint64Metric("/foo", false, fooDescription)
+ _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription)
if err != nil {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- _, err = NewUint64Metric("/bar", true, barDescription)
+ _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NANOSECONDS, barDescription)
if err != nil {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
@@ -94,8 +94,8 @@ func TestInitialize(t *testing.T) {
foundFoo := false
foundBar := false
for _, m := range mr.Metrics {
- if m.Type != pb.MetricMetadata_UINT64 {
- t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_UINT64)
+ if m.Type != pb.MetricMetadata_TYPE_UINT64 {
+ t.Errorf("Metadata %+v Type got %v want %v", m, m.Type, pb.MetricMetadata_TYPE_UINT64)
}
if !m.Cumulative {
t.Errorf("Metadata %+v Cumulative got false want true", m)
@@ -110,6 +110,9 @@ func TestInitialize(t *testing.T) {
if m.Sync {
t.Errorf("/foo %+v Sync got true want false", m)
}
+ if m.Units != pb.MetricMetadata_UNITS_NONE {
+ t.Errorf("/foo %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NONE)
+ }
case "/bar":
foundBar = true
if m.Description != barDescription {
@@ -118,6 +121,9 @@ func TestInitialize(t *testing.T) {
if !m.Sync {
t.Errorf("/bar %+v Sync got true want false", m)
}
+ if m.Units != pb.MetricMetadata_UNITS_NANOSECONDS {
+ t.Errorf("/bar %+v Units got %v want %v", m, m.Units, pb.MetricMetadata_UNITS_NANOSECONDS)
+ }
}
}
@@ -132,12 +138,12 @@ func TestInitialize(t *testing.T) {
func TestDisable(t *testing.T) {
defer reset()
- _, err := NewUint64Metric("/foo", false, fooDescription)
+ _, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription)
if err != nil {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- _, err = NewUint64Metric("/bar", true, barDescription)
+ _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription)
if err != nil {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
@@ -161,12 +167,12 @@ func TestDisable(t *testing.T) {
func TestEmitMetricUpdate(t *testing.T) {
defer reset()
- foo, err := NewUint64Metric("/foo", false, fooDescription)
+ foo, err := NewUint64Metric("/foo", false, pb.MetricMetadata_UNITS_NONE, fooDescription)
if err != nil {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
- _, err = NewUint64Metric("/bar", true, barDescription)
+ _, err = NewUint64Metric("/bar", true, pb.MetricMetadata_UNITS_NONE, barDescription)
if err != nil {
t.Fatalf("NewUint64Metric got err %v want nil", err)
}
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index f32244c69..8904afad9 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(
default_visibility = ["//visibility:public"],
@@ -17,18 +16,18 @@ go_library(
"messages.go",
"p9.go",
"path_tree.go",
- "pool.go",
"server.go",
"transport.go",
"transport_flipcall.go",
"version.go",
],
- importpath = "gvisor.dev/gvisor/pkg/p9",
deps = [
"//pkg/fd",
"//pkg/fdchannel",
"//pkg/flipcall",
"//pkg/log",
+ "//pkg/pool",
+ "//pkg/sync",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -42,11 +41,10 @@ go_test(
"client_test.go",
"messages_test.go",
"p9_test.go",
- "pool_test.go",
"transport_test.go",
"version_test.go",
],
- embed = [":p9"],
+ library = ":p9",
deps = [
"//pkg/fd",
"//pkg/unet",
diff --git a/pkg/p9/buffer.go b/pkg/p9/buffer.go
index 249536d8a..6a4951821 100644
--- a/pkg/p9/buffer.go
+++ b/pkg/p9/buffer.go
@@ -20,16 +20,16 @@ import (
// encoder is used for messages and 9P primitives.
type encoder interface {
- // Decode decodes from the given buffer. Decode may be called more than once
+ // decode decodes from the given buffer. decode may be called more than once
// to reuse the instance. It must clear any previous state.
//
// This may not fail, exhaustion will be recorded in the buffer.
- Decode(b *buffer)
+ decode(b *buffer)
- // Encode encodes to the given buffer.
+ // encode encodes to the given buffer.
//
// This may not fail.
- Encode(b *buffer)
+ encode(b *buffer)
}
// order is the byte order used for encoding.
@@ -39,7 +39,7 @@ var order = binary.LittleEndian
//
// This is passed to the encoder methods.
type buffer struct {
- // data is the underlying data. This may grow during Encode.
+ // data is the underlying data. This may grow during encode.
data []byte
// overflow indicates whether an overflow has occurred.
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 221516c6c..71e944c30 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -17,12 +17,13 @@ package p9
import (
"errors"
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/pool"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -74,10 +75,10 @@ type Client struct {
socket *unet.Socket
// tagPool is the collection of available tags.
- tagPool pool
+ tagPool pool.Pool
// fidPool is the collection of available fids.
- fidPool pool
+ fidPool pool.Pool
// messageSize is the maximum total size of a message.
messageSize uint32
@@ -155,8 +156,8 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
}
c := &Client{
socket: socket,
- tagPool: pool{start: 1, limit: uint64(NoTag)},
- fidPool: pool{start: 1, limit: uint64(NoFID)},
+ tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)},
+ fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)},
pending: make(map[Tag]*response),
recvr: make(chan bool, 1),
messageSize: messageSize,
@@ -173,7 +174,7 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
// our sendRecv function to use that functionality. Otherwise,
// we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecvLegacy(&Tversion{
+ _, err := c.sendRecvLegacy(&Tversion{
Version: versionString(requested),
MSize: messageSize,
}, &rversion)
@@ -218,11 +219,11 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.sendRecv = c.sendRecvChannel
} else {
// Channel setup failed; fallback.
- c.sendRecv = c.sendRecvLegacy
+ c.sendRecv = c.sendRecvLegacySyscallErr
}
} else {
// No channels available: use the legacy mechanism.
- c.sendRecv = c.sendRecvLegacy
+ c.sendRecv = c.sendRecvLegacySyscallErr
}
// Ensure that the socket and channels are closed when the socket is shut
@@ -304,7 +305,7 @@ func (c *Client) openChannel(id int) error {
)
// Open the data channel.
- if err := c.sendRecvLegacy(&Tchannel{
+ if _, err := c.sendRecvLegacy(&Tchannel{
ID: uint32(id),
Control: 0,
}, &rchannel0); err != nil {
@@ -318,7 +319,7 @@ func (c *Client) openChannel(id int) error {
defer rchannel0.FilePayload().Close()
// Open the channel for file descriptors.
- if err := c.sendRecvLegacy(&Tchannel{
+ if _, err := c.sendRecvLegacy(&Tchannel{
ID: uint32(id),
Control: 1,
}, &rchannel1); err != nil {
@@ -430,13 +431,28 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
+// sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all
+// non-syscall errors to EIO.
+func (c *Client) sendRecvLegacySyscallErr(t message, r message) error {
+ received, err := c.sendRecvLegacy(t, r)
+ if !received {
+ log.Warningf("p9.Client.sendRecvChannel: %v", err)
+ return syscall.EIO
+ }
+ return err
+}
+
// sendRecvLegacy performs a roundtrip message exchange.
//
+// sendRecvLegacy returns true if a message was received. This allows us to
+// differentiate between failed receives and successful receives where the
+// response was an error message.
+//
// This is called by internal functions.
-func (c *Client) sendRecvLegacy(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) (bool, error) {
tag, ok := c.tagPool.Get()
if !ok {
- return ErrOutOfTags
+ return false, ErrOutOfTags
}
defer c.tagPool.Put(tag)
@@ -456,12 +472,12 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
err := send(c.socket, Tag(tag), t)
c.sendMu.Unlock()
if err != nil {
- return err
+ return false, err
}
// Co-ordinate with other receivers.
if err := c.waitAndRecv(resp.done); err != nil {
- return err
+ return false, err
}
// Is it an error message?
@@ -469,14 +485,14 @@ func (c *Client) sendRecvLegacy(t message, r message) error {
// For convenience, we transform these directly
// into errors. Handlers need not handle this case.
if rlerr, ok := resp.r.(*Rlerror); ok {
- return syscall.Errno(rlerr.Error)
+ return true, syscall.Errno(rlerr.Error)
}
// At this point, we know it matches.
//
// Per recv call above, we will only allow a type
// match (and give our r) or an instance of Rlerror.
- return nil
+ return true, nil
}
// sendRecvChannel uses channels to send a message.
@@ -485,7 +501,7 @@ func (c *Client) sendRecvChannel(t message, r message) error {
c.channelsMu.Lock()
if len(c.availableChannels) == 0 {
c.channelsMu.Unlock()
- return c.sendRecvLegacy(t, r)
+ return c.sendRecvLegacySyscallErr(t, r)
}
idx := len(c.availableChannels) - 1
ch := c.availableChannels[idx]
@@ -525,7 +541,11 @@ func (c *Client) sendRecvChannel(t message, r message) error {
}
// Parse the server's response.
- _, retErr := ch.recv(r, rsz)
+ resp, retErr := ch.recv(r, rsz)
+ if resp == nil {
+ log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr)
+ retErr = syscall.EIO
+ }
// Release the channel.
c.channelsMu.Lock()
diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go
index a6cc0617e..2ee07b664 100644
--- a/pkg/p9/client_file.go
+++ b/pkg/p9/client_file.go
@@ -17,7 +17,6 @@ package p9
import (
"fmt"
"io"
- "runtime"
"sync/atomic"
"syscall"
@@ -45,15 +44,10 @@ func (c *Client) Attach(name string) (File, error) {
// newFile returns a new client file.
func (c *Client) newFile(fid FID) *clientFile {
- cf := &clientFile{
+ return &clientFile{
client: c,
fid: fid,
}
-
- // Make sure the file is closed.
- runtime.SetFinalizer(cf, (*clientFile).Close)
-
- return cf
}
// clientFile is provided to clients.
@@ -171,6 +165,68 @@ func (c *clientFile) SetAttr(valid SetAttrMask, attr SetAttr) error {
return c.client.sendRecv(&Tsetattr{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattr{})
}
+// GetXattr implements File.GetXattr.
+func (c *clientFile) GetXattr(name string, size uint64) (string, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return "", syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return "", syscall.EOPNOTSUPP
+ }
+
+ rgetxattr := Rgetxattr{}
+ if err := c.client.sendRecv(&Tgetxattr{FID: c.fid, Name: name, Size: size}, &rgetxattr); err != nil {
+ return "", err
+ }
+
+ return rgetxattr.Value, nil
+}
+
+// SetXattr implements File.SetXattr.
+func (c *clientFile) SetXattr(name, value string, flags uint32) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsGetSetXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tsetxattr{FID: c.fid, Name: name, Value: value, Flags: flags}, &Rsetxattr{})
+}
+
+// ListXattr implements File.ListXattr.
+func (c *clientFile) ListXattr(size uint64) (map[string]struct{}, error) {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return nil, syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return nil, syscall.EOPNOTSUPP
+ }
+
+ rlistxattr := Rlistxattr{}
+ if err := c.client.sendRecv(&Tlistxattr{FID: c.fid, Size: size}, &rlistxattr); err != nil {
+ return nil, err
+ }
+
+ xattrs := make(map[string]struct{}, len(rlistxattr.Xattrs))
+ for _, x := range rlistxattr.Xattrs {
+ xattrs[x] = struct{}{}
+ }
+ return xattrs, nil
+}
+
+// RemoveXattr implements File.RemoveXattr.
+func (c *clientFile) RemoveXattr(name string) error {
+ if atomic.LoadUint32(&c.closed) != 0 {
+ return syscall.EBADF
+ }
+ if !versionSupportsListRemoveXattr(c.client.version) {
+ return syscall.EOPNOTSUPP
+ }
+
+ return c.client.sendRecv(&Tremovexattr{FID: c.fid, Name: name}, &Rremovexattr{})
+}
+
// Allocate implements File.Allocate.
func (c *clientFile) Allocate(mode AllocateMode, offset, length uint64) error {
if atomic.LoadUint32(&c.closed) != 0 {
@@ -192,7 +248,6 @@ func (c *clientFile) Remove() error {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return syscall.EBADF
}
- runtime.SetFinalizer(c, nil)
// Send the remove message.
if err := c.client.sendRecv(&Tremove{FID: c.fid}, &Rremove{}); err != nil {
@@ -214,7 +269,6 @@ func (c *clientFile) Close() error {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return syscall.EBADF
}
- runtime.SetFinalizer(c, nil)
// Send the close message.
if err := c.client.sendRecv(&Tclunk{FID: c.fid}, &Rclunk{}); err != nil {
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 29a0afadf..b78fdab7a 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -62,6 +62,8 @@ func TestVersion(t *testing.T) {
}
func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ b.ReportAllocs()
+
// See above.
serverSocket, clientSocket, err := unet.SocketPair(false)
if err != nil {
@@ -96,7 +98,12 @@ func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) e
}
func BenchmarkSendRecvLegacy(b *testing.B) {
- benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error {
+ return func(t message, r message) error {
+ _, err := c.sendRecvLegacy(t, r)
+ return err
+ }
+ })
}
func BenchmarkSendRecvChannel(b *testing.B) {
diff --git a/pkg/p9/file.go b/pkg/p9/file.go
index 907445e15..cab35896f 100644
--- a/pkg/p9/file.go
+++ b/pkg/p9/file.go
@@ -89,6 +89,38 @@ type File interface {
// On the server, SetAttr has a write concurrency guarantee.
SetAttr(valid SetAttrMask, attr SetAttr) error
+ // GetXattr returns extended attributes of this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // On the server, GetXattr has a read concurrency guarantee.
+ GetXattr(name string, size uint64) (string, error)
+
+ // SetXattr sets extended attributes on this node.
+ //
+ // On the server, SetXattr has a write concurrency guarantee.
+ SetXattr(name, value string, flags uint32) error
+
+ // ListXattr lists the names of the extended attributes on this node.
+ //
+ // Size indicates the size of the buffer that has been allocated to hold the
+ // attribute list. If the list would be larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ //
+ // On the server, ListXattr has a read concurrency guarantee.
+ ListXattr(size uint64) (map[string]struct{}, error)
+
+ // RemoveXattr removes extended attributes on this node.
+ //
+ // On the server, RemoveXattr has a write concurrency guarantee.
+ RemoveXattr(name string) error
+
// Allocate allows the caller to directly manipulate the allocated disk space
// for the file. See fallocate(2) for more details.
Allocate(mode AllocateMode, offset, length uint64) error
@@ -116,7 +148,7 @@ type File interface {
// N.B. The server must resolve any lazy paths when open is called.
// After this point, read and write may be called on files with no
// deletion check, so resolving in the data path is not viable.
- Open(mode OpenFlags) (*fd.FD, QID, uint32, error)
+ Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
// Read reads from this file. Open must be called first.
//
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index ba9a55d6d..1db5797dd 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -48,6 +48,8 @@ func ExtractErrno(err error) syscall.Errno {
return ExtractErrno(e.Err)
case *os.SyscallError:
return ExtractErrno(e.Err)
+ case *os.LinkError:
+ return ExtractErrno(e.Err)
}
// Default case.
@@ -257,7 +259,6 @@ func CanOpen(mode FileMode) bool {
// handle implements handler.handle.
func (t *Tlopen) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -272,15 +273,15 @@ func (t *Tlopen) handle(cs *connState) message {
return newErr(syscall.EINVAL)
}
- // Are flags valid?
- flags := t.Flags &^ OpenFlagsIgnoreMask
- if flags&^OpenFlagsModeMask != 0 {
- return newErr(syscall.EINVAL)
- }
-
- // Is this an attempt to open a directory as writable? Don't accept.
- if ref.mode.IsDir() && flags != ReadOnly {
- return newErr(syscall.EINVAL)
+ if ref.mode.IsDir() {
+ // Directory must be opened ReadOnly.
+ if t.Flags&OpenFlagsModeMask != ReadOnly {
+ return newErr(syscall.EISDIR)
+ }
+ // Directory not truncatable.
+ if t.Flags&OpenTruncate != 0 {
+ return newErr(syscall.EISDIR)
+ }
}
var (
@@ -294,7 +295,6 @@ func (t *Tlopen) handle(cs *connState) message {
return syscall.EINVAL
}
- // Do the open.
osFile, qid, ioUnit, err = ref.file.Open(t.Flags)
return err
}); err != nil {
@@ -311,12 +311,10 @@ func (t *Tlopen) handle(cs *connState) message {
}
func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return nil, syscall.EBADF
@@ -390,12 +388,10 @@ func (t *Tsymlink) handle(cs *connState) message {
}
func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return nil, syscall.EBADF
@@ -426,19 +422,16 @@ func (t *Tsymlink) do(cs *connState, uid UID) (*Rsymlink, error) {
// handle implements handler.handle.
func (t *Tlink) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
}
defer ref.DecRef()
- // Lookup the other FID.
refTarget, ok := cs.LookupFID(t.Target)
if !ok {
return newErr(syscall.EBADF)
@@ -467,7 +460,6 @@ func (t *Tlink) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Trenameat) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.OldName); err != nil {
return newErr(err)
}
@@ -475,14 +467,12 @@ func (t *Trenameat) handle(cs *connState) message {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.OldDirectory)
if !ok {
return newErr(syscall.EBADF)
}
defer ref.DecRef()
- // Lookup the other FID.
refTarget, ok := cs.LookupFID(t.NewDirectory)
if !ok {
return newErr(syscall.EBADF)
@@ -523,12 +513,10 @@ func (t *Trenameat) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tunlinkat) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
@@ -577,19 +565,16 @@ func (t *Tunlinkat) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Trename) handle(cs *connState) message {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return newErr(err)
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
}
defer ref.DecRef()
- // Lookup the target.
refTarget, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
@@ -641,7 +626,6 @@ func (t *Trename) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Treadlink) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -669,7 +653,6 @@ func (t *Treadlink) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tread) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -708,7 +691,6 @@ func (t *Tread) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Twrite) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -747,12 +729,10 @@ func (t *Tmknod) handle(cs *connState) message {
}
func (t *Tmknod) do(cs *connState, uid UID) (*Rmknod, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return nil, syscall.EBADF
@@ -791,12 +771,10 @@ func (t *Tmkdir) handle(cs *connState) message {
}
func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
- // Don't allow complex names.
if err := checkSafeName(t.Name); err != nil {
return nil, err
}
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return nil, syscall.EBADF
@@ -827,7 +805,6 @@ func (t *Tmkdir) do(cs *connState, uid UID) (*Rmkdir, error) {
// handle implements handler.handle.
func (t *Tgetattr) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -856,7 +833,6 @@ func (t *Tgetattr) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tsetattr) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -883,7 +859,6 @@ func (t *Tsetattr) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tallocate) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -917,7 +892,6 @@ func (t *Tallocate) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Txattrwalk) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -930,7 +904,6 @@ func (t *Txattrwalk) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Txattrcreate) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -942,8 +915,96 @@ func (t *Txattrcreate) handle(cs *connState) message {
}
// handle implements handler.handle.
+func (t *Tgetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var val string
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow getxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ val, err = ref.file.GetXattr(t.Name, t.Size)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rgetxattr{Value: val}
+}
+
+// handle implements handler.handle.
+func (t *Tsetxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Don't allow setxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.SetXattr(t.Name, t.Value, t.Flags)
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rsetxattr{}
+}
+
+// handle implements handler.handle.
+func (t *Tlistxattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ var xattrs map[string]struct{}
+ if err := ref.safelyRead(func() (err error) {
+ // Don't allow listxattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ xattrs, err = ref.file.ListXattr(t.Size)
+ return err
+ }); err != nil {
+ return newErr(err)
+ }
+
+ xattrList := make([]string, 0, len(xattrs))
+ for x := range xattrs {
+ xattrList = append(xattrList, x)
+ }
+ return &Rlistxattr{Xattrs: xattrList}
+}
+
+// handle implements handler.handle.
+func (t *Tremovexattr) handle(cs *connState) message {
+ ref, ok := cs.LookupFID(t.FID)
+ if !ok {
+ return newErr(syscall.EBADF)
+ }
+ defer ref.DecRef()
+
+ if err := ref.safelyWrite(func() error {
+ // Don't allow removexattr on files that have been deleted.
+ if ref.isDeleted() {
+ return syscall.EINVAL
+ }
+ return ref.file.RemoveXattr(t.Name)
+ }); err != nil {
+ return newErr(err)
+ }
+ return &Rremovexattr{}
+}
+
+// handle implements handler.handle.
func (t *Treaddir) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.Directory)
if !ok {
return newErr(syscall.EBADF)
@@ -977,7 +1038,6 @@ func (t *Treaddir) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tfsync) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1001,7 +1061,6 @@ func (t *Tfsync) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tstatfs) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1192,7 +1251,6 @@ func doWalk(cs *connState, ref *fidRef, names []string, getattr bool) (qids []QI
// handle implements handler.handle.
func (t *Twalk) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1213,7 +1271,6 @@ func (t *Twalk) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Twalkgetattr) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1270,7 +1327,6 @@ func (t *Tumknod) handle(cs *connState) message {
// handle implements handler.handle.
func (t *Tlconnect) handle(cs *connState) message {
- // Lookup the FID.
ref, ok := cs.LookupFID(t.FID)
if !ok {
return newErr(syscall.EBADF)
@@ -1303,7 +1359,6 @@ func (t *Tchannel) handle(cs *connState) message {
return newErr(err)
}
- // Lookup the given channel.
ch := cs.lookupChannel(t.ID)
if ch == nil {
return newErr(syscall.ENOSYS)
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index ffdd7e8c6..2cb59f934 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -51,7 +51,7 @@ type payloader interface {
// SetPayload returns the decoded message.
//
// This is going to be total message size - FixedSize. But this should
- // be validated during Decode, which will be called after SetPayload.
+ // be validated during decode, which will be called after SetPayload.
SetPayload([]byte)
}
@@ -90,14 +90,14 @@ type Tversion struct {
Version string
}
-// Decode implements encoder.Decode.
-func (t *Tversion) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tversion) decode(b *buffer) {
t.MSize = b.Read32()
t.Version = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Tversion) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tversion) encode(b *buffer) {
b.Write32(t.MSize)
b.WriteString(t.Version)
}
@@ -121,14 +121,14 @@ type Rversion struct {
Version string
}
-// Decode implements encoder.Decode.
-func (r *Rversion) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rversion) decode(b *buffer) {
r.MSize = b.Read32()
r.Version = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rversion) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rversion) encode(b *buffer) {
b.Write32(r.MSize)
b.WriteString(r.Version)
}
@@ -149,13 +149,13 @@ type Tflush struct {
OldTag Tag
}
-// Decode implements encoder.Decode.
-func (t *Tflush) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tflush) decode(b *buffer) {
t.OldTag = b.ReadTag()
}
-// Encode implements encoder.Encode.
-func (t *Tflush) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tflush) encode(b *buffer) {
b.WriteTag(t.OldTag)
}
@@ -173,12 +173,12 @@ func (t *Tflush) String() string {
type Rflush struct {
}
-// Decode implements encoder.Decode.
-func (*Rflush) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rflush) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rflush) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rflush) encode(*buffer) {
}
// Type implements message.Type.
@@ -188,7 +188,7 @@ func (*Rflush) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rflush) String() string {
- return fmt.Sprintf("RFlush{}")
+ return "RFlush{}"
}
// Twalk is a walk request.
@@ -203,8 +203,8 @@ type Twalk struct {
Names []string
}
-// Decode implements encoder.Decode.
-func (t *Twalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twalk) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
@@ -214,8 +214,8 @@ func (t *Twalk) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (t *Twalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Twalk) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.Write16(uint16(len(t.Names)))
@@ -240,22 +240,22 @@ type Rwalk struct {
QIDs []QID
}
-// Decode implements encoder.Decode.
-func (r *Rwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rwalk) decode(b *buffer) {
n := b.Read16()
r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
- q.Decode(b)
+ q.decode(b)
r.QIDs = append(r.QIDs, q)
}
}
-// Encode implements encoder.Encode.
-func (r *Rwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rwalk) encode(b *buffer) {
b.Write16(uint16(len(r.QIDs)))
for _, q := range r.QIDs {
- q.Encode(b)
+ q.encode(b)
}
}
@@ -275,13 +275,13 @@ type Tclunk struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tclunk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tclunk) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tclunk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tclunk) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -299,12 +299,12 @@ func (t *Tclunk) String() string {
type Rclunk struct {
}
-// Decode implements encoder.Decode.
-func (*Rclunk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rclunk) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rclunk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rclunk) encode(*buffer) {
}
// Type implements message.Type.
@@ -314,7 +314,7 @@ func (*Rclunk) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rclunk) String() string {
- return fmt.Sprintf("Rclunk{}")
+ return "Rclunk{}"
}
// Tremove is a remove request.
@@ -325,13 +325,13 @@ type Tremove struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tremove) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tremove) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tremove) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tremove) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -349,12 +349,12 @@ func (t *Tremove) String() string {
type Rremove struct {
}
-// Decode implements encoder.Decode.
-func (*Rremove) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rremove) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rremove) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rremove) encode(*buffer) {
}
// Type implements message.Type.
@@ -364,7 +364,7 @@ func (*Rremove) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rremove) String() string {
- return fmt.Sprintf("Rremove{}")
+ return "Rremove{}"
}
// Rlerror is an error response.
@@ -374,13 +374,13 @@ type Rlerror struct {
Error uint32
}
-// Decode implements encoder.Decode.
-func (r *Rlerror) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rlerror) decode(b *buffer) {
r.Error = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rlerror) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rlerror) encode(b *buffer) {
b.Write32(r.Error)
}
@@ -409,16 +409,16 @@ type Tauth struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tauth) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tauth) decode(b *buffer) {
t.AuthenticationFID = b.ReadFID()
t.UserName = b.ReadString()
t.AttachName = b.ReadString()
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tauth) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tauth) encode(b *buffer) {
b.WriteFID(t.AuthenticationFID)
b.WriteString(t.UserName)
b.WriteString(t.AttachName)
@@ -437,7 +437,7 @@ func (t *Tauth) String() string {
// Rauth is an authentication response.
//
-// Encode, Decode and Length are inherited directly from QID.
+// encode and decode are inherited directly from QID.
type Rauth struct {
QID
}
@@ -463,16 +463,16 @@ type Tattach struct {
Auth Tauth
}
-// Decode implements encoder.Decode.
-func (t *Tattach) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tattach) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Auth.Decode(b)
+ t.Auth.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tattach) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tattach) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Auth.Encode(b)
+ t.Auth.encode(b)
}
// Type implements message.Type.
@@ -509,14 +509,14 @@ type Tlopen struct {
Flags OpenFlags
}
-// Decode implements encoder.Decode.
-func (t *Tlopen) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlopen) decode(b *buffer) {
t.FID = b.ReadFID()
t.Flags = b.ReadOpenFlags()
}
-// Encode implements encoder.Encode.
-func (t *Tlopen) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlopen) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteOpenFlags(t.Flags)
}
@@ -542,15 +542,15 @@ type Rlopen struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rlopen) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rlopen) decode(b *buffer) {
+ r.QID.decode(b)
r.IoUnit = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rlopen) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rlopen) encode(b *buffer) {
+ r.QID.encode(b)
b.Write32(r.IoUnit)
}
@@ -587,8 +587,8 @@ type Tlcreate struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tlcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlcreate) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.OpenFlags = b.ReadOpenFlags()
@@ -596,8 +596,8 @@ func (t *Tlcreate) Decode(b *buffer) {
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tlcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlcreate) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.WriteOpenFlags(t.OpenFlags)
@@ -617,7 +617,7 @@ func (t *Tlcreate) String() string {
// Rlcreate is a create response.
//
-// The Encode, Decode, etc. methods are inherited from Rlopen.
+// The encode, decode, etc. methods are inherited from Rlopen.
type Rlcreate struct {
Rlopen
}
@@ -647,16 +647,16 @@ type Tsymlink struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tsymlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsymlink) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Target = b.ReadString()
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tsymlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsymlink) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WriteString(t.Target)
@@ -679,14 +679,14 @@ type Rsymlink struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rsymlink) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rsymlink) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rsymlink) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rsymlink) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -711,15 +711,15 @@ type Tlink struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Tlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlink) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Target = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Tlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlink) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteFID(t.Target)
b.WriteString(t.Name)
@@ -744,17 +744,17 @@ func (*Rlink) Type() MsgType {
return MsgRlink
}
-// Decode implements encoder.Decode.
-func (*Rlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rlink) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rlink) encode(*buffer) {
}
// String implements fmt.Stringer.
func (r *Rlink) String() string {
- return fmt.Sprintf("Rlink{}")
+ return "Rlink{}"
}
// Trenameat is a rename request.
@@ -772,16 +772,16 @@ type Trenameat struct {
NewName string
}
-// Decode implements encoder.Decode.
-func (t *Trenameat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Trenameat) decode(b *buffer) {
t.OldDirectory = b.ReadFID()
t.OldName = b.ReadString()
t.NewDirectory = b.ReadFID()
t.NewName = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Trenameat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Trenameat) encode(b *buffer) {
b.WriteFID(t.OldDirectory)
b.WriteString(t.OldName)
b.WriteFID(t.NewDirectory)
@@ -802,12 +802,12 @@ func (t *Trenameat) String() string {
type Rrenameat struct {
}
-// Decode implements encoder.Decode.
-func (*Rrenameat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rrenameat) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rrenameat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rrenameat) encode(*buffer) {
}
// Type implements message.Type.
@@ -817,7 +817,7 @@ func (*Rrenameat) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rrenameat) String() string {
- return fmt.Sprintf("Rrenameat{}")
+ return "Rrenameat{}"
}
// Tunlinkat is an unlink request.
@@ -832,15 +832,15 @@ type Tunlinkat struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Tunlinkat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tunlinkat) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tunlinkat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tunlinkat) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.Write32(t.Flags)
@@ -860,12 +860,12 @@ func (t *Tunlinkat) String() string {
type Runlinkat struct {
}
-// Decode implements encoder.Decode.
-func (*Runlinkat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Runlinkat) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Runlinkat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Runlinkat) encode(*buffer) {
}
// Type implements message.Type.
@@ -875,7 +875,7 @@ func (*Runlinkat) Type() MsgType {
// String implements fmt.Stringer.
func (r *Runlinkat) String() string {
- return fmt.Sprintf("Runlinkat{}")
+ return "Runlinkat{}"
}
// Trename is a rename request.
@@ -893,15 +893,15 @@ type Trename struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Trename) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Trename) decode(b *buffer) {
t.FID = b.ReadFID()
t.Directory = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Trename) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Trename) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.Directory)
b.WriteString(t.Name)
@@ -921,12 +921,12 @@ func (t *Trename) String() string {
type Rrename struct {
}
-// Decode implements encoder.Decode.
-func (*Rrename) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rrename) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rrename) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rrename) encode(*buffer) {
}
// Type implements message.Type.
@@ -936,7 +936,7 @@ func (*Rrename) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rrename) String() string {
- return fmt.Sprintf("Rrename{}")
+ return "Rrename{}"
}
// Treadlink is a readlink request.
@@ -945,13 +945,13 @@ type Treadlink struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Treadlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Treadlink) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Treadlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Treadlink) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -971,13 +971,13 @@ type Rreadlink struct {
Target string
}
-// Decode implements encoder.Decode.
-func (r *Rreadlink) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rreadlink) decode(b *buffer) {
r.Target = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (r *Rreadlink) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rreadlink) encode(b *buffer) {
b.WriteString(r.Target)
}
@@ -1003,15 +1003,15 @@ type Tread struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (t *Tread) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tread) decode(b *buffer) {
t.FID = b.ReadFID()
t.Offset = b.Read64()
t.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tread) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tread) encode(b *buffer) {
b.WriteFID(t.FID)
b.Write64(t.Offset)
b.Write32(t.Count)
@@ -1033,20 +1033,20 @@ type Rread struct {
Data []byte
}
-// Decode implements encoder.Decode.
+// decode implements encoder.decode.
//
// Data is automatically decoded via Payload.
-func (r *Rread) Decode(b *buffer) {
+func (r *Rread) decode(b *buffer) {
count := b.Read32()
if count != uint32(len(r.Data)) {
b.markOverrun()
}
}
-// Encode implements encoder.Encode.
+// encode implements encoder.encode.
//
// Data is automatically encoded via Payload.
-func (r *Rread) Encode(b *buffer) {
+func (r *Rread) encode(b *buffer) {
b.Write32(uint32(len(r.Data)))
}
@@ -1087,8 +1087,8 @@ type Twrite struct {
Data []byte
}
-// Decode implements encoder.Decode.
-func (t *Twrite) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twrite) decode(b *buffer) {
t.FID = b.ReadFID()
t.Offset = b.Read64()
count := b.Read32()
@@ -1097,10 +1097,10 @@ func (t *Twrite) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
+// encode implements encoder.encode.
//
// This uses the buffer payload to avoid a copy.
-func (t *Twrite) Encode(b *buffer) {
+func (t *Twrite) encode(b *buffer) {
b.WriteFID(t.FID)
b.Write64(t.Offset)
b.Write32(uint32(len(t.Data)))
@@ -1137,13 +1137,13 @@ type Rwrite struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (r *Rwrite) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rwrite) decode(b *buffer) {
r.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (r *Rwrite) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rwrite) encode(b *buffer) {
b.Write32(r.Count)
}
@@ -1178,8 +1178,8 @@ type Tmknod struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tmknod) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tmknod) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Mode = b.ReadFileMode()
@@ -1188,8 +1188,8 @@ func (t *Tmknod) Decode(b *buffer) {
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tmknod) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tmknod) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WriteFileMode(t.Mode)
@@ -1214,14 +1214,14 @@ type Rmknod struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rmknod) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rmknod) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rmknod) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rmknod) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -1249,16 +1249,16 @@ type Tmkdir struct {
GID GID
}
-// Decode implements encoder.Decode.
-func (t *Tmkdir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tmkdir) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Name = b.ReadString()
t.Permissions = b.ReadPermissions()
t.GID = b.ReadGID()
}
-// Encode implements encoder.Encode.
-func (t *Tmkdir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tmkdir) encode(b *buffer) {
b.WriteFID(t.Directory)
b.WriteString(t.Name)
b.WritePermissions(t.Permissions)
@@ -1281,14 +1281,14 @@ type Rmkdir struct {
QID QID
}
-// Decode implements encoder.Decode.
-func (r *Rmkdir) Decode(b *buffer) {
- r.QID.Decode(b)
+// decode implements encoder.decode.
+func (r *Rmkdir) decode(b *buffer) {
+ r.QID.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rmkdir) Encode(b *buffer) {
- r.QID.Encode(b)
+// encode implements encoder.encode.
+func (r *Rmkdir) encode(b *buffer) {
+ r.QID.encode(b)
}
// Type implements message.Type.
@@ -1310,16 +1310,16 @@ type Tgetattr struct {
AttrMask AttrMask
}
-// Decode implements encoder.Decode.
-func (t *Tgetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tgetattr) decode(b *buffer) {
t.FID = b.ReadFID()
- t.AttrMask.Decode(b)
+ t.AttrMask.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tgetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tgetattr) encode(b *buffer) {
b.WriteFID(t.FID)
- t.AttrMask.Encode(b)
+ t.AttrMask.encode(b)
}
// Type implements message.Type.
@@ -1344,18 +1344,18 @@ type Rgetattr struct {
Attr Attr
}
-// Decode implements encoder.Decode.
-func (r *Rgetattr) Decode(b *buffer) {
- r.Valid.Decode(b)
- r.QID.Decode(b)
- r.Attr.Decode(b)
+// decode implements encoder.decode.
+func (r *Rgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.QID.decode(b)
+ r.Attr.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rgetattr) Encode(b *buffer) {
- r.Valid.Encode(b)
- r.QID.Encode(b)
- r.Attr.Encode(b)
+// encode implements encoder.encode.
+func (r *Rgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.QID.encode(b)
+ r.Attr.encode(b)
}
// Type implements message.Type.
@@ -1380,18 +1380,18 @@ type Tsetattr struct {
SetAttr SetAttr
}
-// Decode implements encoder.Decode.
-func (t *Tsetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tsetattr) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Valid.Decode(b)
- t.SetAttr.Decode(b)
+ t.Valid.decode(b)
+ t.SetAttr.decode(b)
}
-// Encode implements encoder.Encode.
-func (t *Tsetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tsetattr) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Valid.Encode(b)
- t.SetAttr.Encode(b)
+ t.Valid.encode(b)
+ t.SetAttr.encode(b)
}
// Type implements message.Type.
@@ -1408,12 +1408,12 @@ func (t *Tsetattr) String() string {
type Rsetattr struct {
}
-// Decode implements encoder.Decode.
-func (*Rsetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rsetattr) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rsetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rsetattr) encode(*buffer) {
}
// Type implements message.Type.
@@ -1423,7 +1423,7 @@ func (*Rsetattr) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rsetattr) String() string {
- return fmt.Sprintf("Rsetattr{}")
+ return "Rsetattr{}"
}
// Tallocate is an allocate request. This is an extension to 9P protocol, not
@@ -1435,18 +1435,18 @@ type Tallocate struct {
Length uint64
}
-// Decode implements encoder.Decode.
-func (t *Tallocate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tallocate) decode(b *buffer) {
t.FID = b.ReadFID()
- t.Mode.Decode(b)
+ t.Mode.decode(b)
t.Offset = b.Read64()
t.Length = b.Read64()
}
-// Encode implements encoder.Encode.
-func (t *Tallocate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tallocate) encode(b *buffer) {
b.WriteFID(t.FID)
- t.Mode.Encode(b)
+ t.Mode.encode(b)
b.Write64(t.Offset)
b.Write64(t.Length)
}
@@ -1465,12 +1465,12 @@ func (t *Tallocate) String() string {
type Rallocate struct {
}
-// Decode implements encoder.Decode.
-func (*Rallocate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rallocate) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rallocate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rallocate) encode(*buffer) {
}
// Type implements message.Type.
@@ -1480,7 +1480,71 @@ func (*Rallocate) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rallocate) String() string {
- return fmt.Sprintf("Rallocate{}")
+ return "Rallocate{}"
+}
+
+// Tlistxattr is a listxattr request.
+type Tlistxattr struct {
+ // FID refers to the file on which to list xattrs.
+ FID FID
+
+ // Size is the buffer size for the xattr list.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tlistxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tlistxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tlistxattr) Type() MsgType {
+ return MsgTlistxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tlistxattr) String() string {
+ return fmt.Sprintf("Tlistxattr{FID: %d, Size: %d}", t.FID, t.Size)
+}
+
+// Rlistxattr is a listxattr response.
+type Rlistxattr struct {
+ // Xattrs is a list of extended attribute names.
+ Xattrs []string
+}
+
+// decode implements encoder.decode.
+func (r *Rlistxattr) decode(b *buffer) {
+ n := b.Read16()
+ r.Xattrs = r.Xattrs[:0]
+ for i := 0; i < int(n); i++ {
+ r.Xattrs = append(r.Xattrs, b.ReadString())
+ }
+}
+
+// encode implements encoder.encode.
+func (r *Rlistxattr) encode(b *buffer) {
+ b.Write16(uint16(len(r.Xattrs)))
+ for _, x := range r.Xattrs {
+ b.WriteString(x)
+ }
+}
+
+// Type implements message.Type.
+func (*Rlistxattr) Type() MsgType {
+ return MsgRlistxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rlistxattr) String() string {
+ return fmt.Sprintf("Rlistxattr{Xattrs: %v}", r.Xattrs)
}
// Txattrwalk walks extended attributes.
@@ -1495,15 +1559,15 @@ type Txattrwalk struct {
Name string
}
-// Decode implements encoder.Decode.
-func (t *Txattrwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Txattrwalk) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
t.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (t *Txattrwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Txattrwalk) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.WriteString(t.Name)
@@ -1525,13 +1589,13 @@ type Rxattrwalk struct {
Size uint64
}
-// Decode implements encoder.Decode.
-func (r *Rxattrwalk) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rxattrwalk) decode(b *buffer) {
r.Size = b.Read64()
}
-// Encode implements encoder.Encode.
-func (r *Rxattrwalk) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rxattrwalk) encode(b *buffer) {
b.Write64(r.Size)
}
@@ -1563,16 +1627,16 @@ type Txattrcreate struct {
Flags uint32
}
-// Decode implements encoder.Decode.
-func (t *Txattrcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Txattrcreate) decode(b *buffer) {
t.FID = b.ReadFID()
t.Name = b.ReadString()
t.AttrSize = b.Read64()
t.Flags = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Txattrcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Txattrcreate) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteString(t.Name)
b.Write64(t.AttrSize)
@@ -1593,12 +1657,12 @@ func (t *Txattrcreate) String() string {
type Rxattrcreate struct {
}
-// Decode implements encoder.Decode.
-func (r *Rxattrcreate) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rxattrcreate) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (r *Rxattrcreate) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rxattrcreate) encode(*buffer) {
}
// Type implements message.Type.
@@ -1608,7 +1672,185 @@ func (*Rxattrcreate) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rxattrcreate) String() string {
- return fmt.Sprintf("Rxattrcreate{}")
+ return "Rxattrcreate{}"
+}
+
+// Tgetxattr is a getxattr request.
+type Tgetxattr struct {
+ // FID refers to the file for which to get xattrs.
+ FID FID
+
+ // Name is the xattr to get.
+ Name string
+
+ // Size is the buffer size for the xattr to get.
+ Size uint64
+}
+
+// decode implements encoder.decode.
+func (t *Tgetxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Size = b.Read64()
+}
+
+// encode implements encoder.encode.
+func (t *Tgetxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.Write64(t.Size)
+}
+
+// Type implements message.Type.
+func (*Tgetxattr) Type() MsgType {
+ return MsgTgetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tgetxattr) String() string {
+ return fmt.Sprintf("Tgetxattr{FID: %d, Name: %s, Size: %d}", t.FID, t.Name, t.Size)
+}
+
+// Rgetxattr is a getxattr response.
+type Rgetxattr struct {
+ // Value is the extended attribute value.
+ Value string
+}
+
+// decode implements encoder.decode.
+func (r *Rgetxattr) decode(b *buffer) {
+ r.Value = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (r *Rgetxattr) encode(b *buffer) {
+ b.WriteString(r.Value)
+}
+
+// Type implements message.Type.
+func (*Rgetxattr) Type() MsgType {
+ return MsgRgetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rgetxattr) String() string {
+ return fmt.Sprintf("Rgetxattr{Value: %s}", r.Value)
+}
+
+// Tsetxattr sets extended attributes.
+type Tsetxattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+
+ // Value is the attribute value.
+ Value string
+
+ // Linux setxattr(2) flags.
+ Flags uint32
+}
+
+// decode implements encoder.decode.
+func (t *Tsetxattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+ t.Value = b.ReadString()
+ t.Flags = b.Read32()
+}
+
+// encode implements encoder.encode.
+func (t *Tsetxattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+ b.WriteString(t.Value)
+ b.Write32(t.Flags)
+}
+
+// Type implements message.Type.
+func (*Tsetxattr) Type() MsgType {
+ return MsgTsetxattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tsetxattr) String() string {
+ return fmt.Sprintf("Tsetxattr{FID: %d, Name: %s, Value: %s, Flags: %d}", t.FID, t.Name, t.Value, t.Flags)
+}
+
+// Rsetxattr is a setxattr response.
+type Rsetxattr struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rsetxattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rsetxattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rsetxattr) Type() MsgType {
+ return MsgRsetxattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rsetxattr) String() string {
+ return "Rsetxattr{}"
+}
+
+// Tremovexattr is a removexattr request.
+type Tremovexattr struct {
+ // FID refers to the file on which to set xattrs.
+ FID FID
+
+ // Name is the attribute name.
+ Name string
+}
+
+// decode implements encoder.decode.
+func (t *Tremovexattr) decode(b *buffer) {
+ t.FID = b.ReadFID()
+ t.Name = b.ReadString()
+}
+
+// encode implements encoder.encode.
+func (t *Tremovexattr) encode(b *buffer) {
+ b.WriteFID(t.FID)
+ b.WriteString(t.Name)
+}
+
+// Type implements message.Type.
+func (*Tremovexattr) Type() MsgType {
+ return MsgTremovexattr
+}
+
+// String implements fmt.Stringer.
+func (t *Tremovexattr) String() string {
+ return fmt.Sprintf("Tremovexattr{FID: %d, Name: %s}", t.FID, t.Name)
+}
+
+// Rremovexattr is a removexattr response.
+type Rremovexattr struct {
+}
+
+// decode implements encoder.decode.
+func (r *Rremovexattr) decode(*buffer) {
+}
+
+// encode implements encoder.encode.
+func (r *Rremovexattr) encode(*buffer) {
+}
+
+// Type implements message.Type.
+func (*Rremovexattr) Type() MsgType {
+ return MsgRremovexattr
+}
+
+// String implements fmt.Stringer.
+func (r *Rremovexattr) String() string {
+ return "Rremovexattr{}"
}
// Treaddir is a readdir request.
@@ -1623,15 +1865,15 @@ type Treaddir struct {
Count uint32
}
-// Decode implements encoder.Decode.
-func (t *Treaddir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Treaddir) decode(b *buffer) {
t.Directory = b.ReadFID()
t.Offset = b.Read64()
t.Count = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Treaddir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Treaddir) encode(b *buffer) {
b.WriteFID(t.Directory)
b.Write64(t.Offset)
b.Write32(t.Count)
@@ -1665,14 +1907,14 @@ type Rreaddir struct {
payload []byte
}
-// Decode implements encoder.Decode.
-func (r *Rreaddir) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rreaddir) decode(b *buffer) {
r.Count = b.Read32()
entriesBuf := buffer{data: r.payload}
r.Entries = r.Entries[:0]
for {
var d Dirent
- d.Decode(&entriesBuf)
+ d.decode(&entriesBuf)
if entriesBuf.isOverrun() {
// Couldn't decode a complete entry.
break
@@ -1681,22 +1923,20 @@ func (r *Rreaddir) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (r *Rreaddir) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rreaddir) encode(b *buffer) {
entriesBuf := buffer{}
+ payloadSize := 0
for _, d := range r.Entries {
- d.Encode(&entriesBuf)
- if len(entriesBuf.data) >= int(r.Count) {
+ d.encode(&entriesBuf)
+ if len(entriesBuf.data) > int(r.Count) {
break
}
+ payloadSize = len(entriesBuf.data)
}
- if len(entriesBuf.data) < int(r.Count) {
- r.Count = uint32(len(entriesBuf.data))
- r.payload = entriesBuf.data
- } else {
- r.payload = entriesBuf.data[:r.Count]
- }
- b.Write32(uint32(r.Count))
+ r.Count = uint32(payloadSize)
+ r.payload = entriesBuf.data[:payloadSize]
+ b.Write32(r.Count)
}
// Type implements message.Type.
@@ -1730,13 +1970,13 @@ type Tfsync struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tfsync) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tfsync) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tfsync) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tfsync) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -1754,12 +1994,12 @@ func (t *Tfsync) String() string {
type Rfsync struct {
}
-// Decode implements encoder.Decode.
-func (*Rfsync) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rfsync) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rfsync) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rfsync) encode(*buffer) {
}
// Type implements message.Type.
@@ -1769,7 +2009,7 @@ func (*Rfsync) Type() MsgType {
// String implements fmt.Stringer.
func (r *Rfsync) String() string {
- return fmt.Sprintf("Rfsync{}")
+ return "Rfsync{}"
}
// Tstatfs is a stat request.
@@ -1778,13 +2018,13 @@ type Tstatfs struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tstatfs) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tstatfs) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tstatfs) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tstatfs) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -1804,14 +2044,14 @@ type Rstatfs struct {
FSStat FSStat
}
-// Decode implements encoder.Decode.
-func (r *Rstatfs) Decode(b *buffer) {
- r.FSStat.Decode(b)
+// decode implements encoder.decode.
+func (r *Rstatfs) decode(b *buffer) {
+ r.FSStat.decode(b)
}
-// Encode implements encoder.Encode.
-func (r *Rstatfs) Encode(b *buffer) {
- r.FSStat.Encode(b)
+// encode implements encoder.encode.
+func (r *Rstatfs) encode(b *buffer) {
+ r.FSStat.encode(b)
}
// Type implements message.Type.
@@ -1830,13 +2070,13 @@ type Tflushf struct {
FID FID
}
-// Decode implements encoder.Decode.
-func (t *Tflushf) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tflushf) decode(b *buffer) {
t.FID = b.ReadFID()
}
-// Encode implements encoder.Encode.
-func (t *Tflushf) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tflushf) encode(b *buffer) {
b.WriteFID(t.FID)
}
@@ -1854,12 +2094,12 @@ func (t *Tflushf) String() string {
type Rflushf struct {
}
-// Decode implements encoder.Decode.
-func (*Rflushf) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (*Rflushf) decode(*buffer) {
}
-// Encode implements encoder.Encode.
-func (*Rflushf) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (*Rflushf) encode(*buffer) {
}
// Type implements message.Type.
@@ -1869,7 +2109,7 @@ func (*Rflushf) Type() MsgType {
// String implements fmt.Stringer.
func (*Rflushf) String() string {
- return fmt.Sprintf("Rflushf{}")
+ return "Rflushf{}"
}
// Twalkgetattr is a walk request.
@@ -1884,8 +2124,8 @@ type Twalkgetattr struct {
Names []string
}
-// Decode implements encoder.Decode.
-func (t *Twalkgetattr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Twalkgetattr) decode(b *buffer) {
t.FID = b.ReadFID()
t.NewFID = b.ReadFID()
n := b.Read16()
@@ -1895,8 +2135,8 @@ func (t *Twalkgetattr) Decode(b *buffer) {
}
}
-// Encode implements encoder.Encode.
-func (t *Twalkgetattr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Twalkgetattr) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteFID(t.NewFID)
b.Write16(uint16(len(t.Names)))
@@ -1927,26 +2167,26 @@ type Rwalkgetattr struct {
QIDs []QID
}
-// Decode implements encoder.Decode.
-func (r *Rwalkgetattr) Decode(b *buffer) {
- r.Valid.Decode(b)
- r.Attr.Decode(b)
+// decode implements encoder.decode.
+func (r *Rwalkgetattr) decode(b *buffer) {
+ r.Valid.decode(b)
+ r.Attr.decode(b)
n := b.Read16()
r.QIDs = r.QIDs[:0]
for i := 0; i < int(n); i++ {
var q QID
- q.Decode(b)
+ q.decode(b)
r.QIDs = append(r.QIDs, q)
}
}
-// Encode implements encoder.Encode.
-func (r *Rwalkgetattr) Encode(b *buffer) {
- r.Valid.Encode(b)
- r.Attr.Encode(b)
+// encode implements encoder.encode.
+func (r *Rwalkgetattr) encode(b *buffer) {
+ r.Valid.encode(b)
+ r.Attr.encode(b)
b.Write16(uint16(len(r.QIDs)))
for _, q := range r.QIDs {
- q.Encode(b)
+ q.encode(b)
}
}
@@ -1968,15 +2208,15 @@ type Tucreate struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tucreate) Decode(b *buffer) {
- t.Tlcreate.Decode(b)
+// decode implements encoder.decode.
+func (t *Tucreate) decode(b *buffer) {
+ t.Tlcreate.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tucreate) Encode(b *buffer) {
- t.Tlcreate.Encode(b)
+// encode implements encoder.encode.
+func (t *Tucreate) encode(b *buffer) {
+ t.Tlcreate.encode(b)
b.WriteUID(t.UID)
}
@@ -2013,15 +2253,15 @@ type Tumkdir struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tumkdir) Decode(b *buffer) {
- t.Tmkdir.Decode(b)
+// decode implements encoder.decode.
+func (t *Tumkdir) decode(b *buffer) {
+ t.Tmkdir.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tumkdir) Encode(b *buffer) {
- t.Tmkdir.Encode(b)
+// encode implements encoder.encode.
+func (t *Tumkdir) encode(b *buffer) {
+ t.Tmkdir.encode(b)
b.WriteUID(t.UID)
}
@@ -2058,15 +2298,15 @@ type Tumknod struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tumknod) Decode(b *buffer) {
- t.Tmknod.Decode(b)
+// decode implements encoder.decode.
+func (t *Tumknod) decode(b *buffer) {
+ t.Tmknod.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tumknod) Encode(b *buffer) {
- t.Tmknod.Encode(b)
+// encode implements encoder.encode.
+func (t *Tumknod) encode(b *buffer) {
+ t.Tmknod.encode(b)
b.WriteUID(t.UID)
}
@@ -2103,15 +2343,15 @@ type Tusymlink struct {
UID UID
}
-// Decode implements encoder.Decode.
-func (t *Tusymlink) Decode(b *buffer) {
- t.Tsymlink.Decode(b)
+// decode implements encoder.decode.
+func (t *Tusymlink) decode(b *buffer) {
+ t.Tsymlink.decode(b)
t.UID = b.ReadUID()
}
-// Encode implements encoder.Encode.
-func (t *Tusymlink) Encode(b *buffer) {
- t.Tsymlink.Encode(b)
+// encode implements encoder.encode.
+func (t *Tusymlink) encode(b *buffer) {
+ t.Tsymlink.encode(b)
b.WriteUID(t.UID)
}
@@ -2149,14 +2389,14 @@ type Tlconnect struct {
Flags ConnectFlags
}
-// Decode implements encoder.Decode.
-func (t *Tlconnect) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tlconnect) decode(b *buffer) {
t.FID = b.ReadFID()
t.Flags = b.ReadConnectFlags()
}
-// Encode implements encoder.Encode.
-func (t *Tlconnect) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tlconnect) encode(b *buffer) {
b.WriteFID(t.FID)
b.WriteConnectFlags(t.Flags)
}
@@ -2176,11 +2416,11 @@ type Rlconnect struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rlconnect) Decode(*buffer) {}
+// decode implements encoder.decode.
+func (r *Rlconnect) decode(*buffer) {}
-// Encode implements encoder.Encode.
-func (r *Rlconnect) Encode(*buffer) {}
+// encode implements encoder.encode.
+func (r *Rlconnect) encode(*buffer) {}
// Type implements message.Type.
func (*Rlconnect) Type() MsgType {
@@ -2203,14 +2443,14 @@ type Tchannel struct {
Control uint32
}
-// Decode implements encoder.Decode.
-func (t *Tchannel) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (t *Tchannel) decode(b *buffer) {
t.ID = b.Read32()
t.Control = b.Read32()
}
-// Encode implements encoder.Encode.
-func (t *Tchannel) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (t *Tchannel) encode(b *buffer) {
b.Write32(t.ID)
b.Write32(t.Control)
}
@@ -2232,14 +2472,14 @@ type Rchannel struct {
filePayload
}
-// Decode implements encoder.Decode.
-func (r *Rchannel) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (r *Rchannel) decode(b *buffer) {
r.Offset = b.Read64()
r.Length = b.Read64()
}
-// Encode implements encoder.Encode.
-func (r *Rchannel) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (r *Rchannel) encode(b *buffer) {
b.Write64(r.Offset)
b.Write64(r.Length)
}
@@ -2266,7 +2506,7 @@ type msgFactory struct {
var msgRegistry registry
type registry struct {
- factories [math.MaxUint8]msgFactory
+ factories [math.MaxUint8 + 1]msgFactory
// largestFixedSize is computed so that given some message size M, you can
// compute the maximum payload size (e.g. for Twrite, Rread) with
@@ -2335,7 +2575,7 @@ func calculateSize(m message) uint32 {
return p.FixedSize()
}
var dataBuf buffer
- m.Encode(&dataBuf)
+ m.encode(&dataBuf)
return uint32(len(dataBuf.data))
}
@@ -2359,10 +2599,18 @@ func init() {
msgRegistry.register(MsgRgetattr, func() message { return &Rgetattr{} })
msgRegistry.register(MsgTsetattr, func() message { return &Tsetattr{} })
msgRegistry.register(MsgRsetattr, func() message { return &Rsetattr{} })
+ msgRegistry.register(MsgTlistxattr, func() message { return &Tlistxattr{} })
+ msgRegistry.register(MsgRlistxattr, func() message { return &Rlistxattr{} })
msgRegistry.register(MsgTxattrwalk, func() message { return &Txattrwalk{} })
msgRegistry.register(MsgRxattrwalk, func() message { return &Rxattrwalk{} })
msgRegistry.register(MsgTxattrcreate, func() message { return &Txattrcreate{} })
msgRegistry.register(MsgRxattrcreate, func() message { return &Rxattrcreate{} })
+ msgRegistry.register(MsgTgetxattr, func() message { return &Tgetxattr{} })
+ msgRegistry.register(MsgRgetxattr, func() message { return &Rgetxattr{} })
+ msgRegistry.register(MsgTsetxattr, func() message { return &Tsetxattr{} })
+ msgRegistry.register(MsgRsetxattr, func() message { return &Rsetxattr{} })
+ msgRegistry.register(MsgTremovexattr, func() message { return &Tremovexattr{} })
+ msgRegistry.register(MsgRremovexattr, func() message { return &Rremovexattr{} })
msgRegistry.register(MsgTreaddir, func() message { return &Treaddir{} })
msgRegistry.register(MsgRreaddir, func() message { return &Rreaddir{} })
msgRegistry.register(MsgTfsync, func() message { return &Tfsync{} })
diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go
index 6ba6a1654..7facc9f5e 100644
--- a/pkg/p9/messages_test.go
+++ b/pkg/p9/messages_test.go
@@ -194,6 +194,21 @@ func TestEncodeDecode(t *testing.T) {
Flags: 3,
},
&Rxattrcreate{},
+ &Tgetxattr{
+ FID: 1,
+ Name: "abc",
+ Size: 2,
+ },
+ &Rgetxattr{
+ Value: "xyz",
+ },
+ &Tsetxattr{
+ FID: 1,
+ Name: "abc",
+ Value: "xyz",
+ Flags: 2,
+ },
+ &Rsetxattr{},
&Treaddir{
Directory: 1,
Offset: 2,
@@ -201,7 +216,7 @@ func TestEncodeDecode(t *testing.T) {
},
&Rreaddir{
// Count must be sufficient to encode a dirent.
- Count: 0x18,
+ Count: 0x1a,
Entries: []Dirent{{QID: QID{Type: 2}}},
},
&Tfsync{
@@ -367,7 +382,7 @@ func TestEncodeDecode(t *testing.T) {
// Encode the original.
data := make([]byte, initialBufferLength)
buf := buffer{data: data[:0]}
- enc.Encode(&buf)
+ enc.encode(&buf)
// Create a new object, same as the first.
enc2 := reflect.New(reflect.ValueOf(enc).Elem().Type()).Interface().(encoder)
@@ -384,7 +399,7 @@ func TestEncodeDecode(t *testing.T) {
}
// Mark sure it was okay.
- enc2.Decode(&buf2)
+ enc2.decode(&buf2)
if buf2.isOverrun() {
t.Errorf("object %#v->%#v got overrun on decode", enc, enc2)
continue
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index 25530adca..122c457d2 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -32,21 +32,21 @@ import (
type OpenFlags uint32
const (
- // ReadOnly is a Topen and Tcreate flag indicating read-only mode.
+ // ReadOnly is a Tlopen and Tlcreate flag indicating read-only mode.
ReadOnly OpenFlags = 0
- // WriteOnly is a Topen and Tcreate flag indicating write-only mode.
+ // WriteOnly is a Tlopen and Tlcreate flag indicating write-only mode.
WriteOnly OpenFlags = 1
- // ReadWrite is a Topen flag indicates read-write mode.
+ // ReadWrite is a Tlopen flag indicates read-write mode.
ReadWrite OpenFlags = 2
// OpenFlagsModeMask is a mask of valid OpenFlags mode bits.
OpenFlagsModeMask OpenFlags = 3
- // OpenFlagsIgnoreMask is a list of OpenFlags mode bits that are ignored for Tlopen.
- // Note that syscall.O_LARGEFILE is set to zero, use value from Linux fcntl.h.
- OpenFlagsIgnoreMask OpenFlags = syscall.O_DIRECTORY | syscall.O_NOATIME | 0100000
+ // OpenTruncate is a Tlopen flag indicating that the opened file should be
+ // truncated.
+ OpenTruncate OpenFlags = 01000
)
// ConnectFlags is the mode passed to Connect operations.
@@ -71,25 +71,32 @@ const (
// OSFlags converts a p9.OpenFlags to an int compatible with open(2).
func (o OpenFlags) OSFlags() int {
- return int(o & OpenFlagsModeMask)
+ // "flags contains Linux open(2) flags bits" - 9P2000.L
+ return int(o)
}
// String implements fmt.Stringer.
func (o OpenFlags) String() string {
- switch o {
+ var buf strings.Builder
+ switch mode := o & OpenFlagsModeMask; mode {
case ReadOnly:
- return "ReadOnly"
+ buf.WriteString("ReadOnly")
case WriteOnly:
- return "WriteOnly"
+ buf.WriteString("WriteOnly")
case ReadWrite:
- return "ReadWrite"
- case OpenFlagsModeMask:
- return "OpenFlagsModeMask"
- case OpenFlagsIgnoreMask:
- return "OpenFlagsIgnoreMask"
+ buf.WriteString("ReadWrite")
default:
- return "UNDEFINED"
+ fmt.Fprintf(&buf, "%#o", mode)
}
+ otherFlags := o &^ OpenFlagsModeMask
+ if otherFlags&OpenTruncate != 0 {
+ buf.WriteString("|OpenTruncate")
+ otherFlags &^= OpenTruncate
+ }
+ if otherFlags != 0 {
+ fmt.Fprintf(&buf, "|%#o", otherFlags)
+ }
+ return buf.String()
}
// Tag is a message tag.
@@ -328,10 +335,18 @@ const (
MsgRgetattr = 25
MsgTsetattr = 26
MsgRsetattr = 27
+ MsgTlistxattr = 28
+ MsgRlistxattr = 29
MsgTxattrwalk = 30
MsgRxattrwalk = 31
MsgTxattrcreate = 32
MsgRxattrcreate = 33
+ MsgTgetxattr = 34
+ MsgRgetxattr = 35
+ MsgTsetxattr = 36
+ MsgRsetxattr = 37
+ MsgTremovexattr = 38
+ MsgRremovexattr = 39
MsgTreaddir = 40
MsgRreaddir = 41
MsgTfsync = 50
@@ -435,15 +450,15 @@ func (q QID) String() string {
return fmt.Sprintf("QID{Type: %d, Version: %d, Path: %d}", q.Type, q.Version, q.Path)
}
-// Decode implements encoder.Decode.
-func (q *QID) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (q *QID) decode(b *buffer) {
q.Type = b.ReadQIDType()
q.Version = b.Read32()
q.Path = b.Read64()
}
-// Encode implements encoder.Encode.
-func (q *QID) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (q *QID) encode(b *buffer) {
b.WriteQIDType(q.Type)
b.Write32(q.Version)
b.Write64(q.Path)
@@ -500,8 +515,8 @@ type FSStat struct {
NameLength uint32
}
-// Decode implements encoder.Decode.
-func (f *FSStat) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (f *FSStat) decode(b *buffer) {
f.Type = b.Read32()
f.BlockSize = b.Read32()
f.Blocks = b.Read64()
@@ -513,8 +528,8 @@ func (f *FSStat) Decode(b *buffer) {
f.NameLength = b.Read32()
}
-// Encode implements encoder.Encode.
-func (f *FSStat) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (f *FSStat) encode(b *buffer) {
b.Write32(f.Type)
b.Write32(f.BlockSize)
b.Write64(f.Blocks)
@@ -664,8 +679,8 @@ func (a AttrMask) String() string {
return fmt.Sprintf("AttrMask{with: %s}", strings.Join(masks, " "))
}
-// Decode implements encoder.Decode.
-func (a *AttrMask) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *AttrMask) decode(b *buffer) {
mask := b.Read64()
a.Mode = mask&0x00000001 != 0
a.NLink = mask&0x00000002 != 0
@@ -683,8 +698,8 @@ func (a *AttrMask) Decode(b *buffer) {
a.DataVersion = mask&0x00002000 != 0
}
-// Encode implements encoder.Encode.
-func (a *AttrMask) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *AttrMask) encode(b *buffer) {
var mask uint64
if a.Mode {
mask |= 0x00000001
@@ -759,8 +774,8 @@ func (a Attr) String() string {
a.Mode, a.UID, a.GID, a.NLink, a.RDev, a.Size, a.BlockSize, a.Blocks, a.ATimeSeconds, a.ATimeNanoSeconds, a.MTimeSeconds, a.MTimeNanoSeconds, a.CTimeSeconds, a.CTimeNanoSeconds, a.BTimeSeconds, a.BTimeNanoSeconds, a.Gen, a.DataVersion)
}
-// Encode implements encoder.Encode.
-func (a *Attr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *Attr) encode(b *buffer) {
b.WriteFileMode(a.Mode)
b.WriteUID(a.UID)
b.WriteGID(a.GID)
@@ -781,8 +796,8 @@ func (a *Attr) Encode(b *buffer) {
b.Write64(a.DataVersion)
}
-// Decode implements encoder.Decode.
-func (a *Attr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *Attr) decode(b *buffer) {
a.Mode = b.ReadFileMode()
a.UID = b.ReadUID()
a.GID = b.ReadGID()
@@ -814,7 +829,7 @@ func StatToAttr(s *syscall.Stat_t, req AttrMask) (Attr, AttrMask) {
attr.Mode = FileMode(s.Mode)
}
if req.NLink {
- attr.NLink = s.Nlink
+ attr.NLink = uint64(s.Nlink)
}
if req.UID {
attr.UID = UID(s.Uid)
@@ -911,8 +926,8 @@ func (s SetAttrMask) Empty() bool {
return !s.Permissions && !s.UID && !s.GID && !s.Size && !s.ATime && !s.MTime && !s.CTime && !s.ATimeNotSystemTime && !s.MTimeNotSystemTime
}
-// Decode implements encoder.Decode.
-func (s *SetAttrMask) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (s *SetAttrMask) decode(b *buffer) {
mask := b.Read32()
s.Permissions = mask&0x00000001 != 0
s.UID = mask&0x00000002 != 0
@@ -957,8 +972,8 @@ func (s SetAttrMask) bitmask() uint32 {
return mask
}
-// Encode implements encoder.Encode.
-func (s *SetAttrMask) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (s *SetAttrMask) encode(b *buffer) {
b.Write32(s.bitmask())
}
@@ -979,8 +994,8 @@ func (s SetAttr) String() string {
return fmt.Sprintf("SetAttr{Permissions: 0o%o, UID: %d, GID: %d, Size: %d, ATime: {Sec: %d, NanoSec: %d}, MTime: {Sec: %d, NanoSec: %d}}", s.Permissions, s.UID, s.GID, s.Size, s.ATimeSeconds, s.ATimeNanoSeconds, s.MTimeSeconds, s.MTimeNanoSeconds)
}
-// Decode implements encoder.Decode.
-func (s *SetAttr) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (s *SetAttr) decode(b *buffer) {
s.Permissions = b.ReadPermissions()
s.UID = b.ReadUID()
s.GID = b.ReadGID()
@@ -991,8 +1006,8 @@ func (s *SetAttr) Decode(b *buffer) {
s.MTimeNanoSeconds = b.Read64()
}
-// Encode implements encoder.Encode.
-func (s *SetAttr) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (s *SetAttr) encode(b *buffer) {
b.WritePermissions(s.Permissions)
b.WriteUID(s.UID)
b.WriteGID(s.GID)
@@ -1049,17 +1064,17 @@ func (d Dirent) String() string {
return fmt.Sprintf("Dirent{QID: %d, Offset: %d, Type: 0x%X, Name: %s}", d.QID, d.Offset, d.Type, d.Name)
}
-// Decode implements encoder.Decode.
-func (d *Dirent) Decode(b *buffer) {
- d.QID.Decode(b)
+// decode implements encoder.decode.
+func (d *Dirent) decode(b *buffer) {
+ d.QID.decode(b)
d.Offset = b.Read64()
d.Type = b.ReadQIDType()
d.Name = b.ReadString()
}
-// Encode implements encoder.Encode.
-func (d *Dirent) Encode(b *buffer) {
- d.QID.Encode(b)
+// encode implements encoder.encode.
+func (d *Dirent) encode(b *buffer) {
+ d.QID.encode(b)
b.Write64(d.Offset)
b.WriteQIDType(d.Type)
b.WriteString(d.Name)
@@ -1076,6 +1091,19 @@ type AllocateMode struct {
Unshare bool
}
+// ToAllocateMode returns an AllocateMode from a fallocate(2) mode.
+func ToAllocateMode(mode uint64) AllocateMode {
+ return AllocateMode{
+ KeepSize: mode&unix.FALLOC_FL_KEEP_SIZE != 0,
+ PunchHole: mode&unix.FALLOC_FL_PUNCH_HOLE != 0,
+ NoHideStale: mode&unix.FALLOC_FL_NO_HIDE_STALE != 0,
+ CollapseRange: mode&unix.FALLOC_FL_COLLAPSE_RANGE != 0,
+ ZeroRange: mode&unix.FALLOC_FL_ZERO_RANGE != 0,
+ InsertRange: mode&unix.FALLOC_FL_INSERT_RANGE != 0,
+ Unshare: mode&unix.FALLOC_FL_UNSHARE_RANGE != 0,
+ }
+}
+
// ToLinux converts to a value compatible with fallocate(2)'s mode.
func (a *AllocateMode) ToLinux() uint32 {
rv := uint32(0)
@@ -1103,8 +1131,8 @@ func (a *AllocateMode) ToLinux() uint32 {
return rv
}
-// Decode implements encoder.Decode.
-func (a *AllocateMode) Decode(b *buffer) {
+// decode implements encoder.decode.
+func (a *AllocateMode) decode(b *buffer) {
mask := b.Read32()
a.KeepSize = mask&0x01 != 0
a.PunchHole = mask&0x02 != 0
@@ -1115,8 +1143,8 @@ func (a *AllocateMode) Decode(b *buffer) {
a.Unshare = mask&0x40 != 0
}
-// Encode implements encoder.Encode.
-func (a *AllocateMode) Encode(b *buffer) {
+// encode implements encoder.encode.
+func (a *AllocateMode) encode(b *buffer) {
mask := uint32(0)
if a.KeepSize {
mask |= 0x01
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 28707c0ca..7ca67cb19 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
+load("//tools:defs.bzl", "go_binary", "go_library", "go_test")
package(licenses = ["notice"])
@@ -64,12 +63,12 @@ go_library(
"mocks.go",
"p9test.go",
],
- importpath = "gvisor.dev/gvisor/pkg/p9/p9test",
visibility = ["//:sandbox"],
deps = [
"//pkg/fd",
"//pkg/log",
"//pkg/p9",
+ "//pkg/sync",
"//pkg/unet",
"@com_github_golang_mock//gomock:go_default_library",
],
@@ -79,10 +78,11 @@ go_test(
name = "client_test",
size = "medium",
srcs = ["client_test.go"],
- embed = [":p9test"],
+ library = ":p9test",
deps = [
"//pkg/fd",
"//pkg/p9",
+ "//pkg/sync",
"@com_github_golang_mock//gomock:go_default_library",
],
)
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index 8bbdb2488..6e7bb3db2 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -22,7 +22,6 @@ import (
"os"
"reflect"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -30,6 +29,7 @@ import (
"github.com/golang/mock/gomock"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestPanic(t *testing.T) {
@@ -1044,11 +1044,11 @@ func TestReaddir(t *testing.T) {
if _, err := f.Readdir(0, 1); err != syscall.EINVAL {
t.Errorf("readdir got %v, wanted EINVAL", err)
}
- if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EINVAL {
- t.Errorf("readdir got %v, wanted EINVAL", err)
+ if _, _, _, err := f.Open(p9.ReadWrite); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
}
- if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EINVAL {
- t.Errorf("readdir got %v, wanted EINVAL", err)
+ if _, _, _, err := f.Open(p9.WriteOnly); err != syscall.EISDIR {
+ t.Errorf("readdir got %v, wanted EISDIR", err)
}
backend.EXPECT().Open(p9.ReadOnly).Times(1)
if _, _, _, err := f.Open(p9.ReadOnly); err != nil {
@@ -1065,75 +1065,93 @@ func TestReaddir(t *testing.T) {
func TestOpen(t *testing.T) {
type openTest struct {
name string
- mode p9.OpenFlags
+ flags p9.OpenFlags
err error
match func(p9.FileMode) bool
}
cases := []openTest{
{
- name: "invalid",
- mode: ^p9.OpenFlagsModeMask,
- err: syscall.EINVAL,
- match: func(p9.FileMode) bool { return true },
- },
- {
name: "not-openable-read-only",
- mode: p9.ReadOnly,
+ flags: p9.ReadOnly,
err: syscall.EINVAL,
match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
},
{
name: "not-openable-write-only",
- mode: p9.WriteOnly,
+ flags: p9.WriteOnly,
err: syscall.EINVAL,
match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
},
{
name: "not-openable-read-write",
- mode: p9.ReadWrite,
+ flags: p9.ReadWrite,
err: syscall.EINVAL,
match: func(mode p9.FileMode) bool { return !p9.CanOpen(mode) },
},
{
name: "directory-read-only",
- mode: p9.ReadOnly,
+ flags: p9.ReadOnly,
err: nil,
match: func(mode p9.FileMode) bool { return mode.IsDir() },
},
{
name: "directory-read-write",
- mode: p9.ReadWrite,
- err: syscall.EINVAL,
+ flags: p9.ReadWrite,
+ err: syscall.EISDIR,
match: func(mode p9.FileMode) bool { return mode.IsDir() },
},
{
name: "directory-write-only",
- mode: p9.WriteOnly,
- err: syscall.EINVAL,
+ flags: p9.WriteOnly,
+ err: syscall.EISDIR,
match: func(mode p9.FileMode) bool { return mode.IsDir() },
},
{
name: "read-only",
- mode: p9.ReadOnly,
+ flags: p9.ReadOnly,
err: nil,
match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) },
},
{
name: "write-only",
- mode: p9.WriteOnly,
+ flags: p9.WriteOnly,
err: nil,
match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
},
{
name: "read-write",
- mode: p9.ReadWrite,
+ flags: p9.ReadWrite,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "directory-read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: syscall.EISDIR,
+ match: func(mode p9.FileMode) bool { return mode.IsDir() },
+ },
+ {
+ name: "read-only-truncate",
+ flags: p9.ReadOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "write-only-truncate",
+ flags: p9.WriteOnly | p9.OpenTruncate,
+ err: nil,
+ match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
+ },
+ {
+ name: "read-write-truncate",
+ flags: p9.ReadWrite | p9.OpenTruncate,
err: nil,
match: func(mode p9.FileMode) bool { return p9.CanOpen(mode) && !mode.IsDir() },
},
}
- // Open(mode OpenFlags) (*fd.FD, QID, uint32, error)
+ // Open(flags OpenFlags) (*fd.FD, QID, uint32, error)
// - only works on Regular, NamedPipe, BLockDevice, CharacterDevice
// - returning a file works as expected
for name := range newTypeMap(nil) {
@@ -1171,25 +1189,25 @@ func TestOpen(t *testing.T) {
// Attempt the given open.
if tc.err != nil {
// We expect an error, just test and return.
- if _, _, _, err := f.Open(tc.mode); err != tc.err {
- t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err)
+ if _, _, _, err := f.Open(tc.flags); err != tc.err {
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
}
return
}
// Run an FD test, since we expect success.
fdTest(t, func(send *fd.FD) *fd.FD {
- backend.EXPECT().Open(tc.mode).Return(send, p9.QID{}, uint32(0), nil).Times(1)
- recv, _, _, err := f.Open(tc.mode)
+ backend.EXPECT().Open(tc.flags).Return(send, p9.QID{}, uint32(0), nil).Times(1)
+ recv, _, _, err := f.Open(tc.flags)
if err != tc.err {
- t.Fatalf("open with mode %v got %v, want %v", tc.mode, err, tc.err)
+ t.Fatalf("open with flags %v got %v, want %v", tc.flags, err, tc.err)
}
return recv
})
// If the open was successful, attempt another one.
- if _, _, _, err := f.Open(tc.mode); err != syscall.EINVAL {
- t.Errorf("second open with mode %v got %v, want EINVAL", tc.mode, err)
+ if _, _, _, err := f.Open(tc.flags); err != syscall.EINVAL {
+ t.Errorf("second open with flags %v got %v, want EINVAL", tc.flags, err)
}
// Ensure that all illegal operations fail.
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 4d3271b37..dd8b01b6d 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -17,13 +17,13 @@ package p9test
import (
"fmt"
- "sync"
"sync/atomic"
"syscall"
"testing"
"github.com/golang/mock/gomock"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/p9/path_tree.go b/pkg/p9/path_tree.go
index 865459411..72ef53313 100644
--- a/pkg/p9/path_tree.go
+++ b/pkg/p9/path_tree.go
@@ -16,7 +16,8 @@ package p9
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// pathNode is a single node in a path traversal.
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index e717e6161..60cf94fa1 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -17,7 +17,6 @@ package p9
import (
"io"
"runtime/debug"
- "sync"
"sync/atomic"
"syscall"
@@ -25,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/fdchannel"
"gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -453,7 +453,11 @@ func (cs *connState) initializeChannels() (err error) {
go func() { // S/R-SAFE: Server side.
defer cs.channelWg.Done()
if err := res.service(cs); err != nil {
- log.Warningf("p9.channel.service: %v", err)
+ // Don't log flipcall.ShutdownErrors, which we expect to be
+ // returned during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("p9.channel.service: %v", err)
+ }
}
}()
}
@@ -478,10 +482,10 @@ func (cs *connState) handle(m message) (r message) {
defer func() {
if r == nil {
// Don't allow a panic to propagate.
- recover()
+ err := recover()
// Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ log.Warningf("panic in handler: %v\n%s", err, debug.Stack())
// Wrap in an EFAULT error; we don't really have a
// better way to describe this kind of error. It will
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 6e8b4bbcd..02e665345 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -19,11 +19,11 @@ import (
"fmt"
"io"
"io/ioutil"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -66,21 +66,24 @@ const (
var dataPool = sync.Pool{
New: func() interface{} {
// These buffers are used for decoding without a payload.
- return make([]byte, initialBufferLength)
+ // We need to return a pointer to avoid unnecessary allocations
+ // (see https://staticcheck.io/docs/checks#SA6002).
+ b := make([]byte, initialBufferLength)
+ return &b
},
}
// send sends the given message over the socket.
func send(s *unet.Socket, tag Tag, m message) error {
- data := dataPool.Get().([]byte)
- dataBuf := buffer{data: data[:0]}
+ data := dataPool.Get().(*[]byte)
+ dataBuf := buffer{data: (*data)[:0]}
if log.IsLogging(log.Debug) {
log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
}
// Encode the message. The buffer will grow automatically.
- m.Encode(&dataBuf)
+ m.encode(&dataBuf)
// Get our vectors to send.
var hdr [headerLength]byte
@@ -141,7 +144,7 @@ func send(s *unet.Socket, tag Tag, m message) error {
}
// All set.
- dataPool.Put(dataBuf.data)
+ dataPool.Put(&dataBuf.data)
return nil
}
@@ -227,12 +230,29 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message,
// Not yet initialized.
var dataBuf buffer
+ var vecs [][]byte
+
+ appendBuffer := func(size int) *[]byte {
+ // Pull a data buffer from the pool.
+ datap := dataPool.Get().(*[]byte)
+ data := *datap
+ if size > len(data) {
+ // Create a larger data buffer.
+ data = make([]byte, size)
+ datap = &data
+ } else {
+ // Limit the data buffer.
+ data = data[:size]
+ }
+ dataBuf = buffer{data: data}
+ vecs = append(vecs, data)
+ return datap
+ }
// Read the rest of the payload.
//
// This requires some special care to ensure that the vectors all line
// up the way they should. We do this to minimize copying data around.
- var vecs [][]byte
if payloader, ok := m.(payloader); ok {
fixedSize := payloader.FixedSize()
@@ -246,22 +266,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message,
}
if fixedSize != 0 {
- // Pull a data buffer from the pool.
- data := dataPool.Get().([]byte)
- if int(fixedSize) > len(data) {
- // Create a larger data buffer, ensuring
- // sufficient capicity for the message.
- data = make([]byte, fixedSize)
- defer dataPool.Put(data)
- dataBuf = buffer{data: data}
- vecs = append(vecs, data)
- } else {
- // Limit the data buffer, and make sure it
- // gets filled before the payload buffer.
- defer dataPool.Put(data)
- dataBuf = buffer{data: data[:fixedSize]}
- vecs = append(vecs, data[:fixedSize])
- }
+ datap := appendBuffer(int(fixedSize))
+ defer dataPool.Put(datap)
}
// Include the payload.
@@ -274,20 +280,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message,
vecs = append(vecs, p)
}
} else if remaining != 0 {
- // Pull a data buffer from the pool.
- data := dataPool.Get().([]byte)
- if int(remaining) > len(data) {
- // Create a larger data buffer.
- data = make([]byte, remaining)
- defer dataPool.Put(data)
- dataBuf = buffer{data: data}
- vecs = append(vecs, data)
- } else {
- // Limit the data buffer.
- defer dataPool.Put(data)
- dataBuf = buffer{data: data[:remaining]}
- vecs = append(vecs, data[:remaining])
- }
+ datap := appendBuffer(int(remaining))
+ defer dataPool.Put(datap)
}
if len(vecs) > 0 {
@@ -316,7 +310,7 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message,
}
// Decode the message data.
- m.Decode(&dataBuf)
+ m.decode(&dataBuf)
if dataBuf.isOverrun() {
// No need to drain the socket.
return NoTag, nil, ErrNoValidMessage
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
index 233f825e3..38038abdf 100644
--- a/pkg/p9/transport_flipcall.go
+++ b/pkg/p9/transport_flipcall.go
@@ -151,7 +151,7 @@ func (ch *channel) send(m message) (uint32, error) {
} else {
ch.buf.Write8(0) // No incoming FD.
}
- m.Encode(&ch.buf)
+ m.encode(&ch.buf)
ssz := uint32(len(ch.buf.data)) // Updated below.
// Is there a payload?
@@ -205,7 +205,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
ch.buf.data = ch.buf.data[:fs]
}
- r.Decode(&ch.buf)
+ r.decode(&ch.buf)
if ch.buf.isOverrun() {
// Nothing valid was available.
log.Debugf("recv [got %d bytes, needed more]", rsz)
@@ -236,7 +236,7 @@ func (ch *channel) recv(r message, rsz uint32) (message, error) {
// Convert errors appropriately; see above.
if rlerr, ok := r.(*Rlerror); ok {
- return nil, syscall.Errno(rlerr.Error)
+ return r, syscall.Errno(rlerr.Error)
}
return r, nil
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index 2f50ff3ea..e7406b374 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -56,8 +56,8 @@ func TestSendRecv(t *testing.T) {
// badDecode overruns on decode.
type badDecode struct{}
-func (*badDecode) Decode(b *buffer) { b.markOverrun() }
-func (*badDecode) Encode(b *buffer) {}
+func (*badDecode) decode(b *buffer) { b.markOverrun() }
+func (*badDecode) encode(b *buffer) {}
func (*badDecode) Type() MsgType { return MsgTypeBadDecode }
func (*badDecode) String() string { return "badDecode{}" }
@@ -81,8 +81,8 @@ func TestRecvOverrun(t *testing.T) {
// unregistered is not registered on decode.
type unregistered struct{}
-func (*unregistered) Decode(b *buffer) {}
-func (*unregistered) Encode(b *buffer) {}
+func (*unregistered) decode(b *buffer) {}
+func (*unregistered) encode(b *buffer) {}
func (*unregistered) Type() MsgType { return MsgTypeUnregistered }
func (*unregistered) String() string { return "unregistered{}" }
@@ -182,6 +182,8 @@ func TestSendClosed(t *testing.T) {
}
func BenchmarkSendRecv(b *testing.B) {
+ b.ReportAllocs()
+
server, client, err := unet.SocketPair(false)
if err != nil {
b.Fatalf("socketpair got err %v expected nil", err)
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index f1ffdd23a..09cde9f5a 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 8
+ highestSupportedVersion uint32 = 11
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -155,3 +155,21 @@ func versionSupportsTallocate(v uint32) bool {
func versionSupportsFlipcall(v uint32) bool {
return v >= 8
}
+
+// VersionSupportsOpenTruncateFlag returns true if version v supports
+// passing the OpenTruncate flag to Tlopen.
+func VersionSupportsOpenTruncateFlag(v uint32) bool {
+ return v >= 9
+}
+
+// versionSupportsGetSetXattr returns true if version v supports
+// the Tgetxattr and Tsetxattr messages.
+func versionSupportsGetSetXattr(v uint32) bool {
+ return v >= 10
+}
+
+// versionSupportsListRemoveXattr returns true if version v supports
+// the Tlistxattr and Tremovexattr messages.
+func versionSupportsListRemoveXattr(v uint32) bool {
+ return v >= 11
+}
diff --git a/pkg/pool/BUILD b/pkg/pool/BUILD
new file mode 100644
index 000000000..7b1c6b75b
--- /dev/null
+++ b/pkg/pool/BUILD
@@ -0,0 +1,25 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "pool",
+ srcs = [
+ "pool.go",
+ ],
+ deps = [
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "pool_test",
+ size = "small",
+ srcs = [
+ "pool_test.go",
+ ],
+ library = ":pool",
+)
diff --git a/pkg/p9/pool.go b/pkg/pool/pool.go
index 52de889e1..a1b2e0cfe 100644
--- a/pkg/p9/pool.go
+++ b/pkg/pool/pool.go
@@ -12,33 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
-// pool is a simple allocator.
-//
-// It is used for both tags and FIDs.
-type pool struct {
+// Pool is a simple allocator.
+type Pool struct {
mu sync.Mutex
// cache is the set of returned values.
cache []uint64
- // start is the starting value (if needed).
- start uint64
+ // Start is the starting value (if needed).
+ Start uint64
// max is the current maximum issued.
max uint64
- // limit is the upper limit.
- limit uint64
+ // Limit is the upper limit.
+ Limit uint64
}
// Get gets a value from the pool.
-func (p *pool) Get() (uint64, bool) {
+func (p *Pool) Get() (uint64, bool) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -50,18 +48,18 @@ func (p *pool) Get() (uint64, bool) {
}
// Over the limit?
- if p.start == p.limit {
+ if p.Start == p.Limit {
return 0, false
}
// Generate a new value.
- v := p.start
- p.start++
+ v := p.Start
+ p.Start++
return v, true
}
// Put returns a value to the pool.
-func (p *pool) Put(v uint64) {
+func (p *Pool) Put(v uint64) {
p.mu.Lock()
p.cache = append(p.cache, v)
p.mu.Unlock()
diff --git a/pkg/p9/pool_test.go b/pkg/pool/pool_test.go
index e4746b8da..d928439c1 100644
--- a/pkg/p9/pool_test.go
+++ b/pkg/pool/pool_test.go
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package p9
+package pool
import (
"testing"
)
func TestPoolUnique(t *testing.T) {
- p := pool{start: 1, limit: 3}
+ p := Pool{Start: 1, Limit: 3}
got := make(map[uint64]bool)
for {
@@ -39,7 +39,7 @@ func TestPoolUnique(t *testing.T) {
}
func TestExausted(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
for i := 0; i < 499; i++ {
_, ok := p.Get()
if !ok {
@@ -54,7 +54,7 @@ func TestExausted(t *testing.T) {
}
func TestPoolRecycle(t *testing.T) {
- p := pool{start: 1, limit: 500}
+ p := Pool{Start: 1, Limit: 500}
n1, _ := p.Get()
p.Put(n1)
n2, _ := p.Get()
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
index 078f084b2..aa3e3ac0b 100644
--- a/pkg/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,7 +9,6 @@ go_library(
"procid_amd64.s",
"procid_arm64.s",
],
- importpath = "gvisor.dev/gvisor/pkg/procid",
visibility = ["//visibility:public"],
)
@@ -20,7 +18,8 @@ go_test(
srcs = [
"procid_test.go",
],
- embed = [":procid"],
+ library = ":procid",
+ deps = ["//pkg/sync"],
)
go_test(
@@ -30,5 +29,6 @@ go_test(
"procid_net_test.go",
"procid_test.go",
],
- embed = [":procid"],
+ library = ":procid",
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s
index 30ec8e6e2..7c622e5d7 100644
--- a/pkg/procid/procid_amd64.s
+++ b/pkg/procid/procid_amd64.s
@@ -14,7 +14,7 @@
// +build amd64
// +build go1.8
-// +build !go1.14
+// +build !go1.16
#include "textflag.h"
diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s
index e340d9f98..48ebb5fd1 100644
--- a/pkg/procid/procid_arm64.s
+++ b/pkg/procid/procid_arm64.s
@@ -14,7 +14,7 @@
// +build arm64
// +build go1.8
-// +build !go1.14
+// +build !go1.16
#include "textflag.h"
diff --git a/pkg/procid/procid_test.go b/pkg/procid/procid_test.go
index 88dd0b3ae..9ec08c3d6 100644
--- a/pkg/procid/procid_test.go
+++ b/pkg/procid/procid_test.go
@@ -17,9 +17,10 @@ package procid
import (
"os"
"runtime"
- "sync"
"syscall"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// runOnMain is used to send functions to run on the main (initial) thread.
diff --git a/pkg/rand/BUILD b/pkg/rand/BUILD
index f4f2001f3..80b8ceb02 100644
--- a/pkg/rand/BUILD
+++ b/pkg/rand/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,7 +8,9 @@ go_library(
"rand.go",
"rand_linux.go",
],
- importpath = "gvisor.dev/gvisor/pkg/rand",
visibility = ["//:sandbox"],
- deps = ["@org_golang_x_sys//unix:go_default_library"],
+ deps = [
+ "//pkg/sync",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
)
diff --git a/pkg/rand/rand_linux.go b/pkg/rand/rand_linux.go
index 2b92db3e6..fa6a21026 100644
--- a/pkg/rand/rand_linux.go
+++ b/pkg/rand/rand_linux.go
@@ -17,11 +17,12 @@
package rand
import (
+ "bufio"
"crypto/rand"
"io"
- "sync"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/sync"
)
// reader implements an io.Reader that returns pseudorandom bytes.
@@ -45,8 +46,22 @@ func (r *reader) Read(p []byte) (int, error) {
return rand.Read(p)
}
+// bufferedReader implements a threadsafe buffered io.Reader.
+type bufferedReader struct {
+ mu sync.Mutex
+ r *bufio.Reader
+}
+
+// Read implements io.Reader.Read.
+func (b *bufferedReader) Read(p []byte) (int, error) {
+ b.mu.Lock()
+ n, err := b.r.Read(p)
+ b.mu.Unlock()
+ return n, err
+}
+
// Reader is the default reader.
-var Reader io.Reader = &reader{}
+var Reader io.Reader = &bufferedReader{r: bufio.NewReader(&reader{})}
// Read reads from the default reader.
func Read(b []byte) (int, error) {
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 7ad59dfd7..9888cce9c 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,10 +22,11 @@ go_library(
"refcounter_state.go",
"weak_ref_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/refs",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/context",
"//pkg/log",
+ "//pkg/sync",
],
)
@@ -34,5 +34,9 @@ go_test(
name = "refs_test",
size = "small",
srcs = ["refcounter_test.go"],
- embed = [":refs"],
+ library = ":refs",
+ deps = [
+ "//pkg/context",
+ "//pkg/sync",
+ ],
)
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index ad69e0757..d9d5e6bcb 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -21,10 +21,11 @@ import (
"fmt"
"reflect"
"runtime"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
)
// RefCounter is the interface to be implemented by objects that are reference
@@ -38,7 +39,7 @@ type RefCounter interface {
// Note that AtomicRefCounter.DecRef() does not support destructors.
// If a type has a destructor, it must implement its own DecRef()
// method and call AtomicRefCounter.DecRefWithDestructor(destructor).
- DecRef()
+ DecRef(ctx context.Context)
// TryIncRef attempts to increase the reference counter on the object,
// but may fail if all references have already been dropped. This
@@ -57,7 +58,7 @@ type RefCounter interface {
// A WeakRefUser is notified when the last non-weak reference is dropped.
type WeakRefUser interface {
// WeakRefGone is called when the last non-weak reference is dropped.
- WeakRefGone()
+ WeakRefGone(ctx context.Context)
}
// WeakRef is a weak reference.
@@ -123,7 +124,7 @@ func (w *WeakRef) Get() RefCounter {
// Drop drops this weak reference. You should always call drop when you are
// finished with the weak reference. You may not use this object after calling
// drop.
-func (w *WeakRef) Drop() {
+func (w *WeakRef) Drop(ctx context.Context) {
rc, ok := w.get()
if !ok {
// We've been zapped already. When the refcounter has called
@@ -145,7 +146,7 @@ func (w *WeakRef) Drop() {
// And now aren't on the object's list of weak references. So it won't
// zap us if this causes the reference count to drop to zero.
- rc.DecRef()
+ rc.DecRef(ctx)
// Return to the pool.
weakRefPool.Put(w)
@@ -214,6 +215,8 @@ type AtomicRefCount struct {
// LeakMode configures the leak checker.
type LeakMode uint32
+// TODO(gvisor.dev/issue/1624): Simplify down to two modes once vfs1 ref
+// counting is gone.
const (
// UninitializedLeakChecking indicates that the leak checker has not yet been initialized.
UninitializedLeakChecking LeakMode = iota
@@ -243,6 +246,11 @@ func SetLeakMode(mode LeakMode) {
atomic.StoreUint32(&leakMode, uint32(mode))
}
+// GetLeakMode returns the current leak mode.
+func GetLeakMode() LeakMode {
+ return LeakMode(atomic.LoadUint32(&leakMode))
+}
+
const maxStackFrames = 40
type fileLine struct {
@@ -427,7 +435,7 @@ func (r *AtomicRefCount) dropWeakRef(w *WeakRef) {
// A: TryIncRef [transform speculative to real]
//
//go:nosplit
-func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
+func (r *AtomicRefCount) DecRefWithDestructor(ctx context.Context, destroy func(context.Context)) {
switch v := atomic.AddInt64(&r.refCount, -1); {
case v < -1:
panic("Decrementing non-positive ref count")
@@ -448,7 +456,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
if user != nil {
r.mu.Unlock()
- user.WeakRefGone()
+ user.WeakRefGone(ctx)
r.mu.Lock()
}
}
@@ -456,7 +464,7 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
// Call the destructor.
if destroy != nil {
- destroy()
+ destroy(ctx)
}
}
}
@@ -464,6 +472,16 @@ func (r *AtomicRefCount) DecRefWithDestructor(destroy func()) {
// DecRef decrements this object's reference count.
//
//go:nosplit
-func (r *AtomicRefCount) DecRef() {
- r.DecRefWithDestructor(nil)
+func (r *AtomicRefCount) DecRef(ctx context.Context) {
+ r.DecRefWithDestructor(ctx, nil)
+}
+
+// OnExit is called on sandbox exit. It runs GC to enqueue refcount finalizers,
+// which check for reference leaks. There is no way to guarantee that every
+// finalizer will run before exiting, but this at least ensures that they will
+// be discovered/enqueued by GC.
+func OnExit() {
+ if LeakMode(atomic.LoadUint32(&leakMode)) != NoLeakChecking {
+ runtime.GC()
+ }
}
diff --git a/pkg/refs/refcounter_test.go b/pkg/refs/refcounter_test.go
index ffd3d3f07..6d0dd1018 100644
--- a/pkg/refs/refcounter_test.go
+++ b/pkg/refs/refcounter_test.go
@@ -16,8 +16,10 @@ package refs
import (
"reflect"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
type testCounter struct {
@@ -30,11 +32,11 @@ type testCounter struct {
destroyed bool
}
-func (t *testCounter) DecRef() {
- t.AtomicRefCount.DecRefWithDestructor(t.destroy)
+func (t *testCounter) DecRef(ctx context.Context) {
+ t.AtomicRefCount.DecRefWithDestructor(ctx, t.destroy)
}
-func (t *testCounter) destroy() {
+func (t *testCounter) destroy(context.Context) {
t.mu.Lock()
defer t.mu.Unlock()
t.destroyed = true
@@ -52,7 +54,7 @@ func newTestCounter() *testCounter {
func TestOneRef(t *testing.T) {
tc := newTestCounter()
- tc.DecRef()
+ tc.DecRef(context.Background())
if !tc.IsDestroyed() {
t.Errorf("object should have been destroyed")
@@ -62,8 +64,9 @@ func TestOneRef(t *testing.T) {
func TestTwoRefs(t *testing.T) {
tc := newTestCounter()
tc.IncRef()
- tc.DecRef()
- tc.DecRef()
+ ctx := context.Background()
+ tc.DecRef(ctx)
+ tc.DecRef(ctx)
if !tc.IsDestroyed() {
t.Errorf("object should have been destroyed")
@@ -73,12 +76,13 @@ func TestTwoRefs(t *testing.T) {
func TestMultiRefs(t *testing.T) {
tc := newTestCounter()
tc.IncRef()
- tc.DecRef()
+ ctx := context.Background()
+ tc.DecRef(ctx)
tc.IncRef()
- tc.DecRef()
+ tc.DecRef(ctx)
- tc.DecRef()
+ tc.DecRef(ctx)
if !tc.IsDestroyed() {
t.Errorf("object should have been destroyed")
@@ -88,19 +92,20 @@ func TestMultiRefs(t *testing.T) {
func TestWeakRef(t *testing.T) {
tc := newTestCounter()
w := NewWeakRef(tc, nil)
+ ctx := context.Background()
// Try resolving.
if x := w.Get(); x == nil {
t.Errorf("weak reference didn't resolve: expected %v, got nil", tc)
} else {
- x.DecRef()
+ x.DecRef(ctx)
}
// Try resolving again.
if x := w.Get(); x == nil {
t.Errorf("weak reference didn't resolve: expected %v, got nil", tc)
} else {
- x.DecRef()
+ x.DecRef(ctx)
}
// Shouldn't be destroyed yet. (Can't continue if this fails.)
@@ -109,7 +114,7 @@ func TestWeakRef(t *testing.T) {
}
// Drop the original reference.
- tc.DecRef()
+ tc.DecRef(ctx)
// Assert destroyed.
if !tc.IsDestroyed() {
@@ -125,7 +130,8 @@ func TestWeakRef(t *testing.T) {
func TestWeakRefDrop(t *testing.T) {
tc := newTestCounter()
w := NewWeakRef(tc, nil)
- w.Drop()
+ ctx := context.Background()
+ w.Drop(ctx)
// Just assert the list is empty.
if !tc.weakRefs.Empty() {
@@ -133,14 +139,14 @@ func TestWeakRefDrop(t *testing.T) {
}
// Drop the original reference.
- tc.DecRef()
+ tc.DecRef(ctx)
}
type testWeakRefUser struct {
weakRefGone func()
}
-func (u *testWeakRefUser) WeakRefGone() {
+func (u *testWeakRefUser) WeakRefGone(ctx context.Context) {
u.weakRefGone()
}
@@ -164,7 +170,8 @@ func TestCallback(t *testing.T) {
}})
// Drop the original reference, this must trigger the callback.
- tc.DecRef()
+ ctx := context.Background()
+ tc.DecRef(ctx)
if !called {
t.Fatalf("Callback not called")
diff --git a/pkg/refs_vfs2/BUILD b/pkg/refs_vfs2/BUILD
new file mode 100644
index 000000000..7b3e10683
--- /dev/null
+++ b/pkg/refs_vfs2/BUILD
@@ -0,0 +1,26 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template")
+
+package(licenses = ["notice"])
+
+go_template(
+ name = "refs_template",
+ srcs = [
+ "refs_template.go",
+ ],
+ types = [
+ "T",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/log",
+ "//pkg/refs",
+ ],
+)
+
+go_library(
+ name = "refs_vfs2",
+ srcs = ["refs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/context"],
+)
diff --git a/pkg/refs_vfs2/refs.go b/pkg/refs_vfs2/refs.go
new file mode 100644
index 000000000..99a074e96
--- /dev/null
+++ b/pkg/refs_vfs2/refs.go
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package refs_vfs2 defines an interface for a reference-counted object.
+package refs_vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// RefCounter is the interface to be implemented by objects that are reference
+// counted.
+type RefCounter interface {
+ // IncRef increments the reference counter on the object.
+ IncRef()
+
+ // DecRef decrements the object's reference count. Users of refs_template.Refs
+ // may specify a destructor to be called once the reference count reaches zero.
+ DecRef(ctx context.Context)
+
+ // TryIncRef attempts to increment the reference count, but may fail if all
+ // references have already been dropped, in which case it returns false. If
+ // true is returned, then a valid reference is now held on the object.
+ TryIncRef() bool
+}
diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refs_vfs2/refs_template.go
new file mode 100644
index 000000000..99c43c065
--- /dev/null
+++ b/pkg/refs_vfs2/refs_template.go
@@ -0,0 +1,133 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package refs_template defines a template that can be used by reference counted
+// objects.
+package refs_template
+
+import (
+ "runtime"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+)
+
+// T is the type of the reference counted object. It is only used to customize
+// debug output when leak checking.
+type T interface{}
+
+// ownerType is used to customize logging. Note that we use a pointer to T so
+// that we do not copy the entire object when passed as a format parameter.
+var ownerType *T
+
+// Refs implements refs.RefCounter. It keeps a reference count using atomic
+// operations and calls the destructor when the count reaches zero.
+//
+// Note that the number of references is actually refCount + 1 so that a default
+// zero-value Refs object contains one reference.
+//
+// +stateify savable
+type Refs struct {
+ // refCount is composed of two fields:
+ //
+ // [32-bit speculative references]:[32-bit real references]
+ //
+ // Speculative references are used for TryIncRef, to avoid a CompareAndSwap
+ // loop. See IncRef, DecRef and TryIncRef for details of how these fields are
+ // used.
+ refCount int64
+}
+
+func (r *Refs) finalize() {
+ var note string
+ switch refs_vfs1.GetLeakMode() {
+ case refs_vfs1.NoLeakChecking:
+ return
+ case refs_vfs1.UninitializedLeakChecking:
+ note = "(Leak checker uninitialized): "
+ }
+ if n := r.ReadRefs(); n != 0 {
+ log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n)
+ }
+}
+
+// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
+func (r *Refs) EnableLeakCheck() {
+ if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
+ runtime.SetFinalizer(r, (*Refs).finalize)
+ }
+}
+
+// ReadRefs returns the current number of references. The returned count is
+// inherently racy and is unsafe to use without external synchronization.
+func (r *Refs) ReadRefs() int64 {
+ // Account for the internal -1 offset on refcounts.
+ return atomic.LoadInt64(&r.refCount) + 1
+}
+
+// IncRef implements refs.RefCounter.IncRef.
+//
+//go:nosplit
+func (r *Refs) IncRef() {
+ if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
+ panic("Incrementing non-positive ref count")
+ }
+}
+
+// TryIncRef implements refs.RefCounter.TryIncRef.
+//
+// To do this safely without a loop, a speculative reference is first acquired
+// on the object. This allows multiple concurrent TryIncRef calls to distinguish
+// other TryIncRef calls from genuine references held.
+//
+//go:nosplit
+func (r *Refs) TryIncRef() bool {
+ const speculativeRef = 1 << 32
+ v := atomic.AddInt64(&r.refCount, speculativeRef)
+ if int32(v) < 0 {
+ // This object has already been freed.
+ atomic.AddInt64(&r.refCount, -speculativeRef)
+ return false
+ }
+
+ // Turn into a real reference.
+ atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ return true
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+//
+// Note that speculative references are counted here. Since they were added
+// prior to real references reaching zero, they will successfully convert to
+// real references. In other words, we see speculative references only in the
+// following case:
+//
+// A: TryIncRef [speculative increase => sees non-negative references]
+// B: DecRef [real decrease]
+// A: TryIncRef [transform speculative to real]
+//
+//go:nosplit
+func (r *Refs) DecRef(destroy func()) {
+ switch v := atomic.AddInt64(&r.refCount, -1); {
+ case v < -1:
+ panic("Decrementing non-positive ref count")
+
+ case v == -1:
+ // Call the destructor.
+ if destroy != nil {
+ destroy()
+ }
+ }
+}
diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/safecopy/BUILD
index 6769cd0a5..426ef30c9 100644
--- a/pkg/sentry/platform/safecopy/BUILD
+++ b/pkg/safecopy/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -17,8 +16,7 @@ go_library(
"sighandler_amd64.s",
"sighandler_arm64.s",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy",
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = ["//pkg/syserror"],
)
@@ -27,5 +25,5 @@ go_test(
srcs = [
"safecopy_test.go",
],
- embed = [":safecopy"],
+ library = ":safecopy",
)
diff --git a/pkg/sentry/platform/safecopy/LICENSE b/pkg/safecopy/LICENSE
index 6a66aea5e..6a66aea5e 100644
--- a/pkg/sentry/platform/safecopy/LICENSE
+++ b/pkg/safecopy/LICENSE
diff --git a/pkg/sentry/platform/safecopy/atomic_amd64.s b/pkg/safecopy/atomic_amd64.s
index a0cd78f33..a0cd78f33 100644
--- a/pkg/sentry/platform/safecopy/atomic_amd64.s
+++ b/pkg/safecopy/atomic_amd64.s
diff --git a/pkg/sentry/platform/safecopy/atomic_arm64.s b/pkg/safecopy/atomic_arm64.s
index d58ed71f7..d58ed71f7 100644
--- a/pkg/sentry/platform/safecopy/atomic_arm64.s
+++ b/pkg/safecopy/atomic_arm64.s
diff --git a/pkg/sentry/platform/safecopy/memclr_amd64.s b/pkg/safecopy/memclr_amd64.s
index 64cf32f05..64cf32f05 100644
--- a/pkg/sentry/platform/safecopy/memclr_amd64.s
+++ b/pkg/safecopy/memclr_amd64.s
diff --git a/pkg/sentry/platform/safecopy/memclr_arm64.s b/pkg/safecopy/memclr_arm64.s
index 7361b9067..7361b9067 100644
--- a/pkg/sentry/platform/safecopy/memclr_arm64.s
+++ b/pkg/safecopy/memclr_arm64.s
diff --git a/pkg/sentry/platform/safecopy/memcpy_amd64.s b/pkg/safecopy/memcpy_amd64.s
index 129691d68..00b46c18f 100644
--- a/pkg/sentry/platform/safecopy/memcpy_amd64.s
+++ b/pkg/safecopy/memcpy_amd64.s
@@ -55,15 +55,9 @@ TEXT ·memcpy(SB), NOSPLIT, $0-36
MOVQ from+8(FP), SI
MOVQ n+16(FP), BX
- // REP instructions have a high startup cost, so we handle small sizes
- // with some straightline code. The REP MOVSQ instruction is really fast
- // for large sizes. The cutover is approximately 2K.
tail:
- // move_129through256 or smaller work whether or not the source and the
- // destination memory regions overlap because they load all data into
- // registers before writing it back. move_256through2048 on the other
- // hand can be used only when the memory regions don't overlap or the copy
- // direction is forward.
+ // BSR+branch table make almost all memmove/memclr benchmarks worse. Not
+ // worth doing.
TESTQ BX, BX
JEQ move_0
CMPQ BX, $2
@@ -83,31 +77,45 @@ tail:
JBE move_65through128
CMPQ BX, $256
JBE move_129through256
- // TODO: use branch table and BSR to make this just a single dispatch
-/*
- * forward copy loop
- */
- CMPQ BX, $2048
- JLS move_256through2048
-
- // Check alignment
- MOVL SI, AX
- ORL DI, AX
- TESTL $7, AX
- JEQ fwdBy8
-
- // Do 1 byte at a time
- MOVQ BX, CX
- REP; MOVSB
- RET
-
-fwdBy8:
- // Do 8 bytes at a time
- MOVQ BX, CX
- SHRQ $3, CX
- ANDQ $7, BX
- REP; MOVSQ
+move_257plus:
+ SUBQ $256, BX
+ MOVOU (SI), X0
+ MOVOU X0, (DI)
+ MOVOU 16(SI), X1
+ MOVOU X1, 16(DI)
+ MOVOU 32(SI), X2
+ MOVOU X2, 32(DI)
+ MOVOU 48(SI), X3
+ MOVOU X3, 48(DI)
+ MOVOU 64(SI), X4
+ MOVOU X4, 64(DI)
+ MOVOU 80(SI), X5
+ MOVOU X5, 80(DI)
+ MOVOU 96(SI), X6
+ MOVOU X6, 96(DI)
+ MOVOU 112(SI), X7
+ MOVOU X7, 112(DI)
+ MOVOU 128(SI), X8
+ MOVOU X8, 128(DI)
+ MOVOU 144(SI), X9
+ MOVOU X9, 144(DI)
+ MOVOU 160(SI), X10
+ MOVOU X10, 160(DI)
+ MOVOU 176(SI), X11
+ MOVOU X11, 176(DI)
+ MOVOU 192(SI), X12
+ MOVOU X12, 192(DI)
+ MOVOU 208(SI), X13
+ MOVOU X13, 208(DI)
+ MOVOU 224(SI), X14
+ MOVOU X14, 224(DI)
+ MOVOU 240(SI), X15
+ MOVOU X15, 240(DI)
+ CMPQ BX, $256
+ LEAQ 256(SI), SI
+ LEAQ 256(DI), DI
+ JGE move_257plus
JMP tail
move_1or2:
@@ -209,42 +217,3 @@ move_129through256:
MOVOU -16(SI)(BX*1), X15
MOVOU X15, -16(DI)(BX*1)
RET
-move_256through2048:
- SUBQ $256, BX
- MOVOU (SI), X0
- MOVOU X0, (DI)
- MOVOU 16(SI), X1
- MOVOU X1, 16(DI)
- MOVOU 32(SI), X2
- MOVOU X2, 32(DI)
- MOVOU 48(SI), X3
- MOVOU X3, 48(DI)
- MOVOU 64(SI), X4
- MOVOU X4, 64(DI)
- MOVOU 80(SI), X5
- MOVOU X5, 80(DI)
- MOVOU 96(SI), X6
- MOVOU X6, 96(DI)
- MOVOU 112(SI), X7
- MOVOU X7, 112(DI)
- MOVOU 128(SI), X8
- MOVOU X8, 128(DI)
- MOVOU 144(SI), X9
- MOVOU X9, 144(DI)
- MOVOU 160(SI), X10
- MOVOU X10, 160(DI)
- MOVOU 176(SI), X11
- MOVOU X11, 176(DI)
- MOVOU 192(SI), X12
- MOVOU X12, 192(DI)
- MOVOU 208(SI), X13
- MOVOU X13, 208(DI)
- MOVOU 224(SI), X14
- MOVOU X14, 224(DI)
- MOVOU 240(SI), X15
- MOVOU X15, 240(DI)
- CMPQ BX, $256
- LEAQ 256(SI), SI
- LEAQ 256(DI), DI
- JGE move_256through2048
- JMP tail
diff --git a/pkg/sentry/platform/safecopy/memcpy_arm64.s b/pkg/safecopy/memcpy_arm64.s
index e7e541565..e7e541565 100644
--- a/pkg/sentry/platform/safecopy/memcpy_arm64.s
+++ b/pkg/safecopy/memcpy_arm64.s
diff --git a/pkg/sentry/platform/safecopy/safecopy.go b/pkg/safecopy/safecopy.go
index 2fb7e5809..2fb7e5809 100644
--- a/pkg/sentry/platform/safecopy/safecopy.go
+++ b/pkg/safecopy/safecopy.go
diff --git a/pkg/sentry/platform/safecopy/safecopy_test.go b/pkg/safecopy/safecopy_test.go
index 5818f7f9b..7f7f69d61 100644
--- a/pkg/sentry/platform/safecopy/safecopy_test.go
+++ b/pkg/safecopy/safecopy_test.go
@@ -138,10 +138,14 @@ func TestSwapUint32Success(t *testing.T) {
func TestSwapUint32AlignmentError(t *testing.T) {
// Test that SwapUint32 returns an AlignmentError when passed an unaligned
// address.
- data := new(struct{ val uint64 })
- addr := uintptr(unsafe.Pointer(&data.val)) + 1
- want := AlignmentError{Addr: addr, Alignment: 4}
- if _, err := SwapUint32(unsafe.Pointer(addr), 1); err != want {
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := SwapUint32(ptr, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -171,10 +175,14 @@ func TestSwapUint64Success(t *testing.T) {
func TestSwapUint64AlignmentError(t *testing.T) {
// Test that SwapUint64 returns an AlignmentError when passed an unaligned
// address.
- data := new(struct{ val1, val2 uint64 })
- addr := uintptr(unsafe.Pointer(&data.val1)) + 1
- want := AlignmentError{Addr: addr, Alignment: 8}
- if _, err := SwapUint64(unsafe.Pointer(addr), 1); err != want {
+ data := make([]byte, 16) // 2 * sizeof(uint64).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 8; offset != 0 {
+ alignedIndex = 8 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 8}
+ if _, err := SwapUint64(ptr, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -201,10 +209,14 @@ func TestCompareAndSwapUint32Success(t *testing.T) {
func TestCompareAndSwapUint32AlignmentError(t *testing.T) {
// Test that CompareAndSwapUint32 returns an AlignmentError when passed an
// unaligned address.
- data := new(struct{ val uint64 })
- addr := uintptr(unsafe.Pointer(&data.val)) + 1
- want := AlignmentError{Addr: addr, Alignment: 4}
- if _, err := CompareAndSwapUint32(unsafe.Pointer(addr), 0, 1); err != want {
+ data := make([]byte, 8) // 2 * sizeof(uint32).
+ alignedIndex := uintptr(0)
+ if offset := uintptr(unsafe.Pointer(&data[0])) % 4; offset != 0 {
+ alignedIndex = 4 - offset
+ }
+ ptr := unsafe.Pointer(&data[alignedIndex+1])
+ want := AlignmentError{Addr: uintptr(ptr), Alignment: 4}
+ if _, err := CompareAndSwapUint32(ptr, 0, 1); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
}
}
@@ -252,8 +264,8 @@ func TestCopyInSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := CopyIn(dst, src)
if n != bytesBeforeFault {
@@ -276,8 +288,8 @@ func TestCopyInBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := CopyIn(dst, src)
if n != bytesBeforeFault {
@@ -300,8 +312,8 @@ func TestCopyOutSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := CopyOut(dst, src)
if n != bytesBeforeFault {
@@ -324,8 +336,8 @@ func TestCopyOutBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := CopyOut(dst, src)
if n != bytesBeforeFault {
@@ -348,8 +360,8 @@ func TestCopySourceSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -372,8 +384,8 @@ func TestCopySourceBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- src := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ src := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
dst := randBuf(pageSize)
n, err := Copy(unsafe.Pointer(&dst[0]), src, pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -396,8 +408,8 @@ func TestCopyDestinationSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -420,8 +432,8 @@ func TestCopyDestinationBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting copy %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
src := randBuf(pageSize)
n, err := Copy(dst, unsafe.Pointer(&src[0]), pageSize)
if n != uintptr(bytesBeforeFault) {
@@ -444,8 +456,8 @@ func TestZeroOutSegvError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting write %d bytes before SIGSEGV", bytesBeforeFault), func(t *testing.T) {
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
n, err := ZeroOut(dst, pageSize)
if n != uintptr(bytesBeforeFault) {
t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
@@ -467,8 +479,8 @@ func TestZeroOutBusError(t *testing.T) {
for bytesBeforeFault := 0; bytesBeforeFault <= 2*maxRegisterSize; bytesBeforeFault++ {
t.Run(fmt.Sprintf("starting write %d bytes before SIGBUS", bytesBeforeFault), func(t *testing.T) {
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
- dst := unsafe.Pointer(secondPage - uintptr(bytesBeforeFault))
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
+ dst := unsafe.Pointer(&mapping[pageSize-bytesBeforeFault])
n, err := ZeroOut(dst, pageSize)
if n != uintptr(bytesBeforeFault) {
t.Errorf("Unexpected write length: got %v, want %v", n, bytesBeforeFault)
@@ -488,7 +500,7 @@ func TestSwapUint32SegvError(t *testing.T) {
// Test that SwapUint32 returns a SegvError when reaching a page that
// signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint32(unsafe.Pointer(secondPage), 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -500,7 +512,7 @@ func TestSwapUint32BusError(t *testing.T) {
// Test that SwapUint32 returns a BusError when reaching a page that
// signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint32(unsafe.Pointer(secondPage), 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -512,7 +524,7 @@ func TestSwapUint64SegvError(t *testing.T) {
// Test that SwapUint64 returns a SegvError when reaching a page that
// signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint64(unsafe.Pointer(secondPage), 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -524,7 +536,7 @@ func TestSwapUint64BusError(t *testing.T) {
// Test that SwapUint64 returns a BusError when reaching a page that
// signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := SwapUint64(unsafe.Pointer(secondPage), 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -536,7 +548,7 @@ func TestCompareAndSwapUint32SegvError(t *testing.T) {
// Test that CompareAndSwapUint32 returns a SegvError when reaching a page
// that signals SIGSEGV.
withSegvErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
if want := (SegvError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
@@ -548,7 +560,7 @@ func TestCompareAndSwapUint32BusError(t *testing.T) {
// Test that CompareAndSwapUint32 returns a BusError when reaching a page
// that signals SIGBUS.
withBusErrorTestMapping(t, func(mapping []byte) {
- secondPage := uintptr(unsafe.Pointer(&mapping[0])) + pageSize
+ secondPage := uintptr(unsafe.Pointer(&mapping[pageSize]))
_, err := CompareAndSwapUint32(unsafe.Pointer(secondPage), 0, 1)
if want := (BusError{secondPage}); err != want {
t.Errorf("Unexpected error: got %v, want %v", err, want)
diff --git a/pkg/sentry/platform/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go
index eef028e68..41dd567f3 100644
--- a/pkg/sentry/platform/safecopy/safecopy_unsafe.go
+++ b/pkg/safecopy/safecopy_unsafe.go
@@ -16,6 +16,7 @@ package safecopy
import (
"fmt"
+ "runtime"
"syscall"
"unsafe"
)
@@ -35,7 +36,7 @@ const maxRegisterSize = 16
// successfully copied.
//
//go:noescape
-func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+func memcpy(dst, src uintptr, n uintptr) (fault uintptr, sig int32)
// memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
// signal is received during the write, it returns the address that caused the
@@ -47,7 +48,7 @@ func memcpy(dst, src unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32
// successfully written.
//
//go:noescape
-func memclr(ptr unsafe.Pointer, n uintptr) (fault unsafe.Pointer, sig int32)
+func memclr(ptr uintptr, n uintptr) (fault uintptr, sig int32)
// swapUint32 atomically stores new into *ptr and returns (the previous *ptr
// value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
@@ -90,29 +91,35 @@ func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
// CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
// copied and an error if SIGSEGV or SIGBUS is received while reading from src.
func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
+ n, err := copyIn(dst, uintptr(src))
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyIn is the underlying definition for CopyIn.
+func copyIn(dst []byte, src uintptr) (int, error) {
toCopy := uintptr(len(dst))
if len(dst) == 0 {
return 0, nil
}
- fault, sig := memcpy(unsafe.Pointer(&dst[0]), src, toCopy)
+ fault, sig := memcpy(uintptr(unsafe.Pointer(&dst[0])), src, toCopy)
if sig == 0 {
return len(dst), nil
}
- faultN, srcN := uintptr(fault), uintptr(src)
- if faultN < srcN || faultN >= srcN+toCopy {
- panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, faultN, srcN, srcN+toCopy))
+ if fault < src || fault >= src+toCopy {
+ panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, fault, src, src+toCopy))
}
// memcpy might have ended the copy up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to copy up to the fault.
var done int
- if faultN-srcN > maxRegisterSize {
- done = int(faultN - srcN - maxRegisterSize)
+ if fault-src > maxRegisterSize {
+ done = int(fault - src - maxRegisterSize)
}
- n, err := CopyIn(dst[done:int(faultN-srcN)], unsafe.Pointer(srcN+uintptr(done)))
+ n, err := copyIn(dst[done:int(fault-src)], src+uintptr(done))
done += n
if err != nil {
return done, err
@@ -124,29 +131,35 @@ func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
// bytes done and an error if SIGSEGV or SIGBUS is received while writing to
// dst.
func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
+ n, err := copyOut(uintptr(dst), src)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// copyOut is the underlying definition for CopyOut.
+func copyOut(dst uintptr, src []byte) (int, error) {
toCopy := uintptr(len(src))
if toCopy == 0 {
return 0, nil
}
- fault, sig := memcpy(dst, unsafe.Pointer(&src[0]), toCopy)
+ fault, sig := memcpy(dst, uintptr(unsafe.Pointer(&src[0])), toCopy)
if sig == 0 {
return len(src), nil
}
- faultN, dstN := uintptr(fault), uintptr(dst)
- if faultN < dstN || faultN >= dstN+toCopy {
- panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toCopy))
+ if fault < dst || fault >= dst+toCopy {
+ panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toCopy))
}
// memcpy might have ended the copy up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to copy up to the fault.
var done int
- if faultN-dstN > maxRegisterSize {
- done = int(faultN - dstN - maxRegisterSize)
+ if fault-dst > maxRegisterSize {
+ done = int(fault - dst - maxRegisterSize)
}
- n, err := CopyOut(unsafe.Pointer(dstN+uintptr(done)), src[done:int(faultN-dstN)])
+ n, err := copyOut(dst+uintptr(done), src[done:int(fault-dst)])
done += n
if err != nil {
return done, err
@@ -161,6 +174,14 @@ func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
// Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap,
// the resulting contents of dst are unspecified.
func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
+ n, err := copyN(uintptr(dst), uintptr(src), toCopy)
+ runtime.KeepAlive(dst)
+ runtime.KeepAlive(src)
+ return n, err
+}
+
+// copyN is the underlying definition for Copy.
+func copyN(dst, src uintptr, toCopy uintptr) (uintptr, error) {
if toCopy == 0 {
return 0, nil
}
@@ -171,17 +192,16 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
}
// Did the fault occur while reading from src or writing to dst?
- faultN, srcN, dstN := uintptr(fault), uintptr(src), uintptr(dst)
faultAfterSrc := ^uintptr(0)
- if faultN >= srcN {
- faultAfterSrc = faultN - srcN
+ if fault >= src {
+ faultAfterSrc = fault - src
}
faultAfterDst := ^uintptr(0)
- if faultN >= dstN {
- faultAfterDst = faultN - dstN
+ if fault >= dst {
+ faultAfterDst = fault - dst
}
if faultAfterSrc >= toCopy && faultAfterDst >= toCopy {
- panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, faultN, srcN, srcN+toCopy, dstN, dstN+toCopy))
+ panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, fault, src, src+toCopy, dst, dst+toCopy))
}
faultedAfter := faultAfterSrc
if faultedAfter > faultAfterDst {
@@ -195,7 +215,7 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
if faultedAfter > maxRegisterSize {
done = faultedAfter - maxRegisterSize
}
- n, err := Copy(unsafe.Pointer(dstN+done), unsafe.Pointer(srcN+done), faultedAfter-done)
+ n, err := copyN(dst+done, src+done, faultedAfter-done)
done += n
if err != nil {
return done, err
@@ -206,6 +226,13 @@ func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
// ZeroOut writes toZero zero bytes to dst. It returns the number of bytes
// written and an error if SIGSEGV or SIGBUS is received while writing to dst.
func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
+ n, err := zeroOut(uintptr(dst), toZero)
+ runtime.KeepAlive(dst)
+ return n, err
+}
+
+// zeroOut is the underlying definition for ZeroOut.
+func zeroOut(dst uintptr, toZero uintptr) (uintptr, error) {
if toZero == 0 {
return 0, nil
}
@@ -215,19 +242,18 @@ func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
return toZero, nil
}
- faultN, dstN := uintptr(fault), uintptr(dst)
- if faultN < dstN || faultN >= dstN+toZero {
- panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, faultN, dstN, dstN+toZero))
+ if fault < dst || fault >= dst+toZero {
+ panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toZero))
}
// memclr might have ended the write up to maxRegisterSize bytes before
// fault, if an instruction caused a memory access that straddled two
// pages, and the second one faulted. Try to write up to the fault.
var done uintptr
- if faultN-dstN > maxRegisterSize {
- done = faultN - dstN - maxRegisterSize
+ if fault-dst > maxRegisterSize {
+ done = fault - dst - maxRegisterSize
}
- n, err := ZeroOut(unsafe.Pointer(dstN+done), faultN-dstN-done)
+ n, err := zeroOut(dst+done, fault-dst-done)
done += n
if err != nil {
return done, err
@@ -243,7 +269,7 @@ func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
old, sig := swapUint32(ptr, new)
- return old, errorFromFaultSignal(ptr, sig)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
}
// SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns
@@ -254,7 +280,7 @@ func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) {
return 0, AlignmentError{addr, 8}
}
old, sig := swapUint64(ptr, new)
- return old, errorFromFaultSignal(ptr, sig)
+ return old, errorFromFaultSignal(uintptr(ptr), sig)
}
// CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32,
@@ -265,7 +291,7 @@ func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
prev, sig := compareAndSwapUint32(ptr, old, new)
- return prev, errorFromFaultSignal(ptr, sig)
+ return prev, errorFromFaultSignal(uintptr(ptr), sig)
}
// LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
@@ -277,17 +303,17 @@ func LoadUint32(ptr unsafe.Pointer) (uint32, error) {
return 0, AlignmentError{addr, 4}
}
val, sig := loadUint32(ptr)
- return val, errorFromFaultSignal(ptr, sig)
+ return val, errorFromFaultSignal(uintptr(ptr), sig)
}
-func errorFromFaultSignal(addr unsafe.Pointer, sig int32) error {
+func errorFromFaultSignal(addr uintptr, sig int32) error {
switch sig {
case 0:
return nil
case int32(syscall.SIGSEGV):
- return SegvError{uintptr(addr)}
+ return SegvError{addr}
case int32(syscall.SIGBUS):
- return BusError{uintptr(addr)}
+ return BusError{addr}
default:
panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
}
diff --git a/pkg/sentry/platform/safecopy/sighandler_amd64.s b/pkg/safecopy/sighandler_amd64.s
index 475ae48e9..475ae48e9 100644
--- a/pkg/sentry/platform/safecopy/sighandler_amd64.s
+++ b/pkg/safecopy/sighandler_amd64.s
diff --git a/pkg/sentry/platform/safecopy/sighandler_arm64.s b/pkg/safecopy/sighandler_arm64.s
index 53e4ac2c1..53e4ac2c1 100644
--- a/pkg/sentry/platform/safecopy/sighandler_arm64.s
+++ b/pkg/safecopy/sighandler_arm64.s
diff --git a/pkg/sentry/safemem/BUILD b/pkg/safemem/BUILD
index 884020f7b..ce30382ab 100644
--- a/pkg/sentry/safemem/BUILD
+++ b/pkg/safemem/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,10 +10,9 @@ go_library(
"safemem.go",
"seq_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/safemem",
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
- "//pkg/sentry/platform/safecopy",
+ "//pkg/safecopy",
],
)
@@ -25,5 +23,5 @@ go_test(
"io_test.go",
"seq_test.go",
],
- embed = [":safemem"],
+ library = ":safemem",
)
diff --git a/pkg/sentry/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go
index 6f03c94bf..e7fd30743 100644
--- a/pkg/sentry/safemem/block_unsafe.go
+++ b/pkg/safemem/block_unsafe.go
@@ -19,7 +19,7 @@ import (
"reflect"
"unsafe"
- "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+ "gvisor.dev/gvisor/pkg/safecopy"
)
// A Block is a range of contiguous bytes, similar to []byte but with the
diff --git a/pkg/sentry/safemem/io.go b/pkg/safemem/io.go
index f039a5c34..f039a5c34 100644
--- a/pkg/sentry/safemem/io.go
+++ b/pkg/safemem/io.go
diff --git a/pkg/sentry/safemem/io_test.go b/pkg/safemem/io_test.go
index 629741bee..629741bee 100644
--- a/pkg/sentry/safemem/io_test.go
+++ b/pkg/safemem/io_test.go
diff --git a/pkg/sentry/safemem/safemem.go b/pkg/safemem/safemem.go
index 3e70d33a2..3e70d33a2 100644
--- a/pkg/sentry/safemem/safemem.go
+++ b/pkg/safemem/safemem.go
diff --git a/pkg/sentry/safemem/seq_test.go b/pkg/safemem/seq_test.go
index eba4bb535..de34005e9 100644
--- a/pkg/sentry/safemem/seq_test.go
+++ b/pkg/safemem/seq_test.go
@@ -20,6 +20,27 @@ import (
"testing"
)
+func TestBlockSeqOfEmptyBlock(t *testing.T) {
+ bs := BlockSeqOf(Block{})
+ if !bs.IsEmpty() {
+ t.Errorf("BlockSeqOf(Block{}).IsEmpty(): got false, wanted true; BlockSeq is %v", bs)
+ }
+}
+
+func TestBlockSeqOfNonemptyBlock(t *testing.T) {
+ b := BlockFromSafeSlice(make([]byte, 1))
+ bs := BlockSeqOf(b)
+ if bs.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).IsEmpty(): got true, wanted false; BlockSeq is %v", bs)
+ }
+ if head := bs.Head(); head != b {
+ t.Fatalf("BlockSeqOf(non-empty Block).Head(): got %v, wanted %v", head, b)
+ }
+ if tail := bs.Tail(); !tail.IsEmpty() {
+ t.Fatalf("BlockSeqOf(non-empty Block).Tail().IsEmpty(): got false, wanted true: tail is %v", tail)
+ }
+}
+
type blockSeqTest struct {
desc string
diff --git a/pkg/sentry/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go
index 354a95dde..f5f0574f8 100644
--- a/pkg/sentry/safemem/seq_unsafe.go
+++ b/pkg/safemem/seq_unsafe.go
@@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"reflect"
+ "syscall"
"unsafe"
)
@@ -55,6 +56,9 @@ type BlockSeq struct {
// BlockSeqOf returns a BlockSeq representing the single Block b.
func BlockSeqOf(b Block) BlockSeq {
+ if b.length == 0 {
+ return BlockSeq{}
+ }
bs := BlockSeq{
data: b.start,
length: -1,
@@ -297,3 +301,19 @@ func ZeroSeq(dsts BlockSeq) (uint64, error) {
}
return done, nil
}
+
+// IovecsFromBlockSeq returns a []syscall.Iovec representing seq.
+func IovecsFromBlockSeq(bs BlockSeq) []syscall.Iovec {
+ iovs := make([]syscall.Iovec, 0, bs.NumBlocks())
+ for ; !bs.IsEmpty(); bs = bs.Tail() {
+ b := bs.Head()
+ iovs = append(iovs, syscall.Iovec{
+ Base: &b.ToSlice()[0],
+ Len: uint64(b.Len()),
+ })
+ // We don't need to care about b.NeedSafecopy(), because the host
+ // kernel will handle such address ranges just fine (by returning
+ // EFAULT).
+ }
+ return iovs
+}
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index af94e944d..29aeaab8c 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -1,12 +1,15 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test")
+load("//tools:defs.bzl", "go_binary", "go_embed_data", "go_library", "go_test")
package(licenses = ["notice"])
go_binary(
name = "victim",
testonly = 1,
- srcs = ["seccomp_test_victim.go"],
+ srcs = [
+ "seccomp_test_victim.go",
+ "seccomp_test_victim_amd64.go",
+ "seccomp_test_victim_arm64.go",
+ ],
deps = [":seccomp"],
)
@@ -27,8 +30,7 @@ go_library(
"seccomp_rules.go",
"seccomp_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/seccomp",
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
"//pkg/bpf",
@@ -43,7 +45,7 @@ go_test(
"seccomp_test.go",
":victim_data",
],
- embed = [":seccomp"],
+ library = ":seccomp",
deps = [
"//pkg/abi/linux",
"//pkg/binary",
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go
index c7503f2cc..55fd6967e 100644
--- a/pkg/seccomp/seccomp.go
+++ b/pkg/seccomp/seccomp.go
@@ -199,6 +199,10 @@ func ruleViolationLabel(ruleSetIdx int, sysno uintptr, idx int) string {
return fmt.Sprintf("ruleViolation_%v_%v_%v", ruleSetIdx, sysno, idx)
}
+func ruleLabel(ruleSetIdx int, sysno uintptr, idx int, name string) string {
+ return fmt.Sprintf("rule_%v_%v_%v_%v", ruleSetIdx, sysno, idx, name)
+}
+
func checkArgsLabel(sysno uintptr) string {
return fmt.Sprintf("checkArgs_%v", sysno)
}
@@ -215,14 +219,39 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc
switch a := arg.(type) {
case AllowAny:
case AllowValue:
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
high, low := uint32(a>>32), uint32(a)
// assert arg_low == low
- p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i))
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
// assert arg_high == high
- p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i))
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
labelled = true
+ case GreaterThan:
+ dataOffsetLow := seccompDataOffsetArgLow(i)
+ dataOffsetHigh := seccompDataOffsetArgHigh(i)
+ if i == RuleIP {
+ dataOffsetLow = seccompDataOffsetIPLow
+ dataOffsetHigh = seccompDataOffsetIPHigh
+ }
+ labelGood := fmt.Sprintf("gt%v", i)
+ high, low := uint32(a>>32), uint32(a)
+ // assert arg_high < high
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ // arg_high > high
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ // arg_low < low
+ p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow)
+ p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx))
+ p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood))
+ labelled = true
default:
return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a))
}
diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go
index 29eec8db1..a52dc1b4e 100644
--- a/pkg/seccomp/seccomp_rules.go
+++ b/pkg/seccomp/seccomp_rules.go
@@ -49,17 +49,24 @@ func (a AllowAny) String() (s string) {
// AllowValue specifies a value that needs to be strictly matched.
type AllowValue uintptr
+// GreaterThan specifies a value that needs to be strictly smaller.
+type GreaterThan uintptr
+
func (a AllowValue) String() (s string) {
return fmt.Sprintf("%#x ", uintptr(a))
}
-// Rule stores the whitelist of syscall arguments.
+// Rule stores the allowed syscall arguments.
//
// For example:
// rule := Rule {
// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
// }
-type Rule [6]interface{}
+type Rule [7]interface{} // 6 arguments + RIP
+
+// RuleIP indicates what rules in the Rule array have to be applied to
+// instruction pointer.
+const RuleIP = 6
func (r Rule) String() (s string) {
if len(r) == 0 {
@@ -75,7 +82,7 @@ func (r Rule) String() (s string) {
return
}
-// SyscallRules stores a map of OR'ed whitelist rules indexed by the syscall number.
+// SyscallRules stores a map of OR'ed argument rules indexed by the syscall number.
// If the 'Rules' is empty, we treat it as any argument is allowed.
//
// For example:
diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go
index 353686ed3..5238df8bd 100644
--- a/pkg/seccomp/seccomp_test.go
+++ b/pkg/seccomp/seccomp_test.go
@@ -91,12 +91,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "Single syscall allowed",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "Single syscall disallowed",
- data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -125,22 +125,22 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "Multiple rulesets allowed (1a)",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0x1}},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "Multiple rulesets allowed (1b)",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "Multiple rulesets allowed (2)",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "Multiple rulesets allowed (2)",
- data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_KILL_THREAD,
},
},
@@ -160,42 +160,42 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "Multiple syscalls allowed (1)",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "Multiple syscalls allowed (3)",
- data: seccompData{nr: 3, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "Multiple syscalls allowed (5)",
- data: seccompData{nr: 5, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "Multiple syscalls disallowed (0)",
- data: seccompData{nr: 0, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "Multiple syscalls disallowed (2)",
- data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "Multiple syscalls disallowed (4)",
- data: seccompData{nr: 4, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "Multiple syscalls disallowed (6)",
- data: seccompData{nr: 6, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
{
desc: "Multiple syscalls disallowed (100)",
- data: seccompData{nr: 100, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -231,7 +231,7 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "Syscall disallowed, action trap",
- data: seccompData{nr: 2, arch: linux.AUDIT_ARCH_X86_64},
+ data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -254,12 +254,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "Syscall argument allowed",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xf}},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "Syscall argument disallowed",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf, 0xe}},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}},
want: linux.SECCOMP_RET_TRAP,
},
},
@@ -284,12 +284,12 @@ func TestBasic(t *testing.T) {
specs: []spec{
{
desc: "Syscall argument allowed, two rules",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xf}},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}},
want: linux.SECCOMP_RET_ALLOW,
},
{
desc: "Syscall argument allowed, two rules",
- data: seccompData{nr: 1, arch: linux.AUDIT_ARCH_X86_64, args: [6]uint64{0xe}},
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}},
want: linux.SECCOMP_RET_ALLOW,
},
},
@@ -315,7 +315,7 @@ func TestBasic(t *testing.T) {
desc: "64bit syscall argument allowed",
data: seccompData{
nr: 1,
- arch: linux.AUDIT_ARCH_X86_64,
+ arch: LINUX_AUDIT_ARCH,
args: [6]uint64{0, math.MaxUint64 - 1, math.MaxUint32},
},
want: linux.SECCOMP_RET_ALLOW,
@@ -324,7 +324,7 @@ func TestBasic(t *testing.T) {
desc: "64bit syscall argument disallowed",
data: seccompData{
nr: 1,
- arch: linux.AUDIT_ARCH_X86_64,
+ arch: LINUX_AUDIT_ARCH,
args: [6]uint64{0, math.MaxUint64, math.MaxUint32},
},
want: linux.SECCOMP_RET_TRAP,
@@ -333,13 +333,88 @@ func TestBasic(t *testing.T) {
desc: "64bit syscall argument disallowed",
data: seccompData{
nr: 1,
- arch: linux.AUDIT_ARCH_X86_64,
+ arch: LINUX_AUDIT_ARCH,
args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1},
},
want: linux.SECCOMP_RET_TRAP,
},
},
},
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ GreaterThan(0xf),
+ GreaterThan(0xabcd000d),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "GreaterThan: Syscall argument allowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "GreaterThan: Syscall argument disallowed (equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "Syscall argument disallowed (smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument allowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xfbcd000d}},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument disallowed (equal)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ {
+ desc: "GreaterThan2: Syscall argument disallowed (smaller)",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
+ {
+ ruleSets: []RuleSet{
+ {
+ Rules: SyscallRules{
+ 1: []Rule{
+ {
+ RuleIP: AllowValue(0x7aabbccdd),
+ },
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ },
+ },
+ defaultAction: linux.SECCOMP_RET_TRAP,
+ specs: []spec{
+ {
+ desc: "IP: Syscall instruction pointer allowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd},
+ want: linux.SECCOMP_RET_ALLOW,
+ },
+ {
+ desc: "IP: Syscall instruction pointer disallowed",
+ data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344},
+ want: linux.SECCOMP_RET_TRAP,
+ },
+ },
+ },
} {
instrs, err := BuildProgram(test.ruleSets, test.defaultAction)
if err != nil {
@@ -376,7 +451,7 @@ func TestRandom(t *testing.T) {
}
}
- fmt.Printf("Testing filters: %v", syscallRules)
+ t.Logf("Testing filters: %v", syscallRules)
instrs, err := BuildProgram([]RuleSet{
RuleSet{
Rules: syscallRules,
@@ -391,7 +466,7 @@ func TestRandom(t *testing.T) {
t.Fatalf("bpf.Compile() got error: %v", err)
}
for i := uint32(0); i < 200; i++ {
- data := seccompData{nr: i, arch: linux.AUDIT_ARCH_X86_64}
+ data := seccompData{nr: i, arch: LINUX_AUDIT_ARCH}
got, err := bpf.Exec(p, data.asInput())
if err != nil {
t.Errorf("bpf.Exec() got error: %v, for syscall %d", err, i)
diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go
index 48413f1fb..fe157f539 100644
--- a/pkg/seccomp/seccomp_test_victim.go
+++ b/pkg/seccomp/seccomp_test_victim.go
@@ -31,17 +31,15 @@ func main() {
syscalls := seccomp.SyscallRules{
syscall.SYS_ACCEPT: {},
- syscall.SYS_ARCH_PRCTL: {},
syscall.SYS_BIND: {},
syscall.SYS_BRK: {},
syscall.SYS_CLOCK_GETTIME: {},
syscall.SYS_CLONE: {},
syscall.SYS_CLOSE: {},
syscall.SYS_DUP: {},
- syscall.SYS_DUP2: {},
+ syscall.SYS_DUP3: {},
syscall.SYS_EPOLL_CREATE1: {},
syscall.SYS_EPOLL_CTL: {},
- syscall.SYS_EPOLL_WAIT: {},
syscall.SYS_EPOLL_PWAIT: {},
syscall.SYS_EXIT: {},
syscall.SYS_EXIT_GROUP: {},
@@ -68,8 +66,6 @@ func main() {
syscall.SYS_MUNLOCK: {},
syscall.SYS_MUNMAP: {},
syscall.SYS_NANOSLEEP: {},
- syscall.SYS_NEWFSTATAT: {},
- syscall.SYS_OPEN: {},
syscall.SYS_PPOLL: {},
syscall.SYS_PREAD64: {},
syscall.SYS_PSELECT6: {},
@@ -97,6 +93,9 @@ func main() {
syscall.SYS_WRITE: {},
syscall.SYS_WRITEV: {},
}
+
+ arch_syscalls(syscalls)
+
die := *dieFlag
if !die {
syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{
diff --git a/pkg/seccomp/seccomp_test_victim_amd64.go b/pkg/seccomp/seccomp_test_victim_amd64.go
new file mode 100644
index 000000000..5dfc68e25
--- /dev/null
+++ b/pkg/seccomp/seccomp_test_victim_amd64.go
@@ -0,0 +1,32 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Test binary used to test that seccomp filters are properly constructed and
+// indeed kill the process on violation.
+
+// +build amd64
+
+package main
+
+import (
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "syscall"
+)
+
+func arch_syscalls(syscalls seccomp.SyscallRules) {
+ syscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{}
+ syscalls[syscall.SYS_EPOLL_WAIT] = []seccomp.Rule{}
+ syscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{}
+ syscalls[syscall.SYS_OPEN] = []seccomp.Rule{}
+}
diff --git a/pkg/seccomp/seccomp_test_victim_arm64.go b/pkg/seccomp/seccomp_test_victim_arm64.go
new file mode 100644
index 000000000..5184d8ac4
--- /dev/null
+++ b/pkg/seccomp/seccomp_test_victim_arm64.go
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Test binary used to test that seccomp filters are properly constructed and
+// indeed kill the process on violation.
+
+// +build arm64
+
+package main
+
+import (
+ "gvisor.dev/gvisor/pkg/seccomp"
+ "syscall"
+)
+
+func arch_syscalls(syscalls seccomp.SyscallRules) {
+ syscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{}
+}
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
index be328db12..f7e986589 100644
--- a/pkg/seccomp/seccomp_unsafe.go
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -21,13 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// sockFprog is sock_fprog taken from <linux/filter.h>.
-type sockFprog struct {
- Len uint16
- pad [6]byte
- Filter *linux.BPFInstruction
-}
-
// SetFilter installs the given BPF program.
//
// This is safe to call from an afterFork context.
@@ -39,7 +32,7 @@ func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
return errno
}
- sockProg := sockFprog{
+ sockProg := linux.SockFprog{
Len: uint16(len(instrs)),
Filter: (*linux.BPFInstruction)(unsafe.Pointer(&instrs[0])),
}
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD
index 22abdc69f..60f63c7a6 100644
--- a/pkg/secio/BUILD
+++ b/pkg/secio/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"full_reader.go",
"secio.go",
],
- importpath = "gvisor.dev/gvisor/pkg/secio",
visibility = ["//pkg/sentry:internal"],
)
@@ -17,5 +15,5 @@ go_test(
name = "secio_test",
size = "small",
srcs = ["secio_test.go"],
- embed = [":secio"],
+ library = ":secio",
)
diff --git a/pkg/segment/BUILD b/pkg/segment/BUILD
index 1b487b887..f57ccc170 100644
--- a/pkg/segment/BUILD
+++ b/pkg/segment/BUILD
@@ -21,6 +21,8 @@ go_template(
],
opt_consts = [
"minDegree",
+ # trackGaps must either be 0 or 1.
+ "trackGaps",
],
types = [
"Key",
diff --git a/pkg/segment/set.go b/pkg/segment/set.go
index 03e4f258f..1a17ad9cb 100644
--- a/pkg/segment/set.go
+++ b/pkg/segment/set.go
@@ -36,6 +36,34 @@ type Range interface{}
// Value is a required type parameter.
type Value interface{}
+// trackGaps is an optional parameter.
+//
+// If trackGaps is 1, the Set will track maximum gap size recursively,
+// enabling the GapIterator.{Prev,Next}LargeEnoughGap functions. In this
+// case, Key must be an unsigned integer.
+//
+// trackGaps must be 0 or 1.
+const trackGaps = 0
+
+var _ = uint8(trackGaps << 7) // Will fail if not zero or one.
+
+// dynamicGap is a type that disappears if trackGaps is 0.
+type dynamicGap [trackGaps]Key
+
+// Get returns the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *dynamicGap) Get() Key {
+ return d[:][0]
+}
+
+// Set sets the value of the gap.
+//
+// Precondition: trackGaps must be non-zero.
+func (d *dynamicGap) Set(v Key) {
+ d[:][0] = v
+}
+
// Functions is a required type parameter that must be a struct implementing
// the methods defined by Functions.
type Functions interface {
@@ -327,8 +355,12 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator {
}
if prev.Ok() && prev.End() == r.Start {
if mval, ok := (Functions{}).Merge(prev.Range(), prev.Value(), r, val); ok {
+ shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
prev.SetEndUnchecked(r.End)
prev.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
if next.Ok() && next.Start() == r.End {
val = mval
if mval, ok := (Functions{}).Merge(prev.Range(), val, next.Range(), next.Value()); ok {
@@ -342,11 +374,16 @@ func (s *Set) Insert(gap GapIterator, r Range, val Value) Iterator {
}
if next.Ok() && next.Start() == r.End {
if mval, ok := (Functions{}).Merge(r, val, next.Range(), next.Value()); ok {
+ shrinkMaxGap := trackGaps != 0 && gap.Range().Length() == gap.node.maxGap.Get()
next.SetStartUnchecked(r.Start)
next.SetValue(mval)
+ if shrinkMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return next
}
}
+ // InsertWithoutMergingUnchecked will maintain maxGap if necessary.
return s.InsertWithoutMergingUnchecked(gap, r, val)
}
@@ -373,11 +410,15 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator
// Preconditions: r.Start >= gap.Start(); r.End <= gap.End().
func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator {
gap = gap.node.rebalanceBeforeInsert(gap)
+ splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get())
copy(gap.node.keys[gap.index+1:], gap.node.keys[gap.index:gap.node.nrSegments])
copy(gap.node.values[gap.index+1:], gap.node.values[gap.index:gap.node.nrSegments])
gap.node.keys[gap.index] = r
gap.node.values[gap.index] = val
gap.node.nrSegments++
+ if splitMaxGap {
+ gap.node.updateMaxGapLeaf()
+ }
return Iterator{gap.node, gap.index}
}
@@ -399,12 +440,23 @@ func (s *Set) Remove(seg Iterator) GapIterator {
// overlap.
seg.SetRangeUnchecked(victim.Range())
seg.SetValue(victim.Value())
+ // Need to update the nextAdjacentNode's maxGap because the gap in between
+ // must have been modified by updating seg.Range() to victim.Range().
+ // seg.NextSegment() must exist since the last segment can't be in a
+ // non-leaf node.
+ nextAdjacentNode := seg.NextSegment().node
+ if trackGaps != 0 {
+ nextAdjacentNode.updateMaxGapLeaf()
+ }
return s.Remove(victim).NextGap()
}
copy(seg.node.keys[seg.index:], seg.node.keys[seg.index+1:seg.node.nrSegments])
copy(seg.node.values[seg.index:], seg.node.values[seg.index+1:seg.node.nrSegments])
Functions{}.ClearValue(&seg.node.values[seg.node.nrSegments-1])
seg.node.nrSegments--
+ if trackGaps != 0 {
+ seg.node.updateMaxGapLeaf()
+ }
return seg.node.rebalanceAfterRemove(GapIterator{seg.node, seg.index})
}
@@ -455,6 +507,7 @@ func (s *Set) MergeUnchecked(first, second Iterator) Iterator {
// overlaps second.
first.SetEndUnchecked(second.End())
first.SetValue(mval)
+ // Remove will handle the maxGap update if necessary.
return s.Remove(second).PrevSegment()
}
}
@@ -631,6 +684,12 @@ type node struct {
// than "isLeaf" because false must be the correct value for an empty root.
hasChildren bool
+ // The longest gap within this node. If the node is a leaf, it's simply the
+ // maximum gap among all the (nrSegments+1) gaps formed by its nrSegments keys
+ // including the 0th and nrSegments-th gap possibly shared with its upper-level
+ // nodes; if it's a non-leaf node, it's the max of all children's maxGap.
+ maxGap dynamicGap
+
// Nodes store keys and values in separate arrays to maximize locality in
// the common case (scanning keys for lookup).
keys [maxDegree - 1]Range
@@ -676,12 +735,12 @@ func (n *node) nextSibling() *node {
// required for insertion, and returns an updated iterator to the position
// represented by gap.
func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
- if n.parent != nil {
- gap = n.parent.rebalanceBeforeInsert(gap)
- }
if n.nrSegments < maxDegree-1 {
return gap
}
+ if n.parent != nil {
+ gap = n.parent.rebalanceBeforeInsert(gap)
+ }
if n.parent == nil {
// n is root. Move all segments before and after n's median segment
// into new child nodes adjacent to the median segment, which is now
@@ -719,6 +778,13 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
n.hasChildren = true
n.children[0] = left
n.children[1] = right
+ // In this case, n's maxGap won't violated as it's still the root,
+ // but the left and right children should be updated locally as they
+ // are newly split from n.
+ if trackGaps != 0 {
+ left.updateMaxGapLocal()
+ right.updateMaxGapLocal()
+ }
if gap.node != n {
return gap
}
@@ -758,6 +824,12 @@ func (n *node) rebalanceBeforeInsert(gap GapIterator) GapIterator {
}
}
n.nrSegments = minDegree - 1
+ // MaxGap of n's parent is not violated because the segments within is not changed.
+ // n and its sibling's maxGap need to be updated locally as they are two new nodes split from old n.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
// gap.node can't be n.parent because gaps are always in leaf nodes.
if gap.node != n {
return gap
@@ -821,6 +893,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
}
n.nrSegments++
sibling.nrSegments--
+ // n's parent's maxGap does not need to be updated as its content is unmodified.
+ // n and its sibling must be updated with (new) maxGap because of the shift of keys.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling && gap.index == sibling.nrSegments {
return GapIterator{n, 0}
}
@@ -849,6 +927,12 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
}
n.nrSegments++
sibling.nrSegments--
+ // n's parent's maxGap does not need to be updated as its content is unmodified.
+ // n and its sibling must be updated with (new) maxGap because of the shift of keys.
+ if trackGaps != 0 {
+ n.updateMaxGapLocal()
+ sibling.updateMaxGapLocal()
+ }
if gap.node == sibling {
if gap.index == 0 {
return GapIterator{n, n.nrSegments}
@@ -886,6 +970,7 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
p.children[0] = nil
p.children[1] = nil
}
+ // No need to update maxGap of p as its content is not changed.
if gap.node == left {
return GapIterator{p, gap.index}
}
@@ -932,11 +1017,152 @@ func (n *node) rebalanceAfterRemove(gap GapIterator) GapIterator {
}
p.children[p.nrSegments] = nil
p.nrSegments--
+ // Update maxGap of left locally, no need to change p and right because
+ // p's contents is not changed and right is already invalid.
+ if trackGaps != 0 {
+ left.updateMaxGapLocal()
+ }
// This process robs p of one segment, so recurse into rebalancing p.
n = p
}
}
+// updateMaxGapLeaf updates maxGap bottom-up from the calling leaf until no
+// necessary update.
+//
+// Preconditions: n must be a leaf node, trackGaps must be 1.
+func (n *node) updateMaxGapLeaf() {
+ if n.hasChildren {
+ panic(fmt.Sprintf("updateMaxGapLeaf should always be called on leaf node: %v", n))
+ }
+ max := n.calculateMaxGapLeaf()
+ if max == n.maxGap.Get() {
+ // If new max equals the old maxGap, no update is needed.
+ return
+ }
+ oldMax := n.maxGap.Get()
+ n.maxGap.Set(max)
+ if max > oldMax {
+ // Grow ancestor maxGaps.
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() >= max {
+ // p and its ancestors already contain an equal or larger gap.
+ break
+ }
+ // Only if new maxGap is larger than parent's
+ // old maxGap, propagate this update to parent.
+ p.maxGap.Set(max)
+ }
+ return
+ }
+ // Shrink ancestor maxGaps.
+ for p := n.parent; p != nil; p = p.parent {
+ if p.maxGap.Get() > oldMax {
+ // p and its ancestors still contain a larger gap.
+ break
+ }
+ // If new max is smaller than the old maxGap, and this gap used
+ // to be the maxGap of its parent, iterate parent's children
+ // and calculate parent's new maxGap.(It's probable that parent
+ // has two children with the old maxGap, but we need to check it anyway.)
+ parentNewMax := p.calculateMaxGapInternal()
+ if p.maxGap.Get() == parentNewMax {
+ // p and its ancestors still contain a gap of at least equal size.
+ break
+ }
+ // If p's new maxGap differs from the old one, propagate this update.
+ p.maxGap.Set(parentNewMax)
+ }
+}
+
+// updateMaxGapLocal updates maxGap of the calling node solely with no
+// propagation to ancestor nodes.
+//
+// Precondition: trackGaps must be 1.
+func (n *node) updateMaxGapLocal() {
+ if !n.hasChildren {
+ // Leaf node iterates its gaps.
+ n.maxGap.Set(n.calculateMaxGapLeaf())
+ } else {
+ // Non-leaf node iterates its children.
+ n.maxGap.Set(n.calculateMaxGapInternal())
+ }
+}
+
+// calculateMaxGapLeaf iterates the gaps within a leaf node and calculate the
+// max.
+//
+// Preconditions: n must be a leaf node.
+func (n *node) calculateMaxGapLeaf() Key {
+ max := GapIterator{n, 0}.Range().Length()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := (GapIterator{n, i}).Range().Length(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// calculateMaxGapInternal iterates children's maxGap within an internal node n
+// and calculate the max.
+//
+// Preconditions: n must be a non-leaf node.
+func (n *node) calculateMaxGapInternal() Key {
+ max := n.children[0].maxGap.Get()
+ for i := 1; i <= n.nrSegments; i++ {
+ if current := n.children[i].maxGap.Get(); current > max {
+ max = current
+ }
+ }
+ return max
+}
+
+// searchFirstLargeEnoughGap returns the first gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *node) searchFirstLargeEnoughGap(minSize Key) GapIterator {
+ if n.maxGap.Get() < minSize {
+ return GapIterator{}
+ }
+ if n.hasChildren {
+ for i := 0; i <= n.nrSegments; i++ {
+ if largeEnoughGap := n.children[i].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ currentGap := GapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
+// searchLastLargeEnoughGap returns the last gap having at least minSize length
+// in the subtree rooted by n. If not found, return a terminal gap iterator.
+func (n *node) searchLastLargeEnoughGap(minSize Key) GapIterator {
+ if n.maxGap.Get() < minSize {
+ return GapIterator{}
+ }
+ if n.hasChildren {
+ for i := n.nrSegments; i >= 0; i-- {
+ if largeEnoughGap := n.children[i].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ }
+ } else {
+ for i := n.nrSegments; i >= 0; i-- {
+ currentGap := GapIterator{n, i}
+ if currentGap.Range().Length() >= minSize {
+ return currentGap
+ }
+ }
+ }
+ panic(fmt.Sprintf("invalid maxGap in %v", n))
+}
+
// A Iterator is conceptually one of:
//
// - A pointer to a segment in a set; or
@@ -1243,6 +1469,122 @@ func (gap GapIterator) NextGap() GapIterator {
return seg.NextGap()
}
+// NextLargeEnoughGap returns the iterated gap's first next gap with larger
+// length than minSize. If not found, return a terminal gap iterator (does NOT
+// include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap GapIterator) NextLargeEnoughGap(minSize Key) GapIterator {
+ if trackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == gap.node.nrSegments {
+ // If gap is the trailing gap of an non-leaf node,
+ // translate it to the equivalent gap on leaf level.
+ gap.node = gap.NextSegment().node
+ gap.index = 0
+ return gap.nextLargeEnoughGapHelper(minSize)
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// nextLargeEnoughGapHelper is the helper function used by NextLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the trailing gap of a non-leaf node.
+func (gap GapIterator) nextLargeEnoughGapHelper(minSize Key) GapIterator {
+ // Crawl up the tree if no large enough gap in current node or the
+ // current gap is the trailing one on leaf level.
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == gap.node.nrSegments)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ // If no large enough gap throughout the whole set, return a terminal
+ // gap iterator.
+ if gap.node == nil {
+ return GapIterator{}
+ }
+ // Iterate subsequent gaps.
+ gap.index++
+ for gap.index <= gap.node.nrSegments {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchFirstLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index++
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == gap.node.nrSegments {
+ // If gap is the trailing gap of a non-leaf node, crawl up to
+ // parent again and do recursion.
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.nextLargeEnoughGapHelper(minSize)
+}
+
+// PrevLargeEnoughGap returns the iterated gap's first prev gap with larger or
+// equal length than minSize. If not found, return a terminal gap iterator
+// (does NOT include this gap itself).
+//
+// Precondition: trackGaps must be 1.
+func (gap GapIterator) PrevLargeEnoughGap(minSize Key) GapIterator {
+ if trackGaps != 1 {
+ panic("set is not tracking gaps")
+ }
+ if gap.node != nil && gap.node.hasChildren && gap.index == 0 {
+ // If gap is the first gap of an non-leaf node,
+ // translate it to the equivalent gap on leaf level.
+ gap.node = gap.PrevSegment().node
+ gap.index = gap.node.nrSegments
+ return gap.prevLargeEnoughGapHelper(minSize)
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
+// prevLargeEnoughGapHelper is the helper function used by PrevLargeEnoughGap
+// to do the real recursions.
+//
+// Preconditions: gap is NOT the first gap of a non-leaf node.
+func (gap GapIterator) prevLargeEnoughGapHelper(minSize Key) GapIterator {
+ // Crawl up the tree if no large enough gap in current node or the
+ // current gap is the first one on leaf level.
+ for gap.node != nil &&
+ (gap.node.maxGap.Get() < minSize || (!gap.node.hasChildren && gap.index == 0)) {
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ // If no large enough gap throughout the whole set, return a terminal
+ // gap iterator.
+ if gap.node == nil {
+ return GapIterator{}
+ }
+ // Iterate previous gaps.
+ gap.index--
+ for gap.index >= 0 {
+ if gap.node.hasChildren {
+ if largeEnoughGap := gap.node.children[gap.index].searchLastLargeEnoughGap(minSize); largeEnoughGap.Ok() {
+ return largeEnoughGap
+ }
+ } else {
+ if gap.Range().Length() >= minSize {
+ return gap
+ }
+ }
+ gap.index--
+ }
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ if gap.node != nil && gap.index == 0 {
+ // If gap is the first gap of a non-leaf node, crawl up to
+ // parent again and do recursion.
+ gap.node, gap.index = gap.node.parent, gap.node.parentIndex
+ }
+ return gap.prevLargeEnoughGapHelper(minSize)
+}
+
// segmentBeforePosition returns the predecessor segment of the position given
// by n.children[i], which may or may not contain a child. If no such segment
// exists, segmentBeforePosition returns a terminal iterator.
@@ -1271,7 +1613,7 @@ func segmentAfterPosition(n *node, i int) Iterator {
func zeroValueSlice(slice []Value) {
// TODO(jamieliu): check if Go is actually smart enough to optimize a
- // ClearValue that assigns nil to a memset here
+ // ClearValue that assigns nil to a memset here.
for i := range slice {
Functions{}.ClearValue(&slice[i])
}
@@ -1310,7 +1652,15 @@ func (n *node) writeDebugString(buf *bytes.Buffer, prefix string) {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, i))
}
buf.WriteString(prefix)
- buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ if n.hasChildren {
+ if trackGaps != 0 {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v, maxGap: %d\n", i, n.keys[i], n.values[i], n.maxGap.Get()))
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
+ } else {
+ buf.WriteString(fmt.Sprintf("- % 3d: %v => %v\n", i, n.keys[i], n.values[i]))
+ }
}
if child := n.children[n.nrSegments]; child != nil {
child.writeDebugString(buf, fmt.Sprintf("%s- % 3d ", prefix, n.nrSegments))
@@ -1362,3 +1712,43 @@ func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error {
}
return nil
}
+
+// segmentTestCheck returns an error if s is incorrectly sorted, does not
+// contain exactly expectedSegments segments, or contains a segment which
+// fails the passed check.
+//
+// This should be used only for testing, and has been added to this package for
+// templating convenience.
+func (s *Set) segmentTestCheck(expectedSegments int, segFunc func(int, Range, Value) error) error {
+ havePrev := false
+ prev := Key(0)
+ nrSegments := 0
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ next := seg.Start()
+ if havePrev && prev >= next {
+ return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
+ }
+ if segFunc != nil {
+ if err := segFunc(nrSegments, seg.Range(), seg.Value()); err != nil {
+ return err
+ }
+ }
+ prev = next
+ havePrev = true
+ nrSegments++
+ }
+ if nrSegments != expectedSegments {
+ return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+ }
+ return nil
+}
+
+// countSegments counts the number of segments in the set.
+//
+// Similar to Check, this should only be used for testing.
+func (s *Set) countSegments() (segments int) {
+ for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ segments++
+ }
+ return segments
+}
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index a27c35e21..131bf09b9 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(
@@ -30,15 +29,32 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "gap_set",
+ out = "gap_set.go",
+ consts = {
+ "trackGaps": "1",
+ },
+ package = "segment",
+ prefix = "gap",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "int",
+ "Range": "Range",
+ "Value": "int",
+ "Functions": "gapSetFunctions",
+ },
+)
+
go_library(
name = "segment",
testonly = 1,
srcs = [
+ "gap_set.go",
"int_range.go",
"int_set.go",
"set_functions.go",
],
- importpath = "gvisor.dev/gvisor/pkg/segment/segment",
deps = [
"//pkg/state",
],
@@ -48,5 +64,5 @@ go_test(
name = "segment_test",
size = "small",
srcs = ["segment_test.go"],
- embed = [":segment"],
+ library = ":segment",
)
diff --git a/pkg/segment/test/segment_test.go b/pkg/segment/test/segment_test.go
index f19a005f3..85fa19096 100644
--- a/pkg/segment/test/segment_test.go
+++ b/pkg/segment/test/segment_test.go
@@ -17,6 +17,7 @@ package segment
import (
"fmt"
"math/rand"
+ "reflect"
"testing"
)
@@ -32,61 +33,65 @@ const (
// valueOffset is the difference between the value and start of test
// segments.
valueOffset = 100000
+
+ // intervalLength is the interval used by random gap tests.
+ intervalLength = 10
)
func shuffle(xs []int) {
- for i := range xs {
- j := rand.Intn(i + 1)
- xs[i], xs[j] = xs[j], xs[i]
- }
+ rand.Shuffle(len(xs), func(i, j int) { xs[i], xs[j] = xs[j], xs[i] })
}
-func randPermutation(size int) []int {
+func randIntervalPermutation(size int) []int {
p := make([]int, size)
for i := range p {
- p[i] = i
+ p[i] = intervalLength * i
}
shuffle(p)
return p
}
-// checkSet returns an error if s is incorrectly sorted, does not contain
-// exactly expectedSegments segments, or contains a segment for which val !=
-// key + valueOffset.
-func checkSet(s *Set, expectedSegments int) error {
- havePrev := false
- prev := 0
- nrSegments := 0
- for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- next := seg.Start()
- if havePrev && prev >= next {
- return fmt.Errorf("incorrect order: key %d (segment %d) >= key %d (segment %d)", prev, nrSegments-1, next, nrSegments)
- }
- if got, want := seg.Value(), seg.Start()+valueOffset; got != want {
- return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nrSegments, seg.Start, got, want)
- }
- prev = next
- havePrev = true
- nrSegments++
- }
- if nrSegments != expectedSegments {
- return fmt.Errorf("incorrect number of segments: got %d, wanted %d", nrSegments, expectedSegments)
+// validate can be passed to Check.
+func validate(nr int, r Range, v int) error {
+ if got, want := v, r.Start+valueOffset; got != want {
+ return fmt.Errorf("segment %d has key %d, value %d (expected %d)", nr, r.Start, got, want)
}
return nil
}
-// countSegmentsIn returns the number of segments in s.
-func countSegmentsIn(s *Set) int {
- var count int
- for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- count++
+// checkSetMaxGap returns an error if maxGap inside all nodes of s is not well
+// maintained.
+func checkSetMaxGap(s *gapSet) error {
+ n := s.root
+ return checkNodeMaxGap(&n)
+}
+
+// checkNodeMaxGap returns an error if maxGap inside the subtree rooted by n is
+// not well maintained.
+func checkNodeMaxGap(n *gapnode) error {
+ var max int
+ if !n.hasChildren {
+ max = n.calculateMaxGapLeaf()
+ } else {
+ for i := 0; i <= n.nrSegments; i++ {
+ child := n.children[i]
+ if err := checkNodeMaxGap(child); err != nil {
+ return err
+ }
+ if temp := child.maxGap.Get(); i == 0 || temp > max {
+ max = temp
+ }
+ }
+ }
+ if max != n.maxGap.Get() {
+ return fmt.Errorf("maxGap wrong in node\n%vexpected: %d got: %d", n, max, n.maxGap)
}
- return count
+ return nil
}
func TestAddRandom(t *testing.T) {
var s Set
- order := randPermutation(testSize)
+ order := rand.Perm(testSize)
var nrInsertions int
for i, j := range order {
if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) {
@@ -94,12 +99,12 @@ func TestAddRandom(t *testing.T) {
break
}
nrInsertions++
- if err := checkSet(&s, nrInsertions); err != nil {
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
}
- if got, want := countSegmentsIn(&s), nrInsertions; got != want {
+ if got, want := s.countSegments(), nrInsertions; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -115,7 +120,156 @@ func TestRemoveRandom(t *testing.T) {
t.Fatalf("Failed to insert segment %d", i)
}
}
- order := randPermutation(testSize)
+ order := rand.Perm(testSize)
+ var nrRemovals int
+ for i, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
+ break
+ }
+ s.Remove(seg)
+ nrRemovals++
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestMaxGapAddRandom(t *testing.T) {
+ var s gapSet
+ order := rand.Perm(testSize)
+ var nrInsertions int
+ for i, j := range order {
+ if !s.AddWithoutMerging(Range{j, j + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order[:nrInsertions])
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapAddRandomWithRandomInterval(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize)
+ var nrInsertions int
+ for i, j := range order {
+ if !s.AddWithoutMerging(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ nrInsertions++
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order[:nrInsertions])
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapAddRandomWithMerge(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize)
+ nrInsertions := 1
+ for i, j := range order {
+ if !s.Add(Range{j, j + intervalLength}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), nrInsertions; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Insertion order: %v", order)
+ t.Logf("Set contents:\n%v", &s)
+ }
+}
+
+func TestMaxGapRemoveRandom(t *testing.T) {
+ var s gapSet
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{i, i + 1}, i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := rand.Perm(testSize)
+ var nrRemovals int
+ for i, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
+ break
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ nrRemovals++
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
+ t.Errorf("Iteration %d: %v", i, err)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
+ t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestMaxGapRemoveHalfRandom(t *testing.T) {
+ var s gapSet
+ for i := 0; i < testSize; i++ {
+ if !s.AddWithoutMerging(Range{intervalLength * i, intervalLength*i + rand.Intn(intervalLength-1) + 1}, intervalLength*i+valueOffset) {
+ t.Fatalf("Failed to insert segment %d", i)
+ }
+ }
+ order := randIntervalPermutation(testSize)
+ order = order[:testSize/2]
var nrRemovals int
for i, j := range order {
seg := s.FindSegment(j)
@@ -123,14 +277,19 @@ func TestRemoveRandom(t *testing.T) {
t.Errorf("Iteration %d: failed to find segment with key %d", i, j)
break
}
+ temprange := seg.Range()
s.Remove(seg)
nrRemovals++
- if err := checkSet(&s, testSize-nrRemovals); err != nil {
+ if err := s.segmentTestCheck(testSize-nrRemovals, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
}
- if got, want := countSegmentsIn(&s), testSize-nrRemovals; got != want {
+ if got, want := s.countSegments(), testSize-nrRemovals; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -140,6 +299,148 @@ func TestRemoveRandom(t *testing.T) {
}
}
+func TestMaxGapAddRandomRemoveRandomHalfWithMerge(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + intervalLength}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ shuffle(order)
+ var nrRemovals int
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ nrRemovals++
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ if t.Failed() {
+ t.Logf("Removal order: %v", order[:nrRemovals])
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestNextLargeEnoughGap(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ shuffle(order)
+ order = order[:testSize/2]
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ minSize := 7
+ var gapArr1 []int
+ for gap := s.LowerBoundGap(0).NextLargeEnoughGap(minSize); gap.Ok(); gap = gap.NextLargeEnoughGap(minSize) {
+ if gap.Range().Length() < minSize {
+ t.Errorf("NextLargeEnoughGap wrong, gap %v has length %d, wanted %d", gap.Range(), gap.Range().Length(), minSize)
+ } else {
+ gapArr1 = append(gapArr1, gap.Range().Start)
+ }
+ }
+ var gapArr2 []int
+ for gap := s.LowerBoundGap(0).NextGap(); gap.Ok(); gap = gap.NextGap() {
+ if gap.Range().Length() >= minSize {
+ gapArr2 = append(gapArr2, gap.Range().Start)
+ }
+ }
+
+ if !reflect.DeepEqual(gapArr2, gapArr1) {
+ t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
+func TestPrevLargeEnoughGap(t *testing.T) {
+ var s gapSet
+ order := randIntervalPermutation(testSize * 2)
+ order = order[:testSize]
+ for i, j := range order {
+ if !s.Add(Range{j, j + rand.Intn(intervalLength-1) + 1}, j+valueOffset) {
+ t.Errorf("Iteration %d: failed to insert segment with key %d", i, j)
+ break
+ }
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When inserting %d: %v", j, err)
+ break
+ }
+ }
+ end := s.LastSegment().End()
+ shuffle(order)
+ order = order[:testSize/2]
+ for _, j := range order {
+ seg := s.FindSegment(j)
+ if !seg.Ok() {
+ continue
+ }
+ temprange := seg.Range()
+ s.Remove(seg)
+ if err := checkSetMaxGap(&s); err != nil {
+ t.Errorf("When removing %v: %v", temprange, err)
+ break
+ }
+ }
+ minSize := 7
+ var gapArr1 []int
+ for gap := s.UpperBoundGap(end + intervalLength).PrevLargeEnoughGap(minSize); gap.Ok(); gap = gap.PrevLargeEnoughGap(minSize) {
+ if gap.Range().Length() < minSize {
+ t.Errorf("PrevLargeEnoughGap wrong, gap length %d, wanted %d", gap.Range().Length(), minSize)
+ } else {
+ gapArr1 = append(gapArr1, gap.Range().Start)
+ }
+ }
+ var gapArr2 []int
+ for gap := s.UpperBoundGap(end + intervalLength).PrevGap(); gap.Ok(); gap = gap.PrevGap() {
+ if gap.Range().Length() >= minSize {
+ gapArr2 = append(gapArr2, gap.Range().Start)
+ }
+ }
+ if !reflect.DeepEqual(gapArr2, gapArr1) {
+ t.Errorf("Search result not correct, got: %v, wanted: %v", gapArr1, gapArr2)
+ }
+ if t.Failed() {
+ t.Logf("Set contents:\n%v", &s)
+ t.FailNow()
+ }
+}
+
func TestAddSequentialAdjacent(t *testing.T) {
var s Set
var nrInsertions int
@@ -148,12 +449,12 @@ func TestAddSequentialAdjacent(t *testing.T) {
t.Fatalf("Failed to insert segment %d", i)
}
nrInsertions++
- if err := checkSet(&s, nrInsertions); err != nil {
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
}
- if got, want := countSegmentsIn(&s), nrInsertions; got != want {
+ if got, want := s.countSegments(), nrInsertions; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -202,12 +503,12 @@ func TestAddSequentialNonAdjacent(t *testing.T) {
t.Fatalf("Failed to insert segment %d", i)
}
nrInsertions++
- if err := checkSet(&s, nrInsertions); err != nil {
+ if err := s.segmentTestCheck(nrInsertions, validate); err != nil {
t.Errorf("Iteration %d: %v", i, err)
break
}
}
- if got, want := countSegmentsIn(&s), nrInsertions; got != want {
+ if got, want := s.countSegments(), nrInsertions; got != want {
t.Errorf("Wrong final number of segments: got %d, wanted %d", got, want)
}
if t.Failed() {
@@ -293,7 +594,7 @@ Tests:
var i int
for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
if i > len(test.final) {
- t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s)
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s)
continue Tests
}
if got, want := seg.Range(), test.final[i]; got != want {
@@ -351,7 +652,7 @@ Tests:
var i int
for seg := s.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
if i > len(test.final) {
- t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, countSegmentsIn(&s), len(test.final), &s)
+ t.Errorf("%s: Incorrect number of segments: got %d, wanted %d; set contents:\n%v", test.name, s.countSegments(), len(test.final), &s)
continue Tests
}
if got, want := seg.Range(), test.final[i]; got != want {
@@ -378,7 +679,7 @@ func benchmarkAddSequential(b *testing.B, size int) {
}
func benchmarkAddRandom(b *testing.B, size int) {
- order := randPermutation(size)
+ order := rand.Perm(size)
b.ResetTimer()
for n := 0; n < b.N; n++ {
@@ -416,7 +717,7 @@ func benchmarkFindRandom(b *testing.B, size int) {
b.Fatalf("Failed to insert segment %d", i)
}
}
- order := randPermutation(size)
+ order := rand.Perm(size)
b.ResetTimer()
for n := 0; n < b.N; n++ {
@@ -470,7 +771,7 @@ func benchmarkAddFindRemoveSequential(b *testing.B, size int) {
}
func benchmarkAddFindRemoveRandom(b *testing.B, size int) {
- order := randPermutation(size)
+ order := rand.Perm(size)
b.ResetTimer()
for n := 0; n < b.N; n++ {
diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go
index bcddb39bb..7cd895cc7 100644
--- a/pkg/segment/test/set_functions.go
+++ b/pkg/segment/test/set_functions.go
@@ -14,21 +14,16 @@
package segment
-// Basic numeric constants that we define because the math package doesn't.
-// TODO(nlacasse): These should be Math.MaxInt64/MinInt64?
-const (
- maxInt = int(^uint(0) >> 1)
- minInt = -maxInt - 1
-)
-
type setFunctions struct{}
-func (setFunctions) MinKey() int {
- return minInt
+// MinKey returns the minimum key for the set.
+func (s setFunctions) MinKey() int {
+ return -s.MaxKey() - 1
}
+// MaxKey returns the maximum key for the set.
func (setFunctions) MaxKey() int {
- return maxInt
+ return int(^uint(0) >> 1)
}
func (setFunctions) ClearValue(*int) {}
@@ -40,3 +35,20 @@ func (setFunctions) Merge(_ Range, val1 int, _ Range, _ int) (int, bool) {
func (setFunctions) Split(_ Range, val int, _ int) (int, int) {
return val, val
}
+
+type gapSetFunctions struct {
+ setFunctions
+}
+
+// MinKey is adjusted to make sure no add overflow would happen in test cases.
+// e.g. A gap with range {MinInt32, 2} would cause overflow in Range().Length().
+//
+// Normally Keys should be unsigned to avoid these issues.
+func (s gapSetFunctions) MinKey() int {
+ return s.setFunctions.MinKey() / 2
+}
+
+// MaxKey returns the maximum key for the set.
+func (s gapSetFunctions) MaxKey() int {
+ return s.setFunctions.MaxKey() / 2
+}
diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD
index 2d6379c86..e759dc36f 100644
--- a/pkg/sentry/BUILD
+++ b/pkg/sentry/BUILD
@@ -1,8 +1,8 @@
-# This BUILD file defines a package_group that allows for interdependencies for
-# sentry-internal packages.
-
package(licenses = ["notice"])
+# The "internal" package_group should be used as much as possible by packages
+# that should remain Sentry-internal (i.e. not be exposed directly to command
+# line tooling or APIs).
package_group(
name = "internal",
packages = [
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 18c73cc24..901e0f320 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,6 +1,4 @@
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "proto_library")
package(licenses = ["notice"])
@@ -9,48 +7,42 @@ go_library(
srcs = [
"aligned.go",
"arch.go",
+ "arch_aarch64.go",
"arch_amd64.go",
"arch_amd64.s",
+ "arch_arm64.go",
"arch_state_x86.go",
"arch_x86.go",
+ "arch_x86_impl.go",
"auxv.go",
+ "signal.go",
"signal_act.go",
"signal_amd64.go",
+ "signal_arm64.go",
"signal_info.go",
"signal_stack.go",
"stack.go",
"syscalls_amd64.go",
+ "syscalls_arm64.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/arch",
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
":registers_go_proto",
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/context",
"//pkg/cpuid",
"//pkg/log",
- "//pkg/sentry/context",
"//pkg/sentry/limits",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
proto_library(
- name = "registers_proto",
+ name = "registers",
srcs = ["registers.proto"],
visibility = ["//visibility:public"],
)
-
-cc_proto_library(
- name = "registers_cc_proto",
- visibility = ["//visibility:public"],
- deps = [":registers_proto"],
-)
-
-go_proto_library(
- name = "registers_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto",
- proto = ":registers_proto",
- visibility = ["//visibility:public"],
-)
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index 498ca4669..a903d031c 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -24,7 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Arch describes an architecture.
@@ -88,6 +88,9 @@ type Context interface {
// SyscallNo returns the syscall number.
SyscallNo() uintptr
+ // SyscallSaveOrig save orignal register value.
+ SyscallSaveOrig()
+
// SyscallArgs returns the syscall arguments in an array.
SyscallArgs() SyscallArguments
@@ -125,9 +128,9 @@ type Context interface {
// SetTLS sets the current TLS pointer. Returns false if value is invalid.
SetTLS(value uintptr) bool
- // SetRSEQInterruptedIP sets the register that contains the old IP when a
- // restartable sequence is interrupted.
- SetRSEQInterruptedIP(value uintptr)
+ // SetOldRSeqInterruptedIP sets the register that contains the old IP
+ // when an "old rseq" restartable sequence is interrupted.
+ SetOldRSeqInterruptedIP(value uintptr)
// StateData returns a pointer to underlying architecture state.
StateData() *State
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
new file mode 100644
index 000000000..0f433ee79
--- /dev/null
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -0,0 +1,326 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+import (
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/log"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Registers represents the CPU registers for this architecture.
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+
+ // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register.
+ TPIDR_EL0 uint64
+}
+
+const (
+ // SyscallWidth is the width of insturctions.
+ SyscallWidth = 4
+
+ // fpsimdMagic is the magic number which is used in fpsimd_context.
+ fpsimdMagic = 0x46508001
+
+ // fpsimdContextSize is the size of fpsimd_context.
+ fpsimdContextSize = 0x210
+)
+
+// ARMTrapFlag is the mask for the trap flag.
+const ARMTrapFlag = uint64(1) << 21
+
+// aarch64FPState is aarch64 floating point state.
+type aarch64FPState []byte
+
+// initAarch64FPState sets up initial state.
+//
+// Related code in Linux kernel: fpsimd_flush_thread().
+// FPCR = FPCR_RM_RN (0x0 << 22).
+//
+// Currently, aarch64FPState is only a space of 0x210 length for fpstate.
+// The fp head is useless in sentry/ptrace/kvm.
+//
+func initAarch64FPState(data aarch64FPState) {
+}
+
+func newAarch64FPStateSlice() []byte {
+ return alignedBytes(4096, 16)[:fpsimdContextSize]
+}
+
+// newAarch64FPState returns an initialized floating point state.
+//
+// The returned state is large enough to store all floating point state
+// supported by host, even if the app won't use much of it due to a restricted
+// FeatureSet.
+func newAarch64FPState() aarch64FPState {
+ f := aarch64FPState(newAarch64FPStateSlice())
+ initAarch64FPState(f)
+ return f
+}
+
+// fork creates and returns an identical copy of the aarch64 floating point state.
+func (f aarch64FPState) fork() aarch64FPState {
+ n := aarch64FPState(newAarch64FPStateSlice())
+ copy(n, f)
+ return n
+}
+
+// FloatingPointData returns the raw data pointer.
+func (f aarch64FPState) FloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&f[0])
+}
+
+// NewFloatingPointData returns a new floating point data blob.
+//
+// This is primarily for use in tests.
+func NewFloatingPointData() *FloatingPointData {
+ return (*FloatingPointData)(&(newAarch64FPState()[0]))
+}
+
+// State contains the common architecture bits for aarch64 (the build tag of this
+// file ensures it's only built on aarch64).
+//
+// +stateify savable
+type State struct {
+ // The system registers.
+ Regs Registers
+
+ // Our floating point state.
+ aarch64FPState `state:"wait"`
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+
+ // OrigR0 stores the value of register R0.
+ OrigR0 uint64
+}
+
+// Proto returns a protobuf representation of the system registers in State.
+func (s State) Proto() *rpb.Registers {
+ regs := &rpb.ARM64Registers{
+ R0: s.Regs.Regs[0],
+ R1: s.Regs.Regs[1],
+ R2: s.Regs.Regs[2],
+ R3: s.Regs.Regs[3],
+ R4: s.Regs.Regs[4],
+ R5: s.Regs.Regs[5],
+ R6: s.Regs.Regs[6],
+ R7: s.Regs.Regs[7],
+ R8: s.Regs.Regs[8],
+ R9: s.Regs.Regs[9],
+ R10: s.Regs.Regs[10],
+ R11: s.Regs.Regs[11],
+ R12: s.Regs.Regs[12],
+ R13: s.Regs.Regs[13],
+ R14: s.Regs.Regs[14],
+ R15: s.Regs.Regs[15],
+ R16: s.Regs.Regs[16],
+ R17: s.Regs.Regs[17],
+ R18: s.Regs.Regs[18],
+ R19: s.Regs.Regs[19],
+ R20: s.Regs.Regs[20],
+ R21: s.Regs.Regs[21],
+ R22: s.Regs.Regs[22],
+ R23: s.Regs.Regs[23],
+ R24: s.Regs.Regs[24],
+ R25: s.Regs.Regs[25],
+ R26: s.Regs.Regs[26],
+ R27: s.Regs.Regs[27],
+ R28: s.Regs.Regs[28],
+ R29: s.Regs.Regs[29],
+ R30: s.Regs.Regs[30],
+ Sp: s.Regs.Sp,
+ Pc: s.Regs.Pc,
+ Pstate: s.Regs.Pstate,
+ }
+ return &rpb.Registers{Arch: &rpb.Registers_Arm64{Arm64: regs}}
+}
+
+// Fork creates and returns an identical copy of the state.
+func (s *State) Fork() State {
+ return State{
+ Regs: s.Regs,
+ aarch64FPState: s.aarch64FPState.fork(),
+ FeatureSet: s.FeatureSet,
+ OrigR0: s.OrigR0,
+ }
+}
+
+// StateData implements Context.StateData.
+func (s *State) StateData() *State {
+ return s
+}
+
+// CPUIDEmulate emulates a cpuid instruction.
+func (s *State) CPUIDEmulate(l log.Logger) {
+ // TODO(gvisor.dev/issue/1255): cpuid is not supported.
+}
+
+// SingleStep implements Context.SingleStep.
+func (s *State) SingleStep() bool {
+ return false
+}
+
+// SetSingleStep enables single stepping.
+func (s *State) SetSingleStep() {
+ // Set the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// ClearSingleStep enables single stepping.
+func (s *State) ClearSingleStep() {
+ // Clear the trap flag.
+ // TODO(gvisor.dev/issue/1239): ptrace single-step is not supported.
+}
+
+// RegisterMap returns a map of all registers.
+func (s *State) RegisterMap() (map[string]uintptr, error) {
+ return map[string]uintptr{
+ "R0": uintptr(s.Regs.Regs[0]),
+ "R1": uintptr(s.Regs.Regs[1]),
+ "R2": uintptr(s.Regs.Regs[2]),
+ "R3": uintptr(s.Regs.Regs[3]),
+ "R4": uintptr(s.Regs.Regs[4]),
+ "R5": uintptr(s.Regs.Regs[5]),
+ "R6": uintptr(s.Regs.Regs[6]),
+ "R7": uintptr(s.Regs.Regs[7]),
+ "R8": uintptr(s.Regs.Regs[8]),
+ "R9": uintptr(s.Regs.Regs[9]),
+ "R10": uintptr(s.Regs.Regs[10]),
+ "R11": uintptr(s.Regs.Regs[11]),
+ "R12": uintptr(s.Regs.Regs[12]),
+ "R13": uintptr(s.Regs.Regs[13]),
+ "R14": uintptr(s.Regs.Regs[14]),
+ "R15": uintptr(s.Regs.Regs[15]),
+ "R16": uintptr(s.Regs.Regs[16]),
+ "R17": uintptr(s.Regs.Regs[17]),
+ "R18": uintptr(s.Regs.Regs[18]),
+ "R19": uintptr(s.Regs.Regs[19]),
+ "R20": uintptr(s.Regs.Regs[20]),
+ "R21": uintptr(s.Regs.Regs[21]),
+ "R22": uintptr(s.Regs.Regs[22]),
+ "R23": uintptr(s.Regs.Regs[23]),
+ "R24": uintptr(s.Regs.Regs[24]),
+ "R25": uintptr(s.Regs.Regs[25]),
+ "R26": uintptr(s.Regs.Regs[26]),
+ "R27": uintptr(s.Regs.Regs[27]),
+ "R28": uintptr(s.Regs.Regs[28]),
+ "R29": uintptr(s.Regs.Regs[29]),
+ "R30": uintptr(s.Regs.Regs[30]),
+ "Sp": uintptr(s.Regs.Sp),
+ "Pc": uintptr(s.Regs.Pc),
+ "Pstate": uintptr(s.Regs.Pstate),
+ }, nil
+}
+
+// PtraceGetRegs implements Context.PtraceGetRegs.
+func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
+ regs := s.ptraceGetRegs()
+ n, err := regs.WriteTo(dst)
+ return int(n), err
+}
+
+func (s *State) ptraceGetRegs() Registers {
+ return s.Regs
+}
+
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
+
+// PtraceSetRegs implements Context.PtraceSetRegs.
+func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
+ var regs Registers
+ buf := make([]byte, ptraceRegistersSize)
+ if _, err := io.ReadFull(src, buf); err != nil {
+ return 0, err
+ }
+ regs.UnmarshalUnsafe(buf)
+ s.Regs = regs
+ return ptraceRegistersSize, nil
+}
+
+// PtraceGetFPRegs implements Context.PtraceGetFPRegs.
+func (s *State) PtraceGetFPRegs(dst io.Writer) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// PtraceSetFPRegs implements Context.PtraceSetFPRegs.
+func (s *State) PtraceSetFPRegs(src io.Reader) (int, error) {
+ // TODO(gvisor.dev/issue/1238): floating-point is not supported.
+ return 0, nil
+}
+
+// Register sets defined in include/uapi/linux/elf.h.
+const (
+ _NT_PRSTATUS = 1
+ _NT_PRFPREG = 2
+ _NT_ARM_TLS = 0x401
+)
+
+// PtraceGetRegSet implements Context.PtraceGetRegSet.
+func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegistersSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceGetRegs(dst)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// PtraceSetRegSet implements Context.PtraceSetRegSet.
+func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
+ switch regset {
+ case _NT_PRSTATUS:
+ if maxlen < ptraceRegistersSize {
+ return 0, syserror.EFAULT
+ }
+ return s.PtraceSetRegs(src)
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// FullRestore indicates whether a full restore is required.
+func (s *State) FullRestore() bool {
+ return false
+}
+
+// New returns a new architecture context.
+func New(arch Arch, fs *cpuid.FeatureSet) Context {
+ switch arch {
+ case ARM64:
+ return &context64{
+ State{
+ aarch64FPState: newAarch64FPState(),
+ FeatureSet: fs,
+ },
+ []aarch64FPState(nil),
+ }
+ }
+ panic(fmt.Sprintf("unknown architecture %v", arch))
+}
diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go
index 9e7db8b30..1c3e3c14c 100644
--- a/pkg/sentry/arch/arch_amd64.go
+++ b/pkg/sentry/arch/arch_amd64.go
@@ -22,10 +22,9 @@ import (
"math/rand"
"syscall"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Host specifies the host architecture.
@@ -174,8 +173,8 @@ func (c *context64) SetTLS(value uintptr) bool {
return true
}
-// SetRSEQInterruptedIP implements Context.SetRSEQInterruptedIP.
-func (c *context64) SetRSEQInterruptedIP(value uintptr) {
+// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
+func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
c.Regs.R10 = uint64(value)
}
@@ -301,11 +300,13 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
// PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and
// u_debugreg, returning 0 or silently no-oping for other fields
// respectively.
- if addr < uintptr(ptraceRegsSize) {
- buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
+ if addr < uintptr(ptraceRegistersSize) {
+ regs := c.ptraceGetRegs()
+ buf := make([]byte, regs.SizeBytes())
+ regs.MarshalUnsafe(buf)
return c.Native(uintptr(usermem.ByteOrder.Uint64(buf[addr:]))), nil
}
- // TODO(b/34088053): debug registers
+ // Note: x86 debug registers are missing.
return c.Native(0), nil
}
@@ -314,12 +315,14 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error {
if addr&7 != 0 || addr >= userStructSize {
return syscall.EIO
}
- if addr < uintptr(ptraceRegsSize) {
- buf := binary.Marshal(nil, usermem.ByteOrder, c.ptraceGetRegs())
+ if addr < uintptr(ptraceRegistersSize) {
+ regs := c.ptraceGetRegs()
+ buf := make([]byte, regs.SizeBytes())
+ regs.MarshalUnsafe(buf)
usermem.ByteOrder.PutUint64(buf[addr:], uint64(data))
_, err := c.PtraceSetRegs(bytes.NewBuffer(buf))
return err
}
- // TODO(b/34088053): debug registers
+ // Note: x86 debug registers are missing.
return nil
}
diff --git a/pkg/sentry/arch/arch_amd64.s b/pkg/sentry/arch/arch_amd64.s
index bd61402cf..6c10336e7 100644
--- a/pkg/sentry/arch/arch_amd64.s
+++ b/pkg/sentry/arch/arch_amd64.s
@@ -26,10 +26,11 @@
//
// func initX86FPState(data *FloatingPointData, useXsave bool)
//
-// We need to clear out and initialize an empty fp state area since the sentry
-// may have left sensitive information in the floating point registers.
+// We need to clear out and initialize an empty fp state area since the sentry,
+// or any previous loader, may have left sensitive information in the floating
+// point registers.
//
-// Preconditions: data is zeroed
+// Preconditions: data is zeroed.
TEXT ·initX86FPState(SB), $24-16
// Save MXCSR (callee-save)
STMXCSR mxcsr-8(SP)
diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go
new file mode 100644
index 000000000..550741d8c
--- /dev/null
+++ b/pkg/sentry/arch/arch_arm64.go
@@ -0,0 +1,286 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+import (
+ "fmt"
+ "math/rand"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Host specifies the host architecture.
+const Host = ARM64
+
+// These constants come directly from Linux.
+const (
+ // maxAddr64 is the maximum userspace address. It is TASK_SIZE in Linux
+ // for a 64-bit process.
+ maxAddr64 usermem.Addr = (1 << 48)
+
+ // maxStackRand64 is the maximum randomization to apply to the stack.
+ // It is defined by arch/arm64/mm/mmap.c:(STACK_RND_MASK << PAGE_SHIFT) in Linux.
+ maxStackRand64 = 0x3ffff << 12 // 16 GB
+
+ // maxMmapRand64 is the maximum randomization to apply to the mmap
+ // layout. It is defined by arch/arm64/mm/mmap.c:arch_mmap_rnd in Linux.
+ maxMmapRand64 = (1 << 33) * usermem.PageSize
+
+ // minGap64 is the minimum gap to leave at the top of the address space
+ // for the stack. It is defined by arch/arm64/mm/mmap.c:MIN_GAP in Linux.
+ minGap64 = (128 << 20) + maxStackRand64
+
+ // preferredPIELoadAddr is the standard Linux position-independent
+ // executable base load address. It is ELF_ET_DYN_BASE in Linux.
+ //
+ // The Platform {Min,Max}UserAddress() may preclude loading at this
+ // address. See other preferredFoo comments below.
+ preferredPIELoadAddr usermem.Addr = maxAddr64 / 6 * 5
+)
+
+var (
+ // CPUIDInstruction doesn't exist on ARM64.
+ CPUIDInstruction = []byte{}
+)
+
+// These constants are selected as heuristics to help make the Platform's
+// potentially limited address space conform as closely to Linux as possible.
+const (
+ preferredTopDownAllocMin usermem.Addr = 0x7e8000000000
+ preferredAllocationGap = 128 << 30 // 128 GB
+ preferredTopDownBaseMin = preferredTopDownAllocMin + preferredAllocationGap
+
+ // minMmapRand64 is the smallest we are willing to make the
+ // randomization to stay above preferredTopDownBaseMin.
+ minMmapRand64 = (1 << 18) * usermem.PageSize
+)
+
+// context64 represents an ARM64 context.
+//
+// +stateify savable
+type context64 struct {
+ State
+ sigFPState []aarch64FPState // fpstate to be restored on sigreturn.
+}
+
+// Arch implements Context.Arch.
+func (c *context64) Arch() Arch {
+ return ARM64
+}
+
+func (c *context64) copySigFPState() []aarch64FPState {
+ var sigfps []aarch64FPState
+ for _, s := range c.sigFPState {
+ sigfps = append(sigfps, s.fork())
+ }
+ return sigfps
+}
+
+// Fork returns an exact copy of this context.
+func (c *context64) Fork() Context {
+ return &context64{
+ State: c.State.Fork(),
+ sigFPState: c.copySigFPState(),
+ }
+}
+
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary rgisters.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+
+// Return returns the current syscall return value.
+func (c *context64) Return() uintptr {
+ return uintptr(c.Regs.Regs[0])
+}
+
+// SetReturn sets the syscall return value.
+func (c *context64) SetReturn(value uintptr) {
+ c.Regs.Regs[0] = uint64(value)
+}
+
+// IP returns the current instruction pointer.
+func (c *context64) IP() uintptr {
+ return uintptr(c.Regs.Pc)
+}
+
+// SetIP sets the current instruction pointer.
+func (c *context64) SetIP(value uintptr) {
+ c.Regs.Pc = uint64(value)
+}
+
+// Stack returns the current stack pointer.
+func (c *context64) Stack() uintptr {
+ return uintptr(c.Regs.Sp)
+}
+
+// SetStack sets the current stack pointer.
+func (c *context64) SetStack(value uintptr) {
+ c.Regs.Sp = uint64(value)
+}
+
+// TLS returns the current TLS pointer.
+func (c *context64) TLS() uintptr {
+ return uintptr(c.Regs.TPIDR_EL0)
+}
+
+// SetTLS sets the current TLS pointer. Returns false if value is invalid.
+func (c *context64) SetTLS(value uintptr) bool {
+ if value >= uintptr(maxAddr64) {
+ return false
+ }
+
+ c.Regs.TPIDR_EL0 = uint64(value)
+ return true
+}
+
+// SetOldRSeqInterruptedIP implements Context.SetOldRSeqInterruptedIP.
+func (c *context64) SetOldRSeqInterruptedIP(value uintptr) {
+ c.Regs.Regs[3] = uint64(value)
+}
+
+// Native returns the native type for the given val.
+func (c *context64) Native(val uintptr) interface{} {
+ v := uint64(val)
+ return &v
+}
+
+// Value returns the generic val for the given native type.
+func (c *context64) Value(val interface{}) uintptr {
+ return uintptr(*val.(*uint64))
+}
+
+// Width returns the byte width of this architecture.
+func (c *context64) Width() uint {
+ return 8
+}
+
+// FeatureSet returns the FeatureSet in use.
+func (c *context64) FeatureSet() *cpuid.FeatureSet {
+ return c.State.FeatureSet
+}
+
+// mmapRand returns a random adjustment for randomizing an mmap layout.
+func mmapRand(max uint64) usermem.Addr {
+ return usermem.Addr(rand.Int63n(int64(max))).RoundDown()
+}
+
+// NewMmapLayout implements Context.NewMmapLayout consistently with Linux.
+func (c *context64) NewMmapLayout(min, max usermem.Addr, r *limits.LimitSet) (MmapLayout, error) {
+ min, ok := min.RoundUp()
+ if !ok {
+ return MmapLayout{}, syscall.EINVAL
+ }
+ if max > maxAddr64 {
+ max = maxAddr64
+ }
+ max = max.RoundDown()
+
+ if min > max {
+ return MmapLayout{}, syscall.EINVAL
+ }
+
+ stackSize := r.Get(limits.Stack)
+
+ // MAX_GAP in Linux.
+ maxGap := (max / 6) * 5
+ gap := usermem.Addr(stackSize.Cur)
+ if gap < minGap64 {
+ gap = minGap64
+ }
+ if gap > maxGap {
+ gap = maxGap
+ }
+ defaultDir := MmapTopDown
+ if stackSize.Cur == limits.Infinity {
+ defaultDir = MmapBottomUp
+ }
+
+ topDownMin := max - gap - maxMmapRand64
+ maxRand := usermem.Addr(maxMmapRand64)
+ if topDownMin < preferredTopDownBaseMin {
+ // Try to keep TopDownBase above preferredTopDownBaseMin by
+ // shrinking maxRand.
+ maxAdjust := maxRand - minMmapRand64
+ needAdjust := preferredTopDownBaseMin - topDownMin
+ if needAdjust <= maxAdjust {
+ maxRand -= needAdjust
+ }
+ }
+
+ rnd := mmapRand(uint64(maxRand))
+ l := MmapLayout{
+ MinAddr: min,
+ MaxAddr: max,
+ // TASK_UNMAPPED_BASE in Linux.
+ BottomUpBase: (max/3 + rnd).RoundDown(),
+ TopDownBase: (max - gap - rnd).RoundDown(),
+ DefaultDirection: defaultDir,
+ // We may have reduced the maximum randomization to keep
+ // TopDownBase above preferredTopDownBaseMin while maintaining
+ // our stack gap. Stack allocations must use that max
+ // randomization to avoiding eating into the gap.
+ MaxStackRand: uint64(maxRand),
+ }
+
+ // Final sanity check on the layout.
+ if !l.Valid() {
+ panic(fmt.Sprintf("Invalid MmapLayout: %+v", l))
+ }
+
+ return l, nil
+}
+
+// PIELoadAddress implements Context.PIELoadAddress.
+func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr {
+ base := preferredPIELoadAddr
+ max, ok := base.AddLength(maxMmapRand64)
+ if !ok {
+ panic(fmt.Sprintf("preferredPIELoadAddr %#x too large", base))
+ }
+
+ if max > l.MaxAddr {
+ // preferredPIELoadAddr won't fit; fall back to the standard
+ // Linux behavior of 2/3 of TopDownBase. TSAN won't like this.
+ //
+ // Don't bother trying to shrink the randomization for now.
+ base = l.TopDownBase / 3 * 2
+ }
+
+ return base + mmapRand(maxMmapRand64)
+}
+
+// PtracePeekUser implements Context.PtracePeekUser.
+func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return c.Native(0), nil
+}
+
+// PtracePokeUser implements Context.PtracePokeUser.
+func (c *context64) PtracePokeUser(addr, data uintptr) error {
+ // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64.
+ return nil
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 9061fcc86..19ce99d25 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64 386
+
package arch
import (
"fmt"
- "syscall"
"gvisor.dev/gvisor/pkg/cpuid"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// ErrFloatingPoint indicates a failed restore due to unusable floating point
@@ -41,8 +42,8 @@ func (e ErrFloatingPoint) Error() string {
// and SSE state, so this is the equivalent XSTATE_BV value.
const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
-// afterLoad is invoked by stateify.
-func (s *State) afterLoad() {
+// afterLoadFPState is invoked by afterLoad.
+func (s *State) afterLoadFPState() {
old := s.x86FPState
// Recreate the slice. This is done to ensure that it is aligned
@@ -88,44 +89,3 @@ func (s *State) afterLoad() {
// Copy to the new, aligned location.
copy(s.x86FPState, old)
}
-
-// +stateify savable
-type syscallPtraceRegs struct {
- R15 uint64
- R14 uint64
- R13 uint64
- R12 uint64
- Rbp uint64
- Rbx uint64
- R11 uint64
- R10 uint64
- R9 uint64
- R8 uint64
- Rax uint64
- Rcx uint64
- Rdx uint64
- Rsi uint64
- Rdi uint64
- Orig_rax uint64
- Rip uint64
- Cs uint64
- Eflags uint64
- Rsp uint64
- Ss uint64
- Fs_base uint64
- Gs_base uint64
- Ds uint64
- Es uint64
- Fs uint64
- Gs uint64
-}
-
-// saveRegs is invoked by stateify.
-func (s *State) saveRegs() syscallPtraceRegs {
- return syscallPtraceRegs(s.Regs)
-}
-
-// loadRegs is invoked by stateify.
-func (s *State) loadRegs(r syscallPtraceRegs) {
- s.Regs = syscall.PtraceRegs(r)
-}
diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go
index 9294ac773..b9405b320 100644
--- a/pkg/sentry/arch/arch_x86.go
+++ b/pkg/sentry/arch/arch_x86.go
@@ -12,24 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package arch
import (
"fmt"
"io"
- "sync"
"syscall"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// Registers represents the CPU registers for this architecture.
+//
+// +stateify savable
+type Registers struct {
+ linux.PtraceRegs
+}
+
// System-related constants for x86.
const (
// SyscallWidth is the width of syscall, sysenter, and int 80 insturctions.
@@ -114,6 +121,10 @@ func newX86FPStateSlice() []byte {
size, align := cpuid.HostFeatureSet().ExtendedStateSize()
capacity := size
// Always use at least 4096 bytes.
+ //
+ // For the KVM platform, this state is a fixed 4096 bytes, so make sure
+ // that the underlying array is at _least_ that size otherwise we will
+ // corrupt random memory. This is not a pleasant thing to debug.
if capacity < 4096 {
capacity = 4096
}
@@ -151,21 +162,6 @@ func NewFloatingPointData() *FloatingPointData {
return (*FloatingPointData)(&(newX86FPState()[0]))
}
-// State contains the common architecture bits for X86 (the build tag of this
-// file ensures it's only built on x86).
-//
-// +stateify savable
-type State struct {
- // The system registers.
- Regs syscall.PtraceRegs `state:".(syscallPtraceRegs)"`
-
- // Our floating point state.
- x86FPState `state:"wait"`
-
- // FeatureSet is a pointer to the currently active feature set.
- FeatureSet *cpuid.FeatureSet
-}
-
// Proto returns a protobuf representation of the system registers in State.
func (s State) Proto() *rpb.Registers {
regs := &rpb.AMD64Registers{
@@ -278,10 +274,12 @@ func (s *State) RegisterMap() (map[string]uintptr, error) {
// PtraceGetRegs implements Context.PtraceGetRegs.
func (s *State) PtraceGetRegs(dst io.Writer) (int, error) {
- return dst.Write(binary.Marshal(nil, usermem.ByteOrder, s.ptraceGetRegs()))
+ regs := s.ptraceGetRegs()
+ n, err := regs.WriteTo(dst)
+ return int(n), err
}
-func (s *State) ptraceGetRegs() syscall.PtraceRegs {
+func (s *State) ptraceGetRegs() Registers {
regs := s.Regs
// These may not be initialized.
if regs.Cs == 0 || regs.Ss == 0 || regs.Eflags == 0 {
@@ -317,16 +315,16 @@ func (s *State) ptraceGetRegs() syscall.PtraceRegs {
return regs
}
-var ptraceRegsSize = int(binary.Size(syscall.PtraceRegs{}))
+var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes()
// PtraceSetRegs implements Context.PtraceSetRegs.
func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
- var regs syscall.PtraceRegs
- buf := make([]byte, ptraceRegsSize)
+ var regs Registers
+ buf := make([]byte, ptraceRegistersSize)
if _, err := io.ReadFull(src, buf); err != nil {
return 0, err
}
- binary.Unmarshal(buf, usermem.ByteOrder, &regs)
+ regs.UnmarshalUnsafe(buf)
// Truncate segment registers to 16 bits.
regs.Cs = uint64(uint16(regs.Cs))
regs.Ds = uint64(uint16(regs.Ds))
@@ -380,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) {
}
regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable)
s.Regs = regs
- return ptraceRegsSize, nil
+ return ptraceRegistersSize, nil
}
// isUserSegmentSelector returns true if the given segment selector specifies a
@@ -549,7 +547,7 @@ const (
func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < ptraceRegsSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceGetRegs(dst)
@@ -569,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int,
func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) {
switch regset {
case _NT_PRSTATUS:
- if maxlen < ptraceRegsSize {
+ if maxlen < ptraceRegistersSize {
return 0, syserror.EFAULT
}
return s.PtraceSetRegs(src)
diff --git a/pkg/sentry/arch/arch_x86_impl.go b/pkg/sentry/arch/arch_x86_impl.go
new file mode 100644
index 000000000..0c73fcbfb
--- /dev/null
+++ b/pkg/sentry/arch/arch_x86_impl.go
@@ -0,0 +1,41 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 386
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/cpuid"
+)
+
+// State contains the common architecture bits for X86 (the build tag of this
+// file ensures it's only built on x86).
+//
+// +stateify savable
+type State struct {
+ // The system registers.
+ Regs Registers
+
+ // Our floating point state.
+ x86FPState `state:"wait"`
+
+ // FeatureSet is a pointer to the currently active feature set.
+ FeatureSet *cpuid.FeatureSet
+}
+
+// afterLoad is invoked by stateify.
+func (s *State) afterLoad() {
+ s.afterLoadFPState()
+}
diff --git a/pkg/sentry/arch/auxv.go b/pkg/sentry/arch/auxv.go
index 4546b2ef9..2b4c8f3fc 100644
--- a/pkg/sentry/arch/auxv.go
+++ b/pkg/sentry/arch/auxv.go
@@ -15,7 +15,7 @@
package arch
import (
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// An AuxEntry represents an entry in an ELF auxiliary vector.
diff --git a/pkg/sentry/arch/registers.proto b/pkg/sentry/arch/registers.proto
index 9dc83e241..60c027aab 100644
--- a/pkg/sentry/arch/registers.proto
+++ b/pkg/sentry/arch/registers.proto
@@ -48,8 +48,45 @@ message AMD64Registers {
uint64 gs_base = 27;
}
+message ARM64Registers {
+ uint64 r0 = 1;
+ uint64 r1 = 2;
+ uint64 r2 = 3;
+ uint64 r3 = 4;
+ uint64 r4 = 5;
+ uint64 r5 = 6;
+ uint64 r6 = 7;
+ uint64 r7 = 8;
+ uint64 r8 = 9;
+ uint64 r9 = 10;
+ uint64 r10 = 11;
+ uint64 r11 = 12;
+ uint64 r12 = 13;
+ uint64 r13 = 14;
+ uint64 r14 = 15;
+ uint64 r15 = 16;
+ uint64 r16 = 17;
+ uint64 r17 = 18;
+ uint64 r18 = 19;
+ uint64 r19 = 20;
+ uint64 r20 = 21;
+ uint64 r21 = 22;
+ uint64 r22 = 23;
+ uint64 r23 = 24;
+ uint64 r24 = 25;
+ uint64 r25 = 26;
+ uint64 r26 = 27;
+ uint64 r27 = 28;
+ uint64 r28 = 29;
+ uint64 r29 = 30;
+ uint64 r30 = 31;
+ uint64 sp = 32;
+ uint64 pc = 33;
+ uint64 pstate = 34;
+}
message Registers {
oneof arch {
AMD64Registers amd64 = 1;
+ ARM64Registers arm64 = 2;
}
}
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
new file mode 100644
index 000000000..c9fb55d00
--- /dev/null
+++ b/pkg/sentry/arch/signal.go
@@ -0,0 +1,253 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SignalAct represents the action that should be taken when a signal is
+// delivered, and is equivalent to struct sigaction.
+//
+// +marshal
+// +stateify savable
+type SignalAct struct {
+ Handler uint64
+ Flags uint64
+ Restorer uint64 // Only used on amd64.
+ Mask linux.SignalSet
+}
+
+// SerializeFrom implements NativeSignalAct.SerializeFrom.
+func (s *SignalAct) SerializeFrom(other *SignalAct) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalAct.DeserializeTo.
+func (s *SignalAct) DeserializeTo(other *SignalAct) {
+ *other = *s
+}
+
+// SignalStack represents information about a user stack, and is equivalent to
+// stack_t.
+//
+// +marshal
+// +stateify savable
+type SignalStack struct {
+ Addr uint64
+ Flags uint32
+ _ uint32
+ Size uint64
+}
+
+// SerializeFrom implements NativeSignalStack.SerializeFrom.
+func (s *SignalStack) SerializeFrom(other *SignalStack) {
+ *s = *other
+}
+
+// DeserializeTo implements NativeSignalStack.DeserializeTo.
+func (s *SignalStack) DeserializeTo(other *SignalStack) {
+ *other = *s
+}
+
+// SignalInfo represents information about a signal being delivered, and is
+// equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h).
+//
+// +marshal
+// +stateify savable
+type SignalInfo struct {
+ Signo int32 // Signal number
+ Errno int32 // Errno value
+ Code int32 // Signal code
+ _ uint32
+
+ // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
+ // are accessed through methods.
+ //
+ // For reference, here is the definition of _sifields: (_sigfault._trapno,
+ // which does not exist on x86, omitted for clarity)
+ //
+ // union {
+ // int _pad[SI_PAD_SIZE];
+ //
+ // /* kill() */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // } _kill;
+ //
+ // /* POSIX.1b timers */
+ // struct {
+ // __kernel_timer_t _tid; /* timer id */
+ // int _overrun; /* overrun count */
+ // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
+ // sigval_t _sigval; /* same as below */
+ // int _sys_private; /* not to be passed to user */
+ // } _timer;
+ //
+ // /* POSIX.1b signals */
+ // struct {
+ // __kernel_pid_t _pid; /* sender's pid */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // sigval_t _sigval;
+ // } _rt;
+ //
+ // /* SIGCHLD */
+ // struct {
+ // __kernel_pid_t _pid; /* which child */
+ // __ARCH_SI_UID_T _uid; /* sender's uid */
+ // int _status; /* exit code */
+ // __ARCH_SI_CLOCK_T _utime;
+ // __ARCH_SI_CLOCK_T _stime;
+ // } _sigchld;
+ //
+ // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
+ // struct {
+ // void *_addr; /* faulting insn/memory ref. */
+ // short _addr_lsb; /* LSB of the reported address */
+ // } _sigfault;
+ //
+ // /* SIGPOLL */
+ // struct {
+ // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
+ // int _fd;
+ // } _sigpoll;
+ //
+ // /* SIGSYS */
+ // struct {
+ // void *_call_addr; /* calling user insn */
+ // int _syscall; /* triggering system call number */
+ // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
+ // } _sigsys;
+ // } _sifields;
+ //
+ // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
+ // bytes.
+ Fields [128 - 16]byte
+}
+
+// FixSignalCodeForUser fixes up si_code.
+//
+// The si_code we get from Linux may contain the kernel-specific code in the
+// top 16 bits if it's positive (e.g., from ptrace). Linux's
+// copy_siginfo_to_user does
+// err |= __put_user((short)from->si_code, &to->si_code);
+// to mask out those bits and we need to do the same.
+func (s *SignalInfo) FixSignalCodeForUser() {
+ if s.Code > 0 {
+ s.Code &= 0x0000ffff
+ }
+}
+
+// Pid returns the si_pid field.
+func (s *SignalInfo) Pid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetPid mutates the si_pid field.
+func (s *SignalInfo) SetPid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Uid returns the si_uid field.
+func (s *SignalInfo) Uid() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetUid mutates the si_uid field.
+func (s *SignalInfo) SetUid(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
+func (s *SignalInfo) Sigval() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[8:16])
+}
+
+// SetSigval mutates the sigval field.
+func (s *SignalInfo) SetSigval(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
+}
+
+// TimerID returns the si_timerid field.
+func (s *SignalInfo) TimerID() linux.TimerID {
+ return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
+}
+
+// SetTimerID sets the si_timerid field.
+func (s *SignalInfo) SetTimerID(val linux.TimerID) {
+ usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
+}
+
+// Overrun returns the si_overrun field.
+func (s *SignalInfo) Overrun() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
+}
+
+// SetOverrun sets the si_overrun field.
+func (s *SignalInfo) SetOverrun(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
+}
+
+// Addr returns the si_addr field.
+func (s *SignalInfo) Addr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetAddr sets the si_addr field.
+func (s *SignalInfo) SetAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Status returns the si_status field.
+func (s *SignalInfo) Status() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetStatus mutates the si_status field.
+func (s *SignalInfo) SetStatus(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// CallAddr returns the si_call_addr field.
+func (s *SignalInfo) CallAddr() uint64 {
+ return usermem.ByteOrder.Uint64(s.Fields[0:8])
+}
+
+// SetCallAddr mutates the si_call_addr field.
+func (s *SignalInfo) SetCallAddr(val uint64) {
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
+}
+
+// Syscall returns the si_syscall field.
+func (s *SignalInfo) Syscall() int32 {
+ return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
+}
+
+// SetSyscall mutates the si_syscall field.
+func (s *SignalInfo) SetSyscall(val int32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
+}
+
+// Arch returns the si_arch field.
+func (s *SignalInfo) Arch() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[12:16])
+}
+
+// SetArch mutates the si_arch field.
+func (s *SignalInfo) SetArch(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
+}
diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go
index f9ca2e74e..32173aa20 100644
--- a/pkg/sentry/arch/signal_act.go
+++ b/pkg/sentry/arch/signal_act.go
@@ -14,6 +14,8 @@
package arch
+import "gvisor.dev/gvisor/tools/go_marshal/marshal"
+
// Special values for SignalAct.Handler.
const (
// SignalActDefault is SIG_DFL and specifies that the default behavior for
@@ -71,6 +73,8 @@ func (s SignalAct) HasRestorer() bool {
// NativeSignalAct is a type that is equivalent to struct sigaction in the
// guest architecture.
type NativeSignalAct interface {
+ marshal.Marshallable
+
// SerializeFrom copies the data in the host SignalAct s into this object.
SerializeFrom(s *SignalAct)
diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go
index febd6f9b9..6fb756f0e 100644
--- a/pkg/sentry/arch/signal_amd64.go
+++ b/pkg/sentry/arch/signal_amd64.go
@@ -23,239 +23,9 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// SignalAct represents the action that should be taken when a signal is
-// delivered, and is equivalent to struct sigaction on 64-bit x86.
-//
-// +stateify savable
-type SignalAct struct {
- Handler uint64
- Flags uint64
- Restorer uint64
- Mask linux.SignalSet
-}
-
-// SerializeFrom implements NativeSignalAct.SerializeFrom.
-func (s *SignalAct) SerializeFrom(other *SignalAct) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalAct.DeserializeTo.
-func (s *SignalAct) DeserializeTo(other *SignalAct) {
- *other = *s
-}
-
-// SignalStack represents information about a user stack, and is equivalent to
-// stack_t on 64-bit x86.
-//
-// +stateify savable
-type SignalStack struct {
- Addr uint64
- Flags uint32
- _ uint32
- Size uint64
-}
-
-// SerializeFrom implements NativeSignalStack.SerializeFrom.
-func (s *SignalStack) SerializeFrom(other *SignalStack) {
- *s = *other
-}
-
-// DeserializeTo implements NativeSignalStack.DeserializeTo.
-func (s *SignalStack) DeserializeTo(other *SignalStack) {
- *other = *s
-}
-
-// SignalInfo represents information about a signal being delivered, and is
-// equivalent to struct siginfo on 64-bit x86.
-//
-// +stateify savable
-type SignalInfo struct {
- Signo int32 // Signal number
- Errno int32 // Errno value
- Code int32 // Signal code
- _ uint32
-
- // struct siginfo::_sifields is a union. In SignalInfo, fields in the union
- // are accessed through methods.
- //
- // For reference, here is the definition of _sifields: (_sigfault._trapno,
- // which does not exist on x86, omitted for clarity)
- //
- // union {
- // int _pad[SI_PAD_SIZE];
- //
- // /* kill() */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // } _kill;
- //
- // /* POSIX.1b timers */
- // struct {
- // __kernel_timer_t _tid; /* timer id */
- // int _overrun; /* overrun count */
- // char _pad[sizeof( __ARCH_SI_UID_T) - sizeof(int)];
- // sigval_t _sigval; /* same as below */
- // int _sys_private; /* not to be passed to user */
- // } _timer;
- //
- // /* POSIX.1b signals */
- // struct {
- // __kernel_pid_t _pid; /* sender's pid */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // sigval_t _sigval;
- // } _rt;
- //
- // /* SIGCHLD */
- // struct {
- // __kernel_pid_t _pid; /* which child */
- // __ARCH_SI_UID_T _uid; /* sender's uid */
- // int _status; /* exit code */
- // __ARCH_SI_CLOCK_T _utime;
- // __ARCH_SI_CLOCK_T _stime;
- // } _sigchld;
- //
- // /* SIGILL, SIGFPE, SIGSEGV, SIGBUS */
- // struct {
- // void *_addr; /* faulting insn/memory ref. */
- // short _addr_lsb; /* LSB of the reported address */
- // } _sigfault;
- //
- // /* SIGPOLL */
- // struct {
- // __ARCH_SI_BAND_T _band; /* POLL_IN, POLL_OUT, POLL_MSG */
- // int _fd;
- // } _sigpoll;
- //
- // /* SIGSYS */
- // struct {
- // void *_call_addr; /* calling user insn */
- // int _syscall; /* triggering system call number */
- // unsigned int _arch; /* AUDIT_ARCH_* of syscall */
- // } _sigsys;
- // } _sifields;
- //
- // _sifields is padded so that the size of siginfo is SI_MAX_SIZE = 128
- // bytes.
- Fields [128 - 16]byte
-}
-
-// FixSignalCodeForUser fixes up si_code.
-//
-// The si_code we get from Linux may contain the kernel-specific code in the
-// top 16 bits if it's positive (e.g., from ptrace). Linux's
-// copy_siginfo_to_user does
-// err |= __put_user((short)from->si_code, &to->si_code);
-// to mask out those bits and we need to do the same.
-func (s *SignalInfo) FixSignalCodeForUser() {
- if s.Code > 0 {
- s.Code &= 0x0000ffff
- }
-}
-
-// Pid returns the si_pid field.
-func (s *SignalInfo) Pid() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetPid mutates the si_pid field.
-func (s *SignalInfo) SetPid(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Uid returns the si_uid field.
-func (s *SignalInfo) Uid() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetUid mutates the si_uid field.
-func (s *SignalInfo) SetUid(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Sigval returns the sigval field, which is aliased to both si_int and si_ptr.
-func (s *SignalInfo) Sigval() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[8:16])
-}
-
-// SetSigval mutates the sigval field.
-func (s *SignalInfo) SetSigval(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[8:16], val)
-}
-
-// TimerID returns the si_timerid field.
-func (s *SignalInfo) TimerID() linux.TimerID {
- return linux.TimerID(usermem.ByteOrder.Uint32(s.Fields[0:4]))
-}
-
-// SetTimerID sets the si_timerid field.
-func (s *SignalInfo) SetTimerID(val linux.TimerID) {
- usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
-}
-
-// Overrun returns the si_overrun field.
-func (s *SignalInfo) Overrun() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
-}
-
-// SetOverrun sets the si_overrun field.
-func (s *SignalInfo) SetOverrun(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
-}
-
-// Addr returns the si_addr field.
-func (s *SignalInfo) Addr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetAddr sets the si_addr field.
-func (s *SignalInfo) SetAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Status returns the si_status field.
-func (s *SignalInfo) Status() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetStatus mutates the si_status field.
-func (s *SignalInfo) SetStatus(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// CallAddr returns the si_call_addr field.
-func (s *SignalInfo) CallAddr() uint64 {
- return usermem.ByteOrder.Uint64(s.Fields[0:8])
-}
-
-// SetCallAddr mutates the si_call_addr field.
-func (s *SignalInfo) SetCallAddr(val uint64) {
- usermem.ByteOrder.PutUint64(s.Fields[0:8], val)
-}
-
-// Syscall returns the si_syscall field.
-func (s *SignalInfo) Syscall() int32 {
- return int32(usermem.ByteOrder.Uint32(s.Fields[8:12]))
-}
-
-// SetSyscall mutates the si_syscall field.
-func (s *SignalInfo) SetSyscall(val int32) {
- usermem.ByteOrder.PutUint32(s.Fields[8:12], uint32(val))
-}
-
-// Arch returns the si_arch field.
-func (s *SignalInfo) Arch() uint32 {
- return usermem.ByteOrder.Uint32(s.Fields[12:16])
-}
-
-// SetArch mutates the si_arch field.
-func (s *SignalInfo) SetArch(val uint32) {
- usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
-}
-
// SignalContext64 is equivalent to struct sigcontext, the type passed as the
// second argument to signal handlers set by signal(2).
type SignalContext64 struct {
@@ -285,7 +55,7 @@ type SignalContext64 struct {
Trapno uint64
Oldmask linux.SignalSet
Cr2 uint64
- // Pointer to a struct _fpstate.
+ // Pointer to a struct _fpstate. See b/33003106#comment8.
Fpstate uint64
Reserved [8]uint64
}
diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go
new file mode 100644
index 000000000..642c79dda
--- /dev/null
+++ b/pkg/sentry/arch/signal_arm64.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package arch
+
+import (
+ "encoding/binary"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// SignalContext64 is equivalent to struct sigcontext, the type passed as the
+// second argument to signal handlers set by signal(2).
+type SignalContext64 struct {
+ FaultAddr uint64
+ Regs [31]uint64
+ Sp uint64
+ Pc uint64
+ Pstate uint64
+ _pad [8]byte // __attribute__((__aligned__(16)))
+ Fpsimd64 FpsimdContext // size = 528
+ Reserved [3568]uint8
+}
+
+type aarch64Ctx struct {
+ Magic uint32
+ Size uint32
+}
+
+// FpsimdContext is equivalent to struct fpsimd_context on arm64
+// (arch/arm64/include/uapi/asm/sigcontext.h).
+type FpsimdContext struct {
+ Head aarch64Ctx
+ Fpsr uint32
+ Fpcr uint32
+ Vregs [64]uint64 // actually [32]uint128
+}
+
+// UContext64 is equivalent to ucontext on arm64(arch/arm64/include/uapi/asm/ucontext.h).
+type UContext64 struct {
+ Flags uint64
+ Link uint64
+ Stack SignalStack
+ Sigset linux.SignalSet
+ // glibc uses a 1024-bit sigset_t
+ _pad [(1024 - 64) / 8]byte
+ // sigcontext must be aligned to 16-byte
+ _pad2 [8]byte
+ // last for future expansion
+ MContext SignalContext64
+}
+
+// NewSignalAct implements Context.NewSignalAct.
+func (c *context64) NewSignalAct() NativeSignalAct {
+ return &SignalAct{}
+}
+
+// NewSignalStack implements Context.NewSignalStack.
+func (c *context64) NewSignalStack() NativeSignalStack {
+ return &SignalStack{}
+}
+
+// SignalSetup implements Context.SignalSetup.
+func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error {
+ sp := st.Bottom
+
+ if !(alt.IsEnabled() && sp == alt.Top()) {
+ sp -= 128
+ }
+
+ // Construct the UContext64 now since we need its size.
+ uc := &UContext64{
+ Flags: 0,
+ Stack: *alt,
+ MContext: SignalContext64{
+ Regs: c.Regs.Regs,
+ Sp: c.Regs.Sp,
+ Pc: c.Regs.Pc,
+ Pstate: c.Regs.Pstate,
+ },
+ Sigset: sigset,
+ }
+
+ ucSize := binary.Size(uc)
+ if ucSize < 0 {
+ panic("can't get size of UContext64")
+ }
+
+ // frameSize = ucSize + sizeof(siginfo).
+ // sizeof(siginfo) == 128.
+ // R30 stores the restorer address.
+ frameSize := ucSize + 128
+ frameBottom := (sp - usermem.Addr(frameSize)) & ^usermem.Addr(15)
+ sp = frameBottom + usermem.Addr(frameSize)
+ st.Bottom = sp
+
+ // Prior to proceeding, figure out if the frame will exhaust the range
+ // for the signal stack. This is not allowed, and should immediately
+ // force signal delivery (reverting to the default handler).
+ if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) {
+ return syscall.EFAULT
+ }
+
+ // Adjust the code.
+ info.FixSignalCodeForUser()
+
+ // Set up the stack frame.
+ infoAddr, err := st.Push(info)
+ if err != nil {
+ return err
+ }
+ ucAddr, err := st.Push(uc)
+ if err != nil {
+ return err
+ }
+
+ // Set up registers.
+ c.Regs.Sp = uint64(st.Bottom)
+ c.Regs.Pc = act.Handler
+ c.Regs.Regs[0] = uint64(info.Signo)
+ c.Regs.Regs[1] = uint64(infoAddr)
+ c.Regs.Regs[2] = uint64(ucAddr)
+ c.Regs.Regs[30] = uint64(act.Restorer)
+
+ // Save the thread's floating point state.
+ c.sigFPState = append(c.sigFPState, c.aarch64FPState)
+ // Signal handler gets a clean floating point state.
+ c.aarch64FPState = newAarch64FPState()
+ return nil
+}
+
+// SignalRestore implements Context.SignalRestore.
+func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) {
+ // Copy out the stack frame.
+ var uc UContext64
+ if _, err := st.Pop(&uc); err != nil {
+ return 0, SignalStack{}, err
+ }
+ var info SignalInfo
+ if _, err := st.Pop(&info); err != nil {
+ return 0, SignalStack{}, err
+ }
+
+ // Restore registers.
+ c.Regs.Regs = uc.MContext.Regs
+ c.Regs.Pc = uc.MContext.Pc
+ c.Regs.Sp = uc.MContext.Sp
+ c.Regs.Pstate = uc.MContext.Pstate
+
+ // Restore floating point state.
+ l := len(c.sigFPState)
+ if l > 0 {
+ c.aarch64FPState = c.sigFPState[l-1]
+ // NOTE(cl/133042258): State save requires that any slice
+ // elements from '[len:cap]' to be zero value.
+ c.sigFPState[l-1] = nil
+ c.sigFPState = c.sigFPState[0 : l-1]
+ } else {
+ // This might happen if sigreturn(2) calls are unbalanced with
+ // respect to signal handler entries. This is not expected so
+ // don't bother to do anything fancy with the floating point
+ // state.
+ log.Warningf("sigreturn unable to restore application fpstate")
+ }
+
+ return uc.Sigset, uc.Stack, nil
+}
diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go
index 5a3228113..0fa738a1d 100644
--- a/pkg/sentry/arch/signal_stack.go
+++ b/pkg/sentry/arch/signal_stack.go
@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64 arm64
package arch
import (
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
const (
@@ -55,6 +56,8 @@ func (s *SignalStack) Contains(sp usermem.Addr) bool {
// NativeSignalStack is a type that is equivalent to stack_t in the guest
// architecture.
type NativeSignalStack interface {
+ marshal.Marshallable
+
// SerializeFrom copies the data in the host SignalStack s into this
// object.
SerializeFrom(s *SignalStack)
diff --git a/pkg/sentry/arch/stack.go b/pkg/sentry/arch/stack.go
index 7472c3c61..1108fa0bd 100644
--- a/pkg/sentry/arch/stack.go
+++ b/pkg/sentry/arch/stack.go
@@ -18,8 +18,8 @@ import (
"encoding/binary"
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Stack is a simple wrapper around a usermem.IO and an address.
@@ -97,7 +97,6 @@ func (s *Stack) Push(vals ...interface{}) (usermem.Addr, error) {
if c < 0 {
return 0, fmt.Errorf("bad binary.Size for %T", v)
}
- // TODO(b/38173783): Use a real context.Context.
n, err := usermem.CopyObjectOut(context.Background(), s.IO, s.Bottom-usermem.Addr(c), norm, usermem.IOOpts{})
if err != nil || c != n {
return 0, err
@@ -121,11 +120,9 @@ func (s *Stack) Pop(vals ...interface{}) (usermem.Addr, error) {
var err error
if isVaddr {
value := s.Arch.Native(uintptr(0))
- // TODO(b/38173783): Use a real context.Context.
n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, value, usermem.IOOpts{})
*vaddr = usermem.Addr(s.Arch.Value(value))
} else {
- // TODO(b/38173783): Use a real context.Context.
n, err = usermem.CopyObjectIn(context.Background(), s.IO, s.Bottom, v, usermem.IOOpts{})
}
if err != nil {
diff --git a/pkg/sentry/arch/syscalls_amd64.go b/pkg/sentry/arch/syscalls_amd64.go
index 8b4f23007..3859f41ee 100644
--- a/pkg/sentry/arch/syscalls_amd64.go
+++ b/pkg/sentry/arch/syscalls_amd64.go
@@ -18,6 +18,13 @@ package arch
const restartSyscallNr = uintptr(219)
+// SyscallSaveOrig save the value of the register which is clobbered in
+// syscall handler(doSyscall()).
+//
+// Noop on x86.
+func (c *context64) SyscallSaveOrig() {
+}
+
// SyscallNo returns the syscall number according to the 64-bit convention.
func (c *context64) SyscallNo() uintptr {
return uintptr(c.Regs.Orig_rax)
diff --git a/pkg/sentry/arch/syscalls_arm64.go b/pkg/sentry/arch/syscalls_arm64.go
new file mode 100644
index 000000000..95dfd1e90
--- /dev/null
+++ b/pkg/sentry/arch/syscalls_arm64.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package arch
+
+const restartSyscallNr = uintptr(128)
+
+// SyscallSaveOrig save the value of the register R0 which is clobbered in
+// syscall handler(doSyscall()).
+//
+// In linux, at the entry of the syscall handler(el0_svc_common()), value of R0
+// is saved to the pt_regs.orig_x0 in kernel code. But currently, the orig_x0
+// was not accessible to the userspace application, so we have to do the same
+// operation in the sentry code to save the R0 value into the App context.
+func (c *context64) SyscallSaveOrig() {
+ c.OrigR0 = c.Regs.Regs[0]
+}
+
+// SyscallNo returns the syscall number according to the 64-bit convention.
+func (c *context64) SyscallNo() uintptr {
+ return uintptr(c.Regs.Regs[8])
+}
+
+// SyscallArgs provides syscall arguments according to the 64-bit convention.
+//
+// Due to the way addresses are mapped for the sentry this binary *must* be
+// built in 64-bit mode. So we can just assume the syscall numbers that come
+// back match the expected host system call numbers.
+// General purpose registers usage on Arm64:
+// R0...R7: parameter/result registers.
+// R8: indirect result location register.
+// R9...R15: temporary registers.
+// R16: the first intra-procedure-call scratch register.
+// R17: the second intra-procedure-call scratch register.
+// R18: the platform register.
+// R19...R28: callee-saved registers.
+// R29: the frame pointer.
+// R30: the link register.
+func (c *context64) SyscallArgs() SyscallArguments {
+ return SyscallArguments{
+ SyscallArgument{Value: uintptr(c.OrigR0)},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[1])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[2])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[3])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[4])},
+ SyscallArgument{Value: uintptr(c.Regs.Regs[5])},
+ }
+}
+
+// RestartSyscall implements Context.RestartSyscall.
+// Prepare for system call restart, OrigR0 will be restored to R0.
+// Please see the linux code as reference:
+// arch/arm64/kernel/signal.c:do_signal()
+func (c *context64) RestartSyscall() {
+ c.Regs.Pc -= SyscallWidth
+ // R0 will be backed up into OrigR0 when entering doSyscall().
+ // Please see the linux code as reference:
+ // arch/arm64/kernel/syscall.c:el0_svc_common().
+ // Here we restore it back.
+ c.Regs.Regs[0] = uint64(c.OrigR0)
+}
+
+// RestartSyscallWithRestartBlock implements Context.RestartSyscallWithRestartBlock.
+func (c *context64) RestartSyscallWithRestartBlock() {
+ c.Regs.Pc -= SyscallWidth
+ c.Regs.Regs[0] = uint64(c.OrigR0)
+ c.Regs.Regs[8] = uint64(restartSyscallNr)
+}
diff --git a/pkg/sentry/context/contexttest/BUILD b/pkg/sentry/contexttest/BUILD
index 581e7aa96..6f4c86684 100644
--- a/pkg/sentry/context/contexttest/BUILD
+++ b/pkg/sentry/contexttest/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,11 +6,10 @@ go_library(
name = "contexttest",
testonly = 1,
srcs = ["contexttest.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/context/contexttest",
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/context",
"//pkg/memutil",
- "//pkg/sentry/context",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
diff --git a/pkg/sentry/context/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go
index 15cf086a9..8e5658c7a 100644
--- a/pkg/sentry/context/contexttest/contexttest.go
+++ b/pkg/sentry/contexttest/contexttest.go
@@ -21,8 +21,8 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/memutil"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -97,7 +97,7 @@ type hostClock struct {
}
// Now implements ktime.Clock.Now.
-func (hostClock) Now() ktime.Time {
+func (*hostClock) Now() ktime.Time {
return ktime.FromNanoseconds(time.Now().UnixNano())
}
@@ -127,7 +127,7 @@ func (t *TestContext) Value(key interface{}) interface{} {
case uniqueid.CtxInotifyCookie:
return atomic.AddUint32(&lastInotifyCookie, 1)
case ktime.CtxRealtimeClock:
- return hostClock{}
+ return &hostClock{}
default:
if val, ok := t.otherValues[key]; ok {
return val
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index 5522cecd0..2c5d14be5 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -12,16 +11,18 @@ go_library(
"proc.go",
"state.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/control",
visibility = [
- "//pkg/sentry:internal",
+ "//:sandbox",
],
deps = [
"//pkg/abi/linux",
"//pkg/fd",
"//pkg/log",
+ "//pkg/sentry/fdimport",
"//pkg/sentry/fs",
"//pkg/sentry/fs/host",
+ "//pkg/sentry/fs/user",
+ "//pkg/sentry/fsimpl/host",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
@@ -29,9 +30,12 @@ go_library(
"//pkg/sentry/state",
"//pkg/sentry/strace",
"//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
+ "//pkg/sync",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
@@ -39,7 +43,7 @@ go_test(
name = "control_test",
size = "small",
srcs = ["proc_test.go"],
- embed = [":control"],
+ library = ":control",
deps = [
"//pkg/log",
"//pkg/sentry/kernel/time",
diff --git a/pkg/sentry/control/logging.go b/pkg/sentry/control/logging.go
index 811f24324..8a500a515 100644
--- a/pkg/sentry/control/logging.go
+++ b/pkg/sentry/control/logging.go
@@ -70,8 +70,8 @@ type LoggingArgs struct {
type Logging struct{}
// Change will change the log level and strace arguments. Although
-// this functions signature requires an error it never acctually
-// return san error. It's required by the URPC interface.
+// this functions signature requires an error it never actually
+// returns an error. It's required by the URPC interface.
// Additionally, it may look odd that this is the only method
// attached to an empty struct but this is also part of how
// URPC dispatches.
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 1f78d54a2..2bf3c45e1 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -19,9 +19,10 @@ import (
"runtime"
"runtime/pprof"
"runtime/trace"
- "sync"
"gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -48,6 +49,9 @@ type ProfileOpts struct {
// - dump out the stack trace of current go routines.
// sentryctl -pid <pid> pprof-goroutine
type Profile struct {
+ // Kernel is the kernel under profile. It's immutable.
+ Kernel *kernel.Kernel
+
// mu protects the fields below.
mu sync.Mutex
@@ -113,9 +117,9 @@ func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error {
return nil
}
-// Goroutine is an RPC stub which dumps out the stack trace for all running
-// goroutines.
-func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error {
+// GoroutineProfile is an RPC stub which dumps out the stack trace for all
+// running goroutines.
+func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
return errNoOutput
}
@@ -127,6 +131,34 @@ func (p *Profile) Goroutine(o *ProfileOpts, _ *struct{}) error {
return nil
}
+// BlockProfile is an RPC stub which dumps out the stack trace that led to
+// blocking on synchronization primitives.
+func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("block").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
+// MutexProfile is an RPC stub which dumps out the stack trace of holders of
+// contended mutexes.
+func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error {
+ if len(o.FilePayload.Files) < 1 {
+ return errNoOutput
+ }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
+ if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil {
+ return err
+ }
+ return nil
+}
+
// StartTrace is an RPC stub which starts collection of an execution trace.
func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
@@ -147,6 +179,9 @@ func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
return err
}
+ // Ensure all trace contexts are registered.
+ p.Kernel.RebuildTraceContexts()
+
p.traceFile = output
return nil
}
@@ -158,9 +193,15 @@ func (p *Profile) StopTrace(_, _ *struct{}) error {
defer p.mu.Unlock()
if p.traceFile == nil {
- return errors.New("Execution tracing not start")
+ return errors.New("Execution tracing not started")
}
+ // Similarly to the case above, if tasks have not ended traces, we will
+ // lose information. Thus we need to rebuild the tasks in order to have
+ // complete information. This will not lose information if multiple
+ // traces are overlapping.
+ p.Kernel.RebuildTraceContexts()
+
trace.Stop()
p.traceFile.Close()
p.traceFile = nil
diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go
index c35faeb4c..dfa936563 100644
--- a/pkg/sentry/control/proc.go
+++ b/pkg/sentry/control/proc.go
@@ -23,14 +23,19 @@ import (
"text/tabwriter"
"time"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fdimport"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/urpc"
)
@@ -60,6 +65,12 @@ type ExecArgs struct {
// process's MountNamespace.
MountNamespace *fs.MountNamespace
+ // MountNamespaceVFS2 is the mount namespace to execute the new process in.
+ // A reference on MountNamespace must be held for the lifetime of the
+ // ExecArgs. If MountNamespace is nil, it will default to the init
+ // process's MountNamespace.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// WorkingDirectory defines the working directory for the new process.
WorkingDirectory string `json:"wd"`
@@ -71,15 +82,13 @@ type ExecArgs struct {
// the root group if not set explicitly.
KGID auth.KGID
- // ExtraKGIDs is the list of additional groups to which the user
- // belongs.
+ // ExtraKGIDs is the list of additional groups to which the user belongs.
ExtraKGIDs []auth.KGID
// Capabilities is the list of capabilities to give to the process.
Capabilities *auth.TaskCapabilities
- // StdioIsPty indicates that FDs 0, 1, and 2 are connected to a host
- // pty FD.
+ // StdioIsPty indicates that FDs 0, 1, and 2 are connected to a host pty FD.
StdioIsPty bool
// FilePayload determines the files to give to the new process.
@@ -94,6 +103,9 @@ type ExecArgs struct {
// String prints the arguments as a string.
func (args ExecArgs) String() string {
+ if len(args.Argv) == 0 {
+ return args.Filename
+ }
a := make([]string, len(args.Argv))
copy(a, args.Argv)
if args.Filename != "" {
@@ -104,7 +116,7 @@ func (args ExecArgs) String() string {
// Exec runs a new task.
func (proc *Proc) Exec(args *ExecArgs, waitStatus *uint32) error {
- newTG, _, _, err := proc.execAsync(args)
+ newTG, _, _, _, err := proc.execAsync(args)
if err != nil {
return err
}
@@ -117,25 +129,16 @@ func (proc *Proc) Exec(args *ExecArgs, waitStatus *uint32) error {
// ExecAsync runs a new task, but doesn't wait for it to finish. It is defined
// as a function rather than a method to avoid exposing execAsync as an RPC.
-func ExecAsync(proc *Proc, args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, error) {
+func ExecAsync(proc *Proc, args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
return proc.execAsync(args)
}
// execAsync runs a new task, but doesn't wait for it to finish. It returns the
// newly created thread group and its PID. If the stdio FDs are TTYs, then a
// TTYFileOperations that wraps the TTY is also returned.
-func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, error) {
+func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadID, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
// Import file descriptors.
fdTable := proc.Kernel.NewFDTable()
- defer fdTable.DecRef()
-
- // No matter what happens, we should close all files in the FilePayload
- // before returning. Any files that are imported will be duped.
- defer func() {
- for _, f := range args.FilePayload.Files {
- f.Close()
- }
- }()
creds := auth.NewUserCredentials(
args.KUID,
@@ -150,6 +153,7 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
Envv: args.Envv,
WorkingDirectory: args.WorkingDirectory,
MountNamespace: args.MountNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
Credentials: creds,
FDTable: fdTable,
Umask: 0022,
@@ -166,81 +170,80 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI
// be donated to the new process in CreateProcess.
initArgs.MountNamespace.IncRef()
}
+ if initArgs.MountNamespaceVFS2 != nil {
+ // initArgs must hold a reference on MountNamespaceVFS2, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespaceVFS2.IncRef()
+ }
ctx := initArgs.NewContext(proc.Kernel)
+ defer fdTable.DecRef(ctx)
- if initArgs.Filename == "" {
+ if kernel.VFS2Enabled {
// Get the full path to the filename from the PATH env variable.
- paths := fs.GetPath(initArgs.Envv)
- mns := initArgs.MountNamespace
- if mns == nil {
- mns = proc.Kernel.GlobalInit().Leader().MountNamespace()
+ if initArgs.MountNamespaceVFS2 == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ //
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ initArgs.MountNamespaceVFS2 = proc.Kernel.GlobalInit().Leader().MountNamespaceVFS2()
}
- f, err := mns.ResolveExecutablePath(ctx, initArgs.WorkingDirectory, initArgs.Argv[0], paths)
- if err != nil {
- return nil, 0, nil, fmt.Errorf("error finding executable %q in PATH %v: %v", initArgs.Argv[0], paths, err)
+ } else {
+ if initArgs.MountNamespace == nil {
+ // Set initArgs so that 'ctx' returns the namespace.
+ initArgs.MountNamespace = proc.Kernel.GlobalInit().Leader().MountNamespace()
+
+ // initArgs must hold a reference on MountNamespace, which will
+ // be donated to the new process in CreateProcess.
+ initArgs.MountNamespace.IncRef()
}
- initArgs.Filename = f
}
+ resolved, err := user.ResolveExecutablePath(ctx, &initArgs)
+ if err != nil {
+ return nil, 0, nil, nil, err
+ }
+ initArgs.Filename = resolved
- mounter := fs.FileOwnerFromContext(ctx)
-
- var ttyFile *fs.File
- for appFD, hostFile := range args.FilePayload.Files {
- var appFile *fs.File
-
- if args.StdioIsPty && appFD < 3 {
- // Import the file as a host TTY file.
- if ttyFile == nil {
- var err error
- appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, true /* isTTY */)
- if err != nil {
- return nil, 0, nil, err
- }
- defer appFile.DecRef()
-
- // Remember this in the TTY file, as we will
- // use it for the other stdio FDs.
- ttyFile = appFile
- } else {
- // Re-use the existing TTY file, as all three
- // stdio FDs must point to the same fs.File in
- // order to share TTY state, specifically the
- // foreground process group id.
- appFile = ttyFile
- }
- } else {
- // Import the file as a regular host file.
- var err error
- appFile, err = host.ImportFile(ctx, int(hostFile.Fd()), mounter, false /* isTTY */)
+ fds := make([]int, len(args.FilePayload.Files))
+ for i, file := range args.FilePayload.Files {
+ if kernel.VFS2Enabled {
+ // Need to dup to remove ownership from os.File.
+ dup, err := unix.Dup(int(file.Fd()))
if err != nil {
- return nil, 0, nil, err
+ return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err)
}
- defer appFile.DecRef()
+ fds[i] = dup
+ } else {
+ // VFS1 dups the file on import.
+ fds[i] = int(file.Fd())
}
-
- // Add the file to the FD map.
- if err := fdTable.NewFDAt(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
- return nil, 0, nil, err
+ }
+ ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, args.StdioIsPty, fds)
+ if err != nil {
+ if kernel.VFS2Enabled {
+ for _, fd := range fds {
+ unix.Close(fd)
+ }
}
+ return nil, 0, nil, nil, err
}
tg, tid, err := proc.Kernel.CreateProcess(initArgs)
if err != nil {
- return nil, 0, nil, err
+ return nil, 0, nil, nil, err
}
- var ttyFileOps *host.TTYFileOperations
- if ttyFile != nil {
- // Set the foreground process group on the TTY before starting
- // the process.
- ttyFileOps = ttyFile.FileOperations.(*host.TTYFileOperations)
- ttyFileOps.InitForegroundProcessGroup(tg.ProcessGroup())
+ // Set the foreground process group on the TTY before starting the process.
+ switch {
+ case ttyFile != nil:
+ ttyFile.InitForegroundProcessGroup(tg.ProcessGroup())
+ case ttyFileVFS2 != nil:
+ ttyFileVFS2.InitForegroundProcessGroup(tg.ProcessGroup())
}
// Start the newly created process.
proc.Kernel.StartProcess(tg)
- return tg, tid, ttyFileOps, nil
+ return tg, tid, ttyFile, ttyFileVFS2, nil
}
// PsArgs is the set of arguments to ps.
@@ -268,14 +271,17 @@ func (proc *Proc) Ps(args *PsArgs, out *string) error {
}
// Process contains information about a single process in a Sandbox.
-// TODO(b/117881927): Implement TTY field.
type Process struct {
UID auth.KUID `json:"uid"`
PID kernel.ThreadID `json:"pid"`
// Parent PID
- PPID kernel.ThreadID `json:"ppid"`
+ PPID kernel.ThreadID `json:"ppid"`
+ Threads []kernel.ThreadID `json:"threads"`
// Processor utilization
C int32 `json:"c"`
+ // TTY name of the process. Will be of the form "pts/N" if there is a
+ // TTY, or "?" if there is not.
+ TTY string `json:"tty"`
// Start time
STime string `json:"stime"`
// CPU time
@@ -285,18 +291,19 @@ type Process struct {
}
// ProcessListToTable prints a table with the following format:
-// UID PID PPID C STIME TIME CMD
-// 0 1 0 0 14:04 505262ns tail
+// UID PID PPID C TTY STIME TIME CMD
+// 0 1 0 0 pty/4 14:04 505262ns tail
func ProcessListToTable(pl []*Process) string {
var buf bytes.Buffer
tw := tabwriter.NewWriter(&buf, 10, 1, 3, ' ', 0)
- fmt.Fprint(tw, "UID\tPID\tPPID\tC\tSTIME\tTIME\tCMD")
+ fmt.Fprint(tw, "UID\tPID\tPPID\tC\tTTY\tSTIME\tTIME\tCMD")
for _, d := range pl {
- fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s",
+ fmt.Fprintf(tw, "\n%d\t%d\t%d\t%d\t%s\t%s\t%s\t%s",
d.UID,
d.PID,
d.PPID,
d.C,
+ d.TTY,
d.STime,
d.Time,
d.Cmd)
@@ -307,7 +314,7 @@ func ProcessListToTable(pl []*Process) string {
// ProcessListToJSON will return the JSON representation of ps.
func ProcessListToJSON(pl []*Process) (string, error) {
- b, err := json.Marshal(pl)
+ b, err := json.MarshalIndent(pl, "", " ")
if err != nil {
return "", fmt.Errorf("couldn't marshal process list %v: %v", pl, err)
}
@@ -334,7 +341,9 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
ts := k.TaskSet()
now := k.RealtimeClock().Now()
for _, tg := range ts.Root.ThreadGroups() {
- pid := tg.PIDNamespace().IDOfThreadGroup(tg)
+ pidns := tg.PIDNamespace()
+ pid := pidns.IDOfThreadGroup(tg)
+
// If tg has already been reaped ignore it.
if pid == 0 {
continue
@@ -345,16 +354,19 @@ func Processes(k *kernel.Kernel, containerID string, out *[]*Process) error {
ppid := kernel.ThreadID(0)
if p := tg.Leader().Parent(); p != nil {
- ppid = p.PIDNamespace().IDOfThreadGroup(p.ThreadGroup())
+ ppid = pidns.IDOfThreadGroup(p.ThreadGroup())
}
+ threads := tg.MemberIDs(pidns)
*out = append(*out, &Process{
- UID: tg.Leader().Credentials().EffectiveKUID,
- PID: pid,
- PPID: ppid,
- STime: formatStartTime(now, tg.Leader().StartTime()),
- C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
- Time: tg.CPUStats().SysTime.String(),
- Cmd: tg.Leader().Name(),
+ UID: tg.Leader().Credentials().EffectiveKUID,
+ PID: pid,
+ PPID: ppid,
+ Threads: threads,
+ STime: formatStartTime(now, tg.Leader().StartTime()),
+ C: percentCPU(tg.CPUStats(), tg.Leader().StartTime(), now),
+ Time: tg.CPUStats().SysTime.String(),
+ Cmd: tg.Leader().Name(),
+ TTY: ttyName(tg.TTY()),
})
}
sort.Slice(*out, func(i, j int) bool { return (*out)[i].PID < (*out)[j].PID })
@@ -395,3 +407,10 @@ func percentCPU(stats usage.CPUStats, startTime, now ktime.Time) int32 {
}
return int32(percentCPU)
}
+
+func ttyName(tty *kernel.TTY) string {
+ if tty == nil {
+ return "?"
+ }
+ return fmt.Sprintf("pts/%d", tty.Index)
+}
diff --git a/pkg/sentry/control/proc_test.go b/pkg/sentry/control/proc_test.go
index d8ada2694..0a88459b2 100644
--- a/pkg/sentry/control/proc_test.go
+++ b/pkg/sentry/control/proc_test.go
@@ -34,7 +34,7 @@ func TestProcessListTable(t *testing.T) {
}{
{
pl: []*Process{},
- expected: "UID PID PPID C STIME TIME CMD",
+ expected: "UID PID PPID C TTY STIME TIME CMD",
},
{
pl: []*Process{
@@ -43,6 +43,7 @@ func TestProcessListTable(t *testing.T) {
PID: 0,
PPID: 0,
C: 0,
+ TTY: "?",
STime: "0",
Time: "0",
Cmd: "zero",
@@ -52,14 +53,15 @@ func TestProcessListTable(t *testing.T) {
PID: 1,
PPID: 1,
C: 1,
+ TTY: "pts/4",
STime: "1",
Time: "1",
Cmd: "one",
},
},
- expected: `UID PID PPID C STIME TIME CMD
-0 0 0 0 0 0 zero
-1 1 1 1 1 1 one`,
+ expected: `UID PID PPID C TTY STIME TIME CMD
+0 0 0 0 ? 0 0 zero
+1 1 1 1 pts/4 1 1 one`,
},
}
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 1098ed777..e403cbd8b 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -1,19 +1,20 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "device",
srcs = ["device.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/device",
visibility = ["//pkg/sentry:internal"],
- deps = ["//pkg/abi/linux"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sync",
+ ],
)
go_test(
name = "device_test",
size = "small",
srcs = ["device_test.go"],
- embed = [":device"],
+ library = ":device",
)
diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go
index 47945d1a7..f45b2bd2b 100644
--- a/pkg/sentry/device/device.go
+++ b/pkg/sentry/device/device.go
@@ -19,10 +19,10 @@ package device
import (
"bytes"
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Registry tracks all simple devices and related state on the system for
@@ -188,6 +188,9 @@ type MultiDevice struct {
// String stringifies MultiDevice.
func (m *MultiDevice) String() string {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
buf := bytes.NewBuffer(nil)
buf.WriteString("cache{")
for k, v := range m.cache {
diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD
new file mode 100644
index 000000000..abe58f818
--- /dev/null
+++ b/pkg/sentry/devices/memdev/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "memdev",
+ srcs = [
+ "full.go",
+ "memdev.go",
+ "null.go",
+ "random.go",
+ "zero.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/rand",
+ "//pkg/safemem",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/devices/memdev/full.go b/pkg/sentry/devices/memdev/full.go
new file mode 100644
index 000000000..511179e31
--- /dev/null
+++ b/pkg/sentry/devices/memdev/full.go
@@ -0,0 +1,76 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const fullDevMinor = 7
+
+// fullDevice implements vfs.Device for /dev/full.
+type fullDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (fullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &fullFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// fullFD implements vfs.FileDescriptionImpl for /dev/full.
+type fullFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *fullFD) Release(context.Context) {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *fullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *fullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *fullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSPC
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *fullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ENOSPC
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *fullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
diff --git a/pkg/sentry/devices/memdev/memdev.go b/pkg/sentry/devices/memdev/memdev.go
new file mode 100644
index 000000000..5759900c4
--- /dev/null
+++ b/pkg/sentry/devices/memdev/memdev.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memdev implements "mem" character devices, as implemented in Linux
+// by drivers/char/mem.c and drivers/char/random.c.
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ for minor, dev := range map[uint32]vfs.Device{
+ nullDevMinor: nullDevice{},
+ zeroDevMinor: zeroDevice{},
+ fullDevMinor: fullDevice{},
+ randomDevMinor: randomDevice{},
+ urandomDevMinor: randomDevice{},
+ } {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MEM_MAJOR, minor, dev, &vfs.RegisterDeviceOptions{
+ GroupName: "mem",
+ }); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ for minor, name := range map[uint32]string{
+ nullDevMinor: "null",
+ zeroDevMinor: "zero",
+ fullDevMinor: "full",
+ randomDevMinor: "random",
+ urandomDevMinor: "urandom",
+ } {
+ if err := dev.CreateDeviceFile(ctx, name, vfs.CharDevice, linux.MEM_MAJOR, minor, 0666 /* mode */); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/pkg/sentry/devices/memdev/null.go b/pkg/sentry/devices/memdev/null.go
new file mode 100644
index 000000000..4918dbeeb
--- /dev/null
+++ b/pkg/sentry/devices/memdev/null.go
@@ -0,0 +1,77 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const nullDevMinor = 3
+
+// nullDevice implements vfs.Device for /dev/null.
+type nullDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (nullDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &nullFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// nullFD implements vfs.FileDescriptionImpl for /dev/null.
+type nullFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *nullFD) Release(context.Context) {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *nullFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, io.EOF
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *nullFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return 0, io.EOF
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *nullFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *nullFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *nullFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
diff --git a/pkg/sentry/devices/memdev/random.go b/pkg/sentry/devices/memdev/random.go
new file mode 100644
index 000000000..5e7fe0280
--- /dev/null
+++ b/pkg/sentry/devices/memdev/random.go
@@ -0,0 +1,93 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ randomDevMinor = 8
+ urandomDevMinor = 9
+)
+
+// randomDevice implements vfs.Device for /dev/random and /dev/urandom.
+type randomDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (randomDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &randomFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// randomFD implements vfs.FileDescriptionImpl for /dev/random.
+type randomFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // off is the "file offset". off is accessed using atomic memory
+ // operations.
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *randomFD) Release(context.Context) {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *randomFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *randomFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, safemem.FromIOReader{rand.Reader})
+ atomic.AddInt64(&fd.off, n)
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *randomFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // In Linux, this mixes the written bytes into the entropy pool; we just
+ // throw them away.
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *randomFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ atomic.AddInt64(&fd.off, src.NumBytes())
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *randomFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Linux: drivers/char/random.c:random_fops.llseek == urandom_fops.llseek
+ // == noop_llseek
+ return atomic.LoadInt64(&fd.off), nil
+}
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
new file mode 100644
index 000000000..2e631a252
--- /dev/null
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package memdev
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const zeroDevMinor = 5
+
+// zeroDevice implements vfs.Device for /dev/zero.
+type zeroDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (zeroDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &zeroFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// zeroFD implements vfs.FileDescriptionImpl for /dev/zero.
+type zeroFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *zeroFD) Release(context.Context) {
+ // noop
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *zeroFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *zeroFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return dst.ZeroOut(ctx, dst.NumBytes())
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *zeroFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *zeroFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return src.NumBytes(), nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return 0, nil
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx))
+ if err != nil {
+ return err
+ }
+ opts.MappingIdentity = m
+ opts.Mappable = m
+ return nil
+}
diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD
new file mode 100644
index 000000000..b4b6ca38a
--- /dev/null
+++ b/pkg/sentry/devices/ttydev/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "ttydev",
+ srcs = ["ttydev.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go
new file mode 100644
index 000000000..664e54498
--- /dev/null
+++ b/pkg/sentry/devices/ttydev/ttydev.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ttydev implements an unopenable vfs.Device for /dev/tty.
+package ttydev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // See drivers/tty/tty_io.c:tty_init().
+ ttyDevMinor = 0
+ consoleDevMinor = 1
+)
+
+// ttyDevice implements vfs.Device for /dev/tty.
+type ttyDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return nil, syserror.EIO
+}
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ return vfsObj.RegisterDevice(vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, ttyDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "tty",
+ })
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ return dev.CreateDeviceFile(ctx, "tty", vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, 0666 /* mode */)
+}
diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD
new file mode 100644
index 000000000..71c59287c
--- /dev/null
+++ b/pkg/sentry/devices/tundev/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "tundev",
+ srcs = ["tundev.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/tcpip/link/tun",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go
new file mode 100644
index 000000000..a40625e19
--- /dev/null
+++ b/pkg/sentry/devices/tundev/tundev.go
@@ -0,0 +1,178 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tundev implements the /dev/net/tun device.
+package tundev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// tunDevice implements vfs.Device for /dev/net/tun.
+type tunDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &tunFD{}
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun.
+type tunFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ device tun.Device
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, uio, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fd.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fd.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fd.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, uio, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *tunFD) Release(ctx context.Context) {
+ fd.device.Release(ctx)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *tunFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.Read(ctx, dst, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *tunFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ data, err := fd.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *tunFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.Write(ctx, src, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *tunFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fd.device.Write(data)
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fd *tunFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fd.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fd *tunFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fd *tunFD) EventUnregister(e *waiter.Entry) {
+ fd.device.EventUnregister(e)
+}
+
+// IsNetTunSupported returns whether /dev/net/tun device is supported for s.
+func IsNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
+
+// Register registers all devices implemented by this package in vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ return vfsObj.RegisterDevice(vfs.CharDevice, netTunDevMajor, netTunDevMinor, tunDevice{}, &vfs.RegisterDeviceOptions{})
+}
+
+// CreateDevtmpfsFiles creates device special files in dev representing all
+// devices implemented by this package.
+func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error {
+ return dev.CreateDeviceFile(ctx, "net/tun", vfs.CharDevice, netTunDevMajor, netTunDevMinor, 0666 /* mode */)
+}
diff --git a/pkg/sentry/fdimport/BUILD b/pkg/sentry/fdimport/BUILD
new file mode 100644
index 000000000..5e41ceb4e
--- /dev/null
+++ b/pkg/sentry/fdimport/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fdimport",
+ srcs = [
+ "fdimport.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/host",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/vfs",
+ ],
+)
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
new file mode 100644
index 000000000..1b7cb94c0
--- /dev/null
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -0,0 +1,134 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fdimport
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/host"
+ hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Import imports a slice of FDs into the given FDTable. If console is true,
+// sets up TTY for the first 3 FDs in the slice representing stdin, stdout,
+// stderr. Upon success, Import takes ownership of all FDs.
+func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ if kernel.VFS2Enabled {
+ ttyFile, err := importVFS2(ctx, fdTable, console, fds)
+ return nil, ttyFile, err
+ }
+ ttyFile, err := importFS(ctx, fdTable, console, fds)
+ return ttyFile, nil, err
+}
+
+func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, error) {
+ var ttyFile *fs.File
+ for appFD, hostFD := range fds {
+ var appFile *fs.File
+
+ if console && appFD < 3 {
+ // Import the file as a host TTY file.
+ if ttyFile == nil {
+ var err error
+ appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef(ctx)
+
+ // Remember this in the TTY file, as we will
+ // use it for the other stdio FDs.
+ ttyFile = appFile
+ } else {
+ // Re-use the existing TTY file, as all three
+ // stdio FDs must point to the same fs.File in
+ // order to share TTY state, specifically the
+ // foreground process group id.
+ appFile = ttyFile
+ }
+ } else {
+ // Import the file as a regular host file.
+ var err error
+ appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef(ctx)
+ }
+
+ // Add the file to the FD map.
+ if err := fdTable.NewFDAt(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
+ return nil, err
+ }
+ }
+
+ if ttyFile == nil {
+ return nil, nil
+ }
+ return ttyFile.FileOperations.(*host.TTYFileOperations), nil
+}
+
+func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []int) (*hostvfs2.TTYFileDescription, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, fmt.Errorf("cannot find kernel from context")
+ }
+
+ var ttyFile *vfs.FileDescription
+ for appFD, hostFD := range stdioFDs {
+ var appFile *vfs.FileDescription
+
+ if console && appFD < 3 {
+ // Import the file as a host TTY file.
+ if ttyFile == nil {
+ var err error
+ appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, true /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef(ctx)
+
+ // Remember this in the TTY file, as we will use it for the other stdio
+ // FDs.
+ ttyFile = appFile
+ } else {
+ // Re-use the existing TTY file, as all three stdio FDs must point to
+ // the same fs.File in order to share TTY state, specifically the
+ // foreground process group id.
+ appFile = ttyFile
+ }
+ } else {
+ var err error
+ appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, false /* isTTY */)
+ if err != nil {
+ return nil, err
+ }
+ defer appFile.DecRef(ctx)
+ }
+
+ if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
+ return nil, err
+ }
+ }
+
+ if ttyFile == nil {
+ return nil, nil
+ }
+ return ttyFile.Impl().(*hostvfs2.TTYFileDescription), nil
+}
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index 378602cc9..ea85ab33c 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -44,18 +43,17 @@ go_library(
"splice.go",
"sync.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
+ "//pkg/context",
"//pkg/log",
"//pkg/metric",
"//pkg/p9",
"//pkg/refs",
"//pkg/secio",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
@@ -66,11 +64,11 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/uniqueid",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
- "//third_party/gvsync",
],
)
@@ -109,13 +107,14 @@ go_test(
],
deps = [
":fs",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/kernel/contexttest",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
@@ -128,9 +127,9 @@ go_test(
"mount_test.go",
"path_test.go",
],
- embed = [":fs"],
+ library = ":fs",
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
],
)
diff --git a/pkg/sentry/fs/anon/BUILD b/pkg/sentry/fs/anon/BUILD
index ae1c9cf76..aedcecfa1 100644
--- a/pkg/sentry/fs/anon/BUILD
+++ b/pkg/sentry/fs/anon/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,14 +8,13 @@ go_library(
"anon.go",
"device.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/anon",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/anon/anon.go b/pkg/sentry/fs/anon/anon.go
index 7323c7222..5c421f5fb 100644
--- a/pkg/sentry/fs/anon/anon.go
+++ b/pkg/sentry/fs/anon/anon.go
@@ -18,10 +18,10 @@ package anon
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// NewInode constructs an anonymous Inode that is not associated
diff --git a/pkg/sentry/fs/attr.go b/pkg/sentry/fs/attr.go
index 4f3d6410e..f60bd423d 100644
--- a/pkg/sentry/fs/attr.go
+++ b/pkg/sentry/fs/attr.go
@@ -20,8 +20,8 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
@@ -206,6 +206,11 @@ func IsPipe(s StableAttr) bool {
return s.Type == Pipe
}
+// IsAnonymous returns true if StableAttr.Type matches any type of anonymous.
+func IsAnonymous(s StableAttr) bool {
+ return s.Type == Anonymous
+}
+
// IsSocket returns true if StableAttr.Type matches any type of socket.
func IsSocket(s StableAttr) bool {
return s.Type == Socket
diff --git a/pkg/sentry/fs/context.go b/pkg/sentry/fs/context.go
index dd427de5d..0fbd60056 100644
--- a/pkg/sentry/fs/context.go
+++ b/pkg/sentry/fs/context.go
@@ -16,7 +16,7 @@ package fs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index 9ac62c84d..735452b07 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -17,13 +17,14 @@ package fs
import (
"fmt"
"io"
- "sync"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// copyUp copies a file in an overlay from a lower filesystem to an
@@ -200,7 +201,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
parentUpper := parent.Inode.overlay.upper
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
// Create the file in the upper filesystem and get an Inode for it.
@@ -211,7 +212,7 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
log.Warningf("copy up failed to create file: %v", err)
return syserror.EIO
}
- defer childFile.DecRef()
+ defer childFile.DecRef(ctx)
childUpperInode = childFile.Dirent.Inode
case Directory:
@@ -221,11 +222,11 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
}
childUpper, err := parentUpper.Lookup(ctx, next.name)
if err != nil {
- log.Warningf("copy up failed to lookup directory: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to lookup directory: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
- defer childUpper.DecRef()
+ defer childUpper.DecRef(ctx)
childUpperInode = childUpper.Inode
case Symlink:
@@ -241,11 +242,11 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
}
childUpper, err := parentUpper.Lookup(ctx, next.name)
if err != nil {
- log.Warningf("copy up failed to lookup symlink: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to lookup symlink: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
- defer childUpper.DecRef()
+ defer childUpper.DecRef(ctx)
childUpperInode = childUpper.Inode
default:
@@ -255,23 +256,23 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
// Bring file attributes up to date. This does not include size, which will be
// brought up to date with copyContentsLocked.
if err := copyAttributesLocked(ctx, childUpperInode, next.Inode.overlay.lower); err != nil {
- log.Warningf("copy up failed to copy up attributes: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to copy up attributes: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
// Copy the entire file.
if err := copyContentsLocked(ctx, childUpperInode, next.Inode.overlay.lower, attrs.Size); err != nil {
- log.Warningf("copy up failed to copy up contents: %v", err)
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed to copy up contents: %v", err)
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
lowerMappable := next.Inode.overlay.lower.Mappable()
upperMappable := childUpperInode.Mappable()
if lowerMappable != nil && upperMappable == nil {
- log.Warningf("copy up failed: cannot ensure memory mapping coherence")
- cleanupUpper(ctx, parentUpper, next.name)
+ werr := fmt.Errorf("copy up failed: cannot ensure memory mapping coherence")
+ cleanupUpper(ctx, parentUpper, next.name, werr)
return syserror.EIO
}
@@ -323,12 +324,14 @@ func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error {
return nil
}
-// cleanupUpper removes name from parent, and panics if it is unsuccessful.
-func cleanupUpper(ctx context.Context, parent *Inode, name string) {
+// cleanupUpper is called when copy-up fails. It logs the copy-up error and
+// attempts to remove name from parent. If that fails, then it panics.
+func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr error) {
+ log.Warningf(copyUpErr.Error())
if err := parent.InodeOperations.Remove(ctx, parent, name); err != nil {
// Unfortunately we don't have much choice. We shouldn't
// willingly give the caller access to a nonsense filesystem.
- panic(fmt.Sprintf("overlay filesystem is in an inconsistent state: failed to remove %q from upper filesystem: %v", name, err))
+ panic(fmt.Sprintf("overlay filesystem is in an inconsistent state: copyUp got error: %v; then cleanup failed to remove %q from upper filesystem: %v.", copyUpErr, name, err))
}
}
@@ -349,14 +352,14 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
if err != nil {
return err
}
- defer upperFile.DecRef()
+ defer upperFile.DecRef(ctx)
// Get a handle to the lower filesystem, which we will read from.
lowerFile, err := overlayFile(ctx, lower, FileFlags{Read: true})
if err != nil {
return err
}
- defer lowerFile.DecRef()
+ defer lowerFile.DecRef(ctx)
// Use a buffer pool to minimize allocations.
buf := copyUpBuffers.Get().([]byte)
@@ -395,12 +398,12 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
// Size and permissions are set on upper when the file content is copied
// and when the file is created respectively.
func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error {
- // Extract attributes fro the lower filesystem.
+ // Extract attributes from the lower filesystem.
lowerAttr, err := lower.UnstableAttr(ctx)
if err != nil {
return err
}
- lowerXattr, err := lower.Listxattr()
+ lowerXattr, err := lower.ListXattr(ctx, linux.XATTR_SIZE_MAX)
if err != nil && err != syserror.EOPNOTSUPP {
return err
}
@@ -421,11 +424,11 @@ func copyAttributesLocked(ctx context.Context, upper *Inode, lower *Inode) error
if isXattrOverlay(name) {
continue
}
- value, err := lower.Getxattr(name)
+ value, err := lower.GetXattr(ctx, name, linux.XATTR_SIZE_MAX)
if err != nil {
return err
}
- if err := upper.InodeOperations.Setxattr(upper, name, value); err != nil {
+ if err := upper.InodeOperations.SetXattr(ctx, upper, name, value, 0 /* flags */); err != nil {
return err
}
}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index 1d80bf15a..c7a11eec1 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -19,13 +19,13 @@ import (
"crypto/rand"
"fmt"
"io"
- "sync"
"testing"
"gvisor.dev/gvisor/pkg/sentry/fs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -126,7 +126,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile {
if err != nil {
t.Fatalf("failed to create file %q: %v", name, err)
}
- defer f.DecRef()
+ defer f.DecRef(ctx)
relname, _ := f.Dirent.FullName(lowerRoot)
@@ -171,7 +171,7 @@ func makeOverlayTestFiles(t *testing.T) []*overlayTestFile {
if err != nil {
t.Fatalf("failed to find %q: %v", f.name, err)
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
f.File, err = d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
if err != nil {
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index a0d9e8496..9379a4d7b 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,27 +9,32 @@ go_library(
"device.go",
"fs.go",
"full.go",
+ "net_tun.go",
"null.go",
"random.go",
"tty.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/dev",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/rand",
- "//pkg/sentry/context",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/safemem",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/socket/netstack",
"//pkg/syserror",
+ "//pkg/tcpip/link/tun",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index f739c476c..acbd401a0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -18,11 +18,12 @@ package dev
import (
"math"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Memory device numbers are from Linux's drivers/char/mem.c
@@ -66,8 +67,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
})
}
-func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555))
+func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
@@ -111,7 +112,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
// the devpts to be mounted over.
- "pts": newDirectory(ctx, msrc),
+ "pts": newDirectory(ctx, nil, msrc),
// Similarly, applications expect a ptmx device at /dev/ptmx
// connected to the terminals provided by /dev/pts/. Rather
// than creating a device directly (which requires a hairy
@@ -126,6 +127,12 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
}
+ if isNetTunSupported(inet.StackFromContext(ctx)) {
+ contents["net"] = newDirectory(ctx, map[string]*fs.Inode{
+ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
+ }, msrc)
+ }
+
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
diff --git a/pkg/sentry/fs/dev/fs.go b/pkg/sentry/fs/dev/fs.go
index 55f8af704..5e518fb63 100644
--- a/pkg/sentry/fs/dev/fs.go
+++ b/pkg/sentry/fs/dev/fs.go
@@ -15,7 +15,7 @@
package dev
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
diff --git a/pkg/sentry/fs/dev/full.go b/pkg/sentry/fs/dev/full.go
index 07e0ea010..deb9c6ad8 100644
--- a/pkg/sentry/fs/dev/full.go
+++ b/pkg/sentry/fs/dev/full.go
@@ -16,11 +16,11 @@ package dev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
new file mode 100644
index 000000000..ec474e554
--- /dev/null
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -0,0 +1,177 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// +stateify savable
+type netTunInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*netTunInodeOperations)(nil)
+
+func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations {
+ return &netTunInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
+}
+
+// +stateify savable
+type netTunFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ device tun.Device
+}
+
+var _ fs.FileOperations = (*netTunFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (fops *netTunFileOperations) Release(ctx context.Context) {
+ fops.device.Release(ctx)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fops.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fops.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Write implements fs.FileOperations.Write.
+func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fops.device.Write(data)
+}
+
+// Read implements fs.FileOperations.Read.
+func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := fops.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fops.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fops.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ fops.device.EventUnregister(e)
+}
+
+// isNetTunSupported returns whether /dev/net/tun device is supported for s.
+func isNetTunSupported(s inet.Stack) bool {
+ _, ok := s.(*netstack.Stack)
+ return ok
+}
diff --git a/pkg/sentry/fs/dev/null.go b/pkg/sentry/fs/dev/null.go
index 4404b97ef..aec33d0d9 100644
--- a/pkg/sentry/fs/dev/null.go
+++ b/pkg/sentry/fs/dev/null.go
@@ -16,7 +16,7 @@ package dev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
diff --git a/pkg/sentry/fs/dev/random.go b/pkg/sentry/fs/dev/random.go
index 49cb92f6e..2a9bbeb18 100644
--- a/pkg/sentry/fs/dev/random.go
+++ b/pkg/sentry/fs/dev/random.go
@@ -16,12 +16,12 @@ package dev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/dev/tty.go b/pkg/sentry/fs/dev/tty.go
index 87d80e292..760ca563d 100644
--- a/pkg/sentry/fs/dev/tty.go
+++ b/pkg/sentry/fs/dev/tty.go
@@ -16,7 +16,7 @@ package dev
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/waiter"
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index 3cb73bd78..a2f751068 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -17,17 +17,16 @@ package fs
import (
"fmt"
"path"
- "sort"
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -121,9 +120,6 @@ type Dirent struct {
// deleted may be set atomically when removed.
deleted int32
- // frozen indicates this entry can't walk to unknown nodes.
- frozen bool
-
// mounted is true if Dirent is a mount point, similar to include/linux/dcache.h:DCACHE_MOUNTED.
mounted bool
@@ -253,8 +249,7 @@ func (d *Dirent) IsNegative() bool {
return d.Inode == nil
}
-// hashChild will hash child into the children list of its new parent d, carrying over
-// any "frozen" state from d.
+// hashChild will hash child into the children list of its new parent d.
//
// Returns (*WeakRef, true) if hashing child caused a Dirent to be unhashed. The caller must
// validate the returned unhashed weak reference. Common cases:
@@ -282,9 +277,6 @@ func (d *Dirent) hashChild(child *Dirent) (*refs.WeakRef, bool) {
d.IncRef()
}
- // Carry over parent's frozen state.
- child.frozen = d.frozen
-
return d.hashChildParentSet(child)
}
@@ -320,9 +312,9 @@ func (d *Dirent) SyncAll(ctx context.Context) {
// There is nothing to sync for a read-only filesystem.
if !d.Inode.MountSource.Flags.ReadOnly {
- // FIXME(b/34856369): This should be a mount traversal, not a
- // Dirent traversal, because some Inodes that need to be synced
- // may no longer be reachable by name (after sys_unlink).
+ // NOTE(b/34856369): This should be a mount traversal, not a Dirent
+ // traversal, because some Inodes that need to be synced may no longer
+ // be reachable by name (after sys_unlink).
//
// Write out metadata, dirty page cached pages, and sync disk/remote
// caches.
@@ -333,7 +325,7 @@ func (d *Dirent) SyncAll(ctx context.Context) {
for _, w := range d.children {
if child := w.Get(); child != nil {
child.(*Dirent).SyncAll(ctx)
- child.DecRef()
+ child.DecRef(ctx)
}
}
}
@@ -400,38 +392,6 @@ func (d *Dirent) MountRoot() *Dirent {
return mountRoot
}
-// Freeze prevents this dirent from walking to more nodes. Freeze is applied
-// recursively to all children.
-//
-// If this particular Dirent represents a Virtual node, then Walks and Creates
-// may proceed as before.
-//
-// Freeze can only be called before the application starts running, otherwise
-// the root it might be out of sync with the application root if modified by
-// sys_chroot.
-func (d *Dirent) Freeze() {
- d.mu.Lock()
- defer d.mu.Unlock()
- if d.frozen {
- // Already frozen.
- return
- }
- d.frozen = true
-
- // Take a reference when freezing.
- for _, w := range d.children {
- if child := w.Get(); child != nil {
- // NOTE: We would normally drop the reference here. But
- // instead we're hanging on to it.
- ch := child.(*Dirent)
- ch.Freeze()
- }
- }
-
- // Drop all expired weak references.
- d.flush()
-}
-
// descendantOf returns true if the receiver dirent is equal to, or a
// descendant of, the argument dirent.
//
@@ -491,7 +451,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// which don't hold a hard reference on their parent (their parent holds a
// hard reference on them, and they contain virtually no state). But this is
// good house-keeping.
- child.DecRef()
+ child.DecRef(ctx)
return nil, syscall.ENOENT
}
@@ -508,25 +468,20 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// their pins on the child. Inotify doesn't properly support filesystems that
// revalidate dirents (since watches are lost on revalidation), but if we fail
// to unpin the watches child will never be GCed.
- cd.Inode.Watches.Unpin(cd)
+ cd.Inode.Watches.Unpin(ctx, cd)
// This child needs to be revalidated, fallthrough to unhash it. Make sure
// to not leak a reference from Get().
//
// Note that previous lookups may still have a reference to this stale child;
// this can't be helped, but we can ensure that *new* lookups are up-to-date.
- child.DecRef()
+ child.DecRef(ctx)
}
// Either our weak reference expired or we need to revalidate it. Unhash child first, we're
// about to replace it.
delete(d.children, name)
- w.Drop()
- }
-
- // Are we allowed to do the lookup?
- if d.frozen && !d.Inode.IsVirtual() {
- return nil, syscall.ENOENT
+ w.Drop(ctx)
}
// Slow path: load the InodeOperations into memory. Since this is a hot path and the lookup may be
@@ -557,12 +512,12 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// There are active references to the existing child, prefer it to the one we
// retrieved from Lookup. Likely the Lookup happened very close to the insertion
// of child, so considering one stale over the other is fairly arbitrary.
- c.DecRef()
+ c.DecRef(ctx)
// The child that was installed could be negative.
if cd.IsNegative() {
// If so, don't leak a reference and short circuit.
- child.DecRef()
+ child.DecRef(ctx)
return nil, syscall.ENOENT
}
@@ -576,7 +531,7 @@ func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnl
// we did the Inode.Lookup. Fully drop the weak reference and fallback to using the child
// we looked up.
delete(d.children, name)
- w.Drop()
+ w.Drop(ctx)
}
// Give the looked up child a parent. We cannot kick out entries, since we just checked above
@@ -632,7 +587,7 @@ func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool {
return false
}
// Child exists.
- child.DecRef()
+ child.DecRef(ctx)
return true
}
@@ -659,11 +614,6 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi
return nil, syscall.EEXIST
}
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return nil, syscall.ENOENT
- }
-
// Try the create. We need to trust the file system to return EEXIST (or something
// that will translate to EEXIST) if name already exists.
file, err := d.Inode.Create(ctx, d, name, flags, perms)
@@ -672,7 +622,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi
}
child := file.Dirent
- d.finishCreate(child, name)
+ d.finishCreate(ctx, child, name)
// Return the reference and the new file. When the last reference to
// the file is dropped, file.Dirent may no longer be cached.
@@ -681,7 +631,7 @@ func (d *Dirent) Create(ctx context.Context, root *Dirent, name string, flags Fi
// finishCreate validates the created file, adds it as a child of this dirent,
// and notifies any watchers.
-func (d *Dirent) finishCreate(child *Dirent, name string) {
+func (d *Dirent) finishCreate(ctx context.Context, child *Dirent, name string) {
// Sanity check c, its name must be consistent.
if child.name != name {
panic(fmt.Sprintf("create from %q to %q returned unexpected name %q", d.name, name, child.name))
@@ -700,14 +650,14 @@ func (d *Dirent) finishCreate(child *Dirent, name string) {
panic(fmt.Sprintf("hashed child %q over a positive child", child.name))
}
// Don't leak a reference.
- old.DecRef()
+ old.DecRef(ctx)
// Drop d's reference.
- old.DecRef()
+ old.DecRef(ctx)
}
// Finally drop the useless weak reference on the floor.
- w.Drop()
+ w.Drop(ctx)
}
d.Inode.Watches.Notify(name, linux.IN_CREATE, 0)
@@ -727,11 +677,6 @@ func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, c
return syscall.EEXIST
}
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Remove any negative Dirent. We've already asserted above with d.exists
// that the only thing remaining here can be a negative Dirent.
if w, ok := d.children[name]; ok {
@@ -741,17 +686,17 @@ func (d *Dirent) genericCreate(ctx context.Context, root *Dirent, name string, c
panic(fmt.Sprintf("hashed over a positive child %q", old.(*Dirent).name))
}
// Don't leak a reference.
- old.DecRef()
+ old.DecRef(ctx)
// Drop d's reference.
- old.DecRef()
+ old.DecRef(ctx)
}
// Unhash the negative Dirent, name needs to exist now.
delete(d.children, name)
// Finally drop the useless weak reference on the floor.
- w.Drop()
+ w.Drop(ctx)
}
// Execute the create operation.
@@ -811,7 +756,7 @@ func (d *Dirent) Bind(ctx context.Context, root *Dirent, name string, data trans
if e != nil {
return e
}
- d.finishCreate(childDir, name)
+ d.finishCreate(ctx, childDir, name)
return nil
})
if err == syscall.EEXIST {
@@ -862,49 +807,6 @@ func (d *Dirent) GetDotAttrs(root *Dirent) (DentAttr, DentAttr) {
return dot, dot
}
-// readdirFrozen returns readdir results based solely on the frozen children.
-func (d *Dirent) readdirFrozen(root *Dirent, offset int64, dirCtx *DirCtx) (int64, error) {
- // Collect attrs for "." and "..".
- attrs := make(map[string]DentAttr)
- names := []string{".", ".."}
- attrs["."], attrs[".."] = d.GetDotAttrs(root)
-
- // Get info from all children.
- d.mu.Lock()
- defer d.mu.Unlock()
- for name, w := range d.children {
- if child := w.Get(); child != nil {
- defer child.DecRef()
-
- // Skip negative children.
- if child.(*Dirent).IsNegative() {
- continue
- }
-
- sattr := child.(*Dirent).Inode.StableAttr
- attrs[name] = DentAttr{
- Type: sattr.Type,
- InodeID: sattr.InodeID,
- }
- names = append(names, name)
- }
- }
-
- sort.Strings(names)
-
- if int(offset) >= len(names) {
- return offset, nil
- }
- names = names[int(offset):]
- for _, name := range names {
- if err := dirCtx.DirEmit(name, attrs[name]); err != nil {
- return offset, err
- }
- offset++
- }
- return offset, nil
-}
-
// DirIterator is an open directory containing directory entries that can be read.
type DirIterator interface {
// IterateDir emits directory entries by calling dirCtx.EmitDir, beginning
@@ -964,10 +866,6 @@ func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent,
return offset, nil
}
- if d.frozen {
- return d.readdirFrozen(root, offset, dirCtx)
- }
-
// Collect attrs for "." and "..".
dot, dotdot := d.GetDotAttrs(root)
@@ -1003,7 +901,7 @@ func direntReaddir(ctx context.Context, d *Dirent, it DirIterator, root *Dirent,
// references to children.
//
// Preconditions: d.mu must be held.
-func (d *Dirent) flush() {
+func (d *Dirent) flush(ctx context.Context) {
expired := make(map[string]*refs.WeakRef)
for n, w := range d.children {
// Call flush recursively on each child before removing our
@@ -1014,7 +912,7 @@ func (d *Dirent) flush() {
if !cd.IsNegative() {
// Flush the child.
cd.mu.Lock()
- cd.flush()
+ cd.flush(ctx)
cd.mu.Unlock()
// Allow the file system to drop extra references on child.
@@ -1022,13 +920,13 @@ func (d *Dirent) flush() {
}
// Don't leak a reference.
- child.DecRef()
+ child.DecRef(ctx)
}
// Check if the child dirent is closed, and mark it as expired if it is.
// We must call w.Get() again here, since the child could have been closed
// by the calls to flush() and cache.Remove() in the above if-block.
if child := w.Get(); child != nil {
- child.DecRef()
+ child.DecRef(ctx)
} else {
expired[n] = w
}
@@ -1037,7 +935,7 @@ func (d *Dirent) flush() {
// Remove expired entries.
for n, w := range expired {
delete(d.children, n)
- w.Drop()
+ w.Drop(ctx)
}
}
@@ -1068,11 +966,6 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err
return nil, syserror.EINVAL
}
- // Are we frozen?
- if d.parent.frozen && !d.parent.Inode.IsVirtual() {
- return nil, syserror.ENOENT
- }
-
// Dirent that'll replace d.
//
// Note that NewDirent returns with one reference taken; the reference
@@ -1084,7 +977,7 @@ func (d *Dirent) mount(ctx context.Context, inode *Inode) (newChild *Dirent, err
if !ok {
panic("mount must mount over an existing dirent")
}
- weakRef.Drop()
+ weakRef.Drop(ctx)
// Note that even though `d` is now hidden, it still holds a reference
// to its parent.
@@ -1101,11 +994,6 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
return syserror.ENOENT
}
- // Are we frozen?
- if d.parent.frozen && !d.parent.Inode.IsVirtual() {
- return syserror.ENOENT
- }
-
// Remount our former child in its place.
//
// As replacement used to be our child, it must already have the right
@@ -1114,13 +1002,13 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
if !ok {
panic("mount must mount over an existing dirent")
}
- weakRef.Drop()
+ weakRef.Drop(ctx)
// d is not reachable anymore, and hence not mounted anymore.
d.mounted = false
// Drop mount reference.
- d.DecRef()
+ d.DecRef(ctx)
return nil
}
@@ -1135,18 +1023,13 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
unlock := d.lockDirectory()
defer unlock()
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Try to walk to the node.
child, err := d.walk(ctx, root, name, false /* may unlock */)
if err != nil {
// Child does not exist.
return err
}
- defer child.DecRef()
+ defer child.DecRef(ctx)
// Remove cannot remove directories.
if IsDir(child.Inode.StableAttr) {
@@ -1172,7 +1055,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
atomic.StoreInt32(&child.deleted, 1)
if w, ok := d.children[name]; ok {
delete(d.children, name)
- w.Drop()
+ w.Drop(ctx)
}
// Allow the file system to drop extra references on child.
@@ -1184,7 +1067,7 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath
// inode may have other links. If this was the last link, the events for the
// watch removal will be queued by the inode destructor.
child.Inode.Watches.MarkUnlinked()
- child.Inode.Watches.Unpin(child)
+ child.Inode.Watches.Unpin(ctx, child)
d.Inode.Watches.Notify(name, linux.IN_DELETE, 0)
return nil
@@ -1201,11 +1084,6 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
unlock := d.lockDirectory()
defer unlock()
- // Are we frozen?
- if d.frozen && !d.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Check for dots.
if name == "." {
// Rejected as the last component by rmdir(2).
@@ -1222,7 +1100,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
// Child does not exist.
return err
}
- defer child.DecRef()
+ defer child.DecRef(ctx)
// RemoveDirectory can only remove directories.
if !IsDir(child.Inode.StableAttr) {
@@ -1243,7 +1121,7 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
atomic.StoreInt32(&child.deleted, 1)
if w, ok := d.children[name]; ok {
delete(d.children, name)
- w.Drop()
+ w.Drop(ctx)
}
// Allow the file system to drop extra references on child.
@@ -1252,14 +1130,14 @@ func (d *Dirent) RemoveDirectory(ctx context.Context, root *Dirent, name string)
// Finally, let inotify know the child is being unlinked. Drop any extra
// refs from inotify to this child dirent.
child.Inode.Watches.MarkUnlinked()
- child.Inode.Watches.Unpin(child)
+ child.Inode.Watches.Unpin(ctx, child)
d.Inode.Watches.Notify(name, linux.IN_ISDIR|linux.IN_DELETE, 0)
return nil
}
// destroy closes this node and all children.
-func (d *Dirent) destroy() {
+func (d *Dirent) destroy(ctx context.Context) {
if d.IsNegative() {
// Nothing to tear-down and no parent references to drop, since a negative
// Dirent does not take a references on its parent, has no Inode and no children.
@@ -1275,19 +1153,19 @@ func (d *Dirent) destroy() {
if c.(*Dirent).IsNegative() {
// The parent holds both weak and strong refs in the case of
// negative dirents.
- c.DecRef()
+ c.DecRef(ctx)
}
// Drop the reference we just acquired in WeakRef.Get.
- c.DecRef()
+ c.DecRef(ctx)
}
- w.Drop()
+ w.Drop(ctx)
}
d.children = nil
allDirents.remove(d)
// Drop our reference to the Inode.
- d.Inode.DecRef()
+ d.Inode.DecRef(ctx)
// Allow the Dirent to be GC'ed after this point, since the Inode may still
// be referenced after the Dirent is destroyed (for instance by filesystem
@@ -1297,7 +1175,7 @@ func (d *Dirent) destroy() {
// Drop the reference we have on our parent if we took one. renameMu doesn't need to be
// held because d can't be reparented without any references to it left.
if d.parent != nil {
- d.parent.DecRef()
+ d.parent.DecRef(ctx)
}
}
@@ -1323,14 +1201,14 @@ func (d *Dirent) TryIncRef() bool {
// DecRef decreases the Dirent's refcount and drops its reference on its mount.
//
// DecRef implements RefCounter.DecRef with destructor d.destroy.
-func (d *Dirent) DecRef() {
+func (d *Dirent) DecRef(ctx context.Context) {
if d.Inode != nil {
// Keep mount around, since DecRef may destroy d.Inode.
msrc := d.Inode.MountSource
- d.DecRefWithDestructor(d.destroy)
+ d.DecRefWithDestructor(ctx, d.destroy)
msrc.DecDirentRefs()
} else {
- d.DecRefWithDestructor(d.destroy)
+ d.DecRefWithDestructor(ctx, d.destroy)
}
}
@@ -1438,8 +1316,8 @@ func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName
}, nil
}
-func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error {
- uattr, err := dir.Inode.UnstableAttr(ctx)
+func (d *Dirent) checkSticky(ctx context.Context, victim *Dirent) error {
+ uattr, err := d.Inode.UnstableAttr(ctx)
if err != nil {
return syserror.EPERM
}
@@ -1465,30 +1343,33 @@ func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error {
return syserror.EPERM
}
-// MayDelete determines whether `name`, a child of `dir`, can be deleted or
+// MayDelete determines whether `name`, a child of `d`, can be deleted or
// renamed by `ctx`.
//
// Compare Linux kernel fs/namei.c:may_delete.
-func MayDelete(ctx context.Context, root, dir *Dirent, name string) error {
- if err := dir.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error {
+ if err := d.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
return err
}
- victim, err := dir.Walk(ctx, root, name)
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ victim, err := d.walk(ctx, root, name, true /* may unlock */)
if err != nil {
return err
}
- defer victim.DecRef()
+ defer victim.DecRef(ctx)
- return mayDelete(ctx, dir, victim)
+ return d.mayDelete(ctx, victim)
}
// mayDelete determines whether `victim`, a child of `dir`, can be deleted or
// renamed by `ctx`.
//
// Preconditions: `dir` is writable and executable by `ctx`.
-func mayDelete(ctx context.Context, dir, victim *Dirent) error {
- if err := checkSticky(ctx, dir, victim); err != nil {
+func (d *Dirent) mayDelete(ctx context.Context, victim *Dirent) error {
+ if err := d.checkSticky(ctx, victim); err != nil {
return err
}
@@ -1516,15 +1397,6 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
return err
}
- // Are we frozen?
- // TODO(jamieliu): Is this the right errno?
- if oldParent.frozen && !oldParent.Inode.IsVirtual() {
- return syscall.ENOENT
- }
- if newParent.frozen && !newParent.Inode.IsVirtual() {
- return syscall.ENOENT
- }
-
// Do we have general permission to remove from oldParent and
// create/replace in newParent?
if err := oldParent.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
@@ -1539,10 +1411,10 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
if err != nil {
return err
}
- defer renamed.DecRef()
+ defer renamed.DecRef(ctx)
// Check that the renamed dirent is deletable.
- if err := mayDelete(ctx, oldParent, renamed); err != nil {
+ if err := oldParent.mayDelete(ctx, renamed); err != nil {
return err
}
@@ -1580,14 +1452,14 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// across the Rename, so must call DecRef manually (no defer).
// Check that we can delete replaced.
- if err := mayDelete(ctx, newParent, replaced); err != nil {
- replaced.DecRef()
+ if err := newParent.mayDelete(ctx, replaced); err != nil {
+ replaced.DecRef(ctx)
return err
}
// Target should not be an ancestor of source.
if oldParent.descendantOf(replaced) {
- replaced.DecRef()
+ replaced.DecRef(ctx)
// Note that Linux returns EINVAL if the source is an
// ancestor of target, but ENOTEMPTY if the target is
@@ -1598,7 +1470,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// Check that replaced is not a mount point.
if replaced.isMountPointLocked() {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syscall.EBUSY
}
@@ -1606,11 +1478,11 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
oldIsDir := IsDir(renamed.Inode.StableAttr)
newIsDir := IsDir(replaced.Inode.StableAttr)
if !newIsDir && oldIsDir {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syscall.ENOTDIR
}
if !oldIsDir && newIsDir {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syscall.EISDIR
}
@@ -1621,13 +1493,13 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// open across renames is currently broken for multiple
// reasons, so we flush all references on the replaced node and
// its children.
- replaced.Inode.Watches.Unpin(replaced)
+ replaced.Inode.Watches.Unpin(ctx, replaced)
replaced.mu.Lock()
- replaced.flush()
+ replaced.flush(ctx)
replaced.mu.Unlock()
// Done with replaced.
- replaced.DecRef()
+ replaced.DecRef(ctx)
}
if err := renamed.Inode.Rename(ctx, oldParent, renamed, newParent, newName, replaced != nil); err != nil {
@@ -1641,14 +1513,14 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// can't destroy oldParent (and try to retake its lock) because
// Rename's caller must be holding a reference.
newParent.IncRef()
- oldParent.DecRef()
+ oldParent.DecRef(ctx)
}
if w, ok := newParent.children[newName]; ok {
- w.Drop()
+ w.Drop(ctx)
delete(newParent.children, newName)
}
if w, ok := oldParent.children[oldName]; ok {
- w.Drop()
+ w.Drop(ctx)
delete(oldParent.children, oldName)
}
@@ -1679,7 +1551,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// Same as replaced.flush above.
renamed.mu.Lock()
- renamed.flush()
+ renamed.flush(ctx)
renamed.mu.Unlock()
return nil
diff --git a/pkg/sentry/fs/dirent_cache.go b/pkg/sentry/fs/dirent_cache.go
index 60a15a275..7d9dd717e 100644
--- a/pkg/sentry/fs/dirent_cache.go
+++ b/pkg/sentry/fs/dirent_cache.go
@@ -16,7 +16,9 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
// DirentCache is an LRU cache of Dirents. The Dirent's refCount is
@@ -100,9 +102,7 @@ func (c *DirentCache) remove(d *Dirent) {
panic(fmt.Sprintf("trying to remove %v, which is not in the dirent cache", d))
}
c.list.Remove(d)
- d.SetPrev(nil)
- d.SetNext(nil)
- d.DecRef()
+ d.DecRef(context.Background())
c.currentSize--
if c.limit != nil {
c.limit.dec()
diff --git a/pkg/sentry/fs/dirent_cache_limiter.go b/pkg/sentry/fs/dirent_cache_limiter.go
index ebb80bd50..525ee25f9 100644
--- a/pkg/sentry/fs/dirent_cache_limiter.go
+++ b/pkg/sentry/fs/dirent_cache_limiter.go
@@ -16,7 +16,8 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// DirentCacheLimiter acts as a global limit for all dirent caches in the
diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go
index 47bc72a88..176b894ba 100644
--- a/pkg/sentry/fs/dirent_refs_test.go
+++ b/pkg/sentry/fs/dirent_refs_test.go
@@ -18,8 +18,8 @@ import (
"syscall"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
)
func newMockDirInode(ctx context.Context, cache *DirentCache) *Inode {
@@ -51,7 +51,7 @@ func TestWalkPositive(t *testing.T) {
t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 1)
}
- d.DecRef()
+ d.DecRef(ctx)
if got := root.ReadRefs(); got != 1 {
t.Fatalf("root has a ref count of %d, want %d", got, 1)
@@ -61,7 +61,7 @@ func TestWalkPositive(t *testing.T) {
t.Fatalf("child name = %q has a ref count of %d, want %d", d.name, got, 0)
}
- root.flush()
+ root.flush(ctx)
if got := len(root.children); got != 0 {
t.Fatalf("root has %d children, want %d", got, 0)
@@ -114,7 +114,7 @@ func TestWalkNegative(t *testing.T) {
t.Fatalf("child has a ref count of %d, want %d", got, 2)
}
- child.DecRef()
+ child.DecRef(ctx)
if got := child.(*Dirent).ReadRefs(); got != 1 {
t.Fatalf("child has a ref count of %d, want %d", got, 1)
@@ -124,7 +124,7 @@ func TestWalkNegative(t *testing.T) {
t.Fatalf("root has %d children, want %d", got, 1)
}
- root.DecRef()
+ root.DecRef(ctx)
if got := root.ReadRefs(); got != 0 {
t.Fatalf("root has a ref count of %d, want %d", got, 0)
@@ -351,9 +351,9 @@ func TestRemoveExtraRefs(t *testing.T) {
t.Fatalf("dirent has a ref count of %d, want %d", got, 1)
}
- d.DecRef()
+ d.DecRef(ctx)
- test.root.flush()
+ test.root.flush(ctx)
if got := len(test.root.children); got != 0 {
t.Errorf("root has %d children, want %d", got, 0)
@@ -403,8 +403,8 @@ func TestRenameExtraRefs(t *testing.T) {
t.Fatalf("Rename got error %v, want nil", err)
}
- oldParent.flush()
- newParent.flush()
+ oldParent.flush(ctx)
+ newParent.flush(ctx)
// Expect to have only active references.
if got := renamed.ReadRefs(); got != 1 {
diff --git a/pkg/sentry/fs/dirent_state.go b/pkg/sentry/fs/dirent_state.go
index f623d6c0e..67a35f0b2 100644
--- a/pkg/sentry/fs/dirent_state.go
+++ b/pkg/sentry/fs/dirent_state.go
@@ -18,6 +18,7 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
)
@@ -48,7 +49,7 @@ func (d *Dirent) saveChildren() map[string]*Dirent {
for name, w := range d.children {
if rc := w.Get(); rc != nil {
// Drop the reference count obtain in w.Get()
- rc.DecRef()
+ rc.DecRef(context.Background())
cd := rc.(*Dirent)
if cd.IsNegative() {
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index 277ee4c31..1d09e983c 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,20 +9,20 @@ go_library(
"pipe_opener.go",
"pipe_state.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fdpipe",
imports = ["gvisor.dev/gvisor/pkg/sentry/fs"],
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
"//pkg/log",
+ "//pkg/safemem",
"//pkg/secio",
- "//pkg/sentry/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/safemem",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -35,15 +34,15 @@ go_test(
"pipe_opener_test.go",
"pipe_test.go",
],
- embed = [":fdpipe"],
+ library = ":fdpipe",
deps = [
+ "//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
- "//pkg/sentry/usermem",
"//pkg/syserror",
+ "//pkg/usermem",
"@com_github_google_uuid//:go_default_library",
],
)
diff --git a/pkg/sentry/fs/fdpipe/pipe.go b/pkg/sentry/fs/fdpipe/pipe.go
index 669ffcb75..b99199798 100644
--- a/pkg/sentry/fs/fdpipe/pipe.go
+++ b/pkg/sentry/fs/fdpipe/pipe.go
@@ -17,19 +17,19 @@ package fdpipe
import (
"os"
- "sync"
"syscall"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/secio"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -115,7 +115,7 @@ func (p *pipeOperations) Readiness(mask waiter.EventMask) (eventMask waiter.Even
}
// Release implements fs.FileOperations.Release.
-func (p *pipeOperations) Release() {
+func (p *pipeOperations) Release(context.Context) {
fdnotifier.RemoveFD(int32(p.file.FD()))
p.file.Close()
p.file = nil
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener.go b/pkg/sentry/fs/fdpipe/pipe_opener.go
index 64b558975..0c3595998 100644
--- a/pkg/sentry/fs/fdpipe/pipe_opener.go
+++ b/pkg/sentry/fs/fdpipe/pipe_opener.go
@@ -20,8 +20,8 @@ import (
"syscall"
"time"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/fdpipe/pipe_opener_test.go b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
index 8e4d839e1..b9cec4b13 100644
--- a/pkg/sentry/fs/fdpipe/pipe_opener_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_opener_test.go
@@ -25,12 +25,13 @@ import (
"time"
"github.com/google/uuid"
+
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type hostOpener struct {
@@ -181,7 +182,7 @@ func TestTryOpen(t *testing.T) {
// Cleanup the state of the pipe, and remove the fd from the
// fdnotifier. Sadly this needed to maintain the correctness
// of other tests because the fdnotifier is global.
- pipeOps.Release()
+ pipeOps.Release(ctx)
}
continue
}
@@ -190,7 +191,7 @@ func TestTryOpen(t *testing.T) {
}
if pipeOps != nil {
// Same as above.
- pipeOps.Release()
+ pipeOps.Release(ctx)
}
}
}
@@ -278,7 +279,7 @@ func TestPipeOpenUnblocksEventually(t *testing.T) {
pipeOps, err := Open(ctx, opener, flags)
if pipeOps != nil {
// Same as TestTryOpen.
- pipeOps.Release()
+ pipeOps.Release(ctx)
}
// Check that the partner opened the file successfully.
@@ -324,7 +325,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) {
ctx := contexttest.Context(t)
pipeOps, err := pipeOpenState.TryOpen(ctx, opener, fs.FileFlags{Read: true})
if pipeOps != nil {
- pipeOps.Release()
+ pipeOps.Release(ctx)
t.Fatalf("open(%s, %o) got file, want nil", name, syscall.O_RDONLY)
}
if err != syserror.ErrWouldBlock {
@@ -350,7 +351,7 @@ func TestCopiedReadAheadBuffer(t *testing.T) {
if pipeOps == nil {
t.Fatalf("open(%s, %o) got nil file, want not nil", name, syscall.O_RDONLY)
}
- defer pipeOps.Release()
+ defer pipeOps.Release(ctx)
if err != nil {
t.Fatalf("open(%s, %o) got error %v, want nil", name, syscall.O_RDONLY, err)
@@ -470,14 +471,14 @@ func TestPipeHangup(t *testing.T) {
f := <-fdchan
if f < 0 {
t.Errorf("%s: partner routine got fd %d, want > 0", test.desc, f)
- pipeOps.Release()
+ pipeOps.Release(ctx)
continue
}
if test.hangupSelf {
// Hangup self and assert that our partner got the expected hangup
// error.
- pipeOps.Release()
+ pipeOps.Release(ctx)
if test.flags.Read {
// Partner is writer.
@@ -489,7 +490,7 @@ func TestPipeHangup(t *testing.T) {
} else {
// Hangup our partner and expect us to get the hangup error.
syscall.Close(f)
- defer pipeOps.Release()
+ defer pipeOps.Release(ctx)
if test.flags.Read {
assertReaderHungup(t, test.desc, pipeOps.(*pipeOperations).file)
diff --git a/pkg/sentry/fs/fdpipe/pipe_state.go b/pkg/sentry/fs/fdpipe/pipe_state.go
index 29175fb3d..af8230a7d 100644
--- a/pkg/sentry/fs/fdpipe/pipe_state.go
+++ b/pkg/sentry/fs/fdpipe/pipe_state.go
@@ -17,10 +17,10 @@ package fdpipe
import (
"fmt"
"io/ioutil"
- "sync"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// beforeSave is invoked by stateify.
diff --git a/pkg/sentry/fs/fdpipe/pipe_test.go b/pkg/sentry/fs/fdpipe/pipe_test.go
index 69abc1e71..1c9e82562 100644
--- a/pkg/sentry/fs/fdpipe/pipe_test.go
+++ b/pkg/sentry/fs/fdpipe/pipe_test.go
@@ -23,10 +23,10 @@ import (
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func singlePipeFD() (int, error) {
@@ -98,10 +98,11 @@ func TestNewPipe(t *testing.T) {
}
f := fd.New(gfd)
- p, err := newPipeOperations(contexttest.Context(t), nil, test.flags, f, test.readAheadBuffer)
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, test.flags, f, test.readAheadBuffer)
if p != nil {
// This is necessary to remove the fd from the global fd notifier.
- defer p.Release()
+ defer p.Release(ctx)
} else {
// If there is no p to DecRef on, because newPipeOperations failed, then the
// file still needs to be closed.
@@ -119,7 +120,7 @@ func TestNewPipe(t *testing.T) {
continue
}
if flags := p.flags; test.flags != flags {
- t.Errorf("%s: got file flags %s, want %s", test.desc, flags, test.flags)
+ t.Errorf("%s: got file flags %v, want %v", test.desc, flags, test.flags)
continue
}
if len(test.readAheadBuffer) != len(p.readAheadBuffer) {
@@ -136,7 +137,7 @@ func TestNewPipe(t *testing.T) {
continue
}
if !fdnotifier.HasFD(int32(f.FD())) {
- t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD)
+ t.Errorf("%s: pipe fd %d is not registered for events", test.desc, f.FD())
}
}
}
@@ -153,13 +154,14 @@ func TestPipeDestruction(t *testing.T) {
syscall.Close(fds[1])
// Test the read end, but it doesn't really matter which.
- p, err := newPipeOperations(contexttest.Context(t), nil, fs.FileFlags{Read: true}, f, nil)
+ ctx := contexttest.Context(t)
+ p, err := newPipeOperations(ctx, nil, fs.FileFlags{Read: true}, f, nil)
if err != nil {
f.Close()
t.Fatalf("newPipeOperations got error %v, want nil", err)
}
// Drop our only reference, which should trigger the destructor.
- p.Release()
+ p.Release(ctx)
if fdnotifier.HasFD(int32(fds[0])) {
t.Fatalf("after DecRef fdnotifier has fd %d, want no longer registered", fds[0])
@@ -282,7 +284,7 @@ func TestPipeRequest(t *testing.T) {
if err != nil {
t.Fatalf("%s: newPipeOperations got error %v, want nil", test.desc, err)
}
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{Type: fs.Pipe})
file := fs.NewFile(ctx, fs.NewDirent(ctx, inode, "pipe"), fs.FileFlags{Read: true}, p)
@@ -334,7 +336,7 @@ func TestPipeReadAheadBuffer(t *testing.T) {
rfile.Close()
t.Fatalf("newPipeOperations got error %v, want nil", err)
}
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
Type: fs.Pipe,
@@ -380,7 +382,7 @@ func TestPipeReadsAccumulate(t *testing.T) {
}
// Don't forget to remove the fd from the fd notifier. Otherwise other tests will
// likely be borked, because it's global :(
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
Type: fs.Pipe,
@@ -448,7 +450,7 @@ func TestPipeWritesAccumulate(t *testing.T) {
}
// Don't forget to remove the fd from the fd notifier. Otherwise other tests
// will likely be borked, because it's global :(
- defer p.Release()
+ defer p.Release(ctx)
inode := fs.NewMockInode(ctx, fs.NewMockMountSource(nil), fs.StableAttr{
Type: fs.Pipe,
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index c0a6e884b..72ea70fcf 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -16,20 +16,20 @@ package fs
import (
"math"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -44,7 +44,7 @@ var (
RecordWaitTime = false
reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.")
- readWait = metric.MustCreateNewUint64Metric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
+ readWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.")
)
// IncrementWait increments the given wait time metric, if enabled.
@@ -142,17 +142,17 @@ func NewFile(ctx context.Context, dirent *Dirent, flags FileFlags, fops FileOper
}
// DecRef destroys the File when it is no longer referenced.
-func (f *File) DecRef() {
- f.DecRefWithDestructor(func() {
+func (f *File) DecRef(ctx context.Context) {
+ f.DecRefWithDestructor(ctx, func(context.Context) {
// Drop BSD style locks.
lockRng := lock.LockRange{Start: 0, End: lock.LockEOF}
- f.Dirent.Inode.LockCtx.BSD.UnlockRegion(lock.UniqueID(f.UniqueID), lockRng)
+ f.Dirent.Inode.LockCtx.BSD.UnlockRegion(f, lockRng)
// Release resources held by the FileOperations.
- f.FileOperations.Release()
+ f.FileOperations.Release(ctx)
// Release a reference on the Dirent.
- f.Dirent.DecRef()
+ f.Dirent.DecRef(ctx)
// Only unregister if we are currently registered. There is nothing
// to register if f.async is nil (this happens when async mode is
@@ -310,7 +310,6 @@ func (f *File) Writev(ctx context.Context, src usermem.IOSequence) (int64, error
if !f.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
-
unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
// Handle append mode.
if f.Flags().Append {
@@ -355,7 +354,6 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64
// offset."
unlockAppendMu := f.Dirent.Inode.lockAppendMu(f.Flags().Append)
defer unlockAppendMu()
-
if f.Flags().Append {
if err := f.offsetForAppend(ctx, &offset); err != nil {
return 0, err
@@ -374,9 +372,10 @@ func (f *File) Pwritev(ctx context.Context, src usermem.IOSequence, offset int64
return f.FileOperations.Write(ctx, f, src, offset)
}
-// offsetForAppend sets the given offset to the end of the file.
+// offsetForAppend atomically sets the given offset to the end of the file.
//
-// Precondition: the file.Dirent.Inode.appendMu mutex should be held for writing.
+// Precondition: the file.Dirent.Inode.appendMu mutex should be held for
+// writing.
func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
uattr, err := f.Dirent.Inode.UnstableAttr(ctx)
if err != nil {
@@ -386,7 +385,7 @@ func (f *File) offsetForAppend(ctx context.Context, offset *int64) error {
}
// Update the offset.
- *offset = uattr.Size
+ atomic.StoreInt64(offset, uattr.Size)
return nil
}
@@ -461,7 +460,7 @@ func (f *File) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
func (f *File) MappedName(ctx context.Context) string {
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
name, _ := f.Dirent.FullName(root)
return name
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index b88303f17..305c0f840 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -17,10 +17,10 @@ package fs
import (
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -67,7 +67,7 @@ type SpliceOpts struct {
// - File.Flags(): This value may change during the operation.
type FileOperations interface {
// Release release resources held by FileOperations.
- Release()
+ Release(ctx context.Context)
// Waitable defines how this File can be waited on for read and
// write readiness.
@@ -160,6 +160,7 @@ type FileOperations interface {
// refer.
//
// Preconditions: The AddressSpace (if any) that io refers to is activated.
+ // Must only be called from a task goroutine.
Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error)
}
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 225e40186..9dc58d5ff 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -16,14 +16,14 @@ package fs
import (
"io"
- "sync"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -54,7 +54,7 @@ func overlayFile(ctx context.Context, inode *Inode, flags FileFlags) (*File, err
// Drop the extra reference on the Dirent. Now there's only one reference
// on the dirent, either owned by f (if non-nil), or the Dirent is about
// to be destroyed (if GetFile failed).
- dirent.DecRef()
+ dirent.DecRef(ctx)
return f, err
}
@@ -89,12 +89,12 @@ type overlayFileOperations struct {
}
// Release implements FileOperations.Release.
-func (f *overlayFileOperations) Release() {
+func (f *overlayFileOperations) Release(ctx context.Context) {
if f.upper != nil {
- f.upper.DecRef()
+ f.upper.DecRef(ctx)
}
if f.lower != nil {
- f.lower.DecRef()
+ f.lower.DecRef(ctx)
}
}
@@ -164,7 +164,7 @@ func (f *overlayFileOperations) Seek(ctx context.Context, file *File, whence See
func (f *overlayFileOperations) Readdir(ctx context.Context, file *File, serializer DentrySerializer) (int64, error) {
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &DirCtx{
@@ -475,7 +475,7 @@ func readdirEntries(ctx context.Context, o *overlayEntry) (*SortedDentryMap, err
// Skip this name if it is a negative entry in the
// upper or there exists a whiteout for it.
if o.upper != nil {
- if overlayHasWhiteout(o.upper, name) {
+ if overlayHasWhiteout(ctx, o.upper, name) {
continue
}
}
@@ -497,7 +497,7 @@ func readdirOne(ctx context.Context, d *Dirent) (map[string]DentAttr, error) {
if err != nil {
return nil, err
}
- defer dir.DecRef()
+ defer dir.DecRef(ctx)
// Use a stub serializer to read the entries into memory.
stubSerializer := &CollectEntriesSerializer{}
@@ -521,10 +521,10 @@ type overlayMappingIdentity struct {
}
// DecRef implements AtomicRefCount.DecRef.
-func (omi *overlayMappingIdentity) DecRef() {
- omi.AtomicRefCount.DecRefWithDestructor(func() {
- omi.overlayFile.DecRef()
- omi.id.DecRef()
+func (omi *overlayMappingIdentity) DecRef(ctx context.Context) {
+ omi.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) {
+ omi.overlayFile.DecRef(ctx)
+ omi.id.DecRef(ctx)
})
}
@@ -544,7 +544,7 @@ func (omi *overlayMappingIdentity) InodeID() uint64 {
func (omi *overlayMappingIdentity) MappedName(ctx context.Context) string {
root := RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
name, _ := omi.overlayFile.Dirent.FullName(root)
return name
diff --git a/pkg/sentry/fs/file_overlay_test.go b/pkg/sentry/fs/file_overlay_test.go
index 2fb824d5c..1971cc680 100644
--- a/pkg/sentry/fs/file_overlay_test.go
+++ b/pkg/sentry/fs/file_overlay_test.go
@@ -18,7 +18,7 @@ import (
"reflect"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
@@ -175,89 +175,6 @@ func TestReaddirRevalidation(t *testing.T) {
}
}
-// TestReaddirOverlayFrozen tests that calling Readdir on an overlay file with
-// a frozen dirent tree does not make Readdir calls to the underlying files.
-func TestReaddirOverlayFrozen(t *testing.T) {
- ctx := contexttest.Context(t)
-
- // Create an overlay with two directories, each with two files.
- upper := newTestRamfsDir(ctx, []dirContent{{name: "upper-file1"}, {name: "upper-file2"}}, nil)
- lower := newTestRamfsDir(ctx, []dirContent{{name: "lower-file1"}, {name: "lower-file2"}}, nil)
- overlayInode := fs.NewTestOverlayDir(ctx, upper, lower, false)
-
- // Set that overlay as the root.
- root := fs.NewDirent(ctx, overlayInode, "root")
- ctx = &rootContext{
- Context: ctx,
- root: root,
- }
-
- // Check that calling Readdir on the root now returns all 4 files (2
- // from each layer in the overlay).
- rootFile, err := root.Inode.GetFile(ctx, root, fs.FileFlags{Read: true})
- if err != nil {
- t.Fatalf("root.Inode.GetFile failed: %v", err)
- }
- defer rootFile.DecRef()
- ser := &fs.CollectEntriesSerializer{}
- if err := rootFile.Readdir(ctx, ser); err != nil {
- t.Fatalf("rootFile.Readdir failed: %v", err)
- }
- if got, want := ser.Order, []string{".", "..", "lower-file1", "lower-file2", "upper-file1", "upper-file2"}; !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got names %v, want %v", got, want)
- }
-
- // Readdir should have been called on upper and lower.
- upperDir := upper.InodeOperations.(*dir)
- lowerDir := lower.InodeOperations.(*dir)
- if !upperDir.ReaddirCalled {
- t.Errorf("upperDir.ReaddirCalled got %v, want true", upperDir.ReaddirCalled)
- }
- if !lowerDir.ReaddirCalled {
- t.Errorf("lowerDir.ReaddirCalled got %v, want true", lowerDir.ReaddirCalled)
- }
-
- // Reset.
- upperDir.ReaddirCalled = false
- lowerDir.ReaddirCalled = false
-
- // Take references on "upper-file1" and "lower-file1", pinning them in
- // the dirent tree.
- for _, name := range []string{"upper-file1", "lower-file1"} {
- if _, err := root.Walk(ctx, root, name); err != nil {
- t.Fatalf("root.Walk(%q) failed: %v", name, err)
- }
- // Don't drop a reference on the returned dirent so that it
- // will stay in the tree.
- }
-
- // Freeze the dirent tree.
- root.Freeze()
-
- // Seek back to the beginning of the file.
- if _, err := rootFile.Seek(ctx, fs.SeekSet, 0); err != nil {
- t.Fatalf("error seeking to beginning of directory: %v", err)
- }
-
- // Calling Readdir on the root now will return only the pinned
- // children.
- ser = &fs.CollectEntriesSerializer{}
- if err := rootFile.Readdir(ctx, ser); err != nil {
- t.Fatalf("rootFile.Readdir failed: %v", err)
- }
- if got, want := ser.Order, []string{".", "..", "lower-file1", "upper-file1"}; !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got names %v, want %v", got, want)
- }
-
- // Readdir should NOT have been called on upper or lower.
- if upperDir.ReaddirCalled {
- t.Errorf("upperDir.ReaddirCalled got %v, want false", upperDir.ReaddirCalled)
- }
- if lowerDir.ReaddirCalled {
- t.Errorf("lowerDir.ReaddirCalled got %v, want false", lowerDir.ReaddirCalled)
- }
-}
-
type rootContext struct {
context.Context
root *fs.Dirent
diff --git a/pkg/sentry/fs/filesystems.go b/pkg/sentry/fs/filesystems.go
index b157fd228..d41f30bbb 100644
--- a/pkg/sentry/fs/filesystems.go
+++ b/pkg/sentry/fs/filesystems.go
@@ -18,9 +18,9 @@ import (
"fmt"
"sort"
"strings"
- "sync"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FilesystemFlags matches include/linux/fs.h:file_system_type.fs_flags.
@@ -87,20 +87,6 @@ func RegisterFilesystem(f Filesystem) {
filesystems.registered[f.Name()] = f
}
-// UnregisterFilesystem removes a file system from the global set. To keep the
-// file system set compatible with save/restore, UnregisterFilesystem must be
-// called before save/restore methods.
-//
-// For instance, packages may unregister their file system after it is mounted.
-// This makes sense for pseudo file systems that should not be visible or
-// mountable. See whitelistfs in fs/host/fs.go for one example.
-func UnregisterFilesystem(name string) {
- filesystems.mu.Lock()
- defer filesystems.mu.Unlock()
-
- delete(filesystems.registered, name)
-}
-
// FindFilesystem returns a Filesystem registered at name or (nil, false) if name
// is not a file system type that can be found in /proc/filesystems.
func FindFilesystem(name string) (Filesystem, bool) {
diff --git a/pkg/sentry/fs/filetest/BUILD b/pkg/sentry/fs/filetest/BUILD
index 358dc2be3..a8000e010 100644
--- a/pkg/sentry/fs/filetest/BUILD
+++ b/pkg/sentry/fs/filetest/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,15 +6,14 @@ go_library(
name = "filetest",
testonly = 1,
srcs = ["filetest.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/filetest",
visibility = ["//pkg/sentry:internal"],
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go
index 22270a494..8049538f2 100644
--- a/pkg/sentry/fs/filetest/filetest.go
+++ b/pkg/sentry/fs/filetest/filetest.go
@@ -19,12 +19,12 @@ import (
"fmt"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/flags.go b/pkg/sentry/fs/flags.go
index 0fab876a9..4338ae1fa 100644
--- a/pkg/sentry/fs/flags.go
+++ b/pkg/sentry/fs/flags.go
@@ -64,6 +64,10 @@ type FileFlags struct {
// NonSeekable indicates that file.offset isn't used.
NonSeekable bool
+
+ // Truncate indicates that the file should be truncated before opened.
+ // This is only applicable if the file is regular.
+ Truncate bool
}
// SettableFileFlags is a subset of FileFlags above that can be changed
@@ -118,6 +122,9 @@ func (f FileFlags) ToLinux() (mask uint) {
if f.LargeFile {
mask |= linux.O_LARGEFILE
}
+ if f.Truncate {
+ mask |= linux.O_TRUNC
+ }
switch {
case f.Read && f.Write:
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index 8b2a5e6b2..d2dbff268 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -42,9 +42,10 @@
// Dirent.dirMu
// Dirent.mu
// DirentCache.mu
-// Locks in InodeOperations implementations or overlayEntry
// Inode.Watches.mu (see `Inotify` for other lock ordering)
// MountSource.mu
+// Inode.appendMu
+// Locks in InodeOperations implementations or overlayEntry
//
// If multiple Dirent or MountSource locks must be taken, locks in the parent must be
// taken before locks in their children.
@@ -54,10 +55,9 @@
package fs
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
var (
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index b2e8d9c77..5fb419bcd 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_template_instance(
out = "dirty_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "Dirty",
@@ -26,16 +24,16 @@ go_template_instance(
name = "frame_ref_set_impl",
out = "frame_ref_set_impl.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "fsutil",
- prefix = "frameRef",
+ prefix = "FrameRef",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "uint64",
- "Functions": "frameRefSetFunctions",
+ "Functions": "FrameRefSetFunctions",
},
)
@@ -44,7 +42,6 @@ go_template_instance(
out = "file_range_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "FileRange",
@@ -53,7 +50,7 @@ go_template_instance(
"Key": "uint64",
"Range": "memmap.MappableRange",
"Value": "uint64",
- "Functions": "fileRangeSetFunctions",
+ "Functions": "FileRangeSetFunctions",
},
)
@@ -75,25 +72,24 @@ go_library(
"inode.go",
"inode_cached.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/fsutil",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -105,15 +101,15 @@ go_test(
"dirty_set_test.go",
"inode_cached_test.go",
],
- embed = [":fsutil"],
+ library = ":fsutil",
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/safemem",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
- "//pkg/sentry/safemem",
- "//pkg/sentry/usermem",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go
index 12132680b..2c9446c1d 100644
--- a/pkg/sentry/fs/fsutil/dirty_set.go
+++ b/pkg/sentry/fs/fsutil/dirty_set.go
@@ -17,11 +17,10 @@ package fsutil
import (
"math"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// DirtySet maps offsets into a memmap.Mappable to DirtyInfo. It is used to
@@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) {
// repeatedly until all bytes have been written. max is the true size of the
// cached object; offsets beyond max will not be passed to writeAt, even if
// they are marked dirty.
-func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
var changedDirty bool
defer func() {
if changedDirty {
@@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet
// successful partial write, SyncDirtyAll will call it repeatedly until all
// bytes have been written. max is the true size of the cached object; offsets
// beyond max will not be passed to writeAt, even if they are marked dirty.
-func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
dseg := dirty.FirstSegment()
for dseg.Ok() {
if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil {
@@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max
}
// Preconditions: mr must be page-aligned.
-func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() {
wbr := cseg.Range().Intersect(mr)
if max < wbr.Start {
diff --git a/pkg/sentry/fs/fsutil/dirty_set_test.go b/pkg/sentry/fs/fsutil/dirty_set_test.go
index 75575d994..e3579c23c 100644
--- a/pkg/sentry/fs/fsutil/dirty_set_test.go
+++ b/pkg/sentry/fs/fsutil/dirty_set_test.go
@@ -19,7 +19,7 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func TestDirtySet(t *testing.T) {
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index fc5b3b1a1..dc9efa5df 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -17,12 +17,12 @@ package fsutil
import (
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -31,7 +31,7 @@ import (
type FileNoopRelease struct{}
// Release is a no-op.
-func (FileNoopRelease) Release() {}
+func (FileNoopRelease) Release(context.Context) {}
// SeekWithDirCursor is used to implement fs.FileOperations.Seek. If dirCursor
// is not nil and the seek was on a directory, the cursor will be updated.
@@ -296,7 +296,7 @@ func (sdfo *StaticDirFileOperations) IterateDir(ctx context.Context, d *fs.Diren
func (sdfo *StaticDirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 0a5466b0a..bbafebf03 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -19,40 +19,39 @@ import (
"io"
"math"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
-// platform.File. It is used to implement Mappables that store data in
+// memmap.File. It is used to implement Mappables that store data in
// sparsely-allocated memory.
//
// type FileRangeSet <generated by go_generics>
-// fileRangeSetFunctions implements segment.Functions for FileRangeSet.
-type fileRangeSetFunctions struct{}
+// FileRangeSetFunctions implements segment.Functions for FileRangeSet.
+type FileRangeSetFunctions struct{}
// MinKey implements segment.Functions.MinKey.
-func (fileRangeSetFunctions) MinKey() uint64 {
+func (FileRangeSetFunctions) MinKey() uint64 {
return 0
}
// MaxKey implements segment.Functions.MaxKey.
-func (fileRangeSetFunctions) MaxKey() uint64 {
+func (FileRangeSetFunctions) MaxKey() uint64 {
return math.MaxUint64
}
// ClearValue implements segment.Functions.ClearValue.
-func (fileRangeSetFunctions) ClearValue(_ *uint64) {
+func (FileRangeSetFunctions) ClearValue(_ *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
+func (FileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _ memmap.MappableRange, frstart2 uint64) (uint64, bool) {
if frstart1+mr1.Length() != frstart2 {
return 0, false
}
@@ -60,25 +59,25 @@ func (fileRangeSetFunctions) Merge(mr1 memmap.MappableRange, frstart1 uint64, _
}
// Split implements segment.Functions.Split.
-func (fileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
+func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, split uint64) (uint64, uint64) {
return frstart, frstart + (split - mr.Start)
}
// FileRange returns the FileRange mapped by seg.
-func (seg FileRangeIterator) FileRange() platform.FileRange {
+func (seg FileRangeIterator) FileRange() memmap.FileRange {
return seg.FileRangeOf(seg.Range())
}
// FileRangeOf returns the FileRange mapped by mr.
//
// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0.
-func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange {
+func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange {
frstart := seg.Value() + (mr.Start - seg.Start())
- return platform.FileRange{frstart, frstart + mr.Length()}
+ return memmap.FileRange{frstart, frstart + mr.Length()}
}
// Fill attempts to ensure that all memmap.Mappable offsets in required are
-// mapped to a platform.File offset, by allocating from mf with the given
+// mapped to a memmap.File offset, by allocating from mf with the given
// memory usage kind and invoking readAt to store data into memory. (If readAt
// returns a successful partial read, Fill will call it repeatedly until all
// bytes have been read.) EOF is handled consistently with the requirements of
@@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map
}
// Drop removes segments for memmap.Mappable offsets in mr, freeing the
-// corresponding platform.FileRanges.
+// corresponding memmap.FileRanges.
//
// Preconditions: mr must be page-aligned.
func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
@@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
}
// DropAll removes all segments in mr, freeing the corresponding
-// platform.FileRanges.
+// memmap.FileRanges.
func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
mf.DecRef(seg.FileRange())
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go
index dd63db32b..a808894df 100644
--- a/pkg/sentry/fs/fsutil/frame_ref_set.go
+++ b/pkg/sentry/fs/fsutil/frame_ref_set.go
@@ -17,27 +17,29 @@ package fsutil
import (
"math"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
)
-type frameRefSetFunctions struct{}
+// FrameRefSetFunctions implements segment.Functions for FrameRefSet.
+type FrameRefSetFunctions struct{}
// MinKey implements segment.Functions.MinKey.
-func (frameRefSetFunctions) MinKey() uint64 {
+func (FrameRefSetFunctions) MinKey() uint64 {
return 0
}
// MaxKey implements segment.Functions.MaxKey.
-func (frameRefSetFunctions) MaxKey() uint64 {
+func (FrameRefSetFunctions) MaxKey() uint64 {
return math.MaxUint64
}
// ClearValue implements segment.Functions.ClearValue.
-func (frameRefSetFunctions) ClearValue(val *uint64) {
+func (FrameRefSetFunctions) ClearValue(val *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (frameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
+func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) {
if val1 != val2 {
return 0, false
}
@@ -45,6 +47,45 @@ func (frameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.
}
// Split implements segment.Functions.Split.
-func (frameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
+func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) {
return val, val
}
+
+// IncRefAndAccount adds a reference on the range fr. All newly inserted segments
+// are accounted as host page cache memory mappings.
+func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) {
+ seg, gap := refs.Find(fr.Start)
+ for {
+ switch {
+ case seg.Ok() && seg.Start() < fr.End:
+ seg = refs.Isolate(seg, fr)
+ seg.SetValue(seg.Value() + 1)
+ seg, gap = seg.NextNonEmpty()
+ case gap.Ok() && gap.Start() < fr.End:
+ newRange := gap.Range().Intersect(fr)
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
+ seg, gap = refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
+ default:
+ refs.MergeAdjacent(fr)
+ return
+ }
+ }
+}
+
+// DecRefAndAccount removes a reference on the range fr and untracks segments
+// that are removed from memory accounting.
+func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) {
+ seg := refs.FindSegment(fr.Start)
+
+ for seg.Ok() && seg.Start() < fr.End {
+ seg = refs.Isolate(seg, fr)
+ if old := seg.Value(); old == 1 {
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
+ seg = refs.Remove(seg).NextSegment()
+ } else {
+ seg.SetValue(old - 1)
+ seg = seg.NextSegment()
+ }
+ }
+ refs.MergeAdjacent(fr)
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index b06a71cc2..ef0113b52 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -16,14 +16,13 @@ package fsutil
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// HostFileMapper caches mappings of an arbitrary host file descriptor. It is
@@ -65,13 +64,18 @@ type mapping struct {
writable bool
}
-// NewHostFileMapper returns a HostFileMapper with no references or cached
-// mappings.
+// Init must be called on zero-value HostFileMappers before first use.
+func (f *HostFileMapper) Init() {
+ f.refs = make(map[uint64]int32)
+ f.mappings = make(map[uint64]mapping)
+}
+
+// NewHostFileMapper returns an initialized HostFileMapper allocated on the
+// heap with no references or cached mappings.
func NewHostFileMapper() *HostFileMapper {
- return &HostFileMapper{
- refs: make(map[uint64]int32),
- mappings: make(map[uint64]mapping),
- }
+ f := &HostFileMapper{}
+ f.Init()
+ return f
}
// IncRefOn increments the reference count on all offsets in mr.
@@ -121,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
// offsets in fr or until the next call to UnmapAll.
//
// Preconditions: The caller must hold a reference on all offsets in fr.
-func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
+func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
f.mapsMu.Lock()
defer f.mapsMu.Unlock()
@@ -141,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool)
}
// Preconditions: f.mapsMu must be locked.
-func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error {
+func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error {
prot := syscall.PROT_READ
if write {
prot |= syscall.PROT_WRITE
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go
index ad11a0573..2d4778d64 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go
@@ -17,7 +17,7 @@ package fsutil
import (
"unsafe"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/safemem"
)
func (*HostFileMapper) unsafeBlockFromChunkMapping(addr uintptr) safemem.Block {
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 30475f340..c15d8a946 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -16,23 +16,22 @@ package fsutil
import (
"math"
- "sync"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// HostMappable implements memmap.Mappable and platform.File over a
+// HostMappable implements memmap.Mappable and memmap.File over a
// CachedFileObject.
//
// Lock order (compare the lock order model in mm/mm.go):
// truncateMu ("fs locks")
// mu ("memmap.Mappable locks not taken by Translate")
-// ("platform.File locks")
+// ("memmap.File locks")
// backingFile ("CachedFileObject locks")
//
// +stateify savable
@@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error {
return nil
}
-// MapInternal implements platform.File.MapInternal.
-func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (h *HostMappable) FD() int {
return h.backingFile.FD()
}
-// IncRef implements platform.File.IncRef.
-func (h *HostMappable) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (h *HostMappable) IncRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.IncRefOn(mr)
}
-// DecRef implements platform.File.DecRef.
-func (h *HostMappable) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (h *HostMappable) DecRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.DecRefOn(mr)
}
diff --git a/pkg/sentry/fs/fsutil/inode.go b/pkg/sentry/fs/fsutil/inode.go
index 4e100a402..1922ff08c 100644
--- a/pkg/sentry/fs/fsutil/inode.go
+++ b/pkg/sentry/fs/fsutil/inode.go
@@ -15,13 +15,13 @@
package fsutil
import (
- "sync"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -203,7 +203,7 @@ func (i *InodeSimpleAttributes) NotifyModificationAndStatusChange(ctx context.Co
}
// InodeSimpleExtendedAttributes implements
-// fs.InodeOperations.{Get,Set,List}xattr.
+// fs.InodeOperations.{Get,Set,List}Xattr.
//
// +stateify savable
type InodeSimpleExtendedAttributes struct {
@@ -212,8 +212,8 @@ type InodeSimpleExtendedAttributes struct {
xattrs map[string]string
}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (i *InodeSimpleExtendedAttributes) Getxattr(_ *fs.Inode, name string) (string, error) {
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *InodeSimpleExtendedAttributes) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
i.mu.RLock()
value, ok := i.xattrs[name]
i.mu.RUnlock()
@@ -223,19 +223,31 @@ func (i *InodeSimpleExtendedAttributes) Getxattr(_ *fs.Inode, name string) (stri
return value, nil
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (i *InodeSimpleExtendedAttributes) Setxattr(_ *fs.Inode, name, value string) error {
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *InodeSimpleExtendedAttributes) SetXattr(_ context.Context, _ *fs.Inode, name, value string, flags uint32) error {
i.mu.Lock()
+ defer i.mu.Unlock()
if i.xattrs == nil {
+ if flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
i.xattrs = make(map[string]string)
}
+
+ _, ok := i.xattrs[name]
+ if ok && flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
i.xattrs[name] = value
- i.mu.Unlock()
return nil
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (i *InodeSimpleExtendedAttributes) Listxattr(_ *fs.Inode) (map[string]struct{}, error) {
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *InodeSimpleExtendedAttributes) ListXattr(context.Context, *fs.Inode, uint64) (map[string]struct{}, error) {
i.mu.RLock()
names := make(map[string]struct{}, len(i.xattrs))
for name := range i.xattrs {
@@ -245,6 +257,17 @@ func (i *InodeSimpleExtendedAttributes) Listxattr(_ *fs.Inode) (map[string]struc
return names, nil
}
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (i *InodeSimpleExtendedAttributes) RemoveXattr(_ context.Context, _ *fs.Inode, name string) error {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if _, ok := i.xattrs[name]; ok {
+ delete(i.xattrs, name)
+ return nil
+ }
+ return syserror.ENOATTR
+}
+
// staticFile is a file with static contents. It is returned by
// InodeStaticFileGetter.GetFile.
//
@@ -437,21 +460,26 @@ func (InodeNotSymlink) Getlink(context.Context, *fs.Inode) (*fs.Dirent, error) {
// extended attributes.
type InodeNoExtendedAttributes struct{}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (InodeNoExtendedAttributes) Getxattr(*fs.Inode, string) (string, error) {
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (InodeNoExtendedAttributes) GetXattr(context.Context, *fs.Inode, string, uint64) (string, error) {
return "", syserror.EOPNOTSUPP
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (InodeNoExtendedAttributes) Setxattr(*fs.Inode, string, string) error {
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (InodeNoExtendedAttributes) SetXattr(context.Context, *fs.Inode, string, string, uint32) error {
return syserror.EOPNOTSUPP
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (InodeNoExtendedAttributes) Listxattr(*fs.Inode) (map[string]struct{}, error) {
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (InodeNoExtendedAttributes) ListXattr(context.Context, *fs.Inode, uint64) (map[string]struct{}, error) {
return nil, syserror.EOPNOTSUPP
}
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (InodeNoExtendedAttributes) RemoveXattr(context.Context, *fs.Inode, string) error {
+ return syserror.EOPNOTSUPP
+}
+
// InodeNoopRelease implements fs.InodeOperations.Release as a noop.
type InodeNoopRelease struct{}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 798920d18..fe8b0b6ac 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -17,19 +17,18 @@ package fsutil
import (
"fmt"
"io"
- "sync"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Lock order (compare the lock order model in mm/mm.go):
@@ -111,7 +110,7 @@ type CachingInodeOperations struct {
// refs tracks active references to data in the cache.
//
// refs is protected by dataMu.
- refs frameRefSet
+ refs FrameRefSet
}
// CachingInodeOperationsOptions configures a CachingInodeOperations.
@@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error {
- // Whether we have a host fd (and consequently what platform.File is
+ // Whether we have a host fd (and consequently what memmap.File is
// mapped) can change across save/restore, so invalidate all translations
// unconditionally.
c.mapsMu.Lock()
@@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable
}
}
-// IncRef implements platform.File.IncRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// IncRef implements memmap.File.IncRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg, gap := c.refs.Find(fr.Start)
@@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
}
}
-// DecRef implements platform.File.DecRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// DecRef implements memmap.File.DecRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg := c.refs.FindSegment(fr.Start)
@@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
c.dataMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal. This is used when we
+// MapInternal implements memmap.File.MapInternal. This is used when we
// directly map an underlying host fd and CachingInodeOperations is used as the
-// platform.File during translation.
-func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// memmap.File during translation.
+func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// FD implements memmap.File.FD. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
func (c *CachingInodeOperations) FD() int {
return c.backingFile.FD()
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index 129f314c8..1547584c5 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -19,14 +19,14 @@ import (
"io"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type noopBackingFile struct{}
diff --git a/pkg/sentry/fs/g3doc/.gitignore b/pkg/sentry/fs/g3doc/.gitignore
new file mode 100644
index 000000000..2d19fc766
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/.gitignore
@@ -0,0 +1 @@
+*.html
diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md
new file mode 100644
index 000000000..2ca84dd74
--- /dev/null
+++ b/pkg/sentry/fs/g3doc/fuse.md
@@ -0,0 +1,263 @@
+# Foreword
+
+This document describes an on-going project to support FUSE filesystems within
+the sentry. This is intended to become the final documentation for this
+subsystem, and is therefore written in the past tense. However FUSE support is
+currently incomplete and the document will be updated as things progress.
+
+# FUSE: Filesystem in Userspace
+
+The sentry supports dispatching filesystem operations to a FUSE server, allowing
+FUSE filesystem to be used with a sandbox.
+
+## Overview
+
+FUSE has two main components:
+
+1. A client kernel driver (canonically `fuse.ko` in Linux), which forwards
+ filesystem operations (usually initiated by syscalls) to the server.
+
+2. A server, which is a userspace daemon that implements the actual filesystem.
+
+The sentry implements the client component, which allows a server daemon running
+within the sandbox to implement a filesystem within the sandbox.
+
+A FUSE filesystem is initialized with `mount(2)`, typically with the help of a
+utility like `fusermount(1)`. Various mount options exist for establishing
+ownership and access permissions on the filesystem, but the most important mount
+option is a file descriptor used to establish communication between the client
+and server.
+
+The FUSE device FD is obtained by opening `/dev/fuse`. During regular operation,
+the client and server use the FUSE protocol described in `fuse(4)` to service
+filesystem operations. See the "Protocol" section below for more information
+about this protocol. The core of the sentry support for FUSE is the client-side
+implementation of this protocol.
+
+## FUSE in the Sentry
+
+The sentry's FUSE client targets VFS2 and has the following components:
+
+- An implementation of `/dev/fuse`.
+
+- A VFS2 filesystem for mapping syscalls to FUSE ops. Since we're targeting
+ VFS2, one point of contention may be the lack of inodes in VFS2. We can
+ tentatively implement a kernfs-based filesystem to bridge the gap in APIs.
+ The kernfs base functionality can serve the role of the Linux inode cache
+ and, the filesystem can map VFS2 syscalls to kernfs inode operations; see
+ the `kernfs.Inode` interface.
+
+The FUSE protocol lends itself well to marshaling with `go_marshal`. The various
+request and response packets can be defined in the ABI package and converted to
+and from the wire format using `go_marshal`.
+
+### Design Goals
+
+- While filesystem performance is always important, the sentry's FUSE support
+ is primarily concerned with compatibility, with performance as a secondary
+ concern.
+
+- Avoiding deadlocks from a hung server daemon.
+
+- Consider the potential for denial of service from a malicious server daemon.
+ Protecting itself from userspace is already a design goal for the sentry,
+ but needs additional consideration for FUSE. Normally, an operating system
+ doesn't rely on userspace to make progress with filesystem operations. Since
+ this changes with FUSE, it opens up the possibility of creating a chain of
+ dependencies controlled by userspace, which could affect an entire sandbox.
+ For example: a FUSE op can block a syscall, which could be holding a
+ subsystem lock, which can then block another task goroutine.
+
+### Milestones
+
+Below are some broad goals to aim for while implementing FUSE in the sentry.
+Many FUSE ops can be grouped into broad categories of functionality, and most
+ops can be implemented in parallel.
+
+#### Minimal client that can mount a trivial FUSE filesystem.
+
+- Implement `/dev/fuse` - a character device used to establish an FD for
+ communication between the sentry and the server daemon.
+
+- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`.
+
+#### Read-only mount with basic file operations
+
+- Implement the majority of file, directory and file descriptor FUSE ops. For
+ this milestone, we can skip uncommon or complex operations like mmap, mknod,
+ file locking, poll, and extended attributes. We can stub these out along
+ with any ops that modify the filesystem. The exact list of required ops are
+ to be determined, but the goal is to mount a real filesystem as read-only,
+ and be able to read contents from the filesystem in the sentry.
+
+#### Full read-write support
+
+- Implement the remaining FUSE ops and decide if we can omit rarely used
+ operations like ioctl.
+
+# Appendix
+
+## FUSE Protocol
+
+The FUSE protocol is a request-response protocol. All requests are initiated by
+the client. The wire-format for the protocol is raw C structs serialized to
+memory.
+
+All FUSE requests begin with the following request header:
+
+```c
+struct fuse_in_header {
+ uint32_t len; // Length of the request, including this header.
+ uint32_t opcode; // Requested operation.
+ uint64_t unique; // A unique identifier for this request.
+ uint64_t nodeid; // ID of the filesystem object being operated on.
+ uint32_t uid; // UID of the requesting process.
+ uint32_t gid; // GID of the requesting process.
+ uint32_t pid; // PID of the requesting process.
+ uint32_t padding;
+};
+```
+
+The request is then followed by a payload specific to the `opcode`.
+
+All responses begin with this response header:
+
+```c
+struct fuse_out_header {
+ uint32_t len; // Length of the response, including this header.
+ int32_t error; // Status of the request, 0 if success.
+ uint64_t unique; // The unique identifier from the corresponding request.
+};
+```
+
+The response payload also depends on the request `opcode`. If `error != 0`, the
+response payload must be empty.
+
+### Operations
+
+The following is a list of all FUSE operations used in `fuse_in_header.opcode`
+as of Linux v4.4, and a brief description of their purpose. These are defined in
+`uapi/linux/fuse.h`. Many of these have a corresponding request and response
+payload struct; `fuse(4)` has details for some of these. We also note how these
+operations map to the sentry virtual filesystem.
+
+#### FUSE meta-operations
+
+These operations are specific to FUSE and don't have a corresponding action in a
+generic filesystem.
+
+- `FUSE_INIT`: This operation initializes a new FUSE filesystem, and is the
+ first message sent by the client after mount. This is used for version and
+ feature negotiation. This is related to `mount(2)`.
+- `FUSE_DESTROY`: Teardown a FUSE filesystem, related to `unmount(2)`.
+- `FUSE_INTERRUPT`: Interrupts an in-flight operation, specified by the
+ `fuse_in_header.unique` value provided in the corresponding request header.
+ The client can send at most one of these per request, and will enter an
+ uninterruptible wait for a reply. The server is expected to reply promptly.
+- `FUSE_FORGET`: A hint to the server that server should evict the indicate
+ node from any caches. This is wired up to `(struct
+ super_operations).evict_inode` in Linux, which is in turned hooked as the
+ inode cache shrinker which is typically triggered by system memory pressure.
+- `FUSE_BATCH_FORGET`: Batch version of `FUSE_FORGET`.
+
+#### Filesystem Syscalls
+
+These FUSE ops map directly to an equivalent filesystem syscall, or family of
+syscalls. The relevant syscalls have a similar name to the operation, unless
+otherwise noted.
+
+Node creation:
+
+- `FUSE_MKNOD`
+- `FUSE_MKDIR`
+- `FUSE_CREATE`: This is equivalent to `open(2)` and `creat(2)`, which
+ atomically creates and opens a node.
+
+Node attributes and extended attributes:
+
+- `FUSE_GETATTR`
+- `FUSE_SETATTR`
+- `FUSE_SETXATTR`
+- `FUSE_GETXATTR`
+- `FUSE_LISTXATTR`
+- `FUSE_REMOVEXATTR`
+
+Node link manipulation:
+
+- `FUSE_READLINK`
+- `FUSE_LINK`
+- `FUSE_SYMLINK`
+- `FUSE_UNLINK`
+
+Directory operations:
+
+- `FUSE_RMDIR`
+- `FUSE_RENAME`
+- `FUSE_RENAME2`
+- `FUSE_OPENDIR`: `open(2)` for directories.
+- `FUSE_RELEASEDIR`: `close(2)` for directories.
+- `FUSE_READDIR`
+- `FUSE_READDIRPLUS`
+- `FUSE_FSYNCDIR`: `fsync(2)` for directories.
+- `FUSE_LOOKUP`: Establishes a unique identifier for a FS node. This is
+ reminiscent of `VirtualFilesystem.GetDentryAt` in that it resolves a path
+ component to a node. However the returned identifier is opaque to the
+ client. The server must remember this mapping, as this is how the client
+ will reference the node in the future.
+
+File operations:
+
+- `FUSE_OPEN`: `open(2)` for files.
+- `FUSE_RELEASE`: `close(2)` for files.
+- `FUSE_FSYNC`
+- `FUSE_FALLOCATE`
+- `FUSE_SETUPMAPPING`: Creates a memory map on a file for `mmap(2)`.
+- `FUSE_REMOVEMAPPING`: Removes a memory map for `munmap(2)`.
+
+File locking:
+
+- `FUSE_GETLK`
+- `FUSE_SETLK`
+- `FUSE_SETLKW`
+- `FUSE_COPY_FILE_RANGE`
+
+File descriptor operations:
+
+- `FUSE_IOCTL`
+- `FUSE_POLL`
+- `FUSE_LSEEK`
+
+Filesystem operations:
+
+- `FUSE_STATFS`
+
+#### Permissions
+
+- `FUSE_ACCESS` is used to check if a node is accessible, as part of many
+ syscall implementations. Maps to `vfs.FilesystemImpl.AccessAt` in the
+ sentry.
+
+#### I/O Operations
+
+These ops are used to read and write file pages. They're used to implement both
+I/O syscalls like `read(2)`, `write(2)` and `mmap(2)`.
+
+- `FUSE_READ`
+- `FUSE_WRITE`
+
+#### Miscellaneous
+
+- `FUSE_FLUSH`: Used by the client to indicate when a file descriptor is
+ closed. Distinct from `FUSE_FSYNC`, which corresponds to an `fsync(2)`
+ syscall from the user. Maps to `vfs.FileDescriptorImpl.Release` in the
+ sentry.
+- `FUSE_BMAP`: Old address space API for block defrag. Probably not needed.
+- `FUSE_NOTIFY_REPLY`: [TODO: what does this do?]
+
+# References
+
+- [fuse(4) Linux manual page](https://www.man7.org/linux/man-pages/man4/fuse.4.html)
+- [Linux kernel FUSE documentation](https://www.kernel.org/doc/html/latest/filesystems/fuse.html)
+- [The reference implementation of the Linux FUSE (Filesystem in Userspace)
+ interface](https://github.com/libfuse/libfuse)
+- [The kernel interface of FUSE](https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/include/uapi/linux/fuse.h)
diff --git a/pkg/sentry/fs/g3doc/inotify.md b/pkg/sentry/fs/g3doc/inotify.md
index 71a577d9d..85063d4e6 100644
--- a/pkg/sentry/fs/g3doc/inotify.md
+++ b/pkg/sentry/fs/g3doc/inotify.md
@@ -112,11 +112,11 @@ attempts to queue a new event, it is already holding `fs.Watches.mu`. If we used
`Inotify.mu` to also protect the event queue, this would violate the above lock
ordering.
-[dirent]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/dirent.go
-[event]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify_event.go
-[fd_table]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/kernel/fd_table.go
-[inode]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inode.go
-[inode_watches]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inode_inotify.go
-[inotify]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify.go
-[syscall_dir]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/syscalls/linux/
-[watch]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/fs/inotify_watch.go
+[dirent]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/dirent.go
+[event]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_event.go
+[fd_table]: https://github.com/google/gvisor/blob/master/pkg/sentry/kernel/fd_table.go
+[inode]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode.go
+[inode_watches]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inode_inotify.go
+[inotify]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify.go
+[syscall_dir]: https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/
+[watch]: https://github.com/google/gvisor/blob/master/pkg/sentry/fs/inotify_watch.go
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 4a005c605..fea135eea 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,6 +9,7 @@ go_library(
"cache_policy.go",
"context_file.go",
"device.go",
+ "fifo.go",
"file.go",
"file_state.go",
"fs.go",
@@ -22,31 +22,32 @@ go_library(
"socket.go",
"util.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/gofer",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/fd",
"//pkg/log",
"//pkg/metric",
"//pkg/p9",
"//pkg/refs",
+ "//pkg/safemem",
"//pkg/secio",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fdpipe",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/host",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/unet",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -55,12 +56,12 @@ go_test(
name = "gofer_test",
size = "small",
srcs = ["gofer_test.go"],
- embed = [":gofer"],
+ library = ":gofer",
deps = [
+ "//pkg/context",
"//pkg/p9",
"//pkg/p9/p9test",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
],
)
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
index 4848e2374..d481baf77 100644
--- a/pkg/sentry/fs/gofer/attr.go
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -17,12 +17,12 @@ package gofer
import (
"syscall"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// getattr returns the 9p attributes of the p9.File. On success, Mode, Size, and RDev
@@ -75,10 +75,18 @@ func owner(mounter fs.FileOwner, valid p9.AttrMask, pattr p9.Attr) fs.FileOwner
// task's EUID/EGID.
owner := mounter
if valid.UID {
- owner.UID = auth.KUID(pattr.UID)
+ if pattr.UID.Ok() {
+ owner.UID = auth.KUID(pattr.UID)
+ } else {
+ owner.UID = auth.KUID(auth.OverflowUID)
+ }
}
if valid.GID {
- owner.GID = auth.KGID(pattr.GID)
+ if pattr.GID.Ok() {
+ owner.GID = auth.KGID(pattr.GID)
+ } else {
+ owner.GID = auth.KGID(auth.OverflowGID)
+ }
}
return owner
}
@@ -88,8 +96,9 @@ func bsize(pattr p9.Attr) int64 {
if pattr.BlockSize > 0 {
return int64(pattr.BlockSize)
}
- // Some files may have no clue of their block size. Better not to report
- // something misleading or buggy and have a safe default.
+ // Some files, particularly those that are not on a local file system,
+ // may have no clue of their block size. Better not to report something
+ // misleading or buggy and have a safe default.
return usermem.PageSize
}
@@ -149,6 +158,7 @@ func links(valid p9.AttrMask, pattr p9.Attr) uint64 {
}
// This node is likely backed by a file system that doesn't support links.
+ //
// We could readdir() and count children directories to provide an accurate
// link count. However this may be expensive since the gofer may be backed by remote
// storage. Instead, simply return 2 links for directories and 1 for everything else
diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go
index cc11c6339..07a564e92 100644
--- a/pkg/sentry/fs/gofer/cache_policy.go
+++ b/pkg/sentry/fs/gofer/cache_policy.go
@@ -17,7 +17,7 @@ package gofer
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -127,6 +127,9 @@ func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child
childIops, ok := child.InodeOperations.(*inodeOperations)
if !ok {
+ if _, ok := child.InodeOperations.(*fifo); ok {
+ return false
+ }
panic(fmt.Sprintf("revalidating inode operations of unknown type %T", child.InodeOperations))
}
parentIops, ok := parent.InodeOperations.(*inodeOperations)
diff --git a/pkg/sentry/fs/gofer/context_file.go b/pkg/sentry/fs/gofer/context_file.go
index 44b72582a..125907d70 100644
--- a/pkg/sentry/fs/gofer/context_file.go
+++ b/pkg/sentry/fs/gofer/context_file.go
@@ -15,9 +15,9 @@
package gofer
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
)
// contextFile is a wrapper around p9.File that notifies the context that
@@ -59,6 +59,34 @@ func (c *contextFile) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9
return err
}
+func (c *contextFile) getXattr(ctx context.Context, name string, size uint64) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ val, err := c.file.GetXattr(name, size)
+ ctx.UninterruptibleSleepFinish(false)
+ return val, err
+}
+
+func (c *contextFile) setXattr(ctx context.Context, name, value string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.SetXattr(name, value, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (c *contextFile) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ ctx.UninterruptibleSleepStart(false)
+ xattrs, err := c.file.ListXattr(size)
+ ctx.UninterruptibleSleepFinish(false)
+ return xattrs, err
+}
+
+func (c *contextFile) removeXattr(ctx context.Context, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := c.file.RemoveXattr(name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
func (c *contextFile) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
ctx.UninterruptibleSleepStart(false)
err := c.file.Allocate(mode, offset, length)
diff --git a/pkg/sentry/fs/gofer/fifo.go b/pkg/sentry/fs/gofer/fifo.go
new file mode 100644
index 000000000..456557058
--- /dev/null
+++ b/pkg/sentry/fs/gofer/fifo.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+)
+
+// +stateify savable
+type fifo struct {
+ fs.InodeOperations
+ fileIops *inodeOperations
+}
+
+var _ fs.InodeOperations = (*fifo)(nil)
+
+// Rename implements fs.InodeOperations. It forwards the call to the underlying
+// file inode to handle the file rename. Note that file key remains the same
+// after the rename to keep the endpoint mapping.
+func (i *fifo) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
+ return i.fileIops.Rename(ctx, inode, oldParent, oldName, newParent, newName, replacement)
+}
+
+// StatFS implements fs.InodeOperations.
+func (i *fifo) StatFS(ctx context.Context) (fs.Info, error) {
+ return i.fileIops.StatFS(ctx)
+}
diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go
index 7960b9c7b..c0bc63a32 100644
--- a/pkg/sentry/fs/gofer/file.go
+++ b/pkg/sentry/fs/gofer/file.go
@@ -19,16 +19,16 @@ import (
"syscall"
"time"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -37,9 +37,9 @@ var (
opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.")
opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.")
reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.")
- readWait9P = metric.MustCreateNewUint64Metric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
+ readWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.")
readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.")
- readWaitHost = metric.MustCreateNewUint64Metric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
+ readWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.")
)
// fileOperations implements fs.FileOperations for a remote file system.
@@ -114,7 +114,7 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF
}
// Release implements fs.FileOpeations.Release.
-func (f *fileOperations) Release() {
+func (f *fileOperations) Release(context.Context) {
f.handles.DecRef()
}
@@ -122,7 +122,7 @@ func (f *fileOperations) Release() {
func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
diff --git a/pkg/sentry/fs/gofer/file_state.go b/pkg/sentry/fs/gofer/file_state.go
index c2fbb4be9..edd6576aa 100644
--- a/pkg/sentry/fs/gofer/file_state.go
+++ b/pkg/sentry/fs/gofer/file_state.go
@@ -17,7 +17,7 @@ package gofer
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -28,8 +28,13 @@ func (f *fileOperations) afterLoad() {
// Manually load the open handles.
var err error
- // TODO(b/38173783): Context is not plumbed to save/restore.
- f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), f.flags, f.inodeOperations.cachingInodeOps)
+
+ // The file may have been opened with Truncate, but we don't
+ // want to re-open it with Truncate or we will lose data.
+ flags := f.flags
+ flags.Truncate = false
+
+ f.handles, err = f.inodeOperations.fileState.getHandles(context.Background(), flags, f.inodeOperations.cachingInodeOps)
if err != nil {
return fmt.Errorf("failed to re-open handle: %v", err)
}
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
index cf96dd9fa..8ae2d78d7 100644
--- a/pkg/sentry/fs/gofer/fs.go
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -20,8 +20,8 @@ import (
"fmt"
"strconv"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -60,8 +60,7 @@ const (
limitHostFDTranslationKey = "limit_host_fd_translation"
// overlayfsStaleRead if present closes cached readonly file after the first
- // write. This is done to workaround a limitation of overlayfs in kernels
- // before 4.19 where open FDs are not updated after the file is copied up.
+ // write. This is done to workaround a limitation of Linux overlayfs.
overlayfsStaleRead = "overlayfs_stale_read"
)
diff --git a/pkg/sentry/fs/gofer/gofer_test.go b/pkg/sentry/fs/gofer/gofer_test.go
index 7fc3c32ae..326fed954 100644
--- a/pkg/sentry/fs/gofer/gofer_test.go
+++ b/pkg/sentry/fs/gofer/gofer_test.go
@@ -20,10 +20,10 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/p9/p9test"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -61,7 +61,7 @@ func rootTest(t *testing.T, name string, cp cachePolicy, fn func(context.Context
ctx := contexttest.Context(t)
sattr, rootInodeOperations := newInodeOperations(ctx, s, contextFile{
file: rootFile,
- }, root.QID, p9.AttrMaskAll(), root.Attr, false /* socket */)
+ }, root.QID, p9.AttrMaskAll(), root.Attr)
m := fs.NewMountSource(ctx, s, &filesystem{}, fs.MountSourceFlags{})
rootInode := fs.NewInode(ctx, rootInodeOperations, m, sattr)
@@ -232,7 +232,7 @@ func TestRevalidation(t *testing.T) {
// We must release the dirent, of the test will fail
// with a reference leak. This is tracked by p9test.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
// Walk again. Depending on the cache policy, we may
// get a new dirent.
@@ -246,7 +246,7 @@ func TestRevalidation(t *testing.T) {
if !test.preModificationWantReload && dirent != newDirent {
t.Errorf("Lookup with cachePolicy=%s got new dirent %+v, wanted old dirent %+v", test.cachePolicy, newDirent, dirent)
}
- newDirent.DecRef() // See above.
+ newDirent.DecRef(ctx) // See above.
// Modify the underlying mocked file's modification
// time for the next walk that occurs.
@@ -287,7 +287,7 @@ func TestRevalidation(t *testing.T) {
if test.postModificationWantUpdatedAttrs && gotModTimeSeconds != nowSeconds {
t.Fatalf("Lookup with cachePolicy=%s got new modification time %v, wanted %v", test.cachePolicy, gotModTimeSeconds, nowSeconds)
}
- newDirent.DecRef() // See above.
+ newDirent.DecRef(ctx) // See above.
// Remove the file from the remote fs, subsequent walks
// should now fail to find anything.
@@ -303,7 +303,7 @@ func TestRevalidation(t *testing.T) {
t.Errorf("Lookup with cachePolicy=%s got new dirent and error %v, wanted old dirent and nil error", test.cachePolicy, err)
}
if err == nil {
- newDirent.DecRef() // See above.
+ newDirent.DecRef(ctx) // See above.
}
})
}
diff --git a/pkg/sentry/fs/gofer/handles.go b/pkg/sentry/fs/gofer/handles.go
index 39c8ec33d..f324dbf26 100644
--- a/pkg/sentry/fs/gofer/handles.go
+++ b/pkg/sentry/fs/gofer/handles.go
@@ -17,14 +17,14 @@ package gofer
import (
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/secio"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
)
// handles are the open handles of a gofer file. They are reference counted to
@@ -47,7 +47,8 @@ type handles struct {
// DecRef drops a reference on handles.
func (h *handles) DecRef() {
- h.DecRefWithDestructor(func() {
+ ctx := context.Background()
+ h.DecRefWithDestructor(ctx, func(context.Context) {
if h.Host != nil {
if h.isHostBorrowed {
h.Host.Release()
@@ -57,14 +58,13 @@ func (h *handles) DecRef() {
}
}
}
- // FIXME(b/38173783): Context is not plumbed here.
- if err := h.File.close(context.Background()); err != nil {
+ if err := h.File.close(ctx); err != nil {
log.Warningf("error closing p9 file: %v", err)
}
})
}
-func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*handles, error) {
+func newHandles(ctx context.Context, client *p9.Client, file contextFile, flags fs.FileFlags) (*handles, error) {
_, newFile, err := file.walk(ctx, nil)
if err != nil {
return nil, err
@@ -81,6 +81,9 @@ func newHandles(ctx context.Context, file contextFile, flags fs.FileFlags) (*han
default:
panic("impossible fs.FileFlags")
}
+ if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(client.Version()) {
+ p9flags |= p9.OpenTruncate
+ }
hostFile, _, _, err := newFile.open(ctx, p9flags)
if err != nil {
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 99910388f..3a225fd39 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -16,21 +16,21 @@ package gofer
import (
"errors"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fdpipe"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -38,8 +38,7 @@ import (
//
// +stateify savable
type inodeOperations struct {
- fsutil.InodeNotVirtual `state:"nosave"`
- fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNotVirtual `state:"nosave"`
// fileState implements fs.CachedFileObject. It exists
// to break a circular load dependency between inodeOperations
@@ -180,7 +179,7 @@ func (i *inodeFileState) setSharedHandlesLocked(flags fs.FileFlags, h *handles)
// given flags.
func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cache *fsutil.CachingInodeOperations) (*handles, error) {
if !i.canShareHandles() {
- return newHandles(ctx, i.file, flags)
+ return newHandles(ctx, i.s.client, i.file, flags)
}
i.handlesMu.Lock()
@@ -201,19 +200,25 @@ func (i *inodeFileState) getHandles(ctx context.Context, flags fs.FileFlags, cac
// whether previously open read handle was recreated. Host mappings must be
// invalidated if so.
func (i *inodeFileState) getHandlesLocked(ctx context.Context, flags fs.FileFlags) (*handles, bool, error) {
- // Do we already have usable shared handles?
- if flags.Write {
+ // Check if we are able to use cached handles.
+ if flags.Truncate && p9.VersionSupportsOpenTruncateFlag(i.s.client.Version()) {
+ // If we are truncating (and the gofer supports it), then we
+ // always need a new handle. Don't return one from the cache.
+ } else if flags.Write {
if i.writeHandles != nil && (i.writeHandlesRW || !flags.Read) {
+ // File is opened for writing, and we have cached write
+ // handles that we can use.
i.writeHandles.IncRef()
return i.writeHandles, false, nil
}
} else if i.readHandles != nil {
+ // File is opened for reading and we have cached handles.
i.readHandles.IncRef()
return i.readHandles, false, nil
}
- // No; get new handles and cache them for future sharing.
- h, err := newHandles(ctx, i.file, flags)
+ // Get new handles and cache them for future sharing.
+ h, err := newHandles(ctx, i.s.client, i.file, flags)
if err != nil {
return nil, false, err
}
@@ -239,7 +244,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle
if !flags.Read {
// Writer can't be used for read, must create a new handle.
var err error
- h, err = newHandles(ctx, i.file, fs.FileFlags{Read: true})
+ h, err = newHandles(ctx, i.s.client, i.file, fs.FileFlags{Read: true})
if err != nil {
return err
}
@@ -268,7 +273,7 @@ func (i *inodeFileState) recreateReadHandles(ctx context.Context, writer *handle
// operations on the old will see the new data. Then, make the new handle take
// ownereship of the old FD and mark the old readHandle to not close the FD
// when done.
- if err := syscall.Dup2(h.Host.FD(), i.readHandles.Host.FD()); err != nil {
+ if err := syscall.Dup3(h.Host.FD(), i.readHandles.Host.FD(), syscall.O_CLOEXEC); err != nil {
return err
}
@@ -436,8 +441,9 @@ func (i *inodeOperations) Release(ctx context.Context) {
// asynchronously.
//
// We use AsyncWithContext to avoid needing to allocate an extra
- // anonymous function on the heap.
- fs.AsyncWithContext(ctx, i.fileState.Release)
+ // anonymous function on the heap. We must use background context
+ // because the async work cannot happen on the task context.
+ fs.AsyncWithContext(context.Background(), i.fileState.Release)
}
// Mappable implements fs.InodeOperations.Mappable.
@@ -598,6 +604,26 @@ func (i *inodeOperations) Truncate(ctx context.Context, inode *fs.Inode, length
return i.fileState.file.setAttr(ctx, p9.SetAttrMask{Size: true}, p9.SetAttr{Size: uint64(length)})
}
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (i *inodeOperations) GetXattr(ctx context.Context, _ *fs.Inode, name string, size uint64) (string, error) {
+ return i.fileState.file.getXattr(ctx, name, size)
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (i *inodeOperations) SetXattr(ctx context.Context, _ *fs.Inode, name string, value string, flags uint32) error {
+ return i.fileState.file.setXattr(ctx, name, value, flags)
+}
+
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (i *inodeOperations) ListXattr(ctx context.Context, _ *fs.Inode, size uint64) (map[string]struct{}, error) {
+ return i.fileState.file.listXattr(ctx, size)
+}
+
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (i *inodeOperations) RemoveXattr(ctx context.Context, _ *fs.Inode, name string) error {
+ return i.fileState.file.removeXattr(ctx, name)
+}
+
// Allocate implements fs.InodeOperations.Allocate.
func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset, length int64) error {
// This can only be called for files anyway.
@@ -615,7 +641,7 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset,
// WriteOut implements fs.InodeOperations.WriteOut.
func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
- if !i.session().cachePolicy.cacheUAttrs(inode) {
+ if inode.MountSource.Flags.ReadOnly || !i.session().cachePolicy.cacheUAttrs(inode) {
return nil
}
@@ -685,13 +711,10 @@ func init() {
}
// AddLink implements InodeOperations.AddLink, but is currently a noop.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (*inodeOperations) AddLink() {}
// DropLink implements InodeOperations.DropLink, but is currently a noop.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (*inodeOperations) DropLink() {}
// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
diff --git a/pkg/sentry/fs/gofer/inode_state.go b/pkg/sentry/fs/gofer/inode_state.go
index 0b2eedb7c..a3402e343 100644
--- a/pkg/sentry/fs/gofer/inode_state.go
+++ b/pkg/sentry/fs/gofer/inode_state.go
@@ -20,8 +20,8 @@ import (
"path/filepath"
"strings"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -123,7 +123,6 @@ func (i *inodeFileState) afterLoad() {
// beforeSave.
return fmt.Errorf("failed to find path for inode number %d. Device %s contains %s", i.sattr.InodeID, i.s.connID, fs.InodeMappings(i.s.inodeMappings))
}
- // TODO(b/38173783): Context is not plumbed to save/restore.
ctx := &dummyClockContext{context.Background()}
_, i.file, err = i.s.attach.walk(ctx, splitAbsolutePath(name))
diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go
index 8c17603f8..3c66dc3c2 100644
--- a/pkg/sentry/fs/gofer/path.go
+++ b/pkg/sentry/fs/gofer/path.go
@@ -16,21 +16,30 @@ package gofer
import (
"fmt"
- "syscall"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
// encoding of strings, which uses 2 bytes for the length prefix.
const maxFilenameLen = (1 << 16) - 1
+func changeType(mode p9.FileMode, newType p9.FileMode) p9.FileMode {
+ if newType&^p9.FileModeMask != 0 {
+ panic(fmt.Sprintf("newType contained more bits than just file mode: %x", newType))
+ }
+ clear := mode &^ p9.FileModeMask
+ return clear | newType
+}
+
// Lookup loads an Inode at name into a Dirent based on the session's cache
// policy.
func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
@@ -58,7 +67,7 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
// Get a p9.File for name.
qids, newFile, mask, p9attr, err := i.fileState.file.walkGetAttr(ctx, []string{name})
if err != nil {
- if err == syscall.ENOENT {
+ if err == syserror.ENOENT {
if cp.cacheNegativeDirents() {
// Return a negative Dirent. It will stay cached until something
// is created over it.
@@ -69,8 +78,25 @@ func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string
return nil, err
}
+ if i.session().overrides != nil {
+ // Check if file belongs to a internal named pipe. Note that it doesn't need
+ // to check for sockets because it's done in newInodeOperations below.
+ deviceKey := device.MultiDeviceKey{
+ Device: p9attr.RDev,
+ SecondaryDevice: i.session().connID,
+ Inode: qids[0].Path,
+ }
+ unlock := i.session().overrides.lock()
+ if pipeInode := i.session().overrides.getPipe(deviceKey); pipeInode != nil {
+ unlock()
+ pipeInode.IncRef()
+ return fs.NewDirent(ctx, pipeInode, name), nil
+ }
+ unlock()
+ }
+
// Construct the Inode operations.
- sattr, node := newInodeOperations(ctx, i.fileState.s, newFile, qids[0], mask, p9attr, false)
+ sattr, node := newInodeOperations(ctx, i.fileState.s, newFile, qids[0], mask, p9attr)
// Construct a positive Dirent.
return fs.NewDirent(ctx, fs.NewInode(ctx, node, dir.MountSource, sattr), name), nil
@@ -138,11 +164,11 @@ func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string
qid := qids[0]
// Construct the InodeOperations.
- sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, p9attr, false)
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, p9attr)
// Construct the positive Dirent.
d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
- defer d.DecRef()
+ defer d.DecRef(ctx)
// Construct the new file, caching the handles if allowed.
h := handles{
@@ -180,7 +206,7 @@ func (i *inodeOperations) CreateHardLink(ctx context.Context, inode *fs.Inode, t
targetOpts, ok := target.InodeOperations.(*inodeOperations)
if !ok {
- return syscall.EXDEV
+ return syserror.EXDEV
}
if err := i.fileState.file.link(ctx, &targetOpts.fileState.file, newName); err != nil {
@@ -223,82 +249,115 @@ func (i *inodeOperations) Bind(ctx context.Context, dir *fs.Inode, name string,
return nil, syserror.ENAMETOOLONG
}
- if i.session().endpoints == nil {
- return nil, syscall.EOPNOTSUPP
+ if i.session().overrides == nil {
+ return nil, syserror.EOPNOTSUPP
}
- // Create replaces the directory fid with the newly created/opened
- // file, so clone this directory so it doesn't change out from under
- // this node.
- _, newFile, err := i.fileState.file.walk(ctx, nil)
+ // Stabilize the override map while creation is in progress.
+ unlock := i.session().overrides.lock()
+ defer unlock()
+
+ sattr, iops, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeSocket)
if err != nil {
return nil, err
}
- // Stabilize the endpoint map while creation is in progress.
- unlock := i.session().endpoints.lock()
- defer unlock()
+ // Construct the positive Dirent.
+ childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
+ i.session().overrides.addBoundEndpoint(iops.fileState.key, childDir, ep)
+ return childDir, nil
+}
- // Create a regular file in the gofer and then mark it as a socket by
- // adding this inode key in the 'endpoints' map.
- owner := fs.FileOwnerFromContext(ctx)
- hostFile, err := newFile.create(ctx, name, p9.ReadWrite, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
- if err != nil {
- return nil, err
+// CreateFifo implements fs.InodeOperations.CreateFifo.
+func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
}
- // We're not going to use this file.
- hostFile.Close()
- i.touchModificationAndStatusChangeTime(ctx, dir)
+ owner := fs.FileOwnerFromContext(ctx)
+ mode := p9.FileMode(perm.LinuxMode()) | p9.ModeNamedPipe
- // Get the attributes of the file to create inode key.
- qid, mask, attr, err := getattr(ctx, newFile)
- if err != nil {
- newFile.close(ctx)
- return nil, err
+ // N.B. FIFOs use major/minor numbers 0.
+ if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
+ if i.session().overrides == nil || err != syserror.EPERM {
+ return err
+ }
+ // If gofer doesn't support mknod, check if we can create an internal fifo.
+ return i.createInternalFifo(ctx, dir, name, owner, perm)
}
- key := device.MultiDeviceKey{
- Device: attr.RDev,
- SecondaryDevice: i.session().connID,
- Inode: qid.Path,
+ i.touchModificationAndStatusChangeTime(ctx, dir)
+ return nil
+}
+
+func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, name string, owner fs.FileOwner, perm fs.FilePermissions) error {
+ if i.session().overrides == nil {
+ return syserror.EPERM
}
- // Create child dirent.
+ // Stabilize the override map while creation is in progress.
+ unlock := i.session().overrides.lock()
+ defer unlock()
- // Get an unopened p9.File for the file we created so that it can be
- // cloned and re-opened multiple times after creation.
- _, unopened, err := i.fileState.file.walk(ctx, []string{name})
+ sattr, fileOps, err := i.createEndpointFile(ctx, dir, name, perm, p9.ModeNamedPipe)
if err != nil {
- newFile.close(ctx)
- return nil, err
+ return err
}
- // Construct the InodeOperations.
- sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, attr, true)
+ // First create a pipe.
+ p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)
+
+ // Wrap the fileOps with our Fifo.
+ iops := &fifo{
+ InodeOperations: pipe.NewInodeOperations(ctx, perm, p),
+ fileIops: fileOps,
+ }
+ inode := fs.NewInode(ctx, iops, dir.MountSource, sattr)
// Construct the positive Dirent.
childDir := fs.NewDirent(ctx, fs.NewInode(ctx, iops, dir.MountSource, sattr), name)
- i.session().endpoints.add(key, childDir, ep)
- return childDir, nil
+ i.session().overrides.addPipe(fileOps.fileState.key, childDir, inode)
+ return nil
}
-// CreateFifo implements fs.InodeOperations.CreateFifo.
-func (i *inodeOperations) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
- if len(name) > maxFilenameLen {
- return syserror.ENAMETOOLONG
+// Caller must hold Session.endpoint lock.
+func (i *inodeOperations) createEndpointFile(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions, fileType p9.FileMode) (fs.StableAttr, *inodeOperations, error) {
+ _, dirClone, err := i.fileState.file.walk(ctx, nil)
+ if err != nil {
+ return fs.StableAttr{}, nil, err
}
+ // We're not going to use dirClone after return.
+ defer dirClone.close(ctx)
+ // Create a regular file in the gofer and then mark it as a socket by
+ // adding this inode key in the 'overrides' map.
owner := fs.FileOwnerFromContext(ctx)
- mode := p9.FileMode(perm.LinuxMode()) | p9.ModeNamedPipe
-
- // N.B. FIFOs use major/minor numbers 0.
- if _, err := i.fileState.file.mknod(ctx, name, mode, 0, 0, p9.UID(owner.UID), p9.GID(owner.GID)); err != nil {
- return err
+ hostFile, err := dirClone.create(ctx, name, p9.ReadWrite, p9.FileMode(perm.LinuxMode()), p9.UID(owner.UID), p9.GID(owner.GID))
+ if err != nil {
+ return fs.StableAttr{}, nil, err
}
+ // We're not going to use this file.
+ hostFile.Close()
i.touchModificationAndStatusChangeTime(ctx, dir)
- return nil
+
+ // Get the attributes of the file to create inode key.
+ qid, mask, attr, err := getattr(ctx, dirClone)
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+
+ // Get an unopened p9.File for the file we created so that it can be
+ // cloned and re-opened multiple times after creation.
+ _, unopened, err := i.fileState.file.walk(ctx, []string{name})
+ if err != nil {
+ return fs.StableAttr{}, nil, err
+ }
+
+ // Construct new inode with file type overridden.
+ attr.Mode = changeType(attr.Mode, fileType)
+ sattr, iops := newInodeOperations(ctx, i.fileState.s, unopened, qid, mask, attr)
+ return sattr, iops, nil
}
// Remove implements InodeOperations.Remove.
@@ -307,20 +366,23 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
return syserror.ENAMETOOLONG
}
- var key device.MultiDeviceKey
- removeSocket := false
- if i.session().endpoints != nil {
- // Find out if file being deleted is a socket that needs to be
+ var key *device.MultiDeviceKey
+ if i.session().overrides != nil {
+ // Find out if file being deleted is a socket or pipe that needs to be
// removed from endpoint map.
if d, err := i.Lookup(ctx, dir, name); err == nil {
- defer d.DecRef()
- if fs.IsSocket(d.Inode.StableAttr) {
- child := d.Inode.InodeOperations.(*inodeOperations)
- key = child.fileState.key
- removeSocket = true
-
- // Stabilize the endpoint map while deletion is in progress.
- unlock := i.session().endpoints.lock()
+ defer d.DecRef(ctx)
+
+ if fs.IsSocket(d.Inode.StableAttr) || fs.IsPipe(d.Inode.StableAttr) {
+ switch iops := d.Inode.InodeOperations.(type) {
+ case *inodeOperations:
+ key = &iops.fileState.key
+ case *fifo:
+ key = &iops.fileIops.fileState.key
+ }
+
+ // Stabilize the override map while deletion is in progress.
+ unlock := i.session().overrides.lock()
defer unlock()
}
}
@@ -329,8 +391,8 @@ func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string
if err := i.fileState.file.unlinkAt(ctx, name, 0); err != nil {
return err
}
- if removeSocket {
- i.session().endpoints.remove(key)
+ if key != nil {
+ i.session().overrides.remove(ctx, *key)
}
i.touchModificationAndStatusChangeTime(ctx, dir)
@@ -364,17 +426,16 @@ func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent
return syserror.ENAMETOOLONG
}
- // Unwrap the new parent to a *inodeOperations.
- newParentInodeOperations, ok := newParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
+ // Don't allow renames across different mounts.
+ if newParent.MountSource != oldParent.MountSource {
+ return syserror.EXDEV
}
+ // Unwrap the new parent to a *inodeOperations.
+ newParentInodeOperations := newParent.InodeOperations.(*inodeOperations)
+
// Unwrap the old parent to a *inodeOperations.
- oldParentInodeOperations, ok := oldParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
- }
+ oldParentInodeOperations := oldParent.InodeOperations.(*inodeOperations)
// Do the rename.
if err := i.fileState.file.rename(ctx, newParentInodeOperations.fileState.file, newName); err != nil {
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 0da608548..7cf3522ff 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -16,15 +16,15 @@ package gofer
import (
"fmt"
- "sync"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -33,60 +33,107 @@ import (
var DefaultDirentCacheSize uint64 = fs.DefaultDirentCacheSize
// +stateify savable
-type endpointMaps struct {
- // mu protexts the direntMap, the keyMap, and the pathMap below.
- mu sync.RWMutex `state:"nosave"`
+type overrideInfo struct {
+ dirent *fs.Dirent
+
+ // endpoint is set when dirent points to a socket. inode must not be set.
+ endpoint transport.BoundEndpoint
+
+ // inode is set when dirent points to a pipe. endpoint must not be set.
+ inode *fs.Inode
+}
- // direntMap links sockets to their dirents.
- // It is filled concurrently with the keyMap and is stored upon save.
- // Before saving, this map is used to populate the pathMap.
- direntMap map[transport.BoundEndpoint]*fs.Dirent
+func (l *overrideInfo) inodeType() fs.InodeType {
+ switch {
+ case l.endpoint != nil:
+ return fs.Socket
+ case l.inode != nil:
+ return fs.Pipe
+ }
+ panic("endpoint or node must be set")
+}
- // keyMap links MultiDeviceKeys (containing inode IDs) to their sockets.
+// +stateify savable
+type overrideMaps struct {
+ // mu protexts the keyMap, and the pathMap below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // keyMap links MultiDeviceKeys (containing inode IDs) to their sockets/pipes.
// It is not stored during save because the inode ID may change upon restore.
- keyMap map[device.MultiDeviceKey]transport.BoundEndpoint `state:"nosave"`
+ keyMap map[device.MultiDeviceKey]*overrideInfo `state:"nosave"`
- // pathMap links the sockets to their paths.
+ // pathMap links the sockets/pipes to their paths.
// It is filled before saving from the direntMap and is stored upon save.
// Upon restore, this map is used to re-populate the keyMap.
- pathMap map[transport.BoundEndpoint]string
+ pathMap map[*overrideInfo]string
+}
+
+// addBoundEndpoint adds the bound endpoint to the map.
+// A reference is taken on the dirent argument.
+//
+// Precondition: maps must have been locked with 'lock'.
+func (e *overrideMaps) addBoundEndpoint(key device.MultiDeviceKey, d *fs.Dirent, ep transport.BoundEndpoint) {
+ d.IncRef()
+ e.keyMap[key] = &overrideInfo{dirent: d, endpoint: ep}
}
-// add adds the endpoint to the maps.
+// addPipe adds the pipe inode to the map.
// A reference is taken on the dirent argument.
//
// Precondition: maps must have been locked with 'lock'.
-func (e *endpointMaps) add(key device.MultiDeviceKey, d *fs.Dirent, ep transport.BoundEndpoint) {
- e.keyMap[key] = ep
+func (e *overrideMaps) addPipe(key device.MultiDeviceKey, d *fs.Dirent, inode *fs.Inode) {
d.IncRef()
- e.direntMap[ep] = d
+ e.keyMap[key] = &overrideInfo{dirent: d, inode: inode}
}
// remove deletes the key from the maps.
//
// Precondition: maps must have been locked with 'lock'.
-func (e *endpointMaps) remove(key device.MultiDeviceKey) {
- endpoint := e.get(key)
+func (e *overrideMaps) remove(ctx context.Context, key device.MultiDeviceKey) {
+ endpoint := e.keyMap[key]
delete(e.keyMap, key)
-
- d := e.direntMap[endpoint]
- d.DecRef()
- delete(e.direntMap, endpoint)
+ endpoint.dirent.DecRef(ctx)
}
// lock blocks other addition and removal operations from happening while
// the backing file is being created or deleted. Returns a function that unlocks
// the endpoint map.
-func (e *endpointMaps) lock() func() {
+func (e *overrideMaps) lock() func() {
e.mu.Lock()
return func() { e.mu.Unlock() }
}
-// get returns the endpoint mapped to the given key.
+// getBoundEndpoint returns the bound endpoint mapped to the given key.
//
-// Precondition: maps must have been locked for reading.
-func (e *endpointMaps) get(key device.MultiDeviceKey) transport.BoundEndpoint {
- return e.keyMap[key]
+// Precondition: maps must have been locked.
+func (e *overrideMaps) getBoundEndpoint(key device.MultiDeviceKey) transport.BoundEndpoint {
+ if v := e.keyMap[key]; v != nil {
+ return v.endpoint
+ }
+ return nil
+}
+
+// getPipe returns the pipe inode mapped to the given key.
+//
+// Precondition: maps must have been locked.
+func (e *overrideMaps) getPipe(key device.MultiDeviceKey) *fs.Inode {
+ if v := e.keyMap[key]; v != nil {
+ return v.inode
+ }
+ return nil
+}
+
+// getType returns the inode type if there is a corresponding endpoint for the
+// given key. Returns false otherwise.
+func (e *overrideMaps) getType(key device.MultiDeviceKey) (fs.InodeType, bool) {
+ e.mu.Lock()
+ v := e.keyMap[key]
+ e.mu.Unlock()
+
+ if v != nil {
+ return v.inodeType(), true
+ }
+ return 0, false
}
// session holds state for each 9p session established during sys_mount.
@@ -137,20 +184,20 @@ type session struct {
// mounter is the EUID/EGID that mounted this file system.
mounter fs.FileOwner `state:"wait"`
- // endpoints is used to map inodes that represent socket files to their
- // corresponding endpoint. Socket files are created as regular files in the
- // gofer and their presence in this map indicate that they should indeed be
- // socket files. This allows unix domain sockets to be used with paths that
- // belong to a gofer.
+ // overrides is used to map inodes that represent socket/pipes files to their
+ // corresponding endpoint/iops. These files are created as regular files in
+ // the gofer and their presence in this map indicate that they should indeed
+ // be socket/pipe files. This allows unix domain sockets and named pipes to
+ // be used with paths that belong to a gofer.
//
- // TODO(b/77154739): there are few possible races with someone stat'ing the
- // file and another deleting it concurrently, where the file will not be
- // reported as socket file.
- endpoints *endpointMaps `state:"wait"`
+ // There are a few possible races with someone stat'ing the file and another
+ // deleting it concurrently, where the file will not be reported as socket
+ // file.
+ overrides *overrideMaps `state:"wait"`
}
// Destroy tears down the session.
-func (s *session) Destroy() {
+func (s *session) Destroy(ctx context.Context) {
s.client.Close()
}
@@ -179,15 +226,21 @@ func (s *session) SaveInodeMapping(inode *fs.Inode, path string) {
// This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
// because overlay copyUp may have changed them out from under us.
// So much for "immutable".
- sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr
- s.inodeMappings[sattr.InodeID] = path
+ switch iops := inode.InodeOperations.(type) {
+ case *inodeOperations:
+ s.inodeMappings[iops.fileState.sattr.InodeID] = path
+ case *fifo:
+ s.inodeMappings[iops.fileIops.fileState.sattr.InodeID] = path
+ default:
+ panic(fmt.Sprintf("Invalid type: %T", iops))
+ }
}
-// newInodeOperations creates a new 9p fs.InodeOperations backed by a p9.File and attributes
-// (p9.QID, p9.AttrMask, p9.Attr).
+// newInodeOperations creates a new 9p fs.InodeOperations backed by a p9.File
+// and attributes (p9.QID, p9.AttrMask, p9.Attr).
//
// Endpoints lock must not be held if socket == false.
-func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p9.QID, valid p9.AttrMask, attr p9.Attr, socket bool) (fs.StableAttr, *inodeOperations) {
+func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p9.QID, valid p9.AttrMask, attr p9.Attr) (fs.StableAttr, *inodeOperations) {
deviceKey := device.MultiDeviceKey{
Device: attr.RDev,
SecondaryDevice: s.connID,
@@ -201,17 +254,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p
BlockSize: bsize(attr),
}
- if s.endpoints != nil {
- if socket {
- sattr.Type = fs.Socket
- } else {
- // If unix sockets are allowed on this filesystem, check if this file is
- // supposed to be a socket file.
- unlock := s.endpoints.lock()
- if s.endpoints.get(deviceKey) != nil {
- sattr.Type = fs.Socket
- }
- unlock()
+ if s.overrides != nil && sattr.Type == fs.RegularFile {
+ // If overrides are allowed on this filesystem, check if this file is
+ // supposed to be of a different type, e.g. socket.
+ if t, ok := s.overrides.getType(deviceKey); ok {
+ sattr.Type = t
}
}
@@ -267,7 +314,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
s.EnableLeakCheck("gofer.session")
if o.privateunixsocket {
- s.endpoints = newEndpointMaps()
+ s.overrides = newOverrideMaps()
}
// Construct the MountSource with the session and superBlockFlags.
@@ -282,7 +329,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
s.client, err = p9.NewClient(conn, s.msize, s.version)
if err != nil {
// Drop our reference on the session, it needs to be torn down.
- s.DecRef()
+ s.DecRef(ctx)
return nil, err
}
@@ -293,7 +340,7 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
ctx.UninterruptibleSleepFinish(false)
if err != nil {
// Same as above.
- s.DecRef()
+ s.DecRef(ctx)
return nil, err
}
@@ -301,30 +348,28 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
if err != nil {
s.attach.close(ctx)
// Same as above, but after we execute the Close request.
- s.DecRef()
+ s.DecRef(ctx)
return nil, err
}
- sattr, iops := newInodeOperations(ctx, &s, s.attach, qid, valid, attr, false)
+ sattr, iops := newInodeOperations(ctx, &s, s.attach, qid, valid, attr)
return fs.NewInode(ctx, iops, m, sattr), nil
}
-// newEndpointMaps creates a new endpointMaps.
-func newEndpointMaps() *endpointMaps {
- return &endpointMaps{
- direntMap: make(map[transport.BoundEndpoint]*fs.Dirent),
- keyMap: make(map[device.MultiDeviceKey]transport.BoundEndpoint),
- pathMap: make(map[transport.BoundEndpoint]string),
+// newOverrideMaps creates a new overrideMaps.
+func newOverrideMaps() *overrideMaps {
+ return &overrideMaps{
+ keyMap: make(map[device.MultiDeviceKey]*overrideInfo),
+ pathMap: make(map[*overrideInfo]string),
}
}
-// fillKeyMap populates key and dirent maps upon restore from saved
-// pathmap.
+// fillKeyMap populates key and dirent maps upon restore from saved pathmap.
func (s *session) fillKeyMap(ctx context.Context) error {
- unlock := s.endpoints.lock()
+ unlock := s.overrides.lock()
defer unlock()
- for ep, dirPath := range s.endpoints.pathMap {
+ for ep, dirPath := range s.overrides.pathMap {
_, file, err := s.attach.walk(ctx, splitAbsolutePath(dirPath))
if err != nil {
return fmt.Errorf("error filling endpointmaps, failed to walk to %q: %v", dirPath, err)
@@ -341,25 +386,25 @@ func (s *session) fillKeyMap(ctx context.Context) error {
Inode: qid.Path,
}
- s.endpoints.keyMap[key] = ep
+ s.overrides.keyMap[key] = ep
}
return nil
}
-// fillPathMap populates paths for endpoints from dirents in direntMap
+// fillPathMap populates paths for overrides from dirents in direntMap
// before save.
-func (s *session) fillPathMap() error {
- unlock := s.endpoints.lock()
+func (s *session) fillPathMap(ctx context.Context) error {
+ unlock := s.overrides.lock()
defer unlock()
- for ep, dir := range s.endpoints.direntMap {
- mountRoot := dir.MountRoot()
- defer mountRoot.DecRef()
- dirPath, _ := dir.FullName(mountRoot)
+ for _, endpoint := range s.overrides.keyMap {
+ mountRoot := endpoint.dirent.MountRoot()
+ defer mountRoot.DecRef(ctx)
+ dirPath, _ := endpoint.dirent.FullName(mountRoot)
if dirPath == "" {
return fmt.Errorf("error getting path from dirent")
}
- s.endpoints.pathMap[ep] = dirPath
+ s.overrides.pathMap[endpoint] = dirPath
}
return nil
}
@@ -368,7 +413,7 @@ func (s *session) fillPathMap() error {
func (s *session) restoreEndpointMaps(ctx context.Context) error {
// When restoring, only need to create the keyMap because the dirent and path
// maps got stored through the save.
- s.endpoints.keyMap = make(map[device.MultiDeviceKey]transport.BoundEndpoint)
+ s.overrides.keyMap = make(map[device.MultiDeviceKey]*overrideInfo)
if err := s.fillKeyMap(ctx); err != nil {
return fmt.Errorf("failed to insert sockets into endpoint map: %v", err)
}
@@ -376,6 +421,6 @@ func (s *session) restoreEndpointMaps(ctx context.Context) error {
// Re-create pathMap because it can no longer be trusted as socket paths can
// change while process continues to run. Empty pathMap will be re-filled upon
// next save.
- s.endpoints.pathMap = make(map[transport.BoundEndpoint]string)
+ s.overrides.pathMap = make(map[*overrideInfo]string)
return nil
}
diff --git a/pkg/sentry/fs/gofer/session_state.go b/pkg/sentry/fs/gofer/session_state.go
index d045e04ff..48b423dd8 100644
--- a/pkg/sentry/fs/gofer/session_state.go
+++ b/pkg/sentry/fs/gofer/session_state.go
@@ -17,17 +17,18 @@ package gofer
import (
"fmt"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/unet"
)
// beforeSave is invoked by stateify.
func (s *session) beforeSave() {
- if s.endpoints != nil {
- if err := s.fillPathMap(); err != nil {
- panic("failed to save paths to endpoint map before saving" + err.Error())
+ if s.overrides != nil {
+ ctx := &dummyClockContext{context.Background()}
+ if err := s.fillPathMap(ctx); err != nil {
+ panic("failed to save paths to override map before saving" + err.Error())
}
}
}
@@ -74,10 +75,10 @@ func (s *session) afterLoad() {
panic(fmt.Sprintf("new attach name %v, want %v", opts.aname, s.aname))
}
- // Check if endpointMaps exist when uds sockets are enabled
- // (only pathmap will actualy have been saved).
- if opts.privateunixsocket != (s.endpoints != nil) {
- panic(fmt.Sprintf("new privateunixsocket option %v, want %v", opts.privateunixsocket, s.endpoints != nil))
+ // Check if overrideMaps exist when uds sockets are enabled (only pathmaps
+ // will actually have been saved).
+ if opts.privateunixsocket != (s.overrides != nil) {
+ panic(fmt.Sprintf("new privateunixsocket option %v, want %v", opts.privateunixsocket, s.overrides != nil))
}
if args.Flags != s.superBlockFlags {
panic(fmt.Sprintf("new mount flags %v, want %v", args.Flags, s.superBlockFlags))
@@ -104,7 +105,6 @@ func (s *session) afterLoad() {
// If private unix sockets are enabled, create and fill the session's endpoint
// maps.
if opts.privateunixsocket {
- // TODO(b/38173783): Context is not plumbed to save/restore.
ctx := &dummyClockContext{context.Background()}
if err = s.restoreEndpointMaps(ctx); err != nil {
diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go
index a45a8f36c..8a1c69ac2 100644
--- a/pkg/sentry/fs/gofer/socket.go
+++ b/pkg/sentry/fs/gofer/socket.go
@@ -16,9 +16,9 @@ package gofer
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
@@ -32,21 +32,23 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.
return nil
}
- if i.session().endpoints != nil {
- unlock := i.session().endpoints.lock()
+ if i.session().overrides != nil {
+ unlock := i.session().overrides.lock()
defer unlock()
- ep := i.session().endpoints.get(i.fileState.key)
+ ep := i.session().overrides.getBoundEndpoint(i.fileState.key)
if ep != nil {
return ep
}
- // Not found in endpoints map, it may be a gofer backed unix socket...
+ // Not found in overrides map, it may be a gofer backed unix socket...
}
inode.IncRef()
return &endpoint{inode, i.fileState.file.file, path}
}
+// LINT.IfChange
+
// endpoint is a Gofer-backed transport.BoundEndpoint.
//
// An endpoint's lifetime is the time between when InodeOperations.BoundEndpoint()
@@ -132,17 +134,19 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect
// We don't need the receiver.
c.CloseRecv()
- c.Release()
+ c.Release(ctx)
return c, nil
}
// Release implements transport.BoundEndpoint.Release.
-func (e *endpoint) Release() {
- e.inode.DecRef()
+func (e *endpoint) Release(ctx context.Context) {
+ e.inode.DecRef(ctx)
}
// Passcred implements transport.BoundEndpoint.Passcred.
func (e *endpoint) Passcred() bool {
return false
}
+
+// LINT.ThenChange(../../fsimpl/gofer/socket.go)
diff --git a/pkg/sentry/fs/gofer/util.go b/pkg/sentry/fs/gofer/util.go
index 848e6812b..47a6c69bf 100644
--- a/pkg/sentry/fs/gofer/util.go
+++ b/pkg/sentry/fs/gofer/util.go
@@ -17,20 +17,32 @@ package gofer
import (
"syscall"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
func utimes(ctx context.Context, file contextFile, ts fs.TimeSpec) error {
if ts.ATimeOmit && ts.MTimeOmit {
return nil
}
+
+ // Replace requests to use the "system time" with the current time to
+ // ensure that timestamps remain consistent with the remote
+ // filesystem.
+ now := ktime.NowFromContext(ctx)
+ if ts.ATimeSetSystemTime {
+ ts.ATime = now
+ }
+ if ts.MTimeSetSystemTime {
+ ts.MTime = now
+ }
mask := p9.SetAttrMask{
ATime: !ts.ATimeOmit,
- ATimeNotSystemTime: !ts.ATimeSetSystemTime,
+ ATimeNotSystemTime: true,
MTime: !ts.MTimeOmit,
- MTimeNotSystemTime: !ts.MTimeSetSystemTime,
+ MTimeNotSystemTime: true,
}
as, ans := ts.ATime.Unix()
ms, mns := ts.MTime.Unix()
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index 1cbed07ae..d41d23a43 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,7 +10,7 @@ go_library(
"descriptor_state.go",
"device.go",
"file.go",
- "fs.go",
+ "host.go",
"inode.go",
"inode_state.go",
"ioctl_unsafe.go",
@@ -21,19 +20,22 @@ go_library(
"socket_unsafe.go",
"tty.go",
"util.go",
+ "util_amd64_unsafe.go",
+ "util_arm64_unsafe.go",
"util_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/host",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/fd",
"//pkg/fdnotifier",
+ "//pkg/iovec",
"//pkg/log",
"//pkg/refs",
+ "//pkg/safemem",
"//pkg/secio",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
@@ -41,17 +43,17 @@ go_library(
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
"//pkg/sentry/uniqueid",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/unet",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -61,24 +63,21 @@ go_test(
size = "small",
srcs = [
"descriptor_test.go",
- "fs_test.go",
"inode_test.go",
"socket_test.go",
"wait_test.go",
],
- embed = [":host"],
+ library = ":host",
deps = [
"//pkg/fd",
"//pkg/fdnotifier",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
- "//pkg/sentry/fs",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
- "//pkg/sentry/usermem",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/fs/host/control.go b/pkg/sentry/fs/host/control.go
index 5532ff5a0..0d8d36afa 100644
--- a/pkg/sentry/fs/host/control.go
+++ b/pkg/sentry/fs/host/control.go
@@ -17,12 +17,14 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
+// LINT.IfChange
+
type scmRights struct {
fds []int
}
@@ -55,7 +57,7 @@ func (c *scmRights) Clone() transport.RightsControlMessage {
}
// Release implements transport.RightsControlMessage.Release.
-func (c *scmRights) Release() {
+func (c *scmRights) Release(ctx context.Context) {
for _, fd := range c.fds {
syscall.Close(fd)
}
@@ -76,7 +78,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*fs.File {
}
// Create the file backed by hostFD.
- file, err := NewFile(ctx, fd, fs.FileOwnerFromContext(ctx))
+ file, err := NewFile(ctx, fd)
if err != nil {
ctx.Warningf("Error creating file from host FD: %v", err)
break
@@ -91,3 +93,5 @@ func fdsToFiles(ctx context.Context, fds []int) []*fs.File {
}
return files
}
+
+// LINT.ThenChange(../../fsimpl/host/control.go)
diff --git a/pkg/sentry/fs/host/descriptor.go b/pkg/sentry/fs/host/descriptor.go
index 2a4d1b291..cfdce6a74 100644
--- a/pkg/sentry/fs/host/descriptor.go
+++ b/pkg/sentry/fs/host/descriptor.go
@@ -16,7 +16,6 @@ package host
import (
"fmt"
- "path"
"syscall"
"gvisor.dev/gvisor/pkg/fdnotifier"
@@ -28,12 +27,9 @@ import (
//
// +stateify savable
type descriptor struct {
- // donated is true if the host fd was donated by another process.
- donated bool
-
// If origFD >= 0, it is the host fd that this file was originally created
// from, which must be available at time of restore. The FD can be closed
- // after descriptor is created. Only set if donated is true.
+ // after descriptor is created.
origFD int
// wouldBlock is true if value (below) points to a file that can
@@ -41,15 +37,13 @@ type descriptor struct {
wouldBlock bool
// value is the wrapped host fd. It is never saved or restored
- // directly. How it is restored depends on whether it was
- // donated and the fs.MountSource it was originally
- // opened/created from.
+ // directly.
value int `state:"nosave"`
}
// newDescriptor returns a wrapped host file descriptor. On success,
// the descriptor is registered for event notifications with queue.
-func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
+func newDescriptor(fd int, saveable bool, wouldBlock bool, queue *waiter.Queue) (*descriptor, error) {
ownedFD := fd
origFD := -1
if saveable {
@@ -69,7 +63,6 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *
}
}
return &descriptor{
- donated: donated,
origFD: origFD,
wouldBlock: wouldBlock,
value: ownedFD,
@@ -77,25 +70,11 @@ func newDescriptor(fd int, donated bool, saveable bool, wouldBlock bool, queue *
}
// initAfterLoad initializes the value of the descriptor after Load.
-func (d *descriptor) initAfterLoad(mo *superOperations, id uint64, queue *waiter.Queue) error {
- if d.donated {
- var err error
- d.value, err = syscall.Dup(d.origFD)
- if err != nil {
- return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
- }
- } else {
- name, ok := mo.inodeMappings[id]
- if !ok {
- return fmt.Errorf("failed to find path for inode number %d", id)
- }
- fullpath := path.Join(mo.root, name)
-
- var err error
- d.value, err = open(nil, fullpath)
- if err != nil {
- return fmt.Errorf("failed to open %q: %v", fullpath, err)
- }
+func (d *descriptor) initAfterLoad(id uint64, queue *waiter.Queue) error {
+ var err error
+ d.value, err = syscall.Dup(d.origFD)
+ if err != nil {
+ return fmt.Errorf("failed to dup restored fd %d: %v", d.origFD, err)
}
if d.wouldBlock {
if err := syscall.SetNonblock(d.value, true); err != nil {
diff --git a/pkg/sentry/fs/host/descriptor_state.go b/pkg/sentry/fs/host/descriptor_state.go
index 8167390a9..e880582ab 100644
--- a/pkg/sentry/fs/host/descriptor_state.go
+++ b/pkg/sentry/fs/host/descriptor_state.go
@@ -16,7 +16,7 @@ package host
// beforeSave is invoked by stateify.
func (d *descriptor) beforeSave() {
- if d.donated && d.origFD < 0 {
+ if d.origFD < 0 {
panic("donated file descriptor cannot be saved")
}
}
diff --git a/pkg/sentry/fs/host/descriptor_test.go b/pkg/sentry/fs/host/descriptor_test.go
index 4205981f5..d8e4605b6 100644
--- a/pkg/sentry/fs/host/descriptor_test.go
+++ b/pkg/sentry/fs/host/descriptor_test.go
@@ -47,10 +47,10 @@ func TestDescriptorRelease(t *testing.T) {
// FD ownership is transferred to the descritor.
queue := &waiter.Queue{}
- d, err := newDescriptor(fd, false /* donated*/, tc.saveable, tc.wouldBlock, queue)
+ d, err := newDescriptor(fd, tc.saveable, tc.wouldBlock, queue)
if err != nil {
syscall.Close(fd)
- t.Fatalf("newDescriptor(%d, %t, false, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err)
+ t.Fatalf("newDescriptor(%d, %t, %t, queue) failed, err: %v", fd, tc.saveable, tc.wouldBlock, err)
}
if tc.saveable {
if d.origFD < 0 {
diff --git a/pkg/sentry/fs/host/file.go b/pkg/sentry/fs/host/file.go
index f6c626f2c..86d1a87f0 100644
--- a/pkg/sentry/fs/host/file.go
+++ b/pkg/sentry/fs/host/file.go
@@ -18,17 +18,17 @@ import (
"fmt"
"syscall"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/secio"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -60,8 +60,8 @@ var _ fs.FileOperations = (*fileOperations)(nil)
// The returned File cannot be saved, since there is no guarantee that the same
// FD will exist or represent the same file at time of restore. If such a
// guarantee does exist, use ImportFile instead.
-func NewFile(ctx context.Context, fd int, mounter fs.FileOwner) (*fs.File, error) {
- return newFileFromDonatedFD(ctx, fd, mounter, false, false)
+func NewFile(ctx context.Context, fd int) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, false, false)
}
// ImportFile creates a new File backed by the provided host file descriptor.
@@ -71,13 +71,13 @@ func NewFile(ctx context.Context, fd int, mounter fs.FileOwner) (*fs.File, error
// If the returned file is saved, it will be restored by re-importing the FD
// originally passed to ImportFile. It is the restorer's responsibility to
// ensure that the FD represents the same file.
-func ImportFile(ctx context.Context, fd int, mounter fs.FileOwner, isTTY bool) (*fs.File, error) {
- return newFileFromDonatedFD(ctx, fd, mounter, true, isTTY)
+func ImportFile(ctx context.Context, fd int, isTTY bool) (*fs.File, error) {
+ return newFileFromDonatedFD(ctx, fd, true, isTTY)
}
// newFileFromDonatedFD returns an fs.File from a donated FD. If the FD is
// saveable, then saveable is true.
-func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner, saveable, isTTY bool) (*fs.File, error) {
+func newFileFromDonatedFD(ctx context.Context, donated int, saveable, isTTY bool) (*fs.File, error) {
var s syscall.Stat_t
if err := syscall.Fstat(donated, &s); err != nil {
return nil, err
@@ -101,8 +101,8 @@ func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner
})
return s, nil
default:
- msrc := newMountSource(ctx, "/", mounter, &Filesystem{}, fs.MountSourceFlags{}, false /* dontTranslateOwnership */)
- inode, err := newInode(ctx, msrc, donated, saveable, true /* donated */)
+ msrc := fs.NewNonCachingMountSource(ctx, &filesystem{}, fs.MountSourceFlags{})
+ inode, err := newInode(ctx, msrc, donated, saveable)
if err != nil {
return nil, err
}
@@ -110,7 +110,7 @@ func newFileFromDonatedFD(ctx context.Context, donated int, mounter fs.FileOwner
name := fmt.Sprintf("host:[%d]", inode.StableAttr.InodeID)
dirent := fs.NewDirent(ctx, inode, name)
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
if isTTY {
return newTTYFile(ctx, dirent, flags, iops), nil
@@ -169,7 +169,7 @@ func (f *fileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
func (f *fileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
diff --git a/pkg/sentry/fs/host/fs.go b/pkg/sentry/fs/host/fs.go
deleted file mode 100644
index 68d2697c0..000000000
--- a/pkg/sentry/fs/host/fs.go
+++ /dev/null
@@ -1,339 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package host implements an fs.Filesystem for files backed by host
-// file descriptors.
-package host
-
-import (
- "fmt"
- "path"
- "path/filepath"
- "strconv"
- "strings"
-
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
-)
-
-// FilesystemName is the name under which Filesystem is registered.
-const FilesystemName = "whitelistfs"
-
-const (
- // whitelistKey is the mount option containing a comma-separated list
- // of host paths to whitelist.
- whitelistKey = "whitelist"
-
- // rootPathKey is the mount option containing the root path of the
- // mount.
- rootPathKey = "root"
-
- // dontTranslateOwnershipKey is the key to superOperations.dontTranslateOwnership.
- dontTranslateOwnershipKey = "dont_translate_ownership"
-)
-
-// maxTraversals determines link traversals in building the whitelist.
-const maxTraversals = 10
-
-// Filesystem is a pseudo file system that is only available during the setup
-// to lock down the configurations. This filesystem should only be mounted at root.
-//
-// Think twice before exposing this to applications.
-//
-// +stateify savable
-type Filesystem struct {
- // whitelist is a set of host paths to whitelist.
- paths []string
-}
-
-var _ fs.Filesystem = (*Filesystem)(nil)
-
-// Name is the identifier of this file system.
-func (*Filesystem) Name() string {
- return FilesystemName
-}
-
-// AllowUserMount prohibits users from using mount(2) with this file system.
-func (*Filesystem) AllowUserMount() bool {
- return false
-}
-
-// AllowUserList allows this filesystem to be listed in /proc/filesystems.
-func (*Filesystem) AllowUserList() bool {
- return true
-}
-
-// Flags returns that there is nothing special about this file system.
-func (*Filesystem) Flags() fs.FilesystemFlags {
- return 0
-}
-
-// Mount returns an fs.Inode exposing the host file system. It is intended to be locked
-// down in PreExec below.
-func (f *Filesystem) Mount(ctx context.Context, _ string, flags fs.MountSourceFlags, data string, _ interface{}) (*fs.Inode, error) {
- // Parse generic comma-separated key=value options.
- options := fs.GenericMountSourceOptions(data)
-
- // Grab the whitelist if one was specified.
- // TODO(edahlgren/mpratt/hzy): require another option "testonly" in order to allow
- // no whitelist.
- if wl, ok := options[whitelistKey]; ok {
- f.paths = strings.Split(wl, "|")
- delete(options, whitelistKey)
- }
-
- // If the rootPath was set, use it. Othewise default to the root of the
- // host fs.
- rootPath := "/"
- if rp, ok := options[rootPathKey]; ok {
- rootPath = rp
- delete(options, rootPathKey)
-
- // We must relativize the whitelisted paths to the new root.
- for i, p := range f.paths {
- rel, err := filepath.Rel(rootPath, p)
- if err != nil {
- return nil, fmt.Errorf("whitelist path %q must be a child of root path %q", p, rootPath)
- }
- f.paths[i] = path.Join("/", rel)
- }
- }
- fd, err := open(nil, rootPath)
- if err != nil {
- return nil, fmt.Errorf("failed to find root: %v", err)
- }
-
- var dontTranslateOwnership bool
- if v, ok := options[dontTranslateOwnershipKey]; ok {
- b, err := strconv.ParseBool(v)
- if err != nil {
- return nil, fmt.Errorf("invalid value for %q: %v", dontTranslateOwnershipKey, err)
- }
- dontTranslateOwnership = b
- delete(options, dontTranslateOwnershipKey)
- }
-
- // Fail if the caller passed us more options than we know about.
- if len(options) > 0 {
- return nil, fmt.Errorf("unsupported mount options: %v", options)
- }
-
- // The mounting EUID/EGID will be cached by this file system. This will
- // be used to assign ownership to files that we own.
- owner := fs.FileOwnerFromContext(ctx)
-
- // Construct the host file system mount and inode.
- msrc := newMountSource(ctx, rootPath, owner, f, flags, dontTranslateOwnership)
- return newInode(ctx, msrc, fd, false /* saveable */, false /* donated */)
-}
-
-// InstallWhitelist locks down the MountNamespace to only the currently installed
-// Dirents and the given paths.
-func (f *Filesystem) InstallWhitelist(ctx context.Context, m *fs.MountNamespace) error {
- return installWhitelist(ctx, m, f.paths)
-}
-
-func installWhitelist(ctx context.Context, m *fs.MountNamespace, paths []string) error {
- if len(paths) == 0 || (len(paths) == 1 && paths[0] == "") {
- // Warning will be logged during filter installation if the empty
- // whitelist matters (allows for host file access).
- return nil
- }
-
- // Done tracks entries already added.
- done := make(map[string]bool)
- root := m.Root()
- defer root.DecRef()
-
- for i := 0; i < len(paths); i++ {
- // Make sure the path is absolute. This is a sanity check.
- if !path.IsAbs(paths[i]) {
- return fmt.Errorf("path %q is not absolute", paths[i])
- }
-
- // We need to add all the intermediate paths, in case one of
- // them is a symlink that needs to be resolved.
- for j := 1; j <= len(paths[i]); j++ {
- if j < len(paths[i]) && paths[i][j] != '/' {
- continue
- }
- current := paths[i][:j]
-
- // Lookup the given component in the tree.
- remainingTraversals := uint(maxTraversals)
- d, err := m.FindLink(ctx, root, nil, current, &remainingTraversals)
- if err != nil {
- log.Warningf("populate failed for %q: %v", current, err)
- continue
- }
-
- // It's critical that this DecRef happens after the
- // freeze below. This ensures that the dentry is in
- // place to be frozen. Otherwise, we freeze without
- // these entries.
- defer d.DecRef()
-
- // Expand the last component if necessary.
- if current == paths[i] {
- // Is it a directory or symlink?
- sattr := d.Inode.StableAttr
- if fs.IsDir(sattr) {
- for name := range childDentAttrs(ctx, d) {
- paths = append(paths, path.Join(current, name))
- }
- }
- if fs.IsSymlink(sattr) {
- // Only expand symlinks once. The
- // folder structure may contain
- // recursive symlinks and we don't want
- // to end up infinitely expanding this
- // symlink. This is safe because this
- // is the last component. If a later
- // path wants to symlink something
- // beneath this symlink that will still
- // be handled by the FindLink above.
- if done[current] {
- continue
- }
-
- s, err := d.Inode.Readlink(ctx)
- if err != nil {
- log.Warningf("readlink failed for %q: %v", current, err)
- continue
- }
- if path.IsAbs(s) {
- paths = append(paths, s)
- } else {
- target := path.Join(path.Dir(current), s)
- paths = append(paths, target)
- }
- }
- }
-
- // Only report this one once even though we may look
- // it up more than once. If we whitelist /a/b,/a then
- // /a will be "done" when it is looked up for /a/b,
- // however we still need to expand all of its contents
- // when whitelisting /a.
- if !done[current] {
- log.Debugf("whitelisted: %s", current)
- }
- done[current] = true
- }
- }
-
- // Freeze the mount tree in place. This prevents any new paths from
- // being opened and any old ones from being removed. If we do provide
- // tmpfs mounts, we'll want to freeze/thaw those separately.
- m.Freeze()
- return nil
-}
-
-func childDentAttrs(ctx context.Context, d *fs.Dirent) map[string]fs.DentAttr {
- dirname, _ := d.FullName(nil /* root */)
- dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- log.Warningf("failed to open directory %q: %v", dirname, err)
- return nil
- }
- dir.DecRef()
- var stubSerializer fs.CollectEntriesSerializer
- if err := dir.Readdir(ctx, &stubSerializer); err != nil {
- log.Warningf("failed to iterate on host directory %q: %v", dirname, err)
- return nil
- }
- delete(stubSerializer.Entries, ".")
- delete(stubSerializer.Entries, "..")
- return stubSerializer.Entries
-}
-
-// newMountSource constructs a new host fs.MountSource
-// relative to a root path. The root should match the mount point.
-func newMountSource(ctx context.Context, root string, mounter fs.FileOwner, filesystem fs.Filesystem, flags fs.MountSourceFlags, dontTranslateOwnership bool) *fs.MountSource {
- return fs.NewMountSource(ctx, &superOperations{
- root: root,
- inodeMappings: make(map[uint64]string),
- mounter: mounter,
- dontTranslateOwnership: dontTranslateOwnership,
- }, filesystem, flags)
-}
-
-// superOperations implements fs.MountSourceOperations.
-//
-// +stateify savable
-type superOperations struct {
- fs.SimpleMountSourceOperations
-
- // root is the path of the mount point. All inode mappings
- // are relative to this root.
- root string
-
- // inodeMappings contains mappings of fs.Inodes associated
- // with this MountSource to paths under root.
- inodeMappings map[uint64]string
-
- // mounter is the cached EUID/EGID that mounted this file system.
- mounter fs.FileOwner
-
- // dontTranslateOwnership indicates whether to not translate file
- // ownership.
- //
- // By default, files/directories owned by the sandbox uses UID/GID
- // of the mounter. For files/directories that are not owned by the
- // sandbox, file UID/GID is translated to a UID/GID which cannot
- // be mapped in the sandboxed application's user namespace. The
- // UID/GID will look like the nobody UID/GID (65534) but is not
- // strictly owned by the user "nobody".
- //
- // If whitelistfs is a lower filesystem in an overlay, set
- // dont_translate_ownership=true in mount options.
- dontTranslateOwnership bool
-}
-
-var _ fs.MountSourceOperations = (*superOperations)(nil)
-
-// ResetInodeMappings implements fs.MountSourceOperations.ResetInodeMappings.
-func (m *superOperations) ResetInodeMappings() {
- m.inodeMappings = make(map[uint64]string)
-}
-
-// SaveInodeMapping implements fs.MountSourceOperations.SaveInodeMapping.
-func (m *superOperations) SaveInodeMapping(inode *fs.Inode, path string) {
- // This is very unintuitive. We *CANNOT* trust the inode's StableAttrs,
- // because overlay copyUp may have changed them out from under us.
- // So much for "immutable".
- sattr := inode.InodeOperations.(*inodeOperations).fileState.sattr
- m.inodeMappings[sattr.InodeID] = path
-}
-
-// Keep implements fs.MountSourceOperations.Keep.
-//
-// TODO(b/72455313,b/77596690): It is possible to change the permissions on a
-// host file while it is in the dirent cache (say from RO to RW), but it is not
-// possible to re-open the file with more relaxed permissions, since the host
-// FD is already open and stored in the inode.
-//
-// Using the dirent LRU cache increases the odds that this bug is encountered.
-// Since host file access is relatively fast anyways, we disable the LRU cache
-// for host fs files. Once we can properly deal with permissions changes and
-// re-opening host files, we should revisit whether or not to make use of the
-// LRU cache.
-func (*superOperations) Keep(*fs.Dirent) bool {
- return false
-}
-
-func init() {
- fs.RegisterFilesystem(&Filesystem{})
-}
diff --git a/pkg/sentry/fs/host/fs_test.go b/pkg/sentry/fs/host/fs_test.go
deleted file mode 100644
index c6852ee30..000000000
--- a/pkg/sentry/fs/host/fs_test.go
+++ /dev/null
@@ -1,380 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package host
-
-import (
- "fmt"
- "io/ioutil"
- "os"
- "path"
- "reflect"
- "sort"
- "testing"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
-)
-
-// newTestMountNamespace creates a MountNamespace with a ramfs root.
-// It returns the host folder created, which should be removed when done.
-func newTestMountNamespace(t *testing.T) (*fs.MountNamespace, string, error) {
- p, err := ioutil.TempDir("", "root")
- if err != nil {
- return nil, "", err
- }
-
- fd, err := open(nil, p)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- ctx := contexttest.Context(t)
- root, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- mm, err := fs.NewMountNamespace(ctx, root)
- if err != nil {
- os.RemoveAll(p)
- return nil, "", err
- }
- return mm, p, nil
-}
-
-// createTestDirs populates the root with some test files and directories.
-// /a/a1.txt
-// /a/a2.txt
-// /b/b1.txt
-// /b/c/c1.txt
-// /symlinks/normal.txt
-// /symlinks/to_normal.txt -> /symlinks/normal.txt
-// /symlinks/recursive -> /symlinks
-func createTestDirs(ctx context.Context, t *testing.T, m *fs.MountNamespace) error {
- r := m.Root()
- defer r.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "a", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- a, err := r.Walk(ctx, r, "a")
- if err != nil {
- return err
- }
- defer a.DecRef()
-
- a1, err := a.Create(ctx, r, "a1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- a1.DecRef()
-
- a2, err := a.Create(ctx, r, "a2.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- a2.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "b", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- b, err := r.Walk(ctx, r, "b")
- if err != nil {
- return err
- }
- defer b.DecRef()
-
- b1, err := b.Create(ctx, r, "b1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- b1.DecRef()
-
- if err := b.CreateDirectory(ctx, r, "c", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- c, err := b.Walk(ctx, r, "c")
- if err != nil {
- return err
- }
- defer c.DecRef()
-
- c1, err := c.Create(ctx, r, "c1.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- c1.DecRef()
-
- if err := r.CreateDirectory(ctx, r, "symlinks", fs.FilePermsFromMode(0777)); err != nil {
- return err
- }
-
- symlinks, err := r.Walk(ctx, r, "symlinks")
- if err != nil {
- return err
- }
- defer symlinks.DecRef()
-
- normal, err := symlinks.Create(ctx, r, "normal.txt", fs.FileFlags{Read: true, Write: true}, fs.FilePermsFromMode(0666))
- if err != nil {
- return err
- }
- normal.DecRef()
-
- if err := symlinks.CreateLink(ctx, r, "/symlinks/normal.txt", "to_normal.txt"); err != nil {
- return err
- }
-
- return symlinks.CreateLink(ctx, r, "/symlinks", "recursive")
-}
-
-// allPaths returns a slice of all paths of entries visible in the rootfs.
-func allPaths(ctx context.Context, t *testing.T, m *fs.MountNamespace, base string) ([]string, error) {
- var paths []string
- root := m.Root()
- defer root.DecRef()
-
- maxTraversals := uint(1)
- d, err := m.FindLink(ctx, root, nil, base, &maxTraversals)
- if err != nil {
- t.Logf("FindLink failed for %q", base)
- return paths, err
- }
- defer d.DecRef()
-
- if fs.IsDir(d.Inode.StableAttr) {
- dir, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
- if err != nil {
- return nil, fmt.Errorf("failed to open directory %q: %v", base, err)
- }
- iter, ok := dir.FileOperations.(fs.DirIterator)
- if !ok {
- return nil, fmt.Errorf("cannot directly iterate on host directory %q", base)
- }
- dirCtx := &fs.DirCtx{
- Serializer: noopDentrySerializer{},
- }
- if _, err := fs.DirentReaddir(ctx, d, iter, root, dirCtx, 0); err != nil {
- return nil, err
- }
- for name := range dirCtx.DentAttrs() {
- if name == "." || name == ".." {
- continue
- }
-
- fullName := path.Join(base, name)
- paths = append(paths, fullName)
-
- // Recurse.
- subpaths, err := allPaths(ctx, t, m, fullName)
- if err != nil {
- return paths, err
- }
- paths = append(paths, subpaths...)
- }
- }
-
- return paths, nil
-}
-
-type noopDentrySerializer struct{}
-
-func (noopDentrySerializer) CopyOut(string, fs.DentAttr) error {
- return nil
-}
-func (noopDentrySerializer) Written() int {
- return 4096
-}
-
-// pathsEqual returns true if the two string slices contain the same entries.
-func pathsEqual(got, want []string) bool {
- sort.Strings(got)
- sort.Strings(want)
-
- if len(got) != len(want) {
- return false
- }
-
- for i := range got {
- if got[i] != want[i] {
- return false
- }
- }
-
- return true
-}
-
-func TestWhitelist(t *testing.T) {
- for _, test := range []struct {
- // description of the test.
- desc string
- // paths are the paths to whitelist
- paths []string
- // want are all of the directory entries that should be
- // visible (nothing beyond this set should be visible).
- want []string
- }{
- {
- desc: "root",
- paths: []string{"/"},
- want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt", "/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt", "/symlinks/recursive"},
- },
- {
- desc: "top-level directories",
- paths: []string{"/a", "/b"},
- want: []string{"/a", "/a/a1.txt", "/a/a2.txt", "/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "nested directories (1/2)",
- paths: []string{"/b", "/b/c"},
- want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "nested directories (2/2)",
- paths: []string{"/b/c", "/b"},
- want: []string{"/b", "/b/b1.txt", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "single file",
- paths: []string{"/b/c/c1.txt"},
- want: []string{"/b", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "single file and directory",
- paths: []string{"/a/a1.txt", "/b/c"},
- want: []string{"/a", "/a/a1.txt", "/b", "/b/c", "/b/c/c1.txt"},
- },
- {
- desc: "symlink",
- paths: []string{"/symlinks/to_normal.txt"},
- want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/to_normal.txt"},
- },
- {
- desc: "recursive symlink",
- paths: []string{"/symlinks/recursive/normal.txt"},
- want: []string{"/symlinks", "/symlinks/normal.txt", "/symlinks/recursive"},
- },
- } {
- t.Run(test.desc, func(t *testing.T) {
- m, p, err := newTestMountNamespace(t)
- if err != nil {
- t.Errorf("Failed to create MountNamespace: %v", err)
- }
- defer os.RemoveAll(p)
-
- ctx := withRoot(contexttest.RootContext(t), m.Root())
- if err := createTestDirs(ctx, t, m); err != nil {
- t.Errorf("Failed to create test dirs: %v", err)
- }
-
- if err := installWhitelist(ctx, m, test.paths); err != nil {
- t.Errorf("installWhitelist(%v) err got %v want nil", test.paths, err)
- }
-
- got, err := allPaths(ctx, t, m, "/")
- if err != nil {
- t.Fatalf("Failed to lookup paths (whitelisted: %v): %v", test.paths, err)
- }
-
- if !pathsEqual(got, test.want) {
- t.Errorf("For paths %v got %v want %v", test.paths, got, test.want)
- }
- })
- }
-}
-
-func TestRootPath(t *testing.T) {
- // Create a temp dir, which will be the root of our mounted fs.
- rootPath, err := ioutil.TempDir(os.TempDir(), "root")
- if err != nil {
- t.Fatalf("TempDir failed: %v", err)
- }
- defer os.RemoveAll(rootPath)
-
- // Create two files inside the new root, one which will be whitelisted
- // and one not.
- whitelisted, err := ioutil.TempFile(rootPath, "white")
- if err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
- if _, err := ioutil.TempFile(rootPath, "black"); err != nil {
- t.Fatalf("TempFile failed: %v", err)
- }
-
- // Create a mount with a root path and single whitelisted file.
- hostFS := &Filesystem{}
- ctx := contexttest.Context(t)
- data := fmt.Sprintf("%s=%s,%s=%s", rootPathKey, rootPath, whitelistKey, whitelisted.Name())
- inode, err := hostFS.Mount(ctx, "", fs.MountSourceFlags{}, data, nil)
- if err != nil {
- t.Fatalf("Mount failed: %v", err)
- }
- mm, err := fs.NewMountNamespace(ctx, inode)
- if err != nil {
- t.Fatalf("NewMountNamespace failed: %v", err)
- }
- if err := hostFS.InstallWhitelist(ctx, mm); err != nil {
- t.Fatalf("InstallWhitelist failed: %v", err)
- }
-
- // Get the contents of the root directory.
- rootDir := mm.Root()
- rctx := withRoot(ctx, rootDir)
- f, err := rootDir.Inode.GetFile(rctx, rootDir, fs.FileFlags{})
- if err != nil {
- t.Fatalf("GetFile failed: %v", err)
- }
- c := &fs.CollectEntriesSerializer{}
- if err := f.Readdir(rctx, c); err != nil {
- t.Fatalf("Readdir failed: %v", err)
- }
-
- // We should have only our whitelisted file, plus the dots.
- want := []string{path.Base(whitelisted.Name()), ".", ".."}
- got := c.Order
- sort.Strings(want)
- sort.Strings(got)
- if !reflect.DeepEqual(got, want) {
- t.Errorf("Readdir got %v, wanted %v", got, want)
- }
-}
-
-type rootContext struct {
- context.Context
- root *fs.Dirent
-}
-
-// withRoot returns a copy of ctx with the given root.
-func withRoot(ctx context.Context, root *fs.Dirent) context.Context {
- return &rootContext{
- Context: ctx,
- root: root,
- }
-}
-
-// Value implements Context.Value.
-func (rc rootContext) Value(key interface{}) interface{} {
- switch key {
- case fs.CtxRoot:
- rc.root.IncRef()
- return rc.root
- default:
- return rc.Context.Value(key)
- }
-}
diff --git a/pkg/sentry/fs/host/host.go b/pkg/sentry/fs/host/host.go
new file mode 100644
index 000000000..081ba1dd8
--- /dev/null
+++ b/pkg/sentry/fs/host/host.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host supports file descriptors imported directly.
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// filesystem is a host filesystem.
+//
+// +stateify savable
+type filesystem struct{}
+
+func init() {
+ fs.RegisterFilesystem(&filesystem{})
+}
+
+// FilesystemName is the name under which the filesystem is registered.
+const FilesystemName = "host"
+
+// Name is the name of the filesystem.
+func (*filesystem) Name() string {
+ return FilesystemName
+}
+
+// Mount returns an error. Mounting hostfs is not allowed.
+func (*filesystem) Mount(ctx context.Context, device string, flags fs.MountSourceFlags, data string, dataObj interface{}) (*fs.Inode, error) {
+ return nil, syserror.EPERM
+}
+
+// AllowUserMount prohibits users from using mount(2) with this file system.
+func (*filesystem) AllowUserMount() bool {
+ return false
+}
+
+// AllowUserList prohibits this filesystem to be listed in /proc/filesystems.
+func (*filesystem) AllowUserList() bool {
+ return false
+}
+
+// Flags returns that there is nothing special about this file system.
+func (*filesystem) Flags() fs.FilesystemFlags {
+ return 0
+}
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index a6e4a09e3..fbfba1b58 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -15,19 +15,17 @@
package host
import (
- "sync"
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/secio"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -69,9 +67,6 @@ type inodeOperations struct {
//
// +stateify savable
type inodeFileState struct {
- // Common file system state.
- mops *superOperations `state:"wait"`
-
// descriptor is the backing host FD.
descriptor *descriptor `state:"wait"`
@@ -160,7 +155,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
if err := syscall.Fstat(i.FD(), &s); err != nil {
return fs.UnstableAttr{}, err
}
- return unstableAttr(i.mops, &s), nil
+ return unstableAttr(&s), nil
}
// Allocate implements fsutil.CachedFileObject.Allocate.
@@ -172,7 +167,7 @@ func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error
var _ fs.InodeOperations = (*inodeOperations)(nil)
// newInode returns a new fs.Inode backed by the host FD.
-func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool, donated bool) (*fs.Inode, error) {
+func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool) (*fs.Inode, error) {
// Retrieve metadata.
var s syscall.Stat_t
err := syscall.Fstat(fd, &s)
@@ -181,24 +176,17 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool,
}
fileState := &inodeFileState{
- mops: msrc.MountSourceOperations.(*superOperations),
sattr: stableAttr(&s),
}
// Initialize the wrapped host file descriptor.
- fileState.descriptor, err = newDescriptor(
- fd,
- donated,
- saveable,
- wouldBlock(&s),
- &fileState.queue,
- )
+ fileState.descriptor, err = newDescriptor(fd, saveable, wouldBlock(&s), &fileState.queue)
if err != nil {
return nil, err
}
// Build the fs.InodeOperations.
- uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
+ uattr := unstableAttr(&s)
iops := &inodeOperations{
fileState: fileState,
cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
@@ -232,54 +220,23 @@ func (i *inodeOperations) Release(context.Context) {
// Lookup implements fs.InodeOperations.Lookup.
func (i *inodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name string) (*fs.Dirent, error) {
- // Get a new FD relative to i at name.
- fd, err := open(i, name)
- if err != nil {
- if err == syserror.ENOENT {
- return nil, syserror.ENOENT
- }
- return nil, err
- }
-
- inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
- if err != nil {
- return nil, err
- }
-
- // Return the fs.Dirent.
- return fs.NewDirent(ctx, inode, name), nil
+ return nil, syserror.ENOENT
}
// Create implements fs.InodeOperations.Create.
func (i *inodeOperations) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.FileFlags, perm fs.FilePermissions) (*fs.File, error) {
- // Create a file relative to i at name.
- //
- // N.B. We always open this file O_RDWR regardless of flags because a
- // future GetFile might want more access. Open allows this regardless
- // of perm.
- fd, err := openAt(i, name, syscall.O_RDWR|syscall.O_CREAT|syscall.O_EXCL, perm.LinuxMode())
- if err != nil {
- return nil, err
- }
-
- inode, err := newInode(ctx, dir.MountSource, fd, false /* saveable */, false /* donated */)
- if err != nil {
- return nil, err
- }
+ return nil, syserror.EPERM
- d := fs.NewDirent(ctx, inode, name)
- defer d.DecRef()
- return inode.GetFile(ctx, d, flags)
}
// CreateDirectory implements fs.InodeOperations.CreateDirectory.
func (i *inodeOperations) CreateDirectory(ctx context.Context, dir *fs.Inode, name string, perm fs.FilePermissions) error {
- return syscall.Mkdirat(i.fileState.FD(), name, uint32(perm.LinuxMode()))
+ return syserror.EPERM
}
// CreateLink implements fs.InodeOperations.CreateLink.
func (i *inodeOperations) CreateLink(ctx context.Context, dir *fs.Inode, oldname string, newname string) error {
- return createLink(i.fileState.FD(), oldname, newname)
+ return syserror.EPERM
}
// CreateHardLink implements fs.InodeOperations.CreateHardLink.
@@ -294,25 +251,17 @@ func (*inodeOperations) CreateFifo(context.Context, *fs.Inode, string, fs.FilePe
// Remove implements fs.InodeOperations.Remove.
func (i *inodeOperations) Remove(ctx context.Context, dir *fs.Inode, name string) error {
- return unlinkAt(i.fileState.FD(), name, false /* dir */)
+ return syserror.EPERM
}
// RemoveDirectory implements fs.InodeOperations.RemoveDirectory.
func (i *inodeOperations) RemoveDirectory(ctx context.Context, dir *fs.Inode, name string) error {
- return unlinkAt(i.fileState.FD(), name, true /* dir */)
+ return syserror.EPERM
}
// Rename implements fs.InodeOperations.Rename.
func (i *inodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
- op, ok := oldParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
- }
- np, ok := newParent.InodeOperations.(*inodeOperations)
- if !ok {
- return syscall.EXDEV
- }
- return syscall.Renameat(op.fileState.FD(), oldName, np.fileState.FD(), newName)
+ return syserror.EPERM
}
// Bind implements fs.InodeOperations.Bind.
@@ -419,6 +368,9 @@ func (i *inodeOperations) Allocate(ctx context.Context, inode *fs.Inode, offset,
// WriteOut implements fs.InodeOperations.WriteOut.
func (i *inodeOperations) WriteOut(ctx context.Context, inode *fs.Inode) error {
+ if inode.MountSource.Flags.ReadOnly {
+ return nil
+ }
// Have we been using host kernel metadata caches?
if !inode.MountSource.Flags.ForcePageCache || !canMap(inode) {
// Then the metadata is already up to date on the host.
@@ -448,82 +400,17 @@ func (i *inodeOperations) StatFS(context.Context) (fs.Info, error) {
}
// AddLink implements fs.InodeOperations.AddLink.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) AddLink() {}
// DropLink implements fs.InodeOperations.DropLink.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) DropLink() {}
// NotifyStatusChange implements fs.InodeOperations.NotifyStatusChange.
-// FIXME(b/63117438): Remove this from InodeOperations altogether.
func (i *inodeOperations) NotifyStatusChange(ctx context.Context) {}
// readdirAll returns all of the directory entries in i.
func (i *inodeOperations) readdirAll(d *dirInfo) (map[string]fs.DentAttr, error) {
- i.readdirMu.Lock()
- defer i.readdirMu.Unlock()
-
- fd := i.fileState.FD()
-
- // syscall.ReadDirent will use getdents, which will seek the file past
- // the last directory entry. To read the directory entries a second
- // time, we need to seek back to the beginning.
- if _, err := syscall.Seek(fd, 0, 0); err != nil {
- if err == syscall.ESPIPE {
- // All directories should be seekable. If this file
- // isn't seekable, it is not a directory and we should
- // return that more sane error.
- err = syscall.ENOTDIR
- }
- return nil, err
- }
-
- names := make([]string, 0, 100)
- for {
- // Refill the buffer if necessary
- if d.bufp >= d.nbuf {
- d.bufp = 0
- // ReadDirent will just do a sys_getdents64 to the kernel.
- n, err := syscall.ReadDirent(fd, d.buf)
- if err != nil {
- return nil, err
- }
- if n == 0 {
- break // EOF
- }
- d.nbuf = n
- }
-
- var nb int
- // Parse the dirent buffer we just get and return the directory names along
- // with the number of bytes consumed in the buffer.
- nb, _, names = syscall.ParseDirent(d.buf[d.bufp:d.nbuf], -1, names)
- d.bufp += nb
- }
-
- entries := make(map[string]fs.DentAttr)
- for _, filename := range names {
- // Lookup the type and host device and inode.
- stat, lerr := fstatat(fd, filename, linux.AT_SYMLINK_NOFOLLOW)
- if lerr == syscall.ENOENT {
- // File disappeared between readdir and lstat.
- // Just treat it as if it didn't exist.
- continue
- }
-
- // There was a serious problem, we should probably report it.
- if lerr != nil {
- return nil, lerr
- }
-
- entries[filename] = fs.DentAttr{
- Type: nodeType(&stat),
- InodeID: hostFileDevice.Map(device.MultiDeviceKey{
- Device: stat.Dev,
- Inode: stat.Ino,
- }),
- }
- }
- return entries, nil
+ // We only support non-directory file descriptors that have been
+ // imported, so just claim that this isn't a directory, even if it is.
+ return nil, syscall.ENOTDIR
}
diff --git a/pkg/sentry/fs/host/inode_state.go b/pkg/sentry/fs/host/inode_state.go
index b267ec305..1adbd4562 100644
--- a/pkg/sentry/fs/host/inode_state.go
+++ b/pkg/sentry/fs/host/inode_state.go
@@ -18,29 +18,14 @@ import (
"fmt"
"syscall"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
-// beforeSave is invoked by stateify.
-func (i *inodeFileState) beforeSave() {
- if !i.queue.IsEmpty() {
- panic("event queue must be empty")
- }
- if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
- uattr, err := i.unstableAttr(context.Background())
- if err != nil {
- panic(fs.ErrSaveRejection{fmt.Errorf("failed to get unstable atttribute of %s: %v", i.mops.inodeMappings[i.sattr.InodeID], err)})
- }
- i.savedUAttr = &uattr
- }
-}
-
// afterLoad is invoked by stateify.
func (i *inodeFileState) afterLoad() {
// Initialize the descriptor value.
- if err := i.descriptor.initAfterLoad(i.mops, i.sattr.InodeID, &i.queue); err != nil {
+ if err := i.descriptor.initAfterLoad(i.sattr.InodeID, &i.queue); err != nil {
panic(fmt.Sprintf("failed to load value of descriptor: %v", err))
}
@@ -61,19 +46,4 @@ func (i *inodeFileState) afterLoad() {
// change across save and restore, error out.
panic(fs.ErrCorruption{fmt.Errorf("host %s conflict in host device mappings: %s", key, hostFileDevice)})
}
-
- if !i.descriptor.donated && i.sattr.Type == fs.RegularFile {
- env, ok := fs.CurrentRestoreEnvironment()
- if !ok {
- panic("missing restore environment")
- }
- uattr := unstableAttr(i.mops, &s)
- if env.ValidateFileSize && uattr.Size != i.savedUAttr.Size {
- panic(fs.ErrCorruption{fmt.Errorf("file size has changed for %s: previously %d, now %d", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.Size, uattr.Size)})
- }
- if env.ValidateFileTimestamp && uattr.ModificationTime != i.savedUAttr.ModificationTime {
- panic(fs.ErrCorruption{fmt.Errorf("file modification time has changed for %s: previously %v, now %v", i.mops.inodeMappings[i.sattr.InodeID], i.savedUAttr.ModificationTime, uattr.ModificationTime)})
- }
- i.savedUAttr = nil
- }
}
diff --git a/pkg/sentry/fs/host/inode_test.go b/pkg/sentry/fs/host/inode_test.go
index 2d959f10d..41a23b5da 100644
--- a/pkg/sentry/fs/host/inode_test.go
+++ b/pkg/sentry/fs/host/inode_test.go
@@ -15,79 +15,12 @@
package host
import (
- "io/ioutil"
- "os"
- "path"
"syscall"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
)
-// TestMultipleReaddir verifies that multiple Readdir calls return the same
-// thing if they use different dir contexts.
-func TestMultipleReaddir(t *testing.T) {
- p, err := ioutil.TempDir("", "readdir")
- if err != nil {
- t.Fatalf("Failed to create test dir: %v", err)
- }
- defer os.RemoveAll(p)
-
- f, err := os.Create(path.Join(p, "a.txt"))
- if err != nil {
- t.Fatalf("Failed to create a.txt: %v", err)
- }
- f.Close()
-
- f, err = os.Create(path.Join(p, "b.txt"))
- if err != nil {
- t.Fatalf("Failed to create b.txt: %v", err)
- }
- f.Close()
-
- fd, err := open(nil, p)
- if err != nil {
- t.Fatalf("Failed to open %q: %v", p, err)
- }
- ctx := contexttest.Context(t)
- n, err := newInode(ctx, newMountSource(ctx, p, fs.RootOwner, &Filesystem{}, fs.MountSourceFlags{}, false), fd, false, false)
- if err != nil {
- t.Fatalf("Failed to create inode: %v", err)
- }
-
- dirent := fs.NewDirent(ctx, n, "readdir")
- openFile, err := n.GetFile(ctx, dirent, fs.FileFlags{Read: true})
- if err != nil {
- t.Fatalf("Failed to get file: %v", err)
- }
- defer openFile.DecRef()
-
- c1 := &fs.DirCtx{DirCursor: new(string)}
- if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c1, 0); err != nil {
- t.Fatalf("First Readdir failed: %v", err)
- }
-
- c2 := &fs.DirCtx{DirCursor: new(string)}
- if _, err := openFile.FileOperations.(*fileOperations).IterateDir(ctx, dirent, c2, 0); err != nil {
- t.Errorf("Second Readdir failed: %v", err)
- }
-
- if _, ok := c1.DentAttrs()["a.txt"]; !ok {
- t.Errorf("want a.txt in first Readdir, got %v", c1.DentAttrs())
- }
- if _, ok := c1.DentAttrs()["b.txt"]; !ok {
- t.Errorf("want b.txt in first Readdir, got %v", c1.DentAttrs())
- }
-
- if _, ok := c2.DentAttrs()["a.txt"]; !ok {
- t.Errorf("want a.txt in second Readdir, got %v", c2.DentAttrs())
- }
- if _, ok := c2.DentAttrs()["b.txt"]; !ok {
- t.Errorf("want b.txt in second Readdir, got %v", c2.DentAttrs())
- }
-}
-
// TestCloseFD verifies fds will be closed.
func TestCloseFD(t *testing.T) {
var p [2]int
@@ -99,11 +32,11 @@ func TestCloseFD(t *testing.T) {
// Use the write-end because we will detect if it's closed on the read end.
ctx := contexttest.Context(t)
- file, err := NewFile(ctx, p[1], fs.RootOwner)
+ file, err := NewFile(ctx, p[1])
if err != nil {
t.Fatalf("Failed to create File: %v", err)
}
- file.DecRef()
+ file.DecRef(ctx)
s := make([]byte, 10)
if c, err := syscall.Read(p[0], s); c != 0 || err != nil {
diff --git a/pkg/sentry/fs/host/ioctl_unsafe.go b/pkg/sentry/fs/host/ioctl_unsafe.go
index 271582e54..150ac8e19 100644
--- a/pkg/sentry/fs/host/ioctl_unsafe.go
+++ b/pkg/sentry/fs/host/ioctl_unsafe.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
+// LINT.IfChange
+
func ioctlGetTermios(fd int) (*linux.Termios, error) {
var t linux.Termios
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
@@ -54,3 +56,5 @@ func ioctlSetWinsize(fd int, w *linux.Winsize) error {
}
return nil
}
+
+// LINT.ThenChange(../../fsimpl/host/ioctl_unsafe.go)
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 107336a3e..a2f3d5918 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -16,20 +16,19 @@ package host
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
- "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -37,10 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// maxSendBufferSize is the maximum host send buffer size allowed for endpoint.
-//
-// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max).
-const maxSendBufferSize = 8 << 20
+// LINT.IfChange
// ConnectedEndpoint is a host FD backed implementation of
// transport.ConnectedEndpoint and transport.Receiver.
@@ -101,10 +97,6 @@ func (c *ConnectedEndpoint) init() *syserr.Error {
if err != nil {
return syserr.FromError(err)
}
- if sndbuf > maxSendBufferSize {
- log.Warningf("Socket send buffer too large: %d", sndbuf)
- return syserr.ErrInvalidEndpointState
- }
c.stype = linux.SockType(stype)
c.sndbuf = int64(sndbuf)
@@ -202,7 +194,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
}
// Send implements transport.ConnectedEndpoint.Send.
-func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -279,7 +271,7 @@ func (c *ConnectedEndpoint) EventUpdate() {
}
// Recv implements transport.Receiver.Recv.
-func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -326,7 +318,7 @@ func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights int, peek
}
// close releases all resources related to the endpoint.
-func (c *ConnectedEndpoint) close() {
+func (c *ConnectedEndpoint) close(context.Context) {
fdnotifier.RemoveFD(int32(c.file.FD()))
c.file.Close()
c.file = nil
@@ -382,9 +374,11 @@ func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
}
// Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release.
-func (c *ConnectedEndpoint) Release() {
- c.ref.DecRefWithDestructor(c.close)
+func (c *ConnectedEndpoint) Release(ctx context.Context) {
+ c.ref.DecRefWithDestructor(ctx, c.close)
}
// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
func (c *ConnectedEndpoint) CloseUnread() {}
+
+// LINT.ThenChange(../../fsimpl/host/socket.go)
diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go
index af6955675..905afb50d 100644
--- a/pkg/sentry/fs/host/socket_iovec.go
+++ b/pkg/sentry/fs/host/socket_iovec.go
@@ -17,12 +17,11 @@ package host
import (
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/iovec"
"gvisor.dev/gvisor/pkg/syserror"
)
-// maxIovs is the maximum number of iovecs to pass to the host.
-var maxIovs = linux.UIO_MAXIOV
+// LINT.IfChange
// copyToMulti copies as many bytes from src to dst as possible.
func copyToMulti(dst [][]byte, src []byte) {
@@ -74,7 +73,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
}
}
- if iovsRequired > maxIovs {
+ if iovsRequired > iovec.MaxIovs {
// The kernel will reject our call if we pass this many iovs.
// Use a single intermediate buffer instead.
b := make([]byte, stopLen)
@@ -111,3 +110,5 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec
return total, iovecs, nil, err
}
+
+// LINT.ThenChange(../../fsimpl/host/socket_iovec.go)
diff --git a/pkg/sentry/fs/host/socket_test.go b/pkg/sentry/fs/host/socket_test.go
index 68b38fd1c..9d58ea448 100644
--- a/pkg/sentry/fs/host/socket_test.go
+++ b/pkg/sentry/fs/host/socket_test.go
@@ -21,13 +21,13 @@ import (
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/fdnotifier"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -67,11 +67,12 @@ func TestSocketIsBlocking(t *testing.T) {
if fl&syscall.O_NONBLOCK == syscall.O_NONBLOCK {
t.Fatalf("Expected socket %v to be blocking", pair[1])
}
- sock, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ sock, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) failed => %v", pair[0], err)
}
- defer sock.DecRef()
+ defer sock.DecRef(ctx)
// Test that the socket now is non-blocking.
if fl, err = getFl(pair[0]); err != nil {
t.Fatalf("getFl: fcntl(%v, GETFL) => %v", pair[0], err)
@@ -93,11 +94,12 @@ func TestSocketWritev(t *testing.T) {
if err != nil {
t.Fatalf("host socket creation failed: %v", err)
}
- socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ socket, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer socket.DecRef()
+ defer socket.DecRef(ctx)
buf := []byte("hello world\n")
n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(buf))
if err != nil {
@@ -115,11 +117,12 @@ func TestSocketWritevLen0(t *testing.T) {
if err != nil {
t.Fatalf("host socket creation failed: %v", err)
}
- socket, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ socket, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer socket.DecRef()
+ defer socket.DecRef(ctx)
n, err := socket.Writev(contexttest.Context(t), usermem.BytesIOSequence(nil))
if err != nil {
t.Fatalf("socket writev failed: %v", err)
@@ -136,11 +139,12 @@ func TestSocketSendMsgLen0(t *testing.T) {
if err != nil {
t.Fatalf("host socket creation failed: %v", err)
}
- sfile, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ sfile, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer sfile.DecRef()
+ defer sfile.DecRef(ctx)
s := sfile.FileOperations.(socket.Socket)
n, terr := s.SendMsg(nil, usermem.BytesIOSequence(nil), []byte{}, 0, false, ktime.Time{}, socket.ControlMessages{})
@@ -158,18 +162,19 @@ func TestListen(t *testing.T) {
if err != nil {
t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
}
- sfile1, err := newSocket(contexttest.Context(t), pair[0], false)
+ ctx := contexttest.Context(t)
+ sfile1, err := newSocket(ctx, pair[0], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[0], err)
}
- defer sfile1.DecRef()
+ defer sfile1.DecRef(ctx)
socket1 := sfile1.FileOperations.(socket.Socket)
- sfile2, err := newSocket(contexttest.Context(t), pair[1], false)
+ sfile2, err := newSocket(ctx, pair[1], false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", pair[1], err)
}
- defer sfile2.DecRef()
+ defer sfile2.DecRef(ctx)
socket2 := sfile2.FileOperations.(socket.Socket)
// Socketpairs can not be listened to.
@@ -185,11 +190,11 @@ func TestListen(t *testing.T) {
if err != nil {
t.Fatalf("syscall.Socket(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) => %v", err)
}
- sfile3, err := newSocket(contexttest.Context(t), sock, false)
+ sfile3, err := newSocket(ctx, sock, false)
if err != nil {
t.Fatalf("newSocket(%v) => %v", sock, err)
}
- defer sfile3.DecRef()
+ defer sfile3.DecRef(ctx)
socket3 := sfile3.FileOperations.(socket.Socket)
// This socket is not bound so we can't listen on it.
@@ -199,14 +204,14 @@ func TestListen(t *testing.T) {
}
func TestPasscred(t *testing.T) {
- e := ConnectedEndpoint{}
+ e := &ConnectedEndpoint{}
if got, want := e.Passcred(), false; got != want {
t.Errorf("Got %#v.Passcred() = %t, want = %t", e, got, want)
}
}
func TestGetLocalAddress(t *testing.T) {
- e := ConnectedEndpoint{path: "foo"}
+ e := &ConnectedEndpoint{path: "foo"}
want := tcpip.FullAddress{Addr: tcpip.Address("foo")}
if got, err := e.GetLocalAddress(); err != nil || got != want {
t.Errorf("Got %#v.GetLocalAddress() = %#v, %v, want = %#v, %v", e, got, err, want, nil)
@@ -214,7 +219,7 @@ func TestGetLocalAddress(t *testing.T) {
}
func TestQueuedSize(t *testing.T) {
- e := ConnectedEndpoint{}
+ e := &ConnectedEndpoint{}
tests := []struct {
name string
f func() int64
@@ -237,9 +242,10 @@ func TestRelease(t *testing.T) {
}
c := &ConnectedEndpoint{queue: &waiter.Queue{}, file: fd.New(f)}
want := &ConnectedEndpoint{queue: c.queue}
- want.ref.DecRef()
+ ctx := contexttest.Context(t)
+ want.ref.DecRef(ctx)
fdnotifier.AddFD(int32(c.file.FD()), nil)
- c.Release()
+ c.Release(ctx)
if !reflect.DeepEqual(c, want) {
t.Errorf("got = %#v, want = %#v", c, want)
}
diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go
index f3bbed7ea..5d4f312cf 100644
--- a/pkg/sentry/fs/host/socket_unsafe.go
+++ b/pkg/sentry/fs/host/socket_unsafe.go
@@ -19,6 +19,8 @@ import (
"unsafe"
)
+// LINT.IfChange
+
// fdReadVec receives from fd to bufs.
//
// If the total length of bufs is > maxlen, fdReadVec will do a partial read
@@ -99,3 +101,5 @@ func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int6
return int64(n), length, err
}
+
+// LINT.ThenChange(../../fsimpl/host/socket_unsafe.go)
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 90331e3b2..b5229098c 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -15,18 +15,19 @@
package host
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// TTYFileOperations implements fs.FileOperations for a host file descriptor
// that wraps a TTY FD.
//
@@ -44,6 +45,7 @@ type TTYFileOperations struct {
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+ // termios contains the terminal attributes for this TTY.
termios linux.KernelTermios
}
@@ -111,12 +113,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme
}
// Release implements fs.FileOperations.Release.
-func (t *TTYFileOperations) Release() {
+func (t *TTYFileOperations) Release(ctx context.Context) {
t.mu.Lock()
t.fgProcessGroup = nil
t.mu.Unlock()
- t.fileOperations.Release()
+ t.fileOperations.Release(ctx)
}
// Ioctl implements fs.FileOperations.Ioctl.
@@ -306,9 +308,9 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
task := kernel.TaskFromContext(ctx)
if task == nil {
// No task? Linux does not have an analog for this case, but
- // tty_check_change is more of a blacklist of cases than a
- // whitelist, and is surprisingly permissive. Allowing the
- // change seems most appropriate.
+ // tty_check_change only blocks specific cases and is
+ // surprisingly permissive. Allowing the change seems
+ // appropriate.
return nil
}
@@ -358,3 +360,5 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e
_ = pg.SendSignal(kernel.SignalInfoPriv(sig))
return kernel.ERESTARTSYS
}
+
+// LINT.ThenChange(../../fsimpl/host/tty.go)
diff --git a/pkg/sentry/fs/host/util.go b/pkg/sentry/fs/host/util.go
index bad61a9a1..1b0356930 100644
--- a/pkg/sentry/fs/host/util.go
+++ b/pkg/sentry/fs/host/util.go
@@ -16,7 +16,6 @@ package host
import (
"os"
- "path"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -28,45 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-func open(parent *inodeOperations, name string) (int, error) {
- if parent == nil && !path.IsAbs(name) {
- return -1, syserror.EINVAL
- }
- name = path.Clean(name)
-
- // Don't follow through symlinks.
- flags := syscall.O_NOFOLLOW
-
- if fd, err := openAt(parent, name, flags|syscall.O_RDWR, 0); err == nil {
- return fd, nil
- }
- // Retry as read-only.
- if fd, err := openAt(parent, name, flags|syscall.O_RDONLY, 0); err == nil {
- return fd, nil
- }
-
- // Retry as write-only.
- if fd, err := openAt(parent, name, flags|syscall.O_WRONLY, 0); err == nil {
- return fd, nil
- }
-
- // Retry as a symlink, by including O_PATH as an option.
- fd, err := openAt(parent, name, linux.O_PATH|flags, 0)
- if err == nil {
- return fd, nil
- }
-
- // Everything failed.
- return -1, err
-}
-
-func openAt(parent *inodeOperations, name string, flags int, perm linux.FileMode) (int, error) {
- if parent == nil {
- return syscall.Open(name, flags, uint32(perm))
- }
- return syscall.Openat(parent.fileState.FD(), name, flags, uint32(perm))
-}
-
func nodeType(s *syscall.Stat_t) fs.InodeType {
switch x := (s.Mode & syscall.S_IFMT); x {
case syscall.S_IFLNK:
@@ -107,55 +67,23 @@ func stableAttr(s *syscall.Stat_t) fs.StableAttr {
}
}
-func owner(mo *superOperations, s *syscall.Stat_t) fs.FileOwner {
- // User requested no translation, just return actual owner.
- if mo.dontTranslateOwnership {
- return fs.FileOwner{auth.KUID(s.Uid), auth.KGID(s.Gid)}
+func owner(s *syscall.Stat_t) fs.FileOwner {
+ return fs.FileOwner{
+ UID: auth.KUID(s.Uid),
+ GID: auth.KGID(s.Gid),
}
-
- // Show only IDs relevant to the sandboxed task. I.e. if we not own the
- // file, no sandboxed task can own the file. In that case, we
- // use OverflowID for UID, implying that the IDs are not mapped in the
- // "root" user namespace.
- //
- // E.g.
- // sandbox's host EUID/EGID is 1/1.
- // some_dir's host UID/GID is 2/1.
- // Task that mounted this fs has virtualized EUID/EGID 5/5.
- //
- // If you executed `ls -n` in the sandboxed task, it would show:
- // drwxwrxwrx [...] 65534 5 [...] some_dir
-
- // Files are owned by OverflowID by default.
- owner := fs.FileOwner{auth.KUID(auth.OverflowUID), auth.KGID(auth.OverflowGID)}
-
- // If we own file on host, let mounting task's initial EUID own
- // the file.
- if s.Uid == hostUID {
- owner.UID = mo.mounter.UID
- }
-
- // If our group matches file's group, make file's group match
- // the mounting task's initial EGID.
- for _, gid := range hostGIDs {
- if s.Gid == gid {
- owner.GID = mo.mounter.GID
- break
- }
- }
- return owner
}
-func unstableAttr(mo *superOperations, s *syscall.Stat_t) fs.UnstableAttr {
+func unstableAttr(s *syscall.Stat_t) fs.UnstableAttr {
return fs.UnstableAttr{
Size: s.Size,
Usage: s.Blocks * 512,
Perms: fs.FilePermsFromMode(linux.FileMode(s.Mode)),
- Owner: owner(mo, s),
+ Owner: owner(s),
AccessTime: ktime.FromUnix(s.Atim.Sec, s.Atim.Nsec),
ModificationTime: ktime.FromUnix(s.Mtim.Sec, s.Mtim.Nsec),
StatusChangeTime: ktime.FromUnix(s.Ctim.Sec, s.Ctim.Nsec),
- Links: s.Nlink,
+ Links: uint64(s.Nlink),
}
}
@@ -165,6 +93,8 @@ type dirInfo struct {
bufp int // location of next record in buf.
}
+// LINT.IfChange
+
// isBlockError unwraps os errors and checks if they are caused by EAGAIN or
// EWOULDBLOCK. This is so they can be transformed into syserror.ErrWouldBlock.
func isBlockError(err error) bool {
@@ -177,6 +107,8 @@ func isBlockError(err error) bool {
return false
}
+// LINT.ThenChange(../../fsimpl/host/util.go)
+
func hostEffectiveKIDs() (uint32, []uint32, error) {
gids, err := os.Getgroups()
if err != nil {
diff --git a/pkg/sentry/fs/host/util_amd64_unsafe.go b/pkg/sentry/fs/host/util_amd64_unsafe.go
new file mode 100644
index 000000000..66da6e9f5
--- /dev/null
+++ b/pkg/sentry/fs/host/util_amd64_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return stat, err
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_NEWFSTATAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(unsafe.Pointer(&stat)),
+ uintptr(flags),
+ 0, 0)
+ if errno != 0 {
+ return stat, errno
+ }
+ return stat, nil
+}
diff --git a/pkg/sentry/fs/host/util_arm64_unsafe.go b/pkg/sentry/fs/host/util_arm64_unsafe.go
new file mode 100644
index 000000000..e8cb94aeb
--- /dev/null
+++ b/pkg/sentry/fs/host/util_arm64_unsafe.go
@@ -0,0 +1,41 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ namePtr, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return stat, err
+ }
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_FSTATAT,
+ uintptr(fd),
+ uintptr(unsafe.Pointer(namePtr)),
+ uintptr(unsafe.Pointer(&stat)),
+ uintptr(flags),
+ 0, 0)
+ if errno != 0 {
+ return stat, errno
+ }
+ return stat, nil
+}
diff --git a/pkg/sentry/fs/host/util_unsafe.go b/pkg/sentry/fs/host/util_unsafe.go
index 2b76f1065..23bd35d64 100644
--- a/pkg/sentry/fs/host/util_unsafe.go
+++ b/pkg/sentry/fs/host/util_unsafe.go
@@ -26,26 +26,6 @@ import (
// NulByte is a single NUL byte. It is passed to readlinkat as an empty string.
var NulByte byte = '\x00'
-func createLink(fd int, name string, linkName string) error {
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return err
- }
- linkNamePtr, err := syscall.BytePtrFromString(linkName)
- if err != nil {
- return err
- }
- _, _, errno := syscall.Syscall(
- syscall.SYS_SYMLINKAT,
- uintptr(unsafe.Pointer(namePtr)),
- uintptr(fd),
- uintptr(unsafe.Pointer(linkNamePtr)))
- if errno != 0 {
- return errno
- }
- return nil
-}
-
func readLink(fd int) (string, error) {
// Buffer sizing copied from os.Readlink.
for l := 128; ; l *= 2 {
@@ -66,27 +46,6 @@ func readLink(fd int) (string, error) {
}
}
-func unlinkAt(fd int, name string, dir bool) error {
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return err
- }
- var flags uintptr
- if dir {
- flags = linux.AT_REMOVEDIR
- }
- _, _, errno := syscall.Syscall(
- syscall.SYS_UNLINKAT,
- uintptr(fd),
- uintptr(unsafe.Pointer(namePtr)),
- flags,
- )
- if errno != 0 {
- return errno
- }
- return nil
-}
-
func timespecFromTimestamp(t ktime.Time, omit, setSysTime bool) syscall.Timespec {
if omit {
return syscall.Timespec{0, linux.UTIME_OMIT}
@@ -116,22 +75,3 @@ func setTimestamps(fd int, ts fs.TimeSpec) error {
}
return nil
}
-
-func fstatat(fd int, name string, flags int) (syscall.Stat_t, error) {
- var stat syscall.Stat_t
- namePtr, err := syscall.BytePtrFromString(name)
- if err != nil {
- return stat, err
- }
- _, _, errno := syscall.Syscall6(
- syscall.SYS_NEWFSTATAT,
- uintptr(fd),
- uintptr(unsafe.Pointer(namePtr)),
- uintptr(unsafe.Pointer(&stat)),
- uintptr(flags),
- 0, 0)
- if errno != 0 {
- return stat, errno
- }
- return stat, nil
-}
diff --git a/pkg/sentry/fs/host/wait_test.go b/pkg/sentry/fs/host/wait_test.go
index 88d24d693..c143f4ce2 100644
--- a/pkg/sentry/fs/host/wait_test.go
+++ b/pkg/sentry/fs/host/wait_test.go
@@ -19,8 +19,7 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -34,13 +33,13 @@ func TestWait(t *testing.T) {
defer syscall.Close(fds[1])
ctx := contexttest.Context(t)
- file, err := NewFile(ctx, fds[0], fs.RootOwner)
+ file, err := NewFile(ctx, fds[0])
if err != nil {
syscall.Close(fds[0])
t.Fatalf("NewFile failed: %v", err)
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
r := file.Readiness(waiter.EventIn)
if r != 0 {
diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go
index f4ddfa406..b79cd9877 100644
--- a/pkg/sentry/fs/inode.go
+++ b/pkg/sentry/fs/inode.go
@@ -15,17 +15,16 @@
package fs
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -97,14 +96,12 @@ func NewInode(ctx context.Context, iops InodeOperations, msrc *MountSource, satt
}
// DecRef drops a reference on the Inode.
-func (i *Inode) DecRef() {
- i.DecRefWithDestructor(i.destroy)
+func (i *Inode) DecRef(ctx context.Context) {
+ i.DecRefWithDestructor(ctx, i.destroy)
}
// destroy releases the Inode and releases the msrc reference taken.
-func (i *Inode) destroy() {
- // FIXME(b/38173783): Context is not plumbed here.
- ctx := context.Background()
+func (i *Inode) destroy(ctx context.Context) {
if err := i.WriteOut(ctx); err != nil {
// FIXME(b/65209558): Mark as warning again once noatime is
// properly supported.
@@ -124,12 +121,12 @@ func (i *Inode) destroy() {
i.Watches.targetDestroyed()
if i.overlay != nil {
- i.overlay.release()
+ i.overlay.release(ctx)
} else {
i.InodeOperations.Release(ctx)
}
- i.MountSource.DecRef()
+ i.MountSource.DecRef(ctx)
}
// Mappable calls i.InodeOperations.Mappable.
@@ -262,20 +259,36 @@ func (i *Inode) UnstableAttr(ctx context.Context) (UnstableAttr, error) {
return i.InodeOperations.UnstableAttr(ctx, i)
}
-// Getxattr calls i.InodeOperations.Getxattr with i as the Inode.
-func (i *Inode) Getxattr(name string) (string, error) {
+// GetXattr calls i.InodeOperations.GetXattr with i as the Inode.
+func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string, error) {
+ if i.overlay != nil {
+ return overlayGetXattr(ctx, i.overlay, name, size)
+ }
+ return i.InodeOperations.GetXattr(ctx, i, name, size)
+}
+
+// SetXattr calls i.InodeOperations.SetXattr with i as the Inode.
+func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error {
+ if i.overlay != nil {
+ return overlaySetxattr(ctx, i.overlay, d, name, value, flags)
+ }
+ return i.InodeOperations.SetXattr(ctx, i, name, value, flags)
+}
+
+// ListXattr calls i.InodeOperations.ListXattr with i as the Inode.
+func (i *Inode) ListXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
if i.overlay != nil {
- return overlayGetxattr(i.overlay, name)
+ return overlayListXattr(ctx, i.overlay, size)
}
- return i.InodeOperations.Getxattr(i, name)
+ return i.InodeOperations.ListXattr(ctx, i, size)
}
-// Listxattr calls i.InodeOperations.Listxattr with i as the Inode.
-func (i *Inode) Listxattr() (map[string]struct{}, error) {
+// RemoveXattr calls i.InodeOperations.RemoveXattr with i as the Inode.
+func (i *Inode) RemoveXattr(ctx context.Context, d *Dirent, name string) error {
if i.overlay != nil {
- return overlayListxattr(i.overlay)
+ return overlayRemoveXattr(ctx, i.overlay, d, name)
}
- return i.InodeOperations.Listxattr(i)
+ return i.InodeOperations.RemoveXattr(ctx, i, name)
}
// CheckPermission will check if the caller may access this file in the
@@ -344,6 +357,10 @@ func (i *Inode) SetTimestamps(ctx context.Context, d *Dirent, ts TimeSpec) error
// Truncate calls i.InodeOperations.Truncate with i as the Inode.
func (i *Inode) Truncate(ctx context.Context, d *Dirent, size int64) error {
+ if IsDir(i.StableAttr) {
+ return syserror.EISDIR
+ }
+
if i.overlay != nil {
return overlayTruncate(ctx, i.overlay, d, size)
}
@@ -378,8 +395,6 @@ func (i *Inode) Getlink(ctx context.Context) (*Dirent, error) {
// AddLink calls i.InodeOperations.AddLink.
func (i *Inode) AddLink() {
if i.overlay != nil {
- // FIXME(b/63117438): Remove this from InodeOperations altogether.
- //
// This interface is only used by ramfs to update metadata of
// children. These filesystems should _never_ have overlay
// Inodes cached as children. So explicitly disallow this
diff --git a/pkg/sentry/fs/inode_inotify.go b/pkg/sentry/fs/inode_inotify.go
index 0f2a66a79..9911a00c2 100644
--- a/pkg/sentry/fs/inode_inotify.go
+++ b/pkg/sentry/fs/inode_inotify.go
@@ -16,7 +16,9 @@ package fs
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Watches is the collection of inotify watches on an inode.
@@ -135,11 +137,11 @@ func (w *Watches) Notify(name string, events, cookie uint32) {
}
// Unpin unpins dirent from all watches in this set.
-func (w *Watches) Unpin(d *Dirent) {
+func (w *Watches) Unpin(ctx context.Context, d *Dirent) {
w.mu.RLock()
defer w.mu.RUnlock()
for _, watch := range w.ws {
- watch.Unpin(d)
+ watch.Unpin(ctx, d)
}
}
diff --git a/pkg/sentry/fs/inode_operations.go b/pkg/sentry/fs/inode_operations.go
index 5cde9d215..2bbfb72ef 100644
--- a/pkg/sentry/fs/inode_operations.go
+++ b/pkg/sentry/fs/inode_operations.go
@@ -17,7 +17,7 @@ package fs
import (
"errors"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
@@ -170,20 +170,38 @@ type InodeOperations interface {
// file system events.
UnstableAttr(ctx context.Context, inode *Inode) (UnstableAttr, error)
- // Getxattr retrieves the value of extended attribute name. Inodes that
- // do not support extended attributes return EOPNOTSUPP. Inodes that
- // support extended attributes but don't have a value at name return
+ // GetXattr retrieves the value of extended attribute specified by name.
+ // Inodes that do not support extended attributes return EOPNOTSUPP. Inodes
+ // that support extended attributes but don't have a value at name return
// ENODATA.
- Getxattr(inode *Inode, name string) (string, error)
+ //
+ // If this is called through the getxattr(2) syscall, size indicates the
+ // size of the buffer that the application has allocated to hold the
+ // attribute value. If the value is larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ GetXattr(ctx context.Context, inode *Inode, name string, size uint64) (string, error)
- // Setxattr sets the value of extended attribute name. Inodes that
- // do not support extended attributes return EOPNOTSUPP.
- Setxattr(inode *Inode, name, value string) error
+ // SetXattr sets the value of extended attribute specified by name. Inodes
+ // that do not support extended attributes return EOPNOTSUPP.
+ SetXattr(ctx context.Context, inode *Inode, name, value string, flags uint32) error
- // Listxattr returns the set of all extended attributes names that
+ // ListXattr returns the set of all extended attributes names that
// have values. Inodes that do not support extended attributes return
// EOPNOTSUPP.
- Listxattr(inode *Inode) (map[string]struct{}, error)
+ //
+ // If this is called through the listxattr(2) syscall, size indicates the
+ // size of the buffer that the application has allocated to hold the
+ // attribute list. If the list would be larger than size, implementations may
+ // return ERANGE to indicate that the buffer is too small, but they are also
+ // free to ignore the hint entirely. All size checking is done independently
+ // at the syscall layer.
+ ListXattr(ctx context.Context, inode *Inode, size uint64) (map[string]struct{}, error)
+
+ // RemoveXattr removes an extended attribute specified by name. Inodes that
+ // do not support extended attributes return EOPNOTSUPP.
+ RemoveXattr(ctx context.Context, inode *Inode, name string) error
// Check determines whether an Inode can be accessed with the
// requested permission mask using the context (which gives access
diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go
index 5a388dad1..dc2e353d9 100644
--- a/pkg/sentry/fs/inode_overlay.go
+++ b/pkg/sentry/fs/inode_overlay.go
@@ -19,19 +19,19 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
)
-func overlayHasWhiteout(parent *Inode, name string) bool {
- s, err := parent.Getxattr(XattrOverlayWhiteout(name))
+func overlayHasWhiteout(ctx context.Context, parent *Inode, name string) bool {
+ s, err := parent.GetXattr(ctx, XattrOverlayWhiteout(name), 1)
return err == nil && s == "y"
}
-func overlayCreateWhiteout(parent *Inode, name string) error {
- return parent.InodeOperations.Setxattr(parent, XattrOverlayWhiteout(name), "y")
+func overlayCreateWhiteout(ctx context.Context, parent *Inode, name string) error {
+ return parent.InodeOperations.SetXattr(ctx, parent, XattrOverlayWhiteout(name), "y", 0 /* flags */)
}
func overlayWriteOut(ctx context.Context, o *overlayEntry) error {
@@ -85,11 +85,11 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
upperInode = child.Inode
upperInode.IncRef()
}
- child.DecRef()
+ child.DecRef(ctx)
}
// Are we done?
- if overlayHasWhiteout(parent.upper, name) {
+ if overlayHasWhiteout(ctx, parent.upper, name) {
if upperInode == nil {
parent.copyMu.RUnlock()
if negativeUpperChild {
@@ -108,7 +108,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
entry, err := newOverlayEntry(ctx, upperInode, nil, false)
if err != nil {
// Don't leak resources.
- upperInode.DecRef()
+ upperInode.DecRef(ctx)
parent.copyMu.RUnlock()
return nil, false, err
}
@@ -129,7 +129,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
if err != nil && err != syserror.ENOENT {
// Don't leak resources.
if upperInode != nil {
- upperInode.DecRef()
+ upperInode.DecRef(ctx)
}
parent.copyMu.RUnlock()
return nil, false, err
@@ -152,7 +152,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
}
}
}
- child.DecRef()
+ child.DecRef(ctx)
}
}
@@ -183,7 +183,7 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
// unnecessary because we don't need to copy-up and we will always
// operate (e.g. read/write) on the upper Inode.
if !IsDir(upperInode.StableAttr) {
- lowerInode.DecRef()
+ lowerInode.DecRef(ctx)
lowerInode = nil
}
}
@@ -194,10 +194,10 @@ func overlayLookup(ctx context.Context, parent *overlayEntry, inode *Inode, name
// Well, not quite, we failed at the last moment, how depressing.
// Be sure not to leak resources.
if upperInode != nil {
- upperInode.DecRef()
+ upperInode.DecRef(ctx)
}
if lowerInode != nil {
- lowerInode.DecRef()
+ lowerInode.DecRef(ctx)
}
parent.copyMu.RUnlock()
return nil, false, err
@@ -231,7 +231,8 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st
upperFile.Dirent.Inode.IncRef()
entry, err := newOverlayEntry(ctx, upperFile.Dirent.Inode, nil, false)
if err != nil {
- cleanupUpper(ctx, o.upper, name)
+ werr := fmt.Errorf("newOverlayEntry failed: %v", err)
+ cleanupUpper(ctx, o.upper, name, werr)
return nil, err
}
@@ -247,7 +248,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st
// user) will clobber the real path for the underlying Inode.
upperFile.Dirent.Inode.IncRef()
upperDirent := NewTransientDirent(upperFile.Dirent.Inode)
- upperFile.Dirent.DecRef()
+ upperFile.Dirent.DecRef(ctx)
upperFile.Dirent = upperDirent
// Create the overlay inode and dirent. We need this to construct the
@@ -258,7 +259,7 @@ func overlayCreate(ctx context.Context, o *overlayEntry, parent *Dirent, name st
// The overlay file created below with NewFile will take a reference on
// the overlayDirent, and it should be the only thing holding a
// reference at the time of creation, so we must drop this reference.
- defer overlayDirent.DecRef()
+ defer overlayDirent.DecRef(ctx)
// Create a new overlay file that wraps the upper file.
flags.Pread = upperFile.Flags().Pread
@@ -345,7 +346,7 @@ func overlayRemove(ctx context.Context, o *overlayEntry, parent *Dirent, child *
}
}
if child.Inode.overlay.lowerExists {
- if err := overlayCreateWhiteout(o.upper, child.name); err != nil {
+ if err := overlayCreateWhiteout(ctx, o.upper, child.name); err != nil {
return err
}
}
@@ -398,7 +399,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
if !replaced.IsNegative() && IsDir(replaced.Inode.StableAttr) {
children, err := readdirOne(ctx, replaced)
if err != nil {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return err
}
@@ -406,12 +407,12 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
// included among the returned children, so we don't
// need to bother checking for them.
if len(children) > 0 {
- replaced.DecRef()
+ replaced.DecRef(ctx)
return syserror.ENOTEMPTY
}
}
- replaced.DecRef()
+ replaced.DecRef(ctx)
}
}
@@ -426,7 +427,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
return err
}
if renamed.Inode.overlay.lowerExists {
- if err := overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName); err != nil {
+ if err := overlayCreateWhiteout(ctx, oldParent.Inode.overlay.upper, oldName); err != nil {
return err
}
}
@@ -436,7 +437,7 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena
}
func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name string, data transport.BoundEndpoint, perm FilePermissions) (*Dirent, error) {
- if err := copyUp(ctx, parent); err != nil {
+ if err := copyUpLockedForRename(ctx, parent); err != nil {
return nil, err
}
@@ -454,15 +455,17 @@ func overlayBind(ctx context.Context, o *overlayEntry, parent *Dirent, name stri
// Grab the inode and drop the dirent, we don't need it.
inode := d.Inode
inode.IncRef()
- d.DecRef()
+ d.DecRef(ctx)
// Create a new overlay entry and dirent for the socket.
entry, err := newOverlayEntry(ctx, inode, nil, false)
if err != nil {
- inode.DecRef()
+ inode.DecRef(ctx)
return nil, err
}
- return NewDirent(ctx, newOverlayInode(ctx, entry, inode.MountSource), name), nil
+ // Use the parent's MountSource, since that corresponds to the overlay,
+ // and not the upper filesystem.
+ return NewDirent(ctx, newOverlayInode(ctx, entry, parent.Inode.MountSource), name), nil
}
func overlayBoundEndpoint(o *overlayEntry, path string) transport.BoundEndpoint {
@@ -526,7 +529,7 @@ func overlayUnstableAttr(ctx context.Context, o *overlayEntry) (UnstableAttr, er
return attr, err
}
-func overlayGetxattr(o *overlayEntry, name string) (string, error) {
+func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uint64) (string, error) {
// Hot path. This is how the overlay checks for whiteout files.
// Avoid defers.
var (
@@ -542,26 +545,38 @@ func overlayGetxattr(o *overlayEntry, name string) (string, error) {
o.copyMu.RLock()
if o.upper != nil {
- s, err = o.upper.Getxattr(name)
+ s, err = o.upper.GetXattr(ctx, name, size)
} else {
- s, err = o.lower.Getxattr(name)
+ s, err = o.lower.GetXattr(ctx, name, size)
}
o.copyMu.RUnlock()
return s, err
}
-func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) {
+func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error {
+ // Don't allow changes to overlay xattrs through a setxattr syscall.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return syserror.EPERM
+ }
+
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.SetXattr(ctx, d, name, value, flags)
+}
+
+func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[string]struct{}, error) {
o.copyMu.RLock()
defer o.copyMu.RUnlock()
var names map[string]struct{}
var err error
if o.upper != nil {
- names, err = o.upper.Listxattr()
+ names, err = o.upper.ListXattr(ctx, size)
} else {
- names, err = o.lower.Listxattr()
+ names, err = o.lower.ListXattr(ctx, size)
}
for name := range names {
- // Same as overlayGetxattr, we shouldn't forward along
+ // Same as overlayGetXattr, we shouldn't forward along
// overlay attributes.
if strings.HasPrefix(XattrOverlayPrefix, name) {
delete(names, name)
@@ -570,6 +585,18 @@ func overlayListxattr(o *overlayEntry) (map[string]struct{}, error) {
return names, err
}
+func overlayRemoveXattr(ctx context.Context, o *overlayEntry, d *Dirent, name string) error {
+ // Don't allow changes to overlay xattrs through a removexattr syscall.
+ if strings.HasPrefix(XattrOverlayPrefix, name) {
+ return syserror.EPERM
+ }
+
+ if err := copyUp(ctx, d); err != nil {
+ return err
+ }
+ return o.upper.RemoveXattr(ctx, d, name)
+}
+
func overlayCheck(ctx context.Context, o *overlayEntry, p PermMask) error {
o.copyMu.RLock()
// Hot path. Avoid defers.
@@ -645,7 +672,7 @@ func overlayGetlink(ctx context.Context, o *overlayEntry) (*Dirent, error) {
// ground and claim that jumping around the filesystem like this
// is not supported.
name, _ := dirent.FullName(nil)
- dirent.DecRef()
+ dirent.DecRef(ctx)
// Claim that the path is not accessible.
err = syserror.EACCES
diff --git a/pkg/sentry/fs/inode_overlay_test.go b/pkg/sentry/fs/inode_overlay_test.go
index 8935aad65..aa9851b26 100644
--- a/pkg/sentry/fs/inode_overlay_test.go
+++ b/pkg/sentry/fs/inode_overlay_test.go
@@ -17,7 +17,7 @@ package fs_test
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
@@ -316,7 +316,7 @@ func TestCacheFlush(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
ctx = &rootContext{
Context: ctx,
@@ -345,7 +345,7 @@ func TestCacheFlush(t *testing.T) {
}
// Drop the file reference.
- file.DecRef()
+ file.DecRef(ctx)
// Dirent should have 2 refs left.
if got, want := dirent.ReadRefs(), 2; int(got) != want {
@@ -361,7 +361,7 @@ func TestCacheFlush(t *testing.T) {
}
// Drop our ref.
- dirent.DecRef()
+ dirent.DecRef(ctx)
// We should be back to zero refs.
if got, want := dirent.ReadRefs(), 0; int(got) != want {
@@ -382,8 +382,8 @@ type dir struct {
ReaddirCalled bool
}
-// Getxattr implements InodeOperations.Getxattr.
-func (d *dir) Getxattr(inode *fs.Inode, name string) (string, error) {
+// GetXattr implements InodeOperations.GetXattr.
+func (d *dir) GetXattr(_ context.Context, _ *fs.Inode, name string, _ uint64) (string, error) {
for _, n := range d.negative {
if name == fs.XattrOverlayWhiteout(n) {
return "y", nil
@@ -398,7 +398,7 @@ func (d *dir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags
if err != nil {
return nil, err
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
// Wrap the file's FileOperations in a dirFile.
fops := &dirFile{
FileOperations: file.FileOperations,
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index ba3e0233d..c5c07d564 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -16,16 +16,16 @@ package fs
import (
"io"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -80,7 +80,7 @@ func NewInotify(ctx context.Context) *Inotify {
// Release implements FileOperations.Release. Release removes all watches and
// frees all resources for an inotify instance.
-func (i *Inotify) Release() {
+func (i *Inotify) Release(ctx context.Context) {
// We need to hold i.mu to avoid a race with concurrent calls to
// Inotify.targetDestroyed from Watches. There's no risk of Watches
// accessing this Inotify after the destructor ends, because we remove all
@@ -93,7 +93,7 @@ func (i *Inotify) Release() {
// the owner's destructor.
w.target.Watches.Remove(w.ID())
// Don't leak any references to the target, held by pins in the watch.
- w.destroy()
+ w.destroy(ctx)
}
}
@@ -143,7 +143,10 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
}
var writeLen int64
- for event := i.events.Front(); event != nil; event = event.Next() {
+ for it := i.events.Front(); it != nil; {
+ event := it
+ it = it.Next()
+
// Does the buffer have enough remaining space to hold the event we're
// about to write out?
if dst.NumBytes() < int64(event.sizeOf()) {
@@ -318,7 +321,7 @@ func (i *Inotify) AddWatch(target *Dirent, mask uint32) int32 {
//
// RmWatch looks up an inotify watch for the given 'wd' and configures the
// target dirent to stop sending events to this inotify instance.
-func (i *Inotify) RmWatch(wd int32) error {
+func (i *Inotify) RmWatch(ctx context.Context, wd int32) error {
i.mu.Lock()
// Find the watch we were asked to removed.
@@ -343,7 +346,7 @@ func (i *Inotify) RmWatch(wd int32) error {
i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0))
// Remove all pins.
- watch.destroy()
+ watch.destroy(ctx)
return nil
}
diff --git a/pkg/sentry/fs/inotify_event.go b/pkg/sentry/fs/inotify_event.go
index 9f70a3e82..686e1b1cd 100644
--- a/pkg/sentry/fs/inotify_event.go
+++ b/pkg/sentry/fs/inotify_event.go
@@ -18,8 +18,8 @@ import (
"bytes"
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// inotifyEventBaseSize is the base size of linux's struct inotify_event. This
diff --git a/pkg/sentry/fs/inotify_watch.go b/pkg/sentry/fs/inotify_watch.go
index 0aa0a5e9b..605423d22 100644
--- a/pkg/sentry/fs/inotify_watch.go
+++ b/pkg/sentry/fs/inotify_watch.go
@@ -15,10 +15,11 @@
package fs
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Watch represent a particular inotify watch created by inotify_add_watch.
@@ -105,12 +106,12 @@ func (w *Watch) Pin(d *Dirent) {
// Unpin drops any extra refs held on dirent due to a previous Pin
// call. Calling Unpin multiple times for the same dirent, or on a dirent
// without a corresponding Pin call is a no-op.
-func (w *Watch) Unpin(d *Dirent) {
+func (w *Watch) Unpin(ctx context.Context, d *Dirent) {
w.mu.Lock()
defer w.mu.Unlock()
if w.pins[d] {
delete(w.pins, d)
- d.DecRef()
+ d.DecRef(ctx)
}
}
@@ -125,11 +126,11 @@ func (w *Watch) TargetDestroyed() {
// this watch. Destroy does not cause any new events to be generated. The caller
// is responsible for ensuring there are no outstanding references to this
// watch.
-func (w *Watch) destroy() {
+func (w *Watch) destroy(ctx context.Context) {
w.mu.Lock()
defer w.mu.Unlock()
for d := range w.pins {
- d.DecRef()
+ d.DecRef(ctx)
}
w.pins = nil
}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 8d62642e7..ae3331737 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -40,10 +39,10 @@ go_library(
"lock_set.go",
"lock_set_functions.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/lock",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/waiter",
],
)
@@ -55,5 +54,5 @@ go_test(
"lock_range_test.go",
"lock_test.go",
],
- embed = [":lock"],
+ library = ":lock",
)
diff --git a/pkg/sentry/fs/lock/lock.go b/pkg/sentry/fs/lock/lock.go
index 636484424..8a5d9c7eb 100644
--- a/pkg/sentry/fs/lock/lock.go
+++ b/pkg/sentry/fs/lock/lock.go
@@ -52,9 +52,9 @@ package lock
import (
"fmt"
"math"
- "sync"
"syscall"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -62,7 +62,7 @@ import (
type LockType int
// UniqueID is a unique identifier of the holder of a regional file lock.
-type UniqueID uint64
+type UniqueID interface{}
const (
// ReadLock describes a POSIX regional file lock to be taken
@@ -78,6 +78,9 @@ const (
)
// LockEOF is the maximal possible end of a regional file lock.
+//
+// A BSD-style full file lock can be represented as a regional file lock from
+// offset 0 to LockEOF.
const LockEOF = math.MaxUint64
// Lock is a regional file lock. It consists of either a single writer
@@ -95,12 +98,7 @@ type Lock struct {
// If len(Readers) > 0 then HasWriter must be false.
Readers map[UniqueID]bool
- // HasWriter indicates that this is a write lock held by a single
- // UniqueID.
- HasWriter bool
-
- // Writer is only valid if HasWriter is true. It identifies a
- // single write lock holder.
+ // Writer holds the writer unique ID. It's nil if there are no writers.
Writer UniqueID
}
@@ -183,7 +181,6 @@ func makeLock(uid UniqueID, t LockType) Lock {
case ReadLock:
value.Readers[uid] = true
case WriteLock:
- value.HasWriter = true
value.Writer = uid
default:
panic(fmt.Sprintf("makeLock: invalid lock type %d", t))
@@ -193,10 +190,7 @@ func makeLock(uid UniqueID, t LockType) Lock {
// isHeld returns true if uid is a holder of Lock.
func (l Lock) isHeld(uid UniqueID) bool {
- if l.HasWriter && l.Writer == uid {
- return true
- }
- return l.Readers[uid]
+ return l.Writer == uid || l.Readers[uid]
}
// lock sets uid as a holder of a typed lock on Lock.
@@ -211,20 +205,20 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// We cannot downgrade a write lock to a read lock unless the
// uid is the same.
- if l.HasWriter {
+ if l.Writer != nil {
if l.Writer != uid {
panic(fmt.Sprintf("lock: cannot downgrade write lock to read lock for uid %d, writer is %d", uid, l.Writer))
}
// Ensure that there is only one reader if upgrading.
l.Readers = make(map[UniqueID]bool)
// Ensure that there is no longer a writer.
- l.HasWriter = false
+ l.Writer = nil
}
l.Readers[uid] = true
return
case WriteLock:
// If we are already the writer, then this is a no-op.
- if l.HasWriter && l.Writer == uid {
+ if l.Writer == uid {
return
}
// We can only upgrade a read lock to a write lock if there
@@ -240,7 +234,6 @@ func (l *Lock) lock(uid UniqueID, t LockType) {
}
// Ensure that there is only a writer.
l.Readers = make(map[UniqueID]bool)
- l.HasWriter = true
l.Writer = uid
default:
panic(fmt.Sprintf("lock: invalid lock type %d", t))
@@ -274,9 +267,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
switch t {
case ReadLock:
return l.lockable(r, func(value Lock) bool {
- // If there is no writer, there's no problem adding
- // another reader.
- if !value.HasWriter {
+ // If there is no writer, there's no problem adding another reader.
+ if value.Writer == nil {
return true
}
// If there is a writer, then it must be the same uid
@@ -286,10 +278,9 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
case WriteLock:
return l.lockable(r, func(value Lock) bool {
// If there are only readers.
- if !value.HasWriter {
- // Then this uid can only take a write lock if
- // this is a private upgrade, meaning that the
- // only reader is uid.
+ if value.Writer == nil {
+ // Then this uid can only take a write lock if this is a private
+ // upgrade, meaning that the only reader is uid.
return len(value.Readers) == 1 && value.Readers[uid]
}
// If the uid is already a writer on this region, then
@@ -301,7 +292,8 @@ func (l LockSet) canLock(uid UniqueID, t LockType, r LockRange) bool {
}
}
-// lock returns true if uid took a lock of type t on the entire range of LockRange.
+// lock returns true if uid took a lock of type t on the entire range of
+// LockRange.
//
// Preconditions: r.Start <= r.End (will panic otherwise).
func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
@@ -336,7 +328,7 @@ func (l *LockSet) lock(uid UniqueID, t LockType, r LockRange) bool {
seg, _ = l.SplitUnchecked(seg, r.End)
}
- // Set the lock on the segment. This is guaranteed to
+ // Set the lock on the segment. This is guaranteed to
// always be safe, given canLock above.
value := seg.ValuePtr()
value.lock(uid, t)
@@ -383,7 +375,7 @@ func (l *LockSet) unlock(uid UniqueID, r LockRange) {
value := seg.Value()
var remove bool
- if value.HasWriter && value.Writer == uid {
+ if value.Writer == uid {
// If we are unlocking a writer, then since there can
// only ever be one writer and no readers, then this
// lock should always be removed from the set.
diff --git a/pkg/sentry/fs/lock/lock_set_functions.go b/pkg/sentry/fs/lock/lock_set_functions.go
index 8a3ace0c1..50a16e662 100644
--- a/pkg/sentry/fs/lock/lock_set_functions.go
+++ b/pkg/sentry/fs/lock/lock_set_functions.go
@@ -44,14 +44,9 @@ func (lockSetFunctions) Merge(r1 LockRange, val1 Lock, r2 LockRange, val2 Lock)
return Lock{}, false
}
}
- if val1.HasWriter != val2.HasWriter {
+ if val1.Writer != val2.Writer {
return Lock{}, false
}
- if val1.HasWriter {
- if val1.Writer != val2.Writer {
- return Lock{}, false
- }
- }
return val1, true
}
@@ -62,7 +57,6 @@ func (lockSetFunctions) Split(r LockRange, val Lock, split uint64) (Lock, Lock)
for k, v := range val.Readers {
val0.Readers[k] = v
}
- val0.HasWriter = val.HasWriter
val0.Writer = val.Writer
return val, val0
diff --git a/pkg/sentry/fs/lock/lock_test.go b/pkg/sentry/fs/lock/lock_test.go
index ba002aeb7..fad90984b 100644
--- a/pkg/sentry/fs/lock/lock_test.go
+++ b/pkg/sentry/fs/lock/lock_test.go
@@ -42,9 +42,6 @@ func equals(e0, e1 []entry) bool {
if !reflect.DeepEqual(e0[i].LockRange, e1[i].LockRange) {
return false
}
- if e0[i].Lock.HasWriter != e1[i].Lock.HasWriter {
- return false
- }
if e0[i].Lock.Writer != e1[i].Lock.Writer {
return false
}
@@ -105,7 +102,7 @@ func TestCanLock(t *testing.T) {
LockRange: LockRange{2048, 3072},
},
{
- Lock: Lock{HasWriter: true, Writer: 1},
+ Lock: Lock{Writer: 1},
LockRange: LockRange{3072, 4096},
},
})
@@ -241,7 +238,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -254,7 +251,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -273,7 +270,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 4096},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -301,7 +298,7 @@ func TestSetLock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
{
@@ -318,7 +315,7 @@ func TestSetLock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -550,7 +547,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -594,7 +591,7 @@ func TestSetLock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 3072},
},
{
@@ -633,7 +630,7 @@ func TestSetLock(t *testing.T) {
// 0 1024 2048 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -663,11 +660,11 @@ func TestSetLock(t *testing.T) {
// 0 1024 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, LockEOF},
},
},
@@ -675,28 +672,30 @@ func TestSetLock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- success := l.lock(test.uid, test.lockType, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
+ r := LockRange{Start: test.start, End: test.end}
+ success := l.lock(test.uid, test.lockType, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
- if success != test.success {
- t.Errorf("%s: setlock(%v, %+v, %d, %d) got success %v, want %v", test.name, test.before, r, test.uid, test.lockType, success, test.success)
- continue
- }
+ if success != test.success {
+ t.Errorf("setlock(%v, %+v, %d, %d) got success %v, want %v", test.before, r, test.uid, test.lockType, success, test.success)
+ return
+ }
- if success {
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
+ if success {
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
}
- }
+ })
}
}
@@ -782,7 +781,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -824,7 +823,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -837,7 +836,7 @@ func TestUnlock(t *testing.T) {
// 0 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{4096, LockEOF},
},
},
@@ -876,7 +875,7 @@ func TestUnlock(t *testing.T) {
// 0 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, LockEOF},
},
},
@@ -889,7 +888,7 @@ func TestUnlock(t *testing.T) {
// 0 4096
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 4096},
},
},
@@ -906,7 +905,7 @@ func TestUnlock(t *testing.T) {
LockRange: LockRange{0, 1024},
},
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{1024, 4096},
},
{
@@ -974,7 +973,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -991,7 +990,7 @@ func TestUnlock(t *testing.T) {
// 0 8 4096 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 8},
},
{
@@ -1008,7 +1007,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 max uint64
before: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1025,7 +1024,7 @@ func TestUnlock(t *testing.T) {
// 0 1024 4096 8192 max uint64
after: []entry{
{
- Lock: Lock{HasWriter: true, Writer: 0},
+ Lock: Lock{Writer: 0},
LockRange: LockRange{0, 1024},
},
{
@@ -1041,19 +1040,21 @@ func TestUnlock(t *testing.T) {
}
for _, test := range tests {
- l := fill(test.before)
+ t.Run(test.name, func(t *testing.T) {
+ l := fill(test.before)
- r := LockRange{Start: test.start, End: test.end}
- l.unlock(test.uid, r)
- var got []entry
- for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
- got = append(got, entry{
- Lock: seg.Value(),
- LockRange: seg.Range(),
- })
- }
- if !equals(got, test.after) {
- t.Errorf("%s: got set %+v, want %+v", test.name, got, test.after)
- }
+ r := LockRange{Start: test.start, End: test.end}
+ l.unlock(test.uid, r)
+ var got []entry
+ for seg := l.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
+ got = append(got, entry{
+ Lock: seg.Value(),
+ LockRange: seg.Range(),
+ })
+ }
+ if !equals(got, test.after) {
+ t.Errorf("got set %+v, want %+v", got, test.after)
+ }
+ })
}
}
diff --git a/pkg/sentry/fs/mock.go b/pkg/sentry/fs/mock.go
index 7a24c6f1b..1d6ea5736 100644
--- a/pkg/sentry/fs/mock.go
+++ b/pkg/sentry/fs/mock.go
@@ -15,7 +15,7 @@
package fs
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/mount.go b/pkg/sentry/fs/mount.go
index 7a9692800..ee69b10e8 100644
--- a/pkg/sentry/fs/mount.go
+++ b/pkg/sentry/fs/mount.go
@@ -19,8 +19,8 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
)
// DirentOperations provide file systems greater control over how long a Dirent
@@ -51,7 +51,7 @@ type MountSourceOperations interface {
DirentOperations
// Destroy destroys the MountSource.
- Destroy()
+ Destroy(ctx context.Context)
// Below are MountSourceOperations that do not conform to Linux.
@@ -165,16 +165,16 @@ func (msrc *MountSource) DecDirentRefs() {
}
}
-func (msrc *MountSource) destroy() {
+func (msrc *MountSource) destroy(ctx context.Context) {
if c := msrc.DirentRefs(); c != 0 {
panic(fmt.Sprintf("MountSource with non-zero direntRefs is being destroyed: %d", c))
}
- msrc.MountSourceOperations.Destroy()
+ msrc.MountSourceOperations.Destroy(ctx)
}
// DecRef drops a reference on the MountSource.
-func (msrc *MountSource) DecRef() {
- msrc.DecRefWithDestructor(msrc.destroy)
+func (msrc *MountSource) DecRef(ctx context.Context) {
+ msrc.DecRefWithDestructor(ctx, msrc.destroy)
}
// FlushDirentRefs drops all references held by the MountSource on Dirents.
@@ -264,7 +264,7 @@ func (*SimpleMountSourceOperations) ResetInodeMappings() {}
func (*SimpleMountSourceOperations) SaveInodeMapping(*Inode, string) {}
// Destroy implements MountSourceOperations.Destroy.
-func (*SimpleMountSourceOperations) Destroy() {}
+func (*SimpleMountSourceOperations) Destroy(context.Context) {}
// Info defines attributes of a filesystem.
type Info struct {
diff --git a/pkg/sentry/fs/mount_overlay.go b/pkg/sentry/fs/mount_overlay.go
index 299712cd7..7badc75d6 100644
--- a/pkg/sentry/fs/mount_overlay.go
+++ b/pkg/sentry/fs/mount_overlay.go
@@ -15,7 +15,7 @@
package fs
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// overlayMountSourceOperations implements MountSourceOperations for an overlay
@@ -115,9 +115,9 @@ func (o *overlayMountSourceOperations) SaveInodeMapping(inode *Inode, path strin
}
// Destroy drops references on the upper and lower MountSource.
-func (o *overlayMountSourceOperations) Destroy() {
- o.upper.DecRef()
- o.lower.DecRef()
+func (o *overlayMountSourceOperations) Destroy(ctx context.Context) {
+ o.upper.DecRef(ctx)
+ o.lower.DecRef(ctx)
}
// type overlayFilesystem is the filesystem for overlay mounts.
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
index 0b84732aa..6c296f5d0 100644
--- a/pkg/sentry/fs/mount_test.go
+++ b/pkg/sentry/fs/mount_test.go
@@ -18,7 +18,8 @@ import (
"fmt"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
)
// cacheReallyContains iterates through the dirent cache to determine whether
@@ -32,15 +33,16 @@ func cacheReallyContains(cache *DirentCache, d *Dirent) bool {
return false
}
-func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
+func mountPathsAre(ctx context.Context, root *Dirent, got []*Mount, want ...string) error {
gotPaths := make(map[string]struct{}, len(got))
gotStr := make([]string, len(got))
for i, g := range got {
- groot := g.Root()
- name, _ := groot.FullName(root)
- groot.DecRef()
- gotStr[i] = name
- gotPaths[name] = struct{}{}
+ if groot := g.Root(); groot != nil {
+ name, _ := groot.FullName(root)
+ groot.DecRef(ctx)
+ gotStr[i] = name
+ gotPaths[name] = struct{}{}
+ }
}
if len(got) != len(want) {
return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want)
@@ -68,7 +70,7 @@ func TestMountSourceOnlyCachedOnce(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
rootDirent := mm.Root()
- defer rootDirent.DecRef()
+ defer rootDirent.DecRef(ctx)
// Get a child of the root which we will mount over. Note that the
// MockInodeOperations causes Walk to always succeed.
@@ -124,7 +126,7 @@ func TestAllMountsUnder(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
rootDirent := mm.Root()
- defer rootDirent.DecRef()
+ defer rootDirent.DecRef(ctx)
// Add mounts at the following paths:
paths := []string{
@@ -149,14 +151,14 @@ func TestAllMountsUnder(t *testing.T) {
if err := mm.Mount(ctx, d, submountInode); err != nil {
t.Fatalf("could not mount at %q: %v", p, err)
}
- d.DecRef()
+ d.DecRef(ctx)
}
// mm root should contain all submounts (and does not include the root mount).
rootMnt := mm.FindMount(rootDirent)
submounts := mm.AllMountsUnder(rootMnt)
allPaths := append(paths, "/")
- if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil {
t.Error(err)
}
@@ -180,9 +182,9 @@ func TestAllMountsUnder(t *testing.T) {
if err != nil {
t.Fatalf("could not find path %q in mount manager: %v", "/foo", err)
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
submounts = mm.AllMountsUnder(mm.FindMount(d))
- if err := mountPathsAre(rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, "/foo", "/foo/bar", "/foo/qux", "/foo/bar/baz"); err != nil {
t.Error(err)
}
@@ -192,9 +194,9 @@ func TestAllMountsUnder(t *testing.T) {
if err != nil {
t.Fatalf("could not find path %q in mount manager: %v", "/waldo", err)
}
- defer waldo.DecRef()
+ defer waldo.DecRef(ctx)
submounts = mm.AllMountsUnder(mm.FindMount(waldo))
- if err := mountPathsAre(rootDirent, submounts, "/waldo"); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, "/waldo"); err != nil {
t.Error(err)
}
}
@@ -211,7 +213,7 @@ func TestUnmount(t *testing.T) {
t.Fatalf("NewMountNamespace failed: %v", err)
}
rootDirent := mm.Root()
- defer rootDirent.DecRef()
+ defer rootDirent.DecRef(ctx)
// Add mounts at the following paths:
paths := []string{
@@ -239,7 +241,7 @@ func TestUnmount(t *testing.T) {
if err := mm.Mount(ctx, d, submountInode); err != nil {
t.Fatalf("could not mount at %q: %v", p, err)
}
- d.DecRef()
+ d.DecRef(ctx)
}
allPaths := make([]string, len(paths)+1)
@@ -258,13 +260,13 @@ func TestUnmount(t *testing.T) {
if err := mm.Unmount(ctx, d, false); err != nil {
t.Fatalf("could not unmount at %q: %v", p, err)
}
- d.DecRef()
+ d.DecRef(ctx)
// Remove the path that has been unmounted and the check that the remaining
// mounts are still there.
allPaths = allPaths[:len(allPaths)-1]
submounts := mm.AllMountsUnder(rootMnt)
- if err := mountPathsAre(rootDirent, submounts, allPaths...); err != nil {
+ if err := mountPathsAre(ctx, rootDirent, submounts, allPaths...); err != nil {
t.Error(err)
}
}
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index ac0398bd9..d741c4339 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -17,16 +17,12 @@ package fs
import (
"fmt"
"math"
- "path"
- "strings"
- "sync"
"syscall"
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -100,10 +96,14 @@ func newUndoMount(d *Dirent) *Mount {
}
}
-// Root returns the root dirent of this mount. Callers must call DecRef on the
-// returned dirent.
+// Root returns the root dirent of this mount.
+//
+// This may return nil if the mount has already been free. Callers must handle this
+// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent.
func (m *Mount) Root() *Dirent {
- m.root.IncRef()
+ if !m.root.TryIncRef() {
+ return nil
+ }
return m.root
}
@@ -234,7 +234,7 @@ func (mns *MountNamespace) flushMountSourceRefsLocked() {
// After destroy is called, the MountNamespace may continue to be referenced (for
// example via /proc/mounts), but should free all resources and shouldn't have
// Find* methods called.
-func (mns *MountNamespace) destroy() {
+func (mns *MountNamespace) destroy(ctx context.Context) {
mns.mu.Lock()
defer mns.mu.Unlock()
@@ -247,13 +247,13 @@ func (mns *MountNamespace) destroy() {
for _, mp := range mns.mounts {
// Drop the mount reference on all mounted dirents.
for ; mp != nil; mp = mp.previous {
- mp.root.DecRef()
+ mp.root.DecRef(ctx)
}
}
mns.mounts = nil
// Drop reference on the root.
- mns.root.DecRef()
+ mns.root.DecRef(ctx)
// Ensure that root cannot be accessed via this MountNamespace any
// more.
@@ -265,21 +265,8 @@ func (mns *MountNamespace) destroy() {
}
// DecRef implements RefCounter.DecRef with destructor mns.destroy.
-func (mns *MountNamespace) DecRef() {
- mns.DecRefWithDestructor(mns.destroy)
-}
-
-// Freeze freezes the entire mount tree.
-func (mns *MountNamespace) Freeze() {
- mns.mu.Lock()
- defer mns.mu.Unlock()
-
- // We only want to freeze Dirents with active references, not Dirents referenced
- // by a mount's MountSource.
- mns.flushMountSourceRefsLocked()
-
- // Freeze the entire shebang.
- mns.root.Freeze()
+func (mns *MountNamespace) DecRef(ctx context.Context) {
+ mns.DecRefWithDestructor(ctx, mns.destroy)
}
// withMountLocked prevents further walks to `node`, because `node` is about to
@@ -325,7 +312,7 @@ func (mns *MountNamespace) Mount(ctx context.Context, mountPoint *Dirent, inode
if err != nil {
return err
}
- defer replacement.DecRef()
+ defer replacement.DecRef(ctx)
// Set the mount's root dirent and id.
parentMnt := mns.findMountLocked(mountPoint)
@@ -407,7 +394,7 @@ func (mns *MountNamespace) Unmount(ctx context.Context, node *Dirent, detachOnly
panic(fmt.Sprintf("Last mount in the chain must be a undo mount: %+v", prev))
}
// Drop mount reference taken at the end of MountNamespace.Mount.
- prev.root.DecRef()
+ prev.root.DecRef(ctx)
} else {
mns.mounts[prev.root] = prev
}
@@ -509,11 +496,11 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path
// non-directory root is hopeless.
if current != root {
if !IsDir(current.Inode.StableAttr) {
- current.DecRef() // Drop reference from above.
+ current.DecRef(ctx) // Drop reference from above.
return nil, syserror.ENOTDIR
}
if err := current.Inode.CheckPermission(ctx, PermMask{Execute: true}); err != nil {
- current.DecRef() // Drop reference from above.
+ current.DecRef(ctx) // Drop reference from above.
return nil, err
}
}
@@ -524,12 +511,12 @@ func (mns *MountNamespace) FindLink(ctx context.Context, root, wd *Dirent, path
// Allow failed walks to cache the dirent, because no
// children will acquire a reference at the end.
current.maybeExtendReference()
- current.DecRef()
+ current.DecRef(ctx)
return nil, err
}
// Drop old reference.
- current.DecRef()
+ current.DecRef(ctx)
if remainder != "" {
// Ensure it's resolved, unless it's the last level.
@@ -583,11 +570,11 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema
case nil:
// Make sure we didn't exhaust the traversal budget.
if *remainingTraversals == 0 {
- target.DecRef()
+ target.DecRef(ctx)
return nil, syscall.ELOOP
}
- node.DecRef() // Drop the original reference.
+ node.DecRef(ctx) // Drop the original reference.
return target, nil
case syscall.ENOLINK:
@@ -595,7 +582,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema
return node, nil
case ErrResolveViaReadlink:
- defer node.DecRef() // See above.
+ defer node.DecRef(ctx) // See above.
// First, check if we should traverse.
if *remainingTraversals == 0 {
@@ -609,8 +596,11 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema
}
// Find the node; we resolve relative to the current symlink's parent.
+ renameMu.RLock()
+ parent := node.parent
+ renameMu.RUnlock()
*remainingTraversals--
- d, err := mns.FindInode(ctx, root, node.parent, targetPath, remainingTraversals)
+ d, err := mns.FindInode(ctx, root, parent, targetPath, remainingTraversals)
if err != nil {
return nil, err
}
@@ -618,7 +608,7 @@ func (mns *MountNamespace) resolve(ctx context.Context, root, node *Dirent, rema
return d, err
default:
- node.DecRef() // Drop for err; see above.
+ node.DecRef(ctx) // Drop for err; see above.
// Propagate the error.
return nil, err
@@ -631,71 +621,3 @@ func (mns *MountNamespace) SyncAll(ctx context.Context) {
defer mns.mu.Unlock()
mns.root.SyncAll(ctx)
}
-
-// ResolveExecutablePath resolves the given executable name given a set of
-// paths that might contain it.
-func (mns *MountNamespace) ResolveExecutablePath(ctx context.Context, wd, name string, paths []string) (string, error) {
- // Absolute paths can be used directly.
- if path.IsAbs(name) {
- return name, nil
- }
-
- // Paths with '/' in them should be joined to the working directory, or
- // to the root if working directory is not set.
- if strings.IndexByte(name, '/') > 0 {
- if wd == "" {
- wd = "/"
- }
- if !path.IsAbs(wd) {
- return "", fmt.Errorf("working directory %q must be absolute", wd)
- }
- return path.Join(wd, name), nil
- }
-
- // Otherwise, We must lookup the name in the paths, starting from the
- // calling context's root directory.
- root := RootFromContext(ctx)
- if root == nil {
- // Caller has no root. Don't bother traversing anything.
- return "", syserror.ENOENT
- }
- defer root.DecRef()
- for _, p := range paths {
- binPath := path.Join(p, name)
- traversals := uint(linux.MaxSymlinkTraversals)
- d, err := mns.FindInode(ctx, root, nil, binPath, &traversals)
- if err == syserror.ENOENT || err == syserror.EACCES {
- // Didn't find it here.
- continue
- }
- if err != nil {
- return "", err
- }
- defer d.DecRef()
-
- // Check that it is a regular file.
- if !IsRegular(d.Inode.StableAttr) {
- continue
- }
-
- // Check whether we can read and execute the found file.
- if err := d.Inode.CheckPermission(ctx, PermMask{Read: true, Execute: true}); err != nil {
- log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err)
- continue
- }
- return path.Join("/", p, name), nil
- }
- return "", syserror.ENOENT
-}
-
-// GetPath returns the PATH as a slice of strings given the environment
-// variables.
-func GetPath(env []string) []string {
- const prefix = "PATH="
- for _, e := range env {
- if strings.HasPrefix(e, prefix) {
- return strings.Split(strings.TrimPrefix(e, prefix), ":")
- }
- }
- return nil
-}
diff --git a/pkg/sentry/fs/mounts_test.go b/pkg/sentry/fs/mounts_test.go
index c4c771f2c..975d6cbc9 100644
--- a/pkg/sentry/fs/mounts_test.go
+++ b/pkg/sentry/fs/mounts_test.go
@@ -17,7 +17,7 @@ package fs_test
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
@@ -51,7 +51,7 @@ func TestFindLink(t *testing.T) {
}
root := mm.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
foo, err := root.Walk(ctx, root, "foo")
if err != nil {
t.Fatalf("Error walking to foo: %v", err)
diff --git a/pkg/sentry/fs/offset.go b/pkg/sentry/fs/offset.go
index f7d844ce7..53b5df175 100644
--- a/pkg/sentry/fs/offset.go
+++ b/pkg/sentry/fs/offset.go
@@ -17,7 +17,7 @@ package fs
import (
"math"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// OffsetPageEnd returns the file offset rounded up to the nearest
diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go
index 1d3ff39e0..35013a21b 100644
--- a/pkg/sentry/fs/overlay.go
+++ b/pkg/sentry/fs/overlay.go
@@ -17,14 +17,13 @@ package fs
import (
"fmt"
"strings"
- "sync"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// The virtual filesystem implements an overlay configuration. For a high-level
@@ -108,7 +107,7 @@ func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags Mount
msrc := newOverlayMountSource(ctx, upper.MountSource, lower.MountSource, flags)
overlay, err := newOverlayEntry(ctx, upper, lower, true)
if err != nil {
- msrc.DecRef()
+ msrc.DecRef(ctx)
return nil, err
}
@@ -131,7 +130,7 @@ func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode,
msrc := newOverlayMountSource(ctx, upperMS, lower.MountSource, flags)
overlay, err := newOverlayEntry(ctx, nil, lower, true)
if err != nil {
- msrc.DecRef()
+ msrc.DecRef(ctx)
return nil, err
}
return newOverlayInode(ctx, overlay, msrc), nil
@@ -199,7 +198,7 @@ type overlayEntry struct {
upper *Inode
// dirCacheMu protects dirCache.
- dirCacheMu gvsync.DowngradableRWMutex `state:"nosave"`
+ dirCacheMu sync.RWMutex `state:"nosave"`
// dirCache is cache of DentAttrs from upper and lower Inodes.
dirCache *SortedDentryMap
@@ -231,16 +230,16 @@ func newOverlayEntry(ctx context.Context, upper *Inode, lower *Inode, lowerExist
}, nil
}
-func (o *overlayEntry) release() {
+func (o *overlayEntry) release(ctx context.Context) {
// We drop a reference on upper and lower file system Inodes
// rather than releasing them, because in-memory filesystems
// may hold an extra reference to these Inodes so that they
// stay in memory.
if o.upper != nil {
- o.upper.DecRef()
+ o.upper.DecRef(ctx)
}
if o.lower != nil {
- o.lower.DecRef()
+ o.lower.DecRef(ctx)
}
}
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index f21e2a65c..b8b2281a8 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,7 +17,6 @@ go_library(
"mounts.go",
"net.go",
"proc.go",
- "rpcinet_proc.go",
"stat.go",
"sys.go",
"sys_net.go",
@@ -28,17 +26,17 @@ go_library(
"uptime.go",
"version.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
- "//pkg/sentry/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/proc/device",
"//pkg/sentry/fs/proc/seqfile",
"//pkg/sentry/fs/ramfs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
@@ -46,14 +44,14 @@ go_library(
"//pkg/sentry/limits",
"//pkg/sentry/mm",
"//pkg/sentry/socket",
- "//pkg/sentry/socket/rpcinet",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/header",
"//pkg/tcpip/network/ipv4",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -65,11 +63,11 @@ go_test(
"net_test.go",
"sys_net_test.go",
],
- embed = [":proc"],
+ library = ":proc",
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/inet",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/proc/README.md b/pkg/sentry/fs/proc/README.md
index 5d4ec6c7b..6667a0916 100644
--- a/pkg/sentry/fs/proc/README.md
+++ b/pkg/sentry/fs/proc/README.md
@@ -11,6 +11,8 @@ inconsistency, please file a bug.
The following files are implemented:
+<!-- mdformat off(don't wrap the table) -->
+
| File /proc/ | Content |
| :------------------------ | :---------------------------------------------------- |
| [cpuinfo](#cpuinfo) | Info about the CPU |
@@ -22,6 +24,8 @@ The following files are implemented:
| [uptime](#uptime) | Wall clock since boot, combined idle time of all cpus |
| [version](#version) | Kernel version |
+<!-- mdformat on -->
+
### cpuinfo
```bash
diff --git a/pkg/sentry/fs/proc/cgroup.go b/pkg/sentry/fs/proc/cgroup.go
index 05e31c55d..7c1d9e7e9 100644
--- a/pkg/sentry/fs/proc/cgroup.go
+++ b/pkg/sentry/fs/proc/cgroup.go
@@ -17,10 +17,12 @@ package proc
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
+// LINT.IfChange
+
func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string]string) *fs.Inode {
// From man 7 cgroups: "For each cgroup hierarchy of which the process
// is a member, there is one entry containing three colon-separated
@@ -39,3 +41,5 @@ func newCGroupInode(ctx context.Context, msrc *fs.MountSource, cgroupControllers
return newStaticProcInode(ctx, msrc, []byte(data))
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/cpuinfo.go b/pkg/sentry/fs/proc/cpuinfo.go
index 3edf36780..c96533401 100644
--- a/pkg/sentry/fs/proc/cpuinfo.go
+++ b/pkg/sentry/fs/proc/cpuinfo.go
@@ -15,11 +15,15 @@
package proc
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
k := kernel.KernelFromContext(ctx)
features := k.FeatureSet()
@@ -27,9 +31,11 @@ func newCPUInfo(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// Kernel is always initialized with a FeatureSet.
panic("cpuinfo read with nil FeatureSet")
}
- contents := make([]byte, 0, 1024)
+ var buf bytes.Buffer
for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
- contents = append(contents, []byte(features.CPUInfo(i))...)
+ features.WriteCPUInfoTo(i, &buf)
}
- return newStaticProcInode(ctx, msrc, contents)
+ return newStaticProcInode(ctx, msrc, buf.Bytes())
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/device/BUILD b/pkg/sentry/fs/proc/device/BUILD
index 0394451d4..52c9aa93d 100644
--- a/pkg/sentry/fs/proc/device/BUILD
+++ b/pkg/sentry/fs/proc/device/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "device",
srcs = ["device.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/device",
visibility = ["//pkg/sentry:internal"],
deps = ["//pkg/sentry/device"],
)
diff --git a/pkg/sentry/fs/proc/exec_args.go b/pkg/sentry/fs/proc/exec_args.go
index 1d3a2d426..8fe626e1c 100644
--- a/pkg/sentry/fs/proc/exec_args.go
+++ b/pkg/sentry/fs/proc/exec_args.go
@@ -20,15 +20,17 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// execArgType enumerates the types of exec arguments that are exposed through
// proc.
type execArgType int
@@ -201,3 +203,5 @@ func (f *execArgFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequen
}
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go)
diff --git a/pkg/sentry/fs/proc/fds.go b/pkg/sentry/fs/proc/fds.go
index bee421d76..45523adf8 100644
--- a/pkg/sentry/fs/proc/fds.go
+++ b/pkg/sentry/fs/proc/fds.go
@@ -19,7 +19,7 @@ import (
"sort"
"strconv"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// walkDescriptors finds the descriptor (file-flag pair) for the fd identified
// by p, and calls the toInodeOperations callback with that descriptor. This is a helper
// method for implementing fs.InodeOperations.Lookup.
@@ -54,11 +56,11 @@ func walkDescriptors(t *kernel.Task, p string, toInode func(*fs.File, kernel.FDF
// readDescriptors reads fds in the task starting at offset, and calls the
// toDentAttr callback for each to get a DentAttr, which it then emits. This is
// a helper for implementing fs.InodeOperations.Readdir.
-func readDescriptors(t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) {
+func readDescriptors(ctx context.Context, t *kernel.Task, c *fs.DirCtx, offset int64, toDentAttr func(int) fs.DentAttr) (int64, error) {
var fds []int32
t.WithMuLocked(func(t *kernel.Task) {
if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.GetFDs()
+ fds = fdTable.GetFDs(ctx)
}
})
@@ -114,7 +116,7 @@ func (f *fd) GetFile(context.Context, *fs.Dirent, fs.FileFlags) (*fs.File, error
func (f *fd) Readlink(ctx context.Context, _ *fs.Inode) (string, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
n, _ := f.file.Dirent.FullName(root)
return n, nil
@@ -133,13 +135,7 @@ func (f *fd) Truncate(context.Context, *fs.Inode, int64) error {
func (f *fd) Release(ctx context.Context) {
f.Symlink.Release(ctx)
- f.file.DecRef()
-}
-
-// Close releases the reference on the file.
-func (f *fd) Close() error {
- f.file.DecRef()
- return nil
+ f.file.DecRef(ctx)
}
// fdDir is an InodeOperations for /proc/TID/fd.
@@ -225,7 +221,7 @@ func (f *fdDirFile) Readdir(ctx context.Context, file *fs.File, ser fs.DentrySer
if f.isInfoFile {
typ = fs.Symlink
}
- return readDescriptors(f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr {
+ return readDescriptors(ctx, f.t, dirCtx, file.Offset(), func(fd int) fs.DentAttr {
return fs.GenericDentAttr(typ, device.ProcDevice)
})
}
@@ -259,7 +255,7 @@ func (fdid *fdInfoDir) Lookup(ctx context.Context, dir *fs.Inode, p string) (*fs
// locks, and other data. For now we only have flags.
// See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
flags := file.Flags().ToLinux() | fdFlags.ToLinuxFileFlags()
- file.DecRef()
+ file.DecRef(ctx)
contents := []byte(fmt.Sprintf("flags:\t0%o\n", flags))
return newStaticProcInode(ctx, dir.MountSource, contents)
})
@@ -277,3 +273,5 @@ func (fdid *fdInfoDir) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.
}
return fs.NewFile(ctx, dirent, flags, fops), nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/filesystems.go b/pkg/sentry/fs/proc/filesystems.go
index e9250c51c..0a58ac34c 100644
--- a/pkg/sentry/fs/proc/filesystems.go
+++ b/pkg/sentry/fs/proc/filesystems.go
@@ -18,11 +18,13 @@ import (
"bytes"
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
)
+// LINT.IfChange
+
// filesystemsData backs /proc/filesystems.
//
// +stateify savable
@@ -59,3 +61,5 @@ func (*filesystemsData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle
// Return the SeqData and advance the generation counter.
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*filesystemsData)(nil)}}, 1
}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/fs.go b/pkg/sentry/fs/proc/fs.go
index f14833805..daf1ba781 100644
--- a/pkg/sentry/fs/proc/fs.go
+++ b/pkg/sentry/fs/proc/fs.go
@@ -17,10 +17,12 @@ package proc
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
+// LINT.IfChange
+
// filesystem is a procfs.
//
// +stateify savable
@@ -79,3 +81,5 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// never want them cached.
return New(ctx, fs.NewNonCachingMountSource(ctx, f, flags), cgroups)
}
+
+// LINT.ThenChange(../../fsimpl/proc/filesystem.go)
diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go
index 0c04f81fa..d2859a4c2 100644
--- a/pkg/sentry/fs/proc/inode.go
+++ b/pkg/sentry/fs/proc/inode.go
@@ -16,16 +16,18 @@ package proc
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr
// method to return either the task or root as the owner, depending on the
// task's dumpability.
@@ -131,3 +133,5 @@ func newProcInode(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
}
return fs.NewInode(ctx, iops, msrc, sattr)
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/loadavg.go b/pkg/sentry/fs/proc/loadavg.go
index 8602b7426..139d49c34 100644
--- a/pkg/sentry/fs/proc/loadavg.go
+++ b/pkg/sentry/fs/proc/loadavg.go
@@ -18,10 +18,12 @@ import (
"bytes"
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
)
+// LINT.IfChange
+
// loadavgData backs /proc/loadavg.
//
// +stateify savable
@@ -53,3 +55,5 @@ func (d *loadavgData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/meminfo.go b/pkg/sentry/fs/proc/meminfo.go
index 495f3e3ba..91617267d 100644
--- a/pkg/sentry/fs/proc/meminfo.go
+++ b/pkg/sentry/fs/proc/meminfo.go
@@ -18,13 +18,15 @@ import (
"bytes"
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// meminfoData backs /proc/meminfo.
//
// +stateify savable
@@ -56,12 +58,16 @@ func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
var buf bytes.Buffer
fmt.Fprintf(&buf, "MemTotal: %8d kB\n", totalSize/1024)
- memFree := (totalSize - totalUsage) / 1024
+ memFree := totalSize - totalUsage
+ if memFree > totalSize {
+ // Underflow.
+ memFree = 0
+ }
// We use MemFree as MemAvailable because we don't swap.
// TODO(rahat): When reclaim is implemented the value of MemAvailable
// should change.
- fmt.Fprintf(&buf, "MemFree: %8d kB\n", memFree)
- fmt.Fprintf(&buf, "MemAvailable: %8d kB\n", memFree)
+ fmt.Fprintf(&buf, "MemFree: %8d kB\n", memFree/1024)
+ fmt.Fprintf(&buf, "MemAvailable: %8d kB\n", memFree/1024)
fmt.Fprintf(&buf, "Buffers: 0 kB\n") // memory usage by block devices
fmt.Fprintf(&buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
// Emulate a system with no swap, which disables inactivation of anon pages.
@@ -83,3 +89,5 @@ func (d *meminfoData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
fmt.Fprintf(&buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*meminfoData)(nil)}}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index e33c4a460..6a63c47b3 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -18,13 +18,16 @@ import (
"bytes"
"fmt"
"sort"
+ "strings"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// forEachMountSource runs f for the process root mount and each mount that is a
// descendant of the root.
func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
@@ -44,7 +47,7 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
// The task has been destroyed. Nothing to show here.
return
}
- defer rootDir.DecRef()
+ defer rootDir.DecRef(t)
mnt := t.MountNamespace().FindMount(rootDir)
if mnt == nil {
@@ -57,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
})
for _, m := range ms {
mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
mountPath, desc := mroot.FullName(rootDir)
- mroot.DecRef()
+ mroot.DecRef(t)
if !desc {
// MountSources that are not descendants of the chroot jail are ignored.
continue
}
-
fn(mountPath, m)
}
}
@@ -88,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
var buf bytes.Buffer
forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef(ctx)
+
// Format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
@@ -104,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
// (3) Major:Minor device ID. We don't have a superblock, so we
// just use the root inode device number.
- mroot := m.Root()
- defer mroot.DecRef()
-
sa := mroot.Inode.StableAttr
fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
@@ -144,14 +152,36 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
// (10) Mount source: filesystem-specific information or "none".
fmt.Fprintf(&buf, "none ")
- // (11) Superblock options. Only "ro/rw" is supported for now,
- // and is the same as the filesystem option.
- fmt.Fprintf(&buf, "%s\n", opts)
+ // (11) Superblock options, and final newline.
+ fmt.Fprintf(&buf, "%s\n", superBlockOpts(mountPath, mroot.Inode.MountSource))
})
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountInfoFile)(nil)}}, 0
}
+func superBlockOpts(mountPath string, msrc *fs.MountSource) string {
+ // gVisor doesn't (yet) have a concept of super block options, so we
+ // use the ro/rw bit from the mount flag.
+ opts := "rw"
+ if msrc.Flags.ReadOnly {
+ opts = "ro"
+ }
+
+ // NOTE(b/147673608): If the mount is a cgroup, we also need to include
+ // the cgroup name in the options. For now we just read that from the
+ // path.
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
+ // should get this value from the cgroup itself, and not rely on the
+ // path.
+ if msrc.FilesystemType == "cgroup" {
+ splitPath := strings.Split(mountPath, "/")
+ cgroupType := splitPath[len(splitPath)-1]
+ opts += "," + cgroupType
+ }
+ return opts
+}
+
// mountsFile is used to implement /proc/[pid]/mounts.
//
// +stateify savable
@@ -183,7 +213,10 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
//
// The "needs dump"and fsck flags are always 0, which is allowed.
root := m.Root()
- defer root.DecRef()
+ if root == nil {
+ return // No longer valid.
+ }
+ defer root.DecRef(ctx)
flags := root.Inode.MountSource.Flags
opts := "rw"
@@ -195,3 +228,5 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*mountsFile)(nil)}}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 402919924..83a43aa26 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -22,8 +22,8 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
@@ -33,49 +33,55 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// newNet creates a new proc net entry.
-func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
+// LINT.IfChange
+
+// newNetDir creates a new proc net entry.
+func newNetDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ k := t.Kernel()
+
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ if s := t.NetworkNamespace().Stack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
contents = map[string]*fs.Inode{
- "dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
- "snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
+ "dev": seqfile.NewSeqFileInode(t, &netDev{s: s}, msrc),
+ "snmp": seqfile.NewSeqFileInode(t, &netSnmp{s: s}, msrc),
// The following files are simple stubs until they are
// implemented in netstack, if the file contains a
// header the stub is just the header otherwise it is
// an empty file.
- "arp": newStaticProcInode(ctx, msrc, []byte("IP address HW type Flags HW address Mask Device")),
+ "arp": newStaticProcInode(t, msrc, []byte("IP address HW type Flags HW address Mask Device\n")),
- "netlink": newStaticProcInode(ctx, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode")),
- "netstat": newStaticProcInode(ctx, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess")),
- "packet": newStaticProcInode(ctx, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode")),
- "protocols": newStaticProcInode(ctx, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em")),
+ "netlink": newStaticProcInode(t, msrc, []byte("sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n")),
+ "netstat": newStaticProcInode(t, msrc, []byte("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")),
+ "packet": newStaticProcInode(t, msrc, []byte("sk RefCnt Type Proto Iface R Rmem User Inode\n")),
+ "protocols": newStaticProcInode(t, msrc, []byte("protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n")),
// Linux sets psched values to: nsec per usec, psched
// tick in ns, 1000000, high res timer ticks per sec
// (ClockGetres returns 1ns resolution).
- "psched": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
- "ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")),
- "route": seqfile.NewSeqFileInode(ctx, &netRoute{s: s}, msrc),
- "tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc),
- "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc),
- "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
+ "psched": newStaticProcInode(t, msrc, []byte(fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond)))),
+ "ptype": newStaticProcInode(t, msrc, []byte("Type Device Function\n")),
+ "route": seqfile.NewSeqFileInode(t, &netRoute{s: s}, msrc),
+ "tcp": seqfile.NewSeqFileInode(t, &netTCP{k: k}, msrc),
+ "udp": seqfile.NewSeqFileInode(t, &netUDP{k: k}, msrc),
+ "unix": seqfile.NewSeqFileInode(t, &netUnix{k: k}, msrc),
}
if s.SupportsIPv6() {
- contents["if_inet6"] = seqfile.NewSeqFileInode(ctx, &ifinet6{s: s}, msrc)
- contents["ipv6_route"] = newStaticProcInode(ctx, msrc, []byte(""))
- contents["tcp6"] = seqfile.NewSeqFileInode(ctx, &netTCP6{k: k}, msrc)
- contents["udp6"] = newStaticProcInode(ctx, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode"))
+ contents["if_inet6"] = seqfile.NewSeqFileInode(t, &ifinet6{s: s}, msrc)
+ contents["ipv6_route"] = newStaticProcInode(t, msrc, []byte(""))
+ contents["tcp6"] = seqfile.NewSeqFileInode(t, &netTCP6{k: k}, msrc)
+ contents["udp6"] = newStaticProcInode(t, msrc, []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"))
}
}
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
+ d := ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
+ return newProcInode(t, d, msrc, fs.SpecialDirectory, t)
}
// ifinet6 implements seqfile.SeqSource for /proc/net/if_inet6.
@@ -413,7 +419,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
}
sfile := s.(*fs.File)
if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
- s.DecRef()
+ s.DecRef(ctx)
// Not a unix socket.
continue
}
@@ -473,7 +479,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
}
fmt.Fprintf(&buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
data := []seqfile.SeqData{
@@ -568,7 +574,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
}
if family, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
- s.DecRef()
+ s.DecRef(ctx)
// Not tcp4 sockets.
continue
}
@@ -658,7 +664,7 @@ func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kerne
fmt.Fprintf(&buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
data := []seqfile.SeqData{
@@ -746,7 +752,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
}
if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
- s.DecRef()
+ s.DecRef(ctx)
// Not udp4 socket.
continue
}
@@ -816,7 +822,7 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
fmt.Fprintf(&buf, "\n")
- s.DecRef()
+ s.DecRef(ctx)
}
data := []seqfile.SeqData{
@@ -831,3 +837,5 @@ func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
}
return data, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_net.go)
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index 56e92721e..77e0e1d26 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -20,17 +20,18 @@ import (
"sort"
"strconv"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// proc is a root proc node.
//
// +stateify savable
@@ -69,6 +70,7 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
"loadavg": seqfile.NewSeqFileInode(ctx, &loadavgData{}, msrc),
"meminfo": seqfile.NewSeqFileInode(ctx, &meminfoData{k}, msrc),
"mounts": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/mounts"), msrc, fs.Symlink, nil),
+ "net": newProcInode(ctx, ramfs.NewSymlink(ctx, fs.RootOwner, "self/net"), msrc, fs.Symlink, nil),
"self": newSelf(ctx, pidns, msrc),
"stat": seqfile.NewSeqFileInode(ctx, &statData{k}, msrc),
"thread-self": newThreadSelf(ctx, pidns, msrc),
@@ -87,13 +89,6 @@ func New(ctx context.Context, msrc *fs.MountSource, cgroupControllers map[string
// Add more contents that need proc to be initialized.
p.AddChild(ctx, "sys", p.newSysDir(ctx, msrc))
- // If we're using rpcinet we will let it manage /proc/net.
- if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
- p.AddChild(ctx, "net", newRPCInetProcNet(ctx, msrc))
- } else {
- p.AddChild(ctx, "net", p.newNetDir(ctx, k, msrc))
- }
-
return newProcInode(ctx, p, msrc, fs.SpecialDirectory, nil), nil
}
@@ -218,7 +213,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent
// Add dot and dotdot.
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dot, dotdot := file.Dirent.GetDotAttrs(root)
names = append(names, ".", "..")
@@ -249,3 +244,5 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent
}
return offset, nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks.go)
diff --git a/pkg/sentry/fs/proc/rpcinet_proc.go b/pkg/sentry/fs/proc/rpcinet_proc.go
deleted file mode 100644
index 01ac97530..000000000
--- a/pkg/sentry/fs/proc/rpcinet_proc.go
+++ /dev/null
@@ -1,217 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "io"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// rpcInetInode implements fs.InodeOperations.
-type rpcInetInode struct {
- fsutil.SimpleFileInode
-
- // filepath is the full path of this rpcInetInode.
- filepath string
-
- k *kernel.Kernel
-}
-
-func newRPCInetInode(ctx context.Context, msrc *fs.MountSource, filepath string, mode linux.FileMode) *fs.Inode {
- f := &rpcInetInode{
- SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(mode), linux.PROC_SUPER_MAGIC),
- filepath: filepath,
- k: kernel.KernelFromContext(ctx),
- }
- return newProcInode(ctx, f, msrc, fs.SpecialFile, nil)
-}
-
-// GetFile implements fs.InodeOperations.GetFile.
-func (i *rpcInetInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- flags.Pread = true
- flags.Pwrite = true
- fops := &rpcInetFile{
- inode: i,
- }
- return fs.NewFile(ctx, dirent, flags, fops), nil
-}
-
-// rpcInetFile implements fs.FileOperations as RPCs.
-type rpcInetFile struct {
- fsutil.FileGenericSeek `state:"nosave"`
- fsutil.FileNoIoctl `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileNoopFsync `state:"nosave"`
- fsutil.FileNoopRelease `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.AlwaysReady `state:"nosave"`
-
- inode *rpcInetInode
-}
-
-// Read implements fs.FileOperations.Read.
-//
-// This method can panic if an rpcInetInode was created without an rpcinet
-// stack.
-func (f *rpcInetFile) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
- if !ok {
- panic("Network stack is not a rpcinet.")
- }
-
- contents, se := s.RPCReadFile(f.inode.filepath)
- if se != nil || offset >= int64(len(contents)) {
- return 0, io.EOF
- }
-
- n, err := dst.CopyOut(ctx, contents[offset:])
- return int64(n), err
-}
-
-// Write implements fs.FileOperations.Write.
-//
-// This method can panic if an rpcInetInode was created without an rpcInet
-// stack.
-func (f *rpcInetFile) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
- s, ok := f.inode.k.NetworkStack().(*rpcinet.Stack)
- if !ok {
- panic("Network stack is not a rpcinet.")
- }
-
- if src.NumBytes() == 0 {
- return 0, nil
- }
-
- b := make([]byte, src.NumBytes(), src.NumBytes())
- n, err := src.CopyIn(ctx, b)
- if err != nil {
- return int64(n), err
- }
-
- written, se := s.RPCWriteFile(f.inode.filepath, b)
- return int64(written), se.ToError()
-}
-
-// newRPCInetProcNet will build an inode for /proc/net.
-func newRPCInetProcNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "arp": newRPCInetInode(ctx, msrc, "/proc/net/arp", 0444),
- "dev": newRPCInetInode(ctx, msrc, "/proc/net/dev", 0444),
- "if_inet6": newRPCInetInode(ctx, msrc, "/proc/net/if_inet6", 0444),
- "ipv6_route": newRPCInetInode(ctx, msrc, "/proc/net/ipv6_route", 0444),
- "netlink": newRPCInetInode(ctx, msrc, "/proc/net/netlink", 0444),
- "netstat": newRPCInetInode(ctx, msrc, "/proc/net/netstat", 0444),
- "packet": newRPCInetInode(ctx, msrc, "/proc/net/packet", 0444),
- "protocols": newRPCInetInode(ctx, msrc, "/proc/net/protocols", 0444),
- "psched": newRPCInetInode(ctx, msrc, "/proc/net/psched", 0444),
- "ptype": newRPCInetInode(ctx, msrc, "/proc/net/ptype", 0444),
- "route": newRPCInetInode(ctx, msrc, "/proc/net/route", 0444),
- "tcp": newRPCInetInode(ctx, msrc, "/proc/net/tcp", 0444),
- "tcp6": newRPCInetInode(ctx, msrc, "/proc/net/tcp6", 0444),
- "udp": newRPCInetInode(ctx, msrc, "/proc/net/udp", 0444),
- "udp6": newRPCInetInode(ctx, msrc, "/proc/net/udp6", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetProcSysNet will build an inode for /proc/sys/net.
-func newRPCInetProcSysNet(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "ipv4": newRPCInetSysNetIPv4Dir(ctx, msrc),
- "core": newRPCInetSysNetCore(ctx, msrc),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetSysNetCore builds the /proc/sys/net/core directory.
-func newRPCInetSysNetCore(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "default_qdisc": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/default_qdisc", 0444),
- "message_burst": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_burst", 0444),
- "message_cost": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/message_cost", 0444),
- "optmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/optmem_max", 0444),
- "rmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_default", 0444),
- "rmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/rmem_max", 0444),
- "somaxconn": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/somaxconn", 0444),
- "wmem_default": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_default", 0444),
- "wmem_max": newRPCInetInode(ctx, msrc, "/proc/sys/net/core/wmem_max", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
-
-// newRPCInetSysNetIPv4Dir builds the /proc/sys/net/ipv4 directory.
-func newRPCInetSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- contents := map[string]*fs.Inode{
- "ip_local_port_range": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_port_range", 0444),
- "ip_local_reserved_ports": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_local_reserved_ports", 0444),
- "ipfrag_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ipfrag_time", 0444),
- "ip_nonlocal_bind": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_nonlocal_bind", 0444),
- "ip_no_pmtu_disc": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/ip_no_pmtu_disc", 0444),
- "tcp_allowed_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_allowed_congestion_control", 0444),
- "tcp_available_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_available_congestion_control", 0444),
- "tcp_base_mss": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_base_mss", 0444),
- "tcp_congestion_control": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_congestion_control", 0644),
- "tcp_dsack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_dsack", 0644),
- "tcp_early_retrans": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_early_retrans", 0644),
- "tcp_fack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fack", 0644),
- "tcp_fastopen": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen", 0644),
- "tcp_fastopen_key": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fastopen_key", 0444),
- "tcp_fin_timeout": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_fin_timeout", 0644),
- "tcp_invalid_ratelimit": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_invalid_ratelimit", 0444),
- "tcp_keepalive_intvl": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_intvl", 0644),
- "tcp_keepalive_probes": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_probes", 0644),
- "tcp_keepalive_time": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_keepalive_time", 0644),
- "tcp_mem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mem", 0444),
- "tcp_mtu_probing": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_mtu_probing", 0644),
- "tcp_no_metrics_save": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_no_metrics_save", 0444),
- "tcp_probe_interval": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_interval", 0444),
- "tcp_probe_threshold": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_probe_threshold", 0444),
- "tcp_retries1": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries1", 0644),
- "tcp_retries2": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_retries2", 0644),
- "tcp_rfc1337": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rfc1337", 0444),
- "tcp_rmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_rmem", 0444),
- "tcp_sack": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_sack", 0644),
- "tcp_slow_start_after_idle": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_slow_start_after_idle", 0644),
- "tcp_synack_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_synack_retries", 0644),
- "tcp_syn_retries": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_syn_retries", 0644),
- "tcp_timestamps": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_timestamps", 0644),
- "tcp_wmem": newRPCInetInode(ctx, msrc, "/proc/sys/net/ipv4/tcp_wmem", 0444),
- }
-
- d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
- return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
-}
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index fe7067be1..21338d912 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -1,22 +1,21 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "seqfile",
srcs = ["seqfile.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/proc/device",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -25,12 +24,12 @@ go_test(
name = "seqfile_test",
size = "small",
srcs = ["seqfile_test.go"],
- embed = [":seqfile"],
+ library = ":seqfile",
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
"//pkg/sentry/fs/ramfs",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile.go b/pkg/sentry/fs/proc/seqfile/seqfile.go
index 5fe823000..6121f0e95 100644
--- a/pkg/sentry/fs/proc/seqfile/seqfile.go
+++ b/pkg/sentry/fs/proc/seqfile/seqfile.go
@@ -17,16 +17,16 @@ package seqfile
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/proc/seqfile/seqfile_test.go b/pkg/sentry/fs/proc/seqfile/seqfile_test.go
index ebfeee835..98e394569 100644
--- a/pkg/sentry/fs/proc/seqfile/seqfile_test.go
+++ b/pkg/sentry/fs/proc/seqfile/seqfile_test.go
@@ -20,11 +20,11 @@ import (
"io"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type seqTest struct {
diff --git a/pkg/sentry/fs/proc/stat.go b/pkg/sentry/fs/proc/stat.go
index b641effbb..d4fbd76ac 100644
--- a/pkg/sentry/fs/proc/stat.go
+++ b/pkg/sentry/fs/proc/stat.go
@@ -19,11 +19,13 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// statData backs /proc/stat.
//
// +stateify savable
@@ -140,3 +142,5 @@ func (s *statData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index cd37776c8..f8aad2dbd 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -20,17 +20,18 @@ import (
"strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// mmapMinAddrData backs /proc/sys/vm/mmap_min_addr.
//
// +stateify savable
@@ -104,16 +105,10 @@ func (p *proc) newVMDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
func (p *proc) newSysDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
children := map[string]*fs.Inode{
"kernel": p.newKernelDir(ctx, msrc),
+ "net": p.newSysNetDir(ctx, msrc),
"vm": p.newVMDir(ctx, msrc),
}
- // If we're using rpcinet we will let it manage /proc/sys/net.
- if _, ok := p.k.NetworkStack().(*rpcinet.Stack); ok {
- children["net"] = newRPCInetProcSysNet(ctx, msrc)
- } else {
- children["net"] = p.newSysNetDir(ctx, msrc)
- }
-
d := ramfs.NewDir(ctx, children, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
@@ -160,3 +155,5 @@ func (hf *hostnameFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ
}
var _ fs.FileOperations = (*hostnameFile)(nil)
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 794723d9c..f2f49a7f6 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -17,20 +17,22 @@ package proc
import (
"fmt"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
type tcpMemDir int
const (
@@ -65,7 +67,7 @@ var _ fs.InodeOperations = (*tcpMemInode)(nil)
func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir tcpMemDir) *fs.Inode {
tm := &tcpMemInode{
- SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
s: s,
dir: dir,
}
@@ -78,6 +80,11 @@ func newTCPMemInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack, dir
return fs.NewInode(ctx, tm, msrc, sattr)
}
+// Truncate implements fs.InodeOperations.Truncate.
+func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
// GetFile implements fs.InodeOperations.GetFile.
func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
flags.Pread = true
@@ -169,14 +176,15 @@ func writeSize(dirType tcpMemDir, s inet.Stack, size inet.TCPBufferSize) error {
// +stateify savable
type tcpSack struct {
+ fsutil.SimpleFileInode
+
stack inet.Stack `state:"wait"`
enabled *bool
- fsutil.SimpleFileInode
}
func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
ts := &tcpSack{
- SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0444), linux.PROC_SUPER_MAGIC),
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
stack: s,
}
sattr := fs.StableAttr{
@@ -188,6 +196,11 @@ func newTCPSackInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *f
return fs.NewInode(ctx, ts, msrc, sattr)
}
+// Truncate implements fs.InodeOperations.Truncate.
+func (*tcpSack) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
// GetFile implements fs.InodeOperations.GetFile.
func (s *tcpSack) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
flags.Pread = true
@@ -260,6 +273,96 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque
return n, f.tcpSack.stack.SetTCPSACKEnabled(*f.tcpSack.enabled)
}
+// +stateify savable
+type tcpRecovery struct {
+ fsutil.SimpleFileInode
+
+ stack inet.Stack `state:"wait"`
+ recovery inet.TCPLossRecovery
+}
+
+func newTCPRecoveryInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ ts := &tcpRecovery{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ stack: s,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, ts, msrc, sattr)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*tcpRecovery) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (r *tcpRecovery) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, dirent, flags, &tcpRecoveryFile{
+ tcpRecovery: r,
+ stack: r.stack,
+ }), nil
+}
+
+// +stateify savable
+type tcpRecoveryFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ tcpRecovery *tcpRecovery
+
+ stack inet.Stack `state:"wait"`
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *tcpRecoveryFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+
+ recovery, err := f.stack.TCPRecovery()
+ if err != nil {
+ return 0, err
+ }
+ f.tcpRecovery.recovery = recovery
+ s := fmt.Sprintf("%d\n", f.tcpRecovery.recovery)
+ n, err := dst.CopyOut(ctx, []byte(s))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *tcpRecoveryFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+ f.tcpRecovery.recovery = inet.TCPLossRecovery(v)
+ if err := f.tcpRecovery.stack.SetTCPRecovery(f.tcpRecovery.recovery); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
@@ -318,15 +421,22 @@ func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stac
return fs.NewInode(ctx, ipf, msrc, sattr)
}
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*ipForwarding) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
// +stateify savable
type ipForwardingFile struct {
fsutil.FileGenericSeek `state:"nosave"`
fsutil.FileNoIoctl `state:"nosave"`
fsutil.FileNoMMap `state:"nosave"`
fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopRelease `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
fsutil.FileNotDirReaddir `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
waiter.AlwaysReady `state:"nosave"`
@@ -443,13 +553,20 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
contents["tcp_wmem"] = newTCPMemInode(ctx, msrc, s, tcpWMem)
}
+ // Add tcp_recovery.
+ if _, err := s.TCPRecovery(); err == nil {
+ contents["tcp_recovery"] = newTCPRecoveryInode(ctx, msrc, s)
+ }
+
d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
"core": p.newSysNetCore(ctx, msrc, s),
@@ -458,3 +575,5 @@ func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil)
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_sys.go)
diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go
index 02e43297f..3fadb870e 100644
--- a/pkg/sentry/fs/proc/sys_net_state.go
+++ b/pkg/sentry/fs/proc/sys_net_state.go
@@ -16,6 +16,7 @@ package proc
import (
"fmt"
+
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
)
diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go
index 6e51dfbb7..72c9857d0 100644
--- a/pkg/sentry/fs/proc/sys_net_test.go
+++ b/pkg/sentry/fs/proc/sys_net_test.go
@@ -17,9 +17,9 @@ package proc
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func TestQuerySendBufferSize(t *testing.T) {
@@ -124,7 +124,9 @@ func TestConfigureRecvBufferSize(t *testing.T) {
}
}
-func TestConfigureIPForwarding(t *testing.T) {
+// TestIPForwarding tests the implementation of
+// /proc/sys/net/ipv4/ip_forwarding
+func TestIPForwarding(t *testing.T) {
ctx := context.Background()
s := inet.NewTestStack()
diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go
index 87184ec67..9cf7f2a62 100644
--- a/pkg/sentry/fs/proc/task.go
+++ b/pkg/sentry/fs/proc/task.go
@@ -22,21 +22,24 @@ import (
"strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/device"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// getTaskMM returns t's MemoryManager. If getTaskMM succeeds, the MemoryManager's
// users count is incremented, and must be decremented by the caller when it is
// no longer in use.
@@ -54,42 +57,53 @@ func getTaskMM(t *kernel.Task) (*mm.MemoryManager, error) {
return m, nil
}
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
// taskDir represents a task-level directory.
//
// +stateify savable
type taskDir struct {
ramfs.Dir
- t *kernel.Task
- pidns *kernel.PIDNamespace
+ t *kernel.Task
}
var _ fs.InodeOperations = (*taskDir)(nil)
// newTaskDir creates a new proc task entry.
-func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks bool) *fs.Inode {
+func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
contents := map[string]*fs.Inode{
- "auxv": newAuxvec(t, msrc),
- "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
- "comm": newComm(t, msrc),
- "environ": newExecArgInode(t, msrc, environExecArg),
- "exe": newExe(t, msrc),
- "fd": newFdDir(t, msrc),
- "fdinfo": newFdInfoDir(t, msrc),
- "gid_map": newGIDMap(t, msrc),
- // FIXME(b/123511468): create the correct io file for threads.
- "io": newIO(t, msrc),
- "maps": newMaps(t, msrc),
- "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
- "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
- "ns": newNamespaceDir(t, msrc),
- "smaps": newSmaps(t, msrc),
- "stat": newTaskStat(t, msrc, showSubtasks, p.pidns),
- "statm": newStatm(t, msrc),
- "status": newStatus(t, msrc, p.pidns),
- "uid_map": newUIDMap(t, msrc),
- }
- if showSubtasks {
+ "auxv": newAuxvec(t, msrc),
+ "cmdline": newExecArgInode(t, msrc, cmdlineExecArg),
+ "comm": newComm(t, msrc),
+ "environ": newExecArgInode(t, msrc, environExecArg),
+ "exe": newExe(t, msrc),
+ "fd": newFdDir(t, msrc),
+ "fdinfo": newFdInfoDir(t, msrc),
+ "gid_map": newGIDMap(t, msrc),
+ "io": newIO(t, msrc, isThreadGroup),
+ "maps": newMaps(t, msrc),
+ "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc),
+ "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc),
+ "net": newNetDir(t, msrc),
+ "ns": newNamespaceDir(t, msrc),
+ "oom_score": newOOMScore(t, msrc),
+ "oom_score_adj": newOOMScoreAdj(t, msrc),
+ "smaps": newSmaps(t, msrc),
+ "stat": newTaskStat(t, msrc, isThreadGroup, p.pidns),
+ "statm": newStatm(t, msrc),
+ "status": newStatus(t, msrc, p.pidns),
+ "uid_map": newUIDMap(t, msrc),
+ }
+ if isThreadGroup {
contents["task"] = p.newSubtasks(t, msrc)
}
if len(p.cgroupControllers) > 0 {
@@ -171,7 +185,7 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry
// Serialize "." and "..".
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dot, dotdot := file.Dirent.GetDotAttrs(root)
if err := dirCtx.DirEmit(".", dot); err != nil {
@@ -248,12 +262,13 @@ func newExe(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
return newProcInode(t, exeSymlink, msrc, fs.Symlink, t)
}
-func (e *exe) executable() (d *fs.Dirent, err error) {
+func (e *exe) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(e.t); err != nil {
+ return nil, err
+ }
e.t.WithMuLocked(func(t *kernel.Task) {
mm := t.MemoryManager()
if mm == nil {
- // TODO(b/34851096): Check shouldn't allow Readlink once the
- // Task is zombied.
err = syserror.EACCES
return
}
@@ -261,9 +276,9 @@ func (e *exe) executable() (d *fs.Dirent, err error) {
// The MemoryManager may be destroyed, in which case
// MemoryManager.destroy will simply set the executable to nil
// (with locks held).
- d = mm.Executable()
- if d == nil {
- err = syserror.ENOENT
+ file = mm.Executable()
+ if file == nil {
+ err = syserror.ESRCH
}
})
return
@@ -280,17 +295,9 @@ func (e *exe) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
if err != nil {
return "", err
}
- defer exec.DecRef()
+ defer exec.DecRef(ctx)
- root := fs.RootFromContext(ctx)
- if root == nil {
- // This doesn't correspond to anything in Linux because the vfs is
- // global there.
- return "", syserror.EINVAL
- }
- defer root.DecRef()
- n, _ := exec.FullName(root)
- return n, nil
+ return exec.PathnameWithDeleted(ctx), nil
}
// namespaceSymlink represents a symlink in the namespacefs, such as the files
@@ -316,11 +323,22 @@ func newNamespaceSymlink(t *kernel.Task, msrc *fs.MountSource, name string) *fs.
return newProcInode(t, n, msrc, fs.Symlink, t)
}
+// Readlink reads the symlink value.
+func (n *namespaceSymlink) Readlink(ctx context.Context, inode *fs.Inode) (string, error) {
+ if err := checkTaskState(n.t); err != nil {
+ return "", err
+ }
+ return n.Symlink.Readlink(ctx, inode)
+}
+
// Getlink implements fs.InodeOperations.Getlink.
func (n *namespaceSymlink) Getlink(ctx context.Context, inode *fs.Inode) (*fs.Dirent, error) {
if !kernel.ContextCanTrace(ctx, n.t, false) {
return nil, syserror.EACCES
}
+ if err := checkTaskState(n.t); err != nil {
+ return nil, err
+ }
// Create a new regular file to fake the namespace file.
iops := fsutil.NewNoReadWriteFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0777), linux.PROC_SUPER_MAGIC)
@@ -605,6 +623,10 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) (
fmt.Fprintf(&buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
fmt.Fprintf(&buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
fmt.Fprintf(&buf, "Seccomp:\t%d\n", s.t.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(&buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(&buf, "Mems_allowed_list:\t0\n")
return []seqfile.SeqData{{Buf: buf.Bytes(), Handle: (*statusData)(nil)}}, 0
}
@@ -619,8 +641,11 @@ type ioData struct {
ioUsage
}
-func newIO(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
- return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+func newIO(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bool) *fs.Inode {
+ if isThreadGroup {
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t.ThreadGroup()}), msrc, fs.SpecialFile, t)
+ }
+ return newProcInode(t, seqfile.NewSeqFile(t, &ioData{t}), msrc, fs.SpecialFile, t)
}
// NeedsUpdate returns whether the generation is old or not.
@@ -639,7 +664,7 @@ func (i *ioData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
io.Accumulate(i.IOUsage())
var buf bytes.Buffer
- fmt.Fprintf(&buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(&buf, "rchar: %d\n", io.CharsRead)
fmt.Fprintf(&buf, "wchar: %d\n", io.CharsWritten)
fmt.Fprintf(&buf, "syscr: %d\n", io.ReadSyscalls)
fmt.Fprintf(&buf, "syscw: %d\n", io.WriteSyscalls)
@@ -794,3 +819,96 @@ func (f *auxvecFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
n, err := dst.CopyOut(ctx, buf[offset:])
return int64(n), err
}
+
+// newOOMScore returns a oom_score file. It is a stub that always returns 0.
+// TODO(gvisor.dev/issue/1967)
+func newOOMScore(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ return newStaticProcInode(t, msrc, []byte("0\n"))
+}
+
+// oomScoreAdj is a file containing the oom_score adjustment for a task.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ fsutil.SimpleFileInode
+
+ t *kernel.Task
+}
+
+// +stateify savable
+type oomScoreAdjFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ t *kernel.Task
+}
+
+// newOOMScoreAdj returns a oom_score_adj file.
+func newOOMScoreAdj(t *kernel.Task, msrc *fs.MountSource) *fs.Inode {
+ i := &oomScoreAdj{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ t: t,
+ }
+ return newProcInode(t, i, msrc, fs.SpecialFile, t)
+}
+
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*oomScoreAdj) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (o *oomScoreAdj) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, dirent, flags, &oomScoreAdjFile{t: o.t}), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (f *oomScoreAdjFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "%d\n", f.t.OOMScoreAdj())
+ if offset >= int64(buf.Len()) {
+ return 0, io.EOF
+ }
+ n, err := dst.CopyOut(ctx, buf.Bytes()[offset:])
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (f *oomScoreAdjFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if f.t.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := f.t.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// LINT.ThenChange(../../fsimpl/proc/task.go|../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uid_gid_map.go b/pkg/sentry/fs/proc/uid_gid_map.go
index eea37d15c..8d9517b95 100644
--- a/pkg/sentry/fs/proc/uid_gid_map.go
+++ b/pkg/sentry/fs/proc/uid_gid_map.go
@@ -20,16 +20,18 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// idMapInodeOperations implements fs.InodeOperations for
// /proc/[pid]/{uid,gid}_map.
//
@@ -177,3 +179,5 @@ func (imfo *idMapFileOperations) Write(ctx context.Context, file *fs.File, src u
// count, even if fewer bytes were used.
return int64(srclen), nil
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/proc/uptime.go b/pkg/sentry/fs/proc/uptime.go
index 4e903917a..c0f6fb802 100644
--- a/pkg/sentry/fs/proc/uptime.go
+++ b/pkg/sentry/fs/proc/uptime.go
@@ -19,15 +19,17 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// uptime is a file containing the system uptime.
//
// +stateify savable
@@ -85,3 +87,5 @@ func (f *uptimeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequenc
n, err := dst.CopyOut(ctx, s[offset:])
return int64(n), err
}
+
+// LINT.ThenChange(../../fsimpl/proc/tasks_files.go)
diff --git a/pkg/sentry/fs/proc/version.go b/pkg/sentry/fs/proc/version.go
index a6d2c3cd3..35e258ff6 100644
--- a/pkg/sentry/fs/proc/version.go
+++ b/pkg/sentry/fs/proc/version.go
@@ -17,11 +17,13 @@ package proc
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/kernel"
)
+// LINT.IfChange
+
// versionData backs /proc/version.
//
// +stateify savable
@@ -76,3 +78,5 @@ func (v *versionData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle)
},
}, 0
}
+
+// LINT.ThenChange(../../fsimpl/proc/task_files.go)
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index 012cb3e44..8ca823fb3 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,17 +10,17 @@ go_library(
"symlink.go",
"tree.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/ramfs",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/socket/unix/transport",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -30,9 +29,9 @@ go_test(
name = "ramfs_test",
size = "small",
srcs = ["tree_test.go"],
- embed = [":ramfs"],
+ library = ":ramfs",
deps = [
- "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
],
)
diff --git a/pkg/sentry/fs/ramfs/dir.go b/pkg/sentry/fs/ramfs/dir.go
index 78e082b8e..f4fcddecb 100644
--- a/pkg/sentry/fs/ramfs/dir.go
+++ b/pkg/sentry/fs/ramfs/dir.go
@@ -17,14 +17,14 @@ package ramfs
import (
"fmt"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -219,7 +219,7 @@ func (d *Dir) Remove(ctx context.Context, _ *fs.Inode, name string) error {
}
// Remove our reference on the inode.
- inode.DecRef()
+ inode.DecRef(ctx)
return nil
}
@@ -250,7 +250,7 @@ func (d *Dir) RemoveDirectory(ctx context.Context, _ *fs.Inode, name string) err
}
// Remove our reference on the inode.
- inode.DecRef()
+ inode.DecRef(ctx)
return nil
}
@@ -326,7 +326,7 @@ func (d *Dir) Create(ctx context.Context, dir *fs.Inode, name string, flags fs.F
// Create the Dirent and corresponding file.
created := fs.NewDirent(ctx, inode, name)
- defer created.DecRef()
+ defer created.DecRef(ctx)
return created.Inode.GetFile(ctx, created, flags)
}
@@ -412,11 +412,11 @@ func (*Dir) Rename(ctx context.Context, inode *fs.Inode, oldParent *fs.Inode, ol
}
// Release implements fs.InodeOperation.Release.
-func (d *Dir) Release(_ context.Context) {
+func (d *Dir) Release(ctx context.Context) {
// Drop references on all children.
d.mu.Lock()
for _, i := range d.children {
- i.DecRef()
+ i.DecRef(ctx)
}
d.mu.Unlock()
}
@@ -456,7 +456,7 @@ func (dfo *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirC
func (dfo *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
@@ -473,13 +473,13 @@ func hasChildren(ctx context.Context, inode *fs.Inode) (bool, error) {
// dropped when that dirent is destroyed.
inode.IncRef()
d := fs.NewTransientDirent(inode)
- defer d.DecRef()
+ defer d.DecRef(ctx)
file, err := inode.GetFile(ctx, d, fs.FileFlags{Read: true})
if err != nil {
return false, err
}
- defer file.DecRef()
+ defer file.DecRef(ctx)
ser := &fs.CollectEntriesSerializer{}
if err := file.Readdir(ctx, ser); err != nil {
@@ -530,7 +530,7 @@ func Rename(ctx context.Context, oldParent fs.InodeOperations, oldName string, n
if err != nil {
return err
}
- inode.DecRef()
+ inode.DecRef(ctx)
}
// Be careful, we may have already grabbed this mutex above.
diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go
index a24fe2ea2..29ff004f2 100644
--- a/pkg/sentry/fs/ramfs/socket.go
+++ b/pkg/sentry/fs/ramfs/socket.go
@@ -16,7 +16,7 @@ package ramfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
diff --git a/pkg/sentry/fs/ramfs/symlink.go b/pkg/sentry/fs/ramfs/symlink.go
index fcfaa29aa..d988349aa 100644
--- a/pkg/sentry/fs/ramfs/symlink.go
+++ b/pkg/sentry/fs/ramfs/symlink.go
@@ -16,7 +16,7 @@ package ramfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/waiter"
diff --git a/pkg/sentry/fs/ramfs/tree.go b/pkg/sentry/fs/ramfs/tree.go
index 702cc4a1e..dfc9d3453 100644
--- a/pkg/sentry/fs/ramfs/tree.go
+++ b/pkg/sentry/fs/ramfs/tree.go
@@ -19,10 +19,10 @@ import (
"path"
"strings"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// MakeDirectoryTree constructs a ramfs tree of all directories containing
diff --git a/pkg/sentry/fs/ramfs/tree_test.go b/pkg/sentry/fs/ramfs/tree_test.go
index 61a7e2900..3e0d1e07e 100644
--- a/pkg/sentry/fs/ramfs/tree_test.go
+++ b/pkg/sentry/fs/ramfs/tree_test.go
@@ -17,7 +17,7 @@ package ramfs
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
@@ -67,7 +67,7 @@ func TestMakeDirectoryTree(t *testing.T) {
continue
}
root := mm.Root()
- defer mm.DecRef()
+ defer mm.DecRef(ctx)
for _, p := range test.subdirs {
maxTraversals := uint(0)
diff --git a/pkg/sentry/fs/restore.go b/pkg/sentry/fs/restore.go
index f10168125..64c6a6ae9 100644
--- a/pkg/sentry/fs/restore.go
+++ b/pkg/sentry/fs/restore.go
@@ -15,7 +15,7 @@
package fs
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// RestoreEnvironment is the restore environment for file systems. It consists
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index 311798811..33da82868 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -18,7 +18,7 @@ import (
"io"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/fs/sys/BUILD b/pkg/sentry/fs/sys/BUILD
index 25f0f124e..f2e8b9932 100644
--- a/pkg/sentry/fs/sys/BUILD
+++ b/pkg/sentry/fs/sys/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -10,16 +10,15 @@ go_library(
"fs.go",
"sys.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/sys",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/kernel",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/sys/devices.go b/pkg/sentry/fs/sys/devices.go
index 4f78ca8d2..b67065956 100644
--- a/pkg/sentry/fs/sys/devices.go
+++ b/pkg/sentry/fs/sys/devices.go
@@ -18,7 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
diff --git a/pkg/sentry/fs/sys/fs.go b/pkg/sentry/fs/sys/fs.go
index e60b63e75..fd03a4e38 100644
--- a/pkg/sentry/fs/sys/fs.go
+++ b/pkg/sentry/fs/sys/fs.go
@@ -15,7 +15,7 @@
package sys
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
)
diff --git a/pkg/sentry/fs/sys/sys.go b/pkg/sentry/fs/sys/sys.go
index b14bf3f55..0891645e4 100644
--- a/pkg/sentry/fs/sys/sys.go
+++ b/pkg/sentry/fs/sys/sys.go
@@ -16,10 +16,10 @@
package sys
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func newFile(ctx context.Context, node fs.InodeOperations, msrc *fs.MountSource) *fs.Inode {
diff --git a/pkg/sentry/fs/timerfd/BUILD b/pkg/sentry/fs/timerfd/BUILD
index a215c1b95..d16cdb4df 100644
--- a/pkg/sentry/fs/timerfd/BUILD
+++ b/pkg/sentry/fs/timerfd/BUILD
@@ -1,20 +1,19 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "timerfd",
srcs = ["timerfd.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/timerfd",
visibility = ["//pkg/sentry:internal"],
deps = [
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/usermem",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index f8bf663bb..f362ca9b6 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -19,13 +19,13 @@ package timerfd
import (
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -55,7 +55,7 @@ type TimerOperations struct {
func NewFile(ctx context.Context, c ktime.Clock) *fs.File {
dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[timerfd]")
// Release the initial dirent reference after NewFile takes a reference.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
tops := &TimerOperations{}
tops.timer = ktime.NewTimer(c, tops)
// Timerfds reject writes, but the Write flag must be set in order to
@@ -65,7 +65,7 @@ func NewFile(ctx context.Context, c ktime.Clock) *fs.File {
}
// Release implements fs.FileOperations.Release.
-func (t *TimerOperations) Release() {
+func (t *TimerOperations) Release(context.Context) {
t.timer.Destroy()
}
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 59ce400c2..aa7199014 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -12,12 +11,12 @@ go_library(
"inode_file.go",
"tmpfs.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/metric",
- "//pkg/sentry/context",
+ "//pkg/safemem",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
@@ -27,11 +26,11 @@ go_library(
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -40,12 +39,12 @@ go_test(
name = "tmpfs_test",
size = "small",
srcs = ["file_test.go"],
- embed = [":tmpfs"],
+ library = ":tmpfs",
deps = [
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/contexttest",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/tmpfs/file_regular.go b/pkg/sentry/fs/tmpfs/file_regular.go
index 9a6943fe4..614f8f8a1 100644
--- a/pkg/sentry/fs/tmpfs/file_regular.go
+++ b/pkg/sentry/fs/tmpfs/file_regular.go
@@ -15,11 +15,11 @@
package tmpfs
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/fs/tmpfs/file_test.go b/pkg/sentry/fs/tmpfs/file_test.go
index 0075ef023..d4d613ea9 100644
--- a/pkg/sentry/fs/tmpfs/file_test.go
+++ b/pkg/sentry/fs/tmpfs/file_test.go
@@ -18,11 +18,11 @@ import (
"bytes"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func newFileInode(ctx context.Context) *fs.Inode {
@@ -46,7 +46,7 @@ func newFile(ctx context.Context) *fs.File {
func TestGrow(t *testing.T) {
ctx := contexttest.Context(t)
f := newFile(ctx)
- defer f.DecRef()
+ defer f.DecRef(ctx)
abuf := bytes.Repeat([]byte{'a'}, 68)
n, err := f.Pwritev(ctx, usermem.BytesIOSequence(abuf), 0)
diff --git a/pkg/sentry/fs/tmpfs/fs.go b/pkg/sentry/fs/tmpfs/fs.go
index be98ad751..bc117ca6a 100644
--- a/pkg/sentry/fs/tmpfs/fs.go
+++ b/pkg/sentry/fs/tmpfs/fs.go
@@ -19,7 +19,7 @@ import (
"strconv"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -44,9 +44,6 @@ const (
// lookup.
cacheRevalidate = "revalidate"
- // TODO(edahlgren/mpratt): support a tmpfs size limit.
- // size = "size"
-
// Permissions that exceed modeMask will be rejected.
modeMask = 01777
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index f86dfaa36..1dc75291d 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -17,28 +17,29 @@ package tmpfs
import (
"fmt"
"io"
- "sync"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/metric"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
var (
opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.")
opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.")
reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.")
- readWait = metric.MustCreateNewUint64Metric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
+ readWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.")
)
// fileInodeOperations implements fs.InodeOperations for a regular tmpfs file.
@@ -444,10 +445,15 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
defer rw.f.dataMu.Unlock()
// Compute the range to write.
- end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
- if end == rw.offset { // srcs.NumBytes() == 0?
+ if srcs.NumBytes() == 0 {
+ // Nothing to do.
return 0, nil
}
+ end := fs.WriteEndOffset(rw.offset, int64(srcs.NumBytes()))
+ if end == math.MaxInt64 {
+ // Overflow.
+ return 0, syserror.EINVAL
+ }
// Check if seals prevent either file growth or all writes.
switch {
diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go
index 69089c8a8..b095312fe 100644
--- a/pkg/sentry/fs/tmpfs/tmpfs.go
+++ b/pkg/sentry/fs/tmpfs/tmpfs.go
@@ -17,7 +17,7 @@ package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
@@ -25,8 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
var fsInfo = fs.Info{
@@ -39,14 +39,13 @@ var fsInfo = fs.Info{
// rename implements fs.InodeOperations.Rename for tmpfs nodes.
func rename(ctx context.Context, oldParent *fs.Inode, oldName string, newParent *fs.Inode, newName string, replacement bool) error {
- op, ok := oldParent.InodeOperations.(*Dir)
- if !ok {
- return syserror.EXDEV
- }
- np, ok := newParent.InodeOperations.(*Dir)
- if !ok {
+ // Don't allow renames across different mounts.
+ if newParent.MountSource != oldParent.MountSource {
return syserror.EXDEV
}
+
+ op := oldParent.InodeOperations.(*Dir)
+ np := newParent.InodeOperations.(*Dir)
return ramfs.Rename(ctx, op.ramfsDir, oldName, np.ramfsDir, newName, replacement)
}
@@ -148,19 +147,24 @@ func (d *Dir) CreateFifo(ctx context.Context, dir *fs.Inode, name string, perms
return d.ramfsDir.CreateFifo(ctx, dir, name, perms)
}
-// Getxattr implements fs.InodeOperations.Getxattr.
-func (d *Dir) Getxattr(i *fs.Inode, name string) (string, error) {
- return d.ramfsDir.Getxattr(i, name)
+// GetXattr implements fs.InodeOperations.GetXattr.
+func (d *Dir) GetXattr(ctx context.Context, i *fs.Inode, name string, size uint64) (string, error) {
+ return d.ramfsDir.GetXattr(ctx, i, name, size)
+}
+
+// SetXattr implements fs.InodeOperations.SetXattr.
+func (d *Dir) SetXattr(ctx context.Context, i *fs.Inode, name, value string, flags uint32) error {
+ return d.ramfsDir.SetXattr(ctx, i, name, value, flags)
}
-// Setxattr implements fs.InodeOperations.Setxattr.
-func (d *Dir) Setxattr(i *fs.Inode, name, value string) error {
- return d.ramfsDir.Setxattr(i, name, value)
+// ListXattr implements fs.InodeOperations.ListXattr.
+func (d *Dir) ListXattr(ctx context.Context, i *fs.Inode, size uint64) (map[string]struct{}, error) {
+ return d.ramfsDir.ListXattr(ctx, i, size)
}
-// Listxattr implements fs.InodeOperations.Listxattr.
-func (d *Dir) Listxattr(i *fs.Inode) (map[string]struct{}, error) {
- return d.ramfsDir.Listxattr(i)
+// RemoveXattr implements fs.InodeOperations.RemoveXattr.
+func (d *Dir) RemoveXattr(ctx context.Context, i *fs.Inode, name string) error {
+ return d.ramfsDir.RemoveXattr(ctx, i, name)
}
// Lookup implements fs.InodeOperations.Lookup.
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 95ad98cb0..5cb0e0417 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -14,23 +13,23 @@ go_library(
"slave.go",
"terminal.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fs/tty",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/refs",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/unimpl",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -39,10 +38,10 @@ go_test(
name = "tty_test",
size = "small",
srcs = ["tty_test.go"],
- embed = [":tty"],
+ library = ":tty",
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context/contexttest",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/contexttest",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 2f639c823..463f6189e 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -19,16 +19,16 @@ import (
"fmt"
"math"
"strconv"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -132,7 +132,7 @@ func (d *dirInodeOperations) Release(ctx context.Context) {
d.mu.Lock()
defer d.mu.Unlock()
- d.master.DecRef()
+ d.master.DecRef(ctx)
if len(d.slaves) != 0 {
panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
}
@@ -263,7 +263,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e
}
// masterClose is called when the master end of t is closed.
-func (d *dirInodeOperations) masterClose(t *Terminal) {
+func (d *dirInodeOperations) masterClose(ctx context.Context, t *Terminal) {
d.mu.Lock()
defer d.mu.Unlock()
@@ -277,7 +277,7 @@ func (d *dirInodeOperations) masterClose(t *Terminal) {
panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d))
}
- s.DecRef()
+ s.DecRef(ctx)
delete(d.slaves, t.n)
d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10))
}
@@ -322,7 +322,7 @@ func (df *dirFileOperations) IterateDir(ctx context.Context, d *fs.Dirent, dirCt
func (df *dirFileOperations) Readdir(ctx context.Context, file *fs.File, serializer fs.DentrySerializer) (int64, error) {
root := fs.RootFromContext(ctx)
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(ctx)
}
dirCtx := &fs.DirCtx{
Serializer: serializer,
diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go
index edee56c12..2d4d44bf3 100644
--- a/pkg/sentry/fs/tty/fs.go
+++ b/pkg/sentry/fs/tty/fs.go
@@ -15,7 +15,7 @@
package tty
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -108,4 +108,4 @@ func (superOperations) ResetInodeMappings() {}
func (superOperations) SaveInodeMapping(*fs.Inode, string) {}
// Destroy implements MountSourceOperations.Destroy.
-func (superOperations) Destroy() {}
+func (superOperations) Destroy(context.Context) {}
diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go
index 7cc0eb409..2e9dd2d55 100644
--- a/pkg/sentry/fs/tty/line_discipline.go
+++ b/pkg/sentry/fs/tty/line_discipline.go
@@ -16,17 +16,19 @@ package tty
import (
"bytes"
- "sync"
"unicode/utf8"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// canonMaxBytes is the number of bytes that fit into a single line of
// terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE
@@ -140,8 +142,10 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc
// buffer to its read buffer. Anything already in the read buffer is
// now readable.
if oldCanonEnabled && !l.termios.LEnabled(linux.ICANON) {
- l.inQueue.pushWaitBuf(l)
+ l.inQueue.mu.Lock()
+ l.inQueue.pushWaitBufLocked(l)
l.inQueue.readable = true
+ l.inQueue.mu.Unlock()
l.slaveWaiter.Notify(waiter.EventIn)
}
@@ -441,3 +445,5 @@ func (l *lineDiscipline) peek(b []byte) int {
}
return size
}
+
+// LINT.ThenChange(../../fsimpl/devpts/line_discipline.go)
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index 19b7557d5..e00746017 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -16,16 +16,18 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// masterInodeOperations are the fs.InodeOperations for the master end of the
// Terminal (ptmx file).
//
@@ -73,7 +75,12 @@ func newMasterInode(ctx context.Context, d *dirInodeOperations, owner fs.FileOwn
}
// Release implements fs.InodeOperations.Release.
-func (mi *masterInodeOperations) Release(ctx context.Context) {
+func (mi *masterInodeOperations) Release(context.Context) {
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*masterInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
}
// GetFile implements fs.InodeOperations.GetFile.
@@ -113,9 +120,9 @@ type masterFileOperations struct {
var _ fs.FileOperations = (*masterFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (mf *masterFileOperations) Release() {
- mf.d.masterClose(mf.t)
- mf.t.DecRef()
+func (mf *masterFileOperations) Release(ctx context.Context) {
+ mf.d.masterClose(ctx, mf.t)
+ mf.t.DecRef(ctx)
}
// EventRegister implements waiter.Waitable.EventRegister.
@@ -227,3 +234,5 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
unimpl.EmitUnimplementedEvent(ctx)
}
}
+
+// LINT.ThenChange(../../fsimpl/devpts/master.go)
diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go
index 231e4e6eb..ceabb9b1e 100644
--- a/pkg/sentry/fs/tty/queue.go
+++ b/pkg/sentry/fs/tty/queue.go
@@ -15,17 +15,18 @@
package tty
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
// TTYB_DEFAULT_MEM_LIMIT.
const waitBufMaxBytes = 131072
@@ -198,18 +199,11 @@ func (q *queue) writeBytes(b []byte, l *lineDiscipline) {
q.pushWaitBufLocked(l)
}
-// pushWaitBuf fills the queue's read buffer with data from the wait buffer.
+// pushWaitBufLocked fills the queue's read buffer with data from the wait
+// buffer.
//
// Preconditions:
// * l.termiosMu must be held for reading.
-func (q *queue) pushWaitBuf(l *lineDiscipline) int {
- q.mu.Lock()
- defer q.mu.Unlock()
- return q.pushWaitBufLocked(l)
-}
-
-// Preconditions:
-// * l.termiosMu must be held for reading.
// * q.mu must be locked.
func (q *queue) pushWaitBufLocked(l *lineDiscipline) int {
if q.waitBufLen == 0 {
@@ -242,3 +236,5 @@ func (q *queue) waitBufAppend(b []byte) {
q.waitBuf = append(q.waitBuf, b)
q.waitBufLen += uint64(len(b))
}
+
+// LINT.ThenChange(../../fsimpl/devpts/queue.go)
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index 944c4ada1..7c7292687 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -16,15 +16,17 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// slaveInodeOperations are the fs.InodeOperations for the slave end of the
// Terminal (pts file).
//
@@ -69,7 +71,12 @@ func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owne
// Release implements fs.InodeOperations.Release.
func (si *slaveInodeOperations) Release(ctx context.Context) {
- si.t.DecRef()
+ si.t.DecRef(ctx)
+}
+
+// Truncate implements fs.InodeOperations.Truncate.
+func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
}
// GetFile implements fs.InodeOperations.GetFile.
@@ -99,7 +106,7 @@ type slaveFileOperations struct {
var _ fs.FileOperations = (*slaveFileOperations)(nil)
// Release implements fs.FileOperations.Release.
-func (sf *slaveFileOperations) Release() {
+func (sf *slaveFileOperations) Release(context.Context) {
}
// EventRegister implements waiter.Waitable.EventRegister.
@@ -167,3 +174,5 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem
return 0, syserror.ENOTTY
}
}
+
+// LINT.ThenChange(../../fsimpl/devpts/slave.go)
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index ff8138820..ddcccf4da 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -16,13 +16,15 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Terminal is a pseudoterminal.
//
// +stateify savable
@@ -53,8 +55,8 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal
d: d,
n: n,
ld: newLineDiscipline(termios),
- masterKTTY: &kernel.TTY{},
- slaveKTTY: &kernel.TTY{},
+ masterKTTY: &kernel.TTY{Index: n},
+ slaveKTTY: &kernel.TTY{Index: n},
}
t.EnableLeakCheck("tty.Terminal")
return &t
@@ -126,3 +128,5 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
}
return tm.slaveKTTY
}
+
+// LINT.ThenChange(../../fsimpl/devpts/terminal.go)
diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go
index 59f07ff8e..2cbc05678 100644
--- a/pkg/sentry/fs/tty/tty_test.go
+++ b/pkg/sentry/fs/tty/tty_test.go
@@ -18,8 +18,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func TestSimpleMasterToSlave(t *testing.T) {
diff --git a/pkg/sentry/fs/user/BUILD b/pkg/sentry/fs/user/BUILD
new file mode 100644
index 000000000..66e949c95
--- /dev/null
+++ b/pkg/sentry/fs/user/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "user",
+ srcs = [
+ "path.go",
+ "user.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "user_test",
+ size = "small",
+ srcs = ["user_test.go"],
+ library = ":user",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/contexttest",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fs/user/path.go b/pkg/sentry/fs/user/path.go
new file mode 100644
index 000000000..2f5a43b84
--- /dev/null
+++ b/pkg/sentry/fs/user/path.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package user
+
+import (
+ "fmt"
+ "path"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ResolveExecutablePath resolves the given executable name given the working
+// dir and environment.
+func ResolveExecutablePath(ctx context.Context, args *kernel.CreateProcessArgs) (string, error) {
+ name := args.Filename
+ if len(name) == 0 {
+ if len(args.Argv) == 0 {
+ return "", fmt.Errorf("no filename or command provided")
+ }
+ name = args.Argv[0]
+ }
+
+ // Absolute paths can be used directly.
+ if path.IsAbs(name) {
+ return name, nil
+ }
+
+ // Paths with '/' in them should be joined to the working directory, or
+ // to the root if working directory is not set.
+ if strings.IndexByte(name, '/') > 0 {
+ wd := args.WorkingDirectory
+ if wd == "" {
+ wd = "/"
+ }
+ if !path.IsAbs(wd) {
+ return "", fmt.Errorf("working directory %q must be absolute", wd)
+ }
+ return path.Join(wd, name), nil
+ }
+
+ // Otherwise, We must lookup the name in the paths.
+ paths := getPath(args.Envv)
+ if kernel.VFS2Enabled {
+ f, err := resolveVFS2(ctx, args.Credentials, args.MountNamespaceVFS2, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+ }
+
+ f, err := resolve(ctx, args.MountNamespace, paths, name)
+ if err != nil {
+ return "", fmt.Errorf("error finding executable %q in PATH %v: %v", name, paths, err)
+ }
+ return f, nil
+}
+
+func resolve(ctx context.Context, mns *fs.MountNamespace, paths []string, name string) (string, error) {
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // Caller has no root. Don't bother traversing anything.
+ return "", syserror.ENOENT
+ }
+ defer root.DecRef(ctx)
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ traversals := uint(linux.MaxSymlinkTraversals)
+ d, err := mns.FindInode(ctx, root, nil, binPath, &traversals)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return "", err
+ }
+ defer d.DecRef(ctx)
+
+ // Check that it is a regular file.
+ if !fs.IsRegular(d.Inode.StableAttr) {
+ continue
+ }
+
+ // Check whether we can read and execute the found file.
+ if err := d.Inode.CheckPermission(ctx, fs.PermMask{Read: true, Execute: true}); err != nil {
+ log.Infof("Found executable at %q, but user cannot execute it: %v", binPath, err)
+ continue
+ }
+ return path.Join("/", p, name), nil
+ }
+
+ // Couldn't find it.
+ return "", syserror.ENOENT
+}
+
+func resolveVFS2(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, paths []string, name string) (string, error) {
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ for _, p := range paths {
+ if !path.IsAbs(p) {
+ // Relative paths aren't safe, no one should be using them.
+ log.Warningf("Skipping relative path %q in $PATH", p)
+ continue
+ }
+
+ binPath := path.Join(p, name)
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(binPath),
+ FollowFinalSymlink: true,
+ }
+ opts := &vfs.OpenOptions{
+ FileExec: true,
+ Flags: linux.O_RDONLY,
+ }
+ dentry, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, pop, opts)
+ if err == syserror.ENOENT || err == syserror.EACCES {
+ // Didn't find it here.
+ continue
+ }
+ if err != nil {
+ return "", err
+ }
+ dentry.DecRef(ctx)
+
+ return binPath, nil
+ }
+
+ // Couldn't find it.
+ return "", syserror.ENOENT
+}
+
+// getPath returns the PATH as a slice of strings given the environment
+// variables.
+func getPath(env []string) []string {
+ const prefix = "PATH="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.Split(strings.TrimPrefix(e, prefix), ":")
+ }
+ }
+ return nil
+}
diff --git a/runsc/boot/user.go b/pkg/sentry/fs/user/user.go
index 56cc12ee0..936fd3932 100644
--- a/runsc/boot/user.go
+++ b/pkg/sentry/fs/user/user.go
@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package boot
+// Package user contains methods for resolving filesystem paths based on the
+// user and their environment.
+package user
import (
"bufio"
@@ -22,10 +24,12 @@ import (
"strings"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type fileReader struct {
@@ -58,7 +62,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K
// doesn't exist we will return the default home directory.
return defaultHome, nil
}
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
// Check read permissions on the file.
if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Read: true}); err != nil {
@@ -77,13 +81,55 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K
if err != nil {
return "", err
}
- defer f.DecRef()
+ defer f.DecRef(ctx)
r := &fileReader{
Ctx: ctx,
File: f,
}
+ return findHomeInPasswd(uint32(uid), r, defaultHome)
+}
+
+type fileReaderVFS2 struct {
+ ctx context.Context
+ fd *vfs.FileDescription
+}
+
+func (r *fileReaderVFS2) Read(buf []byte) (int, error) {
+ n, err := r.fd.Read(r.ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ return int(n), err
+}
+
+func getExecUserHomeVFS2(ctx context.Context, mns *vfs.MountNamespace, uid auth.KUID) (string, error) {
+ const defaultHome = "/"
+
+ root := mns.Root()
+ defer root.DecRef(ctx)
+
+ creds := auth.CredentialsFromContext(ctx)
+
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse("/etc/passwd"),
+ }
+
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ }
+
+ fd, err := root.Mount().Filesystem().VirtualFilesystem().OpenAt(ctx, creds, target, opts)
+ if err != nil {
+ return defaultHome, nil
+ }
+ defer fd.DecRef(ctx)
+
+ r := &fileReaderVFS2{
+ ctx: ctx,
+ fd: fd,
+ }
+
homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
if err != nil {
return "", err
@@ -92,10 +138,10 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.K
return homeDir, nil
}
-// maybeAddExecUserHome returns a new slice with the HOME enviroment variable
+// MaybeAddExecUserHome returns a new slice with the HOME enviroment variable
// set if the slice does not already contain it, otherwise it returns the
// original slice unmodified.
-func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+func MaybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
// Check if the envv already contains HOME.
for _, env := range envv {
if strings.HasPrefix(env, "HOME=") {
@@ -111,6 +157,29 @@ func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.
if err != nil {
return nil, fmt.Errorf("error reading exec user: %v", err)
}
+
+ return append(envv, "HOME="+homeDir), nil
+}
+
+// MaybeAddExecUserHomeVFS2 returns a new slice with the HOME enviroment
+// variable set if the slice does not already contain it, otherwise it returns
+// the original slice unmodified.
+func MaybeAddExecUserHomeVFS2(ctx context.Context, vmns *vfs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHomeVFS2(ctx, vmns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
return append(envv, "HOME="+homeDir), nil
}
diff --git a/runsc/boot/user_test.go b/pkg/sentry/fs/user/user_test.go
index 9aee2ad07..12b786224 100644
--- a/runsc/boot/user_test.go
+++ b/pkg/sentry/fs/user/user_test.go
@@ -12,167 +12,111 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package boot
+package user
import (
- "io/ioutil"
- "os"
- "path/filepath"
+ "fmt"
"strings"
- "syscall"
"testing"
- specs "github.com/opencontainers/runtime-spec/specs-go"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-func setupTempDir() (string, error) {
- tmpDir, err := ioutil.TempDir(os.TempDir(), "exec-user-test")
+// createEtcPasswd creates /etc/passwd with the given contents and mode. If
+// mode is empty, then no file will be created. If mode is not a regular file
+// mode, then contents is ignored.
+func createEtcPasswd(ctx context.Context, root *fs.Dirent, contents string, mode linux.FileMode) error {
+ if err := root.CreateDirectory(ctx, root, "etc", fs.FilePermsFromMode(0755)); err != nil {
+ return err
+ }
+ etc, err := root.Walk(ctx, root, "etc")
if err != nil {
- return "", err
+ return err
}
- return tmpDir, nil
-}
-
-func setupPasswd(contents string, perms os.FileMode) func() (string, error) {
- return func() (string, error) {
- tmpDir, err := setupTempDir()
- if err != nil {
- return "", err
- }
-
- if err := os.Mkdir(filepath.Join(tmpDir, "etc"), 0777); err != nil {
- return "", err
- }
-
- f, err := os.Create(filepath.Join(tmpDir, "etc", "passwd"))
- if err != nil {
- return "", err
- }
- defer f.Close()
-
- _, err = f.WriteString(contents)
+ defer etc.DecRef(ctx)
+ switch mode.FileType() {
+ case 0:
+ // Don't create anything.
+ return nil
+ case linux.S_IFREG:
+ passwd, err := etc.Create(ctx, root, "passwd", fs.FileFlags{Write: true}, fs.FilePermsFromMode(mode))
if err != nil {
- return "", err
+ return err
}
-
- err = f.Chmod(perms)
- if err != nil {
- return "", err
+ defer passwd.DecRef(ctx)
+ if _, err := passwd.Writev(ctx, usermem.BytesIOSequence([]byte(contents))); err != nil {
+ return err
}
- return tmpDir, nil
+ return nil
+ case linux.S_IFDIR:
+ return etc.CreateDirectory(ctx, root, "passwd", fs.FilePermsFromMode(mode))
+ case linux.S_IFIFO:
+ return etc.CreateFifo(ctx, root, "passwd", fs.FilePermsFromMode(mode))
+ default:
+ return fmt.Errorf("unknown file type %x", mode.FileType())
}
}
// TestGetExecUserHome tests the getExecUserHome function.
func TestGetExecUserHome(t *testing.T) {
tests := map[string]struct {
- uid auth.KUID
- createRoot func() (string, error)
- expected string
+ uid auth.KUID
+ passwdContents string
+ passwdMode linux.FileMode
+ expected string
}{
"success": {
- uid: 1000,
- createRoot: setupPasswd("adin::1000:1111::/home/adin:/bin/sh", 0666),
- expected: "/home/adin",
+ uid: 1000,
+ passwdContents: "adin::1000:1111::/home/adin:/bin/sh",
+ passwdMode: linux.S_IFREG | 0666,
+ expected: "/home/adin",
+ },
+ "no_perms": {
+ uid: 1000,
+ passwdContents: "adin::1000:1111::/home/adin:/bin/sh",
+ passwdMode: linux.S_IFREG,
+ expected: "/",
},
"no_passwd": {
- uid: 1000,
- createRoot: setupTempDir,
- expected: "/",
+ uid: 1000,
+ expected: "/",
},
- "no_perms": {
+ "directory": {
uid: 1000,
- createRoot: setupPasswd("adin::1000:1111::/home/adin:/bin/sh", 0000),
+ passwdMode: linux.S_IFDIR | 0666,
expected: "/",
},
- "directory": {
- uid: 1000,
- createRoot: func() (string, error) {
- tmpDir, err := setupTempDir()
- if err != nil {
- return "", err
- }
-
- if err := os.Mkdir(filepath.Join(tmpDir, "etc"), 0777); err != nil {
- return "", err
- }
-
- if err := syscall.Mkdir(filepath.Join(tmpDir, "etc", "passwd"), 0666); err != nil {
- return "", err
- }
-
- return tmpDir, nil
- },
- expected: "/",
- },
// Currently we don't allow named pipes.
"named_pipe": {
- uid: 1000,
- createRoot: func() (string, error) {
- tmpDir, err := setupTempDir()
- if err != nil {
- return "", err
- }
-
- if err := os.Mkdir(filepath.Join(tmpDir, "etc"), 0777); err != nil {
- return "", err
- }
-
- if err := syscall.Mkfifo(filepath.Join(tmpDir, "etc", "passwd"), 0666); err != nil {
- return "", err
- }
-
- return tmpDir, nil
- },
- expected: "/",
+ uid: 1000,
+ passwdMode: linux.S_IFIFO | 0666,
+ expected: "/",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
- tmpDir, err := tc.createRoot()
- if err != nil {
- t.Fatalf("failed to create root dir: %v", err)
- }
-
- sandEnd, cleanup, err := startGofer(tmpDir)
- if err != nil {
- t.Fatalf("failed to create gofer: %v", err)
- }
- defer cleanup()
-
ctx := contexttest.Context(t)
- conf := &Config{
- RootDir: "unused_root_dir",
- Network: NetworkNone,
- DisableSeccomp: true,
- }
+ msrc := fs.NewPseudoMountSource(ctx)
+ rootInode := tmpfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0777), msrc)
- spec := &specs.Spec{
- Root: &specs.Root{
- Path: tmpDir,
- Readonly: true,
- },
- // Add /proc mount as tmpfs to avoid needing a kernel.
- Mounts: []specs.Mount{
- {
- Destination: "/proc",
- Type: "tmpfs",
- },
- },
- }
-
- mntr := newContainerMounter(spec, []int{sandEnd}, nil, &podMountHints{})
- mns, err := mntr.createMountNamespace(ctx, conf)
+ mns, err := fs.NewMountNamespace(ctx, rootInode)
if err != nil {
- t.Fatalf("failed to create mount namespace: %v", err)
+ t.Fatalf("NewMountNamespace failed: %v", err)
}
- ctx = fs.WithRoot(ctx, mns.Root())
- if err := mntr.mountSubmounts(ctx, conf, mns); err != nil {
- t.Fatalf("failed to create mount namespace: %v", err)
+ defer mns.DecRef(ctx)
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ ctx = fs.WithRoot(ctx, root)
+
+ if err := createEtcPasswd(ctx, root, tc.passwdContents, tc.passwdMode); err != nil {
+ t.Fatalf("createEtcPasswd failed: %v", err)
}
got, err := getExecUserHome(ctx, mns, tc.uid)
diff --git a/pkg/sentry/fsbridge/BUILD b/pkg/sentry/fsbridge/BUILD
new file mode 100644
index 000000000..6c798f0bd
--- /dev/null
+++ b/pkg/sentry/fsbridge/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "fsbridge",
+ srcs = [
+ "bridge.go",
+ "fs.go",
+ "vfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsbridge/bridge.go b/pkg/sentry/fsbridge/bridge.go
new file mode 100644
index 000000000..7e61209ee
--- /dev/null
+++ b/pkg/sentry/fsbridge/bridge.go
@@ -0,0 +1,54 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fsbridge provides common interfaces to bridge between VFS1 and VFS2
+// files.
+package fsbridge
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// File provides a common interface to bridge between VFS1 and VFS2 files.
+type File interface {
+ // PathnameWithDeleted returns an absolute pathname to vd, consistent with
+ // Linux's d_path(). In particular, if vd.Dentry() has been disowned,
+ // PathnameWithDeleted appends " (deleted)" to the returned pathname.
+ PathnameWithDeleted(ctx context.Context) string
+
+ // ReadFull read all contents from the file.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+
+ // ConfigureMMap mutates opts to implement mmap(2) for the file.
+ ConfigureMMap(context.Context, *memmap.MMapOpts) error
+
+ // Type returns the file type, e.g. linux.S_IFREG.
+ Type(context.Context) (linux.FileMode, error)
+
+ // IncRef increments reference.
+ IncRef()
+
+ // DecRef decrements reference.
+ DecRef(ctx context.Context)
+}
+
+// Lookup provides a common interface to open files.
+type Lookup interface {
+ // OpenPath opens a file.
+ OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error)
+}
diff --git a/pkg/sentry/fsbridge/fs.go b/pkg/sentry/fsbridge/fs.go
new file mode 100644
index 000000000..9785fd62a
--- /dev/null
+++ b/pkg/sentry/fsbridge/fs.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// fsFile implements File interface over fs.File.
+//
+// +stateify savable
+type fsFile struct {
+ file *fs.File
+}
+
+var _ File = (*fsFile)(nil)
+
+// NewFSFile creates a new File over fs.File.
+func NewFSFile(file *fs.File) File {
+ return &fsFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *fsFile) PathnameWithDeleted(ctx context.Context) string {
+ root := fs.RootFromContext(ctx)
+ if root == nil {
+ // This doesn't correspond to anything in Linux because the vfs is
+ // global there.
+ return ""
+ }
+ defer root.DecRef(ctx)
+
+ name, _ := f.file.Dirent.FullName(root)
+ return name
+}
+
+// ReadFull implements File.
+func (f *fsFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.Preadv(ctx, dst, offset+total)
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *fsFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *fsFile) Type(context.Context) (linux.FileMode, error) {
+ return linux.FileMode(f.file.Dirent.Inode.StableAttr.Type.LinuxType()), nil
+}
+
+// IncRef implements File.
+func (f *fsFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *fsFile) DecRef(ctx context.Context) {
+ f.file.DecRef(ctx)
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type fsLookup struct {
+ mntns *fs.MountNamespace
+
+ root *fs.Dirent
+ workingDir *fs.Dirent
+}
+
+var _ Lookup = (*fsLookup)(nil)
+
+// NewFSLookup creates a new Lookup using VFS1.
+func NewFSLookup(mntns *fs.MountNamespace, root, workingDir *fs.Dirent) Lookup {
+ return &fsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+func (l *fsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, remainingTraversals *uint, resolveFinal bool) (File, error) {
+ var d *fs.Dirent
+ var err error
+ if resolveFinal {
+ d, err = l.mntns.FindInode(ctx, l.root, l.workingDir, path, remainingTraversals)
+ } else {
+ d, err = l.mntns.FindLink(ctx, l.root, l.workingDir, path, remainingTraversals)
+ }
+ if err != nil {
+ return nil, err
+ }
+ defer d.DecRef(ctx)
+
+ if !resolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
+ return nil, syserror.ELOOP
+ }
+
+ fsPerm := openOptionsToPermMask(&opts)
+ if err := d.Inode.CheckPermission(ctx, fsPerm); err != nil {
+ return nil, err
+ }
+
+ // If they claim it's a directory, then make sure.
+ if strings.HasSuffix(path, "/") {
+ if d.Inode.StableAttr.Type != fs.Directory {
+ return nil, syserror.ENOTDIR
+ }
+ }
+
+ if opts.FileExec && d.Inode.StableAttr.Type != fs.RegularFile {
+ ctx.Infof("%q is not a regular file: %v", path, d.Inode.StableAttr.Type)
+ return nil, syserror.EACCES
+ }
+
+ f, err := d.Inode.GetFile(ctx, d, flagsToFileFlags(opts.Flags))
+ if err != nil {
+ return nil, err
+ }
+
+ return &fsFile{file: f}, nil
+}
+
+func openOptionsToPermMask(opts *vfs.OpenOptions) fs.PermMask {
+ mode := opts.Flags & linux.O_ACCMODE
+ return fs.PermMask{
+ Read: mode == linux.O_RDONLY || mode == linux.O_RDWR,
+ Write: mode == linux.O_WRONLY || mode == linux.O_RDWR,
+ Execute: opts.FileExec,
+ }
+}
+
+func flagsToFileFlags(flags uint32) fs.FileFlags {
+ return fs.FileFlags{
+ Direct: flags&linux.O_DIRECT != 0,
+ DSync: flags&(linux.O_DSYNC|linux.O_SYNC) != 0,
+ Sync: flags&linux.O_SYNC != 0,
+ NonBlocking: flags&linux.O_NONBLOCK != 0,
+ Read: (flags & linux.O_ACCMODE) != linux.O_WRONLY,
+ Write: (flags & linux.O_ACCMODE) != linux.O_RDONLY,
+ Append: flags&linux.O_APPEND != 0,
+ Directory: flags&linux.O_DIRECTORY != 0,
+ Async: flags&linux.O_ASYNC != 0,
+ LargeFile: flags&linux.O_LARGEFILE != 0,
+ Truncate: flags&linux.O_TRUNC != 0,
+ }
+}
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
new file mode 100644
index 000000000..323506d33
--- /dev/null
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -0,0 +1,142 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fsbridge
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// VFSFile implements File interface over vfs.FileDescription.
+//
+// +stateify savable
+type VFSFile struct {
+ file *vfs.FileDescription
+}
+
+var _ File = (*VFSFile)(nil)
+
+// NewVFSFile creates a new File over fs.File.
+func NewVFSFile(file *vfs.FileDescription) File {
+ return &VFSFile{file: file}
+}
+
+// PathnameWithDeleted implements File.
+func (f *VFSFile) PathnameWithDeleted(ctx context.Context) string {
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef(ctx)
+
+ vfsObj := f.file.VirtualDentry().Mount().Filesystem().VirtualFilesystem()
+ name, _ := vfsObj.PathnameWithDeleted(ctx, root, f.file.VirtualDentry())
+ return name
+}
+
+// ReadFull implements File.
+func (f *VFSFile) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
+ var total int64
+ for dst.NumBytes() > 0 {
+ n, err := f.file.PRead(ctx, dst, offset+total, vfs.ReadOptions{})
+ total += n
+ if err == io.EOF && total != 0 {
+ return total, io.ErrUnexpectedEOF
+ } else if err != nil {
+ return total, err
+ }
+ dst = dst.DropFirst64(n)
+ }
+ return total, nil
+}
+
+// ConfigureMMap implements File.
+func (f *VFSFile) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return f.file.ConfigureMMap(ctx, opts)
+}
+
+// Type implements File.
+func (f *VFSFile) Type(ctx context.Context) (linux.FileMode, error) {
+ stat, err := f.file.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ return 0, err
+ }
+ return linux.FileMode(stat.Mode).FileType(), nil
+}
+
+// IncRef implements File.
+func (f *VFSFile) IncRef() {
+ f.file.IncRef()
+}
+
+// DecRef implements File.
+func (f *VFSFile) DecRef(ctx context.Context) {
+ f.file.DecRef(ctx)
+}
+
+// FileDescription returns the FileDescription represented by f. It does not
+// take an additional reference on the returned FileDescription.
+func (f *VFSFile) FileDescription() *vfs.FileDescription {
+ return f.file
+}
+
+// fsLookup implements Lookup interface using fs.File.
+//
+// +stateify savable
+type vfsLookup struct {
+ mntns *vfs.MountNamespace
+
+ root vfs.VirtualDentry
+ workingDir vfs.VirtualDentry
+}
+
+var _ Lookup = (*vfsLookup)(nil)
+
+// NewVFSLookup creates a new Lookup using VFS2.
+func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry) Lookup {
+ return &vfsLookup{
+ mntns: mntns,
+ root: root,
+ workingDir: workingDir,
+ }
+}
+
+// OpenPath implements Lookup.
+//
+// remainingTraversals is not configurable in VFS2, all callers are using the
+// default anyways.
+func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
+ vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
+ creds := auth.CredentialsFromContext(ctx)
+ path := fspath.Parse(pathname)
+ pop := &vfs.PathOperation{
+ Root: l.root,
+ Start: l.workingDir,
+ Path: path,
+ FollowFinalSymlink: resolveFinal,
+ }
+ if path.Absolute {
+ pop.Start = l.root
+ }
+ fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return &VFSFile{file: fd}, nil
+}
diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD
new file mode 100644
index 000000000..93512c9b6
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/BUILD
@@ -0,0 +1,44 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "devpts",
+ srcs = [
+ "devpts.go",
+ "line_discipline.go",
+ "master.go",
+ "queue.go",
+ "slave.go",
+ "terminal.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "devpts_test",
+ size = "small",
+ srcs = ["devpts_test.go"],
+ library = ":devpts",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/contexttest",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
new file mode 100644
index 000000000..7169e91af
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -0,0 +1,233 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package devpts provides a filesystem implementation that behaves like
+// devpts.
+package devpts
+
+import (
+ "fmt"
+ "math"
+ "sort"
+ "strconv"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the filesystem name.
+const Name = "devpts"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ // No data allowed.
+ if opts.Data != "" {
+ return nil, nil, syserror.EINVAL
+ }
+
+ fs, root, err := fstype.newFilesystem(vfsObj, creds)
+ if err != nil {
+ return nil, nil, err
+ }
+ return fs.Filesystem.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// newFilesystem creates a new devpts filesystem with root directory and ptmx
+// master inode. It returns the filesystem and root Dentry.
+func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.Filesystem.VFSFilesystem().Init(vfsObj, fstype, fs)
+
+ // Construct the root directory. This is always inode id 1.
+ root := &rootInode{
+ slaves: make(map[uint32]*slaveInode),
+ }
+ root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555)
+ root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ root.dentry.Init(root)
+
+ // Construct the pts master inode and dentry. Linux always uses inode
+ // id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx.
+ master := &masterInode{
+ root: root,
+ }
+ master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666)
+ master.dentry.Init(master)
+
+ // Add the master as a child of the root.
+ links := root.OrderedChildren.Populate(&root.dentry, map[string]*kernfs.Dentry{
+ "ptmx": &master.dentry,
+ })
+ root.IncLinks(links)
+
+ return fs, &root.dentry, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// rootInode is the root directory inode for the devpts mounts.
+type rootInode struct {
+ kernfs.AlwaysValid
+ kernfs.InodeAttrs
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeNotSymlink
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // master is the master pty inode. Immutable.
+ master *masterInode
+
+ // root is the root directory inode for this filesystem. Immutable.
+ root *rootInode
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // slaves maps pty ids to slave inodes.
+ slaves map[uint32]*slaveInode
+
+ // nextIdx is the next pty index to use. Must be accessed atomically.
+ //
+ // TODO(b/29356795): reuse indices when ptys are closed.
+ nextIdx uint32
+}
+
+var _ kernfs.Inode = (*rootInode)(nil)
+
+// allocateTerminal creates a new Terminal and installs a pts node for it.
+func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if i.nextIdx == math.MaxUint32 {
+ return nil, syserror.ENOMEM
+ }
+ idx := i.nextIdx
+ i.nextIdx++
+
+ // Sanity check that slave with idx does not exist.
+ if _, ok := i.slaves[idx]; ok {
+ panic(fmt.Sprintf("pty index collision; index %d already exists", idx))
+ }
+
+ // Create the new terminal and slave.
+ t := newTerminal(idx)
+ slave := &slaveInode{
+ root: i,
+ t: t,
+ }
+ // Linux always uses pty index + 3 as the inode id. See
+ // fs/devpts/inode.c:devpts_pty_new().
+ slave.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600)
+ slave.dentry.Init(slave)
+ i.slaves[idx] = slave
+
+ return t, nil
+}
+
+// masterClose is called when the master end of t is closed.
+func (i *rootInode) masterClose(t *Terminal) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ // Sanity check that slave with idx exists.
+ if _, ok := i.slaves[t.n]; !ok {
+ panic(fmt.Sprintf("pty with index %d does not exist", t.n))
+ }
+ delete(i.slaves, t.n)
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// Lookup implements kernfs.Inode.Lookup.
+func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ idx, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ if si, ok := i.slaves[uint32(idx)]; ok {
+ si.dentry.IncRef()
+ return si.dentry.VFSDentry(), nil
+
+ }
+ return nil, syserror.ENOENT
+}
+
+// IterDirents implements kernfs.Inode.IterDirents.
+func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ ids := make([]int, 0, len(i.slaves))
+ for id := range i.slaves {
+ ids = append(ids, int(id))
+ }
+ sort.Ints(ids)
+ for _, id := range ids[relOffset:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(id), 10),
+ Type: linux.DT_CHR,
+ Ino: i.slaves[uint32(id)].InodeAttrs.Ino(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go
new file mode 100644
index 000000000..b7c149047
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/devpts_test.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestSimpleMasterToSlave(t *testing.T) {
+ ld := newLineDiscipline(linux.DefaultSlaveTermios)
+ ctx := contexttest.Context(t)
+ inBytes := []byte("hello, tty\n")
+ src := usermem.BytesIOSequence(inBytes)
+ outBytes := make([]byte, 32)
+ dst := usermem.BytesIOSequence(outBytes)
+
+ // Write to the input queue.
+ nw, err := ld.inputQueueWrite(ctx, src)
+ if err != nil {
+ t.Fatalf("error writing to input queue: %v", err)
+ }
+ if nw != int64(len(inBytes)) {
+ t.Fatalf("wrote wrong length: got %d, want %d", nw, len(inBytes))
+ }
+
+ // Read from the input queue.
+ nr, err := ld.inputQueueRead(ctx, dst)
+ if err != nil {
+ t.Fatalf("error reading from input queue: %v", err)
+ }
+ if nr != int64(len(inBytes)) {
+ t.Fatalf("read wrong length: got %d, want %d", nr, len(inBytes))
+ }
+
+ outStr := string(outBytes[:nr])
+ inStr := string(inBytes)
+ if outStr != inStr {
+ t.Fatalf("written and read strings do not match: got %q, want %q", outStr, inStr)
+ }
+}
diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go
new file mode 100644
index 000000000..f7bc325d1
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/line_discipline.go
@@ -0,0 +1,445 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "bytes"
+ "unicode/utf8"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // canonMaxBytes is the number of bytes that fit into a single line of
+ // terminal input in canonical mode. This corresponds to N_TTY_BUF_SIZE
+ // in include/linux/tty.h.
+ canonMaxBytes = 4096
+
+ // nonCanonMaxBytes is the maximum number of bytes that can be read at
+ // a time in noncanonical mode.
+ nonCanonMaxBytes = canonMaxBytes - 1
+
+ spacesPerTab = 8
+)
+
+// lineDiscipline dictates how input and output are handled between the
+// pseudoterminal (pty) master and slave. It can be configured to alter I/O,
+// modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man
+// pages are good resources for how to affect the line discipline:
+//
+// * termios(3)
+// * tty_ioctl(4)
+//
+// This file corresponds most closely to drivers/tty/n_tty.c.
+//
+// lineDiscipline has a simple structure but supports a multitude of options
+// (see the above man pages). It consists of two queues of bytes: one from the
+// terminal master to slave (the input queue) and one from slave to master (the
+// output queue). When bytes are written to one end of the pty, the line
+// discipline reads the bytes, modifies them or takes special action if
+// required, and enqueues them to be read by the other end of the pty:
+//
+// input from terminal +-------------+ input to process (e.g. bash)
+// +------------------------>| input queue |---------------------------+
+// | (inputQueueWrite) +-------------+ (inputQueueRead) |
+// | |
+// | v
+// masterFD slaveFD
+// ^ |
+// | |
+// | output to terminal +--------------+ output from process |
+// +------------------------| output queue |<--------------------------+
+// (outputQueueRead) +--------------+ (outputQueueWrite)
+//
+// Lock order:
+// termiosMu
+// inQueue.mu
+// outQueue.mu
+//
+// +stateify savable
+type lineDiscipline struct {
+ // sizeMu protects size.
+ sizeMu sync.Mutex `state:"nosave"`
+
+ // size is the terminal size (width and height).
+ size linux.WindowSize
+
+ // inQueue is the input queue of the terminal.
+ inQueue queue
+
+ // outQueue is the output queue of the terminal.
+ outQueue queue
+
+ // termiosMu protects termios.
+ termiosMu sync.RWMutex `state:"nosave"`
+
+ // termios is the terminal configuration used by the lineDiscipline.
+ termios linux.KernelTermios
+
+ // column is the location in a row of the cursor. This is important for
+ // handling certain special characters like backspace.
+ column int
+
+ // masterWaiter is used to wait on the master end of the TTY.
+ masterWaiter waiter.Queue `state:"zerovalue"`
+
+ // slaveWaiter is used to wait on the slave end of the TTY.
+ slaveWaiter waiter.Queue `state:"zerovalue"`
+}
+
+func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline {
+ ld := lineDiscipline{termios: termios}
+ ld.inQueue.transformer = &inputQueueTransformer{}
+ ld.outQueue.transformer = &outputQueueTransformer{}
+ return &ld
+}
+
+// getTermios gets the linux.Termios for the tty.
+func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ // We must copy a Termios struct, not KernelTermios.
+ t := l.termios.ToTermios()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// setTermios sets a linux.Termios for the tty.
+func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ l.termiosMu.Lock()
+ defer l.termiosMu.Unlock()
+ oldCanonEnabled := l.termios.LEnabled(linux.ICANON)
+ // We must copy a Termios struct, not KernelTermios.
+ var t linux.Termios
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ l.termios.FromTermios(t)
+
+ // If canonical mode is turned off, move bytes from inQueue's wait
+ // buffer to its read buffer. Anything already in the read buffer is
+ // now readable.
+ if oldCanonEnabled && !l.termios.LEnabled(linux.ICANON) {
+ l.inQueue.mu.Lock()
+ l.inQueue.pushWaitBufLocked(l)
+ l.inQueue.readable = true
+ l.inQueue.mu.Unlock()
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+
+ return 0, err
+}
+
+func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ l.sizeMu.Lock()
+ defer l.sizeMu.Unlock()
+ _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+}
+
+func (l *lineDiscipline) masterReadiness() waiter.EventMask {
+ // We don't have to lock a termios because the default master termios
+ // is immutable.
+ return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios)
+}
+
+func (l *lineDiscipline) slaveReadiness() waiter.EventMask {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios)
+}
+
+func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.inQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.inQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.inQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ return l.outQueue.readableSize(ctx, io, args)
+}
+
+func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, pushed, err := l.outQueue.read(ctx, dst, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.slaveWaiter.Notify(waiter.EventOut)
+ if pushed {
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+func (l *lineDiscipline) outputQueueWrite(ctx context.Context, src usermem.IOSequence) (int64, error) {
+ l.termiosMu.RLock()
+ defer l.termiosMu.RUnlock()
+ n, err := l.outQueue.write(ctx, src, l)
+ if err != nil {
+ return 0, err
+ }
+ if n > 0 {
+ l.masterWaiter.Notify(waiter.EventIn)
+ return n, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// transformer is a helper interface to make it easier to stateify queue.
+type transformer interface {
+ // transform functions require queue's mutex to be held.
+ transform(*lineDiscipline, *queue, []byte) int
+}
+
+// outputQueueTransformer implements transformer. It performs line discipline
+// transformations on the output queue.
+//
+// +stateify savable
+type outputQueueTransformer struct{}
+
+// transform does output processing for one end of the pty. See
+// drivers/tty/n_tty.c:do_output_char for an analogous kernel function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*outputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // transformOutput is effectively always in noncanonical mode, as the
+ // master termios never has ICANON set.
+
+ if !l.termios.OEnabled(linux.OPOST) {
+ q.readBuf = append(q.readBuf, buf...)
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return len(buf)
+ }
+
+ var ret int
+ for len(buf) > 0 {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ ret += size
+ buf = buf[size:]
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\n':
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ if l.termios.OEnabled(linux.ONLCR) {
+ q.readBuf = append(q.readBuf, '\r', '\n')
+ continue
+ }
+ case '\r':
+ if l.termios.OEnabled(linux.ONOCR) && l.column == 0 {
+ continue
+ }
+ if l.termios.OEnabled(linux.OCRNL) {
+ cBytes[0] = '\n'
+ if l.termios.OEnabled(linux.ONLRET) {
+ l.column = 0
+ }
+ break
+ }
+ l.column = 0
+ case '\t':
+ spaces := spacesPerTab - l.column%spacesPerTab
+ if l.termios.OutputFlags&linux.TABDLY == linux.XTABS {
+ l.column += spaces
+ q.readBuf = append(q.readBuf, bytes.Repeat([]byte{' '}, spacesPerTab)...)
+ continue
+ }
+ l.column += spaces
+ case '\b':
+ if l.column > 0 {
+ l.column--
+ }
+ default:
+ l.column++
+ }
+ q.readBuf = append(q.readBuf, cBytes...)
+ }
+ if len(q.readBuf) > 0 {
+ q.readable = true
+ }
+ return ret
+}
+
+// inputQueueTransformer implements transformer. It performs line discipline
+// transformations on the input queue.
+//
+// +stateify savable
+type inputQueueTransformer struct{}
+
+// transform does input processing for one end of the pty. Characters read are
+// transformed according to flags set in the termios struct. See
+// drivers/tty/n_tty.c:n_tty_receive_char_special for an analogous kernel
+// function.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (*inputQueueTransformer) transform(l *lineDiscipline, q *queue, buf []byte) int {
+ // If there's a line waiting to be read in canonical mode, don't write
+ // anything else to the read buffer.
+ if l.termios.LEnabled(linux.ICANON) && q.readable {
+ return 0
+ }
+
+ maxBytes := nonCanonMaxBytes
+ if l.termios.LEnabled(linux.ICANON) {
+ maxBytes = canonMaxBytes
+ }
+
+ var ret int
+ for len(buf) > 0 && len(q.readBuf) < canonMaxBytes {
+ size := l.peek(buf)
+ cBytes := append([]byte{}, buf[:size]...)
+ // We're guaranteed that cBytes has at least one element.
+ switch cBytes[0] {
+ case '\r':
+ if l.termios.IEnabled(linux.IGNCR) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+ if l.termios.IEnabled(linux.ICRNL) {
+ cBytes[0] = '\n'
+ }
+ case '\n':
+ if l.termios.IEnabled(linux.INLCR) {
+ cBytes[0] = '\r'
+ }
+ }
+
+ // In canonical mode, we discard non-terminating characters
+ // after the first 4095.
+ if l.shouldDiscard(q, cBytes) {
+ buf = buf[size:]
+ ret += size
+ continue
+ }
+
+ // Stop if the buffer would be overfilled.
+ if len(q.readBuf)+size > maxBytes {
+ break
+ }
+ buf = buf[size:]
+ ret += size
+
+ // If we get EOF, make the buffer available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsEOF(cBytes[0]) {
+ q.readable = true
+ break
+ }
+
+ q.readBuf = append(q.readBuf, cBytes...)
+
+ // Anything written to the readBuf will have to be echoed.
+ if l.termios.LEnabled(linux.ECHO) {
+ l.outQueue.writeBytes(cBytes, l)
+ l.masterWaiter.Notify(waiter.EventIn)
+ }
+
+ // If we finish a line, make it available for reading.
+ if l.termios.LEnabled(linux.ICANON) && l.termios.IsTerminating(cBytes) {
+ q.readable = true
+ break
+ }
+ }
+
+ // In noncanonical mode, everything is readable.
+ if !l.termios.LEnabled(linux.ICANON) && len(q.readBuf) > 0 {
+ q.readable = true
+ }
+
+ return ret
+}
+
+// shouldDiscard returns whether c should be discarded. In canonical mode, if
+// too many bytes are enqueued, we keep reading input and discarding it until
+// we find a terminating character. Signal/echo processing still occurs.
+//
+// Precondition:
+// * l.termiosMu must be held for reading.
+// * q.mu must be held.
+func (l *lineDiscipline) shouldDiscard(q *queue, cBytes []byte) bool {
+ return l.termios.LEnabled(linux.ICANON) && len(q.readBuf)+len(cBytes) >= canonMaxBytes && !l.termios.IsTerminating(cBytes)
+}
+
+// peek returns the size in bytes of the next character to process. As long as
+// b isn't empty, peek returns a value of at least 1.
+func (l *lineDiscipline) peek(b []byte) int {
+ size := 1
+ // If UTF-8 support is enabled, runes might be multiple bytes.
+ if l.termios.IEnabled(linux.IUTF8) {
+ _, size = utf8.DecodeRune(b)
+ }
+ return size
+}
diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go
new file mode 100644
index 000000000..3bb397f71
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/master.go
@@ -0,0 +1,237 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// masterInode is the inode for the master end of the Terminal.
+type masterInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // root is the devpts root inode.
+ root *rootInode
+}
+
+var _ kernfs.Inode = (*masterInode)(nil)
+
+// Open implements kernfs.Inode.Open.
+func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ t, err := mi.root.allocateTerminal(rp.Credentials())
+ if err != nil {
+ return nil, err
+ }
+
+ mi.IncRef()
+ fd := &masterFileDescription{
+ inode: mi,
+ t: t,
+ }
+ fd.LockFD.Init(&mi.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ mi.DecRef(ctx)
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := mi.InodeAttrs.Stat(ctx, vfsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ statx.Blksize = 1024
+ statx.RdevMajor = linux.TTYAUX_MAJOR
+ statx.RdevMinor = linux.PTMX_MINOR
+ return statx, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat
+func (mi *masterInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 {
+ return syserror.EINVAL
+ }
+ return mi.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
+}
+
+type masterFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *masterInode
+ t *Terminal
+}
+
+var _ vfs.FileDescriptionImpl = (*masterFileDescription)(nil)
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (mfd *masterFileDescription) Release(ctx context.Context) {
+ mfd.inode.root.masterClose(mfd.t)
+ mfd.inode.DecRef(ctx)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (mfd *masterFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ mfd.t.ld.masterWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (mfd *masterFileDescription) EventUnregister(e *waiter.Entry) {
+ mfd.t.ld.masterWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (mfd *masterFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mfd.t.ld.masterReadiness()
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (mfd *masterFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return mfd.t.ld.outputQueueRead(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (mfd *masterFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return mfd.t.ld.inputQueueWrite(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (mfd *masterFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the output queue read buffer.
+ return 0, mfd.t.ld.outputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ // N.B. TCGETS on the master actually returns the configuration
+ // of the slave end.
+ return mfd.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ // N.B. TCSETS on the master actually affects the configuration
+ // of the slave end.
+ return mfd.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return mfd.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mfd.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCSPTLCK:
+ // TODO(b/29356795): Implement pty locking. For now just pretend we do.
+ return 0, nil
+ case linux.TIOCGWINSZ:
+ return 0, mfd.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, mfd.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, mfd.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, mfd.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return mfd.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return mfd.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return mfd.inode.SetStat(ctx, fs, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return mfd.inode.Stat(ctx, fs, opts)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (mfd *masterFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return mfd.Locks().LockPOSIX(ctx, &mfd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (mfd *masterFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return mfd.Locks().UnlockPOSIX(ctx, &mfd.vfsfd, uid, start, length, whence)
+}
+
+// maybeEmitUnimplementedEvent emits unimplemented event if cmd is valid.
+func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
+ switch cmd {
+ case linux.TCGETS,
+ linux.TCSETS,
+ linux.TCSETSW,
+ linux.TCSETSF,
+ linux.TIOCGWINSZ,
+ linux.TIOCSWINSZ,
+ linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+}
diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go
new file mode 100644
index 000000000..dffb4232c
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/queue.go
@@ -0,0 +1,236 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// waitBufMaxBytes is the maximum size of a wait buffer. It is based on
+// TTYB_DEFAULT_MEM_LIMIT.
+const waitBufMaxBytes = 131072
+
+// queue represents one of the input or output queues between a pty master and
+// slave. Bytes written to a queue are added to the read buffer until it is
+// full, at which point they are written to the wait buffer. Bytes are
+// processed (i.e. undergo termios transformations) as they are added to the
+// read buffer. The read buffer is readable when its length is nonzero and
+// readable is true.
+//
+// +stateify savable
+type queue struct {
+ // mu protects everything in queue.
+ mu sync.Mutex `state:"nosave"`
+
+ // readBuf is buffer of data ready to be read when readable is true.
+ // This data has been processed.
+ readBuf []byte
+
+ // waitBuf contains data that can't fit into readBuf. It is put here
+ // until it can be loaded into the read buffer. waitBuf contains data
+ // that hasn't been processed.
+ waitBuf [][]byte
+ waitBufLen uint64
+
+ // readable indicates whether the read buffer can be read from. In
+ // canonical mode, there can be an unterminated line in the read buffer,
+ // so readable must be checked.
+ readable bool
+
+ // transform is the the queue's function for transforming bytes
+ // entering the queue. For example, transform might convert all '\r's
+ // entering the queue to '\n's.
+ transformer
+}
+
+// readReadiness returns whether q is ready to be read from.
+func (q *queue) readReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if len(q.readBuf) > 0 && q.readable {
+ return waiter.EventIn
+ }
+ return waiter.EventMask(0)
+}
+
+// writeReadiness returns whether q is ready to be written to.
+func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ if q.waitBufLen < waitBufMaxBytes {
+ return waiter.EventOut
+ }
+ return waiter.EventMask(0)
+}
+
+// readableSize writes the number of readable bytes to userspace.
+func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ var size int32
+ if q.readable {
+ size = int32(len(q.readBuf))
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return err
+
+}
+
+// read reads from q to userspace. It returns the number of bytes read as well
+// as whether the read caused more readable data to become available (whether
+// data was pushed from the wait buffer to the read buffer).
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.readable {
+ return 0, false, syserror.ErrWouldBlock
+ }
+
+ if dst.NumBytes() > canonMaxBytes {
+ dst = dst.TakeFirst(canonMaxBytes)
+ }
+
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dst safemem.BlockSeq) (uint64, error) {
+ src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(q.readBuf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.readBuf = q.readBuf[n:]
+
+ // If we read everything, this queue is no longer readable.
+ if len(q.readBuf) == 0 {
+ q.readable = false
+ }
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, false, err
+ }
+
+ // Move data from the queue's wait buffer to its read buffer.
+ nPushed := q.pushWaitBufLocked(l)
+
+ return int64(n), nPushed > 0, nil
+}
+
+// write writes to q from userspace.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Copy data into the wait buffer.
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(src safemem.BlockSeq) (uint64, error) {
+ copyLen := src.NumBytes()
+ room := waitBufMaxBytes - q.waitBufLen
+ // If out of room, return EAGAIN.
+ if room == 0 && copyLen > 0 {
+ return 0, syserror.ErrWouldBlock
+ }
+ // Cap the size of the wait buffer.
+ if copyLen > room {
+ copyLen = room
+ src = src.TakeFirst64(room)
+ }
+ buf := make([]byte, copyLen)
+
+ // Copy the data into the wait buffer.
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf))
+ n, err := safemem.CopySeq(dst, src)
+ if err != nil {
+ return 0, err
+ }
+ q.waitBufAppend(buf)
+
+ return n, nil
+ }))
+ if err != nil {
+ return 0, err
+ }
+
+ // Push data from the wait to the read buffer.
+ q.pushWaitBufLocked(l)
+
+ return n, nil
+}
+
+// writeBytes writes to q from b.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+func (q *queue) writeBytes(b []byte, l *lineDiscipline) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ // Write to the wait buffer.
+ q.waitBufAppend(b)
+ q.pushWaitBufLocked(l)
+}
+
+// pushWaitBufLocked fills the queue's read buffer with data from the wait
+// buffer.
+//
+// Preconditions:
+// * l.termiosMu must be held for reading.
+// * q.mu must be locked.
+func (q *queue) pushWaitBufLocked(l *lineDiscipline) int {
+ if q.waitBufLen == 0 {
+ return 0
+ }
+
+ // Move data from the wait to the read buffer.
+ var total int
+ var i int
+ for i = 0; i < len(q.waitBuf); i++ {
+ n := q.transform(l, q, q.waitBuf[i])
+ total += n
+ if n != len(q.waitBuf[i]) {
+ // The read buffer filled up without consuming the
+ // entire buffer.
+ q.waitBuf[i] = q.waitBuf[i][n:]
+ break
+ }
+ }
+
+ // Update wait buffer based on consumed data.
+ q.waitBuf = q.waitBuf[i:]
+ q.waitBufLen -= uint64(total)
+
+ return total
+}
+
+// Precondition: q.mu must be locked.
+func (q *queue) waitBufAppend(b []byte) {
+ q.waitBuf = append(q.waitBuf, b)
+ q.waitBufLen += uint64(len(b))
+}
diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go
new file mode 100644
index 000000000..32e4e1908
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/slave.go
@@ -0,0 +1,197 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// slaveInode is the inode for the slave end of the Terminal.
+type slaveInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+
+ // Keep a reference to this inode's dentry.
+ dentry kernfs.Dentry
+
+ // root is the devpts root inode.
+ root *rootInode
+
+ // t is the connected Terminal.
+ t *Terminal
+}
+
+var _ kernfs.Inode = (*slaveInode)(nil)
+
+// Open implements kernfs.Inode.Open.
+func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ si.IncRef()
+ fd := &slaveFileDescription{
+ inode: si,
+ }
+ fd.LockFD.Init(&si.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ si.DecRef(ctx)
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+
+}
+
+// Valid implements kernfs.Inode.Valid.
+func (si *slaveInode) Valid(context.Context) bool {
+ // Return valid if the slave still exists.
+ si.root.mu.Lock()
+ defer si.root.mu.Unlock()
+ _, ok := si.root.slaves[si.t.n]
+ return ok
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ statx.Blksize = 1024
+ statx.RdevMajor = linux.UNIX98_PTY_SLAVE_MAJOR
+ statx.RdevMinor = si.t.n
+ return statx, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat
+func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask&linux.STATX_SIZE != 0 {
+ return syserror.EINVAL
+ }
+ return si.InodeAttrs.SetStat(ctx, vfsfs, creds, opts)
+}
+
+type slaveFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ inode *slaveInode
+}
+
+var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (sfd *slaveFileDescription) Release(ctx context.Context) {
+ sfd.inode.DecRef(ctx)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (sfd *slaveFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ sfd.inode.t.ld.slaveWaiter.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (sfd *slaveFileDescription) EventUnregister(e *waiter.Entry) {
+ sfd.inode.t.ld.slaveWaiter.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (sfd *slaveFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return sfd.inode.t.ld.slaveReadiness()
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (sfd *slaveFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ return sfd.inode.t.ld.inputQueueRead(ctx, dst)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ return sfd.inode.t.ld.outputQueueWrite(ctx, src)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.Ioctl.
+func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := args[1].Uint(); cmd {
+ case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ
+ // Get the number of bytes in the input queue read buffer.
+ return 0, sfd.inode.t.ld.inputQueueReadSize(ctx, io, args)
+ case linux.TCGETS:
+ return sfd.inode.t.ld.getTermios(ctx, io, args)
+ case linux.TCSETS:
+ return sfd.inode.t.ld.setTermios(ctx, io, args)
+ case linux.TCSETSW:
+ // TODO(b/29356795): This should drain the output queue first.
+ return sfd.inode.t.ld.setTermios(ctx, io, args)
+ case linux.TIOCGPTN:
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sfd.inode.t.n), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case linux.TIOCGWINSZ:
+ return 0, sfd.inode.t.ld.windowSize(ctx, io, args)
+ case linux.TIOCSWINSZ:
+ return 0, sfd.inode.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, sfd.inode.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, sfd.inode.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return sfd.inode.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return sfd.inode.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
+ default:
+ maybeEmitUnimplementedEvent(ctx, cmd)
+ return 0, syserror.ENOTTY
+ }
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return sfd.inode.SetStat(ctx, fs, creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return sfd.inode.Stat(ctx, fs, opts)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go
new file mode 100644
index 000000000..7d2781c54
--- /dev/null
+++ b/pkg/sentry/fsimpl/devpts/terminal.go
@@ -0,0 +1,120 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devpts
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Terminal is a pseudoterminal.
+//
+// +stateify savable
+type Terminal struct {
+ // n is the terminal index. It is immutable.
+ n uint32
+
+ // ld is the line discipline of the terminal. It is immutable.
+ ld *lineDiscipline
+
+ // masterKTTY contains the controlling process of the master end of
+ // this terminal. This field is immutable.
+ masterKTTY *kernel.TTY
+
+ // slaveKTTY contains the controlling process of the slave end of this
+ // terminal. This field is immutable.
+ slaveKTTY *kernel.TTY
+}
+
+func newTerminal(n uint32) *Terminal {
+ termios := linux.DefaultSlaveTermios
+ t := Terminal{
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{Index: n},
+ slaveKTTY: &kernel.TTY{Index: n},
+ }
+ return &t
+}
+
+// setControllingTTY makes tm the controlling terminal of the calling thread
+// group.
+func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int())
+}
+
+// releaseControllingTTY removes tm as the controlling terminal of the calling
+// thread group.
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("releaseControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster))
+}
+
+// foregroundProcessGroup gets the process group ID of tm's foreground process.
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("foregroundProcessGroup must be called from a task context")
+ }
+
+ ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster))
+ if err != nil {
+ return 0, err
+ }
+
+ // Write it out to *arg.
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// foregroundProcessGroup sets tm's foreground process.
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setForegroundProcessGroup must be called from a task context")
+ }
+
+ // Read in the process group ID.
+ var pgid int32
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid))
+ return uintptr(ret), err
+}
+
+func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
+ if isMaster {
+ return tm.masterKTTY
+ }
+ return tm.slaveKTTY
+}
diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD
new file mode 100644
index 000000000..aa0c2ad8c
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "devtmpfs",
+ srcs = ["devtmpfs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "devtmpfs_test",
+ size = "small",
+ srcs = ["devtmpfs_test.go"],
+ library = ":devtmpfs",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
new file mode 100644
index 000000000..2ed5fa8a9
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go
@@ -0,0 +1,219 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package devtmpfs provides an implementation of /dev based on tmpfs,
+// analogous to Linux's devtmpfs.
+package devtmpfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Name is the default filesystem name.
+const Name = "devtmpfs"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct {
+ initOnce sync.Once
+ initErr error
+
+ // fs is the tmpfs filesystem that backs all mounts of this FilesystemType.
+ // root is fs' root. fs and root are immutable.
+ fs *vfs.Filesystem
+ root *vfs.Dentry
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (*FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fst *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ fst.initOnce.Do(func() {
+ fs, root, err := tmpfs.FilesystemType{}.GetFilesystem(ctx, vfsObj, creds, "" /* source */, vfs.GetFilesystemOptions{
+ Data: "mode=0755", // opts from drivers/base/devtmpfs.c:devtmpfs_init()
+ })
+ if err != nil {
+ fst.initErr = err
+ return
+ }
+ fst.fs = fs
+ fst.root = root
+ })
+ if fst.initErr != nil {
+ return nil, nil, fst.initErr
+ }
+ fst.fs.IncRef()
+ fst.root.IncRef()
+ return fst.fs, fst.root, nil
+}
+
+// Accessor allows devices to create device special files in devtmpfs.
+type Accessor struct {
+ vfsObj *vfs.VirtualFilesystem
+ mntns *vfs.MountNamespace
+ root vfs.VirtualDentry
+ creds *auth.Credentials
+}
+
+// NewAccessor returns an Accessor that supports creation of device special
+// files in the devtmpfs instance registered with name fsTypeName in vfsObj.
+func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, fsTypeName string) (*Accessor, error) {
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, err
+ }
+ return &Accessor{
+ vfsObj: vfsObj,
+ mntns: mntns,
+ root: mntns.Root(),
+ creds: creds,
+ }, nil
+}
+
+// Release must be called when a is no longer in use.
+func (a *Accessor) Release(ctx context.Context) {
+ a.root.DecRef(ctx)
+ a.mntns.DecRef(ctx)
+}
+
+// accessorContext implements context.Context by extending an existing
+// context.Context with an Accessor's values for VFS-relevant state.
+type accessorContext struct {
+ context.Context
+ a *Accessor
+}
+
+func (a *Accessor) wrapContext(ctx context.Context) *accessorContext {
+ return &accessorContext{
+ Context: ctx,
+ a: a,
+ }
+}
+
+// Value implements context.Context.Value.
+func (ac *accessorContext) Value(key interface{}) interface{} {
+ switch key {
+ case vfs.CtxMountNamespace:
+ ac.a.mntns.IncRef()
+ return ac.a.mntns
+ case vfs.CtxRoot:
+ ac.a.root.IncRef()
+ return ac.a.root
+ default:
+ return ac.Context.Value(key)
+ }
+}
+
+func (a *Accessor) pathOperationAt(pathname string) *vfs.PathOperation {
+ return &vfs.PathOperation{
+ Root: a.root,
+ Start: a.root,
+ Path: fspath.Parse(pathname),
+ }
+}
+
+// CreateDeviceFile creates a device special file at the given pathname in the
+// devtmpfs instance accessed by the Accessor.
+func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind vfs.DeviceKind, major, minor uint32, perms uint16) error {
+ actx := a.wrapContext(ctx)
+
+ mode := (linux.FileMode)(perms)
+ switch kind {
+ case vfs.BlockDevice:
+ mode |= linux.S_IFBLK
+ case vfs.CharDevice:
+ mode |= linux.S_IFCHR
+ default:
+ panic(fmt.Sprintf("invalid vfs.DeviceKind: %v", kind))
+ }
+
+ // Create any parent directories. See
+ // devtmpfs.c:handle_create()=>path_create().
+ for it := fspath.Parse(pathname).Begin; it.NextOk(); it = it.Next() {
+ pop := a.pathOperationAt(it.String())
+ if err := a.vfsObj.MkdirAt(actx, a.creds, pop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ return fmt.Errorf("failed to create directory %q: %v", it.String(), err)
+ }
+ }
+
+ // NOTE: Linux's devtmpfs refuses to automatically delete files it didn't
+ // create, which it recognizes by storing a pointer to the kdevtmpfs struct
+ // thread in struct inode::i_private. Accessor doesn't yet support deletion
+ // of files at all, and probably won't as long as we don't need to support
+ // kernel modules, so this is moot for now.
+ return a.vfsObj.MknodAt(actx, a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{
+ Mode: mode,
+ DevMajor: major,
+ DevMinor: minor,
+ })
+}
+
+// UserspaceInit creates symbolic links and mount points in the devtmpfs
+// instance accessed by the Accessor that are created by userspace in Linux. It
+// does not create mounts.
+func (a *Accessor) UserspaceInit(ctx context.Context) error {
+ actx := a.wrapContext(ctx)
+
+ // Initialize symlinks.
+ for _, symlink := range []struct {
+ source string
+ target string
+ }{
+ // systemd: src/shared/dev-setup.c:dev_setup()
+ {source: "fd", target: "/proc/self/fd"},
+ {source: "stdin", target: "/proc/self/fd/0"},
+ {source: "stdout", target: "/proc/self/fd/1"},
+ {source: "stderr", target: "/proc/self/fd/2"},
+ // /proc/kcore is not implemented.
+
+ // Linux implements /dev/ptmx as a device node, but advises
+ // container implementations to create /dev/ptmx as a symlink
+ // to pts/ptmx (Documentation/filesystems/devpts.txt). Systemd
+ // follows this advice (src/nspawn/nspawn.c:setup_pts()), while
+ // LXC tries to create a bind mount and falls back to a symlink
+ // (src/lxc/conf.c:lxc_setup_devpts()).
+ {source: "ptmx", target: "pts/ptmx"},
+ } {
+ if err := a.vfsObj.SymlinkAt(actx, a.creds, a.pathOperationAt(symlink.source), symlink.target); err != nil {
+ return fmt.Errorf("failed to create symlink %q => %q: %v", symlink.source, symlink.target, err)
+ }
+ }
+
+ // systemd: src/core/mount-setup.c:mount_table
+ for _, dir := range []string{
+ "shm",
+ "pts",
+ } {
+ if err := a.vfsObj.MkdirAt(actx, a.creds, a.pathOperationAt(dir), &vfs.MkdirOptions{
+ // systemd: src/core/mount-setup.c:mount_one()
+ Mode: 0755,
+ }); err != nil {
+ return fmt.Errorf("failed to create directory %q: %v", dir, err)
+ }
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
new file mode 100644
index 000000000..747867cca
--- /dev/null
+++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package devtmpfs
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func TestDevtmpfs(t *testing.T) {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ // Register tmpfs just so that we can have a root filesystem that isn't
+ // devtmpfs.
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ vfsObj.MustRegisterFilesystemType("devtmpfs", &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ // Create a test mount namespace with devtmpfs mounted at "/dev".
+ const devPath = "/dev"
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("failed to create tmpfs root mount: %v", err)
+ }
+ defer mntns.DecRef(ctx)
+ root := mntns.Root()
+ defer root.DecRef(ctx)
+ devpop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(devPath),
+ }
+ if err := vfsObj.MkdirAt(ctx, creds, &devpop, &vfs.MkdirOptions{
+ Mode: 0755,
+ }); err != nil {
+ t.Fatalf("failed to create mount point: %v", err)
+ }
+ if err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil {
+ t.Fatalf("failed to mount devtmpfs: %v", err)
+ }
+
+ a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs")
+ if err != nil {
+ t.Fatalf("failed to create devtmpfs.Accessor: %v", err)
+ }
+ defer a.Release(ctx)
+
+ // Create "userspace-initialized" files using a devtmpfs.Accessor.
+ if err := a.UserspaceInit(ctx); err != nil {
+ t.Fatalf("failed to userspace-initialize devtmpfs: %v", err)
+ }
+ // Created files should be visible in the test mount namespace.
+ abspath := devPath + "/fd"
+ target, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(abspath),
+ })
+ if want := "/proc/self/fd"; err != nil || target != want {
+ t.Fatalf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, target, err, want)
+ }
+
+ // Create a dummy device special file using a devtmpfs.Accessor.
+ const (
+ pathInDev = "dummy"
+ kind = vfs.CharDevice
+ major = 12
+ minor = 34
+ perms = 0600
+ wantMode = linux.S_IFCHR | perms
+ )
+ if err := a.CreateDeviceFile(ctx, pathInDev, kind, major, minor, perms); err != nil {
+ t.Fatalf("failed to create device file: %v", err)
+ }
+ // The device special file should be visible in the test mount namespace.
+ abspath = devPath + "/" + pathInDev
+ stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(abspath),
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE,
+ })
+ if err != nil {
+ t.Fatalf("failed to stat device file at %q: %v", abspath, err)
+ }
+ if stat.Mode != wantMode {
+ t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode)
+ }
+ if stat.RdevMajor != major {
+ t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, major)
+ }
+ if stat.RdevMinor != minor {
+ t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, minor)
+ }
+}
diff --git a/pkg/sentry/fsimpl/eventfd/BUILD b/pkg/sentry/fsimpl/eventfd/BUILD
new file mode 100644
index 000000000..ea167d38c
--- /dev/null
+++ b/pkg/sentry/fsimpl/eventfd/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "eventfd",
+ srcs = ["eventfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fdnotifier",
+ "//pkg/log",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "eventfd_test",
+ size = "small",
+ srcs = ["eventfd_test.go"],
+ library = ":eventfd",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/vfs",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go
new file mode 100644
index 000000000..812171fa3
--- /dev/null
+++ b/pkg/sentry/fsimpl/eventfd/eventfd.go
@@ -0,0 +1,285 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package eventfd implements event fds.
+package eventfd
+
+import (
+ "math"
+ "sync"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EventFileDescription implements FileDescriptionImpl for file-based event
+// notification (eventfd). Eventfds are usually internal to the Sentry but in
+// certain situations they may be converted into a host-backed eventfd.
+type EventFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // queue is used to notify interested parties when the event object
+ // becomes readable or writable.
+ queue waiter.Queue `state:"zerovalue"`
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // val is the current value of the event counter.
+ val uint64
+
+ // semMode specifies whether the event is in "semaphore" mode.
+ semMode bool
+
+ // hostfd indicates whether this eventfd is passed through to the host.
+ hostfd int
+}
+
+var _ vfs.FileDescriptionImpl = (*EventFileDescription)(nil)
+
+// New creates a new event fd.
+func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, initVal uint64, semMode bool, flags uint32) (*vfs.FileDescription, error) {
+ vd := vfsObj.NewAnonVirtualDentry("[eventfd]")
+ defer vd.DecRef(ctx)
+ efd := &EventFileDescription{
+ val: initVal,
+ semMode: semMode,
+ hostfd: -1,
+ }
+ if err := efd.vfsfd.Init(efd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &efd.vfsfd, nil
+}
+
+// HostFD returns the host eventfd associated with this event.
+func (efd *EventFileDescription) HostFD() (int, error) {
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ return efd.hostfd, nil
+ }
+
+ flags := linux.EFD_NONBLOCK
+ if efd.semMode {
+ flags |= linux.EFD_SEMAPHORE
+ }
+
+ fd, _, errno := syscall.Syscall(syscall.SYS_EVENTFD2, uintptr(efd.val), uintptr(flags), 0)
+ if errno != 0 {
+ return -1, errno
+ }
+
+ if err := fdnotifier.AddFD(int32(fd), &efd.queue); err != nil {
+ if closeErr := syscall.Close(int(fd)); closeErr != nil {
+ log.Warningf("close(%d) eventfd failed: %v", fd, closeErr)
+ }
+ return -1, err
+ }
+
+ efd.hostfd = int(fd)
+ return efd.hostfd, nil
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (efd *EventFileDescription) Release(context.Context) {
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ fdnotifier.RemoveFD(int32(efd.hostfd))
+ if closeErr := syscall.Close(int(efd.hostfd)); closeErr != nil {
+ log.Warningf("close(%d) eventfd failed: %v", efd.hostfd, closeErr)
+ }
+ efd.hostfd = -1
+ }
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ if dst.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := efd.read(ctx, dst); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (efd *EventFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
+ if src.NumBytes() < 8 {
+ return 0, syscall.EINVAL
+ }
+ if err := efd.write(ctx, src); err != nil {
+ return 0, err
+ }
+ return 8, nil
+}
+
+// Preconditions: Must be called with efd.mu locked.
+func (efd *EventFileDescription) hostReadLocked(ctx context.Context, dst usermem.IOSequence) error {
+ var buf [8]byte
+ if _, err := syscall.Read(efd.hostfd, buf[:]); err != nil {
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+ }
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+func (efd *EventFileDescription) read(ctx context.Context, dst usermem.IOSequence) error {
+ efd.mu.Lock()
+ if efd.hostfd >= 0 {
+ defer efd.mu.Unlock()
+ return efd.hostReadLocked(ctx, dst)
+ }
+
+ // We can't complete the read if the value is currently zero.
+ if efd.val == 0 {
+ efd.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ // Update the value based on the mode the event is operating in.
+ var val uint64
+ if efd.semMode {
+ val = 1
+ // Consistent with Linux, this is done even if writing to memory fails.
+ efd.val--
+ } else {
+ val = efd.val
+ efd.val = 0
+ }
+
+ efd.mu.Unlock()
+
+ // Notify writers. We do this even if we were already writable because
+ // it is possible that a writer is waiting to write the maximum value
+ // to the event.
+ efd.queue.Notify(waiter.EventOut)
+
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := dst.CopyOut(ctx, buf[:])
+ return err
+}
+
+// Preconditions: Must be called with efd.mu locked.
+func (efd *EventFileDescription) hostWriteLocked(val uint64) error {
+ var buf [8]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ _, err := syscall.Write(efd.hostfd, buf[:])
+ if err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+func (efd *EventFileDescription) write(ctx context.Context, src usermem.IOSequence) error {
+ var buf [8]byte
+ if _, err := src.CopyIn(ctx, buf[:]); err != nil {
+ return err
+ }
+ val := usermem.ByteOrder.Uint64(buf[:])
+
+ return efd.Signal(val)
+}
+
+// Signal is an internal function to signal the event fd.
+func (efd *EventFileDescription) Signal(val uint64) error {
+ if val == math.MaxUint64 {
+ return syscall.EINVAL
+ }
+
+ efd.mu.Lock()
+
+ if efd.hostfd >= 0 {
+ defer efd.mu.Unlock()
+ return efd.hostWriteLocked(val)
+ }
+
+ // We only allow writes that won't cause the value to go over the max
+ // uint64 minus 1.
+ if val > math.MaxUint64-1-efd.val {
+ efd.mu.Unlock()
+ return syserror.ErrWouldBlock
+ }
+
+ efd.val += val
+ efd.mu.Unlock()
+
+ // Always trigger a notification.
+ efd.queue.Notify(waiter.EventIn)
+
+ return nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (efd *EventFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+
+ if efd.hostfd >= 0 {
+ return fdnotifier.NonBlockingPoll(int32(efd.hostfd), mask)
+ }
+
+ ready := waiter.EventMask(0)
+ if efd.val > 0 {
+ ready |= waiter.EventIn
+ }
+
+ if efd.val < math.MaxUint64-1 {
+ ready |= waiter.EventOut
+ }
+
+ return mask & ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (efd *EventFileDescription) EventRegister(entry *waiter.Entry, mask waiter.EventMask) {
+ efd.queue.EventRegister(entry, mask)
+
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(efd.hostfd))
+ }
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (efd *EventFileDescription) EventUnregister(entry *waiter.Entry) {
+ efd.queue.EventUnregister(entry)
+
+ efd.mu.Lock()
+ defer efd.mu.Unlock()
+ if efd.hostfd >= 0 {
+ fdnotifier.UpdateFD(int32(efd.hostfd))
+ }
+}
diff --git a/pkg/sentry/fsimpl/eventfd/eventfd_test.go b/pkg/sentry/fsimpl/eventfd/eventfd_test.go
new file mode 100644
index 000000000..49916fa81
--- /dev/null
+++ b/pkg/sentry/fsimpl/eventfd/eventfd_test.go
@@ -0,0 +1,97 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package eventfd
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestEventFD(t *testing.T) {
+ initVals := []uint64{
+ 0,
+ // Using a non-zero initial value verifies that writing to an
+ // eventfd signals when the eventfd's counter was already
+ // non-zero.
+ 343,
+ }
+
+ for _, initVal := range initVals {
+ ctx := contexttest.Context(t)
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+
+ // Make a new eventfd that is writable.
+ eventfd, err := New(ctx, vfsObj, initVal, false, linux.O_RDWR)
+ if err != nil {
+ t.Fatalf("New() failed: %v", err)
+ }
+ defer eventfd.DecRef(ctx)
+
+ // Register a callback for a write event.
+ w, ch := waiter.NewChannelEntry(nil)
+ eventfd.EventRegister(&w, waiter.EventIn)
+ defer eventfd.EventUnregister(&w)
+
+ data := []byte("00000124")
+ // Create and submit a write request.
+ n, err := eventfd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 8 {
+ t.Errorf("eventfd.write wrote %d bytes, not full int64", n)
+ }
+
+ // Check if the callback fired due to the write event.
+ select {
+ case <-ch:
+ default:
+ t.Errorf("Didn't get notified of EventIn after write")
+ }
+ }
+}
+
+func TestEventFDStat(t *testing.T) {
+ ctx := contexttest.Context(t)
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+
+ // Make a new eventfd that is writable.
+ eventfd, err := New(ctx, vfsObj, 0, false, linux.O_RDWR)
+ if err != nil {
+ t.Fatalf("New() failed: %v", err)
+ }
+ defer eventfd.DecRef(ctx)
+
+ statx, err := eventfd.Stat(ctx, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ t.Fatalf("eventfd.Stat failed: %v", err)
+ }
+ if statx.Size != 0 {
+ t.Errorf("eventfd size should be 0")
+ }
+}
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 7ccff8b0d..abc610ef3 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -16,6 +15,17 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "ext",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
go_library(
name = "ext",
srcs = [
@@ -27,29 +37,33 @@ go_library(
"extent_file.go",
"file_description.go",
"filesystem.go",
+ "fstree.go",
"inode.go",
"regular_file.go",
"symlink.go",
"utils.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/context",
"//pkg/fd",
+ "//pkg/fspath",
"//pkg/log",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/fs",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
- "//pkg/sentry/safemem",
+ "//pkg/sentry/socket/unix/transport",
"//pkg/sentry/syscalls/linux",
- "//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -69,19 +83,20 @@ go_test(
"//pkg/sentry/fsimpl/ext:assets/tiny.ext3",
"//pkg/sentry/fsimpl/ext:assets/tiny.ext4",
],
- embed = [":ext"],
+ library = ":ext",
deps = [
"//pkg/abi/linux",
"//pkg/binary",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/ext/disklayout",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//runsc/testutil",
- "@com_github_google_go-cmp//cmp:go_default_library",
- "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
+ "//pkg/test/testutil",
+ "//pkg/usermem",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
index bfc46dfa6..6c5a559fd 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/BUILD
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_test")
package(licenses = ["notice"])
@@ -7,8 +7,9 @@ go_test(
size = "small",
srcs = ["benchmark_test.go"],
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fsimpl/ext",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
index 10a8083a0..8f7d5a9bb 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
+++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go
@@ -15,6 +15,9 @@
// These benchmarks emulate memfs benchmarks. Ext4 images must be created
// before this benchmark is run using the `make_deep_ext4.sh` script at
// /tmp/image-{depth}.ext4 for all the depths tested below.
+//
+// The benchmark itself cannot run the script because the script requires
+// sudo privileges to create the file system images.
package benchmark_test
import (
@@ -24,8 +27,9 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -48,9 +52,14 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())})
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ return nil, nil, nil, nil, err
+ }
+ vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
return nil, nil, nil, nil, err
@@ -59,7 +68,7 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys
root := mntns.Root()
tearDown := func() {
- root.DecRef()
+ root.DecRef(ctx)
if err := f.Close(); err != nil {
b.Fatalf("tearDown failed: %v", err)
@@ -81,7 +90,11 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf
ctx := contexttest.Context(b)
creds := auth.CredentialsFromContext(ctx)
- if err := vfsfs.NewMount(ctx, creds, imagePath, pop, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())}); err != nil {
+ if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: int(f.Fd()),
+ },
+ }); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
return func() {
@@ -117,7 +130,7 @@ func BenchmarkVFS2Ext4fsStat(b *testing.B) {
stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
Root: *root,
Start: *root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
@@ -146,9 +159,9 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
creds := auth.CredentialsFromContext(ctx)
mountPointName := "/1/"
pop := vfs.PathOperation{
- Root: *root,
- Start: *root,
- Pathname: mountPointName,
+ Root: *root,
+ Start: *root,
+ Path: fspath.Parse(mountPointName),
}
// Save the mount point for later use.
@@ -156,7 +169,7 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount point: %v", err)
}
- defer mountPoint.DecRef()
+ defer mountPoint.DecRef(ctx)
// Create extfs submount.
mountTearDown := mount(b, fmt.Sprintf("/tmp/image-%d.ext4", depth), vfsfs, &pop)
@@ -177,7 +190,7 @@ func BenchmarkVFS2ExtfsMountStat(b *testing.B) {
stat, err := vfsfs.StatAt(ctx, creds, &vfs.PathOperation{
Root: *root,
Start: *root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
diff --git a/pkg/sentry/fsimpl/ext/block_map_file.go b/pkg/sentry/fsimpl/ext/block_map_file.go
index cea89bcd9..8bb104ff0 100644
--- a/pkg/sentry/fsimpl/ext/block_map_file.go
+++ b/pkg/sentry/fsimpl/ext/block_map_file.go
@@ -58,15 +58,16 @@ var _ io.ReaderAt = (*blockMapFile)(nil)
// newBlockMapFile is the blockMapFile constructor. It initializes the file to
// physical blocks map with (at most) the first 12 (direct) blocks.
-func newBlockMapFile(regFile regularFile) (*blockMapFile, error) {
- file := &blockMapFile{regFile: regFile}
+func newBlockMapFile(args inodeArgs) (*blockMapFile, error) {
+ file := &blockMapFile{}
file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
for i := uint(0); i < 4; i++ {
- file.coverage[i] = getCoverage(regFile.inode.blkSize, i)
+ file.coverage[i] = getCoverage(file.regFile.inode.blkSize, i)
}
- blkMap := regFile.inode.diskInode.Data()
+ blkMap := file.regFile.inode.diskInode.Data()
binary.Unmarshal(blkMap[:numDirectBlks*4], binary.LittleEndian, &file.directBlks)
binary.Unmarshal(blkMap[numDirectBlks*4:(numDirectBlks+1)*4], binary.LittleEndian, &file.indirectBlk)
binary.Unmarshal(blkMap[(numDirectBlks+1)*4:(numDirectBlks+2)*4], binary.LittleEndian, &file.doubleIndirectBlk)
@@ -154,7 +155,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
toRead = len(dst)
}
- n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff))
+ n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], curPhyBlkOff+int64(relFileOff))
if n < toRead {
return n, syserror.EIO
}
@@ -174,7 +175,7 @@ func (f *blockMapFile) read(curPhyBlk uint32, relFileOff uint64, height uint, ds
curChildOff := relFileOff % childCov
for i := startIdx; i < endIdx; i++ {
var childPhyBlk uint32
- err := readFromDisk(f.regFile.inode.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
+ err := readFromDisk(f.regFile.inode.fs.dev, curPhyBlkOff+int64(i*4), &childPhyBlk)
if err != nil {
return read, err
}
diff --git a/pkg/sentry/fsimpl/ext/block_map_test.go b/pkg/sentry/fsimpl/ext/block_map_test.go
index 213aa3919..6fa84e7aa 100644
--- a/pkg/sentry/fsimpl/ext/block_map_test.go
+++ b/pkg/sentry/fsimpl/ext/block_map_test.go
@@ -85,18 +85,6 @@ func (n *blkNumGen) next() uint32 {
// the inode covers and that is written to disk.
func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
mockDisk := make([]byte, mockBMDiskSize)
- regFile := regularFile{
- inode: inode{
- diskInode: &disklayout.InodeNew{
- InodeOld: disklayout.InodeOld{
- SizeLo: getMockBMFileFize(),
- },
- },
- dev: bytes.NewReader(mockDisk),
- blkSize: uint64(mockBMBlkSize),
- },
- }
-
var fileData []byte
blkNums := newBlkNumGen()
var data []byte
@@ -123,9 +111,20 @@ func blockMapSetUp(t *testing.T) (*blockMapFile, []byte) {
data = binary.Marshal(data, binary.LittleEndian, triplyIndirectBlk)
fileData = append(fileData, writeFileDataToBlock(mockDisk, triplyIndirectBlk, 3, blkNums)...)
- copy(regFile.inode.diskInode.Data(), data)
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: getMockBMFileFize(),
+ },
+ },
+ blkSize: uint64(mockBMBlkSize),
+ }
+ copy(args.diskInode.Data(), data)
- mockFile, err := newBlockMapFile(regFile)
+ mockFile, err := newBlockMapFile(args)
if err != nil {
t.Fatalf("newBlockMapFile failed: %v", err)
}
diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go
index 054fb42b6..7a1b4219f 100644
--- a/pkg/sentry/fsimpl/ext/dentry.go
+++ b/pkg/sentry/fsimpl/ext/dentry.go
@@ -15,6 +15,7 @@
package ext
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/vfs"
)
@@ -22,6 +23,10 @@ import (
type dentry struct {
vfsd vfs.Dentry
+ // Protected by filesystem.mu.
+ parent *dentry
+ name string
+
// inode is the inode represented by this dentry. Multiple Dentries may
// share a single non-directory Inode (with hard links). inode is
// immutable.
@@ -41,16 +46,35 @@ func newDentry(in *inode) *dentry {
}
// IncRef implements vfs.DentryImpl.IncRef.
-func (d *dentry) IncRef(vfsfs *vfs.Filesystem) {
+func (d *dentry) IncRef() {
d.inode.incRef()
}
// TryIncRef implements vfs.DentryImpl.TryIncRef.
-func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool {
+func (d *dentry) TryIncRef() bool {
return d.inode.tryIncRef()
}
// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef(vfsfs *vfs.Filesystem) {
- d.inode.decRef(vfsfs.Impl().(*filesystem))
+func (d *dentry) DecRef(ctx context.Context) {
+ // FIXME(b/134676337): filesystem.mu may not be locked as required by
+ // inode.decRef().
+ d.inode.decRef()
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {}
+
+// Watches implements vfs.DentryImpl.Watches.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) Watches() *vfs.Watches {
+ return nil
}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+//
+// TODO(b/134676337): Implement inotify.
+func (d *dentry) OnZeroWatches(context.Context) {}
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index 91802dc1e..0fc01668d 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -15,16 +15,15 @@
package ext
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
- "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -32,6 +31,10 @@ import (
type directory struct {
inode inode
+ // childCache maps filenames to dentries for children for which dentries
+ // have been instantiated. childCache is protected by filesystem.mu.
+ childCache map[string]*dentry
+
// mu serializes the changes to childList.
// Lock Order (outermost locks must be taken first):
// directory.mu
@@ -51,13 +54,16 @@ type directory struct {
childMap map[string]*dirent
}
-// newDirectroy is the directory constructor.
-func newDirectroy(inode inode, newDirent bool) (*directory, error) {
- file := &directory{inode: inode, childMap: make(map[string]*dirent)}
- file.inode.impl = file
+// newDirectory is the directory constructor.
+func newDirectory(args inodeArgs, newDirent bool) (*directory, error) {
+ file := &directory{
+ childCache: make(map[string]*dentry),
+ childMap: make(map[string]*dirent),
+ }
+ file.inode.init(args, file)
// Initialize childList by reading dirents from the underlying file.
- if inode.diskInode.Flags().Index {
+ if args.diskInode.Flags().Index {
// TODO(b/134676337): Support hash tree directories. Currently only the '.'
// and '..' entries are read in.
@@ -68,7 +74,7 @@ func newDirectroy(inode inode, newDirent bool) (*directory, error) {
// The dirents are organized in a linear array in the file data.
// Extract the file data and decode the dirents.
- regFile, err := newRegularFile(inode)
+ regFile, err := newRegularFile(args)
if err != nil {
return nil, err
}
@@ -76,7 +82,7 @@ func newDirectroy(inode inode, newDirent bool) (*directory, error) {
// buf is used as scratch space for reading in dirents from disk and
// unmarshalling them into dirent structs.
buf := make([]byte, disklayout.DirentSize)
- size := inode.diskInode.Size()
+ size := args.diskInode.Size()
for off, inc := uint64(0), uint64(0); off < size; off += inc {
toRead := size - off
if toRead > disklayout.DirentSize {
@@ -136,7 +142,7 @@ type directoryFD struct {
var _ vfs.FileDescriptionImpl = (*directoryFD)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *directoryFD) Release() {
+func (fd *directoryFD) Release(ctx context.Context) {
if fd.iter == nil {
return
}
@@ -189,14 +195,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
childType = fs.ToInodeType(childInode.diskInode.Mode().FileType())
}
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: child.diskDirent.FileName(),
Type: fs.ToDirentType(childType),
Ino: uint64(child.diskDirent.Inode()),
NextOff: fd.off + 1,
- }) {
+ }); err != nil {
dir.childList.InsertBefore(child, fd.iter)
- return nil
+ return err
}
fd.off++
}
@@ -301,8 +307,12 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
return offset, nil
}
-// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
-func (fd *directoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
- // mmap(2) specifies that EACCESS should be returned for non-regular file fds.
- return syserror.EACCES
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *directoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *directoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
}
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index fcfaf5c3e..9bd9c76c0 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -23,7 +22,6 @@ go_library(
"superblock_old.go",
"test_utils.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
@@ -44,6 +42,6 @@ go_test(
"inode_test.go",
"superblock_test.go",
],
- embed = [":disklayout"],
+ library = ":disklayout",
deps = ["//pkg/sentry/kernel/time"],
)
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent.go b/pkg/sentry/fsimpl/ext/disklayout/extent.go
index 567523d32..4110649ab 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent.go
@@ -29,8 +29,12 @@ package disklayout
// byte (i * sb.BlockSize()) to ((i+1) * sb.BlockSize()).
const (
- // ExtentStructsSize is the size of all the three extent on-disk structs.
- ExtentStructsSize = 12
+ // ExtentHeaderSize is the size of the header of an extent tree node.
+ ExtentHeaderSize = 12
+
+ // ExtentEntrySize is the size of an entry in an extent tree node.
+ // This size is the same for both leaf and internal nodes.
+ ExtentEntrySize = 12
// ExtentMagic is the magic number which must be present in the header.
ExtentMagic = 0xf30a
@@ -57,7 +61,7 @@ type ExtentNode struct {
Entries []ExtentEntryPair
}
-// ExtentEntry reprsents an extent tree node entry. The entry can either be
+// ExtentEntry represents an extent tree node entry. The entry can either be
// an ExtentIdx or Extent itself. This exists to simplify navigation logic.
type ExtentEntry interface {
// FileBlock returns the first file block number covered by this entry.
diff --git a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
index b0fad9b71..8762b90db 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/disklayout/extent_test.go
@@ -21,7 +21,7 @@ import (
// TestExtentSize tests that the extent structs are of the correct
// size.
func TestExtentSize(t *testing.T) {
- assertSize(t, ExtentHeader{}, ExtentStructsSize)
- assertSize(t, ExtentIdx{}, ExtentStructsSize)
- assertSize(t, Extent{}, ExtentStructsSize)
+ assertSize(t, ExtentHeader{}, ExtentHeaderSize)
+ assertSize(t, ExtentIdx{}, ExtentEntrySize)
+ assertSize(t, Extent{}, ExtentEntrySize)
}
diff --git a/pkg/sentry/fsimpl/ext/ext.go b/pkg/sentry/fsimpl/ext/ext.go
index f10accafc..08ffc2834 100644
--- a/pkg/sentry/fsimpl/ext/ext.go
+++ b/pkg/sentry/fsimpl/ext/ext.go
@@ -21,15 +21,18 @@ import (
"io"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
+// Name is the name of this filesystem.
+const Name = "ext"
+
// FilesystemType implements vfs.FilesystemType.
type FilesystemType struct{}
@@ -40,14 +43,14 @@ var _ vfs.FilesystemType = (*FilesystemType)(nil)
// Currently there are two ways of mounting an ext(2/3/4) fs:
// 1. Specify a mount with our internal special MountType in the OCI spec.
// 2. Expose the device to the container and mount it from application layer.
-func getDeviceFd(source string, opts vfs.NewFilesystemOptions) (io.ReaderAt, error) {
+func getDeviceFd(source string, opts vfs.GetFilesystemOptions) (io.ReaderAt, error) {
if opts.InternalData == nil {
// User mount call.
// TODO(b/134676337): Open the device specified by `source` and return that.
panic("unimplemented")
}
- // NewFilesystem call originated from within the sentry.
+ // GetFilesystem call originated from within the sentry.
devFd, ok := opts.InternalData.(int)
if !ok {
return nil, errors.New("internal data for ext fs must be an int containing the file descriptor to device")
@@ -91,42 +94,61 @@ func isCompatible(sb disklayout.SuperBlock) bool {
return true
}
-// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
-func (FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
// TODO(b/134676337): Ensure that the user is mounting readonly. If not,
// EACCESS should be returned according to mount(2). Filesystem independent
// flags (like readonly) are currently not available in pkg/sentry/vfs.
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
dev, err := getDeviceFd(source, opts)
if err != nil {
return nil, nil, err
}
- fs := filesystem{dev: dev, inodeCache: make(map[uint32]*inode)}
- fs.vfsfs.Init(&fs)
+ fs := filesystem{
+ dev: dev,
+ inodeCache: make(map[uint32]*inode),
+ devMinor: devMinor,
+ }
+ fs.vfsfs.Init(vfsObj, &fsType, &fs)
fs.sb, err = readSuperBlock(dev)
if err != nil {
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
if fs.sb.Magic() != linux.EXT_SUPER_MAGIC {
// mount(2) specifies that EINVAL should be returned if the superblock is
// invalid.
+ fs.vfsfs.DecRef(ctx)
return nil, nil, syserror.EINVAL
}
// Refuse to mount if the filesystem is incompatible.
if !isCompatible(fs.sb) {
+ fs.vfsfs.DecRef(ctx)
return nil, nil, syserror.EINVAL
}
fs.bgs, err = readBlockGroups(dev, fs.sb)
if err != nil {
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
rootInode, err := fs.getOrCreateInodeLocked(disklayout.RootDirInode)
if err != nil {
+ fs.vfsfs.DecRef(ctx)
return nil, nil, err
}
rootInode.incRef()
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 1aa2bd6a4..2dbaee287 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -25,15 +25,15 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
-
- "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -64,9 +64,14 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.NewFilesystemOptions{InternalData: int(f.Fd())})
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())})
if err != nil {
f.Close()
return nil, nil, nil, nil, err
@@ -75,7 +80,7 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys
root := mntns.Root()
tearDown := func() {
- root.DecRef()
+ root.DecRef(ctx)
if err := f.Close(); err != nil {
t.Fatalf("tearDown failed: %v", err)
@@ -140,62 +145,61 @@ func TestSeek(t *testing.T) {
fd, err := vfsfs.OpenAt(
ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
&vfs.OpenOptions{},
)
if err != nil {
t.Fatalf("vfsfs.OpenAt failed: %v", err)
}
- if n, err := fd.Impl().Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
+ if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 || err != nil {
t.Errorf("expected seek position 0, got %d and error %v", n, err)
}
- stat, err := fd.Impl().Stat(ctx, vfs.StatOptions{})
+ stat, err := fd.Stat(ctx, vfs.StatOptions{})
if err != nil {
t.Errorf("fd.stat failed for file %s in image %s: %v", test.path, test.image, err)
}
// We should be able to seek beyond the end of file.
size := int64(stat.Size)
- if n, err := fd.Impl().Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
+ if n, err := fd.Seek(ctx, size, linux.SEEK_SET); n != size || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -1, linux.SEEK_SET); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
- if n, err := fd.Impl().Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
+ if n, err := fd.Seek(ctx, 3, linux.SEEK_CUR); n != size+3 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+3, n, err)
}
// Make sure negative offsets work with SEEK_CUR.
- if n, err := fd.Impl().Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
+ if n, err := fd.Seek(ctx, -2, linux.SEEK_CUR); n != size+1 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 2), linux.SEEK_CUR); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
// Make sure SEEK_END works with regular files.
- switch fd.Impl().(type) {
- case *regularFileFD:
+ if _, ok := fd.Impl().(*regularFileFD); ok {
// Seek back to 0.
- if n, err := fd.Impl().Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
+ if n, err := fd.Seek(ctx, -size, linux.SEEK_END); n != 0 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", 0, n, err)
}
// Seek forward beyond EOF.
- if n, err := fd.Impl().Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
+ if n, err := fd.Seek(ctx, 1, linux.SEEK_END); n != size+1 || err != nil {
t.Errorf("expected seek position %d, got %d and error %v", size+1, n, err)
}
// EINVAL should be returned if the resulting offset is negative.
- if _, err := fd.Impl().Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
+ if _, err := fd.Seek(ctx, -(size + 1), linux.SEEK_END); err != syserror.EINVAL {
t.Errorf("expected error EINVAL but got %v", err)
}
}
@@ -360,7 +364,7 @@ func TestStatAt(t *testing.T) {
got, err := vfsfs.StatAt(ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
&vfs.StatOptions{},
)
if err != nil {
@@ -430,7 +434,7 @@ func TestRead(t *testing.T) {
fd, err := vfsfs.OpenAt(
ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.absPath},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.absPath)},
&vfs.OpenOptions{},
)
if err != nil {
@@ -456,7 +460,7 @@ func TestRead(t *testing.T) {
want := make([]byte, 1)
for {
n, err := f.Read(want)
- fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
+ fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{})
if diff := cmp.Diff(got, want); diff != "" {
t.Errorf("file data mismatch (-want +got):\n%s", diff)
@@ -464,7 +468,7 @@ func TestRead(t *testing.T) {
// Make sure there is no more file data left after getting EOF.
if n == 0 || err == io.EOF {
- if n, _ := fd.Impl().Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
+ if n, _ := fd.Read(ctx, usermem.BytesIOSequence(got), vfs.ReadOptions{}); n != 0 {
t.Errorf("extra unexpected file data in file %s in image %s", test.absPath, test.image)
}
@@ -494,9 +498,9 @@ func newIterDirentCb() *iterDirentsCb {
}
// Handle implements vfs.IterDirentsCallback.Handle.
-func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) bool {
+func (cb *iterDirentsCb) Handle(dirent vfs.Dirent) error {
cb.dirents = append(cb.dirents, dirent)
- return true
+ return nil
}
// TestIterDirents tests the FileDescriptionImpl.IterDirents functionality.
@@ -509,27 +513,27 @@ func TestIterDirents(t *testing.T) {
}
wantDirents := []vfs.Dirent{
- vfs.Dirent{
+ {
Name: ".",
Type: linux.DT_DIR,
},
- vfs.Dirent{
+ {
Name: "..",
Type: linux.DT_DIR,
},
- vfs.Dirent{
+ {
Name: "lost+found",
Type: linux.DT_DIR,
},
- vfs.Dirent{
+ {
Name: "file.txt",
Type: linux.DT_REG,
},
- vfs.Dirent{
+ {
Name: "bigfile.txt",
Type: linux.DT_REG,
},
- vfs.Dirent{
+ {
Name: "symlink.txt",
Type: linux.DT_LNK,
},
@@ -566,7 +570,7 @@ func TestIterDirents(t *testing.T) {
fd, err := vfsfs.OpenAt(
ctx,
auth.CredentialsFromContext(ctx),
- &vfs.PathOperation{Root: *root, Start: *root, Pathname: test.path},
+ &vfs.PathOperation{Root: *root, Start: *root, Path: fspath.Parse(test.path)},
&vfs.OpenOptions{},
)
if err != nil {
@@ -574,7 +578,7 @@ func TestIterDirents(t *testing.T) {
}
cb := &iterDirentsCb{}
- if err = fd.Impl().IterDirents(ctx, cb); err != nil {
+ if err = fd.IterDirents(ctx, cb); err != nil {
t.Fatalf("dir fd.IterDirents() failed: %v", err)
}
diff --git a/pkg/sentry/fsimpl/ext/extent_file.go b/pkg/sentry/fsimpl/ext/extent_file.go
index 38b68a2d3..c36225a7c 100644
--- a/pkg/sentry/fsimpl/ext/extent_file.go
+++ b/pkg/sentry/fsimpl/ext/extent_file.go
@@ -38,9 +38,10 @@ var _ io.ReaderAt = (*extentFile)(nil)
// newExtentFile is the extent file constructor. It reads the entire extent
// tree into memory.
// TODO(b/134676337): Build extent tree on demand to reduce memory usage.
-func newExtentFile(regFile regularFile) (*extentFile, error) {
- file := &extentFile{regFile: regFile}
+func newExtentFile(args inodeArgs) (*extentFile, error) {
+ file := &extentFile{}
file.regFile.impl = file
+ file.regFile.inode.init(args, &file.regFile)
err := file.buildExtTree()
if err != nil {
return nil, err
@@ -57,7 +58,7 @@ func newExtentFile(regFile regularFile) (*extentFile, error) {
func (f *extentFile) buildExtTree() error {
rootNodeData := f.regFile.inode.diskInode.Data()
- binary.Unmarshal(rootNodeData[:disklayout.ExtentStructsSize], binary.LittleEndian, &f.root.Header)
+ binary.Unmarshal(rootNodeData[:disklayout.ExtentHeaderSize], binary.LittleEndian, &f.root.Header)
// Root node can not have more than 4 entries: 60 bytes = 1 header + 4 entries.
if f.root.Header.NumEntries > 4 {
@@ -67,7 +68,7 @@ func (f *extentFile) buildExtTree() error {
}
f.root.Entries = make([]disklayout.ExtentEntryPair, f.root.Header.NumEntries)
- for i, off := uint16(0), disklayout.ExtentStructsSize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), disklayout.ExtentEntrySize; i < f.root.Header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if f.root.Header.Height == 0 {
// Leaf node.
@@ -76,7 +77,7 @@ func (f *extentFile) buildExtTree() error {
// Internal node.
curEntry = &disklayout.ExtentIdx{}
}
- binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentStructsSize], binary.LittleEndian, curEntry)
+ binary.Unmarshal(rootNodeData[off:off+disklayout.ExtentEntrySize], binary.LittleEndian, curEntry)
f.root.Entries[i].Entry = curEntry
}
@@ -99,13 +100,13 @@ func (f *extentFile) buildExtTree() error {
func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*disklayout.ExtentNode, error) {
var header disklayout.ExtentHeader
off := entry.PhysicalBlock() * f.regFile.inode.blkSize
- err := readFromDisk(f.regFile.inode.dev, int64(off), &header)
+ err := readFromDisk(f.regFile.inode.fs.dev, int64(off), &header)
if err != nil {
return nil, err
}
entries := make([]disklayout.ExtentEntryPair, header.NumEntries)
- for i, off := uint16(0), off+disklayout.ExtentStructsSize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentStructsSize {
+ for i, off := uint16(0), off+disklayout.ExtentEntrySize; i < header.NumEntries; i, off = i+1, off+disklayout.ExtentEntrySize {
var curEntry disklayout.ExtentEntry
if header.Height == 0 {
// Leaf node.
@@ -115,7 +116,7 @@ func (f *extentFile) buildExtTreeFromDisk(entry disklayout.ExtentEntry) (*diskla
curEntry = &disklayout.ExtentIdx{}
}
- err := readFromDisk(f.regFile.inode.dev, int64(off), curEntry)
+ err := readFromDisk(f.regFile.inode.fs.dev, int64(off), curEntry)
if err != nil {
return nil, err
}
@@ -229,7 +230,7 @@ func (f *extentFile) readFromExtent(ex *disklayout.Extent, off uint64, dst []byt
toRead = len(dst)
}
- n, _ := f.regFile.inode.dev.ReadAt(dst[:toRead], int64(readStart))
+ n, _ := f.regFile.inode.fs.dev.ReadAt(dst[:toRead], int64(readStart))
if n < toRead {
return n, syserror.EIO
}
diff --git a/pkg/sentry/fsimpl/ext/extent_test.go b/pkg/sentry/fsimpl/ext/extent_test.go
index 42d0a484b..cd10d46ee 100644
--- a/pkg/sentry/fsimpl/ext/extent_test.go
+++ b/pkg/sentry/fsimpl/ext/extent_test.go
@@ -177,19 +177,19 @@ func extentTreeSetUp(t *testing.T, root *disklayout.ExtentNode) (*extentFile, []
t.Helper()
mockDisk := make([]byte, mockExtentBlkSize*10)
- mockExtentFile := &extentFile{
- regFile: regularFile{
- inode: inode{
- diskInode: &disklayout.InodeNew{
- InodeOld: disklayout.InodeOld{
- SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
- },
- },
- blkSize: mockExtentBlkSize,
- dev: bytes.NewReader(mockDisk),
+ mockExtentFile := &extentFile{}
+ args := inodeArgs{
+ fs: &filesystem{
+ dev: bytes.NewReader(mockDisk),
+ },
+ diskInode: &disklayout.InodeNew{
+ InodeOld: disklayout.InodeOld{
+ SizeLo: uint32(mockExtentBlkSize) * getNumPhyBlks(root),
},
},
+ blkSize: mockExtentBlkSize,
}
+ mockExtentFile.regFile.inode.init(args, &mockExtentFile.regFile)
fileData := writeTree(&mockExtentFile.regFile.inode, mockDisk, node0, mockExtentBlkSize)
diff --git a/pkg/sentry/fsimpl/ext/file_description.go b/pkg/sentry/fsimpl/ext/file_description.go
index 4d18b28cb..90b086468 100644
--- a/pkg/sentry/fsimpl/ext/file_description.go
+++ b/pkg/sentry/fsimpl/ext/file_description.go
@@ -16,7 +16,7 @@ package ext
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -26,33 +26,15 @@ import (
type fileDescription struct {
vfsfd vfs.FileDescription
vfs.FileDescriptionDefaultImpl
-
- // flags is the same as vfs.OpenOptions.Flags which are passed to
- // vfs.FilesystemImpl.OpenAt.
- // TODO(b/134676337): syscalls like read(2), write(2), fchmod(2), fchown(2),
- // fgetxattr(2), ioctl(2), mmap(2) should fail with EBADF if O_PATH is set.
- // Only close(2), fstat(2), fstatfs(2) should work.
- flags uint32
+ vfs.LockFD
}
func (fd *fileDescription) filesystem() *filesystem {
- return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem)
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
}
func (fd *fileDescription) inode() *inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
-}
-
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
+ return fd.vfsfd.Dentry().Impl().(*dentry).inode
}
// Stat implements vfs.FileDescriptionImpl.Stat.
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 2d15e8aaf..c714ddf73 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -17,12 +17,15 @@ package ext
import (
"errors"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/ext/disklayout"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -61,6 +64,10 @@ type filesystem struct {
// bgs represents all the block group descriptors for the filesystem.
// Immutable after initialization.
bgs []disklayout.BlockGroup
+
+ // devMinor is this filesystem's device minor number. Immutable after
+ // initialization.
+ devMinor uint32
}
// Compiles only if filesystem implements vfs.FilesystemImpl.
@@ -77,7 +84,7 @@ var _ vfs.FilesystemImpl = (*filesystem)(nil)
// - filesystem.mu must be locked (for writing if write param is true).
// - !rp.Done().
// - inode == vfsd.Impl().(*Dentry).inode.
-func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) {
+func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) {
if !inode.isDir() {
return nil, nil, syserror.ENOTDIR
}
@@ -86,14 +93,33 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo
}
for {
- nextVFSD, err := rp.ResolveComponent(vfsd)
- if err != nil {
- return nil, nil, err
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return vfsd, inode, nil
}
- if nextVFSD == nil {
- // Since the Dentry tree is not the sole source of truth for extfs, if it's
- // not in the Dentry tree, it might need to be pulled from disk.
- childDirent, ok := inode.impl.(*directory).childMap[rp.Component()]
+ d := vfsd.Impl().(*dentry)
+ if name == ".." {
+ isRoot, err := rp.CheckRoot(ctx, vfsd)
+ if err != nil {
+ return nil, nil, err
+ }
+ if isRoot || d.parent == nil {
+ rp.Advance()
+ return vfsd, inode, nil
+ }
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, nil, err
+ }
+ rp.Advance()
+ return &d.parent.vfsd, d.parent.inode, nil
+ }
+
+ dir := inode.impl.(*directory)
+ child, ok := dir.childCache[name]
+ if !ok {
+ // We may need to instantiate a new dentry for this child.
+ childDirent, ok := dir.childMap[name]
if !ok {
// The underlying inode does not exist on disk.
return nil, nil, syserror.ENOENT
@@ -112,21 +138,22 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo
}
// incRef because this is being added to the dentry tree.
childInode.incRef()
- child := newDentry(childInode)
- vfsd.InsertChild(&child.vfsd, rp.Component())
-
- // Continue as usual now that nextVFSD is not nil.
- nextVFSD = &child.vfsd
+ child = newDentry(childInode)
+ child.parent = d
+ child.name = name
+ dir.childCache[name] = child
}
- nextInode := nextVFSD.Impl().(*dentry).inode
- if nextInode.isSymlink() && rp.ShouldFollowSymlink() {
- if err := rp.HandleSymlink(inode.impl.(*symlink).target); err != nil {
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
+ return nil, nil, err
+ }
+ if child.inode.isSymlink() && rp.ShouldFollowSymlink() {
+ if err := rp.HandleSymlink(child.inode.impl.(*symlink).target); err != nil {
return nil, nil, err
}
continue
}
rp.Advance()
- return nextVFSD, nextInode, nil
+ return &child.vfsd, child.inode, nil
}
}
@@ -140,12 +167,12 @@ func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write boo
//
// Preconditions:
// - filesystem.mu must be locked (for writing if write param is true).
-func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+func walkLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
vfsd := rp.Start()
inode := vfsd.Impl().(*dentry).inode
for !rp.Done() {
var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write)
if err != nil {
return nil, nil, err
}
@@ -169,12 +196,12 @@ func walkLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error)
// Preconditions:
// - filesystem.mu must be locked (for writing if write param is true).
// - !rp.Done().
-func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
+func walkParentLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) {
vfsd := rp.Start()
inode := vfsd.Impl().(*dentry).inode
for !rp.Final() {
var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode, write)
+ vfsd, inode, err = stepLocked(ctx, rp, vfsd, inode, write)
if err != nil {
return nil, nil, err
}
@@ -189,7 +216,7 @@ func walkParentLocked(rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, e
// the rp till the parent of the last component which should be an existing
// directory. If parent is false then resolves rp entirely. Attemps to resolve
// the path as far as it can with a read lock and upgrades the lock if needed.
-func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) {
+func (fs *filesystem) walk(ctx context.Context, rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *inode, error) {
var (
vfsd *vfs.Dentry
inode *inode
@@ -200,9 +227,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in
// of disk. This reduces congestion (allows concurrent walks).
fs.mu.RLock()
if parent {
- vfsd, inode, err = walkParentLocked(rp, false)
+ vfsd, inode, err = walkParentLocked(ctx, rp, false)
} else {
- vfsd, inode, err = walkLocked(rp, false)
+ vfsd, inode, err = walkLocked(ctx, rp, false)
}
fs.mu.RUnlock()
@@ -211,9 +238,9 @@ func (fs *filesystem) walk(rp *vfs.ResolvingPath, parent bool) (*vfs.Dentry, *in
// walk is fine as this is a read only filesystem.
fs.mu.Lock()
if parent {
- vfsd, inode, err = walkParentLocked(rp, true)
+ vfsd, inode, err = walkParentLocked(ctx, rp, true)
} else {
- vfsd, inode, err = walkLocked(rp, true)
+ vfsd, inode, err = walkLocked(ctx, rp, true)
}
fs.mu.Unlock()
}
@@ -254,9 +281,18 @@ func (fs *filesystem) statTo(stat *linux.Statfs) {
// TODO(b/134676337): Set Statfs.Flags and Statfs.FSID.
}
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ _, inode, err := fs.walk(ctx, rp, false)
+ if err != nil {
+ return err
+ }
+ return inode.checkPermissions(rp.Credentials(), ats)
+}
+
// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
- vfsd, inode, err := fs.walk(rp, false)
+ vfsd, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return nil, err
}
@@ -274,9 +310,19 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op
return vfsd, nil
}
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ vfsd, inode, err := fs.walk(ctx, rp, true)
+ if err != nil {
+ return nil, err
+ }
+ inode.incRef()
+ return vfsd, nil
+}
+
// OpenAt implements vfs.FilesystemImpl.OpenAt.
func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- vfsd, inode, err := fs.walk(rp, false)
+ vfsd, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return nil, err
}
@@ -285,12 +331,12 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
if vfs.MayWriteFileWithOpenFlags(opts.Flags) || opts.Flags&(linux.O_CREAT|linux.O_EXCL|linux.O_TMPFILE) != 0 {
return nil, syserror.EROFS
}
- return inode.open(rp, vfsd, opts.Flags)
+ return inode.open(rp, vfsd, &opts)
}
// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return "", err
}
@@ -303,7 +349,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st
// StatAt implements vfs.FilesystemImpl.StatAt.
func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return linux.Statx{}, err
}
@@ -314,7 +360,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf
// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
- if _, _, err := fs.walk(rp, false); err != nil {
+ if _, _, err := fs.walk(ctx, rp, false); err != nil {
return linux.Statfs{}, err
}
@@ -324,7 +370,9 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu
}
// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {}
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+}
// Sync implements vfs.FilesystemImpl.Sync.
func (fs *filesystem) Sync(ctx context.Context) error {
@@ -342,7 +390,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.
return syserror.EEXIST
}
- if _, _, err := fs.walk(rp, true); err != nil {
+ if _, _, err := fs.walk(ctx, rp, true); err != nil {
return err
}
@@ -355,7 +403,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
- if _, _, err := fs.walk(rp, true); err != nil {
+ if _, _, err := fs.walk(ctx, rp, true); err != nil {
return err
}
@@ -368,7 +416,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
return syserror.EEXIST
}
- _, _, err := fs.walk(rp, true)
+ _, _, err := fs.walk(ctx, rp, true)
if err != nil {
return err
}
@@ -377,12 +425,12 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v
}
// RenameAt implements vfs.FilesystemImpl.RenameAt.
-func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
if rp.Done() {
return syserror.ENOENT
}
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -392,7 +440,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vf
// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -406,7 +454,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error
// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
- _, _, err := fs.walk(rp, false)
+ _, _, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -420,7 +468,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
return syserror.EEXIST
}
- _, _, err := fs.walk(rp, true)
+ _, _, err := fs.walk(ctx, rp, true)
if err != nil {
return err
}
@@ -430,7 +478,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ
// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- _, inode, err := fs.walk(rp, false)
+ _, inode, err := fs.walk(ctx, rp, false)
if err != nil {
return err
}
@@ -441,3 +489,60 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error
return syserror.EROFS
}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ _, inode, err := fs.walk(ctx, rp, false)
+ if err != nil {
+ return nil, err
+ }
+ if err := inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+
+ // TODO(b/134676337): Support sockets.
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ _, _, err := fs.walk(ctx, rp, false)
+ if err != nil {
+ return nil, err
+ }
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ _, _, err := fs.walk(ctx, rp, false)
+ if err != nil {
+ return "", err
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ _, _, err := fs.walk(ctx, rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ _, _, err := fs.walk(ctx, rp, false)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
+}
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index e6c847a71..30636cf66 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -16,7 +16,6 @@ package ext
import (
"fmt"
- "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -42,19 +41,21 @@ type inode struct {
// refs is a reference count. refs is accessed using atomic memory operations.
refs int64
+ // fs is the containing filesystem.
+ fs *filesystem
+
// inodeNum is the inode number of this inode on disk. This is used to
// identify inodes within the ext filesystem.
inodeNum uint32
- // dev represents the underlying device. Same as filesystem.dev.
- dev io.ReaderAt
-
// blkSize is the fs data block size. Same as filesystem.sb.BlockSize().
blkSize uint64
// diskInode gives us access to the inode struct on disk. Immutable.
diskInode disklayout.Inode
+ locks vfs.FileLocks
+
// This is immutable. The first field of the implementations must have inode
// as the first field to ensure temporality.
impl interface{}
@@ -81,10 +82,10 @@ func (in *inode) tryIncRef() bool {
// decRef decrements the inode ref count and releases the inode resources if
// the ref count hits 0.
//
-// Precondition: Must have locked fs.mu.
-func (in *inode) decRef(fs *filesystem) {
+// Precondition: Must have locked filesystem.mu.
+func (in *inode) decRef() {
if refs := atomic.AddInt64(&in.refs, -1); refs == 0 {
- delete(fs.inodeCache, in.inodeNum)
+ delete(in.fs.inodeCache, in.inodeNum)
} else if refs < 0 {
panic("ext.inode.decRef() called without holding a reference")
}
@@ -116,28 +117,28 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
// Build the inode based on its type.
- inode := inode{
+ args := inodeArgs{
+ fs: fs,
inodeNum: inodeNum,
- dev: fs.dev,
blkSize: blkSize,
diskInode: diskInode,
}
switch diskInode.Mode().FileType() {
case linux.ModeSymlink:
- f, err := newSymlink(inode)
+ f, err := newSymlink(args)
if err != nil {
return nil, err
}
return &f.inode, nil
case linux.ModeRegular:
- f, err := newRegularFile(inode)
+ f, err := newRegularFile(args)
if err != nil {
return nil, err
}
return &f.inode, nil
case linux.ModeDirectory:
- f, err := newDirectroy(inode, fs.sb.IncompatibleFeatures().DirentFileType)
+ f, err := newDirectory(args, fs.sb.IncompatibleFeatures().DirentFileType)
if err != nil {
return nil, err
}
@@ -148,17 +149,35 @@ func newInode(fs *filesystem, inodeNum uint32) (*inode, error) {
}
}
+type inodeArgs struct {
+ fs *filesystem
+ inodeNum uint32
+ blkSize uint64
+ diskInode disklayout.Inode
+}
+
+func (in *inode) init(args inodeArgs, impl interface{}) {
+ in.fs = args.fs
+ in.inodeNum = args.inodeNum
+ in.blkSize = args.blkSize
+ in.diskInode = args.diskInode
+ in.impl = impl
+}
+
// open creates and returns a file description for the dentry passed in.
-func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(flags)
+func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
if err := in.checkPermissions(rp.Credentials(), ats); err != nil {
return nil, err
}
+ mnt := rp.Mount()
switch in.impl.(type) {
case *regularFile:
var fd regularFileFD
- fd.flags = flags
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ fd.LockFD.Init(&in.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
return &fd.vfsfd, nil
case *directory:
// Can't open directories writably. This check is not necessary for a read
@@ -167,17 +186,19 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
return nil, syserror.EISDIR
}
var fd directoryFD
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
- fd.flags = flags
+ fd.LockFD.Init(&in.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
return &fd.vfsfd, nil
case *symlink:
- if flags&linux.O_PATH == 0 {
+ if opts.Flags&linux.O_PATH == 0 {
// Can't open symlinks without O_PATH.
return nil, syserror.ELOOP
}
var fd symlinkFD
- fd.flags = flags
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
+ fd.LockFD.Init(&in.locks)
+ fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
return &fd.vfsfd, nil
default:
panic(fmt.Sprintf("unknown inode type: %T", in.impl))
@@ -185,7 +206,7 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*v
}
func (in *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
- return vfs.GenericCheckPermissions(creds, ats, in.isDir(), uint16(in.diskInode.Mode()), in.diskInode.UID(), in.diskInode.GID())
+ return vfs.GenericCheckPermissions(creds, ats, in.diskInode.Mode(), in.diskInode.UID(), in.diskInode.GID())
}
// statTo writes the statx fields to the output parameter.
@@ -203,6 +224,8 @@ func (in *inode) statTo(stat *linux.Statx) {
stat.Atime = in.diskInode.AccessTime().StatxTimestamp()
stat.Ctime = in.diskInode.ChangeTime().StatxTimestamp()
stat.Mtime = in.diskInode.ModificationTime().StatxTimestamp()
+ stat.DevMajor = linux.UNNAMED_MAJOR
+ stat.DevMinor = in.fs.devMinor
// TODO(b/134676337): Set stat.Blocks which is the number of 512 byte blocks
// (including metadata blocks) required to represent this file.
}
diff --git a/pkg/sentry/fsimpl/ext/regular_file.go b/pkg/sentry/fsimpl/ext/regular_file.go
index aec33e00a..e73e740d6 100644
--- a/pkg/sentry/fsimpl/ext/regular_file.go
+++ b/pkg/sentry/fsimpl/ext/regular_file.go
@@ -16,15 +16,16 @@ package ext
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// regularFile represents a regular file's inode. This too follows the
@@ -43,28 +44,19 @@ type regularFile struct {
// newRegularFile is the regularFile constructor. It figures out what kind of
// file this is and initializes the fileReader.
-func newRegularFile(inode inode) (*regularFile, error) {
- regFile := regularFile{
- inode: inode,
- }
-
- inodeFlags := inode.diskInode.Flags()
-
- if inodeFlags.Extents {
- file, err := newExtentFile(regFile)
+func newRegularFile(args inodeArgs) (*regularFile, error) {
+ if args.diskInode.Flags().Extents {
+ file, err := newExtentFile(args)
if err != nil {
return nil, err
}
-
- file.regFile.inode.impl = &file.regFile
return &file.regFile, nil
}
- file, err := newBlockMapFile(regFile)
+ file, err := newBlockMapFile(args)
if err != nil {
return nil, err
}
- file.regFile.inode.impl = &file.regFile
return &file.regFile, nil
}
@@ -77,6 +69,7 @@ func (in *inode) isRegular() bool {
// vfs.FileDescriptionImpl.
type regularFileFD struct {
fileDescription
+ vfs.LockFD
// off is the file offset. off is accessed using atomic memory operations.
off int64
@@ -86,7 +79,7 @@ type regularFileFD struct {
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {}
+func (fd *regularFileFD) Release(context.Context) {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
@@ -157,3 +150,13 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
// TODO(b/134676337): Implement mmap(2).
return syserror.ENODEV
}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *regularFileFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go
index bdf8705c1..2fd0d1fa8 100644
--- a/pkg/sentry/fsimpl/ext/symlink.go
+++ b/pkg/sentry/fsimpl/ext/symlink.go
@@ -15,11 +15,11 @@
package ext
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// symlink represents a symlink inode.
@@ -30,18 +30,17 @@ type symlink struct {
// newSymlink is the symlink constructor. It reads out the symlink target from
// the inode (however it might have been stored).
-func newSymlink(inode inode) (*symlink, error) {
- var file *symlink
+func newSymlink(args inodeArgs) (*symlink, error) {
var link []byte
// If the symlink target is lesser than 60 bytes, its stores in inode.Data().
// Otherwise either extents or block maps will be used to store the link.
- size := inode.diskInode.Size()
+ size := args.diskInode.Size()
if size < 60 {
- link = inode.diskInode.Data()[:size]
+ link = args.diskInode.Data()[:size]
} else {
// Create a regular file out of this inode and read out the target.
- regFile, err := newRegularFile(inode)
+ regFile, err := newRegularFile(args)
if err != nil {
return nil, err
}
@@ -52,8 +51,8 @@ func newSymlink(inode inode) (*symlink, error) {
}
}
- file = &symlink{inode: inode, target: string(link)}
- file.inode.impl = file
+ file := &symlink{target: string(link)}
+ file.inode.init(args, file)
return file, nil
}
@@ -67,13 +66,14 @@ func (in *inode) isSymlink() bool {
// O_PATH. For this reason most of the functions return EBADF.
type symlinkFD struct {
fileDescription
+ vfs.NoLockFD
}
// Compiles only if symlinkFD implements vfs.FileDescriptionImpl.
var _ vfs.FileDescriptionImpl = (*symlinkFD)(nil)
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *symlinkFD) Release() {}
+func (fd *symlinkFD) Release(context.Context) {}
// PRead implements vfs.FileDescriptionImpl.PRead.
func (fd *symlinkFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
new file mode 100644
index 000000000..999111deb
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -0,0 +1,63 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "request_list",
+ out = "request_list.go",
+ package = "fuse",
+ prefix = "request",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Request",
+ "Linker": "*Request",
+ },
+)
+
+go_library(
+ name = "fuse",
+ srcs = [
+ "connection.go",
+ "dev.go",
+ "fusefs.go",
+ "init.go",
+ "register.go",
+ "request_list.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "fuse_test",
+ size = "small",
+ srcs = ["dev_test.go"],
+ library = ":fuse",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go
new file mode 100644
index 000000000..6df2728ab
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/connection.go
@@ -0,0 +1,437 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// maxActiveRequestsDefault is the default setting controlling the upper bound
+// on the number of active requests at any given time.
+const maxActiveRequestsDefault = 10000
+
+// Ordinary requests have even IDs, while interrupts IDs are odd.
+// Used to increment the unique ID for each FUSE request.
+var reqIDStep uint64 = 2
+
+const (
+ // fuseDefaultMaxBackground is the default value for MaxBackground.
+ fuseDefaultMaxBackground = 12
+
+ // fuseDefaultCongestionThreshold is the default value for CongestionThreshold,
+ // and is 75% of the default maximum of MaxGround.
+ fuseDefaultCongestionThreshold = (fuseDefaultMaxBackground * 3 / 4)
+
+ // fuseDefaultMaxPagesPerReq is the default value for MaxPagesPerReq.
+ fuseDefaultMaxPagesPerReq = 32
+)
+
+// Request represents a FUSE operation request that hasn't been sent to the
+// server yet.
+//
+// +stateify savable
+type Request struct {
+ requestEntry
+
+ id linux.FUSEOpID
+ hdr *linux.FUSEHeaderIn
+ data []byte
+}
+
+// Response represents an actual response from the server, including the
+// response payload.
+//
+// +stateify savable
+type Response struct {
+ opcode linux.FUSEOpcode
+ hdr linux.FUSEHeaderOut
+ data []byte
+}
+
+// connection is the struct by which the sentry communicates with the FUSE server daemon.
+type connection struct {
+ fd *DeviceFD
+
+ // The following FUSE_INIT flags are currently unsupported by this implementation:
+ // - FUSE_ATOMIC_O_TRUNC: requires open(..., O_TRUNC)
+ // - FUSE_EXPORT_SUPPORT
+ // - FUSE_HANDLE_KILLPRIV
+ // - FUSE_POSIX_LOCKS: requires POSIX locks
+ // - FUSE_FLOCK_LOCKS: requires POSIX locks
+ // - FUSE_AUTO_INVAL_DATA: requires page caching eviction
+ // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction
+ // - FUSE_DO_READDIRPLUS/FUSE_READDIRPLUS_AUTO: requires FUSE_READDIRPLUS implementation
+ // - FUSE_ASYNC_DIO
+ // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler
+
+ // initialized after receiving FUSE_INIT reply.
+ // Until it's set, suspend sending FUSE requests.
+ // Use SetInitialized() and IsInitialized() for atomic access.
+ initialized int32
+
+ // initializedChan is used to block requests before initialization.
+ initializedChan chan struct{}
+
+ // blocked when there are too many outstading backgrounds requests (NumBackground == MaxBackground).
+ // TODO(gvisor.dev/issue/3185): update the numBackground accordingly; use a channel to block.
+ blocked bool
+
+ // connected (connection established) when a new FUSE file system is created.
+ // Set to false when:
+ // umount,
+ // connection abort,
+ // device release.
+ connected bool
+
+ // aborted via sysfs.
+ // TODO(gvisor.dev/issue/3185): abort all queued requests.
+ aborted bool
+
+ // connInitError if FUSE_INIT encountered error (major version mismatch).
+ // Only set in INIT.
+ connInitError bool
+
+ // connInitSuccess if FUSE_INIT is successful.
+ // Only set in INIT.
+ // Used for destory.
+ connInitSuccess bool
+
+ // TODO(gvisor.dev/issue/3185): All the queue logic are working in progress.
+
+ // NumberBackground is the number of requests in the background.
+ numBackground uint16
+
+ // congestionThreshold for NumBackground.
+ // Negotiated in FUSE_INIT.
+ congestionThreshold uint16
+
+ // maxBackground is the maximum number of NumBackground.
+ // Block connection when it is reached.
+ // Negotiated in FUSE_INIT.
+ maxBackground uint16
+
+ // numActiveBackground is the number of requests in background and has being marked as active.
+ numActiveBackground uint16
+
+ // numWating is the number of requests waiting for completion.
+ numWaiting uint32
+
+ // TODO(gvisor.dev/issue/3185): BgQueue
+ // some queue for background queued requests.
+
+ // bgLock protects:
+ // MaxBackground, CongestionThreshold, NumBackground,
+ // NumActiveBackground, BgQueue, Blocked.
+ bgLock sync.Mutex
+
+ // maxRead is the maximum size of a read buffer in in bytes.
+ maxRead uint32
+
+ // maxWrite is the maximum size of a write buffer in bytes.
+ // Negotiated in FUSE_INIT.
+ maxWrite uint32
+
+ // maxPages is the maximum number of pages for a single request to use.
+ // Negotiated in FUSE_INIT.
+ maxPages uint16
+
+ // minor version of the FUSE protocol.
+ // Negotiated and only set in INIT.
+ minor uint32
+
+ // asyncRead if read pages asynchronously.
+ // Negotiated and only set in INIT.
+ asyncRead bool
+
+ // abortErr is true if kernel need to return an unique read error after abort.
+ // Negotiated and only set in INIT.
+ abortErr bool
+
+ // writebackCache is true for write-back cache policy,
+ // false for write-through policy.
+ // Negotiated and only set in INIT.
+ writebackCache bool
+
+ // cacheSymlinks if filesystem needs to cache READLINK responses in page cache.
+ // Negotiated and only set in INIT.
+ cacheSymlinks bool
+
+ // bigWrites if doing multi-page cached writes.
+ // Negotiated and only set in INIT.
+ bigWrites bool
+
+ // dontMask if filestestem does not apply umask to creation modes.
+ // Negotiated in INIT.
+ dontMask bool
+}
+
+// newFUSEConnection creates a FUSE connection to fd.
+func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*connection, error) {
+ // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
+ // mount a FUSE filesystem.
+ fuseFD := fd.Impl().(*DeviceFD)
+ fuseFD.mounted = true
+
+ // Create the writeBuf for the header to be stored in.
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ fuseFD.writeBuf = make([]byte, hdrLen)
+ fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse)
+ fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests)
+ fuseFD.writeCursor = 0
+
+ return &connection{
+ fd: fuseFD,
+ maxBackground: fuseDefaultMaxBackground,
+ congestionThreshold: fuseDefaultCongestionThreshold,
+ maxPages: fuseDefaultMaxPagesPerReq,
+ initializedChan: make(chan struct{}),
+ connected: true,
+ }, nil
+}
+
+// SetInitialized atomically sets the connection as initialized.
+func (conn *connection) SetInitialized() {
+ // Unblock the requests sent before INIT.
+ close(conn.initializedChan)
+
+ // Close the channel first to avoid the non-atomic situation
+ // where conn.initialized is true but there are
+ // tasks being blocked on the channel.
+ // And it prevents the newer tasks from gaining
+ // unnecessary higher chance to be issued before the blocked one.
+
+ atomic.StoreInt32(&(conn.initialized), int32(1))
+}
+
+// IsInitialized atomically check if the connection is initialized.
+// pairs with SetInitialized().
+func (conn *connection) Initialized() bool {
+ return atomic.LoadInt32(&(conn.initialized)) != 0
+}
+
+// NewRequest creates a new request that can be sent to the FUSE server.
+func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+ conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
+
+ hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes()
+ hdr := linux.FUSEHeaderIn{
+ Len: uint32(hdrLen + payload.SizeBytes()),
+ Opcode: opcode,
+ Unique: conn.fd.nextOpID,
+ NodeID: ino,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ PID: pid,
+ }
+
+ buf := make([]byte, hdr.Len)
+ hdr.MarshalUnsafe(buf[:hdrLen])
+ payload.MarshalUnsafe(buf[hdrLen:])
+
+ return &Request{
+ id: hdr.Unique,
+ hdr: &hdr,
+ data: buf,
+ }, nil
+}
+
+// Call makes a request to the server and blocks the invoking task until a
+// server responds with a response. Task should never be nil.
+// Requests will not be sent before the connection is initialized.
+// For async tasks, use CallAsync().
+func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) {
+ // Block requests sent before connection is initalized.
+ if !conn.Initialized() {
+ if err := t.Block(conn.initializedChan); err != nil {
+ return nil, err
+ }
+ }
+
+ return conn.call(t, r)
+}
+
+// CallAsync makes an async (aka background) request.
+// Those requests either do not expect a response (e.g. release) or
+// the response should be handled by others (e.g. init).
+// Return immediately unless the connection is blocked (before initialization).
+// Async call example: init, release, forget, aio, interrupt.
+// When the Request is FUSE_INIT, it will not be blocked before initialization.
+func (conn *connection) CallAsync(t *kernel.Task, r *Request) error {
+ // Block requests sent before connection is initalized.
+ if !conn.Initialized() && r.hdr.Opcode != linux.FUSE_INIT {
+ if err := t.Block(conn.initializedChan); err != nil {
+ return err
+ }
+ }
+
+ // This should be the only place that invokes call() with a nil task.
+ _, err := conn.call(nil, r)
+ return err
+}
+
+// call makes a call without blocking checks.
+func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) {
+ if !conn.connected {
+ return nil, syserror.ENOTCONN
+ }
+
+ if conn.connInitError {
+ return nil, syserror.ECONNREFUSED
+ }
+
+ fut, err := conn.callFuture(t, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return fut.resolve(t)
+}
+
+// Error returns the error of the FUSE call.
+func (r *Response) Error() error {
+ errno := r.hdr.Error
+ if errno >= 0 {
+ return nil
+ }
+
+ sysErrNo := syscall.Errno(-errno)
+ return error(sysErrNo)
+}
+
+// UnmarshalPayload unmarshals the response data into m.
+func (r *Response) UnmarshalPayload(m marshal.Marshallable) error {
+ hdrLen := r.hdr.SizeBytes()
+ haveDataLen := r.hdr.Len - uint32(hdrLen)
+ wantDataLen := uint32(m.SizeBytes())
+
+ if haveDataLen < wantDataLen {
+ return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen)
+ }
+
+ m.UnmarshalUnsafe(r.data[hdrLen:])
+ return nil
+}
+
+// callFuture makes a request to the server and returns a future response.
+// Call resolve() when the response needs to be fulfilled.
+func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+
+ // Is the queue full?
+ //
+ // We must busy wait here until the request can be queued. We don't
+ // block on the fd.fullQueueCh with a lock - so after being signalled,
+ // before we acquire the lock, it is possible that a barging task enters
+ // and queues a request. As a result, upon acquiring the lock we must
+ // again check if the room is available.
+ //
+ // This can potentially starve a request forever but this can only happen
+ // if there are always too many ongoing requests all the time. The
+ // supported maxActiveRequests setting should be really high to avoid this.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ if t == nil {
+ // Since there is no task that is waiting. We must error out.
+ return nil, errors.New("FUSE request queue full")
+ }
+
+ log.Infof("Blocking request %v from being queued. Too many active requests: %v",
+ r.id, conn.fd.numActiveRequests)
+ conn.fd.mu.Unlock()
+ err := t.Block(conn.fd.fullQueueCh)
+ conn.fd.mu.Lock()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return conn.callFutureLocked(t, r)
+}
+
+// callFutureLocked makes a request to the server and returns a future response.
+func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.queue.PushBack(r)
+ conn.fd.numActiveRequests += 1
+ fut := newFutureResponse(r.hdr.Opcode)
+ conn.fd.completions[r.id] = fut
+
+ // Signal the readers that there is something to read.
+ conn.fd.waitQueue.Notify(waiter.EventIn)
+
+ return fut, nil
+}
+
+// futureResponse represents an in-flight request, that may or may not have
+// completed yet. Convert it to a resolved Response by calling Resolve, but note
+// that this may block.
+//
+// +stateify savable
+type futureResponse struct {
+ opcode linux.FUSEOpcode
+ ch chan struct{}
+ hdr *linux.FUSEHeaderOut
+ data []byte
+}
+
+// newFutureResponse creates a future response to a FUSE request.
+func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse {
+ return &futureResponse{
+ opcode: opcode,
+ ch: make(chan struct{}),
+ }
+}
+
+// resolve blocks the task until the server responds to its corresponding request,
+// then returns a resolved response.
+func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) {
+ // If there is no Task associated with this request - then we don't try to resolve
+ // the response. Instead, the task writing the response (proxy to the server) will
+ // process the response on our behalf.
+ if t == nil {
+ log.Infof("fuse.Response.resolve: Not waiting on a response from server.")
+ return nil, nil
+ }
+
+ if err := t.Block(f.ch); err != nil {
+ return nil, err
+ }
+
+ return f.getResponse(), nil
+}
+
+// getResponse creates a Response from the data the futureResponse has.
+func (f *futureResponse) getResponse() *Response {
+ return &Response{
+ opcode: f.opcode,
+ hdr: *f.hdr,
+ data: f.data,
+ }
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
new file mode 100644
index 000000000..e522ff9a0
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -0,0 +1,397 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const fuseDevMinor = 229
+
+// fuseDevice implements vfs.Device for /dev/fuse.
+type fuseDevice struct{}
+
+// Open implements vfs.Device.Open.
+func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if !kernel.FUSEEnabled {
+ return nil, syserror.ENOENT
+ }
+
+ var fd DeviceFD
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse.
+type DeviceFD struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD.
+ mounted bool
+
+ // nextOpID is used to create new requests.
+ nextOpID linux.FUSEOpID
+
+ // queue is the list of requests that need to be processed by the FUSE server.
+ queue requestList
+
+ // numActiveRequests is the number of requests made by the Sentry that has
+ // yet to be responded to.
+ numActiveRequests uint64
+
+ // completions is used to map a request to its response. A Writer will use this
+ // to notify the caller of a completed response.
+ completions map[linux.FUSEOpID]*futureResponse
+
+ writeCursor uint32
+
+ // writeBuf is the memory buffer used to copy in the FUSE out header from
+ // userspace.
+ writeBuf []byte
+
+ // writeCursorFR current FR being copied from server.
+ writeCursorFR *futureResponse
+
+ // mu protects all the queues, maps, buffers and cursors and nextOpID.
+ mu sync.Mutex
+
+ // waitQueue is used to notify interested parties when the device becomes
+ // readable or writable.
+ waitQueue waiter.Queue
+
+ // fullQueueCh is a channel used to synchronize the readers with the writers.
+ // Writers (inbound requests to the filesystem) block if there are too many
+ // unprocessed in-flight requests.
+ fullQueueCh chan struct{}
+
+ // fs is the FUSE filesystem that this FD is being used for.
+ fs *filesystem
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *DeviceFD) Release(context.Context) {
+ fd.fs.conn.connected = false
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ return 0, syserror.ENOSYS
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ // We require that any Read done on this filesystem have a sane minimum
+ // read buffer. It must have the capacity for the fixed parts of any request
+ // header (Linux uses the request header and the FUSEWriteIn header for this
+ // calculation) + the negotiated MaxWrite room for the data.
+ minBuffSize := linux.FUSE_MIN_READ_BUFFER
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes())
+ negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.maxWrite
+ if minBuffSize < negotiatedMinBuffSize {
+ minBuffSize = negotiatedMinBuffSize
+ }
+
+ // If the read buffer is too small, error out.
+ if dst.NumBytes() < int64(minBuffSize) {
+ return 0, syserror.EINVAL
+ }
+
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.readLocked(ctx, dst, opts)
+}
+
+// readLocked implements the reading of the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ if fd.queue.Empty() {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var readCursor uint32
+ var bytesRead int64
+ for {
+ req := fd.queue.Front()
+ if dst.NumBytes() < int64(req.hdr.Len) {
+ // The request is too large. Cannot process it. All requests must be smaller than the
+ // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT
+ // handshake.
+ errno := -int32(syscall.EIO)
+ if req.hdr.Opcode == linux.FUSE_SETXATTR {
+ errno = -int32(syscall.E2BIG)
+ }
+
+ // Return the error to the calling task.
+ if err := fd.sendError(ctx, errno, req); err != nil {
+ return 0, err
+ }
+
+ // We're done with this request.
+ fd.queue.Remove(req)
+
+ // Restart the read as this request was invalid.
+ log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.")
+ return fd.readLocked(ctx, dst, opts)
+ }
+
+ n, err := dst.CopyOut(ctx, req.data[readCursor:])
+ if err != nil {
+ return 0, err
+ }
+ readCursor += uint32(n)
+ bytesRead += int64(n)
+
+ if readCursor >= req.hdr.Len {
+ // Fully done with this req, remove it from the queue.
+ fd.queue.Remove(req)
+ break
+ }
+ }
+
+ return bytesRead, nil
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ return 0, syserror.ENOSYS
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.writeLocked(ctx, src, opts)
+}
+
+// writeLocked implements writing to the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ var cn, n int64
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+
+ for src.NumBytes() > 0 {
+ if fd.writeCursorFR != nil {
+ // Already have common header, and we're now copying the payload.
+ wantBytes := fd.writeCursorFR.hdr.Len
+
+ // Note that the FR data doesn't have the header. Copy it over if its necessary.
+ if fd.writeCursorFR.data == nil {
+ fd.writeCursorFR.data = make([]byte, wantBytes)
+ }
+
+ bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == wantBytes {
+ // Done reading this full response. Clean up and unblock the
+ // initiator.
+ break
+ }
+
+ // Check if we have more data in src.
+ continue
+ }
+
+ // Assert that the header isn't read into the writeBuf yet.
+ if fd.writeCursor >= hdrLen {
+ return 0, syserror.EINVAL
+ }
+
+ // We don't have the full common response header yet.
+ wantBytes := hdrLen - fd.writeCursor
+ bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == hdrLen {
+ // Have full header in the writeBuf. Use it to fetch the actual futureResponse
+ // from the device's completions map.
+ var hdr linux.FUSEHeaderOut
+ hdr.UnmarshalBytes(fd.writeBuf)
+
+ // We have the header now and so the writeBuf has served its purpose.
+ // We could reset it manually here but instead of doing that, at the
+ // end of the write, the writeCursor will be set to 0 thereby allowing
+ // the next request to overwrite whats in the buffer,
+
+ fut, ok := fd.completions[hdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return 0, syserror.EINVAL
+ }
+
+ delete(fd.completions, hdr.Unique)
+
+ // Copy over the header into the future response. The rest of the payload
+ // will be copied over to the FR's data in the next iteration.
+ fut.hdr = &hdr
+ fd.writeCursorFR = fut
+
+ // Next iteration will now try read the complete request, if src has
+ // any data remaining. Otherwise we're done.
+ }
+ }
+
+ if fd.writeCursorFR != nil {
+ if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil {
+ return 0, err
+ }
+
+ // Ready the device for the next request.
+ fd.writeCursorFR = nil
+ fd.writeCursor = 0
+ }
+
+ return n, nil
+}
+
+// Readiness implements vfs.FileDescriptionImpl.Readiness.
+func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ ready |= waiter.EventOut // FD is always writable
+ if !fd.queue.Empty() {
+ // Have reqs available, FD is readable.
+ ready |= waiter.EventIn
+ }
+
+ return ready & mask
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.waitQueue.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *DeviceFD) EventUnregister(e *waiter.Entry) {
+ fd.waitQueue.EventUnregister(e)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
+ if !fd.mounted {
+ return 0, syserror.EPERM
+ }
+
+ return 0, syserror.ENOSYS
+}
+
+// sendResponse sends a response to the waiting task (if any).
+func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error {
+ // See if the running task need to perform some action before returning.
+ // Since we just finished writing the future, we can be sure that
+ // getResponse generates a populated response.
+ if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil {
+ return err
+ }
+
+ // Signal that the queue is no longer full.
+ select {
+ case fd.fullQueueCh <- struct{}{}:
+ default:
+ }
+ fd.numActiveRequests -= 1
+
+ // Signal the task waiting on a response.
+ close(fut.ch)
+ return nil
+}
+
+// sendError sends an error response to the waiting task (if any).
+func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error {
+ // Return the error to the calling task.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ respHdr := linux.FUSEHeaderOut{
+ Len: outHdrLen,
+ Error: errno,
+ Unique: req.hdr.Unique,
+ }
+
+ fut, ok := fd.completions[respHdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return syserror.EINVAL
+ }
+ delete(fd.completions, respHdr.Unique)
+
+ fut.hdr = &respHdr
+ if err := fd.sendResponse(ctx, fut); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// noReceiverAction has the calling kernel.Task do some action if its known that no
+// receiver is going to be waiting on the future channel. This is to be used by:
+// FUSE_INIT.
+func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error {
+ if r.opcode == linux.FUSE_INIT {
+ creds := auth.CredentialsFromContext(ctx)
+ rootUserNs := kernel.KernelFromContext(ctx).RootUserNamespace()
+ return fd.fs.conn.InitRecv(r, creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, rootUserNs))
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
new file mode 100644
index 000000000..1ffe7ccd2
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -0,0 +1,428 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// echoTestOpcode is the Opcode used during testing. The server used in tests
+// will simply echo the payload back with the appropriate headers.
+const echoTestOpcode linux.FUSEOpcode = 1000
+
+type testPayload struct {
+ data uint32
+}
+
+// TestFUSECommunication tests that the communication layer between the Sentry and the
+// FUSE server daemon works as expected.
+func TestFUSECommunication(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ creds := auth.CredentialsFromContext(s.Ctx)
+
+ // Create test cases with different number of concurrent clients and servers.
+ testCases := []struct {
+ Name string
+ NumClients int
+ NumServers int
+ MaxActiveRequests uint64
+ }{
+ {
+ Name: "SingleClientSingleServer",
+ NumClients: 1,
+ NumServers: 1,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "SingleClientMultipleServers",
+ NumClients: 1,
+ NumServers: 10,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsSingleServer",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsMultipleServers",
+ NumClients: 10,
+ NumServers: 10,
+ MaxActiveRequests: maxActiveRequestsDefault,
+ },
+ {
+ Name: "RequestCapacityFull",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: 1,
+ },
+ {
+ Name: "RequestCapacityContinuouslyFull",
+ NumClients: 100,
+ NumServers: 2,
+ MaxActiveRequests: 2,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests)
+ if err != nil {
+ t.Fatalf("newTestConnection: %v", err)
+ }
+
+ clientsDone := make([]chan struct{}, testCase.NumClients)
+ serversDone := make([]chan struct{}, testCase.NumServers)
+ serversKill := make([]chan struct{}, testCase.NumServers)
+
+ // FUSE clients.
+ for i := 0; i < testCase.NumClients; i++ {
+ clientsDone[i] = make(chan struct{})
+ go func(i int) {
+ fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i])
+ }(i)
+ }
+
+ // FUSE servers.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversDone[j] = make(chan struct{})
+ serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block.
+ go func(j int) {
+ fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j])
+ }(j)
+ }
+
+ // Tear down.
+ //
+ // Make sure all the clients are done.
+ for i := 0; i < testCase.NumClients; i++ {
+ <-clientsDone[i]
+ }
+
+ // Kill any server that is potentially waiting.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversKill[j] <- struct{}{}
+ }
+
+ // Make sure all the servers are done.
+ for j := 0; j < testCase.NumServers; j++ {
+ <-serversDone[j]
+ }
+ })
+ }
+}
+
+// CallTest makes a request to the server and blocks the invoking
+// goroutine until a server responds with a response. Doesn't block
+// a kernel.Task. Analogous to Connection.Call but used for testing.
+func CallTest(conn *connection, t *kernel.Task, r *Request, i uint32) (*Response, error) {
+ conn.fd.mu.Lock()
+
+ // Wait until we're certain that a new request can be processed.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ conn.fd.mu.Unlock()
+ select {
+ case <-conn.fd.fullQueueCh:
+ }
+ conn.fd.mu.Lock()
+ }
+
+ fut, err := conn.callFutureLocked(t, r) // No task given.
+ conn.fd.mu.Unlock()
+
+ if err != nil {
+ return nil, err
+ }
+
+ // Resolve the response.
+ //
+ // Block without a task.
+ select {
+ case <-fut.ch:
+ }
+
+ // A response is ready. Resolve and return it.
+ return fut.getResponse(), nil
+}
+
+// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE
+// device. However, it does so by - not blocking the task that is calling - and
+// instead just waits on a channel. The behaviour is essentially the same as
+// DeviceFD.Read except it guarantees that the task is not blocked.
+func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) {
+ var err error
+ var n, total int64
+
+ dev := fd.Impl().(*DeviceFD)
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ dev.EventRegister(&w, waiter.EventIn)
+ for {
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{})
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ // Emulate the blocking for when no requests are available
+ select {
+ case <-ch:
+ case <-killServer:
+ // Server killed by the main program.
+ return 0, true, nil
+ }
+ }
+
+ dev.EventUnregister(&w)
+ return total, false, err
+}
+
+// fuseClientRun emulates all the actions of a normal FUSE request. It creates
+// a header, a payload, calls the server, waits for the response, and processes
+// the response.
+func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) {
+ defer func() { clientDone <- struct{}{} }()
+
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+ testObj := &testPayload{
+ data: rand.Uint32(),
+ }
+
+ req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
+ if err != nil {
+ t.Fatalf("NewRequest creation failed: %v", err)
+ }
+
+ // Queue up a request.
+ // Analogous to Call except it doesn't block on the task.
+ resp, err := CallTest(conn, clientTask, req, pid)
+ if err != nil {
+ t.Fatalf("CallTaskNonBlock failed: %v", err)
+ }
+
+ if err = resp.Error(); err != nil {
+ t.Fatalf("Server responded with an error: %v", err)
+ }
+
+ var respTestPayload testPayload
+ if err := resp.UnmarshalPayload(&respTestPayload); err != nil {
+ t.Fatalf("Unmarshalling payload error: %v", err)
+ }
+
+ if resp.hdr.Unique != req.hdr.Unique {
+ t.Fatalf("got response for another request. Expected response for req %v but got response for req %v",
+ req.hdr.Unique, resp.hdr.Unique)
+ }
+
+ if respTestPayload.data != testObj.data {
+ t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data)
+ }
+
+}
+
+// fuseServerRun creates a task and emulates all the actions of a simple FUSE server
+// that simply reads a request and echos the same struct back as a response using the
+// appropriate headers.
+func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) {
+ defer func() { serverDone <- struct{}{} }()
+
+ // Create the tasks that the server will be using.
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ var readPayload testPayload
+
+ serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Read the request.
+ for {
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ payloadLen := uint32(readPayload.SizeBytes())
+
+ // The raed buffer must meet some certain size criteria.
+ buffSize := inHdrLen + payloadLen
+ if buffSize < linux.FUSE_MIN_READ_BUFFER {
+ buffSize = linux.FUSE_MIN_READ_BUFFER
+ }
+ inBuf := make([]byte, buffSize)
+ inIOseq := usermem.BytesIOSequence(inBuf)
+
+ n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer)
+ if err != nil {
+ t.Fatalf("Read failed :%v", err)
+ }
+
+ // Server should shut down. No new requests are going to be made.
+ if serverKilled {
+ break
+ }
+
+ if n <= 0 {
+ t.Fatalf("Read read no bytes")
+ }
+
+ var readFUSEHeaderIn linux.FUSEHeaderIn
+ readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen])
+ readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen])
+
+ if readFUSEHeaderIn.Opcode != echoTestOpcode {
+ t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload)
+ }
+
+ // Write the response.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ outBuf := make([]byte, outHdrLen+payloadLen)
+ outHeader := linux.FUSEHeaderOut{
+ Len: outHdrLen + payloadLen,
+ Error: 0,
+ Unique: readFUSEHeaderIn.Unique,
+ }
+
+ // Echo the payload back.
+ outHeader.MarshalUnsafe(outBuf[:outHdrLen])
+ readPayload.MarshalUnsafe(outBuf[outHdrLen:])
+ outIOseq := usermem.BytesIOSequence(outBuf)
+
+ n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed :%v", err)
+ }
+ }
+}
+
+func setup(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Error creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ AllowUserMount: true,
+ })
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
+}
+
+// newTestConnection creates a fuse connection that the sentry can communicate with
+// and the FD for the server to communicate with.
+func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) {
+ vfsObj := &vfs.VirtualFilesystem{}
+ fuseDev := &DeviceFD{}
+
+ if err := vfsObj.Init(system.Ctx); err != nil {
+ return nil, nil, err
+ }
+
+ vd := vfsObj.NewAnonVirtualDentry("genCountFD")
+ defer vd.DecRef(system.Ctx)
+ if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, nil, err
+ }
+
+ fsopts := filesystemOptions{
+ maxActiveRequests: maxActiveRequests,
+ }
+ fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return fs.conn, &fuseDev.vfsfd, nil
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *testPayload) SizeBytes() int {
+ return 4
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *testPayload) MarshalBytes(dst []byte) {
+ usermem.ByteOrder.PutUint32(dst[:4], t.data)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *testPayload) UnmarshalBytes(src []byte) {
+ *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (t *testPayload) Packed() bool {
+ return true
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (t *testPayload) MarshalUnsafe(dst []byte) {
+ t.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (t *testPayload) UnmarshalUnsafe(src []byte) {
+ t.UnmarshalBytes(src)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ panic("not implemented")
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
+ panic("not implemented")
+}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
new file mode 100644
index 000000000..83c24ec25
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -0,0 +1,324 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fuse implements fusefs.
+package fuse
+
+import (
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "fuse"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+type filesystemOptions struct {
+ // userID specifies the numeric uid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ userID uint32
+
+ // groupID specifies the numeric gid of the mount owner.
+ // This option should not be specified by the filesystem owner.
+ // It is set by libfuse (or, if libfuse is not used, must be set
+ // by the filesystem itself). For more information, see man page
+ // for fuse(8)
+ groupID uint32
+
+ // rootMode specifies the the file mode of the filesystem's root.
+ rootMode linux.FileMode
+
+ // maxActiveRequests specifies the maximum number of active requests that can
+ // exist at any time. Any further requests will block when trying to
+ // Call the server.
+ maxActiveRequests uint64
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // conn is used for communication between the FUSE server
+ // daemon and the sentry fusefs.
+ conn *connection
+
+ // opts is the options the fusefs is initialized with.
+ opts *filesystemOptions
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ var fsopts filesystemOptions
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ deviceDescriptorStr, ok := mopts["fd"]
+ if !ok {
+ log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "fd")
+
+ deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ kernelTask := kernel.TaskFromContext(ctx)
+ if kernelTask == nil {
+ log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name())
+ return nil, nil, syserror.EINVAL
+ }
+ fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor))
+
+ // Parse and set all the other supported FUSE mount options.
+ // TODO(gVisor.dev/issue/3229): Expand the supported mount options.
+ if userIDStr, ok := mopts["user_id"]; ok {
+ delete(mopts, "user_id")
+ userID, err := strconv.ParseUint(userIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.userID = uint32(userID)
+ }
+
+ if groupIDStr, ok := mopts["group_id"]; ok {
+ delete(mopts, "group_id")
+ groupID, err := strconv.ParseUint(groupIDStr, 10, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.groupID = uint32(groupID)
+ }
+
+ rootMode := linux.FileMode(0777)
+ modeStr, ok := mopts["rootmode"]
+ if ok {
+ delete(mopts, "rootmode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode)
+ }
+ fsopts.rootMode = rootMode
+
+ // Set the maxInFlightRequests option.
+ fsopts.maxActiveRequests = maxActiveRequestsDefault
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Create a new FUSE filesystem.
+ fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd)
+ if err != nil {
+ log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err)
+ return nil, nil, err
+ }
+
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ // Send a FUSE_INIT request to the FUSE daemon server before returning.
+ // This call is not blocking.
+ if err := fs.conn.InitSend(creds, uint32(kernelTask.ThreadID())); err != nil {
+ log.Warningf("%s.InitSend: failed with error: %v", fsType.Name(), err)
+ return nil, nil, err
+ }
+
+ // root is the fusefs root directory.
+ root := fs.newInode(creds, fsopts.rootMode)
+
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// NewFUSEFilesystem creates a new FUSE filesystem.
+func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) {
+ fs := &filesystem{
+ devMinor: devMinor,
+ opts: opts,
+ }
+
+ conn, err := newFUSEConnection(ctx, device, opts.maxActiveRequests)
+ if err != nil {
+ log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err)
+ return nil, syserror.EINVAL
+ }
+
+ fs.conn = conn
+ fuseFD := device.Impl().(*DeviceFD)
+ fuseFD.fs = fs
+
+ return fs, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry {
+ i := &inode{}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ i.dentry.Init(i)
+
+ return &i.dentry
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// statFromFUSEAttr makes attributes from linux.FUSEAttr to linux.Statx. The
+// opts.Sync attribute is ignored since the synchronization is handled by the
+// FUSE server.
+func statFromFUSEAttr(attr linux.FUSEAttr, mask, devMinor uint32) linux.Statx {
+ var stat linux.Statx
+ stat.Blksize = attr.BlkSize
+ stat.DevMajor, stat.DevMinor = linux.UNNAMED_MAJOR, devMinor
+
+ rdevMajor, rdevMinor := linux.DecodeDeviceID(attr.Rdev)
+ stat.RdevMajor, stat.RdevMinor = uint32(rdevMajor), rdevMinor
+
+ if mask&linux.STATX_MODE != 0 {
+ stat.Mode = uint16(attr.Mode)
+ }
+ if mask&linux.STATX_NLINK != 0 {
+ stat.Nlink = attr.Nlink
+ }
+ if mask&linux.STATX_UID != 0 {
+ stat.UID = attr.UID
+ }
+ if mask&linux.STATX_GID != 0 {
+ stat.GID = attr.GID
+ }
+ if mask&linux.STATX_ATIME != 0 {
+ stat.Atime = linux.StatxTimestamp{
+ Sec: int64(attr.Atime),
+ Nsec: attr.AtimeNsec,
+ }
+ }
+ if mask&linux.STATX_MTIME != 0 {
+ stat.Mtime = linux.StatxTimestamp{
+ Sec: int64(attr.Mtime),
+ Nsec: attr.MtimeNsec,
+ }
+ }
+ if mask&linux.STATX_CTIME != 0 {
+ stat.Ctime = linux.StatxTimestamp{
+ Sec: int64(attr.Ctime),
+ Nsec: attr.CtimeNsec,
+ }
+ }
+ if mask&linux.STATX_INO != 0 {
+ stat.Ino = attr.Ino
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ stat.Size = attr.Size
+ }
+ if mask&linux.STATX_BLOCKS != 0 {
+ stat.Blocks = attr.Blocks
+ }
+ return stat
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ fusefs := fs.Impl().(*filesystem)
+ conn := fusefs.conn
+ task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
+ if task == nil {
+ log.Warningf("couldn't get kernel task from context")
+ return linux.Statx{}, syserror.EINVAL
+ }
+
+ var in linux.FUSEGetAttrIn
+ // We don't set any attribute in the request, because in VFS2 fstat(2) will
+ // finally be translated into vfs.FilesystemImpl.StatAt() (see
+ // pkg/sentry/syscalls/linux/vfs2/stat.go), resulting in the same flow
+ // as stat(2). Thus GetAttrFlags and Fh variable will never be used in VFS2.
+ req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.Ino(), linux.FUSE_GETATTR, &in)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ res, err := conn.Call(task, req)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if err := res.Error(); err != nil {
+ return linux.Statx{}, err
+ }
+
+ var out linux.FUSEGetAttrOut
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return linux.Statx{}, err
+ }
+
+ // Set all metadata into kernfs.InodeAttrs.
+ if err := i.SetStat(ctx, fs, creds, vfs.SetStatOptions{
+ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, fusefs.devMinor),
+ }); err != nil {
+ return linux.Statx{}, err
+ }
+
+ return statFromFUSEAttr(out.Attr, opts.Mask, fusefs.devMinor), nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/init.go b/pkg/sentry/fsimpl/fuse/init.go
new file mode 100644
index 000000000..779c2bd3f
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/init.go
@@ -0,0 +1,166 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// consts used by FUSE_INIT negotiation.
+const (
+ // fuseMaxMaxPages is the maximum value for MaxPages received in InitOut.
+ // Follow the same behavior as unix fuse implementation.
+ fuseMaxMaxPages = 256
+
+ // Maximum value for the time granularity for file time stamps, 1s.
+ // Follow the same behavior as unix fuse implementation.
+ fuseMaxTimeGranNs = 1000000000
+
+ // Minimum value for MaxWrite.
+ // Follow the same behavior as unix fuse implementation.
+ fuseMinMaxWrite = 4096
+
+ // Temporary default value for max readahead, 128kb.
+ fuseDefaultMaxReadahead = 131072
+
+ // The FUSE_INIT_IN flags sent to the daemon.
+ // TODO(gvisor.dev/issue/3199): complete the flags.
+ fuseDefaultInitFlags = linux.FUSE_MAX_PAGES
+)
+
+// Adjustable maximums for Connection's cogestion control parameters.
+// Used as the upperbound of the config values.
+// Currently we do not support adjustment to them.
+var (
+ MaxUserBackgroundRequest uint16 = fuseDefaultMaxBackground
+ MaxUserCongestionThreshold uint16 = fuseDefaultCongestionThreshold
+)
+
+// InitSend sends a FUSE_INIT request.
+func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
+ in := linux.FUSEInitIn{
+ Major: linux.FUSE_KERNEL_VERSION,
+ Minor: linux.FUSE_KERNEL_MINOR_VERSION,
+ // TODO(gvisor.dev/issue/3196): find appropriate way to calculate this
+ MaxReadahead: fuseDefaultMaxReadahead,
+ Flags: fuseDefaultInitFlags,
+ }
+
+ req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
+ if err != nil {
+ return err
+ }
+
+ // Since there is no task to block on and FUSE_INIT is the request
+ // to unblock other requests, use nil.
+ return conn.CallAsync(nil, req)
+}
+
+// InitRecv receives a FUSE_INIT reply and process it.
+func (conn *connection) InitRecv(res *Response, hasSysAdminCap bool) error {
+ if err := res.Error(); err != nil {
+ return err
+ }
+
+ var out linux.FUSEInitOut
+ if err := res.UnmarshalPayload(&out); err != nil {
+ return err
+ }
+
+ return conn.initProcessReply(&out, hasSysAdminCap)
+}
+
+// Process the FUSE_INIT reply from the FUSE server.
+func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap bool) error {
+ // No support for old major fuse versions.
+ if out.Major != linux.FUSE_KERNEL_VERSION {
+ conn.connInitError = true
+
+ // Set the connection as initialized and unblock the blocked requests
+ // (i.e. return error for them).
+ conn.SetInitialized()
+
+ return nil
+ }
+
+ // Start processing the reply.
+ conn.connInitSuccess = true
+ conn.minor = out.Minor
+
+ // No support for limits before minor version 13.
+ if out.Minor >= 13 {
+ conn.bgLock.Lock()
+
+ if out.MaxBackground > 0 {
+ conn.maxBackground = out.MaxBackground
+
+ if !hasSysAdminCap &&
+ conn.maxBackground > MaxUserBackgroundRequest {
+ conn.maxBackground = MaxUserBackgroundRequest
+ }
+ }
+
+ if out.CongestionThreshold > 0 {
+ conn.congestionThreshold = out.CongestionThreshold
+
+ if !hasSysAdminCap &&
+ conn.congestionThreshold > MaxUserCongestionThreshold {
+ conn.congestionThreshold = MaxUserCongestionThreshold
+ }
+ }
+
+ conn.bgLock.Unlock()
+ }
+
+ // No support for the following flags before minor version 6.
+ if out.Minor >= 6 {
+ conn.asyncRead = out.Flags&linux.FUSE_ASYNC_READ != 0
+ conn.bigWrites = out.Flags&linux.FUSE_BIG_WRITES != 0
+ conn.dontMask = out.Flags&linux.FUSE_DONT_MASK != 0
+ conn.writebackCache = out.Flags&linux.FUSE_WRITEBACK_CACHE != 0
+ conn.cacheSymlinks = out.Flags&linux.FUSE_CACHE_SYMLINKS != 0
+ conn.abortErr = out.Flags&linux.FUSE_ABORT_ERROR != 0
+
+ // TODO(gvisor.dev/issue/3195): figure out how to use TimeGran (0 < TimeGran <= fuseMaxTimeGranNs).
+
+ if out.Flags&linux.FUSE_MAX_PAGES != 0 {
+ maxPages := out.MaxPages
+ if maxPages < 1 {
+ maxPages = 1
+ }
+ if maxPages > fuseMaxMaxPages {
+ maxPages = fuseMaxMaxPages
+ }
+ conn.maxPages = maxPages
+ }
+ }
+
+ // No support for negotiating MaxWrite before minor version 5.
+ if out.Minor >= 5 {
+ conn.maxWrite = out.MaxWrite
+ } else {
+ conn.maxWrite = fuseMinMaxWrite
+ }
+ if conn.maxWrite < fuseMinMaxWrite {
+ conn.maxWrite = fuseMinMaxWrite
+ }
+
+ // Set connection as initialized and unblock the requests
+ // issued before init.
+ conn.SetInitialized()
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go
new file mode 100644
index 000000000..b5b581152
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/register.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Register registers the FUSE device with vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "misc",
+ }); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// CreateDevtmpfsFile creates a device special file in devtmpfs.
+func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error {
+ if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD
new file mode 100644
index 000000000..16787116f
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/BUILD
@@ -0,0 +1,90 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "gofer",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dentry",
+ "Linker": "*dentry",
+ },
+)
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "gofer",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_library(
+ name = "gofer",
+ srcs = [
+ "dentry_list.go",
+ "directory.go",
+ "filesystem.go",
+ "fstree.go",
+ "gofer.go",
+ "handle.go",
+ "host_named_pipe.go",
+ "p9file.go",
+ "regular_file.go",
+ "socket.go",
+ "special_file.go",
+ "symlink.go",
+ "time.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fdnotifier",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/p9",
+ "//pkg/safemem",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/hostfd",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/unet",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "gofer_test",
+ srcs = ["gofer_test.go"],
+ library = ":gofer",
+ deps = [
+ "//pkg/p9",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/pgalloc",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go
new file mode 100644
index 000000000..2a8011eb4
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/directory.go
@@ -0,0 +1,306 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isDir() bool {
+ return d.fileType() == linux.S_IFDIR
+}
+
+// Preconditions: filesystem.renameMu must be locked. d.dirMu must be locked.
+// d.isDir(). child must be a newly-created dentry that has never had a parent.
+func (d *dentry) cacheNewChildLocked(child *dentry, name string) {
+ d.IncRef() // reference held by child on its parent
+ child.parent = d
+ child.name = name
+ if d.children == nil {
+ d.children = make(map[string]*dentry)
+ }
+ d.children[name] = child
+}
+
+// Preconditions: d.dirMu must be locked. d.isDir().
+func (d *dentry) cacheNegativeLookupLocked(name string) {
+ // Don't cache negative lookups if InteropModeShared is in effect (since
+ // this makes remote lookup unavoidable), or if d.isSynthetic() (in which
+ // case the only files in the directory are those for which a dentry exists
+ // in d.children). Instead, just delete any previously-cached dentry.
+ if d.fs.opts.interop == InteropModeShared || d.isSynthetic() {
+ delete(d.children, name)
+ return
+ }
+ if d.children == nil {
+ d.children = make(map[string]*dentry)
+ }
+ d.children[name] = nil
+}
+
+type createSyntheticOpts struct {
+ name string
+ mode linux.FileMode
+ kuid auth.KUID
+ kgid auth.KGID
+
+ // The endpoint for a synthetic socket. endpoint should be nil if the file
+ // being created is not a socket.
+ endpoint transport.BoundEndpoint
+
+ // pipe should be nil if the file being created is not a pipe.
+ pipe *pipe.VFSPipe
+}
+
+// createSyntheticChildLocked creates a synthetic file with the given name
+// in d.
+//
+// Preconditions: d.dirMu must be locked. d.isDir(). d does not already contain
+// a child with the given name.
+func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) {
+ d2 := &dentry{
+ refs: 1, // held by d
+ fs: d.fs,
+ ino: d.fs.nextSyntheticIno(),
+ mode: uint32(opts.mode),
+ uid: uint32(opts.kuid),
+ gid: uint32(opts.kgid),
+ blockSize: usermem.PageSize, // arbitrary
+ hostFD: -1,
+ nlink: uint32(2),
+ }
+ switch opts.mode.FileType() {
+ case linux.S_IFDIR:
+ // Nothing else needs to be done.
+ case linux.S_IFSOCK:
+ d2.endpoint = opts.endpoint
+ case linux.S_IFIFO:
+ d2.pipe = opts.pipe
+ default:
+ panic(fmt.Sprintf("failed to create synthetic file of unrecognized type: %v", opts.mode.FileType()))
+ }
+ d2.pf.dentry = d2
+ d2.vfsd.Init(d2)
+
+ d.cacheNewChildLocked(d2, opts.name)
+ d.syntheticChildren++
+}
+
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+
+ mu sync.Mutex
+ off int64
+ dirents []vfs.Dirent
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release(context.Context) {
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ d := fd.dentry()
+ if fd.dirents == nil {
+ ds, err := d.getDirents(ctx)
+ if err != nil {
+ return err
+ }
+ fd.dirents = ds
+ }
+
+ d.InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent)
+ if d.cachedMetadataAuthoritative() {
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+
+ for fd.off < int64(len(fd.dirents)) {
+ if err := cb.Handle(fd.dirents[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
+// Preconditions: d.isDir(). There exists at least one directoryFD representing d.
+func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
+ // NOTE(b/135560623): 9P2000.L's readdir does not specify behavior in the
+ // presence of concurrent mutation of an iterated directory, so
+ // implementations may duplicate or omit entries in this case, which
+ // violates POSIX semantics. Thus we read all directory entries while
+ // holding d.dirMu to exclude directory mutations. (Note that it is
+ // impossible for the client to exclude concurrent mutation from other
+ // remote filesystem users. Since there is no way to detect if the server
+ // has incorrectly omitted directory entries, we simply assume that the
+ // server is well-behaved under InteropModeShared.) This is inconsistent
+ // with Linux (which appears to assume that directory fids have the correct
+ // semantics, and translates struct file_operations::readdir calls directly
+ // to readdir RPCs), but is consistent with VFS1.
+
+ // filesystem.renameMu is needed for d.parent, and must be locked before
+ // dentry.dirMu.
+ d.fs.renameMu.RLock()
+ defer d.fs.renameMu.RUnlock()
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+ if d.dirents != nil {
+ return d.dirents, nil
+ }
+
+ // It's not clear if 9P2000.L's readdir is expected to return "." and "..",
+ // so we generate them here.
+ parent := genericParentOrSelf(d)
+ dirents := []vfs.Dirent{
+ {
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: uint64(d.ino),
+ NextOff: 1,
+ },
+ {
+ Name: "..",
+ Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
+ Ino: uint64(parent.ino),
+ NextOff: 2,
+ },
+ }
+ var realChildren map[string]struct{}
+ if !d.isSynthetic() {
+ if d.syntheticChildren != 0 && d.fs.opts.interop == InteropModeShared {
+ // Record the set of children d actually has so that we don't emit
+ // duplicate entries for synthetic children.
+ realChildren = make(map[string]struct{})
+ }
+ off := uint64(0)
+ const count = 64 * 1024 // for consistency with the vfs1 client
+ d.handleMu.RLock()
+ if d.readFile.isNil() {
+ // This should not be possible because a readable handle should
+ // have been opened when the calling directoryFD was opened.
+ d.handleMu.RUnlock()
+ panic("gofer.dentry.getDirents called without a readable handle")
+ }
+ for {
+ p9ds, err := d.readFile.readdir(ctx, off, count)
+ if err != nil {
+ d.handleMu.RUnlock()
+ return nil, err
+ }
+ if len(p9ds) == 0 {
+ d.handleMu.RUnlock()
+ break
+ }
+ for _, p9d := range p9ds {
+ if p9d.Name == "." || p9d.Name == ".." {
+ continue
+ }
+ dirent := vfs.Dirent{
+ Name: p9d.Name,
+ Ino: uint64(inoFromPath(p9d.QID.Path)),
+ NextOff: int64(len(dirents) + 1),
+ }
+ // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or
+ // DMSOCKET.
+ switch p9d.Type {
+ case p9.TypeSymlink:
+ dirent.Type = linux.DT_LNK
+ case p9.TypeDir:
+ dirent.Type = linux.DT_DIR
+ default:
+ dirent.Type = linux.DT_REG
+ }
+ dirents = append(dirents, dirent)
+ if realChildren != nil {
+ realChildren[p9d.Name] = struct{}{}
+ }
+ }
+ off = p9ds[len(p9ds)-1].Offset
+ }
+ }
+ // Emit entries for synthetic children.
+ if d.syntheticChildren != 0 {
+ for _, child := range d.children {
+ if child == nil || !child.isSynthetic() {
+ continue
+ }
+ if _, ok := realChildren[child.name]; ok {
+ continue
+ }
+ dirents = append(dirents, vfs.Dirent{
+ Name: child.name,
+ Type: uint8(atomic.LoadUint32(&child.mode) >> 12),
+ Ino: uint64(child.ino),
+ NextOff: int64(len(dirents) + 1),
+ })
+ }
+ }
+ // Cache dirents for future directoryFDs if permitted.
+ if d.cachedMetadataAuthoritative() {
+ d.dirents = dirents
+ }
+ return dirents, nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset == 0 {
+ // Ensure that the next call to fd.IterDirents() calls
+ // fd.dentry().getDirents().
+ fd.dirents = nil
+ }
+ fd.off = offset
+ return fd.off, nil
+ case linux.SEEK_CUR:
+ offset += fd.off
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ // Don't clear fd.dirents in this case, even if offset == 0.
+ fd.off = offset
+ return fd.off, nil
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *directoryFD) Sync(ctx context.Context) error {
+ return fd.dentry().syncRemoteFile(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
new file mode 100644
index 000000000..a3903db33
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -0,0 +1,1550 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "math"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // Snapshot current syncable dentries and special files.
+ fs.syncMu.Lock()
+ ds := make([]*dentry, 0, len(fs.syncableDentries))
+ for d := range fs.syncableDentries {
+ d.IncRef()
+ ds = append(ds, d)
+ }
+ sffds := make([]*specialFileFD, 0, len(fs.specialFileFDs))
+ for sffd := range fs.specialFileFDs {
+ sffd.vfsfd.IncRef()
+ sffds = append(sffds, sffd)
+ }
+ fs.syncMu.Unlock()
+
+ // Return the first error we encounter, but sync everything we can
+ // regardless.
+ var retErr error
+
+ // Sync regular files.
+ for _, d := range ds {
+ err := d.syncCachedFile(ctx)
+ d.DecRef(ctx)
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ // Sync special files, which may be writable but do not use dentry shared
+ // handles (so they won't be synced by the above).
+ for _, sffd := range sffds {
+ err := sffd.Sync(ctx)
+ sffd.vfsfd.DecRef(ctx)
+ if err != nil && retErr == nil {
+ retErr = err
+ }
+ }
+
+ return retErr
+}
+
+// maxFilenameLen is the maximum length of a filename. This is dictated by 9P's
+// encoding of strings, which uses 2 bytes for the length prefix.
+const maxFilenameLen = (1 << 16) - 1
+
+// dentrySlicePool is a pool of *[]*dentry used to store dentries for which
+// dentry.checkCachingLocked() must be called. The pool holds pointers to
+// slices because Go lacks generics, so sync.Pool operates on interface{}, so
+// every call to (what should be) sync.Pool<[]*dentry>.Put() allocates a copy
+// of the slice header on the heap.
+var dentrySlicePool = sync.Pool{
+ New: func() interface{} {
+ ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity
+ return &ds
+ },
+}
+
+func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
+ if ds == nil {
+ ds = dentrySlicePool.Get().(*[]*dentry)
+ }
+ *ds = append(*ds, d)
+ return ds
+}
+
+// Preconditions: ds != nil.
+func putDentrySlice(ds *[]*dentry) {
+ // Allow dentries to be GC'd.
+ for i := range *ds {
+ (*ds)[i] = nil
+ }
+ *ds = (*ds)[:0]
+ dentrySlicePool.Put(ds)
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// Dentries which may become cached as a result of the traversal are appended
+// to *ds.
+//
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// !rp.Done(). If !d.cachedMetadataAuthoritative(), then d's cached metadata
+// must be up to date.
+//
+// Postconditions: The returned dentry's cached metadata is up to date.
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, nil
+ }
+ // We must assume that d.parent is correct, because if d has been moved
+ // elsewhere in the remote filesystem so that its parent has changed,
+ // we have no way of determining its new parent's location in the
+ // filesystem.
+ //
+ // Call rp.CheckMount() before updating d.parent's metadata, since if
+ // we traverse to another mount then d.parent's metadata is irrelevant.
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ if d != d.parent && !d.cachedMetadataAuthoritative() {
+ if err := d.parent.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ rp.Advance()
+ return d.parent, nil
+ }
+ child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds)
+ if err != nil {
+ return nil, err
+ }
+ if child == nil {
+ return nil, syserror.ENOENT
+ }
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
+ return nil, err
+ }
+ if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// getChildLocked returns a dentry representing the child of parent with the
+// given name. If no such child exists, getChildLocked returns (nil, nil).
+//
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+// parent.isDir(). name is not "." or "..".
+//
+// Postconditions: If getChildLocked returns a non-nil dentry, its cached
+// metadata is up to date.
+func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+ if len(name) > maxFilenameLen {
+ return nil, syserror.ENAMETOOLONG
+ }
+ child, ok := parent.children[name]
+ if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() {
+ // Whether child is nil or not, it is cached information that is
+ // assumed to be correct.
+ return child, nil
+ }
+ // We either don't have cached information or need to verify that it's
+ // still correct, either of which requires a remote lookup. Check if this
+ // name is valid before performing the lookup.
+ return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds)
+}
+
+// Preconditions: As for getChildLocked. !parent.isSynthetic().
+func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) {
+ if child != nil {
+ // Need to lock child.metadataMu because we might be updating child
+ // metadata. We need to hold the lock *before* getting metadata from the
+ // server and release it after updating local metadata.
+ child.metadataMu.Lock()
+ }
+ qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name)
+ if err != nil && err != syserror.ENOENT {
+ if child != nil {
+ child.metadataMu.Unlock()
+ }
+ return nil, err
+ }
+ if child != nil {
+ if !file.isNil() && inoFromPath(qid.Path) == child.ino {
+ // The file at this path hasn't changed. Just update cached metadata.
+ file.close(ctx)
+ child.updateFromP9AttrsLocked(attrMask, &attr)
+ child.metadataMu.Unlock()
+ return child, nil
+ }
+ child.metadataMu.Unlock()
+ if file.isNil() && child.isSynthetic() {
+ // We have a synthetic file, and no remote file has arisen to
+ // replace it.
+ return child, nil
+ }
+ // The file at this path has changed or no longer exists. Mark the
+ // dentry invalidated, and re-evaluate its caching status (i.e. if it
+ // has 0 references, drop it). Wait to update parent.children until we
+ // know what to replace the existing dentry with (i.e. one of the
+ // returns below), to avoid a redundant map access.
+ vfsObj.InvalidateDentry(ctx, &child.vfsd)
+ if child.isSynthetic() {
+ // Normally we don't mark invalidated dentries as deleted since
+ // they may still exist (but at a different path), and also for
+ // consistency with Linux. However, synthetic files are guaranteed
+ // to become unreachable if their dentries are invalidated, so
+ // treat their invalidation as deletion.
+ child.setDeleted()
+ parent.syntheticChildren--
+ child.decRefLocked()
+ parent.dirents = nil
+ }
+ *ds = appendDentry(*ds, child)
+ }
+ if file.isNil() {
+ // No file exists at this path now. Cache the negative lookup if
+ // allowed.
+ parent.cacheNegativeLookupLocked(name)
+ return nil, nil
+ }
+ // Create a new dentry representing the file.
+ child, err = fs.newDentry(ctx, file, qid, attrMask, &attr)
+ if err != nil {
+ file.close(ctx)
+ delete(parent.children, name)
+ return nil, err
+ }
+ parent.cacheNewChildLocked(child, name)
+ // For now, child has 0 references, so our caller should call
+ // child.checkCachingLocked().
+ *ds = appendDentry(*ds, child)
+ return child, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// Preconditions: fs.renameMu must be locked. !rp.Done(). If
+// !d.cachedMetadataAuthoritative(), then d's cached metadata must be up to
+// date.
+func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ for !rp.Final() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// Preconditions: fs.renameMu must be locked.
+func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ if !d.cachedMetadataAuthoritative() {
+ // Get updated metadata for rp.Start() as required by fs.stepLocked().
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ for !rp.Done() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// createInRemoteDir (if the parent directory is a real remote directory) or
+// createInSyntheticDir (if the parent directory is synthetic) to do so.
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string) error, createInSyntheticDir func(parent *dentry, name string) error) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if len(name) > maxFilenameLen {
+ return syserror.ENAMETOOLONG
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ if parent.isDeleted() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+ if parent.isSynthetic() {
+ if child := parent.children[name]; child != nil {
+ return syserror.EEXIST
+ }
+ if createInSyntheticDir == nil {
+ return syserror.EPERM
+ }
+ if err := createInSyntheticDir(parent, name); err != nil {
+ return err
+ }
+ parent.touchCMtime()
+ parent.dirents = nil
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ return nil
+ }
+ if fs.opts.interop == InteropModeShared {
+ if child := parent.children[name]; child != nil && child.isSynthetic() {
+ return syserror.EEXIST
+ }
+ // The existence of a non-synthetic dentry at name would be inconclusive
+ // because the file it represents may have been deleted from the remote
+ // filesystem, so we would need to make an RPC to revalidate the dentry.
+ // Just attempt the file creation RPC instead. If a file does exist, the
+ // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a
+ // stale dentry exists, the dentry will fail revalidation next time it's
+ // used.
+ if err := createInRemoteDir(parent, name); err != nil {
+ return err
+ }
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ return nil
+ }
+ if child := parent.children[name]; child != nil {
+ return syserror.EEXIST
+ }
+ // No cached dentry exists; however, there might still be an existing file
+ // at name. As above, we attempt the file creation RPC anyway.
+ if err := createInRemoteDir(parent, name); err != nil {
+ return err
+ }
+ if child, ok := parent.children[name]; ok && child == nil {
+ // Delete the now-stale negative dentry.
+ delete(parent.children, name)
+ }
+ parent.touchCMtime()
+ parent.dirents = nil
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parent.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ return nil
+}
+
+// Preconditions: !rp.Done().
+func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+
+ name := rp.Component()
+ if dir {
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ } else {
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ child, ok := parent.children[name]
+ if ok && child == nil {
+ return syserror.ENOENT
+ }
+
+ sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0
+ if sticky {
+ if !ok {
+ // If the sticky bit is set, we need to retrieve the child to determine
+ // whether removing it is allowed.
+ child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err != nil {
+ return err
+ }
+ } else if child != nil && !child.cachedMetadataAuthoritative() {
+ // Make sure the dentry representing the file at name is up to date
+ // before examining its metadata.
+ child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
+ if err != nil {
+ return err
+ }
+ }
+ if err := parent.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
+ }
+
+ // If a child dentry exists, prepare to delete it. This should fail if it is
+ // a mount point. We detect mount points by speculatively calling
+ // PrepareDeleteDentry, which fails if child is a mount point. However, we
+ // may need to revalidate the file in this case to make sure that it has not
+ // been deleted or replaced on the remote fs, in which case the mount point
+ // will have disappeared. If calling PrepareDeleteDentry fails again on the
+ // up-to-date dentry, we can be sure that it is a mount point.
+ //
+ // Also note that if child is nil, then it can't be a mount point.
+ if child != nil {
+ // Hold child.dirMu so we can check child.children and
+ // child.syntheticChildren. We don't access these fields until a bit later,
+ // but locking child.dirMu after calling vfs.PrepareDeleteDentry() would
+ // create an inconsistent lock ordering between dentry.dirMu and
+ // vfs.Dentry.mu (in the VFS lock order, it would make dentry.dirMu both "a
+ // FilesystemImpl lock" and "a lock acquired by a FilesystemImpl between
+ // PrepareDeleteDentry and CommitDeleteDentry). To avoid this, lock
+ // child.dirMu before calling PrepareDeleteDentry.
+ child.dirMu.Lock()
+ defer child.dirMu.Unlock()
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ // We can skip revalidation in several cases:
+ // - We are not in InteropModeShared
+ // - The parent directory is synthetic, in which case the child must also
+ // be synthetic
+ // - We already updated the child during the sticky bit check above
+ if parent.cachedMetadataAuthoritative() || sticky {
+ return err
+ }
+ child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds)
+ if err != nil {
+ return err
+ }
+ if child != nil {
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+ }
+ }
+ }
+ flags := uint32(0)
+ // If a dentry exists, use it for best-effort checks on its deletability.
+ if dir {
+ if child != nil {
+ // child must be an empty directory.
+ if child.syntheticChildren != 0 {
+ // This is definitely not an empty directory, irrespective of
+ // fs.opts.interop.
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTEMPTY
+ }
+ // If InteropModeShared is in effect and the first call to
+ // PrepareDeleteDentry above succeeded, then child wasn't
+ // revalidated (so we can't expect its file type to be correct) and
+ // individually revalidating its children (to confirm that they
+ // still exist) would be a waste of time.
+ if child.cachedMetadataAuthoritative() {
+ if !child.isDir() {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTDIR
+ }
+ for _, grandchild := range child.children {
+ if grandchild != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.ENOTEMPTY
+ }
+ }
+ }
+ }
+ flags = linux.AT_REMOVEDIR
+ } else {
+ // child must be a non-directory file.
+ if child != nil && child.isDir() {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return syserror.ENOTDIR
+ }
+ }
+ if parent.isSynthetic() {
+ if child == nil {
+ return syserror.ENOENT
+ }
+ } else if child == nil || !child.isSynthetic() {
+ err = parent.file.unlinkAt(ctx, name, flags)
+ if err != nil {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+ }
+
+ // Generate inotify events for rmdir or unlink.
+ if dir {
+ parent.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ } else {
+ var cw *vfs.Watches
+ if child != nil {
+ cw = &child.watches
+ }
+ vfs.InotifyRemoveChild(ctx, cw, &parent.watches, name)
+ }
+
+ if child != nil {
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
+ child.setDeleted()
+ if child.isSynthetic() {
+ parent.syntheticChildren--
+ child.decRefLocked()
+ }
+ ds = appendDentry(ds, child)
+ }
+ parent.cacheNegativeLookupLocked(name)
+ if parent.cachedMetadataAuthoritative() {
+ parent.dirents = nil
+ parent.touchCMtime()
+ if dir {
+ parent.decLinks()
+ }
+ }
+ return nil
+}
+
+// renameMuRUnlockAndCheckCaching calls fs.renameMu.RUnlock(), then calls
+// dentry.checkCachingLocked on all dentries in *ds with fs.renameMu locked for
+// writing.
+//
+// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
+// but dentry slices are allocated lazily, and it's much easier to say "defer
+// fs.renameMuRUnlockAndCheckCaching(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckCaching(ds) }()" to work around this.
+func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
+ fs.renameMu.RUnlock()
+ if *ds == nil {
+ return
+ }
+ if len(**ds) != 0 {
+ fs.renameMu.Lock()
+ for _, d := range **ds {
+ d.checkCachingLocked(ctx)
+ }
+ fs.renameMu.Unlock()
+ }
+ putDentrySlice(*ds)
+}
+
+func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[]*dentry) {
+ if *ds == nil {
+ fs.renameMu.Unlock()
+ return
+ }
+ for _, d := range **ds {
+ d.checkCachingLocked(ctx)
+ }
+ fs.renameMu.Unlock()
+ putDentrySlice(*ds)
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by
+ // fs.walkParentDirLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ d := vd.Dentry().Impl().(*dentry)
+ if d.isDir() {
+ return syserror.EPERM
+ }
+ gid := auth.KGID(atomic.LoadUint32(&d.gid))
+ uid := auth.KUID(atomic.LoadUint32(&d.uid))
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil {
+ return err
+ }
+ if d.nlink == 0 {
+ return syserror.ENOENT
+ }
+ if d.nlink == math.MaxUint32 {
+ return syserror.EMLINK
+ }
+ if err := parent.file.link(ctx, d.file, childName); err != nil {
+ return err
+ }
+
+ // Success!
+ atomic.AddUint32(&d.nlink, 1)
+ return nil
+ }, nil)
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ creds := rp.Credentials()
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error {
+ if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil {
+ if !opts.ForSyntheticMountpoint || err == syserror.EEXIST {
+ return err
+ }
+ ctx.Infof("Failed to create remote directory %q: %v; falling back to synthetic directory", name, err)
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: linux.S_IFDIR | opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ })
+ }
+ if fs.opts.interop != InteropModeShared {
+ parent.incLinks()
+ }
+ return nil
+ }, func(parent *dentry, name string) error {
+ if !opts.ForSyntheticMountpoint {
+ // Can't create non-synthetic files in synthetic directories.
+ return syserror.EPERM
+ }
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: linux.S_IFDIR | opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ })
+ parent.incLinks()
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error {
+ creds := rp.Credentials()
+ _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ // If the gofer does not allow creating a socket or pipe, create a
+ // synthetic one, i.e. one that is kept entirely in memory.
+ if err == syserror.EPERM {
+ switch opts.Mode.FileType() {
+ case linux.S_IFSOCK:
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ endpoint: opts.Endpoint,
+ })
+ return nil
+ case linux.S_IFIFO:
+ parent.createSyntheticChildLocked(&createSyntheticOpts{
+ name: name,
+ mode: opts.Mode,
+ kuid: creds.EffectiveKUID,
+ kgid: creds.EffectiveKGID,
+ pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ })
+ return nil
+ }
+ }
+ return err
+ }, nil)
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Reject O_TMPFILE, which is not supported; supporting it correctly in the
+ // presence of other remote filesystem users requires remote filesystem
+ // support, and it isn't clear that there's any way to implement this in
+ // 9P.
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ return nil, syserror.EOPNOTSUPP
+ }
+ mayCreate := opts.Flags&linux.O_CREAT != 0
+ mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL)
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if !start.cachedMetadataAuthoritative() {
+ // Get updated metadata for start as required by fs.stepLocked().
+ if err := start.updateFromGetattr(ctx); err != nil {
+ return nil, err
+ }
+ }
+ if rp.Done() {
+ // Reject attempts to open mount root directory with O_CREAT.
+ if mayCreate && rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Reject attempts to open directories with O_CREAT.
+ if mayCreate && rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ // Determine whether or not we need to create a file.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err == syserror.ENOENT && mayCreate {
+ if parent.isSynthetic() {
+ parent.dirMu.Unlock()
+ return nil, syserror.EPERM
+ }
+ fd, err := parent.createAndOpenChildLocked(ctx, rp, &opts, &ds)
+ parent.dirMu.Unlock()
+ return fd, err
+ }
+ parent.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ // Open existing child or follow symlink.
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ if rp.MustBeDir() && !child.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+
+ trunc := opts.Flags&linux.O_TRUNC != 0 && d.fileType() == linux.S_IFREG
+ if trunc {
+ // Lock metadataMu *while* we open a regular file with O_TRUNC because
+ // open(2) will change the file size on server.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ }
+
+ var vfd *vfs.FileDescription
+ var err error
+ mnt := rp.Mount()
+ switch d.fileType() {
+ case linux.S_IFREG:
+ if !d.fs.opts.regularFilesUseSpecialFileFD {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil {
+ return nil, err
+ }
+ fd := &regularFileFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ vfd = &fd.vfsfd
+ }
+ case linux.S_IFDIR:
+ // Can't open directories with O_CREAT.
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EISDIR
+ }
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ if !d.isSynthetic() {
+ if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, false /* write */, false /* trunc */); err != nil {
+ return nil, err
+ }
+ }
+ fd := &directoryFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case linux.S_IFLNK:
+ // Can't open symlinks without O_PATH (which is unimplemented).
+ return nil, syserror.ELOOP
+ case linux.S_IFSOCK:
+ if d.isSynthetic() {
+ return nil, syserror.ENXIO
+ }
+ if d.fs.iopts.OpenSocketsByConnecting {
+ return d.connectSocketLocked(ctx, opts)
+ }
+ case linux.S_IFIFO:
+ if d.isSynthetic() {
+ return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags, &d.locks)
+ }
+ }
+
+ if vfd == nil {
+ if vfd, err = d.openSpecialFileLocked(ctx, mnt, opts); err != nil {
+ return nil, err
+ }
+ }
+
+ if trunc {
+ // If no errors occured so far then update file size in memory. This
+ // step is required even if !d.cachedMetadataAuthoritative() because
+ // d.mappings has to be updated.
+ // d.metadataMu has already been acquired if trunc == true.
+ d.updateFileSizeLocked(0)
+
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtimeLocked()
+ }
+ }
+ return vfd, err
+}
+
+func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ fdObj, err := d.file.connect(ctx, p9.AnonymousSocket)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := host.NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fdObj.FD(), &host.NewFDOptions{
+ HaveFlags: true,
+ Flags: opts.Flags,
+ })
+ if err != nil {
+ fdObj.Close()
+ return nil, err
+ }
+ fdObj.Release()
+ return fd, nil
+}
+
+func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ // We assume that the server silently inserts O_NONBLOCK in the open flags
+ // for all named pipes (because all existing gofers do this).
+ //
+ // NOTE(b/133875563): This makes named pipe opens racy, because the
+ // mechanisms for translating nonblocking to blocking opens can only detect
+ // the instantaneous presence of a peer holding the other end of the pipe
+ // open, not whether the pipe was *previously* opened by a peer that has
+ // since closed its end.
+ isBlockingOpenOfNamedPipe := d.fileType() == linux.S_IFIFO && opts.Flags&linux.O_NONBLOCK == 0
+retry:
+ h, err := openHandle(ctx, d.file, ats.MayRead(), ats.MayWrite(), opts.Flags&linux.O_TRUNC != 0)
+ if err != nil {
+ if isBlockingOpenOfNamedPipe && ats == vfs.MayWrite && err == syserror.ENXIO {
+ // An attempt to open a named pipe with O_WRONLY|O_NONBLOCK fails
+ // with ENXIO if opening the same named pipe with O_WRONLY would
+ // block because there are no readers of the pipe.
+ if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
+ return nil, err
+ }
+ goto retry
+ }
+ return nil, err
+ }
+ if isBlockingOpenOfNamedPipe && ats == vfs.MayRead && h.fd >= 0 {
+ if err := blockUntilNonblockingPipeHasWriter(ctx, h.fd); err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ }
+ fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags)
+ if err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked.
+// !d.isSynthetic().
+func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if d.isDeleted() {
+ return nil, syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer mnt.EndWrite()
+
+ // 9P2000.L's lcreate takes a fid representing the parent directory, and
+ // converts it into an open fid representing the created file, so we need
+ // to duplicate the directory fid first.
+ _, dirfile, err := d.file.walk(ctx, nil)
+ if err != nil {
+ return nil, err
+ }
+ creds := rp.Credentials()
+ name := rp.Component()
+ // We only want the access mode for creating the file.
+ createFlags := p9.OpenFlags(opts.Flags) & p9.OpenFlagsModeMask
+ fdobj, openFile, createQID, _, err := dirfile.create(ctx, name, createFlags, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ if err != nil {
+ dirfile.close(ctx)
+ return nil, err
+ }
+ // Then we need to walk to the file we just created to get a non-open fid
+ // representing it, and to get its metadata. This must use d.file since, as
+ // explained above, dirfile was invalidated by dirfile.Create().
+ _, nonOpenFile, attrMask, attr, err := d.file.walkGetAttrOne(ctx, name)
+ if err != nil {
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+
+ // Construct the new dentry.
+ child, err := d.fs.newDentry(ctx, nonOpenFile, createQID, attrMask, &attr)
+ if err != nil {
+ nonOpenFile.close(ctx)
+ openFile.close(ctx)
+ if fdobj != nil {
+ fdobj.Close()
+ }
+ return nil, err
+ }
+ *ds = appendDentry(*ds, child)
+ // Incorporate the fid that was opened by lcreate.
+ useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD
+ if useRegularFileFD {
+ child.handleMu.Lock()
+ if vfs.MayReadFileWithOpenFlags(opts.Flags) {
+ child.readFile = openFile
+ if fdobj != nil {
+ child.hostFD = int32(fdobj.Release())
+ }
+ } else if fdobj != nil {
+ // Can't use fdobj if it's not readable.
+ fdobj.Close()
+ }
+ if vfs.MayWriteFileWithOpenFlags(opts.Flags) {
+ child.writeFile = openFile
+ }
+ child.handleMu.Unlock()
+ }
+ // Insert the dentry into the tree.
+ d.cacheNewChildLocked(child, name)
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtime()
+ d.dirents = nil
+ }
+
+ // Finally, construct a file description representing the created file.
+ var childVFSFD *vfs.FileDescription
+ if useRegularFileFD {
+ fd := &regularFileFD{}
+ fd.LockFD.Init(&child.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{
+ AllowDirectIO: true,
+ }); err != nil {
+ return nil, err
+ }
+ childVFSFD = &fd.vfsfd
+ } else {
+ h := handle{
+ file: openFile,
+ fd: -1,
+ }
+ if fdobj != nil {
+ h.fd = int32(fdobj.Release())
+ }
+ fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags)
+ if err != nil {
+ h.close(ctx)
+ return nil, err
+ }
+ childVFSFD = &fd.vfsfd
+ }
+ d.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
+ return childVFSFD, nil
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ if !d.isSymlink() {
+ return "", syserror.EINVAL
+ }
+ return d.readlink(ctx, rp.Mount())
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ // Requires 9P support.
+ return syserror.EINVAL
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.Lock()
+ defer fs.renameMuUnlockAndCheckCaching(ctx, &ds)
+ newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ oldParent := oldParentVD.Dentry().Impl().(*dentry)
+ if !oldParent.cachedMetadataAuthoritative() {
+ if err := oldParent.updateFromGetattr(ctx); err != nil {
+ return err
+ }
+ }
+ creds := rp.Credentials()
+ if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ vfsObj := rp.VirtualFilesystem()
+ // We need a dentry representing the renamed file since, if it's a
+ // directory, we need to check for write permission on it.
+ oldParent.dirMu.Lock()
+ defer oldParent.dirMu.Unlock()
+ renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds)
+ if err != nil {
+ return err
+ }
+ if renamed == nil {
+ return syserror.ENOENT
+ }
+ if err := oldParent.mayDelete(creds, renamed); err != nil {
+ return err
+ }
+ if renamed.isDir() {
+ if renamed == newParent || genericIsAncestorDentry(renamed, newParent) {
+ return syserror.EINVAL
+ }
+ if oldParent != newParent {
+ if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ }
+ } else {
+ if opts.MustBeDir || rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+
+ if oldParent != newParent {
+ if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ newParent.dirMu.Lock()
+ defer newParent.dirMu.Unlock()
+ }
+ if newParent.isDeleted() {
+ return syserror.ENOENT
+ }
+ replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds)
+ if err != nil {
+ return err
+ }
+ var replacedVFSD *vfs.Dentry
+ if replaced != nil {
+ replacedVFSD = &replaced.vfsd
+ if replaced.isDir() {
+ if !renamed.isDir() {
+ return syserror.EISDIR
+ }
+ } else {
+ if rp.MustBeDir() || renamed.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ }
+
+ if oldParent == newParent && oldName == newName {
+ return nil
+ }
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
+ return err
+ }
+
+ // Update the remote filesystem.
+ if !renamed.isSynthetic() {
+ if err := renamed.file.rename(ctx, newParent.file, newName); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ } else if replaced != nil && !replaced.isSynthetic() {
+ // We are replacing an existing real file with a synthetic one, so we
+ // need to unlink the former.
+ flags := uint32(0)
+ if replaced.isDir() {
+ flags = linux.AT_REMOVEDIR
+ }
+ if err := newParent.file.unlinkAt(ctx, newName, flags); err != nil {
+ vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD)
+ return err
+ }
+ }
+
+ // Update the dentry tree.
+ vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD)
+ if replaced != nil {
+ replaced.setDeleted()
+ if replaced.isSynthetic() {
+ newParent.syntheticChildren--
+ replaced.decRefLocked()
+ }
+ ds = appendDentry(ds, replaced)
+ }
+ oldParent.cacheNegativeLookupLocked(oldName)
+ // We don't use newParent.cacheNewChildLocked() since we don't want to mess
+ // with reference counts and queue oldParent for checkCachingLocked if the
+ // parent isn't actually changing.
+ if oldParent != newParent {
+ ds = appendDentry(ds, oldParent)
+ newParent.IncRef()
+ if renamed.isSynthetic() {
+ oldParent.syntheticChildren--
+ newParent.syntheticChildren++
+ }
+ }
+ renamed.parent = newParent
+ renamed.name = newName
+ if newParent.children == nil {
+ newParent.children = make(map[string]*dentry)
+ }
+ newParent.children[newName] = renamed
+
+ // Update metadata.
+ if renamed.cachedMetadataAuthoritative() {
+ renamed.touchCtime()
+ }
+ if oldParent.cachedMetadataAuthoritative() {
+ oldParent.dirents = nil
+ oldParent.touchCMtime()
+ if renamed.isDir() {
+ oldParent.decLinks()
+ }
+ }
+ if newParent.cachedMetadataAuthoritative() {
+ newParent.dirents = nil
+ newParent.touchCMtime()
+ if renamed.isDir() && (replaced == nil || !replaced.isDir()) {
+ // Increase the link count if we did not replace another directory.
+ newParent.incLinks()
+ }
+ }
+ vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir())
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return fs.unlinkAt(ctx, rp, true /* dir */)
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ return err
+ }
+ if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ // Since walking updates metadata for all traversed dentries under
+ // InteropModeShared, including the returned one, we can return cached
+ // metadata here regardless of fs.opts.interop.
+ var stat linux.Statx
+ d.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ // If d is synthetic, invoke statfs on the first ancestor of d that isn't.
+ for d.isSynthetic() {
+ d = d.parent
+ }
+ fsstat, err := d.file.statFS(ctx)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ nameLen := uint64(fsstat.NameLength)
+ if nameLen > maxFilenameLen {
+ nameLen = maxFilenameLen
+ }
+ return linux.Statfs{
+ // This is primarily for distinguishing a gofer file system in
+ // tests. Testing is important, so instead of defining
+ // something completely random, use a standard value.
+ Type: linux.V9FS_MAGIC,
+ BlockSize: int64(fsstat.BlockSize),
+ Blocks: fsstat.Blocks,
+ BlocksFree: fsstat.BlocksFree,
+ BlocksAvailable: fsstat.BlocksAvailable,
+ Files: fsstat.Files,
+ FilesFree: fsstat.FilesFree,
+ NameLength: nameLen,
+ }, nil
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error {
+ creds := rp.Credentials()
+ _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID))
+ return err
+ }, nil)
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ return fs.unlinkAt(ctx, rp, false /* dir */)
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if d.isSocket() {
+ if !d.isSynthetic() {
+ d.IncRef()
+ return &endpoint{
+ dentry: d,
+ file: d.file.file,
+ path: opts.Addr,
+ }, nil
+ }
+ return d.endpoint, nil
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ return d.listxattr(ctx, rp.Credentials(), size)
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ return d.getxattr(ctx, rp.Credentials(), &opts)
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ return err
+ }
+ if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ return err
+ }
+ if err := d.removexattr(ctx, rp.Credentials(), name); err != nil {
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+ return err
+ }
+ fs.renameMuRUnlockAndCheckCaching(ctx, &ds)
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.renameMu.RLock()
+ defer fs.renameMu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
+}
+
+func (fs *filesystem) nextSyntheticIno() inodeNumber {
+ return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask)
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
new file mode 100644
index 000000000..63e589859
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -0,0 +1,1708 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gofer provides a filesystem implementation that is backed by a 9p
+// server, interchangably referred to as "gofers" throughout this package.
+//
+// Lock order:
+// regularFileFD/directoryFD.mu
+// filesystem.renameMu
+// dentry.dirMu
+// filesystem.syncMu
+// dentry.metadataMu
+// *** "memmap.Mappable locks" below this point
+// dentry.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// dentry.handleMu
+// dentry.dataMu
+//
+// Locking dentry.dirMu in multiple dentries requires that either ancestor
+// dentries are locked before descendant dentries, or that filesystem.renameMu
+// is locked for writing.
+package gofer
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Name is the default filesystem name.
+const Name = "9p"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // mfp is used to allocate memory that caches regular file contents. mfp is
+ // immutable.
+ mfp pgalloc.MemoryFileProvider
+
+ // Immutable options.
+ opts filesystemOptions
+ iopts InternalFilesystemOptions
+
+ // client is the client used by this filesystem. client is immutable.
+ client *p9.Client
+
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock ktime.Clock
+
+ // devMinor is the filesystem's minor device number. devMinor is immutable.
+ devMinor uint32
+
+ // renameMu serves two purposes:
+ //
+ // - It synchronizes path resolution with renaming initiated by this
+ // client.
+ //
+ // - It is held by path resolution to ensure that reachable dentries remain
+ // valid. A dentry is reachable by path resolution if it has a non-zero
+ // reference count (such that it is usable as vfs.ResolvingPath.Start() or
+ // is reachable from its children), or if it is a child dentry (such that
+ // it is reachable from its parent).
+ renameMu sync.RWMutex
+
+ // cachedDentries contains all dentries with 0 references. (Due to race
+ // conditions, it may also contain dentries with non-zero references.)
+ // cachedDentriesLen is the number of dentries in cachedDentries. These
+ // fields are protected by renameMu.
+ cachedDentries dentryList
+ cachedDentriesLen uint64
+
+ // syncableDentries contains all dentries in this filesystem for which
+ // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs.
+ // These fields are protected by syncMu.
+ syncMu sync.Mutex
+ syncableDentries map[*dentry]struct{}
+ specialFileFDs map[*specialFileFD]struct{}
+
+ // syntheticSeq stores a counter to used to generate unique inodeNumber for
+ // synthetic dentries.
+ syntheticSeq uint64
+}
+
+// inodeNumber represents inode number reported in Dirent.Ino. For regular
+// dentries, it comes from QID.Path from the 9P server. Synthetic dentries
+// have have their inodeNumber generated sequentially, with the MSB reserved to
+// prevent conflicts with regular dentries.
+type inodeNumber uint64
+
+// Reserve MSB for synthetic mounts.
+const syntheticInoMask = uint64(1) << 63
+
+func inoFromPath(path uint64) inodeNumber {
+ if path&syntheticInoMask != 0 {
+ log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask)
+ }
+ return inodeNumber(path &^ syntheticInoMask)
+}
+
+type filesystemOptions struct {
+ // "Standard" 9P options.
+ fd int
+ aname string
+ interop InteropMode // derived from the "cache" mount option
+ dfltuid auth.KUID
+ dfltgid auth.KGID
+ msize uint32
+ version string
+
+ // maxCachedDentries is the maximum number of dentries with 0 references
+ // retained by the client.
+ maxCachedDentries uint64
+
+ // If forcePageCache is true, host FDs may not be used for application
+ // memory mappings even if available; instead, the client must perform its
+ // own caching of regular file pages. This is primarily useful for testing.
+ forcePageCache bool
+
+ // If limitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host FD mappings returned by dentry.(memmap.Mappable).Translate(). This
+ // makes memory accounting behavior more consistent between cases where
+ // host FDs are / are not available, but may increase the frequency of
+ // sentry-handled page faults on files for which a host FD is available.
+ limitHostFDTranslation bool
+
+ // If overlayfsStaleRead is true, O_RDONLY host FDs provided by the remote
+ // filesystem may not be coherent with writable host FDs opened later, so
+ // all uses of the former must be replaced by uses of the latter. This is
+ // usually only the case when the remote filesystem is a Linux overlayfs
+ // mount. (Prior to Linux 4.18, patch series centered on commit
+ // d1d04ef8572b "ovl: stack file ops", both I/O and memory mappings were
+ // incoherent between pre-copy-up and post-copy-up FDs; after that patch
+ // series, only memory mappings are incoherent.)
+ overlayfsStaleRead bool
+
+ // If regularFilesUseSpecialFileFD is true, application FDs representing
+ // regular files will use distinct file handles for each FD, in the same
+ // way that application FDs representing "special files" such as sockets
+ // do. Note that this disables client caching and mmap for regular files.
+ regularFilesUseSpecialFileFD bool
+}
+
+// InteropMode controls the client's interaction with other remote filesystem
+// users.
+type InteropMode uint32
+
+const (
+ // InteropModeExclusive is appropriate when the filesystem client is the
+ // only user of the remote filesystem.
+ //
+ // - The client may cache arbitrary filesystem state (file data, metadata,
+ // filesystem structure, etc.).
+ //
+ // - Client changes to filesystem state may be sent to the remote
+ // filesystem asynchronously, except when server permission checks are
+ // necessary.
+ //
+ // - File timestamps are based on client clocks. This ensures that users of
+ // the client observe timestamps that are coherent with their own clocks
+ // and consistent with Linux's semantics (in particular, it is not always
+ // possible for clients to set arbitrary atimes and mtimes depending on the
+ // remote filesystem implementation, and never possible for clients to set
+ // arbitrary ctimes.) If a dentry containing a client-defined atime or
+ // mtime is evicted from cache, client timestamps will be sent to the
+ // remote filesystem on a best-effort basis to attempt to ensure that
+ // timestamps will be preserved when another dentry representing the same
+ // file is instantiated.
+ InteropModeExclusive InteropMode = iota
+
+ // InteropModeWritethrough is appropriate when there are read-only users of
+ // the remote filesystem that expect to observe changes made by the
+ // filesystem client.
+ //
+ // - The client may cache arbitrary filesystem state.
+ //
+ // - Client changes to filesystem state must be sent to the remote
+ // filesystem synchronously.
+ //
+ // - File timestamps are based on client clocks. As a corollary, access
+ // timestamp changes from other remote filesystem users will not be visible
+ // to the client.
+ InteropModeWritethrough
+
+ // InteropModeShared is appropriate when there are users of the remote
+ // filesystem that may mutate its state other than the client.
+ //
+ // - The client must verify ("revalidate") cached filesystem state before
+ // using it.
+ //
+ // - Client changes to filesystem state must be sent to the remote
+ // filesystem synchronously.
+ //
+ // - File timestamps are based on server clocks. This is necessary to
+ // ensure that timestamp changes are synchronized between remote filesystem
+ // users.
+ //
+ // Note that the correctness of InteropModeShared depends on the server
+ // correctly implementing 9P fids (i.e. each fid immutably represents a
+ // single filesystem object), even in the presence of remote filesystem
+ // mutations from other users. If this is violated, the behavior of the
+ // client is undefined.
+ InteropModeShared
+)
+
+// InternalFilesystemOptions may be passed as
+// vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem.
+type InternalFilesystemOptions struct {
+ // If LeakConnection is true, do not close the connection to the server
+ // when the Filesystem is released. This is necessary for deployments in
+ // which servers can handle only a single client and report failure if that
+ // client disconnects.
+ LeakConnection bool
+
+ // If OpenSocketsByConnecting is true, silently translate attempts to open
+ // files identifying as sockets to connect RPCs.
+ OpenSocketsByConnecting bool
+}
+
+// _V9FS_DEFUID and _V9FS_DEFGID (from Linux's fs/9p/v9fs.h) are the default
+// UIDs and GIDs used for files that do not provide a specific owner or group
+// respectively.
+const (
+ // uint32(-2) doesn't work in Go.
+ _V9FS_DEFUID = auth.KUID(4294967294)
+ _V9FS_DEFGID = auth.KGID(4294967294)
+)
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mfp := pgalloc.MemoryFileProviderFromContext(ctx)
+ if mfp == nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: context does not provide a pgalloc.MemoryFileProvider")
+ return nil, nil, syserror.EINVAL
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ var fsopts filesystemOptions
+
+ // Check that the transport is "fd".
+ trans, ok := mopts["trans"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "trans")
+ if trans != "fd" {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Check that read and write FDs are provided and identical.
+ rfdstr, ok := mopts["rfdno"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "rfdno")
+ rfd, err := strconv.Atoi(rfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr)
+ return nil, nil, syserror.EINVAL
+ }
+ wfdstr, ok := mopts["wfdno"]
+ if !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "wfdno")
+ wfd, err := strconv.Atoi(wfdstr)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr)
+ return nil, nil, syserror.EINVAL
+ }
+ if rfd != wfd {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.fd = rfd
+
+ // Get the attach name.
+ fsopts.aname = "/"
+ if aname, ok := mopts["aname"]; ok {
+ delete(mopts, "aname")
+ fsopts.aname = aname
+ }
+
+ // Parse the cache policy. For historical reasons, this defaults to the
+ // least generally-applicable option, InteropModeExclusive.
+ fsopts.interop = InteropModeExclusive
+ if cache, ok := mopts["cache"]; ok {
+ delete(mopts, "cache")
+ switch cache {
+ case "fscache":
+ fsopts.interop = InteropModeExclusive
+ case "fscache_writethrough":
+ fsopts.interop = InteropModeWritethrough
+ case "none":
+ fsopts.regularFilesUseSpecialFileFD = true
+ fallthrough
+ case "remote_revalidating":
+ fsopts.interop = InteropModeShared
+ default:
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
+ // Parse the default UID and GID.
+ fsopts.dfltuid = _V9FS_DEFUID
+ if dfltuidstr, ok := mopts["dfltuid"]; ok {
+ delete(mopts, "dfltuid")
+ dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ // In Linux, dfltuid is interpreted as a UID and is converted to a KUID
+ // in the caller's user namespace, but goferfs isn't
+ // application-mountable.
+ fsopts.dfltuid = auth.KUID(dfltuid)
+ }
+ fsopts.dfltgid = _V9FS_DEFGID
+ if dfltgidstr, ok := mopts["dfltgid"]; ok {
+ delete(mopts, "dfltgid")
+ dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.dfltgid = auth.KGID(dfltgid)
+ }
+
+ // Parse the 9P message size.
+ fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M
+ if msizestr, ok := mopts["msize"]; ok {
+ delete(mopts, "msize")
+ msize, err := strconv.ParseUint(msizestr, 10, 32)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.msize = uint32(msize)
+ }
+
+ // Parse the 9P protocol version.
+ fsopts.version = p9.HighestVersionString()
+ if version, ok := mopts["version"]; ok {
+ delete(mopts, "version")
+ fsopts.version = version
+ }
+
+ // Parse the dentry cache limit.
+ fsopts.maxCachedDentries = 1000
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err := strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.maxCachedDentries = maxCachedDentries
+ }
+
+ // Handle simple flags.
+ if _, ok := mopts["force_page_cache"]; ok {
+ delete(mopts, "force_page_cache")
+ fsopts.forcePageCache = true
+ }
+ if _, ok := mopts["limit_host_fd_translation"]; ok {
+ delete(mopts, "limit_host_fd_translation")
+ fsopts.limitHostFDTranslation = true
+ }
+ if _, ok := mopts["overlayfs_stale_read"]; ok {
+ delete(mopts, "overlayfs_stale_read")
+ fsopts.overlayfsStaleRead = true
+ }
+ // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
+ // "cache=none".
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Handle internal options.
+ iopts, ok := opts.InternalData.(InternalFilesystemOptions)
+ if opts.InternalData != nil && !ok {
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted gofer.InternalFilesystemOptions", opts.InternalData)
+ return nil, nil, syserror.EINVAL
+ }
+ // If !ok, iopts being the zero value is correct.
+
+ // Establish a connection with the server.
+ conn, err := unet.NewSocket(fsopts.fd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Perform version negotiation with the server.
+ ctx.UninterruptibleSleepStart(false)
+ client, err := p9.NewClient(conn, fsopts.msize, fsopts.version)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ conn.Close()
+ return nil, nil, err
+ }
+ // Ownership of conn has been transferred to client.
+
+ // Perform attach to obtain the filesystem root.
+ ctx.UninterruptibleSleepStart(false)
+ attached, err := client.Attach(fsopts.aname)
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ client.Close()
+ return nil, nil, err
+ }
+ attachFile := p9file{attached}
+ qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask())
+ if err != nil {
+ attachFile.close(ctx)
+ client.Close()
+ return nil, nil, err
+ }
+
+ // Construct the filesystem object.
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ attachFile.close(ctx)
+ client.Close()
+ return nil, nil, err
+ }
+ fs := &filesystem{
+ mfp: mfp,
+ opts: fsopts,
+ iopts: iopts,
+ client: client,
+ clock: ktime.RealtimeClockFromContext(ctx),
+ devMinor: devMinor,
+ syncableDentries: make(map[*dentry]struct{}),
+ specialFileFDs: make(map[*specialFileFD]struct{}),
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
+
+ // Construct the root dentry.
+ root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr)
+ if err != nil {
+ attachFile.close(ctx)
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+ // Set the root's reference count to 2. One reference is returned to the
+ // caller, and the other is deliberately leaked to prevent the root from
+ // being "cached" and subsequently evicted. Its resources will still be
+ // cleaned up by fs.Release().
+ root.refs = 2
+
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ mf := fs.mfp.MemoryFile()
+
+ fs.syncMu.Lock()
+ for d := range fs.syncableDentries {
+ d.handleMu.Lock()
+ d.dataMu.Lock()
+ if h := d.writeHandleLocked(); h.isOpen() {
+ // Write dirty cached data to the remote file.
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil {
+ log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err)
+ }
+ // TODO(jamieliu): Do we need to flushf/fsync d?
+ }
+ // Discard cached pages.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+ d.dataMu.Unlock()
+ // Close the host fd if one exists.
+ if d.hostFD >= 0 {
+ syscall.Close(int(d.hostFD))
+ d.hostFD = -1
+ }
+ d.handleMu.Unlock()
+ }
+ // There can't be any specialFileFDs still using fs, since each such
+ // FileDescription would hold a reference on a Mount holding a reference on
+ // fs.
+ fs.syncMu.Unlock()
+
+ if !fs.iopts.LeakConnection {
+ // Close the connection to the server. This implicitly clunks all fids.
+ fs.client.Close()
+ }
+
+ fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ // refs is the reference count. Each dentry holds a reference on its
+ // parent, even if disowned. An additional reference is held on all
+ // synthetic dentries until they are unlinked or invalidated. When refs
+ // reaches 0, the dentry may be added to the cache or destroyed. If refs ==
+ // -1, the dentry has already been destroyed. refs is accessed using atomic
+ // memory operations.
+ refs int64
+
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // parent is this dentry's parent directory. Each dentry holds a reference
+ // on its parent. If this dentry is a filesystem root, parent is nil.
+ // parent is protected by filesystem.renameMu.
+ parent *dentry
+
+ // name is the name of this dentry in its parent. If this dentry is a
+ // filesystem root, name is the empty string. name is protected by
+ // filesystem.renameMu.
+ name string
+
+ // file is the unopened p9.File that backs this dentry. file is immutable.
+ //
+ // If file.isNil(), this dentry represents a synthetic file, i.e. a file
+ // that does not exist on the remote filesystem. As of this writing, the
+ // only files that can be synthetic are sockets, pipes, and directories.
+ file p9file
+
+ // If deleted is non-zero, the file represented by this dentry has been
+ // deleted. deleted is accessed using atomic memory operations.
+ deleted uint32
+
+ // If cached is true, dentryEntry links dentry into
+ // filesystem.cachedDentries. cached and dentryEntry are protected by
+ // filesystem.renameMu.
+ cached bool
+ dentryEntry
+
+ dirMu sync.Mutex
+
+ // If this dentry represents a directory, children contains:
+ //
+ // - Mappings of child filenames to dentries representing those children.
+ //
+ // - Mappings of child filenames that are known not to exist to nil
+ // dentries (only if InteropModeShared is not in effect and the directory
+ // is not synthetic).
+ //
+ // children is protected by dirMu.
+ children map[string]*dentry
+
+ // If this dentry represents a directory, syntheticChildren is the number
+ // of child dentries for which dentry.isSynthetic() == true.
+ // syntheticChildren is protected by dirMu.
+ syntheticChildren int
+
+ // If this dentry represents a directory,
+ // dentry.cachedMetadataAuthoritative() == true, and dirents is not nil, it
+ // is a cache of all entries in the directory, in the order they were
+ // returned by the server. dirents is protected by dirMu.
+ dirents []vfs.Dirent
+
+ // Cached metadata; protected by metadataMu.
+ // To access:
+ // - In situations where consistency is not required (like stat), these
+ // can be accessed using atomic operations only (without locking).
+ // - Lock metadataMu and can access without atomic operations.
+ // To mutate:
+ // - Lock metadataMu and use atomic operations to update because we might
+ // have atomic readers that don't hold the lock.
+ metadataMu sync.Mutex
+ ino inodeNumber // immutable
+ mode uint32 // type is immutable, perms are mutable
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ blockSize uint32 // 0 if unknown
+ // Timestamps, all nsecs from the Unix epoch.
+ atime int64
+ mtime int64
+ ctime int64
+ btime int64
+ // File size, which differs from other metadata in two ways:
+ //
+ // - We make a best-effort attempt to keep it up to date even if
+ // !dentry.cachedMetadataAuthoritative() for the sake of O_APPEND writes.
+ //
+ // - size is protected by both metadataMu and dataMu (i.e. both must be
+ // locked to mutate it; locking either is sufficient to access it).
+ size uint64
+ // If this dentry does not represent a synthetic file, deleted is 0, and
+ // atimeDirty/mtimeDirty are non-zero, atime/mtime may have diverged from the
+ // remote file's timestamps, which should be updated when this dentry is
+ // evicted.
+ atimeDirty uint32
+ mtimeDirty uint32
+
+ // nlink counts the number of hard links to this dentry. It's updated and
+ // accessed using atomic operations. It's not protected by metadataMu like the
+ // other metadata fields.
+ nlink uint32
+
+ mapsMu sync.Mutex
+
+ // If this dentry represents a regular file, mappings tracks mappings of
+ // the file into memmap.MappingSpaces. mappings is protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // - If this dentry represents a regular file or directory, readFile is the
+ // p9.File used for reads by all regularFileFDs/directoryFDs representing
+ // this dentry.
+ //
+ // - If this dentry represents a regular file, writeFile is the p9.File
+ // used for writes by all regularFileFDs representing this dentry.
+ //
+ // - If this dentry represents a regular file, hostFD is the host FD used
+ // for memory mappings and I/O (when applicable) in preference to readFile
+ // and writeFile. hostFD is always readable; if !writeFile.isNil(), it must
+ // also be writable. If hostFD is -1, no such host FD is available.
+ //
+ // These fields are protected by handleMu.
+ //
+ // readFile and writeFile may or may not represent the same p9.File. Once
+ // either p9.File transitions from closed (isNil() == true) to open
+ // (isNil() == false), it may be mutated with handleMu locked, but cannot
+ // be closed until the dentry is destroyed.
+ handleMu sync.RWMutex
+ readFile p9file
+ writeFile p9file
+ hostFD int32
+
+ dataMu sync.RWMutex
+
+ // If this dentry represents a regular file that is client-cached, cache
+ // maps offsets into the cached file to offsets into
+ // filesystem.mfp.MemoryFile() that store the file's data. cache is
+ // protected by dataMu.
+ cache fsutil.FileRangeSet
+
+ // If this dentry represents a regular file that is client-cached, dirty
+ // tracks dirty segments in cache. dirty is protected by dataMu.
+ dirty fsutil.DirtySet
+
+ // pf implements platform.File for mappings of hostFD.
+ pf dentryPlatformFile
+
+ // If this dentry represents a symbolic link, InteropModeShared is not in
+ // effect, and haveTarget is true, target is the symlink target. haveTarget
+ // and target are protected by dataMu.
+ haveTarget bool
+ target string
+
+ // If this dentry represents a synthetic socket file, endpoint is the
+ // transport endpoint bound to this file.
+ endpoint transport.BoundEndpoint
+
+ // If this dentry represents a synthetic named pipe, pipe is the pipe
+ // endpoint bound to this file.
+ pipe *pipe.VFSPipe
+
+ locks vfs.FileLocks
+
+ // Inotify watches for this dentry.
+ watches vfs.Watches
+}
+
+// dentryAttrMask returns a p9.AttrMask enabling all attributes used by the
+// gofer client.
+func dentryAttrMask() p9.AttrMask {
+ return p9.AttrMask{
+ Mode: true,
+ UID: true,
+ GID: true,
+ ATime: true,
+ MTime: true,
+ CTime: true,
+ Size: true,
+ BTime: true,
+ }
+}
+
+// newDentry creates a new dentry representing the given file. The dentry
+// initially has no references, but is not cached; it is the caller's
+// responsibility to set the dentry's reference count and/or call
+// dentry.checkCachingLocked() as appropriate.
+//
+// Preconditions: !file.isNil().
+func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, mask p9.AttrMask, attr *p9.Attr) (*dentry, error) {
+ if !mask.Mode {
+ ctx.Warningf("can't create gofer.dentry without file type")
+ return nil, syserror.EIO
+ }
+ if attr.Mode.FileType() == p9.ModeRegular && !mask.Size {
+ ctx.Warningf("can't create regular file gofer.dentry without file size")
+ return nil, syserror.EIO
+ }
+
+ d := &dentry{
+ fs: fs,
+ file: file,
+ ino: inoFromPath(qid.Path),
+ mode: uint32(attr.Mode),
+ uid: uint32(fs.opts.dfltuid),
+ gid: uint32(fs.opts.dfltgid),
+ blockSize: usermem.PageSize,
+ hostFD: -1,
+ }
+ d.pf.dentry = d
+ if mask.UID {
+ d.uid = dentryUIDFromP9UID(attr.UID)
+ }
+ if mask.GID {
+ d.gid = dentryGIDFromP9GID(attr.GID)
+ }
+ if mask.Size {
+ d.size = attr.Size
+ }
+ if attr.BlockSize != 0 {
+ d.blockSize = uint32(attr.BlockSize)
+ }
+ if mask.ATime {
+ d.atime = dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds)
+ }
+ if mask.MTime {
+ d.mtime = dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds)
+ }
+ if mask.CTime {
+ d.ctime = dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds)
+ }
+ if mask.BTime {
+ d.btime = dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds)
+ }
+ if mask.NLink {
+ d.nlink = uint32(attr.NLink)
+ }
+ d.vfsd.Init(d)
+
+ fs.syncMu.Lock()
+ fs.syncableDentries[d] = struct{}{}
+ fs.syncMu.Unlock()
+ return d, nil
+}
+
+func (d *dentry) isSynthetic() bool {
+ return d.file.isNil()
+}
+
+func (d *dentry) cachedMetadataAuthoritative() bool {
+ return d.fs.opts.interop != InteropModeShared || d.isSynthetic()
+}
+
+// updateFromP9Attrs is called to update d's metadata after an update from the
+// remote filesystem.
+// Precondition: d.metadataMu must be locked.
+func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) {
+ if mask.Mode {
+ if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want {
+ d.metadataMu.Unlock()
+ panic(fmt.Sprintf("gofer.dentry file type changed from %#o to %#o", want, got))
+ }
+ atomic.StoreUint32(&d.mode, uint32(attr.Mode))
+ }
+ if mask.UID {
+ atomic.StoreUint32(&d.uid, dentryUIDFromP9UID(attr.UID))
+ }
+ if mask.GID {
+ atomic.StoreUint32(&d.gid, dentryGIDFromP9GID(attr.GID))
+ }
+ // There is no P9_GETATTR_* bit for I/O block size.
+ if attr.BlockSize != 0 {
+ atomic.StoreUint32(&d.blockSize, uint32(attr.BlockSize))
+ }
+ // Don't override newer client-defined timestamps with old server-defined
+ // ones.
+ if mask.ATime && atomic.LoadUint32(&d.atimeDirty) == 0 {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds))
+ }
+ if mask.MTime && atomic.LoadUint32(&d.mtimeDirty) == 0 {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds))
+ }
+ if mask.CTime {
+ atomic.StoreInt64(&d.ctime, dentryTimestampFromP9(attr.CTimeSeconds, attr.CTimeNanoSeconds))
+ }
+ if mask.BTime {
+ atomic.StoreInt64(&d.btime, dentryTimestampFromP9(attr.BTimeSeconds, attr.BTimeNanoSeconds))
+ }
+ if mask.NLink {
+ atomic.StoreUint32(&d.nlink, uint32(attr.NLink))
+ }
+ if mask.Size {
+ d.updateFileSizeLocked(attr.Size)
+ }
+}
+
+// Preconditions: !d.isSynthetic()
+func (d *dentry) updateFromGetattr(ctx context.Context) error {
+ // Use d.readFile or d.writeFile, which represent 9P fids that have been
+ // opened, in preference to d.file, which represents a 9P fid that has not.
+ // This may be significantly more efficient in some implementations. Prefer
+ // d.writeFile over d.readFile since some filesystem implementations may
+ // update a writable handle's metadata after writes to that handle, without
+ // making metadata updates immediately visible to read-only handles
+ // representing the same file.
+ var (
+ file p9file
+ handleMuRLocked bool
+ )
+ // d.metadataMu must be locked *before* we getAttr so that we do not end up
+ // updating stale attributes in d.updateFromP9AttrsLocked().
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ d.handleMu.RLock()
+ if !d.writeFile.isNil() {
+ file = d.writeFile
+ handleMuRLocked = true
+ } else if !d.readFile.isNil() {
+ file = d.readFile
+ handleMuRLocked = true
+ } else {
+ file = d.file
+ d.handleMu.RUnlock()
+ }
+ _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask())
+ if handleMuRLocked {
+ d.handleMu.RUnlock()
+ }
+ if err != nil {
+ return err
+ }
+ d.updateFromP9AttrsLocked(attrMask, &attr)
+ return nil
+}
+
+func (d *dentry) fileType() uint32 {
+ return atomic.LoadUint32(&d.mode) & linux.S_IFMT
+}
+
+func (d *dentry) statTo(stat *linux.Statx) {
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME
+ stat.Blksize = atomic.LoadUint32(&d.blockSize)
+ stat.Nlink = atomic.LoadUint32(&d.nlink)
+ if stat.Nlink == 0 {
+ // The remote filesystem doesn't support link count; just make
+ // something up. This is consistent with Linux, where
+ // fs/inode.c:inode_init_always() initializes link count to 1, and
+ // fs/9p/vfs_inode_dotl.c:v9fs_stat2inode_dotl() doesn't touch it if
+ // it's not provided by the remote filesystem.
+ stat.Nlink = 1
+ }
+ stat.UID = atomic.LoadUint32(&d.uid)
+ stat.GID = atomic.LoadUint32(&d.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&d.mode))
+ stat.Ino = uint64(d.ino)
+ stat.Size = atomic.LoadUint64(&d.size)
+ // This is consistent with regularFileFD.Seek(), which treats regular files
+ // as having no holes.
+ stat.Blocks = (stat.Size + 511) / 512
+ stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime))
+ stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime))
+ stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime))
+ stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime))
+ stat.DevMajor = linux.UNNAMED_MAJOR
+ stat.DevMinor = d.fs.devMinor
+}
+
+func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error {
+ stat := &opts.Stat
+ if stat.Mask == 0 {
+ return nil
+ }
+ if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 {
+ return syserror.EPERM
+ }
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // Reject attempts to truncate files other than regular files, since
+ // filesystem implementations may return the wrong errno.
+ switch mode.FileType() {
+ case linux.S_IFREG:
+ // ok
+ case linux.S_IFDIR:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
+
+ var now int64
+ if d.cachedMetadataAuthoritative() {
+ // Truncate updates mtime.
+ if stat.Mask&(linux.STATX_SIZE|linux.STATX_MTIME) == linux.STATX_SIZE {
+ stat.Mask |= linux.STATX_MTIME
+ stat.Mtime = linux.StatxTimestamp{
+ Nsec: linux.UTIME_NOW,
+ }
+ }
+
+ // Use client clocks for timestamps.
+ now = d.fs.clock.Now().Nanoseconds()
+ if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW {
+ stat.Atime = statxTimestampFromDentry(now)
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW {
+ stat.Mtime = statxTimestampFromDentry(now)
+ }
+ }
+
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ if !d.isSynthetic() {
+ if stat.Mask != 0 {
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
+ UID: stat.Mask&linux.STATX_UID != 0,
+ GID: stat.Mask&linux.STATX_GID != 0,
+ Size: stat.Mask&linux.STATX_SIZE != 0,
+ ATime: stat.Mask&linux.STATX_ATIME != 0,
+ MTime: stat.Mask&linux.STATX_MTIME != 0,
+ ATimeNotSystemTime: stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW,
+ MTimeNotSystemTime: stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW,
+ }, p9.SetAttr{
+ Permissions: p9.FileMode(stat.Mode),
+ UID: p9.UID(stat.UID),
+ GID: p9.GID(stat.GID),
+ Size: stat.Size,
+ ATimeSeconds: uint64(stat.Atime.Sec),
+ ATimeNanoSeconds: uint64(stat.Atime.Nsec),
+ MTimeSeconds: uint64(stat.Mtime.Sec),
+ MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
+ }); err != nil {
+ return err
+ }
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // d.size should be kept up to date, and privatized
+ // copy-on-write mappings of truncated pages need to be
+ // invalidated, even if InteropModeShared is in effect.
+ d.updateFileSizeLocked(stat.Size)
+ }
+ }
+ if d.fs.opts.interop == InteropModeShared {
+ // There's no point to updating d's metadata in this case since
+ // it'll be overwritten by revalidation before the next time it's
+ // used anyway. (InteropModeShared inhibits client caching of
+ // regular file data, so there's no cache to truncate either.)
+ return nil
+ }
+ }
+ if stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, stat.UID)
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.gid, stat.GID)
+ }
+ // Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because
+ // if d.cachedMetadataAuthoritative() then we converted stat.Atime and
+ // stat.Mtime to client-local timestamps above, and if
+ // !d.cachedMetadataAuthoritative() then we returned after calling
+ // d.file.setAttr(). For the same reason, now must have been initialized.
+ if stat.Mask&linux.STATX_ATIME != 0 {
+ atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime))
+ atomic.StoreUint32(&d.atimeDirty, 0)
+ }
+ if stat.Mask&linux.STATX_MTIME != 0 {
+ atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime))
+ atomic.StoreUint32(&d.mtimeDirty, 0)
+ }
+ atomic.StoreInt64(&d.ctime, now)
+ return nil
+}
+
+// Preconditions: d.metadataMu must be locked.
+func (d *dentry) updateFileSizeLocked(newSize uint64) {
+ d.dataMu.Lock()
+ oldSize := d.size
+ atomic.StoreUint64(&d.size, newSize)
+ // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
+ // below. This allows concurrent calls to Read/Translate/etc. These
+ // functions synchronize with truncation by refusing to use cache
+ // contents beyond the new d.size. (We are still holding d.metadataMu,
+ // so we can't race with Write or another truncate.)
+ d.dataMu.Unlock()
+ if d.size < oldSize {
+ oldpgend, _ := usermem.PageRoundUp(oldSize)
+ newpgend, _ := usermem.PageRoundUp(d.size)
+ if oldpgend != newpgend {
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ d.mapsMu.Unlock()
+ }
+ // We are now guaranteed that there are no translations of
+ // truncated pages, and can remove them from the cache. Since
+ // truncated pages have been removed from the remote file, they
+ // should be dropped without being written back.
+ d.dataMu.Lock()
+ d.cache.Truncate(d.size, d.fs.mfp.MemoryFile())
+ d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend})
+ d.dataMu.Unlock()
+ }
+}
+
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+}
+
+func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error {
+ return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid)))
+}
+
+func dentryUIDFromP9UID(uid p9.UID) uint32 {
+ if !uid.Ok() {
+ return uint32(auth.OverflowUID)
+ }
+ return uint32(uid)
+}
+
+func dentryGIDFromP9GID(gid p9.GID) uint32 {
+ if !gid.Ok() {
+ return uint32(auth.OverflowGID)
+ }
+ return uint32(gid)
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ // d.refs may be 0 if d.fs.renameMu is locked, which serializes against
+ // d.checkCachingLocked().
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef(ctx context.Context) {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.renameMu.Lock()
+ d.checkCachingLocked(ctx)
+ d.fs.renameMu.Unlock()
+ } else if refs < 0 {
+ panic("gofer.dentry.DecRef() called without holding a reference")
+ }
+}
+
+// decRefLocked decrements d's reference count without calling
+// d.checkCachingLocked, even if d's reference count reaches 0; callers are
+// responsible for ensuring that d.checkCachingLocked will be called later.
+func (d *dentry) decRefLocked() {
+ if refs := atomic.AddInt64(&d.refs, -1); refs < 0 {
+ panic("gofer.dentry.decRefLocked() called without holding a reference")
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {
+ if d.isDir() {
+ events |= linux.IN_ISDIR
+ }
+
+ d.fs.renameMu.RLock()
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ d.parent.watches.Notify(ctx, d.name, events, cookie, et, d.isDeleted())
+ }
+ d.watches.Notify(ctx, "", events, cookie, et, d.isDeleted())
+ d.fs.renameMu.RUnlock()
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ return &d.watches
+}
+
+// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
+//
+// If no watches are left on this dentry and it has no references, cache it.
+func (d *dentry) OnZeroWatches(ctx context.Context) {
+ if atomic.LoadInt64(&d.refs) == 0 {
+ d.fs.renameMu.Lock()
+ d.checkCachingLocked(ctx)
+ d.fs.renameMu.Unlock()
+ }
+}
+
+// checkCachingLocked should be called after d's reference count becomes 0 or it
+// becomes disowned.
+//
+// It may be called on a destroyed dentry. For example,
+// renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times
+// for the same dentry when the dentry is visited more than once in the same
+// operation. One of the calls may destroy the dentry, so subsequent calls will
+// do nothing.
+//
+// Preconditions: d.fs.renameMu must be locked for writing; it may be
+// temporarily unlocked.
+func (d *dentry) checkCachingLocked(ctx context.Context) {
+ // Dentries with a non-zero reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires renameMu, so if d.refs is zero then it will
+ // remain zero while we hold renameMu for writing.)
+ refs := atomic.LoadInt64(&d.refs)
+ if refs > 0 {
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ return
+ }
+ if refs == -1 {
+ // Dentry has already been destroyed.
+ return
+ }
+ // Deleted and invalidated dentries with zero references are no longer
+ // reachable by path resolution and should be dropped immediately.
+ if d.vfsd.IsDead() {
+ if d.isDeleted() {
+ d.watches.HandleDeletion(ctx)
+ }
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentriesLen--
+ d.cached = false
+ }
+ d.destroyLocked(ctx)
+ return
+ }
+ // If d still has inotify watches and it is not deleted or invalidated, we
+ // cannot cache it and allow it to be evicted. Otherwise, we will lose its
+ // watches, even if a new dentry is created for the same file in the future.
+ // Note that the size of d.watches cannot concurrently transition from zero
+ // to non-zero, because adding a watch requires holding a reference on d.
+ if d.watches.Size() > 0 {
+ return
+ }
+ // If d is already cached, just move it to the front of the LRU.
+ if d.cached {
+ d.fs.cachedDentries.Remove(d)
+ d.fs.cachedDentries.PushFront(d)
+ return
+ }
+ // Cache the dentry, then evict the least recently used cached dentry if
+ // the cache becomes over-full.
+ d.fs.cachedDentries.PushFront(d)
+ d.fs.cachedDentriesLen++
+ d.cached = true
+ if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries {
+ victim := d.fs.cachedDentries.Back()
+ d.fs.cachedDentries.Remove(victim)
+ d.fs.cachedDentriesLen--
+ victim.cached = false
+ // victim.refs may have become non-zero from an earlier path resolution
+ // since it was inserted into fs.cachedDentries.
+ if atomic.LoadInt64(&victim.refs) == 0 {
+ if victim.parent != nil {
+ victim.parent.dirMu.Lock()
+ if !victim.vfsd.IsDead() {
+ // Note that victim can't be a mount point (in any mount
+ // namespace), since VFS holds references on mount points.
+ d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd)
+ delete(victim.parent.children, victim.name)
+ // We're only deleting the dentry, not the file it
+ // represents, so we don't need to update
+ // victimParent.dirents etc.
+ }
+ victim.parent.dirMu.Unlock()
+ }
+ victim.destroyLocked(ctx)
+ }
+ // Whether or not victim was destroyed, we brought fs.cachedDentriesLen
+ // back down to fs.opts.maxCachedDentries, so we don't loop.
+ }
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions:
+// * d.fs.renameMu must be locked for writing; it may be temporarily unlocked.
+// * d.refs == 0.
+// * d.parent.children[d.name] != d, i.e. d is not reachable by path traversal
+// from its former parent dentry.
+func (d *dentry) destroyLocked(ctx context.Context) {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("dentry.destroyLocked() called with references on the dentry")
+ }
+
+ // Allow the following to proceed without renameMu locked to improve
+ // scalability.
+ d.fs.renameMu.Unlock()
+
+ mf := d.fs.mfp.MemoryFile()
+ d.handleMu.Lock()
+ d.dataMu.Lock()
+ if h := d.writeHandleLocked(); h.isOpen() {
+ // Write dirty pages back to the remote filesystem.
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil {
+ log.Warningf("gofer.dentry.destroyLocked: failed to write dirty data back: %v", err)
+ }
+ }
+ // Discard cached data.
+ if !d.cache.IsEmpty() {
+ mf.MarkAllUnevictable(d)
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+ }
+ d.dataMu.Unlock()
+ // Clunk open fids and close open host FDs.
+ if !d.readFile.isNil() {
+ d.readFile.close(ctx)
+ }
+ if !d.writeFile.isNil() && d.readFile != d.writeFile {
+ d.writeFile.close(ctx)
+ }
+ d.readFile = p9file{}
+ d.writeFile = p9file{}
+ if d.hostFD >= 0 {
+ syscall.Close(int(d.hostFD))
+ d.hostFD = -1
+ }
+ d.handleMu.Unlock()
+
+ if !d.file.isNil() {
+ if !d.isDeleted() {
+ // Write dirty timestamps back to the remote filesystem.
+ atimeDirty := atomic.LoadUint32(&d.atimeDirty) != 0
+ mtimeDirty := atomic.LoadUint32(&d.mtimeDirty) != 0
+ if atimeDirty || mtimeDirty {
+ atime := atomic.LoadInt64(&d.atime)
+ mtime := atomic.LoadInt64(&d.mtime)
+ if err := d.file.setAttr(ctx, p9.SetAttrMask{
+ ATime: atimeDirty,
+ ATimeNotSystemTime: atimeDirty,
+ MTime: mtimeDirty,
+ MTimeNotSystemTime: mtimeDirty,
+ }, p9.SetAttr{
+ ATimeSeconds: uint64(atime / 1e9),
+ ATimeNanoSeconds: uint64(atime % 1e9),
+ MTimeSeconds: uint64(mtime / 1e9),
+ MTimeNanoSeconds: uint64(mtime % 1e9),
+ }); err != nil {
+ log.Warningf("gofer.dentry.destroyLocked: failed to write dirty timestamps back: %v", err)
+ }
+ }
+ }
+ d.file.close(ctx)
+ d.file = p9file{}
+ // Remove d from the set of syncable dentries.
+ d.fs.syncMu.Lock()
+ delete(d.fs.syncableDentries, d)
+ d.fs.syncMu.Unlock()
+ }
+
+ d.fs.renameMu.Lock()
+
+ // Drop the reference held by d on its parent without recursively locking
+ // d.fs.renameMu.
+ if d.parent != nil {
+ if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
+ d.parent.checkCachingLocked(ctx)
+ } else if refs < 0 {
+ panic("gofer.dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
+func (d *dentry) isDeleted() bool {
+ return atomic.LoadUint32(&d.deleted) != 0
+}
+
+func (d *dentry) setDeleted() {
+ atomic.StoreUint32(&d.deleted, 1)
+}
+
+// We only support xattrs prefixed with "user." (see b/148380782). Currently,
+// there is no need to expose any other xattrs through a gofer.
+func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) {
+ if d.file.isNil() || !d.userXattrSupported() {
+ return nil, nil
+ }
+ xattrMap, err := d.file.listXattr(ctx, size)
+ if err != nil {
+ return nil, err
+ }
+ xattrs := make([]string, 0, len(xattrMap))
+ for x := range xattrMap {
+ if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
+ xattrs = append(xattrs, x)
+ }
+ }
+ return xattrs, nil
+}
+
+func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if d.file.isNil() {
+ return "", syserror.ENODATA
+ }
+ if err := d.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ if !d.userXattrSupported() {
+ return "", syserror.ENODATA
+ }
+ return d.file.getXattr(ctx, opts.Name, opts.Size)
+}
+
+func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if d.file.isNil() {
+ return syserror.EPERM
+ }
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !d.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags)
+}
+
+func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error {
+ if d.file.isNil() {
+ return syserror.EPERM
+ }
+ if err := d.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !d.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return d.file.removeXattr(ctx, name)
+}
+
+// Extended attributes in the user.* namespace are only supported for regular
+// files and directories.
+func (d *dentry) userXattrSupported() bool {
+ filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType()
+ return filetype == linux.ModeRegular || filetype == linux.ModeDirectory
+}
+
+// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir().
+func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error {
+ // O_TRUNC unconditionally requires us to obtain a new handle (opened with
+ // O_TRUNC).
+ if !trunc {
+ d.handleMu.RLock()
+ if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) {
+ // Current handles are sufficient.
+ d.handleMu.RUnlock()
+ return nil
+ }
+ d.handleMu.RUnlock()
+ }
+
+ fdToClose := int32(-1)
+ invalidateTranslations := false
+ d.handleMu.Lock()
+ if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc {
+ // Get a new handle. If this file has been opened for both reading and
+ // writing, try to get a single handle that is usable for both:
+ //
+ // - Writable memory mappings of a host FD require that the host FD is
+ // opened for both reading and writing.
+ //
+ // - NOTE(b/141991141): Some filesystems may not ensure coherence
+ // between multiple handles for the same file.
+ openReadable := !d.readFile.isNil() || read
+ openWritable := !d.writeFile.isNil() || write
+ h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ if err == syserror.EACCES && (openReadable != read || openWritable != write) {
+ // It may not be possible to use a single handle for both
+ // reading and writing, since permissions on the file may have
+ // changed to e.g. disallow reading after previously being
+ // opened for reading. In this case, we have no choice but to
+ // use separate handles for reading and writing.
+ ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d)
+ openReadable = read
+ openWritable = write
+ h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc)
+ }
+ if err != nil {
+ d.handleMu.Unlock()
+ return err
+ }
+
+ if d.hostFD < 0 && openReadable && h.fd >= 0 {
+ // We have no existing FD; use the new FD for at least reading.
+ d.hostFD = h.fd
+ } else if d.hostFD >= 0 && d.writeFile.isNil() && openWritable {
+ // We have an existing read-only FD, but the file has just been
+ // opened for writing, so we need to start supporting writable memory
+ // mappings. This may race with callers of d.pf.FD() using the existing
+ // FD, so in most cases we need to delay closing the old FD until after
+ // invalidating memmap.Translations that might have observed it.
+ if !openReadable || h.fd < 0 {
+ // We don't have a read/write FD, so we have no FD that can be
+ // used to create writable memory mappings. Switch to using the
+ // internal page cache.
+ invalidateTranslations = true
+ fdToClose = d.hostFD
+ d.hostFD = -1
+ } else if d.fs.opts.overlayfsStaleRead {
+ // We do have a read/write FD, but it may not be coherent with
+ // the existing read-only FD, so we must switch to mappings of
+ // the new FD in both the application and sentry.
+ if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil {
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err)
+ h.close(ctx)
+ return err
+ }
+ invalidateTranslations = true
+ fdToClose = d.hostFD
+ d.hostFD = h.fd
+ } else {
+ // We do have a read/write FD. To avoid invalidating existing
+ // memmap.Translations (which is expensive), use dup3 to make
+ // the old file descriptor refer to the new file description,
+ // then close the new file descriptor (which is no longer
+ // needed). Racing callers of d.pf.FD() may use the old or new
+ // file description, but this doesn't matter since they refer
+ // to the same file, and any racing mappings must be read-only.
+ if err := syscall.Dup3(int(h.fd), int(d.hostFD), syscall.O_CLOEXEC); err != nil {
+ oldHostFD := d.hostFD
+ d.handleMu.Unlock()
+ ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldHostFD, err)
+ h.close(ctx)
+ return err
+ }
+ fdToClose = h.fd
+ }
+ } else {
+ // h.fd is not useful.
+ fdToClose = h.fd
+ }
+
+ // Switch to new fids.
+ var oldReadFile p9file
+ if openReadable {
+ oldReadFile = d.readFile
+ d.readFile = h.file
+ }
+ var oldWriteFile p9file
+ if openWritable {
+ oldWriteFile = d.writeFile
+ d.writeFile = h.file
+ }
+ // NOTE(b/141991141): Clunk old fids before making new fids visible (by
+ // unlocking d.handleMu).
+ if !oldReadFile.isNil() {
+ oldReadFile.close(ctx)
+ }
+ if !oldWriteFile.isNil() && oldReadFile != oldWriteFile {
+ oldWriteFile.close(ctx)
+ }
+ }
+ d.handleMu.Unlock()
+
+ if invalidateTranslations {
+ // Invalidate application mappings that may be using an old FD; they
+ // will be replaced with mappings using the new FD after future calls
+ // to d.Translate(). This requires holding d.mapsMu, which precedes
+ // d.handleMu in the lock order.
+ d.mapsMu.Lock()
+ d.mappings.InvalidateAll(memmap.InvalidateOpts{})
+ d.mapsMu.Unlock()
+ }
+ if fdToClose >= 0 {
+ syscall.Close(int(fdToClose))
+ }
+
+ return nil
+}
+
+// Preconditions: d.handleMu must be locked.
+func (d *dentry) readHandleLocked() handle {
+ return handle{
+ file: d.readFile,
+ fd: d.hostFD,
+ }
+}
+
+// Preconditions: d.handleMu must be locked.
+func (d *dentry) writeHandleLocked() handle {
+ return handle{
+ file: d.writeFile,
+ fd: d.hostFD,
+ }
+}
+
+func (d *dentry) syncRemoteFile(ctx context.Context) error {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ return d.syncRemoteFileLocked(ctx)
+}
+
+// Preconditions: d.handleMu must be locked.
+func (d *dentry) syncRemoteFileLocked(ctx context.Context) error {
+ // If we have a host FD, fsyncing it is likely to be faster than an fsync
+ // RPC.
+ if d.hostFD >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(d.hostFD))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
+ if !d.writeFile.isNil() {
+ return d.writeFile.fsync(ctx)
+ }
+ if !d.readFile.isNil() {
+ return d.readFile.fsync(ctx)
+ }
+ return nil
+}
+
+// incLinks increments link count.
+func (d *dentry) incLinks() {
+ if atomic.LoadUint32(&d.nlink) == 0 {
+ // The remote filesystem doesn't support link count.
+ return
+ }
+ atomic.AddUint32(&d.nlink, 1)
+}
+
+// decLinks decrements link count.
+func (d *dentry) decLinks() {
+ if atomic.LoadUint32(&d.nlink) == 0 {
+ // The remote filesystem doesn't support link count.
+ return
+ }
+ atomic.AddUint32(&d.nlink, ^uint32(0))
+}
+
+// fileDescription is embedded by gofer implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ lockLogging sync.Once
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ d := fd.dentry()
+ const validMask = uint32(linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_SIZE | linux.STATX_BLOCKS | linux.STATX_BTIME)
+ if !d.cachedMetadataAuthoritative() && opts.Mask&validMask != 0 && opts.Sync != linux.AT_STATX_DONT_SYNC {
+ // TODO(jamieliu): Use specialFileFD.handle.file for the getattr if
+ // available?
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ var stat linux.Statx
+ d.statTo(&stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil {
+ return err
+ }
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ fd.dentry().InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size)
+}
+
+// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
+func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+ d := fd.dentry()
+ if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil {
+ return err
+ }
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
+func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+ d := fd.dentry()
+ if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil {
+ return err
+ }
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("File lock using gofer file handled internally.")
+ })
+ return fd.LockFD.LockBSD(ctx, uid, t, block)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ fd.lockLogging.Do(func() {
+ log.Infof("Range lock using gofer file handled internally.")
+ })
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go
new file mode 100644
index 000000000..bfe75dfe4
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/gofer_test.go
@@ -0,0 +1,67 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync/atomic"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+)
+
+func TestDestroyIdempotent(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fs := filesystem{
+ mfp: pgalloc.MemoryFileProviderFromContext(ctx),
+ syncableDentries: make(map[*dentry]struct{}),
+ opts: filesystemOptions{
+ // Test relies on no dentry being held in the cache.
+ maxCachedDentries: 0,
+ },
+ }
+
+ attr := &p9.Attr{
+ Mode: p9.ModeRegular,
+ }
+ mask := p9.AttrMask{
+ Mode: true,
+ Size: true,
+ }
+ parent, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+
+ child, err := fs.newDentry(ctx, p9file{}, p9.QID{}, mask, attr)
+ if err != nil {
+ t.Fatalf("fs.newDentry(): %v", err)
+ }
+ parent.cacheNewChildLocked(child, "child")
+
+ fs.renameMu.Lock()
+ defer fs.renameMu.Unlock()
+ child.checkCachingLocked(ctx)
+ if got := atomic.LoadInt64(&child.refs); got != -1 {
+ t.Fatalf("child.refs=%d, want: -1", got)
+ }
+ // Parent will also be destroyed when child reference is removed.
+ if got := atomic.LoadInt64(&parent.refs); got != -1 {
+ t.Fatalf("parent.refs=%d, want: -1", got)
+ }
+ child.checkCachingLocked(ctx)
+ child.checkCachingLocked(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go
new file mode 100644
index 000000000..104157512
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/handle.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+)
+
+// handle represents a remote "open file descriptor", consisting of an opened
+// fid (p9.File) and optionally a host file descriptor.
+type handle struct {
+ file p9file
+ fd int32 // -1 if unavailable
+}
+
+// Preconditions: read || write.
+func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (handle, error) {
+ _, newfile, err := file.walk(ctx, nil)
+ if err != nil {
+ return handle{fd: -1}, err
+ }
+ var flags p9.OpenFlags
+ switch {
+ case read && !write:
+ flags = p9.ReadOnly
+ case !read && write:
+ flags = p9.WriteOnly
+ case read && write:
+ flags = p9.ReadWrite
+ }
+ if trunc {
+ flags |= p9.OpenTruncate
+ }
+ fdobj, _, _, err := newfile.open(ctx, flags)
+ if err != nil {
+ newfile.close(ctx)
+ return handle{fd: -1}, err
+ }
+ fd := int32(-1)
+ if fdobj != nil {
+ fd = int32(fdobj.Release())
+ }
+ return handle{
+ file: newfile,
+ fd: fd,
+ }, nil
+}
+
+func (h *handle) isOpen() bool {
+ return !h.file.isNil()
+}
+
+func (h *handle) close(ctx context.Context) {
+ h.file.close(ctx)
+ h.file = p9file{}
+ if h.fd >= 0 {
+ syscall.Close(int(h.fd))
+ h.fd = -1
+ }
+}
+
+func (h *handle) readToBlocksAt(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := hostfd.Preadv2(h.fd, dsts, int64(offset), 0 /* flags */)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+ }
+ if dsts.NumBlocks() == 1 && !dsts.Head().NeedSafecopy() {
+ n, err := h.file.readAt(ctx, dsts.Head().ToSlice(), offset)
+ return uint64(n), err
+ }
+ // Buffer the read since p9.File.ReadAt() takes []byte.
+ buf := make([]byte, dsts.NumBytes())
+ n, err := h.file.readAt(ctx, buf, offset)
+ if n == 0 {
+ return 0, err
+ }
+ if cp, cperr := safemem.CopySeq(dsts, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:n]))); cperr != nil {
+ return cp, cperr
+ }
+ return uint64(n), err
+}
+
+func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ if h.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := hostfd.Pwritev2(h.fd, srcs, int64(offset), 0 /* flags */)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+ }
+ if srcs.NumBlocks() == 1 && !srcs.Head().NeedSafecopy() {
+ n, err := h.file.writeAt(ctx, srcs.Head().ToSlice(), offset)
+ return uint64(n), err
+ }
+ // Buffer the write since p9.File.WriteAt() takes []byte.
+ buf := make([]byte, srcs.NumBytes())
+ cp, cperr := safemem.CopySeq(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), srcs)
+ if cp == 0 {
+ return 0, cperr
+ }
+ n, err := h.file.writeAt(ctx, buf[:cp], offset)
+ if err != nil {
+ return uint64(n), err
+ }
+ return cp, cperr
+}
diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
new file mode 100644
index 000000000..7294de7d6
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go
@@ -0,0 +1,97 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Global pipe used by blockUntilNonblockingPipeHasWriter since we can't create
+// pipes after sentry initialization due to syscall filters.
+var (
+ tempPipeMu sync.Mutex
+ tempPipeReadFD int
+ tempPipeWriteFD int
+ tempPipeBuf [1]byte
+)
+
+func init() {
+ var pipeFDs [2]int
+ if err := unix.Pipe(pipeFDs[:]); err != nil {
+ panic(fmt.Sprintf("failed to create pipe for gofer.blockUntilNonblockingPipeHasWriter: %v", err))
+ }
+ tempPipeReadFD = pipeFDs[0]
+ tempPipeWriteFD = pipeFDs[1]
+}
+
+func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error {
+ for {
+ ok, err := nonblockingPipeHasWriter(fd)
+ if err != nil {
+ return err
+ }
+ if ok {
+ return nil
+ }
+ if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil {
+ return err
+ }
+ }
+}
+
+func nonblockingPipeHasWriter(fd int32) (bool, error) {
+ tempPipeMu.Lock()
+ defer tempPipeMu.Unlock()
+ // Copy 1 byte from fd into the temporary pipe.
+ n, err := unix.Tee(int(fd), tempPipeWriteFD, 1, unix.SPLICE_F_NONBLOCK)
+ if err == syserror.EAGAIN {
+ // The pipe represented by fd is empty, but has a writer.
+ return true, nil
+ }
+ if err != nil {
+ return false, err
+ }
+ if n == 0 {
+ // The pipe represented by fd is empty and has no writer.
+ return false, nil
+ }
+ // The pipe represented by fd is non-empty, so it either has, or has
+ // previously had, a writer. Remove the byte copied to the temporary pipe
+ // before returning.
+ if n, err := unix.Read(tempPipeReadFD, tempPipeBuf[:]); err != nil || n != 1 {
+ panic(fmt.Sprintf("failed to drain pipe for gofer.blockUntilNonblockingPipeHasWriter: got (%d, %v), wanted (1, nil)", n, err))
+ }
+ return true, nil
+}
+
+func sleepBetweenNamedPipeOpenChecks(ctx context.Context) error {
+ t := time.NewTimer(100 * time.Millisecond)
+ defer t.Stop()
+ cancel := ctx.SleepStart()
+ select {
+ case <-t.C:
+ ctx.SleepFinish(true)
+ return nil
+ case <-cancel:
+ ctx.SleepFinish(false)
+ return syserror.ErrInterrupted
+ }
+}
diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go
new file mode 100644
index 000000000..87f0b877f
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/p9file.go
@@ -0,0 +1,233 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// p9file is a wrapper around p9.File that provides methods that are
+// Context-aware.
+type p9file struct {
+ file p9.File
+}
+
+func (f p9file) isNil() bool {
+ return f.file == nil
+}
+
+func (f p9file) walk(ctx context.Context, names []string) ([]p9.QID, p9file, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, err := f.file.Walk(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return qids, p9file{newfile}, err
+}
+
+func (f p9file) walkGetAttr(ctx context.Context, names []string) ([]p9.QID, p9file, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, attrMask, attr, err := f.file.WalkGetAttr(names)
+ ctx.UninterruptibleSleepFinish(false)
+ return qids, p9file{newfile}, attrMask, attr, err
+}
+
+// walkGetAttrOne is a wrapper around p9.File.WalkGetAttr that takes a single
+// path component and returns a single qid.
+func (f p9file) walkGetAttrOne(ctx context.Context, name string) (p9.QID, p9file, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qids, newfile, attrMask, attr, err := f.file.WalkGetAttr([]string{name})
+ ctx.UninterruptibleSleepFinish(false)
+ if err != nil {
+ return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, err
+ }
+ if len(qids) != 1 {
+ ctx.Warningf("p9.File.WalkGetAttr returned %d qids (%v), wanted 1", len(qids), qids)
+ if newfile != nil {
+ p9file{newfile}.close(ctx)
+ }
+ return p9.QID{}, p9file{}, p9.AttrMask{}, p9.Attr{}, syserror.EIO
+ }
+ return qids[0], p9file{newfile}, attrMask, attr, nil
+}
+
+func (f p9file) statFS(ctx context.Context) (p9.FSStat, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fsstat, err := f.file.StatFS()
+ ctx.UninterruptibleSleepFinish(false)
+ return fsstat, err
+}
+
+func (f p9file) getAttr(ctx context.Context, req p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, attrMask, attr, err := f.file.GetAttr(req)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, attrMask, attr, err
+}
+
+func (f p9file) setAttr(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.SetAttr(valid, attr)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) listXattr(ctx context.Context, size uint64) (map[string]struct{}, error) {
+ ctx.UninterruptibleSleepStart(false)
+ xattrs, err := f.file.ListXattr(size)
+ ctx.UninterruptibleSleepFinish(false)
+ return xattrs, err
+}
+
+func (f p9file) getXattr(ctx context.Context, name string, size uint64) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ val, err := f.file.GetXattr(name, size)
+ ctx.UninterruptibleSleepFinish(false)
+ return val, err
+}
+
+func (f p9file) setXattr(ctx context.Context, name, value string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.SetXattr(name, value, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) removeXattr(ctx context.Context, name string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.RemoveXattr(name)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) allocate(ctx context.Context, mode p9.AllocateMode, offset, length uint64) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Allocate(mode, offset, length)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) close(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Close()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, qid, iounit, err := f.file.Open(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, qid, iounit, err
+}
+
+func (f p9file) readAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := f.file.ReadAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (f p9file) writeAt(ctx context.Context, p []byte, offset uint64) (int, error) {
+ ctx.UninterruptibleSleepStart(false)
+ n, err := f.file.WriteAt(p, offset)
+ ctx.UninterruptibleSleepFinish(false)
+ return n, err
+}
+
+func (f p9file) fsync(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.FSync()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) create(ctx context.Context, name string, flags p9.OpenFlags, permissions p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9file, p9.QID, uint32, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, newfile, qid, iounit, err := f.file.Create(name, flags, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, p9file{newfile}, qid, iounit, err
+}
+
+func (f p9file) mkdir(ctx context.Context, name string, permissions p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Mkdir(name, permissions, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) symlink(ctx context.Context, oldName string, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Symlink(oldName, newName, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) link(ctx context.Context, target p9file, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Link(target.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) mknod(ctx context.Context, name string, mode p9.FileMode, major uint32, minor uint32, uid p9.UID, gid p9.GID) (p9.QID, error) {
+ ctx.UninterruptibleSleepStart(false)
+ qid, err := f.file.Mknod(name, mode, major, minor, uid, gid)
+ ctx.UninterruptibleSleepFinish(false)
+ return qid, err
+}
+
+func (f p9file) rename(ctx context.Context, newDir p9file, newName string) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Rename(newDir.file, newName)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) unlinkAt(ctx context.Context, name string, flags uint32) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.UnlinkAt(name, flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) readdir(ctx context.Context, offset uint64, count uint32) ([]p9.Dirent, error) {
+ ctx.UninterruptibleSleepStart(false)
+ dirents, err := f.file.Readdir(offset, count)
+ ctx.UninterruptibleSleepFinish(false)
+ return dirents, err
+}
+
+func (f p9file) readlink(ctx context.Context) (string, error) {
+ ctx.UninterruptibleSleepStart(false)
+ target, err := f.file.Readlink()
+ ctx.UninterruptibleSleepFinish(false)
+ return target, err
+}
+
+func (f p9file) flush(ctx context.Context) error {
+ ctx.UninterruptibleSleepStart(false)
+ err := f.file.Flush()
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+}
+
+func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, error) {
+ ctx.UninterruptibleSleepStart(false)
+ fdobj, err := f.file.Connect(flags)
+ ctx.UninterruptibleSleepFinish(false)
+ return fdobj, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
new file mode 100644
index 000000000..7e1cbf065
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -0,0 +1,944 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "fmt"
+ "io"
+ "math"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isRegularFile() bool {
+ return d.fileType() == linux.S_IFREG
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // off is the file offset. off is protected by mu.
+ mu sync.Mutex
+ off int64
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release(context.Context) {
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *regularFileFD) OnClose(ctx context.Context) error {
+ if !fd.vfsfd.IsWritable() {
+ return nil
+ }
+ // Skip flushing if writes may be buffered by the client, since (as with
+ // the VFS1 client) we don't flush buffered writes on close anyway.
+ d := fd.dentry()
+ if d.fs.opts.interop == InteropModeExclusive {
+ return nil
+ }
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ if d.writeFile.isNil() {
+ return nil
+ }
+ return d.writeFile.flush(ctx)
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ d := fd.dentry()
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Allocating a smaller size is a noop.
+ size := offset + length
+ if d.cachedMetadataAuthoritative() && size <= d.size {
+ return nil
+ }
+
+ d.handleMu.RLock()
+ err := d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length)
+ d.handleMu.RUnlock()
+ if err != nil {
+ return err
+ }
+ d.dataMu.Lock()
+ atomic.StoreUint64(&d.size, size)
+ d.dataMu.Unlock()
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtimeLocked()
+ }
+ return nil
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // Check for reading at EOF before calling into MM (but not under
+ // InteropModeShared, which makes d.size unreliable).
+ d := fd.dentry()
+ if d.cachedMetadataAuthoritative() && uint64(offset) >= atomic.LoadUint64(&d.size) {
+ return 0, io.EOF
+ }
+
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Lock d.metadataMu for the rest of the read to prevent d.size from
+ // changing.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+ // Write dirty cached pages that will be touched by the read back to
+ // the remote file.
+ if err := d.writeback(ctx, offset, dst.NumBytes()); err != nil {
+ return 0, err
+ }
+ }
+
+ rw := getDentryReadWriter(ctx, d, offset)
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ // Require the read to go to the remote file.
+ rw.direct = true
+ }
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putDentryReadWriter(rw)
+ if d.fs.opts.interop != InteropModeShared {
+ // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed().
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+ return n, err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
+ if offset < 0 {
+ return 0, offset, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the fd was opened with O_APPEND, make sure the file size is updated.
+ // There is a possible race here if size is modified externally after
+ // metadata cache is updated.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
+ }
+
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Set offset to file size if the fd was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, offset, err
+ }
+ src = src.TakeFirst64(limit)
+
+ if d.fs.opts.interop != InteropModeShared {
+ // Compare Linux's mm/filemap.c:__generic_file_write_iter() =>
+ // file_update_time(). This is d.touchCMtime(), but without locking
+ // d.metadataMu (recursively).
+ d.touchCMtimeLocked()
+ }
+
+ rw := getDentryReadWriter(ctx, d, offset)
+ defer putDentryReadWriter(rw)
+
+ if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 {
+ if err := fd.writeCache(ctx, d, offset, src); err != nil {
+ return 0, offset, err
+ }
+
+ // Require the write to go to the remote file.
+ rw.direct = true
+ }
+
+ n, err := src.CopyInTo(ctx, rw)
+ if err != nil {
+ return n, offset + n, err
+ }
+ if n > 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
+ // Note that if any of the following fail, then we can't guarantee that
+ // any data was actually written with the semantics of O_DSYNC or
+ // O_SYNC, so we return zero bytes written. Compare Linux's
+ // mm/filemap.c:generic_file_write_iter() =>
+ // include/linux/fs.h:generic_write_sync().
+ //
+ // Write dirty cached pages touched by the write back to the remote
+ // file.
+ if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
+ return 0, offset, err
+ }
+ // Request the remote filesystem to sync the remote file.
+ if err := d.syncRemoteFile(ctx); err != nil {
+ return 0, offset, err
+ }
+ }
+ return n, offset + n, nil
+}
+
+func (fd *regularFileFD) writeCache(ctx context.Context, d *dentry, offset int64, src usermem.IOSequence) error {
+ // Write dirty cached pages that will be touched by the write back to
+ // the remote file.
+ if err := d.writeback(ctx, offset, src.NumBytes()); err != nil {
+ return err
+ }
+
+ // Remove touched pages from the cache.
+ pgstart := usermem.PageRoundDown(uint64(offset))
+ pgend, ok := usermem.PageRoundUp(uint64(offset + src.NumBytes()))
+ if !ok {
+ return syserror.EINVAL
+ }
+ mr := memmap.MappableRange{pgstart, pgend}
+ var freed []memmap.FileRange
+
+ d.dataMu.Lock()
+ cseg := d.cache.LowerBoundSegment(mr.Start)
+ for cseg.Ok() && cseg.Start() < mr.End {
+ cseg = d.cache.Isolate(cseg, mr)
+ freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
+ cseg = d.cache.Remove(cseg).NextSegment()
+ }
+ d.dataMu.Unlock()
+
+ // Invalidate mappings of removed pages.
+ d.mapsMu.Lock()
+ d.mappings.Invalidate(mr, memmap.InvalidateOpts{})
+ d.mapsMu.Unlock()
+
+ // Finally free pages removed from the cache.
+ mf := d.fs.mfp.MemoryFile()
+ for _, freedFR := range freed {
+ mf.DecRef(freedFR)
+ }
+ return nil
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
+ fd.mu.Unlock()
+ return n, err
+}
+
+type dentryReadWriter struct {
+ ctx context.Context
+ d *dentry
+ off uint64
+ direct bool
+}
+
+var dentryReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &dentryReadWriter{}
+ },
+}
+
+func getDentryReadWriter(ctx context.Context, d *dentry, offset int64) *dentryReadWriter {
+ rw := dentryReadWriterPool.Get().(*dentryReadWriter)
+ rw.ctx = ctx
+ rw.d = d
+ rw.off = uint64(offset)
+ rw.direct = false
+ return rw
+}
+
+func putDentryReadWriter(rw *dentryReadWriter) {
+ rw.ctx = nil
+ rw.d = nil
+ dentryReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ // If we have a mmappable host FD (which must be used here to ensure
+ // coherence with memory-mapped I/O), or if InteropModeShared is in effect
+ // (which prevents us from caching file contents and makes dentry.size
+ // unreliable), or if the file was opened O_DIRECT, read directly from
+ // dentry.readHandleLocked() without locking dentry.dataMu.
+ rw.d.handleMu.RLock()
+ h := rw.d.readHandleLocked()
+ if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ n, err := h.readToBlocksAt(rw.ctx, dsts, rw.off)
+ rw.d.handleMu.RUnlock()
+ rw.off += n
+ return n, err
+ }
+
+ // Otherwise read from/through the cache.
+ mf := rw.d.fs.mfp.MemoryFile()
+ fillCache := mf.ShouldCacheEvictable()
+ var dataMuUnlock func()
+ if fillCache {
+ rw.d.dataMu.Lock()
+ dataMuUnlock = rw.d.dataMu.Unlock
+ } else {
+ rw.d.dataMu.RLock()
+ dataMuUnlock = rw.d.dataMu.RUnlock
+ }
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= rw.d.size {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return 0, io.EOF
+ }
+ end := rw.d.size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.d.cache.Find(rw.off)
+ for rw.off < end {
+ mr := memmap.MappableRange{rw.off, end}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ ims, err := mf.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += n
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ gapMR := gap.Range().Intersect(mr)
+ if fillCache {
+ // Read into the cache, then re-enter the loop to read from the
+ // cache.
+ gapEnd, _ := usermem.PageRoundUp(gapMR.End)
+ reqMR := memmap.MappableRange{
+ Start: usermem.PageRoundDown(gapMR.Start),
+ End: gapEnd,
+ }
+ optMR := gap.Range()
+ err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, h.readToBlocksAt)
+ mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End})
+ seg, gap = rw.d.cache.Find(rw.off)
+ if !seg.Ok() {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+ // err might have occurred in part of gap.Range() outside
+ // gapMR. Forget about it for now; if the error matters and
+ // persists, we'll run into it again in a later iteration of
+ // this loop.
+ } else {
+ // Read directly from the file.
+ gapDsts := dsts.TakeFirst64(gapMR.Length())
+ n, err := h.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start)
+ done += n
+ rw.off += n
+ dsts = dsts.DropFirst64(n)
+ // Partial reads are fine. But we must stop reading.
+ if n != gapDsts.NumBytes() || err != nil {
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ }
+ dataMuUnlock()
+ rw.d.handleMu.RUnlock()
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: rw.d.metadataMu must be locked.
+func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+
+ // If we have a mmappable host FD (which must be used here to ensure
+ // coherence with memory-mapped I/O), or if InteropModeShared is in effect
+ // (which prevents us from caching file contents), or if the file was
+ // opened with O_DIRECT, write directly to dentry.writeHandleLocked()
+ // without locking dentry.dataMu.
+ rw.d.handleMu.RLock()
+ h := rw.d.writeHandleLocked()
+ if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct {
+ n, err := h.writeFromBlocksAt(rw.ctx, srcs, rw.off)
+ rw.off += n
+ rw.d.dataMu.Lock()
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
+ return n, err
+ }
+
+ // Otherwise write to/through the cache.
+ mf := rw.d.fs.mfp.MemoryFile()
+ rw.d.dataMu.Lock()
+
+ // Compute the range to write (overflow-checked).
+ start := rw.off
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.d.cache.Find(rw.off)
+ for rw.off < end {
+ mr := memmap.MappableRange{rw.off, end}
+ switch {
+ case seg.Ok():
+ // Get internal mappings from the cache.
+ segMR := seg.Range().Intersect(mr)
+ ims, err := mf.MapInternal(seg.FileRangeOf(segMR), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += n
+ srcs = srcs.DropFirst64(n)
+ rw.d.dirty.MarkDirty(segMR)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Write directly to the file. At present, we never fill the cache
+ // when writing, since doing so can convert small writes into
+ // inefficient read-modify-write cycles, and we have no mechanism
+ // for detecting or avoiding this.
+ gapMR := gap.Range().Intersect(mr)
+ gapSrcs := srcs.TakeFirst64(gapMR.Length())
+ n, err := h.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start)
+ done += n
+ rw.off += n
+ srcs = srcs.DropFirst64(n)
+ // Partial writes are fine. But we must stop writing.
+ if n != gapSrcs.NumBytes() || err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ if rw.off > rw.d.size {
+ atomic.StoreUint64(&rw.d.size, rw.off)
+ // The remote file's size will implicitly be extended to the correct
+ // value when we write back to it.
+ }
+ // If InteropModeWritethrough is in effect, flush written data back to the
+ // remote filesystem.
+ if rw.d.fs.opts.interop == InteropModeWritethrough && done != 0 {
+ if err := fsutil.SyncDirty(rw.ctx, memmap.MappableRange{
+ Start: start,
+ End: rw.off,
+ }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, h.writeFromBlocksAt); err != nil {
+ // We have no idea how many bytes were actually flushed.
+ rw.off = start
+ done = 0
+ retErr = err
+ }
+ }
+ rw.d.dataMu.Unlock()
+ rw.d.handleMu.RUnlock()
+ return done, retErr
+}
+
+func (d *dentry) writeback(ctx context.Context, offset, size int64) error {
+ if size == 0 {
+ return nil
+ }
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ h := d.writeHandleLocked()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ // Compute the range of valid bytes (overflow-checked).
+ if uint64(offset) >= d.size {
+ return nil
+ }
+ end := int64(d.size)
+ if rend := offset + size; rend > offset && rend < end {
+ end = rend
+ }
+ return fsutil.SyncDirty(ctx, memmap.MappableRange{
+ Start: uint64(offset),
+ End: uint64(end),
+ }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence)
+ if err != nil {
+ return 0, err
+ }
+ fd.off = newOffset
+ return newOffset, nil
+}
+
+// Calculate the new offset for a seek operation on a regular file.
+func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int64, whence int32) (int64, error) {
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as specified.
+ case linux.SEEK_CUR:
+ offset += fdOffset
+ case linux.SEEK_END, linux.SEEK_DATA, linux.SEEK_HOLE:
+ // Ensure file size is up to date.
+ if !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, err
+ }
+ }
+ size := int64(atomic.LoadUint64(&d.size))
+ // For SEEK_DATA and SEEK_HOLE, treat the file as a single contiguous
+ // block of data.
+ switch whence {
+ case linux.SEEK_END:
+ offset += size
+ case linux.SEEK_DATA:
+ if offset > size {
+ return 0, syserror.ENXIO
+ }
+ // Use offset as specified.
+ case linux.SEEK_HOLE:
+ if offset > size {
+ return 0, syserror.ENXIO
+ }
+ offset = size
+ }
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ return offset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *regularFileFD) Sync(ctx context.Context) error {
+ return fd.dentry().syncCachedFile(ctx)
+}
+
+func (d *dentry) syncCachedFile(ctx context.Context) error {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+
+ if h := d.writeHandleLocked(); h.isOpen() {
+ d.dataMu.Lock()
+ // Write dirty cached data to the remote file.
+ err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt)
+ d.dataMu.Unlock()
+ if err != nil {
+ return err
+ }
+ }
+ return d.syncRemoteFileLocked(ctx)
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ d := fd.dentry()
+ switch d.fs.opts.interop {
+ case InteropModeExclusive:
+ // Any mapping is fine.
+ case InteropModeWritethrough:
+ // Shared writable mappings require a host FD, since otherwise we can't
+ // synchronously flush memory-mapped writes to the remote file.
+ if opts.Private || !opts.MaxPerms.Write {
+ break
+ }
+ fallthrough
+ case InteropModeShared:
+ // All mappings require a host FD to be coherent with other filesystem
+ // users.
+ if d.fs.opts.forcePageCache {
+ // Whether or not we have a host FD, we're not allowed to use it.
+ return syserror.ENODEV
+ }
+ d.handleMu.RLock()
+ haveFD := d.hostFD >= 0
+ d.handleMu.RUnlock()
+ if !haveFD {
+ return syserror.ENODEV
+ }
+ default:
+ panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop))
+ }
+ // After this point, d may be used as a memmap.Mappable.
+ d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
+}
+
+func (d *dentry) mayCachePages() bool {
+ if d.fs.opts.interop == InteropModeShared {
+ return false
+ }
+ if d.fs.opts.forcePageCache {
+ return true
+ }
+ d.handleMu.RLock()
+ haveFD := d.hostFD >= 0
+ d.handleMu.RUnlock()
+ return haveFD
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ d.mapsMu.Lock()
+ mapped := d.mappings.AddMapping(ms, ar, offset, writable)
+ // Do this unconditionally since whether we have a host FD can change
+ // across save/restore.
+ for _, r := range mapped {
+ d.pf.hostFileMapper.IncRefOn(r)
+ }
+ if d.mayCachePages() {
+ // d.Evict() will refuse to evict memory-mapped pages, so tell the
+ // MemoryFile to not bother trying.
+ mf := d.fs.mfp.MemoryFile()
+ for _, r := range mapped {
+ mf.MarkUnevictable(d, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ }
+ d.mapsMu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ d.mapsMu.Lock()
+ unmapped := d.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ d.pf.hostFileMapper.DecRefOn(r)
+ }
+ if d.mayCachePages() {
+ // Pages that are no longer referenced by any application memory
+ // mappings are now considered unused; allow MemoryFile to evict them
+ // when necessary.
+ mf := d.fs.mfp.MemoryFile()
+ d.dataMu.Lock()
+ for _, r := range unmapped {
+ // Since these pages are no longer mapped, they are no longer
+ // concurrently dirtyable by a writable memory mapping.
+ d.dirty.AllowClean(r)
+ mf.MarkEvictable(d, pgalloc.EvictableRange{r.Start, r.End})
+ }
+ d.dataMu.Unlock()
+ }
+ d.mapsMu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return d.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ d.handleMu.RLock()
+ if d.hostFD >= 0 && !d.fs.opts.forcePageCache {
+ d.handleMu.RUnlock()
+ mr := optional
+ if d.fs.opts.limitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: &d.pf,
+ Offset: mr.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+ }
+
+ d.dataMu.Lock()
+
+ // Constrain translations to d.size (rounded up) to prevent translation to
+ // pages that may be concurrently truncated.
+ pgend, _ := usermem.PageRoundUp(d.size)
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ d.dataMu.Unlock()
+ d.handleMu.RUnlock()
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ mf := d.fs.mfp.MemoryFile()
+ h := d.readHandleLocked()
+ cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, h.readToBlocksAt)
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := d.cache.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ // TODO(jamieliu): Make Translations writable even if writability is
+ // not required if already kept-dirty by another writable translation.
+ perms := usermem.AccessType{
+ Read: true,
+ Execute: true,
+ }
+ if at.Write {
+ // From this point forward, this memory can be dirtied through the
+ // mapping at any time.
+ d.dirty.KeepDirty(segMR)
+ perms.Write = true
+ }
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: mf,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: perms,
+ })
+ translatedEnd = segMR.End
+ }
+
+ d.dataMu.Unlock()
+ d.handleMu.RUnlock()
+
+ // Don't return the error returned by c.cache.Fill if it occurred outside
+ // of required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange {
+ const maxReadahead = 64 << 10 // 64 KB, chosen arbitrarily
+ if required.Length() >= maxReadahead {
+ return required
+ }
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.Start = required.Start
+ if optional.Length() <= maxReadahead {
+ return optional
+ }
+ optional.End = optional.Start + maxReadahead
+ return optional
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
+ // Whether we have a host fd (and consequently what memmap.File is
+ // mapped) can change across save/restore, so invalidate all translations
+ // unconditionally.
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.mappings.InvalidateAll(memmap.InvalidateOpts{})
+
+ // Write the cache's contents back to the remote file so that if we have a
+ // host fd after restore, the remote file's contents are coherent.
+ mf := d.fs.mfp.MemoryFile()
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ h := d.writeHandleLocked()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil {
+ return err
+ }
+
+ // Discard the cache so that it's not stored in saved state. This is safe
+ // because per InvalidateUnsavable invariants, no new translations can have
+ // been returned after we invalidated all existing translations above.
+ d.cache.DropAll(mf)
+ d.dirty.RemoveAll()
+
+ return nil
+}
+
+// Evict implements pgalloc.EvictableMemoryUser.Evict.
+func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
+ mr := memmap.MappableRange{er.Start, er.End}
+ mf := d.fs.mfp.MemoryFile()
+ d.mapsMu.Lock()
+ defer d.mapsMu.Unlock()
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ h := d.writeHandleLocked()
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+
+ // Only allow pages that are no longer memory-mapped to be evicted.
+ for mgap := d.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() {
+ mgapMR := mgap.Range().Intersect(mr)
+ if mgapMR.Length() == 0 {
+ continue
+ }
+ if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil {
+ log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err)
+ }
+ d.cache.Drop(mgapMR, mf)
+ d.dirty.KeepClean(mgapMR)
+ }
+}
+
+// dentryPlatformFile implements memmap.File. It exists solely because dentry
+// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef.
+//
+// dentryPlatformFile is only used when a host FD representing the remote file
+// is available (i.e. dentry.hostFD >= 0), and that FD is used for application
+// memory mappings (i.e. !filesystem.opts.forcePageCache).
+type dentryPlatformFile struct {
+ *dentry
+
+ // fdRefs counts references on memmap.File offsets. fdRefs is protected
+ // by dentry.dataMu.
+ fdRefs fsutil.FrameRefSet
+
+ // If this dentry represents a regular file, and dentry.hostFD >= 0,
+ // hostFileMapper caches mappings of dentry.hostFD.
+ hostFileMapper fsutil.HostFileMapper
+
+ // hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
+ hostFileMapperInitOnce sync.Once
+}
+
+// IncRef implements memmap.File.IncRef.
+func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) {
+ d.dataMu.Lock()
+ d.fdRefs.IncRefAndAccount(fr)
+ d.dataMu.Unlock()
+}
+
+// DecRef implements memmap.File.DecRef.
+func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) {
+ d.dataMu.Lock()
+ d.fdRefs.DecRefAndAccount(fr)
+ d.dataMu.Unlock()
+}
+
+// MapInternal implements memmap.File.MapInternal.
+func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ return d.hostFileMapper.MapInternal(fr, int(d.hostFD), at.Write)
+}
+
+// FD implements memmap.File.FD.
+func (d *dentryPlatformFile) FD() int {
+ d.handleMu.RLock()
+ defer d.handleMu.RUnlock()
+ return int(d.hostFD)
+}
diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go
new file mode 100644
index 000000000..85d2bee72
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/socket.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func (d *dentry) isSocket() bool {
+ return d.fileType() == linux.S_IFSOCK
+}
+
+// endpoint is a Gofer-backed transport.BoundEndpoint.
+//
+// An endpoint's lifetime is the time between when filesystem.BoundEndpointAt()
+// is called and either BoundEndpoint.BidirectionalConnect or
+// BoundEndpoint.UnidirectionalConnect is called.
+type endpoint struct {
+ // dentry is the filesystem dentry which produced this endpoint.
+ dentry *dentry
+
+ // file is the p9 file that contains a single unopened fid.
+ file p9.File
+
+ // path is the sentry path where this endpoint is bound.
+ path string
+}
+
+func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) {
+ switch t {
+ case linux.SOCK_STREAM:
+ return p9.StreamSocket, true
+ case linux.SOCK_SEQPACKET:
+ return p9.SeqpacketSocket, true
+ case linux.SOCK_DGRAM:
+ return p9.DgramSocket, true
+ }
+ return 0, false
+}
+
+// BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect.
+func (e *endpoint) BidirectionalConnect(ctx context.Context, ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error {
+ cf, ok := sockTypeToP9(ce.Type())
+ if !ok {
+ return syserr.ErrConnectionRefused
+ }
+
+ // No lock ordering required as only the ConnectingEndpoint has a mutex.
+ ce.Lock()
+
+ // Check connecting state.
+ if ce.Connected() {
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ c, err := e.newConnectedEndpoint(ctx, cf, ce.WaiterQueue())
+ if err != nil {
+ ce.Unlock()
+ return err
+ }
+
+ returnConnect(c, c)
+ ce.Unlock()
+ if err := c.Init(); err != nil {
+ return syserr.FromError(err)
+ }
+
+ return nil
+}
+
+// UnidirectionalConnect implements
+// transport.BoundEndpoint.UnidirectionalConnect.
+func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.ConnectedEndpoint, *syserr.Error) {
+ c, err := e.newConnectedEndpoint(ctx, p9.DgramSocket, &waiter.Queue{})
+ if err != nil {
+ return nil, err
+ }
+
+ if err := c.Init(); err != nil {
+ return nil, syserr.FromError(err)
+ }
+
+ // We don't need the receiver.
+ c.CloseRecv()
+ c.Release(ctx)
+
+ return c, nil
+}
+
+func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) {
+ hostFile, err := e.file.Connect(flags)
+ if err != nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+ // Dup the fd so that the new endpoint can manage its lifetime.
+ hostFD, err := syscall.Dup(hostFile.FD())
+ if err != nil {
+ log.Warningf("Could not dup host socket fd %d: %v", hostFile.FD(), err)
+ return nil, syserr.FromError(err)
+ }
+ // After duplicating, we no longer need hostFile.
+ hostFile.Close()
+
+ c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path)
+ if serr != nil {
+ log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr)
+ return nil, serr
+ }
+ return c, nil
+}
+
+// Release implements transport.BoundEndpoint.Release.
+func (e *endpoint) Release(ctx context.Context) {
+ e.dentry.DecRef(ctx)
+}
+
+// Passcred implements transport.BoundEndpoint.Passcred.
+func (e *endpoint) Passcred() bool {
+ return false
+}
diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go
new file mode 100644
index 000000000..a6368fdd0
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/special_file.go
@@ -0,0 +1,292 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device
+// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is
+// in effect) regular files. specialFileFD differs from regularFileFD by using
+// per-FD handles instead of shared per-dentry handles, and never buffering I/O.
+type specialFileFD struct {
+ fileDescription
+
+ // handle is used for file I/O. handle is immutable.
+ handle handle
+
+ // seekable is true if this file description represents a file for which
+ // file offset is significant, i.e. a regular file. seekable is immutable.
+ seekable bool
+
+ // haveQueue is true if this file description represents a file for which
+ // queue may send I/O readiness events. haveQueue is immutable.
+ haveQueue bool
+ queue waiter.Queue
+
+ // If seekable is true, off is the file offset. off is protected by mu.
+ mu sync.Mutex
+ off int64
+}
+
+func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) {
+ ftype := d.fileType()
+ seekable := ftype == linux.S_IFREG
+ haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0
+ fd := &specialFileFD{
+ handle: h,
+ seekable: seekable,
+ haveQueue: haveQueue,
+ }
+ fd.LockFD.Init(locks)
+ if haveQueue {
+ if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil {
+ return nil, err
+ }
+ }
+ if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ DenyPRead: !seekable,
+ DenyPWrite: !seekable,
+ }); err != nil {
+ if haveQueue {
+ fdnotifier.RemoveFD(h.fd)
+ }
+ return nil, err
+ }
+ return fd, nil
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *specialFileFD) Release(ctx context.Context) {
+ if fd.haveQueue {
+ fdnotifier.RemoveFD(fd.handle.fd)
+ }
+ fd.handle.close(ctx)
+ fs := fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+ fs.syncMu.Lock()
+ delete(fs.specialFileFDs, fd)
+ fs.syncMu.Unlock()
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *specialFileFD) OnClose(ctx context.Context) error {
+ if !fd.vfsfd.IsWritable() {
+ return nil
+ }
+ return fd.handle.file.flush(ctx)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if fd.haveQueue {
+ return fdnotifier.NonBlockingPoll(fd.handle.fd, mask)
+ }
+ return fd.fileDescription.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ if fd.haveQueue {
+ fd.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(fd.handle.fd)
+ return
+ }
+ fd.fileDescription.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *specialFileFD) EventUnregister(e *waiter.Entry) {
+ if fd.haveQueue {
+ fd.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(fd.handle.fd)
+ return
+ }
+ fd.fileDescription.EventUnregister(e)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if fd.seekable && offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // Going through dst.CopyOutFrom() holds MM locks around file operations of
+ // unknown duration. For regularFileFD, doing so is necessary to support
+ // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't
+ // hold here since specialFileFD doesn't client-cache data. Just buffer the
+ // read instead.
+ if d := fd.dentry(); d.cachedMetadataAuthoritative() {
+ d.touchAtime(fd.vfsfd.Mount())
+ }
+ buf := make([]byte, dst.NumBytes())
+ n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ if err == syserror.EAGAIN {
+ err = syserror.ErrWouldBlock
+ }
+ if n == 0 {
+ return 0, err
+ }
+ if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil {
+ return int64(cp), cperr
+ }
+ return int64(n), err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ if !fd.seekable {
+ return fd.PRead(ctx, dst, -1, opts)
+ }
+
+ fd.mu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset, error. The final
+// offset should be ignored by PWrite.
+func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
+ if fd.seekable && offset < 0 {
+ return 0, offset, syserror.EINVAL
+ }
+
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if opts.Flags&^linux.RWF_HIPRI != 0 {
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ d := fd.dentry()
+ // If the regular file fd was opened with O_APPEND, make sure the file size
+ // is updated. There is a possible race here if size is modified externally
+ // after metadata cache is updated.
+ if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() {
+ if err := d.updateFromGetattr(ctx); err != nil {
+ return 0, offset, err
+ }
+ }
+
+ if fd.seekable {
+ // We need to hold the metadataMu *while* writing to a regular file.
+ d.metadataMu.Lock()
+ defer d.metadataMu.Unlock()
+
+ // Set offset to file size if the regular file was opened with O_APPEND.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Holding d.metadataMu is sufficient for reading d.size.
+ offset = int64(d.size)
+ }
+ limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, offset, err
+ }
+ src = src.TakeFirst64(limit)
+ }
+
+ // Do a buffered write. See rationale in PRead.
+ if d.cachedMetadataAuthoritative() {
+ d.touchCMtime()
+ }
+ buf := make([]byte, src.NumBytes())
+ // Don't do partial writes if we get a partial read from src.
+ if _, err := src.CopyIn(ctx, buf); err != nil {
+ return 0, offset, err
+ }
+ n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset))
+ if err == syserror.EAGAIN {
+ err = syserror.ErrWouldBlock
+ }
+ finalOff = offset
+ // Update file size for regular files.
+ if fd.seekable {
+ finalOff += int64(n)
+ // d.metadataMu is already locked at this point.
+ if uint64(finalOff) > d.size {
+ d.dataMu.Lock()
+ defer d.dataMu.Unlock()
+ atomic.StoreUint64(&d.size, uint64(finalOff))
+ }
+ }
+ return int64(n), finalOff, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ if !fd.seekable {
+ return fd.PWrite(ctx, src, -1, opts)
+ }
+
+ fd.mu.Lock()
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ if !fd.seekable {
+ return 0, syserror.ESPIPE
+ }
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence)
+ if err != nil {
+ return 0, err
+ }
+ fd.off = newOffset
+ return newOffset, nil
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *specialFileFD) Sync(ctx context.Context) error {
+ // If we have a host FD, fsyncing it is likely to be faster than an fsync
+ // RPC.
+ if fd.handle.fd >= 0 {
+ ctx.UninterruptibleSleepStart(false)
+ err := syscall.Fsync(int(fd.handle.fd))
+ ctx.UninterruptibleSleepFinish(false)
+ return err
+ }
+ return fd.handle.file.fsync(ctx)
+}
diff --git a/pkg/sentry/fsimpl/gofer/symlink.go b/pkg/sentry/fsimpl/gofer/symlink.go
new file mode 100644
index 000000000..2ec819f86
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/symlink.go
@@ -0,0 +1,47 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func (d *dentry) isSymlink() bool {
+ return d.fileType() == linux.S_IFLNK
+}
+
+// Precondition: d.isSymlink().
+func (d *dentry) readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
+ if d.fs.opts.interop != InteropModeShared {
+ d.touchAtime(mnt)
+ d.dataMu.Lock()
+ if d.haveTarget {
+ target := d.target
+ d.dataMu.Unlock()
+ return target, nil
+ }
+ }
+ target, err := d.file.readlink(ctx)
+ if d.fs.opts.interop != InteropModeShared {
+ if err == nil {
+ d.haveTarget = true
+ d.target = target
+ }
+ d.dataMu.Unlock()
+ }
+ return target, err
+}
diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go
new file mode 100644
index 000000000..2cb8191b9
--- /dev/null
+++ b/pkg/sentry/fsimpl/gofer/time.go
@@ -0,0 +1,82 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gofer
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func dentryTimestampFromP9(s, ns uint64) int64 {
+ return int64(s*1e9 + ns)
+}
+
+func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 {
+ return ts.Sec*1e9 + int64(ts.Nsec)
+}
+
+func statxTimestampFromDentry(ns int64) linux.StatxTimestamp {
+ return linux.StatxTimestamp{
+ Sec: ns / 1e9,
+ Nsec: uint32(ns % 1e9),
+ }
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true.
+func (d *dentry) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.atime, now)
+ atomic.StoreUint32(&d.atimeDirty, 1)
+ d.metadataMu.Unlock()
+ mnt.EndWrite()
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// successfully called vfs.Mount.CheckBeginWrite().
+func (d *dentry) touchCtime() {
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.ctime, now)
+ d.metadataMu.Unlock()
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// successfully called vfs.Mount.CheckBeginWrite().
+func (d *dentry) touchCMtime() {
+ now := d.fs.clock.Now().Nanoseconds()
+ d.metadataMu.Lock()
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+ atomic.StoreUint32(&d.mtimeDirty, 1)
+ d.metadataMu.Unlock()
+}
+
+// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has
+// locked d.metadataMu.
+func (d *dentry) touchCMtimeLocked() {
+ now := d.fs.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&d.mtime, now)
+ atomic.StoreInt64(&d.ctime, now)
+ atomic.StoreUint32(&d.mtimeDirty, 1)
+}
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
new file mode 100644
index 000000000..bd701bbc7
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -0,0 +1,52 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "host",
+ srcs = [
+ "control.go",
+ "host.go",
+ "ioctl_unsafe.go",
+ "mmap.go",
+ "socket.go",
+ "socket_iovec.go",
+ "socket_unsafe.go",
+ "tty.go",
+ "util.go",
+ "util_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fdnotifier",
+ "//pkg/fspath",
+ "//pkg/iovec",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/hostfd",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/unet",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go
new file mode 100644
index 000000000..0135e4428
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/control.go
@@ -0,0 +1,96 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type scmRights struct {
+ fds []int
+}
+
+func newSCMRights(fds []int) control.SCMRightsVFS2 {
+ return &scmRights{fds}
+}
+
+// Files implements control.SCMRights.Files.
+func (c *scmRights) Files(ctx context.Context, max int) (control.RightsFilesVFS2, bool) {
+ n := max
+ var trunc bool
+ if l := len(c.fds); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+
+ rf := control.RightsFilesVFS2(fdsToFiles(ctx, c.fds[:n]))
+
+ // Only consume converted FDs (fdsToFiles may convert fewer than n FDs).
+ c.fds = c.fds[len(rf):]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (c *scmRights) Clone() transport.RightsControlMessage {
+ // Host rights never need to be cloned.
+ return nil
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (c *scmRights) Release(ctx context.Context) {
+ for _, fd := range c.fds {
+ syscall.Close(fd)
+ }
+ c.fds = nil
+}
+
+// If an error is encountered, only files created before the error will be
+// returned. This is what Linux does.
+func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription {
+ files := make([]*vfs.FileDescription, 0, len(fds))
+ for _, fd := range fds {
+ // Get flags. We do it here because they may be modified
+ // by subsequent functions.
+ fileFlags, _, errno := syscall.Syscall(syscall.SYS_FCNTL, uintptr(fd), syscall.F_GETFL, 0)
+ if errno != 0 {
+ ctx.Warningf("Error retrieving host FD flags: %v", error(errno))
+ break
+ }
+
+ // Create the file backed by hostFD.
+ file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */)
+ if err != nil {
+ ctx.Warningf("Error creating file from host FD: %v", err)
+ break
+ }
+
+ if err := file.SetStatusFlags(ctx, auth.CredentialsFromContext(ctx), uint32(fileFlags&linux.O_NONBLOCK)); err != nil {
+ ctx.Warningf("Error setting flags on host FD file: %v", err)
+ break
+ }
+
+ files = append(files, file)
+ }
+ return files
+}
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
new file mode 100644
index 000000000..bd6caba06
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -0,0 +1,769 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package host provides a filesystem implementation for host files imported as
+// file descriptors.
+package host
+
+import (
+ "fmt"
+ "math"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// NewFDOptions contains options to NewFD.
+type NewFDOptions struct {
+ // If IsTTY is true, the file descriptor is a TTY.
+ IsTTY bool
+
+ // If HaveFlags is true, use Flags for the new file description. Otherwise,
+ // the new file description will inherit flags from hostFD.
+ HaveFlags bool
+ Flags uint32
+}
+
+// NewFD returns a vfs.FileDescription representing the given host file
+// descriptor. mnt must be Kernel.HostMount().
+func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) (*vfs.FileDescription, error) {
+ fs, ok := mnt.Filesystem().Impl().(*filesystem)
+ if !ok {
+ return nil, fmt.Errorf("can't import host FDs into filesystems of type %T", mnt.Filesystem().Impl())
+ }
+
+ // Retrieve metadata.
+ var s unix.Stat_t
+ if err := unix.Fstat(hostFD, &s); err != nil {
+ return nil, err
+ }
+
+ flags := opts.Flags
+ if !opts.HaveFlags {
+ // Get flags for the imported FD.
+ flagsInt, err := unix.FcntlInt(uintptr(hostFD), syscall.F_GETFL, 0)
+ if err != nil {
+ return nil, err
+ }
+ flags = uint32(flagsInt)
+ }
+
+ fileMode := linux.FileMode(s.Mode)
+ fileType := fileMode.FileType()
+
+ // Determine if hostFD is seekable. If not, this syscall will return ESPIPE
+ // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character
+ // devices.
+ _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR)
+ seekable := err != syserror.ESPIPE
+
+ i := &inode{
+ hostFD: hostFD,
+ ino: fs.NextIno(),
+ isTTY: opts.IsTTY,
+ wouldBlock: wouldBlock(uint32(fileType)),
+ seekable: seekable,
+ // NOTE(b/38213152): Technically, some obscure char devices can be memory
+ // mapped, but we only allow regular files.
+ canMap: fileType == linux.S_IFREG,
+ }
+ i.pf.inode = i
+
+ // Non-seekable files can't be memory mapped, assert this.
+ if !i.seekable && i.canMap {
+ panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped")
+ }
+
+ // If the hostFD would block, we must set it to non-blocking and handle
+ // blocking behavior in the sentry.
+ if i.wouldBlock {
+ if err := syscall.SetNonblock(i.hostFD, true); err != nil {
+ return nil, err
+ }
+ if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil {
+ return nil, err
+ }
+ }
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+
+ // i.open will take a reference on d.
+ defer d.DecRef(ctx)
+
+ // For simplicity, fileDescription.offset is set to 0. Technically, we
+ // should only set to 0 on files that are not seekable (sockets, pipes,
+ // etc.), and use the offset from the host fd otherwise when importing.
+ return i.open(ctx, d.VFSDentry(), mnt, flags)
+}
+
+// ImportFD sets up and returns a vfs.FileDescription from a donated fd.
+func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) {
+ return NewFD(ctx, mnt, hostFD, &NewFDOptions{
+ IsTTY: isTTY,
+ })
+}
+
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("host.filesystemType.GetFilesystem should never be called")
+}
+
+// Name implements FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "none"
+}
+
+// NewFilesystem sets up and returns a new hostfs filesystem.
+//
+// Note that there should only ever be one instance of host.filesystem,
+// a global mount for host fds.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, err
+ }
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.VFSFilesystem().Init(vfsObj, filesystemType{}, fs)
+ return fs.VFSFilesystem(), nil
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ d := vd.Dentry().Impl().(*kernfs.Dentry)
+ inode := d.Inode().(*inode)
+ b.PrependComponent(fmt.Sprintf("host:[%d]", inode.ino))
+ return vfs.PrependPathSyntheticError{}
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+
+ // When the reference count reaches zero, the host fd is closed.
+ refs.AtomicRefCount
+
+ // hostFD contains the host fd that this file was originally created from,
+ // which must be available at time of restore.
+ //
+ // This field is initialized at creation time and is immutable.
+ hostFD int
+
+ // ino is an inode number unique within this filesystem.
+ //
+ // This field is initialized at creation time and is immutable.
+ ino uint64
+
+ // isTTY is true if this file represents a TTY.
+ //
+ // This field is initialized at creation time and is immutable.
+ isTTY bool
+
+ // seekable is false if the host fd points to a file representing a stream,
+ // e.g. a socket or a pipe. Such files are not seekable and can return
+ // EWOULDBLOCK for I/O operations.
+ //
+ // This field is initialized at creation time and is immutable.
+ seekable bool
+
+ // wouldBlock is true if the host FD would return EWOULDBLOCK for
+ // operations that would block.
+ //
+ // This field is initialized at creation time and is immutable.
+ wouldBlock bool
+
+ // Event queue for blocking operations.
+ queue waiter.Queue
+
+ // canMap specifies whether we allow the file to be memory mapped.
+ //
+ // This field is initialized at creation time and is immutable.
+ canMap bool
+
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex
+
+ // If canMap is true, mappings tracks mappings of hostFD into
+ // memmap.MappingSpaces.
+ mappings memmap.MappingSet
+
+ // pf implements platform.File for mappings of hostFD.
+ pf inodePlatformFile
+}
+
+// CheckPermissions implements kernfs.Inode.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return err
+ }
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid))
+}
+
+// Mode implements kernfs.Inode.
+func (i *inode) Mode() linux.FileMode {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ // Retrieving the mode from the host fd using fstat(2) should not fail.
+ // If the syscall does not succeed, something is fundamentally wrong.
+ panic(fmt.Sprintf("failed to retrieve mode from host fd %d: %v", i.hostFD, err))
+ }
+ return linux.FileMode(s.Mode)
+}
+
+// Stat implements kernfs.Inode.
+func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ if opts.Mask&linux.STATX__RESERVED != 0 {
+ return linux.Statx{}, syserror.EINVAL
+ }
+ if opts.Sync&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
+ return linux.Statx{}, syserror.EINVAL
+ }
+
+ fs := vfsfs.Impl().(*filesystem)
+
+ // Limit our host call only to known flags.
+ mask := opts.Mask & linux.STATX_ALL
+ var s unix.Statx_t
+ err := unix.Statx(i.hostFD, "", int(unix.AT_EMPTY_PATH|opts.Sync), int(mask), &s)
+ if err == syserror.ENOSYS {
+ // Fallback to fstat(2), if statx(2) is not supported on the host.
+ //
+ // TODO(b/151263641): Remove fallback.
+ return i.fstat(fs)
+ }
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ // Unconditionally fill blksize, attributes, and device numbers, as
+ // indicated by /include/uapi/linux/stat.h. Inode number is always
+ // available, since we use our own rather than the host's.
+ ls := linux.Statx{
+ Mask: linux.STATX_INO,
+ Blksize: s.Blksize,
+ Attributes: s.Attributes,
+ Ino: i.ino,
+ AttributesMask: s.Attributes_mask,
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: fs.devMinor,
+ }
+
+ // Copy other fields that were returned by the host. RdevMajor/RdevMinor
+ // are never copied (and therefore left as zero), so as not to expose host
+ // device numbers.
+ ls.Mask |= s.Mask & linux.STATX_ALL
+ if s.Mask&linux.STATX_TYPE != 0 {
+ ls.Mode |= s.Mode & linux.S_IFMT
+ }
+ if s.Mask&linux.STATX_MODE != 0 {
+ ls.Mode |= s.Mode &^ linux.S_IFMT
+ }
+ if s.Mask&linux.STATX_NLINK != 0 {
+ ls.Nlink = s.Nlink
+ }
+ if s.Mask&linux.STATX_UID != 0 {
+ ls.UID = s.Uid
+ }
+ if s.Mask&linux.STATX_GID != 0 {
+ ls.GID = s.Gid
+ }
+ if s.Mask&linux.STATX_ATIME != 0 {
+ ls.Atime = unixToLinuxStatxTimestamp(s.Atime)
+ }
+ if s.Mask&linux.STATX_BTIME != 0 {
+ ls.Btime = unixToLinuxStatxTimestamp(s.Btime)
+ }
+ if s.Mask&linux.STATX_CTIME != 0 {
+ ls.Ctime = unixToLinuxStatxTimestamp(s.Ctime)
+ }
+ if s.Mask&linux.STATX_MTIME != 0 {
+ ls.Mtime = unixToLinuxStatxTimestamp(s.Mtime)
+ }
+ if s.Mask&linux.STATX_SIZE != 0 {
+ ls.Size = s.Size
+ }
+ if s.Mask&linux.STATX_BLOCKS != 0 {
+ ls.Blocks = s.Blocks
+ }
+
+ return ls, nil
+}
+
+// fstat is a best-effort fallback for inode.Stat() if the host does not
+// support statx(2).
+//
+// We ignore the mask and sync flags in opts and simply supply
+// STATX_BASIC_STATS, as fstat(2) itself does not allow the specification
+// of a mask or sync flags. fstat(2) does not provide any metadata
+// equivalent to Statx.Attributes, Statx.AttributesMask, or Statx.Btime, so
+// those fields remain empty.
+func (i *inode) fstat(fs *filesystem) (linux.Statx, error) {
+ var s unix.Stat_t
+ if err := unix.Fstat(i.hostFD, &s); err != nil {
+ return linux.Statx{}, err
+ }
+
+ // As with inode.Stat(), we always use internal device and inode numbers,
+ // and never expose the host's represented device numbers.
+ return linux.Statx{
+ Mask: linux.STATX_BASIC_STATS,
+ Blksize: uint32(s.Blksize),
+ Nlink: uint32(s.Nlink),
+ UID: s.Uid,
+ GID: s.Gid,
+ Mode: uint16(s.Mode),
+ Ino: i.ino,
+ Size: uint64(s.Size),
+ Blocks: uint64(s.Blocks),
+ Atime: timespecToStatxTimestamp(s.Atim),
+ Ctime: timespecToStatxTimestamp(s.Ctim),
+ Mtime: timespecToStatxTimestamp(s.Mtim),
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: fs.devMinor,
+ }, nil
+}
+
+// SetStat implements kernfs.Inode.
+func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ s := &opts.Stat
+
+ m := s.Mask
+ if m == 0 {
+ return nil
+ }
+ if m&^(linux.STATX_MODE|linux.STATX_SIZE|linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ return syserror.EPERM
+ }
+ var hostStat syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &hostStat); err != nil {
+ return err
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil {
+ return err
+ }
+
+ if m&linux.STATX_MODE != 0 {
+ if err := syscall.Fchmod(i.hostFD, uint32(s.Mode)); err != nil {
+ return err
+ }
+ }
+ if m&linux.STATX_SIZE != 0 {
+ if hostStat.Mode&linux.S_IFMT != linux.S_IFREG {
+ return syserror.EINVAL
+ }
+ if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil {
+ return err
+ }
+ oldSize := uint64(hostStat.Size)
+ if s.Size < oldSize {
+ oldpgend, _ := usermem.PageRoundUp(oldSize)
+ newpgend, _ := usermem.PageRoundUp(s.Size)
+ if oldpgend != newpgend {
+ i.mapsMu.Lock()
+ i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/truncate.c:truncate_setsize() =>
+ // truncate_pagecache() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ i.mapsMu.Unlock()
+ }
+ }
+ }
+ if m&(linux.STATX_ATIME|linux.STATX_MTIME) != 0 {
+ ts := [2]syscall.Timespec{
+ toTimespec(s.Atime, m&linux.STATX_ATIME == 0),
+ toTimespec(s.Mtime, m&linux.STATX_MTIME == 0),
+ }
+ if err := setTimestamps(i.hostFD, &ts); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// DecRef implements kernfs.Inode.
+func (i *inode) DecRef(ctx context.Context) {
+ i.AtomicRefCount.DecRefWithDestructor(ctx, i.Destroy)
+}
+
+// Destroy implements kernfs.Inode.
+func (i *inode) Destroy(context.Context) {
+ if i.wouldBlock {
+ fdnotifier.RemoveFD(int32(i.hostFD))
+ }
+ if err := unix.Close(i.hostFD); err != nil {
+ log.Warningf("failed to close host fd %d: %v", i.hostFD, err)
+ }
+}
+
+// Open implements kernfs.Inode.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Once created, we cannot re-open a socket fd through /proc/[pid]/fd/.
+ if i.Mode().FileType() == linux.S_IFSOCK {
+ return nil, syserror.ENXIO
+ }
+ return i.open(ctx, vfsd, rp.Mount(), opts.Flags)
+}
+
+func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return nil, err
+ }
+ fileType := s.Mode & linux.FileTypeMask
+
+ // Constrain flags to a subset we can handle.
+ //
+ // TODO(gvisor.dev/issue/2601): Support O_NONBLOCK by adding RWF_NOWAIT to pread/pwrite calls.
+ flags &= syscall.O_ACCMODE | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND
+
+ switch fileType {
+ case syscall.S_IFSOCK:
+ if i.isTTY {
+ log.Warningf("cannot use host socket fd %d as TTY", i.hostFD)
+ return nil, syserror.ENOTTY
+ }
+
+ ep, err := newEndpoint(ctx, i.hostFD, &i.queue)
+ if err != nil {
+ return nil, err
+ }
+ // Currently, we only allow Unix sockets to be imported.
+ return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks)
+
+ case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR:
+ if i.isTTY {
+ fd := &TTYFileDescription{
+ fileDescription: fileDescription{inode: i},
+ termios: linux.DefaultSlaveTermios,
+ }
+ fd.LockFD.Init(&i.locks)
+ vfsfd := &fd.vfsfd
+ if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+ }
+
+ fd := &fileDescription{inode: i}
+ fd.LockFD.Init(&i.locks)
+ vfsfd := &fd.vfsfd
+ if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+
+ default:
+ log.Warningf("cannot import host fd %d with file type %o", i.hostFD, fileType)
+ return nil, syserror.EPERM
+ }
+}
+
+// fileDescription is embedded by host fd implementations of FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ // inode is vfsfd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode), but
+ // cached to reduce indirections and casting. fileDescription does not hold
+ // a reference on the inode through the inode field (since one is already
+ // held via the Dentry).
+ //
+ // inode is immutable after fileDescription creation.
+ inode *inode
+
+ // offsetMu protects offset.
+ offsetMu sync.Mutex
+
+ // offset specifies the current file offset. It is only meaningful when
+ // inode.seekable is true.
+ offset int64
+}
+
+// SetStat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ return f.inode.SetStat(ctx, f.vfsfd.Mount().Filesystem(), creds, opts)
+}
+
+// Stat implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Release(context.Context) {
+ // noop
+}
+
+// Allocate implements vfs.FileDescriptionImpl.
+func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ if !f.inode.seekable {
+ return syserror.ESPIPE
+ }
+
+ // TODO(gvisor.dev/issue/3589): Implement Allocate for non-pipe hostfds.
+ return syserror.EOPNOTSUPP
+}
+
+// PRead implements FileDescriptionImpl.
+func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags)
+}
+
+// Read implements FileDescriptionImpl.
+func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags)
+ if isBlockError(err) {
+ // If we got any data at all, return it as a "completed" partial read
+ // rather than retrying until complete.
+ if n != 0 {
+ err = nil
+ } else {
+ err = syserror.ErrWouldBlock
+ }
+ }
+ return n, err
+ }
+
+ f.offsetMu.Lock()
+ n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags)
+ f.offset += n
+ f.offsetMu.Unlock()
+ return n, err
+}
+
+func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ // Check that flags are supported.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if flags&^linux.RWF_HIPRI != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+ reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
+ n, err := dst.CopyOutFrom(ctx, reader)
+ hostfd.PutReadWriterAt(reader)
+ return int64(n), err
+}
+
+// PWrite implements FileDescriptionImpl.
+func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ if !f.inode.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ return f.writeToHostFD(ctx, src, offset, opts.Flags)
+}
+
+// Write implements FileDescriptionImpl.
+func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ n, err := f.writeToHostFD(ctx, src, -1, opts.Flags)
+ if isBlockError(err) {
+ err = syserror.ErrWouldBlock
+ }
+ return n, err
+ }
+
+ f.offsetMu.Lock()
+ // NOTE(gvisor.dev/issue/2983): O_APPEND may cause memory corruption if
+ // another process modifies the host file between retrieving the file size
+ // and writing to the host fd. This is an unavoidable race condition because
+ // we cannot enforce synchronization on the host.
+ if f.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ f.offsetMu.Unlock()
+ return 0, err
+ }
+ f.offset = s.Size
+ }
+ n, err := f.writeToHostFD(ctx, src, f.offset, opts.Flags)
+ f.offset += n
+ f.offsetMu.Unlock()
+ return n, err
+}
+
+func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSequence, offset int64, flags uint32) (int64, error) {
+ hostFD := f.inode.hostFD
+ // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags.
+ if flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+ writer := hostfd.GetReadWriterAt(int32(hostFD), offset, flags)
+ n, err := src.CopyInTo(ctx, writer)
+ hostfd.PutReadWriterAt(writer)
+ // NOTE(gvisor.dev/issue/2979): We always sync everything, even for O_DSYNC.
+ if n > 0 && f.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 {
+ if syncErr := unix.Fsync(hostFD); syncErr != nil {
+ return int64(n), syncErr
+ }
+ }
+ return int64(n), err
+}
+
+// Seek implements FileDescriptionImpl.
+//
+// Note that we do not support seeking on directories, since we do not even
+// allow directory fds to be imported at all.
+func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (int64, error) {
+ i := f.inode
+ if !i.seekable {
+ return 0, syserror.ESPIPE
+ }
+
+ f.offsetMu.Lock()
+ defer f.offsetMu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return f.offset, syserror.EINVAL
+ }
+ f.offset = offset
+
+ case linux.SEEK_CUR:
+ // Check for overflow. Note that underflow cannot occur, since f.offset >= 0.
+ if offset > math.MaxInt64-f.offset {
+ return f.offset, syserror.EOVERFLOW
+ }
+ if f.offset+offset < 0 {
+ return f.offset, syserror.EINVAL
+ }
+ f.offset += offset
+
+ case linux.SEEK_END:
+ var s syscall.Stat_t
+ if err := syscall.Fstat(i.hostFD, &s); err != nil {
+ return f.offset, err
+ }
+ size := s.Size
+
+ // Check for overflow. Note that underflow cannot occur, since size >= 0.
+ if offset > math.MaxInt64-size {
+ return f.offset, syserror.EOVERFLOW
+ }
+ if size+offset < 0 {
+ return f.offset, syserror.EINVAL
+ }
+ f.offset = size + offset
+
+ case linux.SEEK_DATA, linux.SEEK_HOLE:
+ // Modifying the offset in the host file table should not matter, since
+ // this is the only place where we use it.
+ //
+ // For reading and writing, we always rely on our internal offset.
+ n, err := unix.Seek(i.hostFD, offset, int(whence))
+ if err != nil {
+ return f.offset, err
+ }
+ f.offset = n
+
+ default:
+ // Invalid whence.
+ return f.offset, syserror.EINVAL
+ }
+
+ return f.offset, nil
+}
+
+// Sync implements FileDescriptionImpl.
+func (f *fileDescription) Sync(context.Context) error {
+ // TODO(gvisor.dev/issue/1897): Currently, we always sync everything.
+ return unix.Fsync(f.inode.hostFD)
+}
+
+// ConfigureMMap implements FileDescriptionImpl.
+func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error {
+ if !f.inode.canMap {
+ return syserror.ENODEV
+ }
+ i := f.inode
+ i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init)
+ return vfs.GenericConfigureMMap(&f.vfsfd, i, opts)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ f.inode.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (f *fileDescription) EventUnregister(e *waiter.Entry) {
+ f.inode.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(f.inode.hostFD))
+}
+
+// Readiness uses the poll() syscall to check the status of the underlying FD.
+func (f *fileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(f.inode.hostFD), mask)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (f *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return f.Locks().LockPOSIX(ctx, &f.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (f *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return f.Locks().UnlockPOSIX(ctx, &f.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/host/ioctl_unsafe.go b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
new file mode 100644
index 000000000..0983bf7d8
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/ioctl_unsafe.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+func ioctlGetTermios(fd int) (*linux.Termios, error) {
+ var t linux.Termios
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TCGETS, uintptr(unsafe.Pointer(&t)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &t, nil
+}
+
+func ioctlSetTermios(fd int, req uint64, t *linux.Termios) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), uintptr(req), uintptr(unsafe.Pointer(t)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+func ioctlGetWinsize(fd int) (*linux.Winsize, error) {
+ var w linux.Winsize
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCGWINSZ, uintptr(unsafe.Pointer(&w)))
+ if errno != 0 {
+ return nil, errno
+ }
+ return &w, nil
+}
+
+func ioctlSetWinsize(fd int, w *linux.Winsize) error {
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), linux.TIOCSWINSZ, uintptr(unsafe.Pointer(w)))
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go
new file mode 100644
index 000000000..65d3af38c
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/mmap.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// inodePlatformFile implements memmap.File. It exists solely because inode
+// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef.
+//
+// inodePlatformFile should only be used if inode.canMap is true.
+type inodePlatformFile struct {
+ *inode
+
+ // fdRefsMu protects fdRefs.
+ fdRefsMu sync.Mutex
+
+ // fdRefs counts references on memmap.File offsets. It is used solely for
+ // memory accounting.
+ fdRefs fsutil.FrameRefSet
+
+ // fileMapper caches mappings of the host file represented by this inode.
+ fileMapper fsutil.HostFileMapper
+
+ // fileMapperInitOnce is used to lazily initialize fileMapper.
+ fileMapperInitOnce sync.Once
+}
+
+// IncRef implements memmap.File.IncRef.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
+ i.fdRefsMu.Lock()
+ i.fdRefs.IncRefAndAccount(fr)
+ i.fdRefsMu.Unlock()
+}
+
+// DecRef implements memmap.File.DecRef.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
+ i.fdRefsMu.Lock()
+ i.fdRefs.DecRefAndAccount(fr)
+ i.fdRefsMu.Unlock()
+}
+
+// MapInternal implements memmap.File.MapInternal.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+ return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
+}
+
+// FD implements memmap.File.FD.
+func (i *inodePlatformFile) FD() int {
+ return i.hostFD
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ i.mapsMu.Lock()
+ mapped := i.mappings.AddMapping(ms, ar, offset, writable)
+ for _, r := range mapped {
+ i.pf.fileMapper.IncRefOn(r)
+ }
+ i.mapsMu.Unlock()
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ i.mapsMu.Lock()
+ unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable)
+ for _, r := range unmapped {
+ i.pf.fileMapper.DecRefOn(r)
+ }
+ i.mapsMu.Unlock()
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return i.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ mr := optional
+ return []memmap.Translation{
+ {
+ Source: mr,
+ File: &i.pf,
+ Offset: mr.Start,
+ Perms: usermem.AnyAccess,
+ },
+ }, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+//
+// Precondition: i.inode.canMap must be true.
+func (i *inode) InvalidateUnsavable(ctx context.Context) error {
+ // We expect the same host fd across save/restore, so all translations
+ // should be valid.
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go
new file mode 100644
index 000000000..4979dd0a9
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket.go
@@ -0,0 +1,385 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/unet"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Create a new host-backed endpoint from the given fd and its corresponding
+// notification queue.
+func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) {
+ // Set up an external transport.Endpoint using the host fd.
+ addr := fmt.Sprintf("hostfd:[%d]", hostFD)
+ e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */)
+ if err != nil {
+ return nil, err.ToError()
+ }
+ ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), queue, e, e)
+ return ep, nil
+}
+
+// ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and
+// transport.Receiver. It is backed by a host fd that was imported at sentry
+// startup. This fd is shared with a hostfs inode, which retains ownership of
+// it.
+//
+// ConnectedEndpoint is saveable, since we expect that the host will provide
+// the same fd upon restore.
+//
+// As of this writing, we only allow Unix sockets to be imported.
+//
+// +stateify savable
+type ConnectedEndpoint struct {
+ // ref keeps track of references to a ConnectedEndpoint.
+ ref refs.AtomicRefCount
+
+ // mu protects fd below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // fd is the host fd backing this endpoint.
+ fd int
+
+ // addr is the address at which this endpoint is bound.
+ addr string
+
+ // sndbuf is the size of the send buffer.
+ //
+ // N.B. When this is smaller than the host size, we present it via
+ // GetSockOpt and message splitting/rejection in SendMsg, but do not
+ // prevent lots of small messages from filling the real send buffer
+ // size on the host.
+ sndbuf int64 `state:"nosave"`
+
+ // stype is the type of Unix socket.
+ stype linux.SockType
+}
+
+// init performs initialization required for creating new ConnectedEndpoints and
+// for restoring them.
+func (c *ConnectedEndpoint) init() *syserr.Error {
+ family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if family != syscall.AF_UNIX {
+ // We only allow Unix sockets.
+ return syserr.ErrInvalidEndpointState
+ }
+
+ stype, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ if err := syscall.SetNonblock(c.fd, true); err != nil {
+ return syserr.FromError(err)
+ }
+
+ sndbuf, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+
+ c.stype = linux.SockType(stype)
+ c.sndbuf = int64(sndbuf)
+
+ return nil
+}
+
+// NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host fd
+// imported at sentry startup,
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) {
+ e := ConnectedEndpoint{
+ fd: hostFD,
+ addr: addr,
+ }
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+ e.ref.EnableLeakCheck("host.ConnectedEndpoint")
+ return &e, nil
+}
+
+// Send implements transport.ConnectedEndpoint.Send.
+func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if !controlMessages.Empty() {
+ return 0, false, syserr.ErrInvalidEndpointState
+ }
+
+ // Since stream sockets don't preserve message boundaries, we can write
+ // only as much of the message as fits in the send buffer.
+ truncate := c.stype == linux.SOCK_STREAM
+
+ n, totalLen, err := fdWriteVec(c.fd, data, c.sndbuf, truncate)
+ if n < totalLen && err == nil {
+ // The host only returns a short write if it would otherwise
+ // block (and only for stream sockets).
+ err = syserror.EAGAIN
+ }
+ if n > 0 && err != syserror.EAGAIN {
+ // The caller may need to block to send more data, but
+ // otherwise there isn't anything that can be done about an
+ // error with a partial write.
+ err = nil
+ }
+
+ // There is no need for the callee to call SendNotify because fdWriteVec
+ // uses the host's sendmsg(2) and the host kernel's queue.
+ return n, false, syserr.FromError(err)
+}
+
+// SendNotify implements transport.ConnectedEndpoint.SendNotify.
+func (c *ConnectedEndpoint) SendNotify() {}
+
+// CloseSend implements transport.ConnectedEndpoint.CloseSend.
+func (c *ConnectedEndpoint) CloseSend() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.fd, syscall.SHUT_WR); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
+func (c *ConnectedEndpoint) CloseNotify() {}
+
+// Writable implements transport.ConnectedEndpoint.Writable.
+func (c *ConnectedEndpoint) Writable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventOut)&waiter.EventOut != 0
+}
+
+// Passcred implements transport.ConnectedEndpoint.Passcred.
+func (c *ConnectedEndpoint) Passcred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
+func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil
+}
+
+// EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
+func (c *ConnectedEndpoint) EventUpdate() {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.fd != -1 {
+ fdnotifier.UpdateFD(int32(c.fd))
+ }
+}
+
+// Recv implements transport.Receiver.Recv.
+func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ var cm unet.ControlMessage
+ if numRights > 0 {
+ cm.EnableFDs(int(numRights))
+ }
+
+ // N.B. Unix sockets don't have a receive buffer, the send buffer
+ // serves both purposes.
+ rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.sndbuf)
+ if rl > 0 && err != nil {
+ // We got some data, so all we need to do on error is return
+ // the data that we got. Short reads are fine, no need to
+ // block.
+ err = nil
+ }
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ // There is no need for the callee to call RecvNotify because fdReadVec uses
+ // the host's recvmsg(2) and the host kernel's queue.
+
+ // Trim the control data if we received less than the full amount.
+ if cl < uint64(len(cm)) {
+ cm = cm[:cl]
+ }
+
+ // Avoid extra allocations in the case where there isn't any control data.
+ if len(cm) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+ }
+
+ fds, err := cm.ExtractFDs()
+ if err != nil {
+ return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
+ }
+
+ if len(fds) == 0 {
+ return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+ }
+ return rl, ml, control.NewVFS2(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
+}
+
+// RecvNotify implements transport.Receiver.RecvNotify.
+func (c *ConnectedEndpoint) RecvNotify() {}
+
+// CloseRecv implements transport.Receiver.CloseRecv.
+func (c *ConnectedEndpoint) CloseRecv() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.fd, syscall.SHUT_RD); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+ }
+}
+
+// Readable implements transport.Receiver.Readable.
+func (c *ConnectedEndpoint) Readable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.EventIn)&waiter.EventIn != 0
+}
+
+// SendQueuedSize implements transport.Receiver.SendQueuedSize.
+func (c *ConnectedEndpoint) SendQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
+func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
+ // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
+ return -1
+}
+
+// SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
+func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
+ return int64(c.sndbuf)
+}
+
+// RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
+func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
+ // N.B. Unix sockets don't use the receive buffer. We'll claim it is
+ // the same size as the send buffer.
+ return int64(c.sndbuf)
+}
+
+func (c *ConnectedEndpoint) destroyLocked() {
+ c.fd = -1
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (c *ConnectedEndpoint) Release(ctx context.Context) {
+ c.ref.DecRefWithDestructor(ctx, func(context.Context) {
+ c.mu.Lock()
+ c.destroyLocked()
+ c.mu.Unlock()
+ })
+}
+
+// CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
+func (c *ConnectedEndpoint) CloseUnread() {}
+
+// SCMConnectedEndpoint represents an endpoint backed by a host fd that was
+// passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the
+// following differences:
+// - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
+// the same descriptor number across S/R.
+// - SCMConnectedEndpoint holds ownership of its fd and notification queue.
+type SCMConnectedEndpoint struct {
+ ConnectedEndpoint
+
+ queue *waiter.Queue
+}
+
+// Init will do the initialization required without holding other locks.
+func (e *SCMConnectedEndpoint) Init() error {
+ return fdnotifier.AddFD(int32(e.fd), e.queue)
+}
+
+// Release implements transport.ConnectedEndpoint.Release and
+// transport.Receiver.Release.
+func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
+ e.ref.DecRefWithDestructor(ctx, func(context.Context) {
+ e.mu.Lock()
+ if err := syscall.Close(e.fd); err != nil {
+ log.Warningf("Failed to close host fd %d: %v", err)
+ }
+ fdnotifier.RemoveFD(int32(e.fd))
+ e.destroyLocked()
+ e.mu.Unlock()
+ })
+}
+
+// NewSCMEndpoint creates a new SCMConnectedEndpoint backed by a host fd that
+// was passed through a Unix socket.
+//
+// The caller is responsible for calling Init(). Additionaly, Release needs to
+// be called twice because ConnectedEndpoint is both a transport.Receiver and
+// transport.ConnectedEndpoint.
+func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
+ e := SCMConnectedEndpoint{
+ ConnectedEndpoint: ConnectedEndpoint{
+ fd: hostFD,
+ addr: addr,
+ },
+ queue: queue,
+ }
+
+ if err := e.init(); err != nil {
+ return nil, err
+ }
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+ e.ref.EnableLeakCheck("host.SCMConnectedEndpoint")
+ return &e, nil
+}
diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go
new file mode 100644
index 000000000..fc0d5fd38
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_iovec.go
@@ -0,0 +1,110 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/iovec"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// copyToMulti copies as many bytes from src to dst as possible.
+func copyToMulti(dst [][]byte, src []byte) {
+ for _, d := range dst {
+ done := copy(d, src)
+ src = src[done:]
+ if len(src) == 0 {
+ break
+ }
+ }
+}
+
+// copyFromMulti copies as many bytes from src to dst as possible.
+func copyFromMulti(dst []byte, src [][]byte) {
+ for _, s := range src {
+ done := copy(dst, s)
+ dst = dst[done:]
+ if len(dst) == 0 {
+ break
+ }
+ }
+}
+
+// buildIovec builds an iovec slice from the given []byte slice.
+//
+// If truncate, truncate bufs > maxlen. Otherwise, immediately return an error.
+//
+// If length < the total length of bufs, err indicates why, even when returning
+// a truncated iovec.
+//
+// If intermediate != nil, iovecs references intermediate rather than bufs and
+// the caller must copy to/from bufs as necessary.
+func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovecs []syscall.Iovec, intermediate []byte, err error) {
+ var iovsRequired int
+ for _, b := range bufs {
+ length += int64(len(b))
+ if len(b) > 0 {
+ iovsRequired++
+ }
+ }
+
+ stopLen := length
+ if length > maxlen {
+ if truncate {
+ stopLen = maxlen
+ err = syserror.EAGAIN
+ } else {
+ return 0, nil, nil, syserror.EMSGSIZE
+ }
+ }
+
+ if iovsRequired > iovec.MaxIovs {
+ // The kernel will reject our call if we pass this many iovs.
+ // Use a single intermediate buffer instead.
+ b := make([]byte, stopLen)
+
+ return stopLen, []syscall.Iovec{{
+ Base: &b[0],
+ Len: uint64(stopLen),
+ }}, b, err
+ }
+
+ var total int64
+ iovecs = make([]syscall.Iovec, 0, iovsRequired)
+ for i := range bufs {
+ l := len(bufs[i])
+ if l == 0 {
+ continue
+ }
+
+ stop := int64(l)
+ if total+stop > stopLen {
+ stop = stopLen - total
+ }
+
+ iovecs = append(iovecs, syscall.Iovec{
+ Base: &bufs[i][0],
+ Len: uint64(stop),
+ })
+
+ total += stop
+ if total >= stopLen {
+ break
+ }
+ }
+
+ return total, iovecs, nil, err
+}
diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go
new file mode 100644
index 000000000..35ded24bc
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/socket_unsafe.go
@@ -0,0 +1,101 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// fdReadVec receives from fd to bufs.
+//
+// If the total length of bufs is > maxlen, fdReadVec will do a partial read
+// and err will indicate why the message was truncated.
+func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) (readLen int64, msgLen int64, controlLen uint64, controlTrunc bool, err error) {
+ flags := uintptr(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC)
+ if peek {
+ flags |= syscall.MSG_PEEK
+ }
+
+ // Always truncate the receive buffer. All socket types will truncate
+ // received messages.
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, true)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, 0, 0, false, err
+ }
+
+ var msg syscall.Msghdr
+ if len(control) != 0 {
+ msg.Control = &control[0]
+ msg.Controllen = uint64(len(control))
+ }
+
+ if len(iovecs) != 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ rawN, _, e := syscall.RawSyscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), flags)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, 0, 0, false, e
+ }
+ n := int64(rawN)
+
+ // Copy data back to bufs.
+ if intermediate != nil {
+ copyToMulti(bufs, intermediate)
+ }
+
+ controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC
+
+ if n > length {
+ return length, n, msg.Controllen, controlTrunc, err
+ }
+
+ return n, n, msg.Controllen, controlTrunc, err
+}
+
+// fdWriteVec sends from bufs to fd.
+//
+// If the total length of bufs is > maxlen && truncate, fdWriteVec will do a
+// partial write and err will indicate why the message was truncated.
+func fdWriteVec(fd int, bufs [][]byte, maxlen int64, truncate bool) (int64, int64, error) {
+ length, iovecs, intermediate, err := buildIovec(bufs, maxlen, truncate)
+ if err != nil && len(iovecs) == 0 {
+ // No partial write to do, return error immediately.
+ return 0, length, err
+ }
+
+ // Copy data to intermediate buf.
+ if intermediate != nil {
+ copyFromMulti(intermediate, bufs)
+ }
+
+ var msg syscall.Msghdr
+ if len(iovecs) > 0 {
+ msg.Iov = &iovecs[0]
+ msg.Iovlen = uint64(len(iovecs))
+ }
+
+ n, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msg)), syscall.MSG_DONTWAIT|syscall.MSG_NOSIGNAL)
+ if e != 0 {
+ // N.B. prioritize the syscall error over the buildIovec error.
+ return 0, length, e
+ }
+
+ return int64(n), length, err
+}
diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go
new file mode 100644
index 000000000..d372c60cb
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/tty.go
@@ -0,0 +1,390 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TTYFileDescription implements vfs.FileDescriptionImpl for a host file
+// descriptor that wraps a TTY FD.
+type TTYFileDescription struct {
+ fileDescription
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // session is the session attached to this TTYFileDescription.
+ session *kernel.Session
+
+ // fgProcessGroup is the foreground process group that is currently
+ // connected to this TTY.
+ fgProcessGroup *kernel.ProcessGroup
+
+ // termios contains the terminal attributes for this TTY.
+ termios linux.KernelTermios
+}
+
+// InitForegroundProcessGroup sets the foreground process group and session for
+// the TTY. This should only be called once, after the foreground process group
+// has been created, but before it has started running.
+func (t *TTYFileDescription) InitForegroundProcessGroup(pg *kernel.ProcessGroup) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.fgProcessGroup != nil {
+ panic("foreground process group is already set")
+ }
+ t.fgProcessGroup = pg
+ t.session = pg.Session()
+}
+
+// ForegroundProcessGroup returns the foreground process for the TTY.
+func (t *TTYFileDescription) ForegroundProcessGroup() *kernel.ProcessGroup {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.fgProcessGroup
+}
+
+// Release implements fs.FileOperations.Release.
+func (t *TTYFileDescription) Release(ctx context.Context) {
+ t.mu.Lock()
+ t.fgProcessGroup = nil
+ t.mu.Unlock()
+
+ t.fileDescription.Release(ctx)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *TTYFileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.PRead(ctx, dst, offset, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.
+//
+// Reading from a TTY is only allowed for foreground process groups. Background
+// process groups will either get EIO or a SIGTTIN.
+func (t *TTYFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Are we allowed to do the read?
+ // drivers/tty/n_tty.c:n_tty_read()=>job_control()=>tty_check_change().
+ if err := t.checkChange(ctx, linux.SIGTTIN); err != nil {
+ return 0, err
+ }
+
+ // Do the read.
+ return t.fileDescription.Read(ctx, dst, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.PWrite(ctx, src, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+ }
+ return t.fileDescription.Write(ctx, src, opts)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Ignore arg[0]. This is the real FD:
+ fd := t.inode.hostFD
+ ioctl := args[1].Uint64()
+ switch ioctl {
+ case linux.TCGETS:
+ termios, err := ioctlGetTermios(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TCSETS, linux.TCSETSW, linux.TCSETSF:
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
+
+ var termios linux.Termios
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
+ return 0, err
+
+ case linux.TIOCGPGRP:
+ // Args: pid_t *argp
+ // When successful, equivalent to *argp = tcgetpgrp(fd).
+ // Get the process group ID of the foreground process group on this
+ // terminal.
+
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Map the ProcessGroup into a ProcessGroupID in the task's PID namespace.
+ pgID := pidns.IDOfProcessGroup(t.fgProcessGroup)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSPGRP:
+ // Args: const pid_t *argp
+ // Equivalent to tcsetpgrp(fd, *argp).
+ // Set the foreground process group ID of this terminal.
+
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ return 0, syserror.ENOTTY
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Check that we are allowed to set the process group.
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ // drivers/tty/tty_io.c:tiocspgrp() converts -EIO from tty_check_change()
+ // to -ENOTTY.
+ if err == syserror.EIO {
+ return 0, syserror.ENOTTY
+ }
+ return 0, err
+ }
+
+ // Check that calling task's process group is in the TTY session.
+ if task.ThreadGroup().Session() != t.session {
+ return 0, syserror.ENOTTY
+ }
+
+ var pgID kernel.ProcessGroupID
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ // pgID must be non-negative.
+ if pgID < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Process group with pgID must exist in this PID namespace.
+ pidns := task.PIDNamespace()
+ pg := pidns.ProcessGroupWithID(pgID)
+ if pg == nil {
+ return 0, syserror.ESRCH
+ }
+
+ // Check that new process group is in the TTY session.
+ if pg.Session() != t.session {
+ return 0, syserror.EPERM
+ }
+
+ t.fgProcessGroup = pg
+ return 0, nil
+
+ case linux.TIOCGWINSZ:
+ // Args: struct winsize *argp
+ // Get window size.
+ winsize, err := ioctlGetWinsize(fd)
+ if err != nil {
+ return 0, err
+ }
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCSWINSZ:
+ // Args: const struct winsize *argp
+ // Set window size.
+
+ // Unlike setting the termios, any process group (even background ones) can
+ // set the winsize.
+
+ var winsize linux.Winsize
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ err := ioctlSetWinsize(fd, &winsize)
+ return 0, err
+
+ // Unimplemented commands.
+ case linux.TIOCSETD,
+ linux.TIOCSBRK,
+ linux.TIOCCBRK,
+ linux.TCSBRK,
+ linux.TCSBRKP,
+ linux.TIOCSTI,
+ linux.TIOCCONS,
+ linux.FIONBIO,
+ linux.TIOCEXCL,
+ linux.TIOCNXCL,
+ linux.TIOCGEXCL,
+ linux.TIOCNOTTY,
+ linux.TIOCSCTTY,
+ linux.TIOCGSID,
+ linux.TIOCGETD,
+ linux.TIOCVHANGUP,
+ linux.TIOCGDEV,
+ linux.TIOCMGET,
+ linux.TIOCMSET,
+ linux.TIOCMBIC,
+ linux.TIOCMBIS,
+ linux.TIOCGICOUNT,
+ linux.TCFLSH,
+ linux.TIOCSSERIAL,
+ linux.TIOCGPTPEER:
+
+ unimpl.EmitUnimplementedEvent(ctx)
+ fallthrough
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// checkChange checks that the process group is allowed to read, write, or
+// change the state of the TTY.
+//
+// This corresponds to Linux drivers/tty/tty_io.c:tty_check_change(). The logic
+// is a bit convoluted, but documented inline.
+//
+// Preconditions: t.mu must be held.
+func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ // No task? Linux does not have an analog for this case, but
+ // tty_check_change only blocks specific cases and is
+ // surprisingly permissive. Allowing the change seems
+ // appropriate.
+ return nil
+ }
+
+ tg := task.ThreadGroup()
+ pg := tg.ProcessGroup()
+
+ // If the session for the task is different than the session for the
+ // controlling TTY, then the change is allowed. Seems like a bad idea,
+ // but that's exactly what linux does.
+ if tg.Session() != t.fgProcessGroup.Session() {
+ return nil
+ }
+
+ // If we are the foreground process group, then the change is allowed.
+ if pg == t.fgProcessGroup {
+ return nil
+ }
+
+ // We are not the foreground process group.
+
+ // Is the provided signal blocked or ignored?
+ if (task.SignalMask()&linux.SignalSetOf(sig) != 0) || tg.SignalHandlers().IsIgnored(sig) {
+ // If the signal is SIGTTIN, then we are attempting to read
+ // from the TTY. Don't send the signal and return EIO.
+ if sig == linux.SIGTTIN {
+ return syserror.EIO
+ }
+
+ // Otherwise, we are writing or changing terminal state. This is allowed.
+ return nil
+ }
+
+ // If the process group is an orphan, return EIO.
+ if pg.IsOrphan() {
+ return syserror.EIO
+ }
+
+ // Otherwise, send the signal to the process group and return ERESTARTSYS.
+ //
+ // Note that Linux also unconditionally sets TIF_SIGPENDING on current,
+ // but this isn't necessary in gVisor because the rationale given in
+ // 040b6362d58f "tty: fix leakage of -ERESTARTSYS to userland" doesn't
+ // apply: the sentry will handle -ERESTARTSYS in
+ // kernel.runApp.execute() even if the kernel.Task isn't interrupted.
+ //
+ // Linux ignores the result of kill_pgrp().
+ _ = pg.SendSignal(kernel.SignalInfoPriv(sig))
+ return kernel.ERESTARTSYS
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (t *TTYFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, typ fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return t.Locks().LockPOSIX(ctx, &t.vfsfd, uid, typ, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (t *TTYFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return t.Locks().UnlockPOSIX(ctx, &t.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go
new file mode 100644
index 000000000..412bdb2eb
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/util.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func toTimespec(ts linux.StatxTimestamp, omit bool) syscall.Timespec {
+ if omit {
+ return syscall.Timespec{
+ Sec: 0,
+ Nsec: unix.UTIME_OMIT,
+ }
+ }
+ return syscall.Timespec{
+ Sec: ts.Sec,
+ Nsec: int64(ts.Nsec),
+ }
+}
+
+func unixToLinuxStatxTimestamp(ts unix.StatxTimestamp) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: ts.Sec, Nsec: ts.Nsec}
+}
+
+func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp {
+ return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)}
+}
+
+// wouldBlock returns true for file types that can return EWOULDBLOCK
+// for blocking operations, e.g. pipes, character devices, and sockets.
+func wouldBlock(fileType uint32) bool {
+ return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK
+}
+
+// isBlockError checks if an error is EAGAIN or EWOULDBLOCK.
+// If so, they can be transformed into syserror.ErrWouldBlock.
+func isBlockError(err error) bool {
+ return err == syserror.EAGAIN || err == syserror.EWOULDBLOCK
+}
diff --git a/pkg/sentry/fsimpl/host/util_unsafe.go b/pkg/sentry/fsimpl/host/util_unsafe.go
new file mode 100644
index 000000000..5136ac844
--- /dev/null
+++ b/pkg/sentry/fsimpl/host/util_unsafe.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+func setTimestamps(fd int, ts *[2]syscall.Timespec) error {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_UTIMENSAT,
+ uintptr(fd),
+ 0, /* path */
+ uintptr(unsafe.Pointer(ts)),
+ 0, /* flags */
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD
new file mode 100644
index 000000000..3835557fe
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/BUILD
@@ -0,0 +1,75 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "kernfs",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "Dentry",
+ },
+)
+
+go_template_instance(
+ name = "slot_list",
+ out = "slot_list.go",
+ package = "kernfs",
+ prefix = "slot",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*slot",
+ "Linker": "*slot",
+ },
+)
+
+go_library(
+ name = "kernfs",
+ srcs = [
+ "dynamic_bytes_file.go",
+ "fd_impl_util.go",
+ "filesystem.go",
+ "fstree.go",
+ "inode_impl_util.go",
+ "kernfs.go",
+ "slot_list.go",
+ "symlink.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "kernfs_test",
+ size = "small",
+ srcs = ["kernfs_test.go"],
+ deps = [
+ ":kernfs",
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
new file mode 100644
index 000000000..12adf727a
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -0,0 +1,147 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// DynamicBytesFile implements kernfs.Inode and represents a read-only
+// file whose contents are backed by a vfs.DynamicBytesSource.
+//
+// Must be instantiated with NewDynamicBytesFile or initialized with Init
+// before first use.
+//
+// +stateify savable
+type DynamicBytesFile struct {
+ InodeAttrs
+ InodeNoopRefCount
+ InodeNotDirectory
+ InodeNotSymlink
+
+ locks vfs.FileLocks
+ data vfs.DynamicBytesSource
+}
+
+var _ Inode = (*DynamicBytesFile)(nil)
+
+// Init initializes a dynamic bytes file.
+func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+ f.data = data
+}
+
+// Open implements Inode.Open.
+func (f *DynamicBytesFile) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &DynamicBytesFD{}
+ if err := fd.Init(rp.Mount(), vfsd, f.data, &f.locks, opts.Flags); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements Inode.SetStat. By default DynamicBytesFile doesn't allow
+// inode attributes to be changed. Override SetStat() making it call
+// f.InodeAttrs to allow it.
+func (*DynamicBytesFile) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// DynamicBytesFD implements vfs.FileDescriptionImpl for an FD backed by a
+// DynamicBytesFile.
+//
+// Must be initialized with Init before first use.
+//
+// +stateify savable
+type DynamicBytesFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.DynamicBytesFileDescriptionImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+ inode Inode
+}
+
+// Init initializes a DynamicBytesFD.
+func (fd *DynamicBytesFD) Init(m *vfs.Mount, d *vfs.Dentry, data vfs.DynamicBytesSource, locks *vfs.FileLocks, flags uint32) error {
+ fd.LockFD.Init(locks)
+ if err := fd.vfsfd.Init(fd, flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return err
+ }
+ fd.inode = d.Impl().(*Dentry).inode
+ fd.SetDataSource(data)
+ return nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *DynamicBytesFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.Write(ctx, src, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *DynamicBytesFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.DynamicBytesFileDescriptionImpl.PWrite(ctx, src, offset, opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *DynamicBytesFD) Release(context.Context) {}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *DynamicBytesFD) SetStat(context.Context, vfs.SetStatOptions) error {
+ // DynamicBytesFiles are immutable.
+ return syserror.EPERM
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *DynamicBytesFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *DynamicBytesFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
new file mode 100644
index 000000000..fcee6200a
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go
@@ -0,0 +1,252 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory
+// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not
+// compatible with dynamic directories.
+//
+// Note that GenericDirectoryFD holds a lock over OrderedChildren while calling
+// IterDirents callback. The IterDirents callback therefore cannot hash or
+// unhash children, or recursively call IterDirents on the same underlying
+// inode.
+//
+// Must be initialize with Init before first use.
+//
+// Lock ordering: mu => children.mu.
+type GenericDirectoryFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.DirectoryFileDescriptionDefaultImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+ children *OrderedChildren
+
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ // off is the current directory offset. Protected by "mu".
+ off int64
+}
+
+// NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its
+// dentry.
+func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) {
+ fd := &GenericDirectoryFD{}
+ if err := fd.Init(children, locks, opts); err != nil {
+ return nil, err
+ }
+ if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return fd, nil
+}
+
+// Init initializes a GenericDirectoryFD. Use it when overriding
+// GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the
+// correct implementation.
+func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) error {
+ if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 {
+ // Can't open directories for writing.
+ return syserror.EISDIR
+ }
+ fd.LockFD.Init(locks)
+ fd.children = children
+ return nil
+}
+
+// VFSFileDescription returns a pointer to the vfs.FileDescription representing
+// this object.
+func (fd *GenericDirectoryFD) VFSFileDescription() *vfs.FileDescription {
+ return &fd.vfsfd
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *GenericDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return fd.FileDescriptionDefaultImpl.ConfigureMMap(ctx, opts)
+}
+
+// Read implmenets vfs.FileDescriptionImpl.Read.
+func (fd *GenericDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.Read(ctx, dst, opts)
+}
+
+// PRead implmenets vfs.FileDescriptionImpl.PRead.
+func (fd *GenericDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.PRead(ctx, dst, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *GenericDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.Write(ctx, src, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts)
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *GenericDirectoryFD) Release(context.Context) {}
+
+func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem {
+ return fd.vfsfd.VirtualDentry().Mount().Filesystem()
+}
+
+func (fd *GenericDirectoryFD) inode() Inode {
+ return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds
+// o.mu when calling cb.
+func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ opts := vfs.StatOptions{Mask: linux.STATX_INO}
+ // Handle ".".
+ if fd.off == 0 {
+ stat, err := fd.inode().Stat(ctx, fd.filesystem(), opts)
+ if err != nil {
+ return err
+ }
+ dirent := vfs.Dirent{
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: stat.Ino,
+ NextOff: 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ // Handle "..".
+ if fd.off == 1 {
+ vfsd := fd.vfsfd.VirtualDentry().Dentry()
+ parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode
+ stat, err := parentInode.Stat(ctx, fd.filesystem(), opts)
+ if err != nil {
+ return err
+ }
+ dirent := vfs.Dirent{
+ Name: "..",
+ Type: linux.FileMode(stat.Mode).DirentType(),
+ Ino: stat.Ino,
+ NextOff: 2,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ // Handle static children.
+ fd.children.mu.RLock()
+ defer fd.children.mu.RUnlock()
+ // fd.off accounts for "." and "..", but fd.children do not track
+ // these.
+ childIdx := fd.off - 2
+ for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() {
+ inode := it.Dentry.Impl().(*Dentry).inode
+ stat, err := inode.Stat(ctx, fd.filesystem(), opts)
+ if err != nil {
+ return err
+ }
+ dirent := vfs.Dirent{
+ Name: it.Name,
+ Type: linux.FileMode(stat.Mode).DirentType(),
+ Ino: stat.Ino,
+ NextOff: fd.off + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return err
+ }
+ fd.off++
+ }
+
+ var err error
+ relOffset := fd.off - int64(len(fd.children.set)) - 2
+ fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset)
+ return err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ // Use offset as given.
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ // TODO(gvisor.dev/issue/1193): This can prevent new files from showing up
+ // if they are added after SEEK_END.
+ offset = math.MaxInt64
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := fd.filesystem()
+ inode := fd.inode()
+ return inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode
+ return inode.SetStat(ctx, fd.filesystem(), creds, opts)
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *GenericDirectoryFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go
new file mode 100644
index 000000000..d7edb6342
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/filesystem.go
@@ -0,0 +1,840 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+// This file implements vfs.FilesystemImpl for kernfs.
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// stepExistingLocked resolves rp.Component() in parent directory vfsd.
+//
+// stepExistingLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done().
+//
+// Postcondition: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, mayFollowSymlinks bool) (*vfs.Dentry, error) {
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ // Directory searchable?
+ if err := d.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ // Revalidation must be skipped if name is "." or ".."; d or its parent
+ // respectively can't be expected to transition from invalidated back to
+ // valid, so detecting invalidation and retrying would loop forever. This
+ // is consistent with Linux: fs/namei.c:walk_component() => lookup_fast()
+ // calls d_revalidate(), but walk_component() => handle_dots() does not.
+ if name == "." {
+ rp.Advance()
+ return vfsd, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(ctx, vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return vfsd, nil
+ }
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return &d.parent.vfsd, nil
+ }
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ d.dirMu.Lock()
+ next, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), d, name, d.children[name])
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.CheckMount(ctx, &next.vfsd); err != nil {
+ return nil, err
+ }
+ // Resolve any symlink at current path component.
+ if mayFollowSymlinks && rp.ShouldFollowSymlink() && next.isSymlink() {
+ targetVD, targetPathname, err := next.inode.Getlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef(ctx)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
+ return nil, err
+ }
+ }
+ goto afterSymlink
+ }
+ rp.Advance()
+ return &next.vfsd, nil
+}
+
+// revalidateChildLocked must be called after a call to parent.vfsd.Child(name)
+// or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be
+// nil) to verify that the returned child (or lack thereof) is correct.
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+// parent.dirMu must be locked. parent.isDir(). name is not "." or "..".
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *Dentry, name string, child *Dentry) (*Dentry, error) {
+ if child != nil {
+ // Cached dentry exists, revalidate.
+ if !child.inode.Valid(ctx) {
+ delete(parent.children, name)
+ vfsObj.InvalidateDentry(ctx, &child.vfsd)
+ fs.deferDecRef(&child.vfsd) // Reference from Lookup.
+ child = nil
+ }
+ }
+ if child == nil {
+ // Dentry isn't cached; it either doesn't exist or failed
+ // revalidation. Attempt to resolve it via Lookup.
+ //
+ // FIXME(gvisor.dev/issue/1193): Inode.Lookup() should return
+ // *(kernfs.)Dentry, not *vfs.Dentry, since (kernfs.)Filesystem assumes
+ // that all dentries in the filesystem are (kernfs.)Dentry and performs
+ // vfs.DentryImpl casts accordingly.
+ childVFSD, err := parent.inode.Lookup(ctx, name)
+ if err != nil {
+ return nil, err
+ }
+ // Reference on childVFSD dropped by a corresponding Valid.
+ child = childVFSD.Impl().(*Dentry)
+ parent.insertChildLocked(name, child)
+ }
+ return child, nil
+}
+
+// walkExistingLocked resolves rp to an existing file.
+//
+// walkExistingLocked is loosely analogous to Linux's
+// fs/namei.c:path_lookupat().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
+ vfsd := rp.Start()
+ for !rp.Done() {
+ var err error
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ d := vfsd.Impl().(*Dentry)
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, d.inode, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory. It does not check that the returned directory is
+// searchable by the provider of rp.
+//
+// walkParentDirLocked is loosely analogous to Linux's
+// fs/namei.c:path_parentat().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done().
+//
+// Postconditions: Caller must call fs.processDeferredDecRefs*.
+func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) {
+ vfsd := rp.Start()
+ for !rp.Final() {
+ var err error
+ vfsd, err = fs.stepExistingLocked(ctx, rp, vfsd, true /* mayFollowSymlinks */)
+ if err != nil {
+ return nil, nil, err
+ }
+ }
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, nil, syserror.ENOTDIR
+ }
+ return vfsd, d.inode, nil
+}
+
+// checkCreateLocked checks that a file named rp.Component() may be created in
+// directory parentVFSD, then returns rp.Component().
+//
+// Preconditions: Filesystem.mu must be locked for at least reading. parentInode
+// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true.
+func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) {
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return "", err
+ }
+ pc := rp.Component()
+ if pc == "." || pc == ".." {
+ return "", syserror.EEXIST
+ }
+ if len(pc) > linux.NAME_MAX {
+ return "", syserror.ENAMETOOLONG
+ }
+ // FIXME(gvisor.dev/issue/1193): Data race due to not holding dirMu.
+ if _, ok := parentVFSD.Impl().(*Dentry).children[pc]; ok {
+ return "", syserror.EEXIST
+ }
+ if parentVFSD.IsDead() {
+ return "", syserror.ENOENT
+ }
+ return pc, nil
+}
+
+// checkDeleteLocked checks that the file represented by vfsd may be deleted.
+//
+// Preconditions: Filesystem.mu must be locked for at least reading.
+func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry) error {
+ parent := vfsd.Impl().(*Dentry).parent
+ if parent == nil {
+ return syserror.EBUSY
+ }
+ if parent.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *Filesystem) Release(context.Context) {
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *Filesystem) Sync(ctx context.Context) error {
+ // All filesystem state is in-memory.
+ return nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *Filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
+
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return err
+ }
+ return inode.CheckPermissions(ctx, creds, ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *Filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+
+ if opts.CheckSearchable {
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ vfsd.IncRef() // Ownership transferred to caller.
+ return vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *Filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
+ vfsd, _, err := fs.walkParentDirLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ vfsd.IncRef() // Ownership transferred to caller.
+ return vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked(ctx)
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+
+ d := vd.Dentry().Impl().(*Dentry)
+ if d.isDir() {
+ return syserror.EPERM
+ }
+
+ childVFSD, err := parentInode.NewLink(ctx, pc, d.inode)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked(ctx)
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ childVFSD, err := parentInode.NewDir(ctx, pc, opts)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked(ctx)
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ newVFSD, err := parentInode.NewNode(ctx, pc, opts)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, newVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ // Filter out flags that are not supported by kernfs. O_DIRECTORY and
+ // O_NOFOLLOW have no effect here (they're handled by VFS by setting
+ // appropriate bits in rp), but are returned by
+ // FileDescriptionImpl.StatusFlags().
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
+ ats := vfs.AccessTypesForOpenFlags(&opts)
+
+ // Do not create new file.
+ if opts.Flags&linux.O_CREAT == 0 {
+ fs.mu.RLock()
+ defer fs.processDeferredDecRefs(ctx)
+ defer fs.mu.RUnlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return inode.Open(ctx, rp, vfsd, opts)
+ }
+
+ // May create new file.
+ mustCreate := opts.Flags&linux.O_EXCL != 0
+ vfsd := rp.Start()
+ inode := vfsd.Impl().(*Dentry).inode
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ if rp.Done() {
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return inode.Open(ctx, rp, vfsd, opts)
+ }
+afterTrailingSymlink:
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked(ctx)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ pc := rp.Component()
+ if pc == "." || pc == ".." {
+ return nil, syserror.EISDIR
+ }
+ if len(pc) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ // Determine whether or not we need to create a file.
+ childVFSD, err := fs.stepExistingLocked(ctx, rp, parentVFSD, false /* mayFollowSymlinks */)
+ if err == syserror.ENOENT {
+ // Already checked for searchability above; now check for writability.
+ if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer rp.Mount().EndWrite()
+ // Create and open the child.
+ childVFSD, err = parentInode.NewFile(ctx, pc, opts)
+ if err != nil {
+ return nil, err
+ }
+ child := childVFSD.Impl().(*Dentry)
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, child)
+ return child.inode.Open(ctx, rp, childVFSD, opts)
+ }
+ if err != nil {
+ return nil, err
+ }
+ // Open existing file or follow symlink.
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ child := childVFSD.Impl().(*Dentry)
+ if rp.ShouldFollowSymlink() && child.isSymlink() {
+ targetVD, targetPathname, err := child.inode.Getlink(ctx, rp.Mount())
+ if err != nil {
+ return nil, err
+ }
+ if targetVD.Ok() {
+ err := rp.HandleJump(targetVD)
+ targetVD.DecRef(ctx)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ if err := rp.HandleSymlink(targetPathname); err != nil {
+ return nil, err
+ }
+ }
+ // rp.Final() may no longer be true since we now need to resolve the
+ // symlink target.
+ goto afterTrailingSymlink
+ }
+ if err := child.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ return child.inode.Open(ctx, rp, &child.vfsd, opts)
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ fs.mu.RLock()
+ d, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return "", err
+ }
+ if !d.Impl().(*Dentry).isSymlink() {
+ return "", syserror.EINVAL
+ }
+ return inode.Readlink(ctx)
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ // Only RENAME_NOREPLACE is supported.
+ if opts.Flags&^linux.RENAME_NOREPLACE != 0 {
+ return syserror.EINVAL
+ }
+ noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0
+
+ fs.mu.Lock()
+ defer fs.processDeferredDecRefsLocked(ctx)
+ defer fs.mu.Unlock()
+
+ // Resolve the destination directory first to verify that it's on this
+ // Mount.
+ dstDirVFSD, dstDirInode, err := fs.walkParentDirLocked(ctx, rp)
+ if err != nil {
+ return err
+ }
+ dstDir := dstDirVFSD.Impl().(*Dentry)
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ srcDirVFSD := oldParentVD.Dentry()
+ srcDir := srcDirVFSD.Impl().(*Dentry)
+ srcDir.dirMu.Lock()
+ src, err := fs.revalidateChildLocked(ctx, rp.VirtualFilesystem(), srcDir, oldName, srcDir.children[oldName])
+ srcDir.dirMu.Unlock()
+ if err != nil {
+ return err
+ }
+ srcVFSD := &src.vfsd
+
+ // Can we remove the src dentry?
+ if err := checkDeleteLocked(ctx, rp, srcVFSD); err != nil {
+ return err
+ }
+
+ // Can we create the dst dentry?
+ var dst *Dentry
+ pc, err := checkCreateLocked(ctx, rp, dstDirVFSD, dstDirInode)
+ switch err {
+ case nil:
+ // Ok, continue with rename as replacement.
+ case syserror.EEXIST:
+ if noReplace {
+ // Won't overwrite existing node since RENAME_NOREPLACE was requested.
+ return syserror.EEXIST
+ }
+ dst = dstDir.children[pc]
+ if dst == nil {
+ panic(fmt.Sprintf("Child %q for parent Dentry %+v disappeared inside atomic section?", pc, dstDirVFSD))
+ }
+ default:
+ return err
+ }
+ var dstVFSD *vfs.Dentry
+ if dst != nil {
+ dstVFSD = &dst.vfsd
+ }
+
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ virtfs := rp.VirtualFilesystem()
+
+ // We can't deadlock here due to lock ordering because we're protected from
+ // concurrent renames by fs.mu held for writing.
+ srcDir.dirMu.Lock()
+ defer srcDir.dirMu.Unlock()
+ if srcDir != dstDir {
+ dstDir.dirMu.Lock()
+ defer dstDir.dirMu.Unlock()
+ }
+
+ if err := virtfs.PrepareRenameDentry(mntns, srcVFSD, dstVFSD); err != nil {
+ return err
+ }
+ replaced, err := srcDir.inode.Rename(ctx, src.name, pc, srcVFSD, dstDirVFSD)
+ if err != nil {
+ virtfs.AbortRenameDentry(srcVFSD, dstVFSD)
+ return err
+ }
+ delete(srcDir.children, src.name)
+ if srcDir != dstDir {
+ fs.deferDecRef(srcDirVFSD)
+ dstDir.IncRef()
+ }
+ src.parent = dstDir
+ src.name = pc
+ if dstDir.children == nil {
+ dstDir.children = make(map[string]*Dentry)
+ }
+ dstDir.children[pc] = src
+ virtfs.CommitRenameReplaceDentry(ctx, srcVFSD, replaced)
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ vfsd, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked(ctx)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
+ return err
+ }
+ d := vfsd.Impl().(*Dentry)
+ if !d.isDir() {
+ return syserror.ENOTDIR
+ }
+ if inode.HasChildren() {
+ return syserror.ENOTEMPTY
+ }
+ virtfs := rp.VirtualFilesystem()
+ parentDentry := d.parent
+ parentDentry.dirMu.Lock()
+ defer parentDentry.dirMu.Unlock()
+
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
+ return err
+ }
+ if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil {
+ virtfs.AbortDeleteDentry(vfsd)
+ return err
+ }
+ virtfs.CommitDeleteDentry(ctx, vfsd)
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *Filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return err
+ }
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return inode.SetStat(ctx, fs.VFSFilesystem(), rp.Credentials(), opts)
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ return inode.Stat(ctx, fs.VFSFilesystem(), opts)
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ // TODO(gvisor.dev/issue/1193): actually implement statfs.
+ return linux.Statfs{}, syserror.ENOSYS
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ if rp.Done() {
+ return syserror.EEXIST
+ }
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentVFSD, parentInode, err := fs.walkParentDirLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked(ctx)
+ if err != nil {
+ return err
+ }
+ pc, err := checkCreateLocked(ctx, rp, parentVFSD, parentInode)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ childVFSD, err := parentInode.NewSymlink(ctx, pc, target)
+ if err != nil {
+ return err
+ }
+ parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry))
+ return nil
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ vfsd, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.processDeferredDecRefsLocked(ctx)
+ if err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ if err := checkDeleteLocked(ctx, rp, vfsd); err != nil {
+ return err
+ }
+ d := vfsd.Impl().(*Dentry)
+ if d.isDir() {
+ return syserror.EISDIR
+ }
+ virtfs := rp.VirtualFilesystem()
+ parentDentry := d.parent
+ parentDentry.dirMu.Lock()
+ defer parentDentry.dirMu.Unlock()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil {
+ return err
+ }
+ if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil {
+ virtfs.AbortDeleteDentry(vfsd)
+ return err
+ }
+ virtfs.CommitDeleteDentry(ctx, vfsd)
+ return nil
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ fs.mu.RLock()
+ _, inode, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return nil, err
+ }
+ // kernfs currently does not support extended attributes.
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return "", err
+ }
+ // kernfs currently does not support extended attributes.
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ _, _, err := fs.walkExistingLocked(ctx, rp)
+ fs.mu.RUnlock()
+ fs.processDeferredDecRefs(ctx)
+ if err != nil {
+ return err
+ }
+ // kernfs currently does not support extended attributes.
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *Filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*Dentry), b)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
new file mode 100644
index 000000000..c3efcf3ec
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go
@@ -0,0 +1,613 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// InodeNoopRefCount partially implements the Inode interface, specifically the
+// inodeRefs sub interface. InodeNoopRefCount implements a simple reference
+// count for inodes, performing no extra actions when references are obtained or
+// released. This is suitable for simple file inodes that don't reference any
+// resources.
+type InodeNoopRefCount struct {
+}
+
+// IncRef implements Inode.IncRef.
+func (InodeNoopRefCount) IncRef() {
+}
+
+// DecRef implements Inode.DecRef.
+func (InodeNoopRefCount) DecRef(context.Context) {
+}
+
+// TryIncRef implements Inode.TryIncRef.
+func (InodeNoopRefCount) TryIncRef() bool {
+ return true
+}
+
+// Destroy implements Inode.Destroy.
+func (InodeNoopRefCount) Destroy(context.Context) {
+}
+
+// InodeDirectoryNoNewChildren partially implements the Inode interface.
+// InodeDirectoryNoNewChildren represents a directory inode which does not
+// support creation of new children.
+type InodeDirectoryNoNewChildren struct{}
+
+// NewFile implements Inode.NewFile.
+func (InodeDirectoryNoNewChildren) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewDir implements Inode.NewDir.
+func (InodeDirectoryNoNewChildren) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewLink implements Inode.NewLink.
+func (InodeDirectoryNoNewChildren) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewSymlink implements Inode.NewSymlink.
+func (InodeDirectoryNoNewChildren) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// NewNode implements Inode.NewNode.
+func (InodeDirectoryNoNewChildren) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+// InodeNotDirectory partially implements the Inode interface, specifically the
+// inodeDirectory and inodeDynamicDirectory sub interfaces. Inodes that do not
+// represent directories can embed this to provide no-op implementations for
+// directory-related functions.
+type InodeNotDirectory struct {
+}
+
+// HasChildren implements Inode.HasChildren.
+func (InodeNotDirectory) HasChildren() bool {
+ return false
+}
+
+// NewFile implements Inode.NewFile.
+func (InodeNotDirectory) NewFile(context.Context, string, vfs.OpenOptions) (*vfs.Dentry, error) {
+ panic("NewFile called on non-directory inode")
+}
+
+// NewDir implements Inode.NewDir.
+func (InodeNotDirectory) NewDir(context.Context, string, vfs.MkdirOptions) (*vfs.Dentry, error) {
+ panic("NewDir called on non-directory inode")
+}
+
+// NewLink implements Inode.NewLinkink.
+func (InodeNotDirectory) NewLink(context.Context, string, Inode) (*vfs.Dentry, error) {
+ panic("NewLink called on non-directory inode")
+}
+
+// NewSymlink implements Inode.NewSymlink.
+func (InodeNotDirectory) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ panic("NewSymlink called on non-directory inode")
+}
+
+// NewNode implements Inode.NewNode.
+func (InodeNotDirectory) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ panic("NewNode called on non-directory inode")
+}
+
+// Unlink implements Inode.Unlink.
+func (InodeNotDirectory) Unlink(context.Context, string, *vfs.Dentry) error {
+ panic("Unlink called on non-directory inode")
+}
+
+// RmDir implements Inode.RmDir.
+func (InodeNotDirectory) RmDir(context.Context, string, *vfs.Dentry) error {
+ panic("RmDir called on non-directory inode")
+}
+
+// Rename implements Inode.Rename.
+func (InodeNotDirectory) Rename(context.Context, string, string, *vfs.Dentry, *vfs.Dentry) (*vfs.Dentry, error) {
+ panic("Rename called on non-directory inode")
+}
+
+// Lookup implements Inode.Lookup.
+func (InodeNotDirectory) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ panic("Lookup called on non-directory inode")
+}
+
+// IterDirents implements Inode.IterDirents.
+func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) {
+ panic("IterDirents called on non-directory inode")
+}
+
+// Valid implements Inode.Valid.
+func (InodeNotDirectory) Valid(context.Context) bool {
+ return true
+}
+
+// InodeNoDynamicLookup partially implements the Inode interface, specifically
+// the inodeDynamicLookup sub interface. Directory inodes that do not support
+// dymanic entries (i.e. entries that are not "hashed" into the
+// vfs.Dentry.children) can embed this to provide no-op implementations for
+// functions related to dynamic entries.
+type InodeNoDynamicLookup struct{}
+
+// Lookup implements Inode.Lookup.
+func (InodeNoDynamicLookup) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ return nil, syserror.ENOENT
+}
+
+// IterDirents implements Inode.IterDirents.
+func (InodeNoDynamicLookup) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ return offset, nil
+}
+
+// Valid implements Inode.Valid.
+func (InodeNoDynamicLookup) Valid(ctx context.Context) bool {
+ return true
+}
+
+// InodeNotSymlink partially implements the Inode interface, specifically the
+// inodeSymlink sub interface. All inodes that are not symlinks may embed this
+// to return the appropriate errors from symlink-related functions.
+type InodeNotSymlink struct{}
+
+// Readlink implements Inode.Readlink.
+func (InodeNotSymlink) Readlink(context.Context) (string, error) {
+ return "", syserror.EINVAL
+}
+
+// Getlink implements Inode.Getlink.
+func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, "", syserror.EINVAL
+}
+
+// InodeAttrs partially implements the Inode interface, specifically the
+// inodeMetadata sub interface. InodeAttrs provides functionality related to
+// inode attributes.
+//
+// Must be initialized by Init prior to first use.
+type InodeAttrs struct {
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+ mode uint32
+ uid uint32
+ gid uint32
+ nlink uint32
+}
+
+// Init initializes this InodeAttrs.
+func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) {
+ if mode.FileType() == 0 {
+ panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode))
+ }
+
+ nlink := uint32(1)
+ if mode.FileType() == linux.ModeDirectory {
+ nlink = 2
+ }
+ a.devMajor = devMajor
+ a.devMinor = devMinor
+ atomic.StoreUint64(&a.ino, ino)
+ atomic.StoreUint32(&a.mode, uint32(mode))
+ atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID))
+ atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID))
+ atomic.StoreUint32(&a.nlink, nlink)
+}
+
+// DevMajor returns the device major number.
+func (a *InodeAttrs) DevMajor() uint32 {
+ return a.devMajor
+}
+
+// DevMinor returns the device minor number.
+func (a *InodeAttrs) DevMinor() uint32 {
+ return a.devMinor
+}
+
+// Ino returns the inode id.
+func (a *InodeAttrs) Ino() uint64 {
+ return atomic.LoadUint64(&a.ino)
+}
+
+// Mode implements Inode.Mode.
+func (a *InodeAttrs) Mode() linux.FileMode {
+ return linux.FileMode(atomic.LoadUint32(&a.mode))
+}
+
+// Stat partially implements Inode.Stat. Note that this function doesn't provide
+// all the stat fields, and the embedder should consider extending the result
+// with filesystem-specific fields.
+func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK
+ stat.DevMajor = a.devMajor
+ stat.DevMinor = a.devMinor
+ stat.Ino = atomic.LoadUint64(&a.ino)
+ stat.Mode = uint16(a.Mode())
+ stat.UID = atomic.LoadUint32(&a.uid)
+ stat.GID = atomic.LoadUint32(&a.gid)
+ stat.Nlink = atomic.LoadUint32(&a.nlink)
+
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
+
+ return stat, nil
+}
+
+// SetStat implements Inode.SetStat.
+func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 {
+ return syserror.EPERM
+ }
+ if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil {
+ return err
+ }
+
+ stat := opts.Stat
+ if stat.Mask&linux.STATX_MODE != 0 {
+ for {
+ old := atomic.LoadUint32(&a.mode)
+ new := old | uint32(stat.Mode & ^uint16(linux.S_IFMT))
+ if swapped := atomic.CompareAndSwapUint32(&a.mode, old, new); swapped {
+ break
+ }
+ }
+ }
+
+ if stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&a.uid, stat.UID)
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&a.gid, stat.GID)
+ }
+
+ // Note that not all fields are modifiable. For example, the file type and
+ // inode numbers are immutable after node creation.
+
+ // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps.
+ // Also, STATX_SIZE will need some special handling, because read-only static
+ // files should return EIO for truncate operations.
+
+ return nil
+}
+
+// CheckPermissions implements Inode.CheckPermissions.
+func (a *InodeAttrs) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(
+ creds,
+ ats,
+ a.Mode(),
+ auth.KUID(atomic.LoadUint32(&a.uid)),
+ auth.KGID(atomic.LoadUint32(&a.gid)),
+ )
+}
+
+// IncLinks implements Inode.IncLinks.
+func (a *InodeAttrs) IncLinks(n uint32) {
+ if atomic.AddUint32(&a.nlink, n) <= n {
+ panic("InodeLink.IncLinks called with no existing links")
+ }
+}
+
+// DecLinks implements Inode.DecLinks.
+func (a *InodeAttrs) DecLinks() {
+ if nlink := atomic.AddUint32(&a.nlink, ^uint32(0)); nlink == ^uint32(0) {
+ // Negative overflow
+ panic("Inode.DecLinks called at 0 links")
+ }
+}
+
+type slot struct {
+ Name string
+ Dentry *vfs.Dentry
+ slotEntry
+}
+
+// OrderedChildrenOptions contains initialization options for OrderedChildren.
+type OrderedChildrenOptions struct {
+ // Writable indicates whether vfs.FilesystemImpl methods implemented by
+ // OrderedChildren may modify the tracked children. This applies to
+ // operations related to rename, unlink and rmdir. If an OrderedChildren is
+ // not writable, these operations all fail with EPERM.
+ Writable bool
+}
+
+// OrderedChildren partially implements the Inode interface. OrderedChildren can
+// be embedded in directory inodes to keep track of the children in the
+// directory, and can then be used to implement a generic directory FD -- see
+// GenericDirectoryFD. OrderedChildren is not compatible with dynamic
+// directories.
+//
+// Must be initialize with Init before first use.
+type OrderedChildren struct {
+ refs.AtomicRefCount
+
+ // Can children be modified by user syscalls? It set to false, interface
+ // methods that would modify the children return EPERM. Immutable.
+ writable bool
+
+ mu sync.RWMutex
+ order slotList
+ set map[string]*slot
+}
+
+// Init initializes an OrderedChildren.
+func (o *OrderedChildren) Init(opts OrderedChildrenOptions) {
+ o.writable = opts.Writable
+ o.set = make(map[string]*slot)
+}
+
+// DecRef implements Inode.DecRef.
+func (o *OrderedChildren) DecRef(ctx context.Context) {
+ o.AtomicRefCount.DecRefWithDestructor(ctx, o.Destroy)
+}
+
+// Destroy cleans up resources referenced by this OrderedChildren.
+func (o *OrderedChildren) Destroy(context.Context) {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ o.order.Reset()
+ o.set = nil
+}
+
+// Populate inserts children into this OrderedChildren, and d's dentry
+// cache. Populate returns the number of directories inserted, which the caller
+// may use to update the link count for the parent directory.
+//
+// Precondition: d must represent a directory inode. children must not contain
+// any conflicting entries already in o.
+func (o *OrderedChildren) Populate(d *Dentry, children map[string]*Dentry) uint32 {
+ var links uint32
+ for name, child := range children {
+ if child.isDir() {
+ links++
+ }
+ if err := o.Insert(name, child.VFSDentry()); err != nil {
+ panic(fmt.Sprintf("Collision when attempting to insert child %q (%+v) into %+v", name, child, d))
+ }
+ d.InsertChild(name, child)
+ }
+ return links
+}
+
+// HasChildren implements Inode.HasChildren.
+func (o *OrderedChildren) HasChildren() bool {
+ o.mu.RLock()
+ defer o.mu.RUnlock()
+ return len(o.set) > 0
+}
+
+// Insert inserts child into o. This ignores the writability of o, as this is
+// not part of the vfs.FilesystemImpl interface, and is a lower-level operation.
+func (o *OrderedChildren) Insert(name string, child *vfs.Dentry) error {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if _, ok := o.set[name]; ok {
+ return syserror.EEXIST
+ }
+ s := &slot{
+ Name: name,
+ Dentry: child,
+ }
+ o.order.PushBack(s)
+ o.set[name] = s
+ return nil
+}
+
+// Precondition: caller must hold o.mu for writing.
+func (o *OrderedChildren) removeLocked(name string) {
+ if s, ok := o.set[name]; ok {
+ delete(o.set, name)
+ o.order.Remove(s)
+ }
+}
+
+// Precondition: caller must hold o.mu for writing.
+func (o *OrderedChildren) replaceChildLocked(name string, new *vfs.Dentry) *vfs.Dentry {
+ if s, ok := o.set[name]; ok {
+ // Existing slot with given name, simply replace the dentry.
+ var old *vfs.Dentry
+ old, s.Dentry = s.Dentry, new
+ return old
+ }
+
+ // No existing slot with given name, create and hash new slot.
+ s := &slot{
+ Name: name,
+ Dentry: new,
+ }
+ o.order.PushBack(s)
+ o.set[name] = s
+ return nil
+}
+
+// Precondition: caller must hold o.mu for reading or writing.
+func (o *OrderedChildren) checkExistingLocked(name string, child *vfs.Dentry) error {
+ s, ok := o.set[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if s.Dentry != child {
+ panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! OrderedChild: %+v, vfs: %+v", s.Dentry, child))
+ }
+ return nil
+}
+
+// Unlink implements Inode.Unlink.
+func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.Dentry) error {
+ if !o.writable {
+ return syserror.EPERM
+ }
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if err := o.checkExistingLocked(name, child); err != nil {
+ return err
+ }
+
+ // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
+ o.removeLocked(name)
+ return nil
+}
+
+// Rmdir implements Inode.Rmdir.
+func (o *OrderedChildren) RmDir(ctx context.Context, name string, child *vfs.Dentry) error {
+ // We're not responsible for checking that child is a directory, that it's
+ // empty, or updating any link counts; so this is the same as unlink.
+ return o.Unlink(ctx, name, child)
+}
+
+type renameAcrossDifferentImplementationsError struct{}
+
+func (renameAcrossDifferentImplementationsError) Error() string {
+ return "rename across inodes with different implementations"
+}
+
+// Rename implements Inode.Rename.
+//
+// Precondition: Rename may only be called across two directory inodes with
+// identical implementations of Rename. Practically, this means filesystems that
+// implement Rename by embedding OrderedChildren for any directory
+// implementation must use OrderedChildren for all directory implementations
+// that will support Rename.
+//
+// Postcondition: reference on any replaced dentry transferred to caller.
+func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (*vfs.Dentry, error) {
+ dst, ok := dstDir.Impl().(*Dentry).inode.(interface{}).(*OrderedChildren)
+ if !ok {
+ return nil, renameAcrossDifferentImplementationsError{}
+ }
+ if !o.writable || !dst.writable {
+ return nil, syserror.EPERM
+ }
+ // Note: There's a potential deadlock below if concurrent calls to Rename
+ // refer to the same src and dst directories in reverse. We avoid any
+ // ordering issues because the caller is required to serialize concurrent
+ // calls to Rename in accordance with the interface declaration.
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if dst != o {
+ dst.mu.Lock()
+ defer dst.mu.Unlock()
+ }
+ if err := o.checkExistingLocked(oldname, child); err != nil {
+ return nil, err
+ }
+
+ // TODO(gvisor.dev/issue/3027): Check sticky bit before removing.
+ replaced := dst.replaceChildLocked(newname, child)
+ return replaced, nil
+}
+
+// nthLocked returns an iterator to the nth child tracked by this object. The
+// iterator is valid until the caller releases o.mu. Returns nil if the
+// requested index falls out of bounds.
+//
+// Preconditon: Caller must hold o.mu for reading.
+func (o *OrderedChildren) nthLocked(i int64) *slot {
+ for it := o.order.Front(); it != nil && i >= 0; it = it.Next() {
+ if i == 0 {
+ return it
+ }
+ i--
+ }
+ return nil
+}
+
+// InodeSymlink partially implements Inode interface for symlinks.
+type InodeSymlink struct {
+ InodeNotDirectory
+}
+
+// Open implements Inode.Open.
+func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return nil, syserror.ELOOP
+}
+
+// StaticDirectory is a standard implementation of a directory with static
+// contents.
+//
+// +stateify savable
+type StaticDirectory struct {
+ InodeNotSymlink
+ InodeDirectoryNoNewChildren
+ InodeAttrs
+ InodeNoDynamicLookup
+ OrderedChildren
+
+ locks vfs.FileLocks
+}
+
+var _ Inode = (*StaticDirectory)(nil)
+
+// NewStaticDir creates a new static directory and returns its dentry.
+func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry {
+ inode := &StaticDirectory{}
+ inode.Init(creds, devMajor, devMinor, ino, perm)
+
+ dentry := &Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, children)
+ inode.IncLinks(links)
+
+ return dentry
+}
+
+// Init initializes StaticDirectory.
+func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm)
+}
+
+// Open implements kernfs.Inode.
+func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// AlwaysValid partially implements kernfs.inodeDynamicLookup.
+type AlwaysValid struct{}
+
+// Valid implements kernfs.inodeDynamicLookup.
+func (*AlwaysValid) Valid(context.Context) bool {
+ return true
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
new file mode 100644
index 000000000..080118841
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -0,0 +1,456 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package kernfs provides the tools to implement inode-based filesystems.
+// Kernfs has two main features:
+//
+// 1. The Inode interface, which maps VFS2's path-based filesystem operations to
+// specific filesystem nodes. Kernfs uses the Inode interface to provide a
+// blanket implementation for the vfs.FilesystemImpl. Kernfs also serves as
+// the synchronization mechanism for all filesystem operations by holding a
+// filesystem-wide lock across all operations.
+//
+// 2. Various utility types which provide generic implementations for various
+// parts of the Inode and vfs.FileDescription interfaces. Client filesystems
+// based on kernfs can embed the appropriate set of these to avoid having to
+// reimplement common filesystem operations. See inode_impl_util.go and
+// fd_impl_util.go.
+//
+// Reference Model:
+//
+// Kernfs dentries represents named pointers to inodes. Dentries and inode have
+// independent lifetimes and reference counts. A child dentry unconditionally
+// holds a reference on its parent directory's dentry. A dentry also holds a
+// reference on the inode it points to. Multiple dentries can point to the same
+// inode (for example, in the case of hardlinks). File descriptors hold a
+// reference to the dentry they're opened on.
+//
+// Dentries are guaranteed to exist while holding Filesystem.mu for
+// reading. Dropping dentries require holding Filesystem.mu for writing. To
+// queue dentries for destruction from a read critical section, see
+// Filesystem.deferDecRef.
+//
+// Lock ordering:
+//
+// kernfs.Filesystem.mu
+// kernfs.Dentry.dirMu
+// vfs.VirtualFilesystem.mountMu
+// vfs.Dentry.mu
+// kernfs.Filesystem.droppedDentriesMu
+// (inode implementation locks, if any)
+package kernfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory
+// filesystem. Concrete implementations are expected to embed this in their own
+// Filesystem type.
+type Filesystem struct {
+ vfsfs vfs.Filesystem
+
+ droppedDentriesMu sync.Mutex
+
+ // droppedDentries is a list of dentries waiting to be DecRef()ed. This is
+ // used to defer dentry destruction until mu can be acquired for
+ // writing. Protected by droppedDentriesMu.
+ droppedDentries []*vfs.Dentry
+
+ // mu synchronizes the lifetime of Dentries on this filesystem. Holding it
+ // for reading guarantees continued existence of any resolved dentries, but
+ // the dentry tree may be modified.
+ //
+ // Kernfs dentries can only be DecRef()ed while holding mu for writing. For
+ // example:
+ //
+ // fs.mu.Lock()
+ // defer fs.mu.Unlock()
+ // ...
+ // dentry1.DecRef()
+ // defer dentry2.DecRef() // Ok, will run before Unlock.
+ //
+ // If discarding dentries in a read context, use Filesystem.deferDecRef. For
+ // example:
+ //
+ // fs.mu.RLock()
+ // fs.mu.processDeferredDecRefs()
+ // defer fs.mu.RUnlock()
+ // ...
+ // fs.deferDecRef(dentry)
+ mu sync.RWMutex
+
+ // nextInoMinusOne is used to to allocate inode numbers on this
+ // filesystem. Must be accessed by atomic operations.
+ nextInoMinusOne uint64
+}
+
+// deferDecRef defers dropping a dentry ref until the next call to
+// processDeferredDecRefs{,Locked}. See comment on Filesystem.mu.
+//
+// Precondition: d must not already be pending destruction.
+func (fs *Filesystem) deferDecRef(d *vfs.Dentry) {
+ fs.droppedDentriesMu.Lock()
+ fs.droppedDentries = append(fs.droppedDentries, d)
+ fs.droppedDentriesMu.Unlock()
+}
+
+// processDeferredDecRefs calls vfs.Dentry.DecRef on all dentries in the
+// droppedDentries list. See comment on Filesystem.mu.
+func (fs *Filesystem) processDeferredDecRefs(ctx context.Context) {
+ fs.mu.Lock()
+ fs.processDeferredDecRefsLocked(ctx)
+ fs.mu.Unlock()
+}
+
+// Precondition: fs.mu must be held for writing.
+func (fs *Filesystem) processDeferredDecRefsLocked(ctx context.Context) {
+ fs.droppedDentriesMu.Lock()
+ for _, d := range fs.droppedDentries {
+ d.DecRef(ctx)
+ }
+ fs.droppedDentries = fs.droppedDentries[:0] // Keep slice memory for reuse.
+ fs.droppedDentriesMu.Unlock()
+}
+
+// VFSFilesystem returns the generic vfs filesystem object.
+func (fs *Filesystem) VFSFilesystem() *vfs.Filesystem {
+ return &fs.vfsfs
+}
+
+// NextIno allocates a new inode number on this filesystem.
+func (fs *Filesystem) NextIno() uint64 {
+ return atomic.AddUint64(&fs.nextInoMinusOne, 1)
+}
+
+// These consts are used in the Dentry.flags field.
+const (
+ // Dentry points to a directory inode.
+ dflagsIsDir = 1 << iota
+
+ // Dentry points to a symlink inode.
+ dflagsIsSymlink
+)
+
+// Dentry implements vfs.DentryImpl.
+//
+// A kernfs dentry is similar to a dentry in a traditional filesystem: it's a
+// named reference to an inode. A dentry generally lives as long as it's part of
+// a mounted filesystem tree. Kernfs doesn't cache dentries once all references
+// to them are removed. Dentries hold a single reference to the inode they point
+// to, and child dentries hold a reference on their parent.
+//
+// Must be initialized by Init prior to first use.
+type Dentry struct {
+ vfsd vfs.Dentry
+
+ refs.AtomicRefCount
+
+ // flags caches useful information about the dentry from the inode. See the
+ // dflags* consts above. Must be accessed by atomic ops.
+ flags uint32
+
+ parent *Dentry
+ name string
+
+ // dirMu protects children and the names of child Dentries.
+ dirMu sync.Mutex
+ children map[string]*Dentry
+
+ inode Inode
+}
+
+// Init initializes this dentry.
+//
+// Precondition: Caller must hold a reference on inode.
+//
+// Postcondition: Caller's reference on inode is transferred to the dentry.
+func (d *Dentry) Init(inode Inode) {
+ d.vfsd.Init(d)
+ d.inode = inode
+ ftype := inode.Mode().FileType()
+ if ftype == linux.ModeDirectory {
+ d.flags |= dflagsIsDir
+ }
+ if ftype == linux.ModeSymlink {
+ d.flags |= dflagsIsSymlink
+ }
+}
+
+// VFSDentry returns the generic vfs dentry for this kernfs dentry.
+func (d *Dentry) VFSDentry() *vfs.Dentry {
+ return &d.vfsd
+}
+
+// isDir checks whether the dentry points to a directory inode.
+func (d *Dentry) isDir() bool {
+ return atomic.LoadUint32(&d.flags)&dflagsIsDir != 0
+}
+
+// isSymlink checks whether the dentry points to a symlink inode.
+func (d *Dentry) isSymlink() bool {
+ return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *Dentry) DecRef(ctx context.Context) {
+ d.AtomicRefCount.DecRefWithDestructor(ctx, d.destroy)
+}
+
+// Precondition: Dentry must be removed from VFS' dentry cache.
+func (d *Dentry) destroy(ctx context.Context) {
+ d.inode.DecRef(ctx) // IncRef from Init.
+ d.inode = nil
+ if d.parent != nil {
+ d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild.
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+//
+// Although Linux technically supports inotify on pseudo filesystems (inotify
+// is implemented at the vfs layer), it is not particularly useful. It is left
+// unimplemented until someone actually needs it.
+func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *Dentry) Watches() *vfs.Watches {
+ return nil
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+func (d *Dentry) OnZeroWatches(context.Context) {}
+
+// InsertChild inserts child into the vfs dentry cache with the given name under
+// this dentry. This does not update the directory inode, so calling this on
+// its own isn't sufficient to insert a child into a directory. InsertChild
+// updates the link count on d if required.
+//
+// Precondition: d must represent a directory inode.
+func (d *Dentry) InsertChild(name string, child *Dentry) {
+ d.dirMu.Lock()
+ d.insertChildLocked(name, child)
+ d.dirMu.Unlock()
+}
+
+// insertChildLocked is equivalent to InsertChild, with additional
+// preconditions.
+//
+// Precondition: d.dirMu must be locked.
+func (d *Dentry) insertChildLocked(name string, child *Dentry) {
+ if !d.isDir() {
+ panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d))
+ }
+ d.IncRef() // DecRef in child's Dentry.destroy.
+ child.parent = d
+ child.name = name
+ if d.children == nil {
+ d.children = make(map[string]*Dentry)
+ }
+ d.children[name] = child
+}
+
+// Inode returns the dentry's inode.
+func (d *Dentry) Inode() Inode {
+ return d.inode
+}
+
+// The Inode interface maps filesystem-level operations that operate on paths to
+// equivalent operations on specific filesystem nodes.
+//
+// The interface methods are groups into logical categories as sub interfaces
+// below. Generally, an implementation for each sub interface can be provided by
+// embedding an appropriate type from inode_impl_utils.go. The sub interfaces
+// are purely organizational. Methods declared directly in the main interface
+// have no generic implementations, and should be explicitly provided by the
+// client filesystem.
+//
+// Generally, implementations are not responsible for tasks that are common to
+// all filesystems. These include:
+//
+// - Checking that dentries passed to methods are of the appropriate file type.
+// - Checking permissions.
+// - Updating link and reference counts.
+//
+// Specific responsibilities of implementations are documented below.
+type Inode interface {
+ // Methods related to reference counting. A generic implementation is
+ // provided by InodeNoopRefCount. These methods are generally called by the
+ // equivalent Dentry methods.
+ inodeRefs
+
+ // Methods related to node metadata. A generic implementation is provided by
+ // InodeAttrs.
+ inodeMetadata
+
+ // Method for inodes that represent symlink. InodeNotSymlink provides a
+ // blanket implementation for all non-symlink inodes.
+ inodeSymlink
+
+ // Method for inodes that represent directories. InodeNotDirectory provides
+ // a blanket implementation for all non-directory inodes.
+ inodeDirectory
+
+ // Method for inodes that represent dynamic directories and their
+ // children. InodeNoDynamicLookup provides a blanket implementation for all
+ // non-dynamic-directory inodes.
+ inodeDynamicLookup
+
+ // Open creates a file description for the filesystem object represented by
+ // this inode. The returned file description should hold a reference on the
+ // inode for its lifetime.
+ //
+ // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing
+ // the inode on which Open() is being called.
+ Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error)
+}
+
+type inodeRefs interface {
+ IncRef()
+ DecRef(ctx context.Context)
+ TryIncRef() bool
+ // Destroy is called when the inode reaches zero references. Destroy release
+ // all resources (references) on objects referenced by the inode, including
+ // any child dentries.
+ Destroy(ctx context.Context)
+}
+
+type inodeMetadata interface {
+ // CheckPermissions checks that creds may access this inode for the
+ // requested access type, per the the rules of
+ // fs/namei.c:generic_permission().
+ CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error
+
+ // Mode returns the (struct stat)::st_mode value for this inode. This is
+ // separated from Stat for performance.
+ Mode() linux.FileMode
+
+ // Stat returns the metadata for this inode. This corresponds to
+ // vfs.FilesystemImpl.StatAt.
+ Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error)
+
+ // SetStat updates the metadata for this inode. This corresponds to
+ // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking
+ // if the operation can be performed (see vfs.CheckSetStat() for common
+ // checks).
+ SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error
+}
+
+// Precondition: All methods in this interface may only be called on directory
+// inodes.
+type inodeDirectory interface {
+ // The New{File,Dir,Node,Symlink} methods below should return a new inode
+ // hashed into this inode.
+ //
+ // These inode constructors are inode-level operations rather than
+ // filesystem-level operations to allow client filesystems to mix different
+ // implementations based on the new node's location in the
+ // filesystem.
+
+ // HasChildren returns true if the directory inode has any children.
+ HasChildren() bool
+
+ // NewFile creates a new regular file inode.
+ NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error)
+
+ // NewDir creates a new directory inode.
+ NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error)
+
+ // NewLink creates a new hardlink to a specified inode in this
+ // directory. Implementations should create a new kernfs Dentry pointing to
+ // target, and update target's link count.
+ NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error)
+
+ // NewSymlink creates a new symbolic link inode.
+ NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error)
+
+ // NewNode creates a new filesystem node for a mknod syscall.
+ NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error)
+
+ // Unlink removes a child dentry from this directory inode.
+ Unlink(ctx context.Context, name string, child *vfs.Dentry) error
+
+ // RmDir removes an empty child directory from this directory
+ // inode. Implementations must update the parent directory's link count,
+ // if required. Implementations are not responsible for checking that child
+ // is a directory, checking for an empty directory.
+ RmDir(ctx context.Context, name string, child *vfs.Dentry) error
+
+ // Rename is called on the source directory containing an inode being
+ // renamed. child should point to the resolved child in the source
+ // directory. If Rename replaces a dentry in the destination directory, it
+ // should return the replaced dentry or nil otherwise.
+ //
+ // Precondition: Caller must serialize concurrent calls to Rename.
+ Rename(ctx context.Context, oldname, newname string, child, dstDir *vfs.Dentry) (replaced *vfs.Dentry, err error)
+}
+
+type inodeDynamicLookup interface {
+ // Lookup should return an appropriate dentry if name should resolve to a
+ // child of this dynamic directory inode. This gives the directory an
+ // opportunity on every lookup to resolve additional entries that aren't
+ // hashed into the directory. This is only called when the inode is a
+ // directory. If the inode is not a directory, or if the directory only
+ // contains a static set of children, the implementer can unconditionally
+ // return an appropriate error (ENOTDIR and ENOENT respectively).
+ //
+ // The child returned by Lookup will be hashed into the VFS dentry tree. Its
+ // lifetime can be controlled by the filesystem implementation with an
+ // appropriate implementation of Valid.
+ //
+ // Lookup returns the child with an extra reference and the caller owns this
+ // reference.
+ Lookup(ctx context.Context, name string) (*vfs.Dentry, error)
+
+ // Valid should return true if this inode is still valid, or needs to
+ // be resolved again by a call to Lookup.
+ Valid(ctx context.Context) bool
+
+ // IterDirents is used to iterate over dynamically created entries. It invokes
+ // cb on each entry in the directory represented by the FileDescription.
+ // 'offset' is the offset for the entire IterDirents call, which may include
+ // results from the caller (e.g. "." and ".."). 'relOffset' is the offset
+ // inside the entries returned by this IterDirents invocation. In other words,
+ // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as
+ // the return value, while 'relOffset' is the place to start iteration.
+ IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error)
+}
+
+type inodeSymlink interface {
+ // Readlink returns the target of a symbolic link. If an inode is not a
+ // symlink, the implementation should return EINVAL.
+ Readlink(ctx context.Context) (string, error)
+
+ // Getlink returns the target of a symbolic link, as used by path
+ // resolution:
+ //
+ // - If the inode is a "magic link" (a link whose target is most accurately
+ // represented as a VirtualDentry), Getlink returns (ok VirtualDentry, "",
+ // nil). A reference is taken on the returned VirtualDentry.
+ //
+ // - If the inode is an ordinary symlink, Getlink returns (zero-value
+ // VirtualDentry, symlink target, nil).
+ //
+ // - If the inode is not a symlink, Getlink returns (zero-value
+ // VirtualDentry, "", EINVAL).
+ Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error)
+}
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
new file mode 100644
index 000000000..c5d5afedf
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -0,0 +1,330 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs_test
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const defaultMode linux.FileMode = 01777
+const staticFileContent = "This is sample content for a static test file."
+
+// RootDentryFn is a generator function for creating the root dentry of a test
+// filesystem. See newTestSystem.
+type RootDentryFn func(*auth.Credentials, *filesystem) *kernfs.Dentry
+
+// newTestSystem sets up a minimal environment for running a test, including an
+// instance of a test filesystem. Tests can control the contents of the
+// filesystem by providing an appropriate rootFn, which should return a
+// pre-populated root dentry.
+func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System {
+ ctx := contexttest.Context(t)
+ creds := auth.CredentialsFromContext(ctx)
+ v := &vfs.VirtualFilesystem{}
+ if err := v.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("Failed to create testfs root mount: %v", err)
+ }
+ return testutil.NewSystem(ctx, t, v, mns)
+}
+
+type fsType struct {
+ rootFn RootDentryFn
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+}
+
+type file struct {
+ kernfs.DynamicBytesFile
+ content string
+}
+
+func (fs *filesystem) newFile(creds *auth.Credentials, content string) *kernfs.Dentry {
+ f := &file{}
+ f.content = content
+ f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(f)
+ return d
+}
+
+func (f *file) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%s", f.content)
+ return nil
+}
+
+type attrs struct {
+ kernfs.InodeAttrs
+}
+
+func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+type readonlyDir struct {
+ attrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &readonlyDir{}
+ dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ dir.dentry.Init(dir)
+
+ dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
+
+ return &dir.dentry
+}
+
+func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+type dir struct {
+ attrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoDynamicLookup
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ fs *filesystem
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &dir{}
+ dir.fs = fs
+ dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode)
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true})
+ dir.dentry.Init(dir)
+
+ dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents))
+
+ return &dir.dentry
+}
+
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ dir := d.fs.newDir(creds, opts.Mode, nil)
+ dirVFSD := dir.VFSDentry()
+ if err := d.OrderedChildren.Insert(name, dirVFSD); err != nil {
+ dir.DecRef(ctx)
+ return nil, err
+ }
+ d.IncLinks(1)
+ return dirVFSD, nil
+}
+
+func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) {
+ creds := auth.CredentialsFromContext(ctx)
+ f := d.fs.newFile(creds, "")
+ fVFSD := f.VFSDentry()
+ if err := d.OrderedChildren.Insert(name, fVFSD); err != nil {
+ f.DecRef(ctx)
+ return nil, err
+ }
+ return fVFSD, nil
+}
+
+func (*dir) NewLink(context.Context, string, kernfs.Inode) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (*dir) NewSymlink(context.Context, string, string) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (*dir) NewNode(context.Context, string, vfs.MknodOptions) (*vfs.Dentry, error) {
+ return nil, syserror.EPERM
+}
+
+func (fsType) Name() string {
+ return "kernfs"
+}
+
+func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ fs := &filesystem{}
+ fs.VFSFilesystem().Init(vfsObj, &fst, fs)
+ root := fst.rootFn(creds, fs)
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+// -------------------- Remainder of the file are test cases --------------------
+
+func TestBasic(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+ defer sys.Destroy()
+ sys.GetDentryOrDie(sys.PathOpAtRoot("file1")).DecRef(sys.Ctx)
+}
+
+func TestMkdirGetDentry(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir1": fs.newDir(creds, 0755, nil),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("dir1/a new directory")
+ if err := sys.VFS.MkdirAt(sys.Ctx, sys.Creds, pop, &vfs.MkdirOptions{Mode: 0755}); err != nil {
+ t.Fatalf("MkdirAt for PathOperation %+v failed: %v", pop, err)
+ }
+ sys.GetDentryOrDie(pop).DecRef(sys.Ctx)
+}
+
+func TestReadStaticFile(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("file1")
+ fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef(sys.Ctx)
+
+ content, err := sys.ReadToEnd(fd)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ if diff := cmp.Diff(staticFileContent, content); diff != "" {
+ t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff)
+ }
+}
+
+func TestCreateNewFileInStaticDir(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir1": fs.newDir(creds, 0755, nil),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("dir1/newfile")
+ opts := &vfs.OpenOptions{Flags: linux.O_CREAT | linux.O_EXCL, Mode: defaultMode}
+ fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, opts)
+ if err != nil {
+ t.Fatalf("OpenAt(pop:%+v, opts:%+v) failed: %v", pop, opts, err)
+ }
+
+ // Close the file. The file should persist.
+ fd.DecRef(sys.Ctx)
+
+ fd, err = sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err)
+ }
+ fd.DecRef(sys.Ctx)
+}
+
+func TestDirFDReadWrite(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, nil)
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("/")
+ fd, err := sys.VFS.OpenAt(sys.Ctx, sys.Creds, pop, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef(sys.Ctx)
+
+ // Read/Write should fail for directory FDs.
+ if _, err := fd.Read(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.ReadOptions{}); err != syserror.EISDIR {
+ t.Fatalf("Read for directory FD failed with unexpected error: %v", err)
+ }
+ if _, err := fd.Write(sys.Ctx, usermem.BytesIOSequence([]byte{}), vfs.WriteOptions{}); err != syserror.EBADF {
+ t.Fatalf("Write for directory FD failed with unexpected error: %v", err)
+ }
+}
+
+func TestDirFDIterDirents(t *testing.T) {
+ sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) *kernfs.Dentry {
+ return fs.newReadonlyDir(creds, 0755, map[string]*kernfs.Dentry{
+ // Fill root with nodes backed by various inode implementations.
+ "dir1": fs.newReadonlyDir(creds, 0755, nil),
+ "dir2": fs.newDir(creds, 0755, map[string]*kernfs.Dentry{
+ "dir3": fs.newDir(creds, 0755, nil),
+ }),
+ "file1": fs.newFile(creds, staticFileContent),
+ })
+ })
+ defer sys.Destroy()
+
+ pop := sys.PathOpAtRoot("/")
+ sys.AssertAllDirentTypes(sys.ListDirents(pop), map[string]testutil.DirentType{
+ "dir1": linux.DT_DIR,
+ "dir2": linux.DT_DIR,
+ "file1": linux.DT_REG,
+ })
+}
diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go
new file mode 100644
index 000000000..2ab3f53fd
--- /dev/null
+++ b/pkg/sentry/fsimpl/kernfs/symlink.go
@@ -0,0 +1,66 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// StaticSymlink provides an Inode implementation for symlinks that point to
+// a immutable target.
+type StaticSymlink struct {
+ InodeAttrs
+ InodeNoopRefCount
+ InodeSymlink
+
+ target string
+}
+
+var _ Inode = (*StaticSymlink)(nil)
+
+// NewStaticSymlink creates a new symlink file pointing to 'target'.
+func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) *Dentry {
+ inode := &StaticSymlink{}
+ inode.Init(creds, devMajor, devMinor, ino, target)
+
+ d := &Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Init initializes the instance.
+func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) {
+ s.target = target
+ s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777)
+}
+
+// Readlink implements Inode.
+func (s *StaticSymlink) Readlink(_ context.Context) (string, error) {
+ return s.target, nil
+}
+
+// Getlink implements Inode.Getlink.
+func (s *StaticSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ return vfs.VirtualDentry{}, s.target, nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*StaticSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
deleted file mode 100644
index 04d667273..000000000
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ /dev/null
@@ -1,76 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-
-go_template_instance(
- name = "dentry_list",
- out = "dentry_list.go",
- package = "memfs",
- prefix = "dentry",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*dentry",
- "Linker": "*dentry",
- },
-)
-
-go_library(
- name = "memfs",
- srcs = [
- "dentry_list.go",
- "directory.go",
- "filesystem.go",
- "memfs.go",
- "named_pipe.go",
- "regular_file.go",
- "symlink.go",
- ],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs",
- deps = [
- "//pkg/abi/linux",
- "//pkg/amutex",
- "//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/kernel/pipe",
- "//pkg/sentry/usermem",
- "//pkg/sentry/vfs",
- "//pkg/syserror",
- ],
-)
-
-go_test(
- name = "benchmark_test",
- size = "small",
- srcs = ["benchmark_test.go"],
- deps = [
- ":memfs",
- "//pkg/abi/linux",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
- "//pkg/sentry/fs",
- "//pkg/sentry/fs/tmpfs",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/vfs",
- "//pkg/syserror",
- ],
-)
-
-go_test(
- name = "memfs_test",
- size = "small",
- srcs = ["pipe_test.go"],
- embed = [":memfs"],
- deps = [
- "//pkg/abi/linux",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/usermem",
- "//pkg/sentry/vfs",
- "//pkg/syserror",
- ],
-)
diff --git a/pkg/sentry/fsimpl/memfs/filesystem.go b/pkg/sentry/fsimpl/memfs/filesystem.go
deleted file mode 100644
index f006c40cd..000000000
--- a/pkg/sentry/fsimpl/memfs/filesystem.go
+++ /dev/null
@@ -1,579 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package memfs
-
-import (
- "fmt"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// stepLocked resolves rp.Component() in parent directory vfsd.
-//
-// stepLocked is loosely analogous to fs/namei.c:walk_component().
-//
-// Preconditions: filesystem.mu must be locked. !rp.Done(). inode ==
-// vfsd.Impl().(*dentry).inode.
-func stepLocked(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode) (*vfs.Dentry, *inode, error) {
- if !inode.isDir() {
- return nil, nil, syserror.ENOTDIR
- }
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
- return nil, nil, err
- }
-afterSymlink:
- nextVFSD, err := rp.ResolveComponent(vfsd)
- if err != nil {
- return nil, nil, err
- }
- if nextVFSD == nil {
- // Since the Dentry tree is the sole source of truth for memfs, if it's
- // not in the Dentry tree, it doesn't exist.
- return nil, nil, syserror.ENOENT
- }
- nextInode := nextVFSD.Impl().(*dentry).inode
- if symlink, ok := nextInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO: symlink traversals update access time
- if err := rp.HandleSymlink(symlink.target); err != nil {
- return nil, nil, err
- }
- goto afterSymlink // don't check the current directory again
- }
- rp.Advance()
- return nextVFSD, nextInode, nil
-}
-
-// walkExistingLocked resolves rp to an existing file.
-//
-// walkExistingLocked is loosely analogous to Linux's
-// fs/namei.c:path_lookupat().
-//
-// Preconditions: filesystem.mu must be locked.
-func walkExistingLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) {
- vfsd := rp.Start()
- inode := vfsd.Impl().(*dentry).inode
- for !rp.Done() {
- var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode)
- if err != nil {
- return nil, nil, err
- }
- }
- if rp.MustBeDir() && !inode.isDir() {
- return nil, nil, syserror.ENOTDIR
- }
- return vfsd, inode, nil
-}
-
-// walkParentDirLocked resolves all but the last path component of rp to an
-// existing directory. It does not check that the returned directory is
-// searchable by the provider of rp.
-//
-// walkParentDirLocked is loosely analogous to Linux's
-// fs/namei.c:path_parentat().
-//
-// Preconditions: filesystem.mu must be locked. !rp.Done().
-func walkParentDirLocked(rp *vfs.ResolvingPath) (*vfs.Dentry, *inode, error) {
- vfsd := rp.Start()
- inode := vfsd.Impl().(*dentry).inode
- for !rp.Final() {
- var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode)
- if err != nil {
- return nil, nil, err
- }
- }
- if !inode.isDir() {
- return nil, nil, syserror.ENOTDIR
- }
- return vfsd, inode, nil
-}
-
-// checkCreateLocked checks that a file named rp.Component() may be created in
-// directory parentVFSD, then returns rp.Component().
-//
-// Preconditions: filesystem.mu must be locked. parentInode ==
-// parentVFSD.Impl().(*dentry).inode. parentInode.isDir() == true.
-func checkCreateLocked(rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode *inode) (string, error) {
- if err := parentInode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec, true); err != nil {
- return "", err
- }
- pc := rp.Component()
- if pc == "." || pc == ".." {
- return "", syserror.EEXIST
- }
- childVFSD, err := rp.ResolveChild(parentVFSD, pc)
- if err != nil {
- return "", err
- }
- if childVFSD != nil {
- return "", syserror.EEXIST
- }
- if parentVFSD.IsDisowned() {
- return "", syserror.ENOENT
- }
- return pc, nil
-}
-
-// checkDeleteLocked checks that the file represented by vfsd may be deleted.
-func checkDeleteLocked(vfsd *vfs.Dentry) error {
- parentVFSD := vfsd.Parent()
- if parentVFSD == nil {
- return syserror.EBUSY
- }
- if parentVFSD.IsDisowned() {
- return syserror.ENOENT
- }
- return nil
-}
-
-// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
-func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
- fs.mu.RLock()
- defer fs.mu.RUnlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return nil, err
- }
- if opts.CheckSearchable {
- if !inode.isDir() {
- return nil, syserror.ENOTDIR
- }
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
- return nil, err
- }
- }
- inode.incRef() // vfsd.IncRef(&fs.vfsfs)
- return vfsd, nil
-}
-
-// LinkAt implements vfs.FilesystemImpl.LinkAt.
-func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if rp.Mount() != vd.Mount() {
- return syserror.EXDEV
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- d := vd.Dentry().Impl().(*dentry)
- if d.inode.isDir() {
- return syserror.EPERM
- }
- d.inode.incLinksLocked()
- child := fs.newDentry(d.inode)
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- return nil
-}
-
-// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
-func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- child := fs.newDentry(fs.newDirectory(rp.Credentials(), opts.Mode))
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- parentInode.incLinksLocked() // from child's ".."
- return nil
-}
-
-// MknodAt implements vfs.FilesystemImpl.MknodAt.
-func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
-
- switch opts.Mode.FileType() {
- case 0:
- // "Zero file type is equivalent to type S_IFREG." - mknod(2)
- fallthrough
- case linux.ModeRegular:
- // TODO(b/138862511): Implement.
- return syserror.EINVAL
-
- case linux.ModeNamedPipe:
- child := fs.newDentry(fs.newNamedPipe(rp.Credentials(), opts.Mode))
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- return nil
-
- case linux.ModeSocket:
- // TODO(b/138862511): Implement.
- return syserror.EINVAL
-
- case linux.ModeCharacterDevice:
- fallthrough
- case linux.ModeBlockDevice:
- // TODO(b/72101894): We don't support creating block or character
- // devices at the moment.
- //
- // When we start supporting block and character devices, we'll
- // need to check for CAP_MKNOD here.
- return syserror.EPERM
-
- default:
- // "EINVAL - mode requested creation of something other than a
- // regular file, device special file, FIFO or socket." - mknod(2)
- return syserror.EINVAL
- }
-}
-
-// OpenAt implements vfs.FilesystemImpl.OpenAt.
-func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
- // Filter out flags that are not supported by memfs. O_DIRECTORY and
- // O_NOFOLLOW have no effect here (they're handled by VFS by setting
- // appropriate bits in rp), but are returned by
- // FileDescriptionImpl.StatusFlags(). O_NONBLOCK is supported only by
- // pipes.
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_TRUNC | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NONBLOCK
-
- if opts.Flags&linux.O_CREAT == 0 {
- fs.mu.RLock()
- defer fs.mu.RUnlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return nil, err
- }
- return inode.open(ctx, rp, vfsd, opts.Flags, false)
- }
-
- mustCreate := opts.Flags&linux.O_EXCL != 0
- vfsd := rp.Start()
- inode := vfsd.Impl().(*dentry).inode
- fs.mu.Lock()
- defer fs.mu.Unlock()
- if rp.Done() {
- if rp.MustBeDir() {
- return nil, syserror.EISDIR
- }
- if mustCreate {
- return nil, syserror.EEXIST
- }
- return inode.open(ctx, rp, vfsd, opts.Flags, false)
- }
-afterTrailingSymlink:
- // Walk to the parent directory of the last path component.
- for !rp.Final() {
- var err error
- vfsd, inode, err = stepLocked(rp, vfsd, inode)
- if err != nil {
- return nil, err
- }
- }
- if !inode.isDir() {
- return nil, syserror.ENOTDIR
- }
- // Check for search permission in the parent directory.
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayExec, true); err != nil {
- return nil, err
- }
- // Reject attempts to open directories with O_CREAT.
- if rp.MustBeDir() {
- return nil, syserror.EISDIR
- }
- pc := rp.Component()
- if pc == "." || pc == ".." {
- return nil, syserror.EISDIR
- }
- // Determine whether or not we need to create a file.
- childVFSD, err := rp.ResolveChild(vfsd, pc)
- if err != nil {
- return nil, err
- }
- if childVFSD == nil {
- // Already checked for searchability above; now check for writability.
- if err := inode.checkPermissions(rp.Credentials(), vfs.MayWrite, true); err != nil {
- return nil, err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return nil, err
- }
- defer rp.Mount().EndWrite()
- // Create and open the child.
- childInode := fs.newRegularFile(rp.Credentials(), opts.Mode)
- child := fs.newDentry(childInode)
- vfsd.InsertChild(&child.vfsd, pc)
- inode.impl.(*directory).childList.PushBack(child)
- return childInode.open(ctx, rp, &child.vfsd, opts.Flags, true)
- }
- // Open existing file or follow symlink.
- if mustCreate {
- return nil, syserror.EEXIST
- }
- childInode := childVFSD.Impl().(*dentry).inode
- if symlink, ok := childInode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
- // TODO: symlink traversals update access time
- if err := rp.HandleSymlink(symlink.target); err != nil {
- return nil, err
- }
- // rp.Final() may no longer be true since we now need to resolve the
- // symlink target.
- goto afterTrailingSymlink
- }
- return childInode.open(ctx, rp, childVFSD, opts.Flags, false)
-}
-
-func (i *inode) open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32, afterCreate bool) (*vfs.FileDescription, error) {
- ats := vfs.AccessTypesForOpenFlags(flags)
- if !afterCreate {
- if err := i.checkPermissions(rp.Credentials(), ats, i.isDir()); err != nil {
- return nil, err
- }
- }
- switch impl := i.impl.(type) {
- case *regularFile:
- var fd regularFileFD
- fd.flags = flags
- fd.readable = vfs.MayReadFileWithOpenFlags(flags)
- fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
- if fd.writable {
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return nil, err
- }
- // Mount.EndWrite() is called by regularFileFD.Release().
- }
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
- if flags&linux.O_TRUNC != 0 {
- impl.mu.Lock()
- impl.data = impl.data[:0]
- atomic.StoreInt64(&impl.dataLen, 0)
- impl.mu.Unlock()
- }
- return &fd.vfsfd, nil
- case *directory:
- // Can't open directories writably.
- if ats&vfs.MayWrite != 0 {
- return nil, syserror.EISDIR
- }
- var fd directoryFD
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
- fd.flags = flags
- return &fd.vfsfd, nil
- case *symlink:
- // Can't open symlinks without O_PATH (which is unimplemented).
- return nil, syserror.ELOOP
- case *namedPipe:
- return newNamedPipeFD(ctx, impl, rp, vfsd, flags)
- default:
- panic(fmt.Sprintf("unknown inode type: %T", i.impl))
- }
-}
-
-// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
-func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
- fs.mu.RLock()
- _, inode, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return "", err
- }
- symlink, ok := inode.impl.(*symlink)
- if !ok {
- return "", syserror.EINVAL
- }
- return symlink.target, nil
-}
-
-// RenameAt implements vfs.FilesystemImpl.RenameAt.
-func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry, opts vfs.RenameOptions) error {
- if rp.Done() {
- return syserror.ENOENT
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- _, err = checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- // TODO: actually implement RenameAt
- return syserror.EPERM
-}
-
-// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
-func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- fs.mu.Lock()
- defer fs.mu.Unlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(vfsd); err != nil {
- return err
- }
- if !inode.isDir() {
- return syserror.ENOTDIR
- }
- if vfsd.HasChildren() {
- return syserror.ENOTEMPTY
- }
- if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
- return err
- }
- // Remove from parent directory's childList.
- vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry))
- inode.decRef()
- return nil
-}
-
-// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
-func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
- fs.mu.RLock()
- _, _, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return err
- }
- if opts.Stat.Mask == 0 {
- return nil
- }
- // TODO: implement inode.setStat
- return syserror.EPERM
-}
-
-// StatAt implements vfs.FilesystemImpl.StatAt.
-func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
- fs.mu.RLock()
- _, inode, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return linux.Statx{}, err
- }
- var stat linux.Statx
- inode.statTo(&stat)
- return stat, nil
-}
-
-// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
-func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
- fs.mu.RLock()
- _, _, err := walkExistingLocked(rp)
- fs.mu.RUnlock()
- if err != nil {
- return linux.Statfs{}, err
- }
- // TODO: actually implement statfs
- return linux.Statfs{}, syserror.ENOSYS
-}
-
-// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
-func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
- if rp.Done() {
- return syserror.EEXIST
- }
- fs.mu.Lock()
- defer fs.mu.Unlock()
- parentVFSD, parentInode, err := walkParentDirLocked(rp)
- if err != nil {
- return err
- }
- pc, err := checkCreateLocked(rp, parentVFSD, parentInode)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- child := fs.newDentry(fs.newSymlink(rp.Credentials(), target))
- parentVFSD.InsertChild(&child.vfsd, pc)
- parentInode.impl.(*directory).childList.PushBack(child)
- return nil
-}
-
-// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
-func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
- fs.mu.Lock()
- defer fs.mu.Unlock()
- vfsd, inode, err := walkExistingLocked(rp)
- if err != nil {
- return err
- }
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return err
- }
- defer rp.Mount().EndWrite()
- if err := checkDeleteLocked(vfsd); err != nil {
- return err
- }
- if inode.isDir() {
- return syserror.EISDIR
- }
- if err := rp.VirtualFilesystem().DeleteDentry(vfs.MountNamespaceFromContext(ctx), vfsd); err != nil {
- return err
- }
- // Remove from parent directory's childList.
- vfsd.Parent().Impl().(*dentry).inode.impl.(*directory).childList.Remove(vfsd.Impl().(*dentry))
- inode.decLinksLocked()
- return nil
-}
diff --git a/pkg/sentry/fsimpl/memfs/memfs.go b/pkg/sentry/fsimpl/memfs/memfs.go
deleted file mode 100644
index 64c851c1a..000000000
--- a/pkg/sentry/fsimpl/memfs/memfs.go
+++ /dev/null
@@ -1,302 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package memfs provides a filesystem implementation that behaves like tmpfs:
-// the Dentry tree is the sole source of truth for the state of the filesystem.
-//
-// memfs is intended primarily to demonstrate filesystem implementation
-// patterns. Real uses cases for an in-memory filesystem should use tmpfs
-// instead.
-//
-// Lock order:
-//
-// filesystem.mu
-// regularFileFD.offMu
-// regularFile.mu
-// inode.mu
-package memfs
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// FilesystemType implements vfs.FilesystemType.
-type FilesystemType struct{}
-
-// filesystem implements vfs.FilesystemImpl.
-type filesystem struct {
- vfsfs vfs.Filesystem
-
- // mu serializes changes to the Dentry tree.
- mu sync.RWMutex
-
- nextInoMinusOne uint64 // accessed using atomic memory operations
-}
-
-// NewFilesystem implements vfs.FilesystemType.NewFilesystem.
-func (fstype FilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts vfs.NewFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
- var fs filesystem
- fs.vfsfs.Init(&fs)
- root := fs.newDentry(fs.newDirectory(creds, 01777))
- return &fs.vfsfs, &root.vfsd, nil
-}
-
-// Release implements vfs.FilesystemImpl.Release.
-func (fs *filesystem) Release() {
-}
-
-// Sync implements vfs.FilesystemImpl.Sync.
-func (fs *filesystem) Sync(ctx context.Context) error {
- // All filesystem state is in-memory.
- return nil
-}
-
-// dentry implements vfs.DentryImpl.
-type dentry struct {
- vfsd vfs.Dentry
-
- // inode is the inode represented by this dentry. Multiple Dentries may
- // share a single non-directory inode (with hard links). inode is
- // immutable.
- inode *inode
-
- // memfs doesn't count references on dentries; because the dentry tree is
- // the sole source of truth, it is by definition always consistent with the
- // state of the filesystem. However, it does count references on inodes,
- // because inode resources are released when all references are dropped.
- // (memfs doesn't really have resources to release, but we implement
- // reference counting because tmpfs regular files will.)
-
- // dentryEntry (ugh) links dentries into their parent directory.childList.
- dentryEntry
-}
-
-func (fs *filesystem) newDentry(inode *inode) *dentry {
- d := &dentry{
- inode: inode,
- }
- d.vfsd.Init(d)
- return d
-}
-
-// IncRef implements vfs.DentryImpl.IncRef.
-func (d *dentry) IncRef(vfsfs *vfs.Filesystem) {
- d.inode.incRef()
-}
-
-// TryIncRef implements vfs.DentryImpl.TryIncRef.
-func (d *dentry) TryIncRef(vfsfs *vfs.Filesystem) bool {
- return d.inode.tryIncRef()
-}
-
-// DecRef implements vfs.DentryImpl.DecRef.
-func (d *dentry) DecRef(vfsfs *vfs.Filesystem) {
- d.inode.decRef()
-}
-
-// inode represents a filesystem object.
-type inode struct {
- // refs is a reference count. refs is accessed using atomic memory
- // operations.
- //
- // A reference is held on all inodes that are reachable in the filesystem
- // tree. For non-directories (which may have multiple hard links), this
- // means that a reference is dropped when nlink reaches 0. For directories,
- // nlink never reaches 0 due to the "." entry; instead,
- // filesystem.RmdirAt() drops the reference.
- refs int64
-
- // Inode metadata; protected by mu and accessed using atomic memory
- // operations unless otherwise specified.
- mu sync.RWMutex
- mode uint32 // excluding file type bits, which are based on impl
- nlink uint32 // protected by filesystem.mu instead of inode.mu
- uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
- gid uint32 // auth.KGID, but ...
- ino uint64 // immutable
-
- impl interface{} // immutable
-}
-
-func (i *inode) init(impl interface{}, fs *filesystem, creds *auth.Credentials, mode linux.FileMode) {
- i.refs = 1
- i.mode = uint32(mode)
- i.uid = uint32(creds.EffectiveKUID)
- i.gid = uint32(creds.EffectiveKGID)
- i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
- // i.nlink initialized by caller
- i.impl = impl
-}
-
-// Preconditions: filesystem.mu must be locked for writing.
-func (i *inode) incLinksLocked() {
- if atomic.AddUint32(&i.nlink, 1) <= 1 {
- panic("memfs.inode.incLinksLocked() called with no existing links")
- }
-}
-
-// Preconditions: filesystem.mu must be locked for writing.
-func (i *inode) decLinksLocked() {
- if nlink := atomic.AddUint32(&i.nlink, ^uint32(0)); nlink == 0 {
- i.decRef()
- } else if nlink == ^uint32(0) { // negative overflow
- panic("memfs.inode.decLinksLocked() called with no existing links")
- }
-}
-
-func (i *inode) incRef() {
- if atomic.AddInt64(&i.refs, 1) <= 1 {
- panic("memfs.inode.incRef() called without holding a reference")
- }
-}
-
-func (i *inode) tryIncRef() bool {
- for {
- refs := atomic.LoadInt64(&i.refs)
- if refs == 0 {
- return false
- }
- if atomic.CompareAndSwapInt64(&i.refs, refs, refs+1) {
- return true
- }
- }
-}
-
-func (i *inode) decRef() {
- if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
- // This is unnecessary; it's mostly to simulate what tmpfs would do.
- if regfile, ok := i.impl.(*regularFile); ok {
- regfile.mu.Lock()
- regfile.data = nil
- atomic.StoreInt64(&regfile.dataLen, 0)
- regfile.mu.Unlock()
- }
- } else if refs < 0 {
- panic("memfs.inode.decRef() called without holding a reference")
- }
-}
-
-func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes, isDir bool) error {
- return vfs.GenericCheckPermissions(creds, ats, isDir, uint16(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
-}
-
-// Go won't inline this function, and returning linux.Statx (which is quite
-// big) means spending a lot of time in runtime.duffcopy(), so instead it's an
-// output parameter.
-func (i *inode) statTo(stat *linux.Statx) {
- stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
- stat.Blksize = 1 // usermem.PageSize in tmpfs
- stat.Nlink = atomic.LoadUint32(&i.nlink)
- stat.UID = atomic.LoadUint32(&i.uid)
- stat.GID = atomic.LoadUint32(&i.gid)
- stat.Mode = uint16(atomic.LoadUint32(&i.mode))
- stat.Ino = i.ino
- // TODO: device number
- switch impl := i.impl.(type) {
- case *regularFile:
- stat.Mode |= linux.S_IFREG
- stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
- stat.Size = uint64(atomic.LoadInt64(&impl.dataLen))
- // In tmpfs, this will be FileRangeSet.Span() / 512 (but also cached in
- // a uint64 accessed using atomic memory operations to avoid taking
- // locks).
- stat.Blocks = allocatedBlocksForSize(stat.Size)
- case *directory:
- stat.Mode |= linux.S_IFDIR
- case *symlink:
- stat.Mode |= linux.S_IFLNK
- stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
- stat.Size = uint64(len(impl.target))
- stat.Blocks = allocatedBlocksForSize(stat.Size)
- case *namedPipe:
- stat.Mode |= linux.S_IFIFO
- default:
- panic(fmt.Sprintf("unknown inode type: %T", i.impl))
- }
-}
-
-// allocatedBlocksForSize returns the number of 512B blocks needed to
-// accommodate the given size in bytes, as appropriate for struct
-// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
-// size is independent of the "preferred block size for I/O", struct
-// stat::st_blksize and struct statx::stx_blksize.)
-func allocatedBlocksForSize(size uint64) uint64 {
- return (size + 511) / 512
-}
-
-func (i *inode) direntType() uint8 {
- switch i.impl.(type) {
- case *regularFile:
- return linux.DT_REG
- case *directory:
- return linux.DT_DIR
- case *symlink:
- return linux.DT_LNK
- default:
- panic(fmt.Sprintf("unknown inode type: %T", i.impl))
- }
-}
-
-// fileDescription is embedded by memfs implementations of
-// vfs.FileDescriptionImpl.
-type fileDescription struct {
- vfsfd vfs.FileDescription
- vfs.FileDescriptionDefaultImpl
-
- flags uint32 // status flags; immutable
-}
-
-func (fd *fileDescription) filesystem() *filesystem {
- return fd.vfsfd.VirtualDentry().Mount().Filesystem().Impl().(*filesystem)
-}
-
-func (fd *fileDescription) inode() *inode {
- return fd.vfsfd.VirtualDentry().Dentry().Impl().(*dentry).inode
-}
-
-// StatusFlags implements vfs.FileDescriptionImpl.StatusFlags.
-func (fd *fileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- return fd.flags, nil
-}
-
-// SetStatusFlags implements vfs.FileDescriptionImpl.SetStatusFlags.
-func (fd *fileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- // None of the flags settable by fcntl(F_SETFL) are supported, so this is a
- // no-op.
- return nil
-}
-
-// Stat implements vfs.FileDescriptionImpl.Stat.
-func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
- var stat linux.Statx
- fd.inode().statTo(&stat)
- return stat, nil
-}
-
-// SetStat implements vfs.FileDescriptionImpl.SetStat.
-func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
- if opts.Stat.Mask == 0 {
- return nil
- }
- // TODO: implement inode.setStat
- return syserror.EPERM
-}
diff --git a/pkg/sentry/fsimpl/memfs/regular_file.go b/pkg/sentry/fsimpl/memfs/regular_file.go
deleted file mode 100644
index b7f4853b3..000000000
--- a/pkg/sentry/fsimpl/memfs/regular_file.go
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package memfs
-
-import (
- "io"
- "sync"
- "sync/atomic"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-type regularFile struct {
- inode inode
-
- mu sync.RWMutex
- data []byte
- // dataLen is len(data), but accessed using atomic memory operations to
- // avoid locking in inode.stat().
- dataLen int64
-}
-
-func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
- file := &regularFile{}
- file.inode.init(file, fs, creds, mode)
- file.inode.nlink = 1 // from parent directory
- return &file.inode
-}
-
-type regularFileFD struct {
- fileDescription
-
- // These are immutable.
- readable bool
- writable bool
-
- // off is the file offset. off is accessed using atomic memory operations.
- // offMu serializes operations that may mutate off.
- off int64
- offMu sync.Mutex
-}
-
-// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *regularFileFD) Release() {
- if fd.writable {
- fd.vfsfd.VirtualDentry().Mount().EndWrite()
- }
-}
-
-// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
- if !fd.readable {
- return 0, syserror.EINVAL
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.RLock()
- if offset >= int64(len(f.data)) {
- f.mu.RUnlock()
- return 0, io.EOF
- }
- n, err := dst.CopyOut(ctx, f.data[offset:])
- f.mu.RUnlock()
- return int64(n), err
-}
-
-// Read implements vfs.FileDescriptionImpl.Read.
-func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PRead(ctx, dst, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
- if !fd.writable {
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- srclen := src.NumBytes()
- if srclen == 0 {
- return 0, nil
- }
- f := fd.inode().impl.(*regularFile)
- f.mu.Lock()
- end := offset + srclen
- if end < offset {
- // Overflow.
- f.mu.Unlock()
- return 0, syserror.EFBIG
- }
- if end > f.dataLen {
- f.data = append(f.data, make([]byte, end-f.dataLen)...)
- atomic.StoreInt64(&f.dataLen, end)
- }
- n, err := src.CopyIn(ctx, f.data[offset:end])
- f.mu.Unlock()
- return int64(n), err
-}
-
-// Write implements vfs.FileDescriptionImpl.Write.
-func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
- fd.offMu.Lock()
- n, err := fd.PWrite(ctx, src, fd.off, opts)
- fd.off += n
- fd.offMu.Unlock()
- return n, err
-}
-
-// Seek implements vfs.FileDescriptionImpl.Seek.
-func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
- fd.offMu.Lock()
- defer fd.offMu.Unlock()
- switch whence {
- case linux.SEEK_SET:
- // use offset as specified
- case linux.SEEK_CUR:
- offset += fd.off
- case linux.SEEK_END:
- offset += atomic.LoadInt64(&fd.inode().impl.(*regularFile).dataLen)
- default:
- return 0, syserror.EINVAL
- }
- if offset < 0 {
- return 0, syserror.EINVAL
- }
- fd.off = offset
- return offset, nil
-}
-
-// Sync implements vfs.FileDescriptionImpl.Sync.
-func (fd *regularFileFD) Sync(ctx context.Context) error {
- return nil
-}
diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD
new file mode 100644
index 000000000..8cf5b35d3
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "overlay",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_library(
+ name = "overlay",
+ srcs = [
+ "copy_up.go",
+ "directory.go",
+ "filesystem.go",
+ "fstree.go",
+ "non_directory.go",
+ "overlay.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
new file mode 100644
index 000000000..b3d19ff82
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -0,0 +1,262 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isCopiedUp() bool {
+ return atomic.LoadUint32(&d.copiedUp) != 0
+}
+
+// copyUpLocked ensures that d exists on the upper layer, i.e. d.upperVD.Ok().
+//
+// Preconditions: filesystem.renameMu must be locked.
+func (d *dentry) copyUpLocked(ctx context.Context) error {
+ // Fast path.
+ if d.isCopiedUp() {
+ return nil
+ }
+
+ ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT
+ switch ftype {
+ case linux.S_IFREG, linux.S_IFDIR, linux.S_IFLNK, linux.S_IFBLK, linux.S_IFCHR:
+ // Can be copied-up.
+ default:
+ // Can't be copied-up.
+ return syserror.EPERM
+ }
+
+ // Ensure that our parent directory is copied-up.
+ if d.parent == nil {
+ // d is a filesystem root with no upper layer.
+ return syserror.EROFS
+ }
+ if err := d.parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ if d.upperVD.Ok() {
+ // Raced with another call to d.copyUpLocked().
+ return nil
+ }
+ if d.vfsd.IsDead() {
+ // Raced with deletion of d.
+ return syserror.ENOENT
+ }
+
+ // Perform copy-up.
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ newpop := vfs.PathOperation{
+ Root: d.parent.upperVD,
+ Start: d.parent.upperVD,
+ Path: fspath.Parse(d.name),
+ }
+ cleanupUndoCopyUp := func() {
+ var err error
+ if ftype == linux.S_IFDIR {
+ err = vfsObj.RmdirAt(ctx, d.fs.creds, &newpop)
+ } else {
+ err = vfsObj.UnlinkAt(ctx, d.fs.creds, &newpop)
+ }
+ if err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err)
+ }
+ }
+ switch ftype {
+ case linux.S_IFREG:
+ oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ })
+ if err != nil {
+ return err
+ }
+ defer oldFD.DecRef(ctx)
+ newFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &newpop, &vfs.OpenOptions{
+ Flags: linux.O_WRONLY | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ })
+ if err != nil {
+ return err
+ }
+ defer newFD.DecRef(ctx)
+ bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
+ for {
+ readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
+ if readErr != nil && readErr != io.EOF {
+ cleanupUndoCopyUp()
+ return readErr
+ }
+ total := int64(0)
+ for total < readN {
+ writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{})
+ total += writeN
+ if writeErr != nil {
+ cleanupUndoCopyUp()
+ return writeErr
+ }
+ }
+ if readErr == io.EOF {
+ break
+ }
+ }
+ if err := newFD.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = newFD.VirtualDentry()
+ d.upperVD.IncRef()
+
+ case linux.S_IFDIR:
+ if err := vfsObj.MkdirAt(ctx, d.fs.creds, &newpop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(d.mode &^ linux.S_IFMT),
+ }); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ case linux.S_IFLNK:
+ target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ })
+ if err != nil {
+ return err
+ }
+ if err := vfsObj.SymlinkAt(ctx, d.fs.creds, &newpop, target); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID,
+ Mode: uint16(d.mode),
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ case linux.S_IFBLK, linux.S_IFCHR:
+ lowerStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVDs[0],
+ Start: d.lowerVDs[0],
+ }, &vfs.StatOptions{})
+ if err != nil {
+ return err
+ }
+ if err := vfsObj.MknodAt(ctx, d.fs.creds, &newpop, &vfs.MknodOptions{
+ Mode: linux.FileMode(d.mode),
+ DevMajor: lowerStat.RdevMajor,
+ DevMinor: lowerStat.RdevMinor,
+ }); err != nil {
+ return err
+ }
+ if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: d.uid,
+ GID: d.gid,
+ },
+ }); err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ upperVD, err := vfsObj.GetDentryAt(ctx, d.fs.creds, &newpop, &vfs.GetDentryOptions{})
+ if err != nil {
+ cleanupUndoCopyUp()
+ return err
+ }
+ d.upperVD = upperVD
+
+ default:
+ // Should have rejected this at the beginning of this function?
+ panic(fmt.Sprintf("unexpected file type %o", ftype))
+ }
+
+ // TODO(gvisor.dev/issue/1199): copy up xattrs
+
+ // Update the dentry's device and inode numbers (except for directories,
+ // for which these remain overlay-assigned).
+ if ftype != linux.S_IFDIR {
+ upperStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_INO,
+ })
+ if err != nil {
+ d.upperVD.DecRef(ctx)
+ d.upperVD = vfs.VirtualDentry{}
+ cleanupUndoCopyUp()
+ return err
+ }
+ if upperStat.Mask&linux.STATX_INO == 0 {
+ d.upperVD.DecRef(ctx)
+ d.upperVD = vfs.VirtualDentry{}
+ cleanupUndoCopyUp()
+ return syserror.EREMOTE
+ }
+ atomic.StoreUint32(&d.devMajor, upperStat.DevMajor)
+ atomic.StoreUint32(&d.devMinor, upperStat.DevMinor)
+ atomic.StoreUint64(&d.ino, upperStat.Ino)
+ }
+
+ atomic.StoreUint32(&d.copiedUp, 1)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go
new file mode 100644
index 000000000..6a79f7ffe
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/directory.go
@@ -0,0 +1,289 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func (d *dentry) isDir() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR
+}
+
+// Preconditions: d.dirMu must be locked. d.isDir().
+func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string]bool, error) {
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ var readdirErr error
+ whiteouts := make(map[string]bool)
+ var maybeWhiteouts []string
+ d.iterLayers(func(layerVD vfs.VirtualDentry, isUpper bool) bool {
+ layerFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ defer layerFD.DecRef(ctx)
+
+ // Reuse slice allocated for maybeWhiteouts from a previous layer to
+ // reduce allocations.
+ maybeWhiteouts = maybeWhiteouts[:0]
+ err = layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name == "." || dirent.Name == ".." {
+ return nil
+ }
+ if _, ok := whiteouts[dirent.Name]; ok {
+ // This file has been whited-out in a previous layer.
+ return nil
+ }
+ if dirent.Type == linux.DT_CHR {
+ // We have to determine if this is a whiteout, which doesn't
+ // count against the directory's emptiness. However, we can't
+ // do so while holding locks held by layerFD.IterDirents().
+ maybeWhiteouts = append(maybeWhiteouts, dirent.Name)
+ return nil
+ }
+ // Non-whiteout file in the directory prevents rmdir.
+ return syserror.ENOTEMPTY
+ }))
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+
+ for _, maybeWhiteoutName := range maybeWhiteouts {
+ stat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ Path: fspath.Parse(maybeWhiteoutName),
+ }, &vfs.StatOptions{})
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ if stat.RdevMajor != 0 || stat.RdevMinor != 0 {
+ // This file is a real character device, not a whiteout.
+ readdirErr = syserror.ENOTEMPTY
+ return false
+ }
+ whiteouts[maybeWhiteoutName] = isUpper
+ }
+ // Continue iteration since we haven't found any non-whiteout files in
+ // this directory yet.
+ return true
+ })
+ return whiteouts, readdirErr
+}
+
+type directoryFD struct {
+ fileDescription
+ vfs.DirectoryFileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ mu sync.Mutex
+ off int64
+ dirents []vfs.Dirent
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *directoryFD) Release(ctx context.Context) {
+}
+
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ d := fd.dentry()
+ if fd.dirents == nil {
+ ds, err := d.getDirents(ctx)
+ if err != nil {
+ return err
+ }
+ fd.dirents = ds
+ }
+
+ for fd.off < int64(len(fd.dirents)) {
+ if err := cb.Handle(fd.dirents[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
+// Preconditions: d.isDir().
+func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) {
+ d.fs.renameMu.RLock()
+ defer d.fs.renameMu.RUnlock()
+ d.dirMu.Lock()
+ defer d.dirMu.Unlock()
+
+ if d.dirents != nil {
+ return d.dirents, nil
+ }
+
+ parent := genericParentOrSelf(d)
+ dirents := []vfs.Dirent{
+ {
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: d.ino,
+ NextOff: 1,
+ },
+ {
+ Name: "..",
+ Type: uint8(atomic.LoadUint32(&parent.mode) >> 12),
+ Ino: parent.ino,
+ NextOff: 2,
+ },
+ }
+
+ // Merge dirents from all layers comprising this directory.
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ var readdirErr error
+ prevDirents := make(map[string]struct{})
+ var maybeWhiteouts []vfs.Dirent
+ d.iterLayers(func(layerVD vfs.VirtualDentry, isUpper bool) bool {
+ layerFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ defer layerFD.DecRef(ctx)
+
+ // Reuse slice allocated for maybeWhiteouts from a previous layer to
+ // reduce allocations.
+ maybeWhiteouts = maybeWhiteouts[:0]
+ err = layerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name == "." || dirent.Name == ".." {
+ return nil
+ }
+ if _, ok := prevDirents[dirent.Name]; ok {
+ // This file is hidden by, or merged with, another file with
+ // the same name in a previous layer.
+ return nil
+ }
+ prevDirents[dirent.Name] = struct{}{}
+ if dirent.Type == linux.DT_CHR {
+ // We can't determine if this file is a whiteout while holding
+ // locks held by layerFD.IterDirents().
+ maybeWhiteouts = append(maybeWhiteouts, dirent)
+ return nil
+ }
+ dirent.NextOff = int64(len(dirents) + 1)
+ dirents = append(dirents, dirent)
+ return nil
+ }))
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+
+ for _, dirent := range maybeWhiteouts {
+ stat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ Path: fspath.Parse(dirent.Name),
+ }, &vfs.StatOptions{})
+ if err != nil {
+ readdirErr = err
+ return false
+ }
+ if stat.RdevMajor == 0 && stat.RdevMinor == 0 {
+ // This file is a whiteout; don't emit a dirent for it.
+ continue
+ }
+ dirent.NextOff = int64(len(dirents) + 1)
+ dirents = append(dirents, dirent)
+ }
+ return true
+ })
+ if readdirErr != nil {
+ return nil, readdirErr
+ }
+
+ // Cache dirents for future directoryFDs.
+ d.dirents = dirents
+ return dirents, nil
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ switch whence {
+ case linux.SEEK_SET:
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ if offset == 0 {
+ // Ensure that the next call to fd.IterDirents() calls
+ // fd.dentry().getDirents().
+ fd.dirents = nil
+ }
+ fd.off = offset
+ return fd.off, nil
+ case linux.SEEK_CUR:
+ offset += fd.off
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ // Don't clear fd.dirents in this case, even if offset == 0.
+ fd.off = offset
+ return fd.off, nil
+ default:
+ return 0, syserror.EINVAL
+ }
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync. Forwards sync to the upper
+// layer, if there is one. The lower layer doesn't need to sync because it
+// never changes.
+func (fd *directoryFD) Sync(ctx context.Context) error {
+ d := fd.dentry()
+ if !d.isCopiedUp() {
+ return nil
+ }
+ vfsObj := d.fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }
+ upperFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY})
+ if err != nil {
+ return err
+ }
+ err = upperFD.Sync(ctx)
+ upperFD.DecRef(ctx)
+ return err
+}
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
new file mode 100644
index 000000000..986b36ead
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -0,0 +1,1364 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// _OVL_XATTR_OPAQUE is an extended attribute key whose value is set to "y" for
+// opaque directories.
+// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_OPAQUE
+const _OVL_XATTR_OPAQUE = "trusted.overlay.opaque"
+
+func isWhiteout(stat *linux.Statx) bool {
+ return stat.Mode&linux.S_IFMT == linux.S_IFCHR && stat.RdevMajor == 0 && stat.RdevMinor == 0
+}
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ if fs.opts.UpperRoot.Ok() {
+ return fs.opts.UpperRoot.Mount().Filesystem().Impl().Sync(ctx)
+ }
+ return nil
+}
+
+var dentrySlicePool = sync.Pool{
+ New: func() interface{} {
+ ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity
+ return &ds
+ },
+}
+
+func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
+ if ds == nil {
+ ds = dentrySlicePool.Get().(*[]*dentry)
+ }
+ *ds = append(*ds, d)
+ return ds
+}
+
+// Preconditions: ds != nil.
+func putDentrySlice(ds *[]*dentry) {
+ // Allow dentries to be GC'd.
+ for i := range *ds {
+ (*ds)[i] = nil
+ }
+ *ds = (*ds)[:0]
+ dentrySlicePool.Put(ds)
+}
+
+// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls
+// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for
+// writing.
+//
+// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
+// but dentry slices are allocated lazily, and it's much easier to say "defer
+// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
+ fs.renameMu.RUnlock()
+ if *ds == nil {
+ return
+ }
+ if len(**ds) != 0 {
+ fs.renameMu.Lock()
+ for _, d := range **ds {
+ d.checkDropLocked(ctx)
+ }
+ fs.renameMu.Unlock()
+ }
+ putDentrySlice(*ds)
+}
+
+func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
+ if *ds == nil {
+ fs.renameMu.Unlock()
+ return
+ }
+ for _, d := range **ds {
+ d.checkDropLocked(ctx)
+ }
+ fs.renameMu.Unlock()
+ putDentrySlice(*ds)
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// Dentries which may have a reference count of zero, and which therefore
+// should be dropped once traversal is complete, are appended to ds.
+//
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+// !rp.Done().
+func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, nil
+ }
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return d.parent, nil
+ }
+ child, err := fs.getChildLocked(ctx, d, name, ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
+ return nil, err
+ }
+ if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// Preconditions: fs.renameMu must be locked. d.dirMu must be locked.
+func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) {
+ if child, ok := parent.children[name]; ok {
+ return child, nil
+ }
+ child, err := fs.lookupLocked(ctx, parent, name)
+ if err != nil {
+ return nil, err
+ }
+ if parent.children == nil {
+ parent.children = make(map[string]*dentry)
+ }
+ parent.children[name] = child
+ // child's refcount is initially 0, so it may be dropped after traversal.
+ *ds = appendDentry(*ds, child)
+ return child, nil
+}
+
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) {
+ childPath := fspath.Parse(name)
+ child := fs.newDentry()
+ existsOnAnyLayer := false
+ var lookupErr error
+
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ parent.iterLayers(func(parentVD vfs.VirtualDentry, isUpper bool) bool {
+ childVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parentVD,
+ Start: parentVD,
+ Path: childPath,
+ }, &vfs.GetDentryOptions{})
+ if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ // The file doesn't exist on this layer. Proceed to the next one.
+ return true
+ }
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+
+ mask := uint32(linux.STATX_TYPE)
+ if !existsOnAnyLayer {
+ // Mode, UID, GID, and (for non-directories) inode number come from
+ // the topmost layer on which the file exists.
+ mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ }
+ stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childVD,
+ Start: childVD,
+ }, &vfs.StatOptions{
+ Mask: mask,
+ })
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+ if stat.Mask&mask != mask {
+ lookupErr = syserror.EREMOTE
+ return false
+ }
+
+ if isWhiteout(&stat) {
+ // This is a whiteout, so it "doesn't exist" on this layer, and
+ // layers below this one are ignored.
+ return false
+ }
+ isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR
+ if existsOnAnyLayer && !isDir {
+ // Directories are not merged with non-directory files from lower
+ // layers; instead, layers including and below the first
+ // non-directory file are ignored. (This file must be a directory
+ // on previous layers, since lower layers aren't searched for
+ // non-directory files.)
+ return false
+ }
+
+ // Update child to include this layer.
+ if isUpper {
+ child.upperVD = childVD
+ child.copiedUp = 1
+ } else {
+ child.lowerVDs = append(child.lowerVDs, childVD)
+ }
+ if !existsOnAnyLayer {
+ existsOnAnyLayer = true
+ child.mode = uint32(stat.Mode)
+ child.uid = stat.UID
+ child.gid = stat.GID
+ child.devMajor = stat.DevMajor
+ child.devMinor = stat.DevMinor
+ child.ino = stat.Ino
+ }
+
+ // For non-directory files, only the topmost layer that contains a file
+ // matters.
+ if !isDir {
+ return false
+ }
+
+ // Directories are merged with directories from lower layers if they
+ // are not explicitly opaque.
+ opaqueVal, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childVD,
+ Start: childVD,
+ }, &vfs.GetxattrOptions{
+ Name: _OVL_XATTR_OPAQUE,
+ Size: 1,
+ })
+ return !(err == nil && opaqueVal == "y")
+ })
+
+ if lookupErr != nil {
+ child.destroyLocked(ctx)
+ return nil, lookupErr
+ }
+ if !existsOnAnyLayer {
+ child.destroyLocked(ctx)
+ return nil, syserror.ENOENT
+ }
+
+ // Device and inode numbers were copied from the topmost layer above;
+ // override them if necessary.
+ if child.isDir() {
+ child.devMajor = linux.UNNAMED_MAJOR
+ child.devMinor = fs.dirDevMinor
+ child.ino = fs.newDirIno()
+ } else if !child.upperVD.Ok() {
+ child.devMajor = linux.UNNAMED_MAJOR
+ child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()]
+ }
+
+ parent.IncRef()
+ child.parent = parent
+ child.name = name
+ return child, nil
+}
+
+// lookupLayerLocked is similar to lookupLocked, but only returns information
+// about the file rather than a dentry.
+//
+// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked.
+func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, name string) (lookupLayer, error) {
+ childPath := fspath.Parse(name)
+ lookupLayer := lookupLayerNone
+ var lookupErr error
+
+ parent.iterLayers(func(parentVD vfs.VirtualDentry, isUpper bool) bool {
+ stat, err := fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: parentVD,
+ Start: parentVD,
+ Path: childPath,
+ }, &vfs.StatOptions{
+ Mask: linux.STATX_TYPE,
+ })
+ if err == syserror.ENOENT || err == syserror.ENAMETOOLONG {
+ // The file doesn't exist on this layer. Proceed to the next
+ // one.
+ return true
+ }
+ if err != nil {
+ lookupErr = err
+ return false
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 {
+ // Linux's overlayfs tends to return EREMOTE in cases where a file
+ // is unusable for reasons that are not better captured by another
+ // errno.
+ lookupErr = syserror.EREMOTE
+ return false
+ }
+ if isWhiteout(&stat) {
+ // This is a whiteout, so it "doesn't exist" on this layer, and
+ // layers below this one are ignored.
+ if isUpper {
+ lookupLayer = lookupLayerUpperWhiteout
+ }
+ return false
+ }
+ // The file exists; we can stop searching.
+ if isUpper {
+ lookupLayer = lookupLayerUpper
+ } else {
+ lookupLayer = lookupLayerLower
+ }
+ return false
+ })
+
+ return lookupLayer, lookupErr
+}
+
+type lookupLayer int
+
+const (
+ // lookupLayerNone indicates that no file exists at the given path on the
+ // upper layer, and is either whited out or does not exist on lower layers.
+ // Therefore, the file does not exist in the overlay filesystem, and file
+ // creation may proceed normally (if an upper layer exists).
+ lookupLayerNone lookupLayer = iota
+
+ // lookupLayerLower indicates that no file exists at the given path on the
+ // upper layer, but exists on a lower layer. Therefore, the file exists in
+ // the overlay filesystem, but must be copied-up before mutation.
+ lookupLayerLower
+
+ // lookupLayerUpper indicates that a non-whiteout file exists at the given
+ // path on the upper layer. Therefore, the file exists in the overlay
+ // filesystem, and is already copied-up.
+ lookupLayerUpper
+
+ // lookupLayerUpperWhiteout indicates that a whiteout exists at the given
+ // path on the upper layer. Therefore, the file does not exist in the
+ // overlay filesystem, and file creation must remove the whiteout before
+ // proceeding.
+ lookupLayerUpperWhiteout
+)
+
+func (ll lookupLayer) existsInOverlay() bool {
+ return ll == lookupLayerLower || ll == lookupLayerUpper
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// Preconditions: fs.renameMu must be locked. !rp.Done().
+func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ for !rp.Final() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// Preconditions: fs.renameMu must be locked.
+func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ for !rp.Done() {
+ d.dirMu.Lock()
+ next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds)
+ d.dirMu.Unlock()
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// create to do so.
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ if parent.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Determine if a file already exists at name.
+ if _, ok := parent.children[name]; ok {
+ return syserror.EEXIST
+ }
+ childLayer, err := fs.lookupLayerLocked(ctx, parent, name)
+ if err != nil {
+ return err
+ }
+ if childLayer.existsInOverlay() {
+ return syserror.EEXIST
+ }
+
+ // Ensure that the parent directory is copied-up so that we can create the
+ // new file in the upper layer.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ // Finally create the new file.
+ if err := create(parent, name, childLayer == lookupLayerUpperWhiteout); err != nil {
+ return err
+ }
+ parent.dirents = nil
+ return nil
+}
+
+// Preconditions: pop's parent directory has been copied up.
+func (fs *filesystem) createWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) error {
+ return vfsObj.MknodAt(ctx, fs.creds, pop, &vfs.MknodOptions{
+ Mode: linux.S_IFCHR, // permissions == include/linux/fs.h:WHITEOUT_MODE == 0
+ // DevMajor == DevMinor == 0, from include/linux/fs.h:WHITEOUT_DEV
+ })
+}
+
+func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) {
+ if err := fs.createWhiteout(ctx, vfsObj, pop); err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err)
+ }
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
+// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ layerVD := d.topLayer()
+ return fs.vfsfs.VirtualFilesystem().BoundEndpointAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &opts)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ old := vd.Dentry().Impl().(*dentry)
+ if old.isDir() {
+ return syserror.EPERM
+ }
+ if err := old.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ newpop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.LinkAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: old.upperVD,
+ Start: old.upperVD,
+ }, &newpop); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &newpop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.MkdirAt(ctx, fs.creds, &pop, &opts); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ if haveUpperWhiteout {
+ // There may be directories on lower layers (previously hidden by
+ // the whiteout) that the new directory should not be merged with.
+ // Mark it opaque to prevent merging.
+ if err := vfsObj.SetxattrAt(ctx, fs.creds, &pop, &vfs.SetxattrOptions{
+ Name: _OVL_XATTR_OPAQUE,
+ Value: "y",
+ }); err != nil {
+ if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr)
+ } else {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ }
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ // Disallow attempts to create whiteouts.
+ if opts.Mode&linux.S_IFMT == linux.S_IFCHR && opts.DevMajor == 0 && opts.DevMinor == 0 {
+ return syserror.EPERM
+ }
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.MknodAt(ctx, fs.creds, &pop, &opts); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ mayCreate := opts.Flags&linux.O_CREAT != 0
+ mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL)
+
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+
+ start := rp.Start().Impl().(*dentry)
+ if rp.Done() {
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return start.openLocked(ctx, rp, &opts)
+ }
+
+afterTrailingSymlink:
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Determine whether or not we need to create a file.
+ parent.dirMu.Lock()
+ child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds)
+ if err == syserror.ENOENT && mayCreate {
+ fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds)
+ parent.dirMu.Unlock()
+ return fd, err
+ }
+ if err != nil {
+ parent.dirMu.Unlock()
+ return nil, err
+ }
+ // Open existing child or follow symlink.
+ parent.dirMu.Unlock()
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ if child.isSymlink() && rp.ShouldFollowSymlink() {
+ target, err := child.readlink(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if err := rp.HandleSymlink(target); err != nil {
+ return nil, err
+ }
+ start = parent
+ goto afterTrailingSymlink
+ }
+ return child.openLocked(ctx, rp, &opts)
+}
+
+// Preconditions: fs.renameMu must be locked.
+func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if err := d.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ if ats.MayWrite() {
+ if err := d.copyUpLocked(ctx); err != nil {
+ return nil, err
+ }
+ }
+ mnt := rp.Mount()
+
+ // Directory FDs open FDs from each layer when directory entries are read,
+ // so they don't require opening an FD from d.topLayer() up front.
+ ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT
+ if ftype == linux.S_IFDIR {
+ // Can't open directories with O_CREAT.
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EISDIR
+ }
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ if opts.Flags&linux.O_DIRECT != 0 {
+ return nil, syserror.EINVAL
+ }
+ fd := &directoryFD{}
+ fd.LockFD.Init(&d.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ }
+
+ layerVD, isUpper := d.topLayerInfo()
+ layerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, opts)
+ if err != nil {
+ return nil, err
+ }
+ layerFlags := layerFD.StatusFlags()
+ fd := &nonDirectoryFD{
+ copiedUp: isUpper,
+ cachedFD: layerFD,
+ cachedFlags: layerFlags,
+ }
+ fd.LockFD.Init(&d.locks)
+ layerFDOpts := layerFD.Options()
+ if err := fd.vfsfd.Init(fd, layerFlags, mnt, &d.vfsd, &layerFDOpts); err != nil {
+ layerFD.DecRef(ctx)
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Preconditions: parent.dirMu must be locked. parent does not already contain
+// a child named rp.Component().
+func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) {
+ creds := rp.Credentials()
+ if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if parent.vfsd.IsDead() {
+ return nil, syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer mnt.EndWrite()
+
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return nil, err
+ }
+
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ childName := rp.Component()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ // We don't know if a whiteout exists on the upper layer; speculatively
+ // unlink it.
+ //
+ // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do
+ // know whether a whiteout exists.
+ var haveUpperWhiteout bool
+ switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err {
+ case nil:
+ haveUpperWhiteout = true
+ case syserror.ENOENT:
+ haveUpperWhiteout = false
+ default:
+ return nil, err
+ }
+ // Create the file on the upper layer, and get an FD representing it.
+ upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{
+ Flags: opts.Flags&^vfs.FileCreationFlags | linux.O_CREAT | linux.O_EXCL,
+ Mode: opts.Mode,
+ })
+ if err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Change the file's owner to the caller. We can't use upperFD.SetStat()
+ // because it will pick up creds from ctx.
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Re-lookup to get a dentry representing the new file, which is needed for
+ // the returned FD.
+ child, err := fs.getChildLocked(ctx, parent, childName, ds)
+ if err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return nil, err
+ }
+ // Finally construct the overlay FD.
+ upperFlags := upperFD.StatusFlags()
+ fd := &nonDirectoryFD{
+ copiedUp: true,
+ cachedFD: upperFD,
+ cachedFlags: upperFlags,
+ }
+ fd.LockFD.Init(&child.locks)
+ upperFDOpts := upperFD.Options()
+ if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil {
+ upperFD.DecRef(ctx)
+ // Don't bother with cleanup; the file was created successfully, we
+ // just can't open it anymore for some reason.
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ layerVD := d.topLayer()
+ return fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ })
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ return syserror.EINVAL
+ }
+
+ var ds *[]*dentry
+ fs.renameMu.Lock()
+ defer fs.renameMuUnlockAndCheckDrop(ctx, &ds)
+ newParent, err := fs.walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry), &ds)
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ // FIXME(gvisor.dev/issue/1199): Actually implement rename.
+ _ = newParent
+ return syserror.EXDEV
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ name := rp.Component()
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Ensure that parent is copied-up before potentially holding child.copyMu
+ // below.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ // Unlike UnlinkAt, we need a dentry representing the child directory being
+ // removed in order to verify that it's empty.
+ child, err := fs.getChildLocked(ctx, parent, name, &ds)
+ if err != nil {
+ return err
+ }
+ if !child.isDir() {
+ return syserror.ENOTDIR
+ }
+ child.dirMu.Lock()
+ defer child.dirMu.Unlock()
+ whiteouts, err := child.collectWhiteoutsForRmdirLocked(ctx)
+ if err != nil {
+ return err
+ }
+ child.copyMu.RLock()
+ defer child.copyMu.RUnlock()
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(name),
+ }
+ if child.upperVD.Ok() {
+ cleanupRecreateWhiteouts := func() {
+ if !child.upperVD.Ok() {
+ return
+ }
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &vfs.PathOperation{
+ Root: child.upperVD,
+ Start: child.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil && err != syserror.EEXIST {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)
+ }
+ }
+ }
+ // Remove existing whiteouts on the upper layer.
+ for whiteoutName, whiteoutUpper := range whiteouts {
+ if !whiteoutUpper {
+ continue
+ }
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: child.upperVD,
+ Start: child.upperVD,
+ Path: fspath.Parse(whiteoutName),
+ }); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+ }
+ // Remove the existing directory on the upper layer.
+ if err := vfsObj.RmdirAt(ctx, fs.creds, &pop); err != nil {
+ cleanupRecreateWhiteouts()
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil {
+ // Don't attempt to recover from this: the original directory is
+ // already gone, so any dentries representing it are invalid, and
+ // creating a new directory won't undo that.
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err)
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ return err
+ }
+
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
+ delete(parent.children, name)
+ ds = appendDentry(ds, child)
+ parent.dirents = nil
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // Changes to d's attributes are serialized by d.copyMu.
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ if err := d.fs.vfsfs.VirtualFilesystem().SetStatAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.upperVD,
+ Start: d.upperVD,
+ }, &opts); err != nil {
+ return err
+ }
+ d.updateAfterSetStatLocked(&opts)
+ return nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ var stat linux.Statx
+ if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 {
+ layerVD := d.topLayer()
+ stat, err = fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ }, &vfs.StatOptions{
+ Mask: layerMask,
+ Sync: opts.Sync,
+ })
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ d.statInternalTo(ctx, &opts, &stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ return fs.statFS(ctx)
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, haveUpperWhiteout bool) error {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(childName),
+ }
+ if haveUpperWhiteout {
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ return err
+ }
+ }
+ if err := vfsObj.SymlinkAt(ctx, fs.creds, &pop, target); err != nil {
+ if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ creds := rp.Credentials()
+ if err := vfsObj.SetStatAt(ctx, fs.creds, &pop, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_UID | linux.STATX_GID,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ },
+ }); err != nil {
+ if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr)
+ } else if haveUpperWhiteout {
+ fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop)
+ }
+ return err
+ }
+ return nil
+ })
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return err
+ }
+ if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer rp.Mount().EndWrite()
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ parent.dirMu.Lock()
+ defer parent.dirMu.Unlock()
+
+ // Ensure that parent is copied-up before potentially holding child.copyMu
+ // below.
+ if err := parent.copyUpLocked(ctx); err != nil {
+ return err
+ }
+
+ child := parent.children[name]
+ var childLayer lookupLayer
+ if child != nil {
+ if child.isDir() {
+ return syserror.EISDIR
+ }
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+ // Hold child.copyMu to prevent it from being copied-up during
+ // deletion.
+ child.copyMu.RLock()
+ defer child.copyMu.RUnlock()
+ if child.upperVD.Ok() {
+ childLayer = lookupLayerUpper
+ } else {
+ childLayer = lookupLayerLower
+ }
+ } else {
+ // Determine if the file being unlinked actually exists. Holding
+ // parent.dirMu prevents a dentry from being instantiated for the file,
+ // which in turn prevents it from being copied-up, so this result is
+ // stable.
+ childLayer, err = fs.lookupLayerLocked(ctx, parent, name)
+ if err != nil {
+ return err
+ }
+ if !childLayer.existsInOverlay() {
+ return syserror.ENOENT
+ }
+ }
+
+ pop := vfs.PathOperation{
+ Root: parent.upperVD,
+ Start: parent.upperVD,
+ Path: fspath.Parse(name),
+ }
+ if childLayer == lookupLayerUpper {
+ // Remove the existing file on the upper layer.
+ if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil {
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+ }
+ if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil {
+ ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err)
+ if child != nil {
+ vfsObj.AbortDeleteDentry(&child.vfsd)
+ }
+ return err
+ }
+
+ if child != nil {
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
+ delete(parent.children, name)
+ ds = appendDentry(ds, child)
+ }
+ parent.dirents = nil
+ return nil
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ // TODO(gvisor.dev/issue/1199): Linux overlayfs actually allows listxattr,
+ // but not any other xattr syscalls. For now we just reject all of them.
+ return nil, syserror.ENOTSUP
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ _, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return syserror.ENOTSUP
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.renameMu.RLock()
+ defer fs.renameMu.RUnlock()
+ return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
+}
diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go
new file mode 100644
index 000000000..d3060a481
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/non_directory.go
@@ -0,0 +1,266 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package overlay
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (d *dentry) isSymlink() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK
+}
+
+func (d *dentry) readlink(ctx context.Context) (string, error) {
+ layerVD := d.topLayer()
+ return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: layerVD,
+ Start: layerVD,
+ })
+}
+
+type nonDirectoryFD struct {
+ fileDescription
+
+ // If copiedUp is false, cachedFD represents
+ // fileDescription.dentry().lowerVDs[0]; otherwise, cachedFD represents
+ // fileDescription.dentry().upperVD. cachedFlags is the last known value of
+ // cachedFD.StatusFlags(). copiedUp, cachedFD, and cachedFlags are
+ // protected by mu.
+ mu sync.Mutex
+ copiedUp bool
+ cachedFD *vfs.FileDescription
+ cachedFlags uint32
+}
+
+func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return nil, err
+ }
+ wrappedFD.IncRef()
+ return wrappedFD, nil
+}
+
+func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) {
+ d := fd.dentry()
+ statusFlags := fd.vfsfd.StatusFlags()
+ if !fd.copiedUp && d.isCopiedUp() {
+ // Switch to the copied-up file.
+ upperVD := d.topLayer()
+ upperFD, err := fd.filesystem().vfsfs.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: upperVD,
+ Start: upperVD,
+ }, &vfs.OpenOptions{
+ Flags: statusFlags,
+ })
+ if err != nil {
+ return nil, err
+ }
+ oldOff, oldOffErr := fd.cachedFD.Seek(ctx, 0, linux.SEEK_CUR)
+ if oldOffErr == nil {
+ if _, err := upperFD.Seek(ctx, oldOff, linux.SEEK_SET); err != nil {
+ upperFD.DecRef(ctx)
+ return nil, err
+ }
+ }
+ fd.cachedFD.DecRef(ctx)
+ fd.copiedUp = true
+ fd.cachedFD = upperFD
+ fd.cachedFlags = statusFlags
+ } else if fd.cachedFlags != statusFlags {
+ if err := fd.cachedFD.SetStatusFlags(ctx, d.fs.creds, statusFlags); err != nil {
+ return nil, err
+ }
+ fd.cachedFlags = statusFlags
+ }
+ return fd.cachedFD, nil
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *nonDirectoryFD) Release(ctx context.Context) {
+ fd.cachedFD.DecRef(ctx)
+ fd.cachedFD = nil
+}
+
+// OnClose implements vfs.FileDescriptionImpl.OnClose.
+func (fd *nonDirectoryFD) OnClose(ctx context.Context) error {
+ // Linux doesn't define ovl_file_operations.flush at all (i.e. its
+ // equivalent to OnClose is a no-op). We pass through to
+ // fd.cachedFD.OnClose() without upgrading if fd.dentry() has been
+ // copied-up, since OnClose is mostly used to define post-close writeback,
+ // and if fd.cachedFD hasn't been updated then it can't have been used to
+ // mutate fd.dentry() anyway.
+ fd.mu.Lock()
+ if statusFlags := fd.vfsfd.StatusFlags(); fd.cachedFlags != statusFlags {
+ if err := fd.cachedFD.SetStatusFlags(ctx, fd.filesystem().creds, statusFlags); err != nil {
+ fd.mu.Unlock()
+ return err
+ }
+ fd.cachedFlags = statusFlags
+ }
+ wrappedFD := fd.cachedFD
+ defer wrappedFD.IncRef()
+ fd.mu.Unlock()
+ return wrappedFD.OnClose(ctx)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ stat, err = wrappedFD.Stat(ctx, vfs.StatOptions{
+ Mask: layerMask,
+ Sync: opts.Sync,
+ })
+ wrappedFD.DecRef(ctx)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ }
+ fd.dentry().statInternalTo(ctx, &opts, &stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ d := fd.dentry()
+ mode := linux.FileMode(atomic.LoadUint32(&d.mode))
+ if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil {
+ return err
+ }
+ mnt := fd.vfsfd.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := d.copyUpLocked(ctx); err != nil {
+ return err
+ }
+ // Changes to d's attributes are serialized by d.copyMu.
+ d.copyMu.Lock()
+ defer d.copyMu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return err
+ }
+ if err := wrappedFD.SetStat(ctx, opts); err != nil {
+ return err
+ }
+ d.updateAfterSetStatLocked(&opts)
+ return nil
+}
+
+// StatFS implements vfs.FileDescriptionImpl.StatFS.
+func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) {
+ return fd.filesystem().statFS(ctx)
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return 0, err
+ }
+ defer wrappedFD.DecRef(ctx)
+ return wrappedFD.PRead(ctx, dst, offset, opts)
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // Hold fd.mu during the read to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Read(ctx, dst, opts)
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return 0, err
+ }
+ defer wrappedFD.DecRef(ctx)
+ return wrappedFD.PWrite(ctx, src, offset, opts)
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // Hold fd.mu during the write to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Write(ctx, src, opts)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Hold fd.mu during the seek to serialize the file offset.
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ return 0, err
+ }
+ return wrappedFD.Seek(ctx, offset, whence)
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync.
+func (fd *nonDirectoryFD) Sync(ctx context.Context) error {
+ fd.mu.Lock()
+ if !fd.dentry().isCopiedUp() {
+ fd.mu.Unlock()
+ return nil
+ }
+ wrappedFD, err := fd.currentFDLocked(ctx)
+ if err != nil {
+ fd.mu.Unlock()
+ return err
+ }
+ wrappedFD.IncRef()
+ defer wrappedFD.DecRef(ctx)
+ fd.mu.Unlock()
+ return wrappedFD.Sync(ctx)
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ wrappedFD, err := fd.getCurrentFD(ctx)
+ if err != nil {
+ return err
+ }
+ defer wrappedFD.DecRef(ctx)
+ return wrappedFD.ConfigureMMap(ctx, opts)
+}
diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go
new file mode 100644
index 000000000..75cc006bf
--- /dev/null
+++ b/pkg/sentry/fsimpl/overlay/overlay.go
@@ -0,0 +1,627 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package overlay provides an overlay filesystem implementation, which
+// synthesizes a filesystem by composing one or more immutable filesystems
+// ("lower layers") with an optional mutable filesystem ("upper layer").
+//
+// Lock order:
+//
+// directoryFD.mu / nonDirectoryFD.mu
+// filesystem.renameMu
+// dentry.dirMu
+// dentry.copyMu
+//
+// Locking dentry.dirMu in multiple dentries requires that parent dentries are
+// locked before child dentries, and that filesystem.renameMu is locked to
+// stabilize this relationship.
+package overlay
+
+import (
+ "strings"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "overlay"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// FilesystemOptions may be passed as vfs.GetFilesystemOptions.InternalData to
+// FilesystemType.GetFilesystem.
+type FilesystemOptions struct {
+ // Callers passing FilesystemOptions to
+ // overlay.FilesystemType.GetFilesystem() are responsible for ensuring that
+ // the vfs.Mounts comprising the layers of the overlay filesystem do not
+ // contain submounts.
+
+ // If UpperRoot.Ok(), it is the root of the writable upper layer of the
+ // overlay.
+ UpperRoot vfs.VirtualDentry
+
+ // LowerRoots contains the roots of the immutable lower layers of the
+ // overlay. LowerRoots is immutable.
+ LowerRoots []vfs.VirtualDentry
+}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // Immutable options.
+ opts FilesystemOptions
+
+ // creds is a copy of the filesystem's creator's credentials, which are
+ // used for accesses to the filesystem's layers. creds is immutable.
+ creds *auth.Credentials
+
+ // dirDevMinor is the device minor number used for directories. dirDevMinor
+ // is immutable.
+ dirDevMinor uint32
+
+ // lowerDevMinors maps lower layer filesystems to device minor numbers
+ // assigned to non-directory files originating from that filesystem.
+ // lowerDevMinors is immutable.
+ lowerDevMinors map[*vfs.Filesystem]uint32
+
+ // renameMu synchronizes renaming with non-renaming operations in order to
+ // ensure consistent lock ordering between dentry.dirMu in different
+ // dentries.
+ renameMu sync.RWMutex
+
+ // lastDirIno is the last inode number assigned to a directory. lastDirIno
+ // is accessed using atomic memory operations.
+ lastDirIno uint64
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ fsoptsRaw := opts.InternalData
+ fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions)
+ if fsoptsRaw != nil && !haveFSOpts {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw)
+ return nil, nil, syserror.EINVAL
+ }
+ if haveFSOpts {
+ if len(fsopts.LowerRoots) == 0 {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified")
+ return nil, nil, syserror.EINVAL
+ }
+ // We don't enforce a maximum number of lower layers when not
+ // configured by applications; the sandbox owner can have an overlay
+ // filesystem with any number of lower layers.
+ } else {
+ vfsroot := vfs.RootFromContext(ctx)
+ defer vfsroot.DecRef(ctx)
+ upperPathname, ok := mopts["upperdir"]
+ if ok {
+ delete(mopts, "upperdir")
+ // Linux overlayfs also requires a workdir when upperdir is
+ // specified; we don't, so silently ignore this option.
+ delete(mopts, "workdir")
+ upperPath := fspath.Parse(upperPathname)
+ if !upperPath.Absolute {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: upperPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ defer upperRoot.DecRef(ctx)
+ privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */)
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err)
+ return nil, nil, err
+ }
+ defer privateUpperRoot.DecRef(ctx)
+ fsopts.UpperRoot = privateUpperRoot
+ }
+ lowerPathnamesStr, ok := mopts["lowerdir"]
+ if !ok {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: missing required option lowerdir")
+ return nil, nil, syserror.EINVAL
+ }
+ delete(mopts, "lowerdir")
+ lowerPathnames := strings.Split(lowerPathnamesStr, ":")
+ const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK
+ if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified")
+ return nil, nil, syserror.EINVAL
+ }
+ if len(lowerPathnames) > maxLowerLayers {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers)
+ return nil, nil, syserror.EINVAL
+ }
+ for _, lowerPathname := range lowerPathnames {
+ lowerPath := fspath.Parse(lowerPathname)
+ if !lowerPath.Absolute {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: lowerPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ defer lowerRoot.DecRef(ctx)
+ privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */)
+ if err != nil {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ defer privateLowerRoot.DecRef(ctx)
+ fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot)
+ }
+ }
+ if len(mopts) != 0 {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Allocate device numbers.
+ dirDevMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerDevMinors := make(map[*vfs.Filesystem]uint32)
+ for _, lowerRoot := range fsopts.LowerRoots {
+ lowerFS := lowerRoot.Mount().Filesystem()
+ if _, ok := lowerDevMinors[lowerFS]; !ok {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ vfsObj.PutAnonBlockDevMinor(dirDevMinor)
+ for _, lowerDevMinor := range lowerDevMinors {
+ vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
+ }
+ return nil, nil, err
+ }
+ lowerDevMinors[lowerFS] = devMinor
+ }
+ }
+
+ // Take extra references held by the filesystem.
+ if fsopts.UpperRoot.Ok() {
+ fsopts.UpperRoot.IncRef()
+ }
+ for _, lowerRoot := range fsopts.LowerRoots {
+ lowerRoot.IncRef()
+ }
+
+ fs := &filesystem{
+ opts: fsopts,
+ creds: creds.Fork(),
+ dirDevMinor: dirDevMinor,
+ lowerDevMinors: lowerDevMinors,
+ }
+ fs.vfsfs.Init(vfsObj, &fstype, fs)
+
+ // Construct the root dentry.
+ root := fs.newDentry()
+ root.refs = 1
+ if fs.opts.UpperRoot.Ok() {
+ fs.opts.UpperRoot.IncRef()
+ root.copiedUp = 1
+ root.upperVD = fs.opts.UpperRoot
+ }
+ for _, lowerRoot := range fs.opts.LowerRoots {
+ lowerRoot.IncRef()
+ root.lowerVDs = append(root.lowerVDs, lowerRoot)
+ }
+ rootTopVD := root.topLayer()
+ // Get metadata from the topmost layer. See fs.lookupLocked().
+ const rootStatMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+ rootStat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
+ Root: rootTopVD,
+ Start: rootTopVD,
+ }, &vfs.StatOptions{
+ Mask: rootStatMask,
+ })
+ if err != nil {
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, err
+ }
+ if rootStat.Mask&rootStatMask != rootStatMask {
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, syserror.EREMOTE
+ }
+ if isWhiteout(&rootStat) {
+ ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout")
+ root.destroyLocked(ctx)
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, syserror.EINVAL
+ }
+ root.mode = uint32(rootStat.Mode)
+ root.uid = rootStat.UID
+ root.gid = rootStat.GID
+ if rootStat.Mode&linux.S_IFMT == linux.S_IFDIR {
+ root.devMajor = linux.UNNAMED_MAJOR
+ root.devMinor = fs.dirDevMinor
+ root.ino = fs.newDirIno()
+ } else if !root.upperVD.Ok() {
+ root.devMajor = linux.UNNAMED_MAJOR
+ root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()]
+ root.ino = rootStat.Ino
+ } else {
+ root.devMajor = rootStat.DevMajor
+ root.devMinor = rootStat.DevMinor
+ root.ino = rootStat.Ino
+ }
+
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// clonePrivateMount creates a non-recursive bind mount rooted at vd, not
+// associated with any MountNamespace, and returns the root of the new mount.
+// (This is required to ensure that each layer of an overlay comprises only a
+// single mount, and therefore can't cross into e.g. the overlay filesystem
+// itself, risking lock recursion.) A reference is held on the returned
+// VirtualDentry.
+func clonePrivateMount(vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, forceReadOnly bool) (vfs.VirtualDentry, error) {
+ oldmnt := vd.Mount()
+ opts := oldmnt.Options()
+ if forceReadOnly {
+ opts.ReadOnly = true
+ }
+ newmnt, err := vfsObj.NewDisconnectedMount(oldmnt.Filesystem(), vd.Dentry(), &opts)
+ if err != nil {
+ return vfs.VirtualDentry{}, err
+ }
+ return vfs.MakeVirtualDentry(newmnt, vd.Dentry()), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ vfsObj := fs.vfsfs.VirtualFilesystem()
+ vfsObj.PutAnonBlockDevMinor(fs.dirDevMinor)
+ for _, lowerDevMinor := range fs.lowerDevMinors {
+ vfsObj.PutAnonBlockDevMinor(lowerDevMinor)
+ }
+ if fs.opts.UpperRoot.Ok() {
+ fs.opts.UpperRoot.DecRef(ctx)
+ }
+ for _, lowerRoot := range fs.opts.LowerRoots {
+ lowerRoot.DecRef(ctx)
+ }
+}
+
+func (fs *filesystem) statFS(ctx context.Context) (linux.Statfs, error) {
+ // Always statfs the root of the topmost layer. Compare Linux's
+ // fs/overlayfs/super.c:ovl_statfs().
+ var rootVD vfs.VirtualDentry
+ if fs.opts.UpperRoot.Ok() {
+ rootVD = fs.opts.UpperRoot
+ } else {
+ rootVD = fs.opts.LowerRoots[0]
+ }
+ fsstat, err := fs.vfsfs.VirtualFilesystem().StatFSAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: rootVD,
+ Start: rootVD,
+ })
+ if err != nil {
+ return linux.Statfs{}, err
+ }
+ fsstat.Type = linux.OVERLAYFS_SUPER_MAGIC
+ return fsstat, nil
+}
+
+func (fs *filesystem) newDirIno() uint64 {
+ return atomic.AddUint64(&fs.lastDirIno, 1)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ refs int64
+
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // mode, uid, and gid are the file mode, owner, and group of the file in
+ // the topmost layer (and therefore the overlay file as well), and are used
+ // for permission checks on this dentry. These fields are protected by
+ // copyMu and accessed using atomic memory operations.
+ mode uint32
+ uid uint32
+ gid uint32
+
+ // copiedUp is 1 if this dentry has been copied-up (i.e. upperVD.Ok()) and
+ // 0 otherwise. copiedUp is accessed using atomic memory operations.
+ copiedUp uint32
+
+ // parent is the dentry corresponding to this dentry's parent directory.
+ // name is this dentry's name in parent. If this dentry is a filesystem
+ // root, parent is nil and name is the empty string. parent and name are
+ // protected by fs.renameMu.
+ parent *dentry
+ name string
+
+ // If this dentry represents a directory, children maps the names of
+ // children for which dentries have been instantiated to those dentries,
+ // and dirents (if not nil) is a cache of dirents as returned by
+ // directoryFDs representing this directory. children is protected by
+ // dirMu.
+ dirMu sync.Mutex
+ children map[string]*dentry
+ dirents []vfs.Dirent
+
+ // upperVD and lowerVDs are the files from the overlay filesystem's layers
+ // that comprise the file on the overlay filesystem.
+ //
+ // If !upperVD.Ok(), it can transition to a valid vfs.VirtualDentry (i.e.
+ // be copied up) with copyMu locked for writing; otherwise, it is
+ // immutable. lowerVDs is always immutable.
+ copyMu sync.RWMutex
+ upperVD vfs.VirtualDentry
+ lowerVDs []vfs.VirtualDentry
+
+ // inlineLowerVDs backs lowerVDs in the common case where len(lowerVDs) <=
+ // len(inlineLowerVDs).
+ inlineLowerVDs [1]vfs.VirtualDentry
+
+ // devMajor, devMinor, and ino are the device major/minor and inode numbers
+ // used by this dentry. These fields are protected by copyMu and accessed
+ // using atomic memory operations.
+ devMajor uint32
+ devMinor uint32
+ ino uint64
+
+ locks vfs.FileLocks
+}
+
+// newDentry creates a new dentry. The dentry initially has no references; it
+// is the caller's responsibility to set the dentry's reference count and/or
+// call dentry.destroy() as appropriate. The dentry is initially invalid in
+// that it contains no layers; the caller is responsible for setting them.
+func (fs *filesystem) newDentry() *dentry {
+ d := &dentry{
+ fs: fs,
+ }
+ d.lowerVDs = d.inlineLowerVDs[:0]
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ // d.refs may be 0 if d.fs.renameMu is locked, which serializes against
+ // d.checkDropLocked().
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef(ctx context.Context) {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.renameMu.Lock()
+ d.checkDropLocked(ctx)
+ d.fs.renameMu.Unlock()
+ } else if refs < 0 {
+ panic("overlay.dentry.DecRef() called without holding a reference")
+ }
+}
+
+// checkDropLocked should be called after d's reference count becomes 0 or it
+// becomes deleted.
+//
+// Preconditions: d.fs.renameMu must be locked for writing.
+func (d *dentry) checkDropLocked(ctx context.Context) {
+ // Dentries with a positive reference count must be retained. (The only way
+ // to obtain a reference on a dentry with zero references is via path
+ // resolution, which requires renameMu, so if d.refs is zero then it will
+ // remain zero while we hold renameMu for writing.) Dentries with a
+ // negative reference count have already been destroyed.
+ if atomic.LoadInt64(&d.refs) != 0 {
+ return
+ }
+ // Refs is still zero; destroy it.
+ d.destroyLocked(ctx)
+ return
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
+func (d *dentry) destroyLocked(ctx context.Context) {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("overlay.dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("overlay.dentry.destroyLocked() called with references on the dentry")
+ }
+
+ if d.upperVD.Ok() {
+ d.upperVD.DecRef(ctx)
+ }
+ for _, lowerVD := range d.lowerVDs {
+ lowerVD.DecRef(ctx)
+ }
+
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ if !d.vfsd.IsDead() {
+ delete(d.parent.children, d.name)
+ }
+ d.parent.dirMu.Unlock()
+ // Drop the reference held by d on its parent without recursively
+ // locking d.fs.renameMu.
+ if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
+ d.parent.checkDropLocked(ctx)
+ } else if refs < 0 {
+ panic("overlay.dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(ctx context.Context, events uint32, cookie uint32, et vfs.EventType) {
+ // TODO(gvisor.dev/issue/1479): Implement inotify.
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ // TODO(gvisor.dev/issue/1479): Implement inotify.
+ return nil
+}
+
+// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
+//
+// TODO(gvisor.dev/issue/1479): Implement inotify.
+func (d *dentry) OnZeroWatches(context.Context) {}
+
+// iterLayers invokes yield on each layer comprising d, from top to bottom. If
+// any call to yield returns false, iterLayer stops iteration.
+func (d *dentry) iterLayers(yield func(vd vfs.VirtualDentry, isUpper bool) bool) {
+ if d.isCopiedUp() {
+ if !yield(d.upperVD, true) {
+ return
+ }
+ }
+ for _, lowerVD := range d.lowerVDs {
+ if !yield(lowerVD, false) {
+ return
+ }
+ }
+}
+
+func (d *dentry) topLayerInfo() (vd vfs.VirtualDentry, isUpper bool) {
+ if d.isCopiedUp() {
+ return d.upperVD, true
+ }
+ return d.lowerVDs[0], false
+}
+
+func (d *dentry) topLayer() vfs.VirtualDentry {
+ vd, _ := d.topLayerInfo()
+ return vd
+}
+
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+}
+
+// statInternalMask is the set of stat fields that is set by
+// dentry.statInternalTo().
+const statInternalMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO
+
+// statInternalTo writes fields to stat that are stored in d, and therefore do
+// not requiring invoking StatAt on the overlay's layers.
+func (d *dentry) statInternalTo(ctx context.Context, opts *vfs.StatOptions, stat *linux.Statx) {
+ stat.Mask |= statInternalMask
+ if d.isDir() {
+ // Linux sets nlink to 1 for merged directories
+ // (fs/overlayfs/inode.c:ovl_getattr()); we set it to 2 because this is
+ // correct more often ("." and the directory's entry in its parent),
+ // and some of our tests expect this.
+ stat.Nlink = 2
+ }
+ stat.UID = atomic.LoadUint32(&d.uid)
+ stat.GID = atomic.LoadUint32(&d.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&d.mode))
+ stat.Ino = atomic.LoadUint64(&d.ino)
+ stat.DevMajor = atomic.LoadUint32(&d.devMajor)
+ stat.DevMinor = atomic.LoadUint32(&d.devMinor)
+}
+
+// Preconditions: d.copyMu must be locked for writing.
+func (d *dentry) updateAfterSetStatLocked(opts *vfs.SetStatOptions) {
+ if opts.Stat.Mask&linux.STATX_MODE != 0 {
+ atomic.StoreUint32(&d.mode, (d.mode&linux.S_IFMT)|uint32(opts.Stat.Mode&^linux.S_IFMT))
+ }
+ if opts.Stat.Mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&d.uid, opts.Stat.UID)
+ }
+ if opts.Stat.Mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&d.gid, opts.Stat.GID)
+ }
+}
+
+// fileDescription is embedded by overlay implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/pipefs/BUILD b/pkg/sentry/fsimpl/pipefs/BUILD
new file mode 100644
index 000000000..5950a2d59
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "pipefs",
+ srcs = ["pipefs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
new file mode 100644
index 000000000..2ca793db9
--- /dev/null
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -0,0 +1,165 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipefs provides the filesystem implementation backing
+// Kernel.PipeMount.
+package pipefs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type filesystemType struct{}
+
+// Name implements vfs.FilesystemType.Name.
+func (filesystemType) Name() string {
+ return "pipefs"
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (filesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("pipefs.filesystemType.GetFilesystem should never be called")
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// NewFilesystem sets up and returns a new vfs.Filesystem implemented by pipefs.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, err
+ }
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.Filesystem.VFSFilesystem().Init(vfsObj, filesystemType{}, fs)
+ return fs.Filesystem.VFSFilesystem(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ inode := vd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode)
+ b.PrependComponent(fmt.Sprintf("pipe:[%d]", inode.ino))
+ return vfs.PrependPathSyntheticError{}
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeNoopRefCount
+
+ locks vfs.FileLocks
+ pipe *pipe.VFSPipe
+
+ ino uint64
+ uid auth.KUID
+ gid auth.KGID
+ // We use the creation timestamp for all of atime, mtime, and ctime.
+ ctime ktime.Time
+}
+
+func newInode(ctx context.Context, fs *filesystem) *inode {
+ creds := auth.CredentialsFromContext(ctx)
+ return &inode{
+ pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize),
+ ino: fs.Filesystem.NextIno(),
+ uid: creds.EffectiveKUID,
+ gid: creds.EffectiveKGID,
+ ctime: ktime.NowFromContext(ctx),
+ }
+}
+
+const pipeMode = 0600 | linux.S_IFIFO
+
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, pipeMode, i.uid, i.gid)
+}
+
+// Mode implements kernfs.Inode.Mode.
+func (i *inode) Mode() linux.FileMode {
+ return pipeMode
+}
+
+// Stat implements kernfs.Inode.Stat.
+func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds())
+ return linux.Statx{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
+ Blksize: usermem.PageSize,
+ Nlink: 1,
+ UID: uint32(i.uid),
+ GID: uint32(i.gid),
+ Mode: pipeMode,
+ Ino: i.ino,
+ Size: 0,
+ Blocks: 0,
+ Atime: ts,
+ Ctime: ts,
+ Mtime: ts,
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: vfsfs.Impl().(*filesystem).devMinor,
+ }, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ if opts.Stat.Mask == 0 {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// TODO(gvisor.dev/issue/1193): kernfs does not provide a way to implement
+// statfs, from which we should indicate PIPEFS_MAGIC.
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks)
+}
+
+// NewConnectedPipeFDs returns a pair of FileDescriptions representing the read
+// and write ends of a newly-created pipe, as for pipe(2) and pipe2(2).
+//
+// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
+func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ fs := mnt.Filesystem().Impl().(*filesystem)
+ inode := newInode(ctx, fs)
+ var d kernfs.Dentry
+ d.Init(inode)
+ defer d.DecRef(ctx)
+ return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
+}
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index ade6ac946..f074e6056 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -1,50 +1,68 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
-package(licenses = ["notice"])
+licenses(["notice"])
go_library(
name = "proc",
srcs = [
- "filesystems.go",
- "loadavg.go",
- "meminfo.go",
- "mounts.go",
- "net.go",
- "proc.go",
- "stat.go",
- "sys.go",
+ "filesystem.go",
+ "subtasks.go",
"task.go",
- "version.go",
+ "task_fds.go",
+ "task_files.go",
+ "task_net.go",
+ "tasks.go",
+ "tasks_files.go",
+ "tasks_sys.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc",
+ visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/context",
"//pkg/log",
- "//pkg/sentry/context",
- "//pkg/sentry/fs",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/mm",
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
"//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/usermem",
],
)
go_test(
name = "proc_test",
size = "small",
- srcs = ["net_test.go"],
- embed = [":proc"],
+ srcs = [
+ "tasks_sys_test.go",
+ "tasks_test.go",
+ ],
+ library = ":proc",
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
new file mode 100644
index 000000000..2463d51cd
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -0,0 +1,117 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package proc implements a partial in-memory file system for procfs.
+package proc
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Name is the default filesystem name.
+const Name = "proc"
+
+// FilesystemType is the factory class for procfs.
+//
+// +stateify savable
+type FilesystemType struct{}
+
+var _ vfs.FilesystemType = (*FilesystemType)(nil)
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, nil, fmt.Errorf("procfs requires a kernel")
+ }
+ pidns := kernel.PIDNamespaceFromContext(ctx)
+ if pidns == nil {
+ return nil, nil, fmt.Errorf("procfs requires a PID namespace")
+ }
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+ procfs := &filesystem{
+ devMinor: devMinor,
+ }
+ procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
+
+ var cgroups map[string]string
+ if opts.InternalData != nil {
+ data := opts.InternalData.(*InternalData)
+ cgroups = data.Cgroups
+ }
+
+ _, dentry := procfs.newTasksInode(k, pidns, cgroups)
+ return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// dynamicInode is an overfitted interface for common Inodes with
+// dynamicByteSource types used in procfs.
+type dynamicInode interface {
+ kernfs.Inode
+ vfs.DynamicBytesSource
+
+ Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode)
+}
+
+func (fs *filesystem) newDentry(creds *auth.Credentials, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+type staticFile struct {
+ kernfs.DynamicBytesFile
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFile)(nil)
+
+func newStaticFile(data string) *staticFile {
+ return &staticFile{StaticData: vfs.StaticData{Data: data}}
+}
+
+// InternalData contains internal data passed in to the procfs mount via
+// vfs.GetFilesystemOptions.InternalData.
+type InternalData struct {
+ Cgroups map[string]string
+}
diff --git a/pkg/sentry/fsimpl/proc/loadavg.go b/pkg/sentry/fsimpl/proc/loadavg.go
deleted file mode 100644
index 9135afef1..000000000
--- a/pkg/sentry/fsimpl/proc/loadavg.go
+++ /dev/null
@@ -1,40 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// loadavgData backs /proc/loadavg.
-//
-// +stateify savable
-type loadavgData struct{}
-
-var _ vfs.DynamicBytesSource = (*loadavgData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- // TODO(b/62345059): Include real data in fields.
- // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
- // Column 4-5: currently running processes and the total number of processes.
- // Column 6: the last process ID used.
- fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/meminfo.go b/pkg/sentry/fsimpl/proc/meminfo.go
deleted file mode 100644
index 9a827cd66..000000000
--- a/pkg/sentry/fsimpl/proc/meminfo.go
+++ /dev/null
@@ -1,77 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
-//
-// +stateify savable
-type meminfoData struct {
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*meminfoData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- mf := d.k.MemoryFile()
- mf.UpdateUsage()
- snapshot, totalUsage := usage.MemoryAccounting.Copy()
- totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
- anon := snapshot.Anonymous + snapshot.Tmpfs
- file := snapshot.PageCache + snapshot.Mapped
- // We don't actually have active/inactive LRUs, so just make up numbers.
- activeFile := (file / 2) &^ (usermem.PageSize - 1)
- inactiveFile := file - activeFile
-
- fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
- memFree := (totalSize - totalUsage) / 1024
- // We use MemFree as MemAvailable because we don't swap.
- // TODO(rahat): When reclaim is implemented the value of MemAvailable
- // should change.
- fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree)
- fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree)
- fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
- fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
- // Emulate a system with no swap, which disables inactivation of anon pages.
- fmt.Fprintf(buf, "SwapCache: 0 kB\n")
- fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
- fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
- fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
- fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
- fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
- fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
- fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
- fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
- fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
- fmt.Fprintf(buf, "SwapFree: 0 kB\n")
- fmt.Fprintf(buf, "Dirty: 0 kB\n")
- fmt.Fprintf(buf, "Writeback: 0 kB\n")
- fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
- fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
- fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/net.go b/pkg/sentry/fsimpl/proc/net.go
deleted file mode 100644
index fd46eebf8..000000000
--- a/pkg/sentry/fsimpl/proc/net.go
+++ /dev/null
@@ -1,338 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/socket/unix"
- "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6.
-//
-// +stateify savable
-type ifinet6 struct {
- s inet.Stack
-}
-
-var _ vfs.DynamicBytesSource = (*ifinet6)(nil)
-
-func (n *ifinet6) contents() []string {
- var lines []string
- nics := n.s.Interfaces()
- for id, naddrs := range n.s.InterfaceAddrs() {
- nic, ok := nics[id]
- if !ok {
- // NIC was added after NICNames was called. We'll just
- // ignore it.
- continue
- }
-
- for _, a := range naddrs {
- // IPv6 only.
- if a.Family != linux.AF_INET6 {
- continue
- }
-
- // Fields:
- // IPv6 address displayed in 32 hexadecimal chars without colons
- // Netlink device number (interface index) in hexadecimal (use nic id)
- // Prefix length in hexadecimal
- // Scope value (use 0)
- // Interface flags
- // Device name
- lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name))
- }
- }
- return lines
-}
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (n *ifinet6) Generate(ctx context.Context, buf *bytes.Buffer) error {
- for _, l := range n.contents() {
- buf.WriteString(l)
- }
- return nil
-}
-
-// netDev implements vfs.DynamicBytesSource for /proc/net/dev.
-//
-// +stateify savable
-type netDev struct {
- s inet.Stack
-}
-
-var _ vfs.DynamicBytesSource = (*netDev)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (n *netDev) Generate(ctx context.Context, buf *bytes.Buffer) error {
- interfaces := n.s.Interfaces()
- buf.WriteString("Inter-| Receive | Transmit\n")
- buf.WriteString(" face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n")
-
- for _, i := range interfaces {
- // Implements the same format as
- // net/core/net-procfs.c:dev_seq_printf_stats.
- var stats inet.StatDev
- if err := n.s.Statistics(&stats, i.Name); err != nil {
- log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err)
- continue
- }
- fmt.Fprintf(
- buf,
- "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n",
- i.Name,
- // Received
- stats[0], // bytes
- stats[1], // packets
- stats[2], // errors
- stats[3], // dropped
- stats[4], // fifo
- stats[5], // frame
- stats[6], // compressed
- stats[7], // multicast
- // Transmitted
- stats[8], // bytes
- stats[9], // packets
- stats[10], // errors
- stats[11], // dropped
- stats[12], // fifo
- stats[13], // frame
- stats[14], // compressed
- stats[15], // multicast
- )
- }
-
- return nil
-}
-
-// netUnix implements vfs.DynamicBytesSource for /proc/net/unix.
-//
-// +stateify savable
-type netUnix struct {
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*netUnix)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (n *netUnix) Generate(ctx context.Context, buf *bytes.Buffer) error {
- buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n")
- for _, se := range n.k.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
- continue
- }
- sfile := s.(*fs.File)
- if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
- s.DecRef()
- // Not a unix socket.
- continue
- }
- sops := sfile.FileOperations.(*unix.SocketOperations)
-
- addr, err := sops.Endpoint().GetLocalAddress()
- if err != nil {
- log.Warningf("Failed to retrieve socket name from %+v: %v", sfile, err)
- addr.Addr = "<unknown>"
- }
-
- sockFlags := 0
- if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok {
- if ce.Listening() {
- // For unix domain sockets, linux reports a single flag
- // value if the socket is listening, of __SO_ACCEPTCON.
- sockFlags = linux.SO_ACCEPTCON
- }
- }
-
- // In the socket entry below, the value for the 'Num' field requires
- // some consideration. Linux prints the address to the struct
- // unix_sock representing a socket in the kernel, but may redact the
- // value for unprivileged users depending on the kptr_restrict
- // sysctl.
- //
- // One use for this field is to allow a privileged user to
- // introspect into the kernel memory to determine information about
- // a socket not available through procfs, such as the socket's peer.
- //
- // In gvisor, returning a pointer to our internal structures would
- // be pointless, as it wouldn't match the memory layout for struct
- // unix_sock, making introspection difficult. We could populate a
- // struct unix_sock with the appropriate data, but even that
- // requires consideration for which kernel version to emulate, as
- // the definition of this struct changes over time.
- //
- // For now, we always redact this pointer.
- fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %5d",
- (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
- sfile.ReadRefs()-1, // RefCount, don't count our own ref.
- 0, // Protocol, always 0 for UDS.
- sockFlags, // Flags.
- sops.Endpoint().Type(), // Type.
- sops.State(), // State.
- sfile.InodeID(), // Inode.
- )
-
- // Path
- if len(addr.Addr) != 0 {
- if addr.Addr[0] == 0 {
- // Abstract path.
- fmt.Fprintf(buf, " @%s", string(addr.Addr[1:]))
- } else {
- fmt.Fprintf(buf, " %s", string(addr.Addr))
- }
- }
- fmt.Fprintf(buf, "\n")
-
- s.DecRef()
- }
- return nil
-}
-
-// netTCP implements vfs.DynamicBytesSource for /proc/net/tcp.
-//
-// +stateify savable
-type netTCP struct {
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*netTCP)(nil)
-
-func (n *netTCP) Generate(ctx context.Context, buf *bytes.Buffer) error {
- t := kernel.TaskFromContext(ctx)
- buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n")
- for _, se := range n.k.ListSockets() {
- s := se.Sock.Get()
- if s == nil {
- log.Debugf("Couldn't resolve weakref %+v in socket table, racing with destruction?", se.Sock)
- continue
- }
- sfile := s.(*fs.File)
- sops, ok := sfile.FileOperations.(socket.Socket)
- if !ok {
- panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
- }
- if family, stype, _ := sops.Type(); !(family == linux.AF_INET && stype == linux.SOCK_STREAM) {
- s.DecRef()
- // Not tcp4 sockets.
- continue
- }
-
- // Linux's documentation for the fields below can be found at
- // https://www.kernel.org/doc/Documentation/networking/proc_net_tcp.txt.
- // For Linux's implementation, see net/ipv4/tcp_ipv4.c:get_tcp4_sock().
- // Note that the header doesn't contain labels for all the fields.
-
- // Field: sl; entry number.
- fmt.Fprintf(buf, "%4d: ", se.ID)
-
- portBuf := make([]byte, 2)
-
- // Field: local_adddress.
- var localAddr linux.SockAddrInet
- if local, _, err := sops.GetSockName(t); err == nil {
- localAddr = *local.(*linux.SockAddrInet)
- }
- binary.LittleEndian.PutUint16(portBuf, localAddr.Port)
- fmt.Fprintf(buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(localAddr.Addr[:]),
- portBuf)
-
- // Field: rem_address.
- var remoteAddr linux.SockAddrInet
- if remote, _, err := sops.GetPeerName(t); err == nil {
- remoteAddr = *remote.(*linux.SockAddrInet)
- }
- binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port)
- fmt.Fprintf(buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(remoteAddr.Addr[:]),
- portBuf)
-
- // Field: state; socket state.
- fmt.Fprintf(buf, "%02X ", sops.State())
-
- // Field: tx_queue, rx_queue; number of packets in the transmit and
- // receive queue. Unimplemented.
- fmt.Fprintf(buf, "%08X:%08X ", 0, 0)
-
- // Field: tr, tm->when; timer active state and number of jiffies
- // until timer expires. Unimplemented.
- fmt.Fprintf(buf, "%02X:%08X ", 0, 0)
-
- // Field: retrnsmt; number of unrecovered RTO timeouts.
- // Unimplemented.
- fmt.Fprintf(buf, "%08X ", 0)
-
- // Field: uid.
- uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
- if err != nil {
- log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
- fmt.Fprintf(buf, "%5d ", 0)
- } else {
- fmt.Fprintf(buf, "%5d ", uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
- }
-
- // Field: timeout; number of unanswered 0-window probes.
- // Unimplemented.
- fmt.Fprintf(buf, "%8d ", 0)
-
- // Field: inode.
- fmt.Fprintf(buf, "%8d ", sfile.InodeID())
-
- // Field: refcount. Don't count the ref we obtain while deferencing
- // the weakref to this socket.
- fmt.Fprintf(buf, "%d ", sfile.ReadRefs()-1)
-
- // Field: Socket struct address. Redacted due to the same reason as
- // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
- fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil))
-
- // Field: retransmit timeout. Unimplemented.
- fmt.Fprintf(buf, "%d ", 0)
-
- // Field: predicted tick of soft clock (delayed ACK control data).
- // Unimplemented.
- fmt.Fprintf(buf, "%d ", 0)
-
- // Field: (ack.quick<<1)|ack.pingpong, Unimplemented.
- fmt.Fprintf(buf, "%d ", 0)
-
- // Field: sending congestion window, Unimplemented.
- fmt.Fprintf(buf, "%d ", 0)
-
- // Field: Slow start size threshold, -1 if threshold >= 0xFFFF.
- // Unimplemented, report as large threshold.
- fmt.Fprintf(buf, "%d", -1)
-
- fmt.Fprintf(buf, "\n")
-
- s.DecRef()
- }
-
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/stat.go b/pkg/sentry/fsimpl/proc/stat.go
deleted file mode 100644
index 720db3828..000000000
--- a/pkg/sentry/fsimpl/proc/stat.go
+++ /dev/null
@@ -1,127 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// cpuStats contains the breakdown of CPU time for /proc/stat.
-type cpuStats struct {
- // user is time spent in userspace tasks with non-positive niceness.
- user uint64
-
- // nice is time spent in userspace tasks with positive niceness.
- nice uint64
-
- // system is time spent in non-interrupt kernel context.
- system uint64
-
- // idle is time spent idle.
- idle uint64
-
- // ioWait is time spent waiting for IO.
- ioWait uint64
-
- // irq is time spent in interrupt context.
- irq uint64
-
- // softirq is time spent in software interrupt context.
- softirq uint64
-
- // steal is involuntary wait time.
- steal uint64
-
- // guest is time spent in guests with non-positive niceness.
- guest uint64
-
- // guestNice is time spent in guests with positive niceness.
- guestNice uint64
-}
-
-// String implements fmt.Stringer.
-func (c cpuStats) String() string {
- return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
-}
-
-// statData implements vfs.DynamicBytesSource for /proc/stat.
-//
-// +stateify savable
-type statData struct {
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*statData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- // TODO(b/37226836): We currently export only zero CPU stats. We could
- // at least provide some aggregate stats.
- var cpu cpuStats
- fmt.Fprintf(buf, "cpu %s\n", cpu)
-
- for c, max := uint(0), s.k.ApplicationCores(); c < max; c++ {
- fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
- }
-
- // The total number of interrupts is dependent on the CPUs and PCI
- // devices on the system. See arch_probe_nr_irqs.
- //
- // Since we don't report real interrupt stats, just choose an arbitrary
- // value from a representative VM.
- const numInterrupts = 256
-
- // The Kernel doesn't handle real interrupts, so report all zeroes.
- // TODO(b/37226836): We could count page faults as #PF.
- fmt.Fprintf(buf, "intr 0") // total
- for i := 0; i < numInterrupts; i++ {
- fmt.Fprintf(buf, " 0")
- }
- fmt.Fprintf(buf, "\n")
-
- // Total number of context switches.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "ctxt 0\n")
-
- // CLOCK_REALTIME timestamp from boot, in seconds.
- fmt.Fprintf(buf, "btime %d\n", s.k.Timekeeper().BootTime().Seconds())
-
- // Total number of clones.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "processes 0\n")
-
- // Number of runnable tasks.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "procs_running 0\n")
-
- // Number of tasks waiting on IO.
- // TODO(b/37226836): Count this.
- fmt.Fprintf(buf, "procs_blocked 0\n")
-
- // Number of each softirq handled.
- fmt.Fprintf(buf, "softirq 0") // total
- for i := 0; i < linux.NumSoftIRQ; i++ {
- fmt.Fprintf(buf, " 0")
- }
- fmt.Fprintf(buf, "\n")
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go
new file mode 100644
index 000000000..79c2725f3
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/subtasks.go
@@ -0,0 +1,182 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// subtasksInode represents the inode for /proc/[pid]/task/ directory.
+//
+// +stateify savable
+type subtasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+
+ locks vfs.FileLocks
+
+ fs *filesystem
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+ cgroupControllers map[string]string
+}
+
+var _ kernfs.Inode = (*subtasksInode)(nil)
+
+func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *kernfs.Dentry {
+ subInode := &subtasksInode{
+ fs: fs,
+ task: task,
+ pidns: pidns,
+ cgroupControllers: cgroupControllers,
+ }
+ // Note: credentials are overridden by taskOwnedInode.
+ subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+ subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ inode := &taskOwnedInode{Inode: subInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ tid, err := strconv.ParseUint(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+
+ subTask := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if subTask == nil {
+ return nil, syserror.ENOENT
+ }
+ if subTask.ThreadGroup() != i.task.ThreadGroup() {
+ return nil, syserror.ENOENT
+ }
+
+ subTaskDentry := i.fs.newTaskInode(subTask, i.pidns, false, i.cgroupControllers)
+ return subTaskDentry.VFSDentry(), nil
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ tasks := i.task.ThreadGroup().MemberIDs(i.pidns)
+ if len(tasks) == 0 {
+ return offset, syserror.ENOENT
+ }
+ if relOffset >= int64(len(tasks)) {
+ return offset, nil
+ }
+
+ tids := make([]int, 0, len(tasks))
+ for _, tid := range tasks {
+ tids = append(tids, int(tid))
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids[relOffset:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+type subtasksFD struct {
+ kernfs.GenericDirectoryFD
+
+ task *kernel.Task
+}
+
+func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.IterDirents(ctx, cb)
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return 0, syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.Seek(ctx, offset, whence)
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *subtasksFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return linux.Statx{}, syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.Stat(ctx, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ if fd.task.ExitState() >= kernel.TaskExitZombie {
+ return syserror.ENOENT
+ }
+ return fd.GenericDirectoryFD.SetStat(ctx, opts)
+}
+
+// Open implements kernfs.Inode.
+func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &subtasksFD{task: i.task}
+ if err := fd.Init(&i.OrderedChildren, &i.locks, &opts); err != nil {
+ return nil, err
+ }
+ if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// Stat implements kernfs.Inode.
+func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ stat.Nlink += uint32(i.task.ThreadGroup().Count())
+ }
+ return stat, nil
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
diff --git a/pkg/sentry/fsimpl/proc/sys.go b/pkg/sentry/fsimpl/proc/sys.go
deleted file mode 100644
index b88256e12..000000000
--- a/pkg/sentry/fsimpl/proc/sys.go
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// mmapMinAddrData implements vfs.DynamicBytesSource for
-// /proc/sys/vm/mmap_min_addr.
-//
-// +stateify savable
-type mmapMinAddrData struct {
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*mmapMinAddrData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
- return nil
-}
-
-// +stateify savable
-type overcommitMemory struct{}
-
-var _ vfs.DynamicBytesSource = (*overcommitMemory)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (d *overcommitMemory) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "0\n")
- return nil
-}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index c46e05c3a..a5c7aa470 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -19,243 +19,221 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
)
-// mapsCommon is embedded by mapsData and smapsData.
-type mapsCommon struct {
- t *kernel.Task
-}
+// taskInode represents the inode for /proc/PID/ directory.
+//
+// +stateify savable
+type taskInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
-// mm gets the kernel task's MemoryManager. No additional reference is taken on
-// mm here. This is safe because MemoryManager.destroy is required to leave the
-// MemoryManager in a state where it's still usable as a DynamicBytesSource.
-func (md *mapsCommon) mm() *mm.MemoryManager {
- var tmm *mm.MemoryManager
- md.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- tmm = mm
- }
- })
- return tmm
+ locks vfs.FileLocks
+
+ task *kernel.Task
}
-// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
-//
-// +stateify savable
-type mapsData struct {
- mapsCommon
+var _ kernfs.Inode = (*taskInode)(nil)
+
+func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) *kernfs.Dentry {
+ // TODO(gvisor.dev/issue/164): Fail with ESRCH if task exited.
+ contents := map[string]*kernfs.Dentry{
+ "auxv": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &auxvData{task: task}),
+ "cmdline": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: cmdlineDataArg}),
+ "comm": fs.newComm(task, fs.NextIno(), 0444),
+ "environ": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &cmdlineData{task: task, arg: environDataArg}),
+ "exe": fs.newExeSymlink(task, fs.NextIno()),
+ "fd": fs.newFDDirInode(task),
+ "fdinfo": fs.newFDInfoDirInode(task),
+ "gid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}),
+ "io": fs.newTaskOwnedFile(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)),
+ "maps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mapsData{task: task}),
+ "mountinfo": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountInfoData{task: task}),
+ "mounts": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &mountsData{task: task}),
+ "net": fs.newTaskNetDir(task),
+ "ns": fs.newTaskOwnedDir(task, fs.NextIno(), 0511, map[string]*kernfs.Dentry{
+ "net": fs.newNamespaceSymlink(task, fs.NextIno(), "net"),
+ "pid": fs.newNamespaceSymlink(task, fs.NextIno(), "pid"),
+ "user": fs.newNamespaceSymlink(task, fs.NextIno(), "user"),
+ }),
+ "oom_score": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newStaticFile("0\n")),
+ "oom_score_adj": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &oomScoreAdj{task: task}),
+ "smaps": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &smapsData{task: task}),
+ "stat": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
+ "statm": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statmData{task: task}),
+ "status": fs.newTaskOwnedFile(task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "uid_map": fs.newTaskOwnedFile(task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
+ }
+ if isThreadGroup {
+ contents["task"] = fs.newSubtasks(task, pidns, cgroupControllers)
+ }
+ if len(cgroupControllers) > 0 {
+ contents["cgroup"] = fs.newTaskOwnedFile(task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
+ }
+
+ taskInode := &taskInode{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ inode := &taskOwnedInode{Inode: taskInode, owner: task}
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ taskInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := taskInode.OrderedChildren.Populate(dentry, contents)
+ taskInode.IncLinks(links)
+
+ return dentry
}
-var _ vfs.DynamicBytesSource = (*mapsData)(nil)
+// Valid implements kernfs.inodeDynamicLookup. This inode remains valid as long
+// as the task is still running. When it's dead, another tasks with the same
+// PID could replace it.
+func (i *taskInode) Valid(ctx context.Context) bool {
+ return i.task.ExitState() != kernel.TaskExitDead
+}
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (md *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- if mm := md.mm(); mm != nil {
- mm.ReadMapsDataInto(ctx, buf)
+// Open implements kernfs.Inode.
+func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
}
- return nil
+ return fd.VFSFileDescription(), nil
}
-// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
-//
-// +stateify savable
-type smapsData struct {
- mapsCommon
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
}
-var _ vfs.DynamicBytesSource = (*smapsData)(nil)
+// taskOwnedInode implements kernfs.Inode and overrides inode owner with task
+// effective user and group.
+type taskOwnedInode struct {
+ kernfs.Inode
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (sd *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- if mm := sd.mm(); mm != nil {
- mm.ReadSmapsDataInto(ctx, buf)
- }
- return nil
+ // owner is the task that owns this inode.
+ owner *kernel.Task
}
-// +stateify savable
-type taskStatData struct {
- t *kernel.Task
+var _ kernfs.Inode = (*taskOwnedInode)(nil)
- // If tgstats is true, accumulate fault stats (not implemented) and CPU
- // time across all tasks in t's thread group.
- tgstats bool
+func (fs *filesystem) newTaskOwnedFile(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) *kernfs.Dentry {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm)
- // pidns is the PID namespace associated with the proc filesystem that
- // includes the file using this statData.
- pidns *kernel.PIDNamespace
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
}
-var _ vfs.DynamicBytesSource = (*taskStatData)(nil)
+func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry {
+ dir := &kernfs.StaticDirectory{}
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.t))
- fmt.Fprintf(buf, "(%s) ", s.t.Name())
- fmt.Fprintf(buf, "%c ", s.t.StateStatus()[0])
- ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
- ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
- }
- fmt.Fprintf(buf, "%d ", ppid)
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.t.ThreadGroup().ProcessGroup()))
- fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.t.ThreadGroup().Session()))
- fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
- fmt.Fprintf(buf, "0 " /* flags */)
- fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
- var cputime usage.CPUStats
- if s.tgstats {
- cputime = s.t.ThreadGroup().CPUStats()
- } else {
- cputime = s.t.CPUStats()
- }
- fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- cputime = s.t.ThreadGroup().JoinedChildCPUStats()
- fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
- fmt.Fprintf(buf, "%d %d ", s.t.Priority(), s.t.Niceness())
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Count())
-
- // itrealvalue. Since kernel 2.6.17, this field is no longer
- // maintained, and is hard coded as 0.
- fmt.Fprintf(buf, "0 ")
-
- // Start time is relative to boot time, expressed in clock ticks.
- fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.t.StartTime().Sub(s.t.Kernel().Timekeeper().BootTime())))
-
- var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- }
- })
- fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
-
- // rsslim.
- fmt.Fprintf(buf, "%d ", s.t.ThreadGroup().Limits().Get(limits.Rss).Cur)
-
- fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
- fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
- fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
- terminationSignal := linux.Signal(0)
- if s.t == s.t.ThreadGroup().Leader() {
- terminationSignal = s.t.ThreadGroup().TerminationSignal()
- }
- fmt.Fprintf(buf, "%d ", terminationSignal)
- fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
- fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
- fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
- fmt.Fprintf(buf, "0\n" /* exit_code */)
+ // Note: credentials are overridden by taskOwnedInode.
+ dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm)
- return nil
-}
+ inode := &taskOwnedInode{Inode: dir, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(inode)
-// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
-//
-// +stateify savable
-type statmData struct {
- t *kernel.Task
-}
+ dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := dir.OrderedChildren.Populate(d, children)
+ dir.IncLinks(links)
-var _ vfs.DynamicBytesSource = (*statmData)(nil)
+ return d
+}
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- var vss, rss uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
+// Stat implements kernfs.Inode.
+func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.Inode.Stat(ctx, fs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ if opts.Mask&(linux.STATX_UID|linux.STATX_GID) != 0 {
+ uid, gid := i.getOwner(linux.FileMode(stat.Mode))
+ if opts.Mask&linux.STATX_UID != 0 {
+ stat.UID = uint32(uid)
}
- })
-
- fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
- return nil
+ if opts.Mask&linux.STATX_GID != 0 {
+ stat.GID = uint32(gid)
+ }
+ }
+ return stat, nil
}
-// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
-//
-// +stateify savable
-type statusData struct {
- t *kernel.Task
- pidns *kernel.PIDNamespace
+// CheckPermissions implements kernfs.Inode.
+func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := i.Mode()
+ uid, gid := i.getOwner(mode)
+ return vfs.GenericCheckPermissions(creds, ats, mode, uid, gid)
}
-var _ vfs.DynamicBytesSource = (*statusData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "Name:\t%s\n", s.t.Name())
- fmt.Fprintf(buf, "State:\t%s\n", s.t.StateStatus())
- fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.t.ThreadGroup()))
- fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.t))
- ppid := kernel.ThreadID(0)
- if parent := s.t.Parent(); parent != nil {
- ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+func (i *taskOwnedInode) getOwner(mode linux.FileMode) (auth.KUID, auth.KGID) {
+ // By default, set the task owner as the file owner.
+ creds := i.owner.Credentials()
+ uid := creds.EffectiveKUID
+ gid := creds.EffectiveKGID
+
+ // Linux doesn't apply dumpability adjustments to world readable/executable
+ // directories so that applications can stat /proc/PID to determine the
+ // effective UID of a process. See fs/proc/base.c:task_dump_owner.
+ if mode.FileType() == linux.ModeDirectory && mode.Permissions() == 0555 {
+ return uid, gid
}
- fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
- tpid := kernel.ThreadID(0)
- if tracer := s.t.Tracer(); tracer != nil {
- tpid = s.pidns.IDOfTask(tracer)
+
+ // If the task is not dumpable, then root (in the namespace preferred)
+ // owns the file.
+ m := getMM(i.owner)
+ if m == nil {
+ return auth.RootKUID, auth.RootKGID
}
- fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
- var fds int
- var vss, rss, data uint64
- s.t.WithMuLocked(func(t *kernel.Task) {
- if fdTable := t.FDTable(); fdTable != nil {
- fds = fdTable.Size()
+ if m.Dumpability() != mm.UserDumpable {
+ uid = auth.RootKUID
+ if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() {
+ uid = kuid
}
- if mm := t.MemoryManager(); mm != nil {
- vss = mm.VirtualMemorySize()
- rss = mm.ResidentSetSize()
- data = mm.VirtualDataSize()
+ gid = auth.RootKGID
+ if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() {
+ gid = kgid
}
- })
- fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
- fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
- fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
- fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
- fmt.Fprintf(buf, "Threads:\t%d\n", s.t.ThreadGroup().Count())
- creds := s.t.Credentials()
- fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
- fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
- fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
- fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
- fmt.Fprintf(buf, "Seccomp:\t%d\n", s.t.SeccompMode())
- return nil
-}
-
-// ioUsage is the /proc/<pid>/io and /proc/<pid>/task/<tid>/io data provider.
-type ioUsage interface {
- // IOUsage returns the io usage data.
- IOUsage() *usage.IO
+ }
+ return uid, gid
}
-// +stateify savable
-type ioData struct {
- ioUsage
+func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
+ if isThreadGroup {
+ return &ioData{ioUsage: t.ThreadGroup()}
+ }
+ return &ioData{ioUsage: t}
}
-var _ vfs.DynamicBytesSource = (*ioData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- io := usage.IO{}
- io.Accumulate(i.IOUsage())
-
- fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
- fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
- fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
- fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
- fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
- fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
- fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
- return nil
+// newCgroupData creates inode that shows cgroup information.
+// From man 7 cgroups: "For each cgroup hierarchy of which the process is a
+// member, there is one entry containing three colon-separated fields:
+// hierarchy-ID:controller-list:cgroup-path"
+func newCgroupData(controllers map[string]string) dynamicInode {
+ var buf bytes.Buffer
+
+ // The hierarchy ids must be positive integers (for cgroup v1), but the
+ // exact number does not matter, so long as they are unique. We can
+ // just use a counter, but since linux sorts this file in descending
+ // order, we must count down to preserve this behavior.
+ i := len(controllers)
+ for name, dir := range controllers {
+ fmt.Fprintf(&buf, "%d:%s:%s\n", i, name, dir)
+ i--
+ }
+ return newStaticFile(buf.String())
}
diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go
new file mode 100644
index 000000000..f0d3f7f5e
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_fds.go
@@ -0,0 +1,307 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func getTaskFD(t *kernel.Task, fd int32) (*vfs.FileDescription, kernel.FDFlags) {
+ var (
+ file *vfs.FileDescription
+ flags kernel.FDFlags
+ )
+ t.WithMuLocked(func(t *kernel.Task) {
+ if fdt := t.FDTable(); fdt != nil {
+ file, flags = fdt.GetVFS2(fd)
+ }
+ })
+ return file, flags
+}
+
+func taskFDExists(ctx context.Context, t *kernel.Task, fd int32) bool {
+ file, _ := getTaskFD(t, fd)
+ if file == nil {
+ return false
+ }
+ file.DecRef(ctx)
+ return true
+}
+
+type fdDir struct {
+ locks vfs.FileLocks
+
+ fs *filesystem
+ task *kernel.Task
+
+ // When produceSymlinks is set, dirents produces for the FDs are reported
+ // as symlink. Otherwise, they are reported as regular files.
+ produceSymlink bool
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) {
+ var fds []int32
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.GetFDs(ctx)
+ }
+ })
+
+ typ := uint8(linux.DT_REG)
+ if i.produceSymlink {
+ typ = linux.DT_LNK
+ }
+
+ // Find the appropriate starting point.
+ idx := sort.Search(len(fds), func(i int) bool { return fds[i] >= int32(relOffset) })
+ if idx >= len(fds) {
+ return offset, nil
+ }
+ for _, fd := range fds[idx:] {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(fd), 10),
+ Type: typ,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return offset, nil
+}
+
+// fdDirInode represents the inode for /proc/[pid]/fd directory.
+//
+// +stateify savable
+type fdDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdDirInode)(nil)
+
+func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry {
+ inode := &fdDirInode{
+ fdDir: fdDir{
+ fs: fs,
+ task: task,
+ produceSymlink: true,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ fdInt, err := strconv.ParseInt(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(ctx, i.task, fd) {
+ return nil, syserror.ENOENT
+ }
+ taskDentry := i.fs.newFDSymlink(i.task, fd, i.fs.NextIno())
+ return taskDentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// CheckPermissions implements kernfs.Inode.
+//
+// This is to match Linux, which uses a special permission handler to guarantee
+// that a process can still access /proc/self/fd after it has executed
+// setuid. See fs/proc/fd.c:proc_fd_permission.
+func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ err := i.InodeAttrs.CheckPermissions(ctx, creds, ats)
+ if err == nil {
+ // Access granted, no extra check needed.
+ return nil
+ }
+ if t := kernel.TaskFromContext(ctx); t != nil {
+ // Allow access if the task trying to access it is in the thread group
+ // corresponding to this directory.
+ if i.task.ThreadGroup() == t.ThreadGroup() {
+ // Access granted (overridden).
+ return nil
+ }
+ }
+ return err
+}
+
+// fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file.
+//
+// +stateify savable
+type fdSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+ fd int32
+}
+
+var _ kernfs.Inode = (*fdSymlink)(nil)
+
+func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *kernfs.Dentry {
+ inode := &fdSymlink{
+ task: task,
+ fd: fd,
+ }
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *fdSymlink) Readlink(ctx context.Context) (string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return "", syserror.ENOENT
+ }
+ defer file.DecRef(ctx)
+ root := vfs.RootFromContext(ctx)
+ defer root.DecRef(ctx)
+ return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry())
+}
+
+func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ file, _ := getTaskFD(s.task, s.fd)
+ if file == nil {
+ return vfs.VirtualDentry{}, "", syserror.ENOENT
+ }
+ defer file.DecRef(ctx)
+ vd := file.VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+// fdInfoDirInode represents the inode for /proc/[pid]/fdinfo directory.
+//
+// +stateify savable
+type fdInfoDirInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+ fdDir
+}
+
+var _ kernfs.Inode = (*fdInfoDirInode)(nil)
+
+func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry {
+ inode := &fdInfoDirInode{
+ fdDir: fdDir{
+ fs: fs,
+ task: task,
+ },
+ }
+ inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+
+ return dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ fdInt, err := strconv.ParseInt(name, 10, 32)
+ if err != nil {
+ return nil, syserror.ENOENT
+ }
+ fd := int32(fdInt)
+ if !taskFDExists(ctx, i.task, fd) {
+ return nil, syserror.ENOENT
+ }
+ data := &fdInfoData{
+ task: i.task,
+ fd: fd,
+ }
+ dentry := i.fs.newTaskOwnedFile(i.task, i.fs.NextIno(), 0444, data)
+ return dentry.VFSDentry(), nil
+}
+
+// Open implements kernfs.Inode.
+func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd].
+//
+// +stateify savable
+type fdInfoData struct {
+ kernfs.DynamicBytesFile
+ refs.AtomicRefCount
+
+ task *kernel.Task
+ fd int32
+}
+
+var _ dynamicInode = (*fdInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *fdInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ file, descriptorFlags := getTaskFD(d.task, d.fd)
+ if file == nil {
+ return syserror.ENOENT
+ }
+ defer file.DecRef(ctx)
+ // TODO(b/121266871): Include pos, locks, and other data. For now we only
+ // have flags.
+ // See https://www.kernel.org/doc/Documentation/filesystems/proc.txt
+ flags := uint(file.StatusFlags()) | descriptorFlags.ToLinuxFileFlags()
+ fmt.Fprintf(buf, "flags:\t0%o\n", flags)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
new file mode 100644
index 000000000..830b78949
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -0,0 +1,902 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// "There is an (arbitrary) limit on the number of lines in the file. As at
+// Linux 3.18, the limit is five lines." - user_namespaces(7)
+const maxIDMapLines = 5
+
+// mm gets the kernel task's MemoryManager. No additional reference is taken on
+// mm here. This is safe because MemoryManager.destroy is required to leave the
+// MemoryManager in a state where it's still usable as a DynamicBytesSource.
+func getMM(task *kernel.Task) *mm.MemoryManager {
+ var tmm *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ tmm = mm
+ }
+ })
+ return tmm
+}
+
+// getMMIncRef returns t's MemoryManager. If getMMIncRef succeeds, the
+// MemoryManager's users count is incremented, and must be decremented by the
+// caller when it is no longer in use.
+func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
+ if task.ExitState() == kernel.TaskExitDead {
+ return nil, syserror.ESRCH
+ }
+ var m *mm.MemoryManager
+ task.WithMuLocked(func(t *kernel.Task) {
+ m = t.MemoryManager()
+ })
+ if m == nil || !m.IncUsers() {
+ return nil, io.EOF
+ }
+ return m, nil
+}
+
+func checkTaskState(t *kernel.Task) error {
+ switch t.ExitState() {
+ case kernel.TaskExitZombie:
+ return syserror.EACCES
+ case kernel.TaskExitDead:
+ return syserror.ESRCH
+ }
+ return nil
+}
+
+type bufferWriter struct {
+ buf *bytes.Buffer
+}
+
+// WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns
+// the number of bytes written. It may return a partial write without an
+// error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not
+// return a full write with an error (i.e. srcs.NumBytes(), err) where err
+// != nil).
+func (w *bufferWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ written := srcs.NumBytes()
+ for !srcs.IsEmpty() {
+ w.buf.Write(srcs.Head().ToSlice())
+ srcs = srcs.Tail()
+ }
+ return written, nil
+}
+
+// auxvData implements vfs.DynamicBytesSource for /proc/[pid]/auxv.
+//
+// +stateify savable
+type auxvData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*auxvData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ auxv := m.Auxv()
+ // Space for buffer with AT_NULL (0) terminator at the end.
+ buf.Grow((len(auxv) + 1) * 16)
+ for _, e := range auxv {
+ var tmp [16]byte
+ usermem.ByteOrder.PutUint64(tmp[:8], e.Key)
+ usermem.ByteOrder.PutUint64(tmp[8:], uint64(e.Value))
+ buf.Write(tmp[:])
+ }
+ var atNull [16]byte
+ buf.Write(atNull[:])
+
+ return nil
+}
+
+// execArgType enumerates the types of exec arguments that are exposed through
+// proc.
+type execArgType int
+
+const (
+ cmdlineDataArg execArgType = iota
+ environDataArg
+)
+
+// cmdlineData implements vfs.DynamicBytesSource for /proc/[pid]/cmdline.
+//
+// +stateify savable
+type cmdlineData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+
+ // arg is the type of exec argument this file contains.
+ arg execArgType
+}
+
+var _ dynamicInode = (*cmdlineData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ m, err := getMMIncRef(d.task)
+ if err != nil {
+ return err
+ }
+ defer m.DecUsers(ctx)
+
+ // Figure out the bounds of the exec arg we are trying to read.
+ var ar usermem.AddrRange
+ switch d.arg {
+ case cmdlineDataArg:
+ ar = usermem.AddrRange{
+ Start: m.ArgvStart(),
+ End: m.ArgvEnd(),
+ }
+ case environDataArg:
+ ar = usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+ default:
+ panic(fmt.Sprintf("unknown exec arg type %v", d.arg))
+ }
+ if ar.Start == 0 || ar.End == 0 {
+ // Don't attempt to read before the start/end are set up.
+ return io.EOF
+ }
+
+ // N.B. Technically this should be usermem.IOOpts.IgnorePermissions = true
+ // until Linux 4.9 (272ddc8b3735 "proc: don't use FOLL_FORCE for reading
+ // cmdline and environment").
+ writer := &bufferWriter{buf: buf}
+ if n, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(ar), writer, usermem.IOOpts{}); n == 0 || err != nil {
+ // Nothing to copy or something went wrong.
+ return err
+ }
+
+ // On Linux, if the NULL byte at the end of the argument vector has been
+ // overwritten, it continues reading the environment vector as part of
+ // the argument vector.
+ if d.arg == cmdlineDataArg && buf.Bytes()[buf.Len()-1] != 0 {
+ if end := bytes.IndexByte(buf.Bytes(), 0); end != -1 {
+ // If we found a NULL character somewhere else in argv, truncate the
+ // return up to the NULL terminator (including it).
+ buf.Truncate(end)
+ return nil
+ }
+
+ // There is no NULL terminator in the string, return into envp.
+ arEnvv := usermem.AddrRange{
+ Start: m.EnvvStart(),
+ End: m.EnvvEnd(),
+ }
+
+ // Upstream limits the returned amount to one page of slop.
+ // https://elixir.bootlin.com/linux/v4.20/source/fs/proc/base.c#L208
+ // we'll return one page total between argv and envp because of the
+ // above page restrictions.
+ if buf.Len() >= usermem.PageSize {
+ // Returned at least one page already, nothing else to add.
+ return nil
+ }
+ remaining := usermem.PageSize - buf.Len()
+ if int(arEnvv.Length()) > remaining {
+ end, ok := arEnvv.Start.AddLength(uint64(remaining))
+ if !ok {
+ return syserror.EFAULT
+ }
+ arEnvv.End = end
+ }
+ if _, err := m.CopyInTo(ctx, usermem.AddrRangeSeqOf(arEnvv), writer, usermem.IOOpts{}); err != nil {
+ return err
+ }
+
+ // Linux will return envp up to and including the first NULL character,
+ // so find it.
+ envStart := int(ar.Length())
+ if nullIdx := bytes.IndexByte(buf.Bytes()[envStart:], 0); nullIdx != -1 {
+ buf.Truncate(envStart + nullIdx)
+ }
+ }
+
+ return nil
+}
+
+// +stateify savable
+type commInode struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) *kernfs.Dentry {
+ inode := &commInode{task: task}
+ inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (i *commInode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ // This file can always be read or written by members of the same thread
+ // group. See fs/proc/base.c:proc_tid_comm_permission.
+ //
+ // N.B. This check is currently a no-op as we don't yet support writing and
+ // this file is world-readable anyways.
+ t := kernel.TaskFromContext(ctx)
+ if t != nil && t.ThreadGroup() == i.task.ThreadGroup() && !ats.MayExec() {
+ return nil
+ }
+
+ return i.DynamicBytesFile.CheckPermissions(ctx, creds, ats)
+}
+
+// commData implements vfs.DynamicBytesSource for /proc/[pid]/comm.
+//
+// +stateify savable
+type commData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*commData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(d.task.Name())
+ buf.WriteString("\n")
+ return nil
+}
+
+// idMapData implements vfs.WritableDynamicBytesSource for
+// /proc/[pid]/{gid_map|uid_map}.
+//
+// +stateify savable
+type idMapData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+ gids bool
+}
+
+var _ dynamicInode = (*idMapData)(nil)
+
+// Generate implements vfs.WritableDynamicBytesSource.Generate.
+func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var entries []auth.IDMapEntry
+ if d.gids {
+ entries = d.task.UserNamespace().GIDMap()
+ } else {
+ entries = d.task.UserNamespace().UIDMap()
+ }
+ for _, e := range entries {
+ fmt.Fprintf(buf, "%10d %10d %10d\n", e.FirstID, e.FirstParentID, e.Length)
+ }
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // "In addition, the number of bytes written to the file must be less than
+ // the system page size, and the write must be performed at the start of
+ // the file ..." - user_namespaces(7)
+ srclen := src.NumBytes()
+ if srclen >= usermem.PageSize || offset != 0 {
+ return 0, syserror.EINVAL
+ }
+ b := make([]byte, srclen)
+ if _, err := src.CopyIn(ctx, b); err != nil {
+ return 0, err
+ }
+
+ // Truncate from the first NULL byte.
+ var nul int64
+ nul = int64(bytes.IndexByte(b, 0))
+ if nul == -1 {
+ nul = srclen
+ }
+ b = b[:nul]
+ // Remove the last \n.
+ if nul >= 1 && b[nul-1] == '\n' {
+ b = b[:nul-1]
+ }
+ lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1)
+ if len(lines) > maxIDMapLines {
+ return 0, syserror.EINVAL
+ }
+
+ entries := make([]auth.IDMapEntry, len(lines))
+ for i, l := range lines {
+ var e auth.IDMapEntry
+ _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length)
+ if err != nil {
+ return 0, syserror.EINVAL
+ }
+ entries[i] = e
+ }
+ var err error
+ if d.gids {
+ err = d.task.UserNamespace().SetGIDMap(ctx, entries)
+ } else {
+ err = d.task.UserNamespace().SetUIDMap(ctx, entries)
+ }
+ if err != nil {
+ return 0, err
+ }
+
+ // On success, Linux's kernel/user_namespace.c:map_write() always returns
+ // count, even if fewer bytes were used.
+ return int64(srclen), nil
+}
+
+// mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps.
+//
+// +stateify savable
+type mapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadMapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// smapsData implements vfs.DynamicBytesSource for /proc/[pid]/smaps.
+//
+// +stateify savable
+type smapsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*smapsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *smapsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if mm := getMM(d.task); mm != nil {
+ mm.ReadSmapsDataInto(ctx, buf)
+ }
+ return nil
+}
+
+// +stateify savable
+type taskStatData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+
+ // If tgstats is true, accumulate fault stats (not implemented) and CPU
+ // time across all tasks in t's thread group.
+ tgstats bool
+
+ // pidns is the PID namespace associated with the proc filesystem that
+ // includes the file using this statData.
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*taskStatData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *taskStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfTask(s.task))
+ fmt.Fprintf(buf, "(%s) ", s.task.Name())
+ fmt.Fprintf(buf, "%c ", s.task.StateStatus()[0])
+ ppid := kernel.ThreadID(0)
+ if parent := s.task.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "%d ", ppid)
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfProcessGroup(s.task.ThreadGroup().ProcessGroup()))
+ fmt.Fprintf(buf, "%d ", s.pidns.IDOfSession(s.task.ThreadGroup().Session()))
+ fmt.Fprintf(buf, "0 0 " /* tty_nr tpgid */)
+ fmt.Fprintf(buf, "0 " /* flags */)
+ fmt.Fprintf(buf, "0 0 0 0 " /* minflt cminflt majflt cmajflt */)
+ var cputime usage.CPUStats
+ if s.tgstats {
+ cputime = s.task.ThreadGroup().CPUStats()
+ } else {
+ cputime = s.task.CPUStats()
+ }
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ cputime = s.task.ThreadGroup().JoinedChildCPUStats()
+ fmt.Fprintf(buf, "%d %d ", linux.ClockTFromDuration(cputime.UserTime), linux.ClockTFromDuration(cputime.SysTime))
+ fmt.Fprintf(buf, "%d %d ", s.task.Priority(), s.task.Niceness())
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Count())
+
+ // itrealvalue. Since kernel 2.6.17, this field is no longer
+ // maintained, and is hard coded as 0.
+ fmt.Fprintf(buf, "0 ")
+
+ // Start time is relative to boot time, expressed in clock ticks.
+ fmt.Fprintf(buf, "%d ", linux.ClockTFromDuration(s.task.StartTime().Sub(s.task.Kernel().Timekeeper().BootTime())))
+
+ var vss, rss uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+ fmt.Fprintf(buf, "%d %d ", vss, rss/usermem.PageSize)
+
+ // rsslim.
+ fmt.Fprintf(buf, "%d ", s.task.ThreadGroup().Limits().Get(limits.Rss).Cur)
+
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* startcode endcode startstack kstkesp kstkeip */)
+ fmt.Fprintf(buf, "0 0 0 0 0 " /* signal blocked sigignore sigcatch wchan */)
+ fmt.Fprintf(buf, "0 0 " /* nswap cnswap */)
+ terminationSignal := linux.Signal(0)
+ if s.task == s.task.ThreadGroup().Leader() {
+ terminationSignal = s.task.ThreadGroup().TerminationSignal()
+ }
+ fmt.Fprintf(buf, "%d ", terminationSignal)
+ fmt.Fprintf(buf, "0 0 0 " /* processor rt_priority policy */)
+ fmt.Fprintf(buf, "0 0 0 " /* delayacct_blkio_ticks guest_time cguest_time */)
+ fmt.Fprintf(buf, "0 0 0 0 0 0 0 " /* start_data end_data start_brk arg_start arg_end env_start env_end */)
+ fmt.Fprintf(buf, "0\n" /* exit_code */)
+
+ return nil
+}
+
+// statmData implements vfs.DynamicBytesSource for /proc/[pid]/statm.
+//
+// +stateify savable
+type statmData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*statmData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var vss, rss uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ }
+ })
+
+ fmt.Fprintf(buf, "%d %d 0 0 0 0 0\n", vss/usermem.PageSize, rss/usermem.PageSize)
+ return nil
+}
+
+// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
+//
+// +stateify savable
+type statusData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+}
+
+var _ dynamicInode = (*statusData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name())
+ fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus())
+ fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup()))
+ fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task))
+ ppid := kernel.ThreadID(0)
+ if parent := s.task.Parent(); parent != nil {
+ ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
+ }
+ fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
+ tpid := kernel.ThreadID(0)
+ if tracer := s.task.Tracer(); tracer != nil {
+ tpid = s.pidns.IDOfTask(tracer)
+ }
+ fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
+ var fds int
+ var vss, rss, data uint64
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ if fdTable := t.FDTable(); fdTable != nil {
+ fds = fdTable.Size()
+ }
+ if mm := t.MemoryManager(); mm != nil {
+ vss = mm.VirtualMemorySize()
+ rss = mm.ResidentSetSize()
+ data = mm.VirtualDataSize()
+ }
+ })
+ fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
+ fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
+ fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
+ fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
+ fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count())
+ creds := s.task.Credentials()
+ fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
+ fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
+ fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
+ fmt.Fprintf(buf, "CapBnd:\t%016x\n", creds.BoundingCaps)
+ fmt.Fprintf(buf, "Seccomp:\t%d\n", s.task.SeccompMode())
+ // We unconditionally report a single NUMA node. See
+ // pkg/sentry/syscalls/linux/sys_mempolicy.go.
+ fmt.Fprintf(buf, "Mems_allowed:\t1\n")
+ fmt.Fprintf(buf, "Mems_allowed_list:\t0\n")
+ return nil
+}
+
+// ioUsage is the /proc/[pid]/io and /proc/[pid]/task/[tid]/io data provider.
+type ioUsage interface {
+ // IOUsage returns the io usage data.
+ IOUsage() *usage.IO
+}
+
+// +stateify savable
+type ioData struct {
+ kernfs.DynamicBytesFile
+
+ ioUsage
+}
+
+var _ dynamicInode = (*ioData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *ioData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ io := usage.IO{}
+ io.Accumulate(i.IOUsage())
+
+ fmt.Fprintf(buf, "char: %d\n", io.CharsRead)
+ fmt.Fprintf(buf, "wchar: %d\n", io.CharsWritten)
+ fmt.Fprintf(buf, "syscr: %d\n", io.ReadSyscalls)
+ fmt.Fprintf(buf, "syscw: %d\n", io.WriteSyscalls)
+ fmt.Fprintf(buf, "read_bytes: %d\n", io.BytesRead)
+ fmt.Fprintf(buf, "write_bytes: %d\n", io.BytesWritten)
+ fmt.Fprintf(buf, "cancelled_write_bytes: %d\n", io.BytesWriteCancelled)
+ return nil
+}
+
+// oomScoreAdj is a stub of the /proc/<pid>/oom_score_adj file.
+//
+// +stateify savable
+type oomScoreAdj struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ vfs.WritableDynamicBytesSource = (*oomScoreAdj)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (o *oomScoreAdj) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
+ fmt.Fprintf(buf, "%d\n", o.task.OOMScoreAdj())
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ if o.task.ExitState() == kernel.TaskExitDead {
+ return 0, syserror.ESRCH
+ }
+ if err := o.task.SetOOMScoreAdj(v); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// exeSymlink is an symlink for the /proc/[pid]/exe file.
+//
+// +stateify savable
+type exeSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ task *kernel.Task
+}
+
+var _ kernfs.Inode = (*exeSymlink)(nil)
+
+func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentry {
+ inode := &exeSymlink{task: task}
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+// Readlink implements kernfs.Inode.
+func (s *exeSymlink) Readlink(ctx context.Context) (string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return "", syserror.EACCES
+ }
+
+ // Pull out the executable for /proc/[pid]/exe.
+ exec, err := s.executable()
+ if err != nil {
+ return "", err
+ }
+ defer exec.DecRef(ctx)
+
+ return exec.PathnameWithDeleted(ctx), nil
+}
+
+// Getlink implements kernfs.Inode.Getlink.
+func (s *exeSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ if !kernel.ContextCanTrace(ctx, s.task, false) {
+ return vfs.VirtualDentry{}, "", syserror.EACCES
+ }
+
+ exec, err := s.executable()
+ if err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+ defer exec.DecRef(ctx)
+
+ vd := exec.(*fsbridge.VFSFile).FileDescription().VirtualDentry()
+ vd.IncRef()
+ return vd, "", nil
+}
+
+func (s *exeSymlink) executable() (file fsbridge.File, err error) {
+ if err := checkTaskState(s.task); err != nil {
+ return nil, err
+ }
+
+ s.task.WithMuLocked(func(t *kernel.Task) {
+ mm := t.MemoryManager()
+ if mm == nil {
+ err = syserror.EACCES
+ return
+ }
+
+ // The MemoryManager may be destroyed, in which case
+ // MemoryManager.destroy will simply set the executable to nil
+ // (with locks held).
+ file = mm.Executable()
+ if file == nil {
+ err = syserror.ESRCH
+ }
+ })
+ return
+}
+
+// mountInfoData is used to implement /proc/[pid]/mountinfo.
+//
+// +stateify savable
+type mountInfoData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountInfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountInfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var fsctx *kernel.FSContext
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return nil
+ }
+ rootDir := fsctx.RootDirectoryVFS2()
+ if !rootDir.Ok() {
+ // Root has been destroyed. Don't try to read mounts.
+ return nil
+ }
+ defer rootDir.DecRef(ctx)
+ i.task.Kernel().VFS().GenerateProcMountInfo(ctx, rootDir, buf)
+ return nil
+}
+
+// mountsData is used to implement /proc/[pid]/mounts.
+//
+// +stateify savable
+type mountsData struct {
+ kernfs.DynamicBytesFile
+
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*mountsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (i *mountsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ var fsctx *kernel.FSContext
+ i.task.WithMuLocked(func(t *kernel.Task) {
+ fsctx = t.FSContext()
+ })
+ if fsctx == nil {
+ // The task has been destroyed. Nothing to show here.
+ return nil
+ }
+ rootDir := fsctx.RootDirectoryVFS2()
+ if !rootDir.Ok() {
+ // Root has been destroyed. Don't try to read mounts.
+ return nil
+ }
+ defer rootDir.DecRef(ctx)
+ i.task.Kernel().VFS().GenerateProcMounts(ctx, rootDir, buf)
+ return nil
+}
+
+type namespaceSymlink struct {
+ kernfs.StaticSymlink
+
+ task *kernel.Task
+}
+
+func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns string) *kernfs.Dentry {
+ // Namespace symlinks should contain the namespace name and the inode number
+ // for the namespace instance, so for example user:[123456]. We currently fake
+ // the inode number by sticking the symlink inode in its place.
+ target := fmt.Sprintf("%s:[%d]", ns, ino)
+
+ inode := &namespaceSymlink{task: task}
+ // Note: credentials are overridden by taskOwnedInode.
+ inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target)
+
+ taskInode := &taskOwnedInode{Inode: inode, owner: task}
+ d := &kernfs.Dentry{}
+ d.Init(taskInode)
+ return d
+}
+
+// Readlink implements Inode.
+func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return "", err
+ }
+ return s.StaticSymlink.Readlink(ctx)
+}
+
+// Getlink implements Inode.Getlink.
+func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ if err := checkTaskState(s.task); err != nil {
+ return vfs.VirtualDentry{}, "", err
+ }
+
+ // Create a synthetic inode to represent the namespace.
+ dentry := &kernfs.Dentry{}
+ dentry.Init(&namespaceInode{})
+ vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry())
+ vd.IncRef()
+ dentry.DecRef(ctx)
+ return vd, "", nil
+}
+
+// namespaceInode is a synthetic inode created to represent a namespace in
+// /proc/[pid]/ns/*.
+type namespaceInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+
+ locks vfs.FileLocks
+}
+
+var _ kernfs.Inode = (*namespaceInode)(nil)
+
+// Init initializes a namespace inode.
+func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) {
+ if perm&^linux.PermissionsMask != 0 {
+ panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask))
+ }
+ i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm)
+}
+
+// Open implements Inode.Open.
+func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &namespaceFD{inode: i}
+ i.IncRef()
+ fd.LockFD.Init(&i.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// namespace FD is a synthetic file that represents a namespace in
+// /proc/[pid]/ns/*.
+type namespaceFD struct {
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+ inode *namespaceInode
+}
+
+var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil)
+
+// Stat implements FileDescriptionImpl.
+func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ return fd.inode.Stat(ctx, vfs, opts)
+}
+
+// SetStat implements FileDescriptionImpl.
+func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem()
+ creds := auth.CredentialsFromContext(ctx)
+ return fd.inode.SetStat(ctx, vfs, creds, opts)
+}
+
+// Release implements FileDescriptionImpl.
+func (fd *namespaceFD) Release(ctx context.Context) {
+ fd.inode.DecRef(ctx)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *namespaceFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *namespaceFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go
new file mode 100644
index 000000000..a4c884bf9
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/task_net.go
@@ -0,0 +1,810 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "reflect"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func (fs *filesystem) newTaskNetDir(task *kernel.Task) *kernfs.Dentry {
+ k := task.Kernel()
+ pidns := task.PIDNamespace()
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+
+ var contents map[string]*kernfs.Dentry
+ if stack := task.NetworkNamespace().Stack(); stack != nil {
+ const (
+ arp = "IP address HW type Flags HW address Mask Device\n"
+ netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
+ packet = "sk RefCnt Type Proto Iface R Rmem User Inode\n"
+ protocols = "protocol size sockets memory press maxhdr slab module cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n"
+ ptype = "Type Device Function\n"
+ upd6 = " sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n"
+ )
+ psched := fmt.Sprintf("%08x %08x %08x %08x\n", uint64(time.Microsecond/time.Nanosecond), 64, 1000000, uint64(time.Second/time.Nanosecond))
+
+ // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task
+ // network namespace.
+ contents = map[string]*kernfs.Dentry{
+ "dev": fs.newDentry(root, fs.NextIno(), 0444, &netDevData{stack: stack}),
+ "snmp": fs.newDentry(root, fs.NextIno(), 0444, &netSnmpData{stack: stack}),
+
+ // The following files are simple stubs until they are implemented in
+ // netstack, if the file contains a header the stub is just the header
+ // otherwise it is an empty file.
+ "arp": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(arp)),
+ "netlink": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(netlink)),
+ "netstat": fs.newDentry(root, fs.NextIno(), 0444, &netStatData{}),
+ "packet": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(packet)),
+ "protocols": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(protocols)),
+
+ // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000,
+ // high res timer ticks per sec (ClockGetres returns 1ns resolution).
+ "psched": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(psched)),
+ "ptype": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(ptype)),
+ "route": fs.newDentry(root, fs.NextIno(), 0444, &netRouteData{stack: stack}),
+ "tcp": fs.newDentry(root, fs.NextIno(), 0444, &netTCPData{kernel: k}),
+ "udp": fs.newDentry(root, fs.NextIno(), 0444, &netUDPData{kernel: k}),
+ "unix": fs.newDentry(root, fs.NextIno(), 0444, &netUnixData{kernel: k}),
+ }
+
+ if stack.SupportsIPv6() {
+ contents["if_inet6"] = fs.newDentry(root, fs.NextIno(), 0444, &ifinet6{stack: stack})
+ contents["ipv6_route"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(""))
+ contents["tcp6"] = fs.newDentry(root, fs.NextIno(), 0444, &netTCP6Data{kernel: k})
+ contents["udp6"] = fs.newDentry(root, fs.NextIno(), 0444, newStaticFile(upd6))
+ }
+ }
+
+ return fs.newTaskOwnedDir(task, fs.NextIno(), 0555, contents)
+}
+
+// ifinet6 implements vfs.DynamicBytesSource for /proc/net/if_inet6.
+//
+// +stateify savable
+type ifinet6 struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*ifinet6)(nil)
+
+func (n *ifinet6) contents() []string {
+ var lines []string
+ nics := n.stack.Interfaces()
+ for id, naddrs := range n.stack.InterfaceAddrs() {
+ nic, ok := nics[id]
+ if !ok {
+ // NIC was added after NICNames was called. We'll just ignore it.
+ continue
+ }
+
+ for _, a := range naddrs {
+ // IPv6 only.
+ if a.Family != linux.AF_INET6 {
+ continue
+ }
+
+ // Fields:
+ // IPv6 address displayed in 32 hexadecimal chars without colons
+ // Netlink device number (interface index) in hexadecimal (use nic id)
+ // Prefix length in hexadecimal
+ // Scope value (use 0)
+ // Interface flags
+ // Device name
+ lines = append(lines, fmt.Sprintf("%032x %02x %02x %02x %02x %8s\n", a.Addr, id, a.PrefixLen, 0, a.Flags, nic.Name))
+ }
+ }
+ return lines
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *ifinet6) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ for _, l := range n.contents() {
+ buf.WriteString(l)
+ }
+ return nil
+}
+
+// netDevData implements vfs.DynamicBytesSource for /proc/net/dev.
+//
+// +stateify savable
+type netDevData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netDevData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *netDevData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ interfaces := n.stack.Interfaces()
+ buf.WriteString("Inter-| Receive | Transmit\n")
+ buf.WriteString(" face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n")
+
+ for _, i := range interfaces {
+ // Implements the same format as
+ // net/core/net-procfs.c:dev_seq_printf_stats.
+ var stats inet.StatDev
+ if err := n.stack.Statistics(&stats, i.Name); err != nil {
+ log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err)
+ continue
+ }
+ fmt.Fprintf(
+ buf,
+ "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n",
+ i.Name,
+ // Received
+ stats[0], // bytes
+ stats[1], // packets
+ stats[2], // errors
+ stats[3], // dropped
+ stats[4], // fifo
+ stats[5], // frame
+ stats[6], // compressed
+ stats[7], // multicast
+ // Transmitted
+ stats[8], // bytes
+ stats[9], // packets
+ stats[10], // errors
+ stats[11], // dropped
+ stats[12], // fifo
+ stats[13], // frame
+ stats[14], // compressed
+ stats[15], // multicast
+ )
+ }
+
+ return nil
+}
+
+// netUnixData implements vfs.DynamicBytesSource for /proc/net/unix.
+//
+// +stateify savable
+type netUnixData struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netUnixData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString("Num RefCount Protocol Flags Type St Inode Path\n")
+ for _, se := range n.kernel.ListSockets() {
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ continue
+ }
+ if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX {
+ s.DecRef(ctx)
+ // Not a unix socket.
+ continue
+ }
+ sops := s.Impl().(*unix.SocketVFS2)
+
+ addr, err := sops.Endpoint().GetLocalAddress()
+ if err != nil {
+ log.Warningf("Failed to retrieve socket name from %+v: %v", s, err)
+ addr.Addr = "<unknown>"
+ }
+
+ sockFlags := 0
+ if ce, ok := sops.Endpoint().(transport.ConnectingEndpoint); ok {
+ if ce.Listening() {
+ // For unix domain sockets, linux reports a single flag
+ // value if the socket is listening, of __SO_ACCEPTCON.
+ sockFlags = linux.SO_ACCEPTCON
+ }
+ }
+
+ // Get inode number.
+ var ino uint64
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_INO})
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve ino for socket file: %v", statErr)
+ } else {
+ ino = stat.Ino
+ }
+
+ // In the socket entry below, the value for the 'Num' field requires
+ // some consideration. Linux prints the address to the struct
+ // unix_sock representing a socket in the kernel, but may redact the
+ // value for unprivileged users depending on the kptr_restrict
+ // sysctl.
+ //
+ // One use for this field is to allow a privileged user to
+ // introspect into the kernel memory to determine information about
+ // a socket not available through procfs, such as the socket's peer.
+ //
+ // In gvisor, returning a pointer to our internal structures would
+ // be pointless, as it wouldn't match the memory layout for struct
+ // unix_sock, making introspection difficult. We could populate a
+ // struct unix_sock with the appropriate data, but even that
+ // requires consideration for which kernel version to emulate, as
+ // the definition of this struct changes over time.
+ //
+ // For now, we always redact this pointer.
+ fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d",
+ (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct.
+ s.Refs()-1, // RefCount, don't count our own ref.
+ 0, // Protocol, always 0 for UDS.
+ sockFlags, // Flags.
+ sops.Endpoint().Type(), // Type.
+ sops.State(), // State.
+ ino, // Inode.
+ )
+
+ // Path
+ if len(addr.Addr) != 0 {
+ if addr.Addr[0] == 0 {
+ // Abstract path.
+ fmt.Fprintf(buf, " @%s", string(addr.Addr[1:]))
+ } else {
+ fmt.Fprintf(buf, " %s", string(addr.Addr))
+ }
+ }
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef(ctx)
+ }
+ return nil
+}
+
+func networkToHost16(n uint16) uint16 {
+ // n is in network byte order, so is big-endian. The most-significant byte
+ // should be stored in the lower address.
+ //
+ // We manually inline binary.BigEndian.Uint16() because Go does not support
+ // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to
+ // binary.BigEndian.Uint16() require a read of binary.BigEndian and an
+ // interface method call, defeating inlining.
+ buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
+ switch family {
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if i != nil {
+ a = *i.(*linux.SockAddrInet)
+ }
+
+ // linux.SockAddrInet.Port is stored in the network byte order and is
+ // printed like a number in host byte order. Note that all numbers in host
+ // byte order are printed with the most-significant byte first when
+ // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux.
+ port := networkToHost16(a.Port)
+
+ // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order
+ // (i.e. most-significant byte in index 0). Linux represents this as a
+ // __be32 which is a typedef for an unsigned int, and is printed with
+ // %X. This means that for a little-endian machine, Linux prints the
+ // least-significant byte of the address first. To emulate this, we first
+ // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // which makes it have the equivalent encoding to a __be32 on a little
+ // endian machine. Note that this operation is a no-op on a big endian
+ // machine. Then similar to Linux, we format it with %X, which will print
+ // the most-significant byte of the __be32 address first, which is now
+ // actually the least-significant byte of the original address in
+ // linux.SockAddrInet.Addr on little endian machines, due to the conversion.
+ addr := usermem.ByteOrder.Uint32(a.Addr[:])
+
+ fmt.Fprintf(w, "%08X:%04X ", addr, port)
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if i != nil {
+ a = *i.(*linux.SockAddrInet6)
+ }
+
+ port := networkToHost16(a.Port)
+ addr0 := usermem.ByteOrder.Uint32(a.Addr[0:4])
+ addr1 := usermem.ByteOrder.Uint32(a.Addr[4:8])
+ addr2 := usermem.ByteOrder.Uint32(a.Addr[8:12])
+ addr3 := usermem.ByteOrder.Uint32(a.Addr[12:16])
+ fmt.Fprintf(w, "%08X%08X%08X%08X:%04X ", addr0, addr1, addr2, addr3, port)
+ }
+}
+
+func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, family int) error {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ for _, se := range k.ListSockets() {
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ continue
+ }
+ sops, ok := s.Impl().(socket.SocketVFS2)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
+ }
+ if fa, stype, _ := sops.Type(); !(family == fa && stype == linux.SOCK_STREAM) {
+ s.DecRef(ctx)
+ // Not tcp4 sockets.
+ continue
+ }
+
+ // Linux's documentation for the fields below can be found at
+ // https://www.kernel.org/doc/Documentation/networking/proc_net_tcp.txt.
+ // For Linux's implementation, see net/ipv4/tcp_ipv4.c:get_tcp4_sock().
+ // Note that the header doesn't contain labels for all the fields.
+
+ // Field: sl; entry number.
+ fmt.Fprintf(buf, "%4d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddr
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = local
+ }
+ }
+ writeInetAddr(buf, family, localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddr
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = remote
+ }
+ }
+ writeInetAddr(buf, family, remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when; timer active state and number of jiffies
+ // until timer expires. Unimplemented.
+ fmt.Fprintf(buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt; number of unrecovered RTO timeouts.
+ // Unimplemented.
+ fmt.Fprintf(buf, "%08X ", 0)
+
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
+ // Field: uid.
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout; number of unanswered 0-window probes.
+ // Unimplemented.
+ fmt.Fprintf(buf, "%8d ", 0)
+
+ // Field: inode.
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
+
+ // Field: refcount. Don't count the ref we obtain while deferencing
+ // the weakref to this socket.
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: retransmit timeout. Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: predicted tick of soft clock (delayed ACK control data).
+ // Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: (ack.quick<<1)|ack.pingpong, Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: sending congestion window, Unimplemented.
+ fmt.Fprintf(buf, "%d ", 0)
+
+ // Field: Slow start size threshold, -1 if threshold >= 0xFFFF.
+ // Unimplemented, report as large threshold.
+ fmt.Fprintf(buf, "%d", -1)
+
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef(ctx)
+ }
+
+ return nil
+}
+
+// netTCPData implements vfs.DynamicBytesSource for /proc/net/tcp.
+//
+// +stateify savable
+type netTCPData struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netTCPData)(nil)
+
+func (d *netTCPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n")
+ return commonGenerateTCP(ctx, buf, d.kernel, linux.AF_INET)
+}
+
+// netTCP6Data implements vfs.DynamicBytesSource for /proc/net/tcp6.
+//
+// +stateify savable
+type netTCP6Data struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netTCP6Data)(nil)
+
+func (d *netTCP6Data) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")
+ return commonGenerateTCP(ctx, buf, d.kernel, linux.AF_INET6)
+}
+
+// netUDPData implements vfs.DynamicBytesSource for /proc/net/udp.
+//
+// +stateify savable
+type netUDPData struct {
+ kernfs.DynamicBytesFile
+
+ kernel *kernel.Kernel
+}
+
+var _ dynamicInode = (*netUDPData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ buf.WriteString(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops \n")
+
+ for _, se := range d.kernel.ListSockets() {
+ s := se.SockVFS2
+ if !s.TryIncRef() {
+ log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s)
+ continue
+ }
+ sops, ok := s.Impl().(socket.SocketVFS2)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", s))
+ }
+ if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
+ s.DecRef(ctx)
+ // Not udp4 socket.
+ continue
+ }
+
+ // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock().
+
+ // Field: sl; entry number.
+ fmt.Fprintf(buf, "%5d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddrInet
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(buf, linux.AF_INET, &localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddrInet
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(buf, linux.AF_INET, &remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when. Always 0 for UDP.
+ fmt.Fprintf(buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt. Always 0 for UDP.
+ fmt.Fprintf(buf, "%08X ", 0)
+
+ stat, statErr := s.Stat(ctx, vfs.StatOptions{Mask: linux.STATX_UID | linux.STATX_INO})
+
+ // Field: uid.
+ if statErr != nil || stat.Mask&linux.STATX_UID == 0 {
+ log.Warningf("Failed to retrieve uid for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(buf, "%5d ", uint32(auth.KUID(stat.UID).In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout. Always 0 for UDP.
+ fmt.Fprintf(buf, "%8d ", 0)
+
+ // Field: inode.
+ if statErr != nil || stat.Mask&linux.STATX_INO == 0 {
+ log.Warningf("Failed to retrieve inode for socket file: %v", statErr)
+ fmt.Fprintf(buf, "%8d ", 0)
+ } else {
+ fmt.Fprintf(buf, "%8d ", stat.Ino)
+ }
+
+ // Field: ref; reference count on the socket inode. Don't count the ref
+ // we obtain while deferencing the weakref to this socket.
+ fmt.Fprintf(buf, "%d ", s.Refs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: drops; number of dropped packets. Unimplemented.
+ fmt.Fprintf(buf, "%d", 0)
+
+ fmt.Fprintf(buf, "\n")
+
+ s.DecRef(ctx)
+ }
+ return nil
+}
+
+// netSnmpData implements vfs.DynamicBytesSource for /proc/net/snmp.
+//
+// +stateify savable
+type netSnmpData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netSnmpData)(nil)
+
+type snmpLine struct {
+ prefix string
+ header string
+}
+
+var snmp = []snmpLine{
+ {
+ prefix: "Ip",
+ header: "Forwarding DefaultTTL InReceives InHdrErrors InAddrErrors ForwDatagrams InUnknownProtos InDiscards InDelivers OutRequests OutDiscards OutNoRoutes ReasmTimeout ReasmReqds ReasmOKs ReasmFails FragOKs FragFails FragCreates",
+ },
+ {
+ prefix: "Icmp",
+ header: "InMsgs InErrors InCsumErrors InDestUnreachs InTimeExcds InParmProbs InSrcQuenchs InRedirects InEchos InEchoReps InTimestamps InTimestampReps InAddrMasks InAddrMaskReps OutMsgs OutErrors OutDestUnreachs OutTimeExcds OutParmProbs OutSrcQuenchs OutRedirects OutEchos OutEchoReps OutTimestamps OutTimestampReps OutAddrMasks OutAddrMaskReps",
+ },
+ {
+ prefix: "IcmpMsg",
+ },
+ {
+ prefix: "Tcp",
+ header: "RtoAlgorithm RtoMin RtoMax MaxConn ActiveOpens PassiveOpens AttemptFails EstabResets CurrEstab InSegs OutSegs RetransSegs InErrs OutRsts InCsumErrors",
+ },
+ {
+ prefix: "Udp",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+ {
+ prefix: "UdpLite",
+ header: "InDatagrams NoPorts InErrors OutDatagrams RcvbufErrors SndbufErrors InCsumErrors IgnoredMulti",
+ },
+}
+
+func toSlice(a interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(a))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
+func sprintSlice(s []uint64) string {
+ if len(s) == 0 {
+ return ""
+ }
+ r := fmt.Sprint(s)
+ return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice.
+}
+
+// Generate implements vfs.DynamicBytesSource.
+func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ types := []interface{}{
+ &inet.StatSNMPIP{},
+ &inet.StatSNMPICMP{},
+ nil, // TODO(gvisor.dev/issue/628): Support IcmpMsg stats.
+ &inet.StatSNMPTCP{},
+ &inet.StatSNMPUDP{},
+ &inet.StatSNMPUDPLite{},
+ }
+ for i, stat := range types {
+ line := snmp[i]
+ if stat == nil {
+ fmt.Fprintf(buf, "%s:\n", line.prefix)
+ fmt.Fprintf(buf, "%s:\n", line.prefix)
+ continue
+ }
+ if err := d.stack.Statistics(stat, line.prefix); err != nil {
+ if err == syserror.EOPNOTSUPP {
+ log.Infof("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ } else {
+ log.Warningf("Failed to retrieve %s of /proc/net/snmp: %v", line.prefix, err)
+ }
+ }
+
+ fmt.Fprintf(buf, "%s: %s\n", line.prefix, line.header)
+
+ if line.prefix == "Tcp" {
+ tcp := stat.(*inet.StatSNMPTCP)
+ // "Tcp" needs special processing because MaxConn is signed. RFC 2012.
+ fmt.Fprintf(buf, "%s: %s %d %s\n", line.prefix, sprintSlice(tcp[:3]), int64(tcp[3]), sprintSlice(tcp[4:]))
+ } else {
+ fmt.Fprintf(buf, "%s: %s\n", line.prefix, sprintSlice(toSlice(stat)))
+ }
+ }
+ return nil
+}
+
+// netRouteData implements vfs.DynamicBytesSource for /proc/net/route.
+//
+// +stateify savable
+type netRouteData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netRouteData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
+func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%-127s\n", "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT")
+
+ interfaces := d.stack.Interfaces()
+ for _, rt := range d.stack.RouteTable() {
+ // /proc/net/route only includes ipv4 routes.
+ if rt.Family != linux.AF_INET {
+ continue
+ }
+
+ // /proc/net/route does not include broadcast or multicast routes.
+ if rt.Type == linux.RTN_BROADCAST || rt.Type == linux.RTN_MULTICAST {
+ continue
+ }
+
+ iface, ok := interfaces[rt.OutputInterface]
+ if !ok || iface.Name == "lo" {
+ continue
+ }
+
+ var (
+ gw uint32
+ prefix uint32
+ flags = linux.RTF_UP
+ )
+ if len(rt.GatewayAddr) == header.IPv4AddressSize {
+ flags |= linux.RTF_GATEWAY
+ gw = usermem.ByteOrder.Uint32(rt.GatewayAddr)
+ }
+ if len(rt.DstAddr) == header.IPv4AddressSize {
+ prefix = usermem.ByteOrder.Uint32(rt.DstAddr)
+ }
+ l := fmt.Sprintf(
+ "%s\t%08X\t%08X\t%04X\t%d\t%d\t%d\t%08X\t%d\t%d\t%d",
+ iface.Name,
+ prefix,
+ gw,
+ flags,
+ 0, // RefCnt.
+ 0, // Use.
+ 0, // Metric.
+ (uint32(1)<<rt.DstLen)-1,
+ 0, // MTU.
+ 0, // Window.
+ 0, // RTT.
+ )
+ fmt.Fprintf(buf, "%-127s\n", l)
+ }
+ return nil
+}
+
+// netStatData implements vfs.DynamicBytesSource for /proc/net/netstat.
+//
+// +stateify savable
+type netStatData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack
+}
+
+var _ dynamicInode = (*netStatData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+// See Linux's net/ipv4/fib_trie.c:fib_route_seq_show.
+func (d *netStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed " +
+ "EmbryonicRsts PruneCalled RcvPruned OfoPruned OutOfWindowIcmps " +
+ "LockDroppedIcmps ArpFilter TW TWRecycled TWKilled PAWSPassive " +
+ "PAWSActive PAWSEstab DelayedACKs DelayedACKLocked DelayedACKLost " +
+ "ListenOverflows ListenDrops TCPPrequeued TCPDirectCopyFromBacklog " +
+ "TCPDirectCopyFromPrequeue TCPPrequeueDropped TCPHPHits TCPHPHitsToUser " +
+ "TCPPureAcks TCPHPAcks TCPRenoRecovery TCPSackRecovery TCPSACKReneging " +
+ "TCPFACKReorder TCPSACKReorder TCPRenoReorder TCPTSReorder TCPFullUndo " +
+ "TCPPartialUndo TCPDSACKUndo TCPLossUndo TCPLostRetransmit " +
+ "TCPRenoFailures TCPSackFailures TCPLossFailures TCPFastRetrans " +
+ "TCPForwardRetrans TCPSlowStartRetrans TCPTimeouts TCPLossProbes " +
+ "TCPLossProbeRecovery TCPRenoRecoveryFail TCPSackRecoveryFail " +
+ "TCPSchedulerFailed TCPRcvCollapsed TCPDSACKOldSent TCPDSACKOfoSent " +
+ "TCPDSACKRecv TCPDSACKOfoRecv TCPAbortOnData TCPAbortOnClose " +
+ "TCPAbortOnMemory TCPAbortOnTimeout TCPAbortOnLinger TCPAbortFailed " +
+ "TCPMemoryPressures TCPSACKDiscard TCPDSACKIgnoredOld " +
+ "TCPDSACKIgnoredNoUndo TCPSpuriousRTOs TCPMD5NotFound TCPMD5Unexpected " +
+ "TCPMD5Failure TCPSackShifted TCPSackMerged TCPSackShiftFallback " +
+ "TCPBacklogDrop TCPMinTTLDrop TCPDeferAcceptDrop IPReversePathFilter " +
+ "TCPTimeWaitOverflow TCPReqQFullDoCookies TCPReqQFullDrop TCPRetransFail " +
+ "TCPRcvCoalesce TCPOFOQueue TCPOFODrop TCPOFOMerge TCPChallengeACK " +
+ "TCPSYNChallenge TCPFastOpenActive TCPFastOpenActiveFail " +
+ "TCPFastOpenPassive TCPFastOpenPassiveFail TCPFastOpenListenOverflow " +
+ "TCPFastOpenCookieReqd TCPSpuriousRtxHostQueues BusyPollRxPackets " +
+ "TCPAutoCorking TCPFromZeroWindowAdv TCPToZeroWindowAdv " +
+ "TCPWantZeroWindowAdv TCPSynRetrans TCPOrigDataSent TCPHystartTrainDetect " +
+ "TCPHystartTrainCwnd TCPHystartDelayDetect TCPHystartDelayCwnd " +
+ "TCPACKSkippedSynRecv TCPACKSkippedPAWS TCPACKSkippedSeq " +
+ "TCPACKSkippedFinWait2 TCPACKSkippedTimeWait TCPACKSkippedChallenge " +
+ "TCPWinProbe TCPKeepAlive TCPMTUPFail TCPMTUPSuccess\n")
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
new file mode 100644
index 000000000..6d2b90a8b
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -0,0 +1,256 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "sort"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ selfName = "self"
+ threadSelfName = "thread-self"
+)
+
+// tasksInode represents the inode for /proc/ directory.
+//
+// +stateify savable
+type tasksInode struct {
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.InodeAttrs
+ kernfs.OrderedChildren
+ kernfs.AlwaysValid
+
+ locks vfs.FileLocks
+
+ fs *filesystem
+ pidns *kernel.PIDNamespace
+
+ // '/proc/self' and '/proc/thread-self' have custom directory offsets in
+ // Linux. So handle them outside of OrderedChildren.
+ selfSymlink *vfs.Dentry
+ threadSelfSymlink *vfs.Dentry
+
+ // cgroupControllers is a map of controller name to directory in the
+ // cgroup hierarchy. These controllers are immutable and will be listed
+ // in /proc/pid/cgroup if not nil.
+ cgroupControllers map[string]string
+}
+
+var _ kernfs.Inode = (*tasksInode)(nil)
+
+func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) (*tasksInode, *kernfs.Dentry) {
+ root := auth.NewRootCredentials(pidns.UserNamespace())
+ contents := map[string]*kernfs.Dentry{
+ "cpuinfo": fs.newDentry(root, fs.NextIno(), 0444, newStaticFileSetStat(cpuInfoData(k))),
+ "filesystems": fs.newDentry(root, fs.NextIno(), 0444, &filesystemsData{}),
+ "loadavg": fs.newDentry(root, fs.NextIno(), 0444, &loadavgData{}),
+ "sys": fs.newSysDir(root, k),
+ "meminfo": fs.newDentry(root, fs.NextIno(), 0444, &meminfoData{}),
+ "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"),
+ "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"),
+ "stat": fs.newDentry(root, fs.NextIno(), 0444, &statData{}),
+ "uptime": fs.newDentry(root, fs.NextIno(), 0444, &uptimeData{}),
+ "version": fs.newDentry(root, fs.NextIno(), 0444, &versionData{}),
+ }
+
+ inode := &tasksInode{
+ pidns: pidns,
+ fs: fs,
+ selfSymlink: fs.newSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(),
+ threadSelfSymlink: fs.newThreadSelfSymlink(root, fs.NextIno(), pidns).VFSDentry(),
+ cgroupControllers: cgroupControllers,
+ }
+ inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
+
+ dentry := &kernfs.Dentry{}
+ dentry.Init(inode)
+
+ inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ links := inode.OrderedChildren.Populate(dentry, contents)
+ inode.IncLinks(links)
+
+ return inode, dentry
+}
+
+// Lookup implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) {
+ // Try to lookup a corresponding task.
+ tid, err := strconv.ParseUint(name, 10, 64)
+ if err != nil {
+ // If it failed to parse, check if it's one of the special handled files.
+ switch name {
+ case selfName:
+ return i.selfSymlink, nil
+ case threadSelfName:
+ return i.threadSelfSymlink, nil
+ }
+ return nil, syserror.ENOENT
+ }
+
+ task := i.pidns.TaskWithID(kernel.ThreadID(tid))
+ if task == nil {
+ return nil, syserror.ENOENT
+ }
+
+ taskDentry := i.fs.newTaskInode(task, i.pidns, true, i.cgroupControllers)
+ return taskDentry.VFSDentry(), nil
+}
+
+// IterDirents implements kernfs.inodeDynamicLookup.
+func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) {
+ // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256
+ const FIRST_PROCESS_ENTRY = 256
+
+ // Use maxTaskID to shortcut searches that will result in 0 entries.
+ const maxTaskID = kernel.TasksLimit + 1
+ if offset >= maxTaskID {
+ return offset, nil
+ }
+
+ // According to Linux (fs/proc/base.c:proc_pid_readdir()), process directories
+ // start at offset FIRST_PROCESS_ENTRY with '/proc/self', followed by
+ // '/proc/thread-self' and then '/proc/[pid]'.
+ if offset < FIRST_PROCESS_ENTRY {
+ offset = FIRST_PROCESS_ENTRY
+ }
+
+ if offset == FIRST_PROCESS_ENTRY {
+ dirent := vfs.Dirent{
+ Name: selfName,
+ Type: linux.DT_LNK,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ if offset == FIRST_PROCESS_ENTRY+1 {
+ dirent := vfs.Dirent{
+ Name: threadSelfName,
+ Type: linux.DT_LNK,
+ Ino: i.fs.NextIno(),
+ NextOff: offset + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+
+ // Collect all tasks that TGIDs are greater than the offset specified. Per
+ // Linux we only include in directory listings if it's the leader. But for
+ // whatever crazy reason, you can still walk to the given node.
+ var tids []int
+ startTid := offset - FIRST_PROCESS_ENTRY - 2
+ for _, tg := range i.pidns.ThreadGroups() {
+ tid := i.pidns.IDOfThreadGroup(tg)
+ if int64(tid) < startTid {
+ continue
+ }
+ if leader := tg.Leader(); leader != nil {
+ tids = append(tids, int(tid))
+ }
+ }
+
+ if len(tids) == 0 {
+ return offset, nil
+ }
+
+ sort.Ints(tids)
+ for _, tid := range tids {
+ dirent := vfs.Dirent{
+ Name: strconv.FormatUint(uint64(tid), 10),
+ Type: linux.DT_DIR,
+ Ino: i.fs.NextIno(),
+ NextOff: FIRST_PROCESS_ENTRY + 2 + int64(tid) + 1,
+ }
+ if err := cb.Handle(dirent); err != nil {
+ return offset, err
+ }
+ offset++
+ }
+ return maxTaskID, nil
+}
+
+// Open implements kernfs.Inode.
+func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) {
+ stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ if opts.Mask&linux.STATX_NLINK != 0 {
+ // Add dynamic children to link count.
+ for _, tg := range i.pidns.ThreadGroups() {
+ if leader := tg.Leader(); leader != nil {
+ stat.Nlink++
+ }
+ }
+ }
+
+ return stat, nil
+}
+
+// staticFileSetStat implements a special static file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type staticFileSetStat struct {
+ dynamicBytesFileSetAttr
+ vfs.StaticData
+}
+
+var _ dynamicInode = (*staticFileSetStat)(nil)
+
+func newStaticFileSetStat(data string) *staticFileSetStat {
+ return &staticFileSetStat{StaticData: vfs.StaticData{Data: data}}
+}
+
+func cpuInfoData(k *kernel.Kernel) string {
+ features := k.FeatureSet()
+ if features == nil {
+ // Kernel is always initialized with a FeatureSet.
+ panic("cpuinfo read with nil FeatureSet")
+ }
+ var buf bytes.Buffer
+ for i, max := uint(0), k.ApplicationCores(); i < max; i++ {
+ features.WriteCPUInfoTo(i, &buf)
+ }
+ return buf.String()
+}
+
+func shmData(v uint64) dynamicInode {
+ return newStaticFile(strconv.FormatUint(v, 10))
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
new file mode 100644
index 000000000..7d8983aa5
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -0,0 +1,384 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+ "strconv"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type selfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*selfSymlink)(nil)
+
+func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &selfSymlink{pidns: pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *selfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ if tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return strconv.FormatUint(uint64(tgid), 10), nil
+}
+
+func (s *selfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+type threadSelfSymlink struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+ kernfs.InodeSymlink
+
+ pidns *kernel.PIDNamespace
+}
+
+var _ kernfs.Inode = (*threadSelfSymlink)(nil)
+
+func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, pidns *kernel.PIDNamespace) *kernfs.Dentry {
+ inode := &threadSelfSymlink{pidns: pidns}
+ inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777)
+
+ d := &kernfs.Dentry{}
+ d.Init(inode)
+ return d
+}
+
+func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // Who is reading this link?
+ return "", syserror.EINVAL
+ }
+ tgid := s.pidns.IDOfThreadGroup(t.ThreadGroup())
+ tid := s.pidns.IDOfTask(t)
+ if tid == 0 || tgid == 0 {
+ return "", syserror.ENOENT
+ }
+ return fmt.Sprintf("%d/task/%d", tgid, tid), nil
+}
+
+func (s *threadSelfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) {
+ target, err := s.Readlink(ctx)
+ return vfs.VirtualDentry{}, target, err
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// dynamicBytesFileSetAttr implements a special file that allows inode
+// attributes to be set. This is to support /proc files that are readonly, but
+// allow attributes to be set.
+type dynamicBytesFileSetAttr struct {
+ kernfs.DynamicBytesFile
+}
+
+// SetStat implements Inode.SetStat.
+func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts)
+}
+
+// cpuStats contains the breakdown of CPU time for /proc/stat.
+type cpuStats struct {
+ // user is time spent in userspace tasks with non-positive niceness.
+ user uint64
+
+ // nice is time spent in userspace tasks with positive niceness.
+ nice uint64
+
+ // system is time spent in non-interrupt kernel context.
+ system uint64
+
+ // idle is time spent idle.
+ idle uint64
+
+ // ioWait is time spent waiting for IO.
+ ioWait uint64
+
+ // irq is time spent in interrupt context.
+ irq uint64
+
+ // softirq is time spent in software interrupt context.
+ softirq uint64
+
+ // steal is involuntary wait time.
+ steal uint64
+
+ // guest is time spent in guests with non-positive niceness.
+ guest uint64
+
+ // guestNice is time spent in guests with positive niceness.
+ guestNice uint64
+}
+
+// String implements fmt.Stringer.
+func (c cpuStats) String() string {
+ return fmt.Sprintf("%d %d %d %d %d %d %d %d %d %d", c.user, c.nice, c.system, c.idle, c.ioWait, c.irq, c.softirq, c.steal, c.guest, c.guestNice)
+}
+
+// statData implements vfs.DynamicBytesSource for /proc/stat.
+//
+// +stateify savable
+type statData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*statData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*statData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/37226836): We currently export only zero CPU stats. We could
+ // at least provide some aggregate stats.
+ var cpu cpuStats
+ fmt.Fprintf(buf, "cpu %s\n", cpu)
+
+ k := kernel.KernelFromContext(ctx)
+ for c, max := uint(0), k.ApplicationCores(); c < max; c++ {
+ fmt.Fprintf(buf, "cpu%d %s\n", c, cpu)
+ }
+
+ // The total number of interrupts is dependent on the CPUs and PCI
+ // devices on the system. See arch_probe_nr_irqs.
+ //
+ // Since we don't report real interrupt stats, just choose an arbitrary
+ // value from a representative VM.
+ const numInterrupts = 256
+
+ // The Kernel doesn't handle real interrupts, so report all zeroes.
+ // TODO(b/37226836): We could count page faults as #PF.
+ fmt.Fprintf(buf, "intr 0") // total
+ for i := 0; i < numInterrupts; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+
+ // Total number of context switches.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "ctxt 0\n")
+
+ // CLOCK_REALTIME timestamp from boot, in seconds.
+ fmt.Fprintf(buf, "btime %d\n", k.Timekeeper().BootTime().Seconds())
+
+ // Total number of clones.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "processes 0\n")
+
+ // Number of runnable tasks.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_running 0\n")
+
+ // Number of tasks waiting on IO.
+ // TODO(b/37226836): Count this.
+ fmt.Fprintf(buf, "procs_blocked 0\n")
+
+ // Number of each softirq handled.
+ fmt.Fprintf(buf, "softirq 0") // total
+ for i := 0; i < linux.NumSoftIRQ; i++ {
+ fmt.Fprintf(buf, " 0")
+ }
+ fmt.Fprintf(buf, "\n")
+ return nil
+}
+
+// loadavgData backs /proc/loadavg.
+//
+// +stateify savable
+type loadavgData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*loadavgData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*loadavgData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/62345059): Include real data in fields.
+ // Column 1-3: CPU and IO utilization of the last 1, 5, and 10 minute periods.
+ // Column 4-5: currently running processes and the total number of processes.
+ // Column 6: the last process ID used.
+ fmt.Fprintf(buf, "%.2f %.2f %.2f %d/%d %d\n", 0.00, 0.00, 0.00, 0, 0, 0)
+ return nil
+}
+
+// meminfoData implements vfs.DynamicBytesSource for /proc/meminfo.
+//
+// +stateify savable
+type meminfoData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*meminfoData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*meminfoData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ mf := k.MemoryFile()
+ mf.UpdateUsage()
+ snapshot, totalUsage := usage.MemoryAccounting.Copy()
+ totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ anon := snapshot.Anonymous + snapshot.Tmpfs
+ file := snapshot.PageCache + snapshot.Mapped
+ // We don't actually have active/inactive LRUs, so just make up numbers.
+ activeFile := (file / 2) &^ (usermem.PageSize - 1)
+ inactiveFile := file - activeFile
+
+ fmt.Fprintf(buf, "MemTotal: %8d kB\n", totalSize/1024)
+ memFree := totalSize - totalUsage
+ if memFree > totalSize {
+ // Underflow.
+ memFree = 0
+ }
+ // We use MemFree as MemAvailable because we don't swap.
+ // TODO(rahat): When reclaim is implemented the value of MemAvailable
+ // should change.
+ fmt.Fprintf(buf, "MemFree: %8d kB\n", memFree/1024)
+ fmt.Fprintf(buf, "MemAvailable: %8d kB\n", memFree/1024)
+ fmt.Fprintf(buf, "Buffers: 0 kB\n") // memory usage by block devices
+ fmt.Fprintf(buf, "Cached: %8d kB\n", (file+snapshot.Tmpfs)/1024)
+ // Emulate a system with no swap, which disables inactivation of anon pages.
+ fmt.Fprintf(buf, "SwapCache: 0 kB\n")
+ fmt.Fprintf(buf, "Active: %8d kB\n", (anon+activeFile)/1024)
+ fmt.Fprintf(buf, "Inactive: %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Active(anon): %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Inactive(anon): 0 kB\n")
+ fmt.Fprintf(buf, "Active(file): %8d kB\n", activeFile/1024)
+ fmt.Fprintf(buf, "Inactive(file): %8d kB\n", inactiveFile/1024)
+ fmt.Fprintf(buf, "Unevictable: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "Mlocked: 0 kB\n") // TODO(b/31823263)
+ fmt.Fprintf(buf, "SwapTotal: 0 kB\n")
+ fmt.Fprintf(buf, "SwapFree: 0 kB\n")
+ fmt.Fprintf(buf, "Dirty: 0 kB\n")
+ fmt.Fprintf(buf, "Writeback: 0 kB\n")
+ fmt.Fprintf(buf, "AnonPages: %8d kB\n", anon/1024)
+ fmt.Fprintf(buf, "Mapped: %8d kB\n", file/1024) // doesn't count mapped tmpfs, which we don't know
+ fmt.Fprintf(buf, "Shmem: %8d kB\n", snapshot.Tmpfs/1024)
+ return nil
+}
+
+// uptimeData implements vfs.DynamicBytesSource for /proc/uptime.
+//
+// +stateify savable
+type uptimeData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*uptimeData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*uptimeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ now := time.NowFromContext(ctx)
+
+ // Pretend that we've spent zero time sleeping (second number).
+ fmt.Fprintf(buf, "%.2f 0.00\n", now.Sub(k.Timekeeper().BootTime()).Seconds())
+ return nil
+}
+
+// versionData implements vfs.DynamicBytesSource for /proc/version.
+//
+// +stateify savable
+type versionData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*versionData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ init := k.GlobalInit()
+ if init == nil {
+ // Attempted to read before the init Task is created. This can
+ // only occur during startup, which should never need to read
+ // this file.
+ panic("Attempted to read version before initial Task is available")
+ }
+
+ // /proc/version takes the form:
+ //
+ // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
+ // (COMPILER_VERSION) VERSION"
+ //
+ // where:
+ // - SYSNAME, RELEASE, and VERSION are the same as returned by
+ // sys_utsname
+ // - COMPILE_USER is the user that build the kernel
+ // - COMPILE_HOST is the hostname of the machine on which the kernel
+ // was built
+ // - COMPILER_VERSION is the version reported by the building compiler
+ //
+ // Since we don't really want to expose build information to
+ // applications, those fields are omitted.
+ //
+ // FIXME(mpratt): Using Version from the init task SyscallTable
+ // disregards the different version a task may have (e.g., in a uts
+ // namespace).
+ ver := init.Leader().SyscallTable().Version
+ fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
+ return nil
+}
+
+// filesystemsData backs /proc/filesystems.
+//
+// +stateify savable
+type filesystemsData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*filesystemsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ k := kernel.KernelFromContext(ctx)
+ k.VFS().GenerateProcFilesystems(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
new file mode 100644
index 000000000..6768aa880
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -0,0 +1,317 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// newSysDir returns the dentry corresponding to /proc/sys directory.
+func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
+ return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "kernel": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "hostname": fs.newDentry(root, fs.NextIno(), 0444, &hostnameData{}),
+ "shmall": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMALL)),
+ "shmmax": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMAX)),
+ "shmmni": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMNI)),
+ }),
+ "vm": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "mmap_min_addr": fs.newDentry(root, fs.NextIno(), 0444, &mmapMinAddrData{k: k}),
+ "overcommit_memory": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0\n")),
+ }),
+ "net": fs.newSysNetDir(root, k),
+ })
+}
+
+// newSysNetDir returns the dentry corresponding to /proc/sys/net directory.
+func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry {
+ var contents map[string]*kernfs.Dentry
+
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
+ contents = map[string]*kernfs.Dentry{
+ "ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}),
+
+ // The following files are simple stubs until they are implemented in
+ // netstack, most of these files are configuration related. We use the
+ // value closest to the actual netstack behavior or any empty file, all
+ // of these files will have mode 0444 (read-only for all users).
+ "ip_local_port_range": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("16000 65535")),
+ "ip_local_reserved_ports": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
+ "ipfrag_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("30")),
+ "ip_nonlocal_bind": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "ip_no_pmtu_disc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+
+ // tcp_allowed_congestion_control tell the user what they are able to
+ // do as an unprivledged process so we leave it empty.
+ "tcp_allowed_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
+ "tcp_available_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
+ "tcp_congestion_control": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("reno")),
+
+ // Many of the following stub files are features netstack doesn't
+ // support. The unsupported features return "0" to indicate they are
+ // disabled.
+ "tcp_base_mss": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1280")),
+ "tcp_dsack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_early_retrans": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fack": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_fastopen_key": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("")),
+ "tcp_invalid_ratelimit": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_intvl": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_probes": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_keepalive_time": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("7200")),
+ "tcp_mtu_probing": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_no_metrics_save": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_probe_interval": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_probe_threshold": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "tcp_retries1": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
+ "tcp_retries2": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("15")),
+ "tcp_rfc1337": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_slow_start_after_idle": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ "tcp_synack_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
+ "tcp_syn_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")),
+ "tcp_timestamps": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")),
+ }),
+ "core": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{
+ "default_qdisc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("pfifo_fast")),
+ "message_burst": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("10")),
+ "message_cost": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")),
+ "optmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0")),
+ "rmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "rmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "somaxconn": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("128")),
+ "wmem_default": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ "wmem_max": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("212992")),
+ }),
+ }
+ }
+
+ return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents)
+}
+
+// mmapMinAddrData implements vfs.DynamicBytesSource for
+// /proc/sys/vm/mmap_min_addr.
+//
+// +stateify savable
+type mmapMinAddrData struct {
+ kernfs.DynamicBytesFile
+
+ k *kernel.Kernel
+}
+
+var _ dynamicInode = (*mmapMinAddrData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *mmapMinAddrData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d\n", d.k.Platform.MinUserAddress())
+ return nil
+}
+
+// hostnameData implements vfs.DynamicBytesSource for /proc/sys/kernel/hostname.
+//
+// +stateify savable
+type hostnameData struct {
+ kernfs.DynamicBytesFile
+}
+
+var _ dynamicInode = (*hostnameData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*hostnameData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ utsns := kernel.UTSNamespaceFromContext(ctx)
+ buf.WriteString(utsns.HostName())
+ buf.WriteString("\n")
+ return nil
+}
+
+// tcpSackData implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/tcp_sack.
+//
+// +stateify savable
+type tcpSackData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+ enabled *bool
+}
+
+var _ vfs.WritableDynamicBytesSource = (*tcpSackData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if d.enabled == nil {
+ sack, err := d.stack.TCPSACKEnabled()
+ if err != nil {
+ return err
+ }
+ d.enabled = &sack
+ }
+
+ val := "0\n"
+ if *d.enabled {
+ // Technically, this is not quite compatible with Linux. Linux stores these
+ // as an integer, so if you write "2" into tcp_sack, you should get 2 back.
+ // Tough luck.
+ val = "1\n"
+ }
+ buf.WriteString(val)
+ return nil
+}
+
+func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit the amount of memory allocated.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return n, err
+ }
+ if d.enabled == nil {
+ d.enabled = new(bool)
+ }
+ *d.enabled = v != 0
+ return n, d.stack.SetTCPSACKEnabled(*d.enabled)
+}
+
+// tcpRecoveryData implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/ipv4/tcp_recovery.
+//
+// +stateify savable
+type tcpRecoveryData struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+}
+
+var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.
+func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ recovery, err := d.stack.TCPRecovery()
+ if err != nil {
+ return err
+ }
+
+ buf.WriteString(fmt.Sprintf("%d\n", recovery))
+ return nil
+}
+
+func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit the amount of memory allocated.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+ if err := d.stack.SetTCPRecovery(inet.TCPLossRecovery(v)); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
+// ipForwarding implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/ipv4/ip_forwarding.
+//
+// +stateify savable
+type ipForwarding struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+ enabled *bool
+}
+
+var _ vfs.WritableDynamicBytesSource = (*ipForwarding)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (ipf *ipForwarding) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if ipf.enabled == nil {
+ enabled := ipf.stack.Forwarding(ipv4.ProtocolNumber)
+ ipf.enabled = &enabled
+ }
+
+ val := "0\n"
+ if *ipf.enabled {
+ // Technically, this is not quite compatible with Linux. Linux stores these
+ // as an integer, so if you write "2" into tcp_sack, you should get 2 back.
+ // Tough luck.
+ val = "1\n"
+ }
+ buf.WriteString(val)
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ var v int32
+ n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+ if ipf.enabled == nil {
+ ipf.enabled = new(bool)
+ }
+ *ipf.enabled = v != 0
+ if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/proc/net_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
index 20a77a8ca..1abf56da2 100644
--- a/pkg/sentry/fsimpl/proc/net_test.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go
@@ -20,8 +20,10 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func newIPv6TestStack() *inet.TestStack {
@@ -31,7 +33,7 @@ func newIPv6TestStack() *inet.TestStack {
}
func TestIfinet6NoAddresses(t *testing.T) {
- n := &ifinet6{s: newIPv6TestStack()}
+ n := &ifinet6{stack: newIPv6TestStack()}
var buf bytes.Buffer
n.Generate(contexttest.Context(t), &buf)
if buf.Len() > 0 {
@@ -62,7 +64,7 @@ func TestIfinet6(t *testing.T) {
"101112131415161718191a1b1c1d1e1f 02 80 00 00 eth1\n": {},
}
- n := &ifinet6{s: s}
+ n := &ifinet6{stack: s}
contents := n.contents()
if len(contents) != len(want) {
t.Errorf("Got len(n.contents()) = %d, want = %d", len(contents), len(want))
@@ -76,3 +78,72 @@ func TestIfinet6(t *testing.T) {
t.Errorf("Got n.contents() = %v, want = %v", got, want)
}
}
+
+// TestIPForwarding tests the implementation of
+// /proc/sys/net/ipv4/ip_forwarding
+func TestConfigureIPForwarding(t *testing.T) {
+ ctx := context.Background()
+ s := inet.NewTestStack()
+
+ var cases = []struct {
+ comment string
+ initial bool
+ str string
+ final bool
+ }{
+ {
+ comment: `Forwarding is disabled; write 1 and enable forwarding`,
+ initial: false,
+ str: "1",
+ final: true,
+ },
+ {
+ comment: `Forwarding is disabled; write 0 and disable forwarding`,
+ initial: false,
+ str: "0",
+ final: false,
+ },
+ {
+ comment: `Forwarding is enabled; write 1 and enable forwarding`,
+ initial: true,
+ str: "1",
+ final: true,
+ },
+ {
+ comment: `Forwarding is enabled; write 0 and disable forwarding`,
+ initial: true,
+ str: "0",
+ final: false,
+ },
+ {
+ comment: `Forwarding is disabled; write 2404 and enable forwarding`,
+ initial: false,
+ str: "2404",
+ final: true,
+ },
+ {
+ comment: `Forwarding is enabled; write 2404 and enable forwarding`,
+ initial: true,
+ str: "2404",
+ final: true,
+ },
+ }
+ for _, c := range cases {
+ t.Run(c.comment, func(t *testing.T) {
+ s.IPForwarding = c.initial
+
+ file := &ipForwarding{stack: s, enabled: &c.initial}
+
+ // Write the values.
+ src := usermem.BytesIOSequence([]byte(c.str))
+ if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil {
+ t.Errorf("file.Write(ctx, nil, %v, 0) = (%d, %v); wanted (%d, nil)", c.str, n, err, len(c.str))
+ }
+
+ // Read the values from the stack and check them.
+ if s.IPForwarding != c.final {
+ t.Errorf("s.IPForwarding = %v; wanted %v", s.IPForwarding, c.final)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go
new file mode 100644
index 000000000..3c9297dee
--- /dev/null
+++ b/pkg/sentry/fsimpl/proc/tasks_test.go
@@ -0,0 +1,505 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "fmt"
+ "math"
+ "path"
+ "strconv"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ // Next offset 256 by convention. Adds 1 for the next offset.
+ selfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 0 + 1}
+ threadSelfLink = vfs.Dirent{Type: linux.DT_LNK, NextOff: 256 + 1 + 1}
+
+ // /proc/[pid] next offset starts at 256+2 (files above), then adds the
+ // PID, and adds 1 for the next offset.
+ proc1 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 1 + 1}
+ proc2 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 2 + 1}
+ proc3 = vfs.Dirent{Type: linux.DT_DIR, NextOff: 258 + 3 + 1}
+)
+
+var (
+ tasksStaticFiles = map[string]testutil.DirentType{
+ "cpuinfo": linux.DT_REG,
+ "filesystems": linux.DT_REG,
+ "loadavg": linux.DT_REG,
+ "meminfo": linux.DT_REG,
+ "mounts": linux.DT_LNK,
+ "net": linux.DT_LNK,
+ "self": linux.DT_LNK,
+ "stat": linux.DT_REG,
+ "sys": linux.DT_DIR,
+ "thread-self": linux.DT_LNK,
+ "uptime": linux.DT_REG,
+ "version": linux.DT_REG,
+ }
+ tasksStaticFilesNextOffs = map[string]int64{
+ "self": selfLink.NextOff,
+ "thread-self": threadSelfLink.NextOff,
+ }
+ taskStaticFiles = map[string]testutil.DirentType{
+ "auxv": linux.DT_REG,
+ "cgroup": linux.DT_REG,
+ "cmdline": linux.DT_REG,
+ "comm": linux.DT_REG,
+ "environ": linux.DT_REG,
+ "exe": linux.DT_LNK,
+ "fd": linux.DT_DIR,
+ "fdinfo": linux.DT_DIR,
+ "gid_map": linux.DT_REG,
+ "io": linux.DT_REG,
+ "maps": linux.DT_REG,
+ "mountinfo": linux.DT_REG,
+ "mounts": linux.DT_REG,
+ "net": linux.DT_DIR,
+ "ns": linux.DT_DIR,
+ "oom_score": linux.DT_REG,
+ "oom_score_adj": linux.DT_REG,
+ "smaps": linux.DT_REG,
+ "stat": linux.DT_REG,
+ "statm": linux.DT_REG,
+ "status": linux.DT_REG,
+ "task": linux.DT_DIR,
+ "uid_map": linux.DT_REG,
+ }
+)
+
+func setup(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Error creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+ pop := &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ if err := k.VFS().MkdirAt(ctx, creds, pop, &vfs.MkdirOptions{Mode: 0777}); err != nil {
+ t.Fatalf("MkDir(/proc): %v", err)
+ }
+
+ pop = &vfs.PathOperation{
+ Root: mntns.Root(),
+ Start: mntns.Root(),
+ Path: fspath.Parse("/proc"),
+ }
+ mntOpts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ InternalData: &InternalData{
+ Cgroups: map[string]string{
+ "cpuset": "/foo/cpuset",
+ "memory": "/foo/memory",
+ },
+ },
+ },
+ }
+ if err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil {
+ t.Fatalf("MountAt(/proc): %v", err)
+ }
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
+}
+
+func TestTasksEmpty(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
+ s.AssertAllDirentTypes(collector, tasksStaticFiles)
+ s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
+}
+
+func TestTasks(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ expectedDirents := make(map[string]testutil.DirentType)
+ for n, d := range tasksStaticFiles {
+ expectedDirents[n] = d
+ }
+
+ k := kernel.KernelFromContext(s.Ctx)
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ tasks = append(tasks, task)
+ expectedDirents[fmt.Sprintf("%d", i+1)] = linux.DT_DIR
+ }
+
+ collector := s.ListDirents(s.PathOpAtRoot("/proc"))
+ s.AssertAllDirentTypes(collector, expectedDirents)
+ s.AssertDirentOffsets(collector, tasksStaticFilesNextOffs)
+
+ lastPid := 0
+ dirents := collector.OrderedDirents()
+ doneSkippingNonTaskDirs := false
+ for _, d := range dirents {
+ pid, err := strconv.Atoi(d.Name)
+ if err != nil {
+ if !doneSkippingNonTaskDirs {
+ // We haven't gotten to the task dirs yet.
+ continue
+ }
+ t.Fatalf("Invalid process directory %q", d.Name)
+ }
+ doneSkippingNonTaskDirs = true
+ if lastPid > pid {
+ t.Errorf("pids not in order: %v", dirents)
+ }
+ found := false
+ for _, t := range tasks {
+ if k.TaskSet().Root.IDOfTask(t) == kernel.ThreadID(pid) {
+ found = true
+ }
+ }
+ if !found {
+ t.Errorf("Additional task ID %d listed: %v", pid, tasks)
+ }
+ // Next offset starts at 256+2 ('self' and 'thread-self'), then adds the
+ // PID, and adds 1 for the next offset.
+ if want := int64(256 + 2 + pid + 1); d.NextOff != want {
+ t.Errorf("Wrong dirent offset want: %d got: %d: %+v", want, d.NextOff, d)
+ }
+ }
+ if !doneSkippingNonTaskDirs {
+ t.Fatalf("Never found any process directories.")
+ }
+
+ // Test lookup.
+ for _, path := range []string{"/proc/1", "/proc/2"} {
+ fd, err := s.VFS.OpenAt(
+ s.Ctx,
+ s.Creds,
+ s.PathOpAtRoot(path),
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(%q) failed: %v", path, err)
+ }
+ defer fd.DecRef(s.Ctx)
+ buf := make([]byte, 1)
+ bufIOSeq := usermem.BytesIOSequence(buf)
+ if _, err := fd.Read(s.Ctx, bufIOSeq, vfs.ReadOptions{}); err != syserror.EISDIR {
+ t.Errorf("wrong error reading directory: %v", err)
+ }
+ }
+
+ if _, err := s.VFS.OpenAt(
+ s.Ctx,
+ s.Creds,
+ s.PathOpAtRoot("/proc/9999"),
+ &vfs.OpenOptions{},
+ ); err != syserror.ENOENT {
+ t.Fatalf("wrong error from vfsfs.OpenAt(/proc/9999): %v", err)
+ }
+}
+
+func TestTasksOffset(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ for i := 0; i < 3; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ if _, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root); err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ }
+
+ for _, tc := range []struct {
+ name string
+ offset int64
+ wants map[string]vfs.Dirent
+ }{
+ {
+ name: "small offset",
+ offset: 100,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "offset at start",
+ offset: 256,
+ wants: map[string]vfs.Dirent{
+ "self": selfLink,
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip /proc/self",
+ offset: 257,
+ wants: map[string]vfs.Dirent{
+ "thread-self": threadSelfLink,
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip symlinks",
+ offset: 258,
+ wants: map[string]vfs.Dirent{
+ "1": proc1,
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "skip first process",
+ offset: 260,
+ wants: map[string]vfs.Dirent{
+ "2": proc2,
+ "3": proc3,
+ },
+ },
+ {
+ name: "last process",
+ offset: 261,
+ wants: map[string]vfs.Dirent{
+ "3": proc3,
+ },
+ },
+ {
+ name: "after last",
+ offset: 262,
+ wants: nil,
+ },
+ {
+ name: "TaskLimit+1",
+ offset: kernel.TasksLimit + 1,
+ wants: nil,
+ },
+ {
+ name: "max",
+ offset: math.MaxInt64,
+ wants: nil,
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ s := s.WithSubtest(t)
+ fd, err := s.VFS.OpenAt(
+ s.Ctx,
+ s.Creds,
+ s.PathOpAtRoot("/proc"),
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/) failed: %v", err)
+ }
+ defer fd.DecRef(s.Ctx)
+ if _, err := fd.Seek(s.Ctx, tc.offset, linux.SEEK_SET); err != nil {
+ t.Fatalf("Seek(%d, SEEK_SET): %v", tc.offset, err)
+ }
+
+ var collector testutil.DirentCollector
+ if err := fd.IterDirents(s.Ctx, &collector); err != nil {
+ t.Fatalf("IterDirent(): %v", err)
+ }
+
+ expectedTypes := make(map[string]testutil.DirentType)
+ expectedOffsets := make(map[string]int64)
+ for name, want := range tc.wants {
+ expectedTypes[name] = want.Type
+ if want.NextOff != 0 {
+ expectedOffsets[name] = want.NextOff
+ }
+ }
+
+ collector.SkipDotsChecks(true) // We seek()ed past the dots.
+ s.AssertAllDirentTypes(&collector, expectedTypes)
+ s.AssertDirentOffsets(&collector, expectedOffsets)
+ })
+ }
+}
+
+func TestTask(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ _, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ collector := s.ListDirents(s.PathOpAtRoot("/proc/1"))
+ s.AssertAllDirentTypes(collector, taskStaticFiles)
+}
+
+func TestProcSelf(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(s.Ctx, "name", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+
+ collector := s.WithTemporaryContext(task).ListDirents(&vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse("/proc/self/"),
+ FollowFinalSymlink: true,
+ })
+ s.AssertAllDirentTypes(collector, taskStaticFiles)
+}
+
+func iterateDir(ctx context.Context, t *testing.T, s *testutil.System, fd *vfs.FileDescription) {
+ t.Logf("Iterating: %s", fd.MappedName(ctx))
+
+ var collector testutil.DirentCollector
+ if err := fd.IterDirents(ctx, &collector); err != nil {
+ t.Fatalf("IterDirents(): %v", err)
+ }
+ if err := collector.Contains(".", linux.DT_DIR); err != nil {
+ t.Error(err.Error())
+ }
+ if err := collector.Contains("..", linux.DT_DIR); err != nil {
+ t.Error(err.Error())
+ }
+
+ for _, d := range collector.Dirents() {
+ if d.Name == "." || d.Name == ".." {
+ continue
+ }
+ absPath := path.Join(fd.MappedName(ctx), d.Name)
+ if d.Type == linux.DT_LNK {
+ link, err := s.VFS.ReadlinkAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(absPath)},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.ReadlinkAt(%v) failed: %v", absPath, err)
+ } else {
+ t.Logf("Skipping symlink: %s => %s", absPath, link)
+ }
+ continue
+ }
+
+ t.Logf("Opening: %s", absPath)
+ child, err := s.VFS.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(ctx),
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse(absPath)},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Errorf("vfsfs.OpenAt(%v) failed: %v", absPath, err)
+ continue
+ }
+ defer child.DecRef(ctx)
+ stat, err := child.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Errorf("Stat(%v) failed: %v", absPath, err)
+ }
+ if got := linux.FileMode(stat.Mode).DirentType(); got != d.Type {
+ t.Errorf("wrong file mode, stat: %v, dirent: %v", got, d.Type)
+ }
+ if d.Type == linux.DT_DIR {
+ // Found another dir, let's do it again!
+ iterateDir(ctx, t, s, child)
+ }
+ }
+}
+
+// TestTree iterates all directories and stats every file.
+func TestTree(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+
+ pop := &vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse("test-file"),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ file, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, opts)
+ if err != nil {
+ t.Fatalf("failed to create test file: %v", err)
+ }
+ defer file.DecRef(s.Ctx)
+
+ var tasks []*kernel.Task
+ for i := 0; i < 5; i++ {
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ task, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("name-%d", i), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatalf("CreateTask(): %v", err)
+ }
+ // Add file to populate /proc/[pid]/fd and fdinfo directories.
+ task.FDTable().NewFDVFS2(task, 0, file, kernel.FDFlags{})
+ tasks = append(tasks, task)
+ }
+
+ ctx := tasks[0]
+ fd, err := s.VFS.OpenAt(
+ ctx,
+ auth.CredentialsFromContext(s.Ctx),
+ &vfs.PathOperation{Root: s.Root, Start: s.Root, Path: fspath.Parse("/proc")},
+ &vfs.OpenOptions{},
+ )
+ if err != nil {
+ t.Fatalf("vfsfs.OpenAt(/proc) failed: %v", err)
+ }
+ iterateDir(ctx, t, s, fd)
+ fd.DecRef(ctx)
+}
diff --git a/pkg/sentry/fsimpl/proc/version.go b/pkg/sentry/fsimpl/proc/version.go
deleted file mode 100644
index e1643d4e0..000000000
--- a/pkg/sentry/fsimpl/proc/version.go
+++ /dev/null
@@ -1,68 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package proc
-
-import (
- "bytes"
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
-)
-
-// versionData implements vfs.DynamicBytesSource for /proc/version.
-//
-// +stateify savable
-type versionData struct {
- // k is the owning Kernel.
- k *kernel.Kernel
-}
-
-var _ vfs.DynamicBytesSource = (*versionData)(nil)
-
-// Generate implements vfs.DynamicBytesSource.Generate.
-func (v *versionData) Generate(ctx context.Context, buf *bytes.Buffer) error {
- init := v.k.GlobalInit()
- if init == nil {
- // Attempted to read before the init Task is created. This can
- // only occur during startup, which should never need to read
- // this file.
- panic("Attempted to read version before initial Task is available")
- }
-
- // /proc/version takes the form:
- //
- // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST)
- // (COMPILER_VERSION) VERSION"
- //
- // where:
- // - SYSNAME, RELEASE, and VERSION are the same as returned by
- // sys_utsname
- // - COMPILE_USER is the user that build the kernel
- // - COMPILE_HOST is the hostname of the machine on which the kernel
- // was built
- // - COMPILER_VERSION is the version reported by the building compiler
- //
- // Since we don't really want to expose build information to
- // applications, those fields are omitted.
- //
- // FIXME(mpratt): Using Version from the init task SyscallTable
- // disregards the different version a task may have (e.g., in a uts
- // namespace).
- ver := init.Leader().SyscallTable().Version
- fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version)
- return nil
-}
diff --git a/pkg/sentry/fsimpl/signalfd/BUILD b/pkg/sentry/fsimpl/signalfd/BUILD
new file mode 100644
index 000000000..067c1657f
--- /dev/null
+++ b/pkg/sentry/fsimpl/signalfd/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "signalfd",
+ srcs = ["signalfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
new file mode 100644
index 000000000..6297e1df4
--- /dev/null
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -0,0 +1,136 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package signalfd
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalFileDescription implements FileDescriptionImpl for signal fds.
+type SignalFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ // target is the original signal target task.
+ //
+ // The semantics here are a bit broken. Linux will always use current
+ // for all reads, regardless of where the signalfd originated. We can't
+ // do exactly that because we need to plumb the context through
+ // EventRegister in order to support proper blocking behavior. This
+ // will undoubtedly become very complicated quickly.
+ target *kernel.Task
+
+ // mu protects mask.
+ mu sync.Mutex
+
+ // mask is the signal mask. Protected by mu.
+ mask linux.SignalSet
+}
+
+var _ vfs.FileDescriptionImpl = (*SignalFileDescription)(nil)
+
+// New creates a new signal fd.
+func New(vfsObj *vfs.VirtualFilesystem, target *kernel.Task, mask linux.SignalSet, flags uint32) (*vfs.FileDescription, error) {
+ vd := vfsObj.NewAnonVirtualDentry("[signalfd]")
+ defer vd.DecRef(target)
+ sfd := &SignalFileDescription{
+ target: target,
+ mask: mask,
+ }
+ if err := sfd.vfsfd.Init(sfd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &sfd.vfsfd, nil
+}
+
+// Mask returns the signal mask.
+func (sfd *SignalFileDescription) Mask() linux.SignalSet {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ return sfd.mask
+}
+
+// SetMask sets the signal mask.
+func (sfd *SignalFileDescription) SetMask(mask linux.SignalSet) {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ sfd.mask = mask
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
+ // Attempt to dequeue relevant signals.
+ info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0)
+ if err != nil {
+ // There must be no signal available.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Copy out the signal info using the specified format.
+ var buf [128]byte
+ binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ Signo: uint32(info.Signo),
+ Errno: info.Errno,
+ Code: info.Code,
+ PID: uint32(info.Pid()),
+ UID: uint32(info.Uid()),
+ Status: info.Status(),
+ Overrun: uint32(info.Overrun()),
+ Addr: info.Addr(),
+ })
+ n, err := dst.CopyOut(ctx, buf[:])
+ return int64(n), err
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (sfd *SignalFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ if mask&waiter.EventIn != 0 && sfd.target.PendingSignals()&sfd.mask != 0 {
+ return waiter.EventIn // Pending signals.
+ }
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (sfd *SignalFileDescription) EventRegister(entry *waiter.Entry, _ waiter.EventMask) {
+ sfd.mu.Lock()
+ defer sfd.mu.Unlock()
+ // Register for the signal set; ignore the passed events.
+ sfd.target.SignalRegister(entry, waiter.EventMask(sfd.mask))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) {
+ // Unregister the original entry.
+ sfd.target.SignalUnregister(entry)
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (sfd *SignalFileDescription) Release(context.Context) {}
diff --git a/pkg/sentry/fsimpl/sockfs/BUILD b/pkg/sentry/fsimpl/sockfs/BUILD
new file mode 100644
index 000000000..9453277b8
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "sockfs",
+ srcs = ["sockfs.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
new file mode 100644
index 000000000..c61818ff6
--- /dev/null
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sockfs provides a filesystem implementation for anonymous sockets.
+package sockfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// filesystemType implements vfs.FilesystemType.
+type filesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ panic("sockfs.filesystemType.GetFilesystem should never be called")
+}
+
+// Name implements FilesystemType.Name.
+//
+// Note that registering sockfs is unnecessary, except for the fact that it
+// will not show up under /proc/filesystems as a result. This is a very minor
+// discrepancy from Linux.
+func (filesystemType) Name() string {
+ return "sockfs"
+}
+
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// NewFilesystem sets up and returns a new sockfs filesystem.
+//
+// Note that there should only ever be one instance of sockfs.Filesystem,
+// backing a global socket mount.
+func NewFilesystem(vfsObj *vfs.VirtualFilesystem) (*vfs.Filesystem, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, err
+ }
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.Filesystem.VFSFilesystem().Init(vfsObj, filesystemType{}, fs)
+ return fs.Filesystem.VFSFilesystem(), nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ inode := vd.Dentry().Impl().(*kernfs.Dentry).Inode().(*inode)
+ b.PrependComponent(fmt.Sprintf("socket:[%d]", inode.InodeAttrs.Ino()))
+ return vfs.PrependPathSyntheticError{}
+}
+
+// inode implements kernfs.Inode.
+type inode struct {
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
+ kernfs.InodeAttrs
+ kernfs.InodeNoopRefCount
+}
+
+// Open implements kernfs.Inode.Open.
+func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ return nil, syserror.ENXIO
+}
+
+// NewDentry constructs and returns a sockfs dentry.
+//
+// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
+func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry {
+ fs := mnt.Filesystem().Impl().(*filesystem)
+
+ // File mode matches net/socket.c:sock_alloc.
+ filemode := linux.FileMode(linux.S_IFSOCK | 0600)
+ i := &inode{}
+ i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode)
+
+ d := &kernfs.Dentry{}
+ d.Init(i)
+ return d.VFSDentry()
+}
diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD
new file mode 100644
index 000000000..1b548ccd4
--- /dev/null
+++ b/pkg/sentry/fsimpl/sys/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "sys",
+ srcs = [
+ "sys.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "sys_test",
+ srcs = ["sys_test.go"],
+ deps = [
+ ":sys",
+ "//pkg/abi/linux",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
new file mode 100644
index 000000000..0401726b6
--- /dev/null
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -0,0 +1,159 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package sys implements sysfs.
+package sys
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "sysfs"
+const defaultSysDirMode = linux.FileMode(0755)
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ kernfs.Filesystem
+
+ devMinor uint32
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ root := fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "block": fs.newDir(creds, defaultSysDirMode, nil),
+ "bus": fs.newDir(creds, defaultSysDirMode, nil),
+ "class": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "power_supply": fs.newDir(creds, defaultSysDirMode, nil),
+ }),
+ "dev": fs.newDir(creds, defaultSysDirMode, nil),
+ "devices": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "system": fs.newDir(creds, defaultSysDirMode, map[string]*kernfs.Dentry{
+ "cpu": cpuDir(ctx, fs, creds),
+ }),
+ }),
+ "firmware": fs.newDir(creds, defaultSysDirMode, nil),
+ "fs": fs.newDir(creds, defaultSysDirMode, nil),
+ "kernel": fs.newDir(creds, defaultSysDirMode, nil),
+ "module": fs.newDir(creds, defaultSysDirMode, nil),
+ "power": fs.newDir(creds, defaultSysDirMode, nil),
+ })
+ return fs.VFSFilesystem(), root.VFSDentry(), nil
+}
+
+func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry {
+ k := kernel.KernelFromContext(ctx)
+ maxCPUCores := k.ApplicationCores()
+ children := map[string]*kernfs.Dentry{
+ "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)),
+ }
+ for i := uint(0); i < maxCPUCores; i++ {
+ children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil)
+ }
+ return fs.newDir(creds, defaultSysDirMode, children)
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// dir implements kernfs.Inode.
+type dir struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoDynamicLookup
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren
+ kernfs.OrderedChildren
+
+ locks vfs.FileLocks
+
+ dentry kernfs.Dentry
+}
+
+func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]*kernfs.Dentry) *kernfs.Dentry {
+ d := &dir{}
+ d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755)
+ d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ d.dentry.Init(d)
+
+ d.IncLinks(d.OrderedChildren.Populate(&d.dentry, contents))
+
+ return &d.dentry
+}
+
+// SetStat implements Inode.SetStat not allowing inode attributes to be changed.
+func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Open implements kernfs.Inode.Open.
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts)
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// cpuFile implements kernfs.Inode.
+type cpuFile struct {
+ kernfs.DynamicBytesFile
+ maxCores uint
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "0-%d\n", c.maxCores-1)
+ return nil
+}
+
+func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) *kernfs.Dentry {
+ c := &cpuFile{maxCores: maxCores}
+ c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode)
+ d := &kernfs.Dentry{}
+ d.Init(c)
+ return d
+}
diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go
new file mode 100644
index 000000000..9fd38b295
--- /dev/null
+++ b/pkg/sentry/fsimpl/sys/sys_test.go
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sys_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func newTestSystem(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Failed to create test kernel: %v", err)
+ }
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+ k.VFS().MustRegisterFilesystemType(sys.Name, sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+
+ mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("Failed to create new mount namespace: %v", err)
+ }
+ return testutil.NewSystem(ctx, t, k.VFS(), mns)
+}
+
+func TestReadCPUFile(t *testing.T) {
+ s := newTestSystem(t)
+ defer s.Destroy()
+ k := kernel.KernelFromContext(s.Ctx)
+ maxCPUCores := k.ApplicationCores()
+
+ expected := fmt.Sprintf("0-%d\n", maxCPUCores-1)
+
+ for _, fname := range []string{"online", "possible", "present"} {
+ pop := s.PathOpAtRoot(fmt.Sprintf("devices/system/cpu/%s", fname))
+ fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{})
+ if err != nil {
+ t.Fatalf("OpenAt(pop:%+v) = %+v failed: %v", pop, fd, err)
+ }
+ defer fd.DecRef(s.Ctx)
+ content, err := s.ReadToEnd(fd)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ if diff := cmp.Diff(expected, content); diff != "" {
+ t.Fatalf("Read returned unexpected data:\n--- want\n+++ got\n%v", diff)
+ }
+ }
+}
+
+func TestSysRootContainsExpectedEntries(t *testing.T) {
+ s := newTestSystem(t)
+ defer s.Destroy()
+ pop := s.PathOpAtRoot("/")
+ s.AssertAllDirentTypes(s.ListDirents(pop), map[string]testutil.DirentType{
+ "block": linux.DT_DIR,
+ "bus": linux.DT_DIR,
+ "class": linux.DT_DIR,
+ "dev": linux.DT_DIR,
+ "devices": linux.DT_DIR,
+ "firmware": linux.DT_DIR,
+ "fs": linux.DT_DIR,
+ "kernel": linux.DT_DIR,
+ "module": linux.DT_DIR,
+ "power": linux.DT_DIR,
+ })
+}
diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD
new file mode 100644
index 000000000..400a97996
--- /dev/null
+++ b/pkg/sentry/fsimpl/testutil/BUILD
@@ -0,0 +1,37 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "testutil",
+ testonly = 1,
+ srcs = [
+ "kernel.go",
+ "testutil.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/cpuid",
+ "//pkg/fspath",
+ "//pkg/memutil",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/sched",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/platform/kvm",
+ "//pkg/sentry/platform/ptrace",
+ "//pkg/sentry/time",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/usermem",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
new file mode 100644
index 000000000..1813269e0
--- /dev/null
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -0,0 +1,180 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "runtime"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/cpuid"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/sched"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+
+ // Platforms are plugable.
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/kvm"
+ _ "gvisor.dev/gvisor/pkg/sentry/platform/ptrace"
+)
+
+var (
+ platformFlag = flag.String("platform", "ptrace", "specify which platform to use")
+)
+
+// Boot initializes a new bare bones kernel for test.
+func Boot() (*kernel.Kernel, error) {
+ platformCtr, err := platform.Lookup(*platformFlag)
+ if err != nil {
+ return nil, fmt.Errorf("platform not found: %v", err)
+ }
+ deviceFile, err := platformCtr.OpenDevice()
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+ plat, err := platformCtr.New(deviceFile)
+ if err != nil {
+ return nil, fmt.Errorf("creating platform: %v", err)
+ }
+
+ kernel.VFS2Enabled = true
+ k := &kernel.Kernel{
+ Platform: plat,
+ }
+
+ mf, err := createMemoryFile()
+ if err != nil {
+ return nil, err
+ }
+ k.SetMemoryFile(mf)
+
+ // Pass k as the platform since it is savable, unlike the actual platform.
+ vdso, err := loader.PrepareVDSO(k)
+ if err != nil {
+ return nil, fmt.Errorf("creating vdso: %v", err)
+ }
+
+ // Create timekeeper.
+ tk, err := kernel.NewTimekeeper(k, vdso.ParamPage.FileRange())
+ if err != nil {
+ return nil, fmt.Errorf("creating timekeeper: %v", err)
+ }
+ tk.SetClocks(time.NewCalibratedClocks())
+
+ creds := auth.NewRootCredentials(auth.NewRootUserNamespace())
+
+ // Initiate the Kernel object, which is required by the Context passed
+ // to createVFS in order to mount (among other things) procfs.
+ if err = k.Init(kernel.InitKernelArgs{
+ ApplicationCores: uint(runtime.GOMAXPROCS(-1)),
+ FeatureSet: cpuid.HostFeatureSet(),
+ Timekeeper: tk,
+ RootUserNamespace: creds.UserNamespace,
+ Vdso: vdso,
+ RootUTSNamespace: kernel.NewUTSNamespace("hostname", "domain", creds.UserNamespace),
+ RootIPCNamespace: kernel.NewIPCNamespace(creds.UserNamespace),
+ RootAbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ PIDNamespace: kernel.NewRootPIDNamespace(creds.UserNamespace),
+ }); err != nil {
+ return nil, fmt.Errorf("initializing kernel: %v", err)
+ }
+
+ k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ ls, err := limits.NewLinuxLimitSet()
+ if err != nil {
+ return nil, err
+ }
+ tg := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, ls)
+ k.TestOnly_SetGlobalInit(tg)
+
+ return k, nil
+}
+
+// CreateTask creates a new bare bones task for tests.
+func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns *vfs.MountNamespace, root, cwd vfs.VirtualDentry) (*kernel.Task, error) {
+ k := kernel.KernelFromContext(ctx)
+ if k == nil {
+ return nil, fmt.Errorf("cannot find kernel from context")
+ }
+
+ exe, err := newFakeExecutable(ctx, k.VFS(), auth.CredentialsFromContext(ctx), root)
+ if err != nil {
+ return nil, err
+ }
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
+ m.SetExecutable(ctx, fsbridge.NewVFSFile(exe))
+
+ config := &kernel.TaskConfig{
+ Kernel: k,
+ ThreadGroup: tc,
+ TaskContext: &kernel.TaskContext{Name: name, MemoryManager: m},
+ Credentials: auth.CredentialsFromContext(ctx),
+ NetworkNamespace: k.RootNetworkNamespace(),
+ AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
+ UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
+ IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
+ AbstractSocketNamespace: kernel.NewAbstractSocketNamespace(),
+ MountNamespaceVFS2: mntns,
+ FSContext: kernel.NewFSContextVFS2(root, cwd, 0022),
+ FDTable: k.NewFDTable(),
+ }
+ return k.TaskSet().NewTask(config)
+}
+
+func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) {
+ const name = "executable"
+ pop := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }
+ opts := &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_CREAT,
+ Mode: 0777,
+ }
+ return vfsObj.OpenAt(ctx, creds, pop, opts)
+}
+
+func createMemoryFile() (*pgalloc.MemoryFile, error) {
+ const memfileName = "test-memory"
+ memfd, err := memutil.CreateMemFD(memfileName, 0)
+ if err != nil {
+ return nil, fmt.Errorf("error creating memfd: %v", err)
+ }
+ memfile := os.NewFile(uintptr(memfd), memfileName)
+ mf, err := pgalloc.NewMemoryFile(memfile, pgalloc.MemoryFileOpts{})
+ if err != nil {
+ memfile.Close()
+ return nil, fmt.Errorf("error creating pgalloc.MemoryFile: %v", err)
+ }
+ return mf, nil
+}
diff --git a/pkg/sentry/fsimpl/testutil/testutil.go b/pkg/sentry/fsimpl/testutil/testutil.go
new file mode 100644
index 000000000..568132121
--- /dev/null
+++ b/pkg/sentry/fsimpl/testutil/testutil.go
@@ -0,0 +1,284 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testutil provides common test utilities for kernfs-based
+// filesystems.
+package testutil
+
+import (
+ "fmt"
+ "io"
+ "strings"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// System represents the context for a single test.
+//
+// Test systems must be explicitly destroyed with System.Destroy.
+type System struct {
+ t *testing.T
+ Ctx context.Context
+ Creds *auth.Credentials
+ VFS *vfs.VirtualFilesystem
+ Root vfs.VirtualDentry
+ MntNs *vfs.MountNamespace
+}
+
+// NewSystem constructs a System.
+//
+// Precondition: Caller must hold a reference on MntNs, whose ownership
+// is transferred to the new System.
+func NewSystem(ctx context.Context, t *testing.T, v *vfs.VirtualFilesystem, mns *vfs.MountNamespace) *System {
+ s := &System{
+ t: t,
+ Ctx: ctx,
+ Creds: auth.CredentialsFromContext(ctx),
+ VFS: v,
+ MntNs: mns,
+ Root: mns.Root(),
+ }
+ return s
+}
+
+// WithSubtest creates a temporary test system with a new test harness,
+// referencing all other resources from the original system. This is useful when
+// a system is reused for multiple subtests, and the T needs to change for each
+// case. Note that this is safe when test cases run in parallel, as all
+// resources referenced by the system are immutable, or handle interior
+// mutations in a thread-safe manner.
+//
+// The returned system must not outlive the original and should not be destroyed
+// via System.Destroy.
+func (s *System) WithSubtest(t *testing.T) *System {
+ return &System{
+ t: t,
+ Ctx: s.Ctx,
+ Creds: s.Creds,
+ VFS: s.VFS,
+ MntNs: s.MntNs,
+ Root: s.Root,
+ }
+}
+
+// WithTemporaryContext constructs a temporary test system with a new context
+// ctx. The temporary system borrows all resources and references from the
+// original system. The returned temporary system must not outlive the original
+// system, and should not be destroyed via System.Destroy.
+func (s *System) WithTemporaryContext(ctx context.Context) *System {
+ return &System{
+ t: s.t,
+ Ctx: ctx,
+ Creds: s.Creds,
+ VFS: s.VFS,
+ MntNs: s.MntNs,
+ Root: s.Root,
+ }
+}
+
+// Destroy release resources associated with a test system.
+func (s *System) Destroy() {
+ s.Root.DecRef(s.Ctx)
+ s.MntNs.DecRef(s.Ctx) // Reference on MntNs passed to NewSystem.
+}
+
+// ReadToEnd reads the contents of fd until EOF to a string.
+func (s *System) ReadToEnd(fd *vfs.FileDescription) (string, error) {
+ buf := make([]byte, usermem.PageSize)
+ bufIOSeq := usermem.BytesIOSequence(buf)
+ opts := vfs.ReadOptions{}
+
+ var content strings.Builder
+ for {
+ n, err := fd.Read(s.Ctx, bufIOSeq, opts)
+ if n == 0 || err != nil {
+ if err == io.EOF {
+ err = nil
+ }
+ return content.String(), err
+ }
+ content.Write(buf[:n])
+ }
+}
+
+// PathOpAtRoot constructs a PathOperation with the given path from
+// the root of the filesystem.
+func (s *System) PathOpAtRoot(path string) *vfs.PathOperation {
+ return &vfs.PathOperation{
+ Root: s.Root,
+ Start: s.Root,
+ Path: fspath.Parse(path),
+ }
+}
+
+// GetDentryOrDie attempts to resolve a dentry referred to by the
+// provided path operation. If unsuccessful, the test fails.
+func (s *System) GetDentryOrDie(pop *vfs.PathOperation) vfs.VirtualDentry {
+ vd, err := s.VFS.GetDentryAt(s.Ctx, s.Creds, pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ s.t.Fatalf("GetDentryAt(pop:%+v) failed: %v", pop, err)
+ }
+ return vd
+}
+
+// DirentType is an alias for values for linux_dirent64.d_type.
+type DirentType = uint8
+
+// ListDirents lists the Dirents for a directory at pop.
+func (s *System) ListDirents(pop *vfs.PathOperation) *DirentCollector {
+ fd, err := s.VFS.OpenAt(s.Ctx, s.Creds, pop, &vfs.OpenOptions{Flags: linux.O_RDONLY})
+ if err != nil {
+ s.t.Fatalf("OpenAt for PathOperation %+v failed: %v", pop, err)
+ }
+ defer fd.DecRef(s.Ctx)
+
+ collector := &DirentCollector{}
+ if err := fd.IterDirents(s.Ctx, collector); err != nil {
+ s.t.Fatalf("IterDirent failed: %v", err)
+ }
+ return collector
+}
+
+// AssertAllDirentTypes verifies that the set of dirents in collector contains
+// exactly the specified set of expected entries. AssertAllDirentTypes respects
+// collector.skipDots, and implicitly checks for "." and ".." accordingly.
+func (s *System) AssertAllDirentTypes(collector *DirentCollector, expected map[string]DirentType) {
+ if expected == nil {
+ expected = make(map[string]DirentType)
+ }
+ // Also implicitly check for "." and "..", if enabled.
+ if !collector.skipDots {
+ expected["."] = linux.DT_DIR
+ expected[".."] = linux.DT_DIR
+ }
+
+ dentryTypes := make(map[string]DirentType)
+ collector.mu.Lock()
+ for _, dirent := range collector.dirents {
+ dentryTypes[dirent.Name] = dirent.Type
+ }
+ collector.mu.Unlock()
+ if diff := cmp.Diff(expected, dentryTypes); diff != "" {
+ s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff)
+ }
+}
+
+// AssertDirentOffsets verifies that collector contains at least the entries
+// specified in expected, with the given NextOff field. Entries specified in
+// expected but missing from collector result in failure. Extra entries in
+// collector are ignored. AssertDirentOffsets respects collector.skipDots, and
+// implicitly checks for "." and ".." accordingly.
+func (s *System) AssertDirentOffsets(collector *DirentCollector, expected map[string]int64) {
+ // Also implicitly check for "." and "..", if enabled.
+ if !collector.skipDots {
+ expected["."] = 1
+ expected[".."] = 2
+ }
+
+ dentryNextOffs := make(map[string]int64)
+ collector.mu.Lock()
+ for _, dirent := range collector.dirents {
+ // Ignore extra entries in dentries that are not in expected.
+ if _, ok := expected[dirent.Name]; ok {
+ dentryNextOffs[dirent.Name] = dirent.NextOff
+ }
+ }
+ collector.mu.Unlock()
+ if diff := cmp.Diff(expected, dentryNextOffs); diff != "" {
+ s.t.Fatalf("IterDirent had unexpected results:\n--- want\n+++ got\n%v", diff)
+ }
+}
+
+// DirentCollector provides an implementation for vfs.IterDirentsCallback for
+// testing. It simply iterates to the end of a given directory FD and collects
+// all dirents emitted by the callback.
+type DirentCollector struct {
+ mu sync.Mutex
+ order []*vfs.Dirent
+ dirents map[string]*vfs.Dirent
+ // When the collector is used in various Assert* functions, should "." and
+ // ".." be implicitly checked?
+ skipDots bool
+}
+
+// SkipDotsChecks enables or disables the implicit checks on "." and ".." when
+// the collector is used in various Assert* functions. Note that "." and ".."
+// are still collected if passed to d.Handle, so the caller should only disable
+// the checks when they aren't expected.
+func (d *DirentCollector) SkipDotsChecks(value bool) {
+ d.skipDots = value
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (d *DirentCollector) Handle(dirent vfs.Dirent) error {
+ d.mu.Lock()
+ if d.dirents == nil {
+ d.dirents = make(map[string]*vfs.Dirent)
+ }
+ d.order = append(d.order, &dirent)
+ d.dirents[dirent.Name] = &dirent
+ d.mu.Unlock()
+ return nil
+}
+
+// Count returns the number of dirents currently in the collector.
+func (d *DirentCollector) Count() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return len(d.dirents)
+}
+
+// Contains checks whether the collector has a dirent with the given name and
+// type.
+func (d *DirentCollector) Contains(name string, typ uint8) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ dirent, ok := d.dirents[name]
+ if !ok {
+ return fmt.Errorf("No dirent named %q found", name)
+ }
+ if dirent.Type != typ {
+ return fmt.Errorf("Dirent named %q found, but was expecting type %s, got: %+v", name, linux.DirentType.Parse(uint64(typ)), dirent)
+ }
+ return nil
+}
+
+// Dirents returns all dirents discovered by this collector.
+func (d *DirentCollector) Dirents() map[string]*vfs.Dirent {
+ d.mu.Lock()
+ dirents := make(map[string]*vfs.Dirent)
+ for n, d := range d.dirents {
+ dirents[n] = d
+ }
+ d.mu.Unlock()
+ return dirents
+}
+
+// OrderedDirents returns an ordered list of dirents as discovered by this
+// collector.
+func (d *DirentCollector) OrderedDirents() []*vfs.Dirent {
+ d.mu.Lock()
+ dirents := make([]*vfs.Dirent, len(d.order))
+ copy(dirents, d.order)
+ d.mu.Unlock()
+ return dirents
+}
diff --git a/pkg/sentry/fsimpl/timerfd/BUILD b/pkg/sentry/fsimpl/timerfd/BUILD
new file mode 100644
index 000000000..fbb02a271
--- /dev/null
+++ b/pkg/sentry/fsimpl/timerfd/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "timerfd",
+ srcs = ["timerfd.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/context",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go
new file mode 100644
index 000000000..86beaa0a8
--- /dev/null
+++ b/pkg/sentry/fsimpl/timerfd/timerfd.go
@@ -0,0 +1,144 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package timerfd implements timer fds.
+package timerfd
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TimerFileDescription implements FileDescriptionImpl for timer fds. It also
+// implements ktime.TimerListener.
+type TimerFileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.NoLockFD
+
+ events waiter.Queue
+ timer *ktime.Timer
+
+ // val is the number of timer expirations since the last successful
+ // call to PRead, or SetTime. val must be accessed using atomic memory
+ // operations.
+ val uint64
+}
+
+var _ vfs.FileDescriptionImpl = (*TimerFileDescription)(nil)
+var _ ktime.TimerListener = (*TimerFileDescription)(nil)
+
+// New returns a new timer fd.
+func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, flags uint32) (*vfs.FileDescription, error) {
+ vd := vfsObj.NewAnonVirtualDentry("[timerfd]")
+ defer vd.DecRef(ctx)
+ tfd := &TimerFileDescription{}
+ tfd.timer = ktime.NewTimer(clock, tfd)
+ if err := tfd.vfsfd.Init(tfd, flags, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &tfd.vfsfd, nil
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ const sizeofUint64 = 8
+ if dst.NumBytes() < sizeofUint64 {
+ return 0, syserror.EINVAL
+ }
+ if val := atomic.SwapUint64(&tfd.val, 0); val != 0 {
+ var buf [sizeofUint64]byte
+ usermem.ByteOrder.PutUint64(buf[:], val)
+ if _, err := dst.CopyOut(ctx, buf[:]); err != nil {
+ // Linux does not undo consuming the number of
+ // expirations even if writing to userspace fails.
+ return 0, err
+ }
+ return sizeofUint64, nil
+ }
+ return 0, syserror.ErrWouldBlock
+}
+
+// Clock returns the timer fd's Clock.
+func (tfd *TimerFileDescription) Clock() ktime.Clock {
+ return tfd.timer.Clock()
+}
+
+// GetTime returns the associated Timer's setting and the time at which it was
+// observed.
+func (tfd *TimerFileDescription) GetTime() (ktime.Time, ktime.Setting) {
+ return tfd.timer.Get()
+}
+
+// SetTime atomically changes the associated Timer's setting, resets the number
+// of expirations to 0, and returns the previous setting and the time at which
+// it was observed.
+func (tfd *TimerFileDescription) SetTime(s ktime.Setting) (ktime.Time, ktime.Setting) {
+ return tfd.timer.SwapAnd(s, func() { atomic.StoreUint64(&tfd.val, 0) })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (tfd *TimerFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ if atomic.LoadUint64(&tfd.val) != 0 {
+ ready |= waiter.EventIn
+ }
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (tfd *TimerFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ tfd.events.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (tfd *TimerFileDescription) EventUnregister(e *waiter.Entry) {
+ tfd.events.EventUnregister(e)
+}
+
+// PauseTimer pauses the associated Timer.
+func (tfd *TimerFileDescription) PauseTimer() {
+ tfd.timer.Pause()
+}
+
+// ResumeTimer resumes the associated Timer.
+func (tfd *TimerFileDescription) ResumeTimer() {
+ tfd.timer.Resume()
+}
+
+// Release implements FileDescriptionImpl.Release()
+func (tfd *TimerFileDescription) Release(context.Context) {
+ tfd.timer.Destroy()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (tfd *TimerFileDescription) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
+ atomic.AddUint64(&tfd.val, exp)
+ tfd.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (tfd *TimerFileDescription) Destroy() {}
diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD
new file mode 100644
index 000000000..5cd428d64
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/BUILD
@@ -0,0 +1,125 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dentry_list",
+ out = "dentry_list.go",
+ package = "tmpfs",
+ prefix = "dentry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*dentry",
+ "Linker": "*dentry",
+ },
+)
+
+go_template_instance(
+ name = "fstree",
+ out = "fstree.go",
+ package = "tmpfs",
+ prefix = "generic",
+ template = "//pkg/sentry/vfs/genericfstree:generic_fstree",
+ types = {
+ "Dentry": "dentry",
+ },
+)
+
+go_template_instance(
+ name = "inode_refs",
+ out = "inode_refs.go",
+ package = "tmpfs",
+ prefix = "inode",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "inode",
+ },
+)
+
+go_library(
+ name = "tmpfs",
+ srcs = [
+ "dentry_list.go",
+ "device_file.go",
+ "directory.go",
+ "filesystem.go",
+ "fstree.go",
+ "inode_refs.go",
+ "named_pipe.go",
+ "regular_file.go",
+ "socket_file.go",
+ "symlink.go",
+ "tmpfs.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/pgalloc",
+ "//pkg/sentry/platform",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sentry/vfs/memxattr",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
+
+go_test(
+ name = "benchmark_test",
+ size = "small",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ ":tmpfs",
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/refs",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ ],
+)
+
+go_test(
+ name = "tmpfs_test",
+ size = "small",
+ srcs = [
+ "pipe_test.go",
+ "regular_file_test.go",
+ "stat_test.go",
+ "tmpfs_test.go",
+ ],
+ library = ":tmpfs",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/contexttest",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/memfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
index a94b17db6..d263147c2 100644
--- a/pkg/sentry/fsimpl/memfs/benchmark_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go
@@ -21,11 +21,13 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/memfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -81,7 +83,7 @@ func fileOpOn(ctx context.Context, mntns *fs.MountNamespace, root, wd *fs.Dirent
}
err = fn(root, d)
- d.DecRef()
+ d.DecRef(ctx)
return err
}
@@ -103,17 +105,17 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create mount namespace: %v", err)
}
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create nested directories with given depth.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
d := root
d.IncRef()
- defer d.DecRef()
+ defer d.DecRef(ctx)
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil {
@@ -123,7 +125,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- d.DecRef()
+ d.DecRef(ctx)
d = next
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
@@ -134,7 +136,7 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- file.DecRef()
+ file.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
@@ -160,39 +162,46 @@ func BenchmarkVFS1TmpfsStat(b *testing.B) {
b.Fatalf("stat(%q) failed: %v", filePath, err)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
-func BenchmarkVFS2MemfsStat(b *testing.B) {
+func BenchmarkVFS2TmpfsStat(b *testing.B) {
for _, depth := range depths {
b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
ctx := contexttest.Context(b)
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{})
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create nested directories with given depth.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
vd := root
vd.IncRef()
- defer vd.DecRef()
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
pop := vfs.PathOperation{
- Root: root,
- Start: vd,
- Pathname: name,
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(name),
}
if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
Mode: 0755,
@@ -203,7 +212,7 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- vd.DecRef()
+ vd.DecRef(ctx)
vd = nextVD
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
@@ -213,16 +222,18 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: vd,
- Pathname: filename,
+ Path: fspath.Parse(filename),
FollowFinalSymlink: true,
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
Mode: 0644,
})
+ vd.DecRef(ctx)
+ vd = vfs.VirtualDentry{}
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
@@ -232,7 +243,7 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
@@ -243,6 +254,8 @@ func BenchmarkVFS2MemfsStat(b *testing.B) {
b.Fatalf("got wrong permissions (%0o)", stat.Mode)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
@@ -265,14 +278,14 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create mount namespace: %v", err)
}
- defer mntns.DecRef()
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create and mount the submount.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
if err := root.Inode.CreateDirectory(ctx, root, mountPointName, fs.FilePermsFromMode(0755)); err != nil {
b.Fatalf("failed to create mount point: %v", err)
}
@@ -280,7 +293,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount point: %v", err)
}
- defer mountPoint.DecRef()
+ defer mountPoint.DecRef(ctx)
submountInode, err := tmpfsFS.Mount(ctx, "tmpfs", fs.MountSourceFlags{}, "", nil)
if err != nil {
b.Fatalf("failed to create tmpfs submount: %v", err)
@@ -296,7 +309,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount root: %v", err)
}
- defer d.DecRef()
+ defer d.DecRef(ctx)
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
if err := d.Inode.CreateDirectory(ctx, d, name, fs.FilePermsFromMode(0755)); err != nil {
@@ -306,7 +319,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- d.DecRef()
+ d.DecRef(ctx)
d = next
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
@@ -317,7 +330,7 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- file.DecRef()
+ file.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
@@ -343,34 +356,42 @@ func BenchmarkVFS1TmpfsMountStat(b *testing.B) {
b.Fatalf("stat(%q) failed: %v", filePath, err)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
-func BenchmarkVFS2MemfsMountStat(b *testing.B) {
+func BenchmarkVFS2TmpfsMountStat(b *testing.B) {
for _, depth := range depths {
b.Run(fmt.Sprintf("%d", depth), func(b *testing.B) {
ctx := contexttest.Context(b)
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", memfs.FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{})
+ vfsObj := vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ b.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
b.Fatalf("failed to create tmpfs root mount: %v", err)
}
+ defer mntns.DecRef(ctx)
var filePathBuilder strings.Builder
filePathBuilder.WriteByte('/')
// Create the mount point.
root := mntns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
pop := vfs.PathOperation{
- Root: root,
- Start: root,
- Pathname: mountPointName,
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(mountPointName),
}
if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
Mode: 0755,
@@ -382,9 +403,9 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount point: %v", err)
}
- defer mountPoint.DecRef()
+ defer mountPoint.DecRef(ctx)
// Create and mount the submount.
- if err := vfsObj.NewMount(ctx, creds, "", &pop, "memfs", &vfs.NewFilesystemOptions{}); err != nil {
+ if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil {
b.Fatalf("failed to mount tmpfs submount: %v", err)
}
filePathBuilder.WriteString(mountPointName)
@@ -395,13 +416,12 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to mount root: %v", err)
}
- defer vd.DecRef()
for i := depth; i > 0; i-- {
name := fmt.Sprintf("%d", i)
pop := vfs.PathOperation{
- Root: root,
- Start: vd,
- Pathname: name,
+ Root: root,
+ Start: vd,
+ Path: fspath.Parse(name),
}
if err := vfsObj.MkdirAt(ctx, creds, &pop, &vfs.MkdirOptions{
Mode: 0755,
@@ -412,33 +432,27 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
if err != nil {
b.Fatalf("failed to walk to directory %q: %v", name, err)
}
- vd.DecRef()
+ vd.DecRef(ctx)
vd = nextVD
filePathBuilder.WriteString(name)
filePathBuilder.WriteByte('/')
}
- // Verify that we didn't create any directories under the mount
- // point (i.e. they were all created on the submount).
- firstDirName := fmt.Sprintf("%d", depth)
- if child := mountPoint.Dentry().Child(firstDirName); child != nil {
- b.Fatalf("created directory %q under root mount, not submount", firstDirName)
- }
-
// Create the file that will be stat'd.
fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: vd,
- Pathname: filename,
+ Path: fspath.Parse(filename),
FollowFinalSymlink: true,
}, &vfs.OpenOptions{
Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
Mode: 0644,
})
+ vd.DecRef(ctx)
if err != nil {
b.Fatalf("failed to create file %q: %v", filename, err)
}
- fd.DecRef()
+ fd.DecRef(ctx)
filePathBuilder.WriteString(filename)
filePath := filePathBuilder.String()
@@ -448,7 +462,7 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: root,
- Pathname: filePath,
+ Path: fspath.Parse(filePath),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
@@ -459,6 +473,14 @@ func BenchmarkVFS2MemfsMountStat(b *testing.B) {
b.Fatalf("got wrong permissions (%0o)", stat.Mode)
}
}
+ // Don't include deferred cleanup in benchmark time.
+ b.StopTimer()
})
}
}
+
+func init() {
+ // Turn off reference leak checking for a fair comparison between vfs1 and
+ // vfs2.
+ refs.SetLeakMode(refs.NoLeakChecking)
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/device_file.go b/pkg/sentry/fsimpl/tmpfs/device_file.go
new file mode 100644
index 000000000..ac54d420d
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/device_file.go
@@ -0,0 +1,49 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+type deviceFile struct {
+ inode inode
+ kind vfs.DeviceKind
+ major uint32
+ minor uint32
+}
+
+func (fs *filesystem) newDeviceFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, kind vfs.DeviceKind, major, minor uint32) *inode {
+ file := &deviceFile{
+ kind: kind,
+ major: major,
+ minor: minor,
+ }
+ switch kind {
+ case vfs.BlockDevice:
+ mode |= linux.S_IFBLK
+ case vfs.CharDevice:
+ mode |= linux.S_IFCHR
+ default:
+ panic(fmt.Sprintf("invalid DeviceKind: %v", kind))
+ }
+ file.inode.init(file, fs, kuid, kgid, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go
index 0bd82e480..78b4fc5be 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/tmpfs/directory.go
@@ -12,55 +12,95 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
type directory struct {
- inode inode
+ // Since directories can't be hard-linked, each directory can only be
+ // associated with a single dentry, which we can store in the directory
+ // struct.
+ dentry dentry
+ inode inode
+
+ // childMap maps the names of the directory's children to their dentries.
+ // childMap is protected by filesystem.mu.
+ childMap map[string]*dentry
- // childList is a list containing (1) child Dentries and (2) fake Dentries
+ // numChildren is len(childMap), but accessed using atomic memory
+ // operations to avoid locking in inode.statTo().
+ numChildren int64
+
+ // childList is a list containing (1) child dentries and (2) fake dentries
// (with inode == nil) that represent the iteration position of
// directoryFDs. childList is used to support directoryFD.IterDirents()
- // efficiently. childList is protected by filesystem.mu.
+ // efficiently. childList is protected by iterMu.
+ iterMu sync.Mutex
childList dentryList
}
-func (fs *filesystem) newDirectory(creds *auth.Credentials, mode linux.FileMode) *inode {
+func (fs *filesystem) newDirectory(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *directory {
dir := &directory{}
- dir.inode.init(dir, fs, creds, mode)
+ dir.inode.init(dir, fs, kuid, kgid, linux.S_IFDIR|mode)
dir.inode.nlink = 2 // from "." and parent directory or ".." for root
- return &dir.inode
+ dir.dentry.inode = &dir.inode
+ dir.dentry.vfsd.Init(&dir.dentry)
+ return dir
+}
+
+// Preconditions: filesystem.mu must be locked for writing. dir must not
+// already contain a child with the given name.
+func (dir *directory) insertChildLocked(child *dentry, name string) {
+ child.parent = &dir.dentry
+ child.name = name
+ if dir.childMap == nil {
+ dir.childMap = make(map[string]*dentry)
+ }
+ dir.childMap[name] = child
+ atomic.AddInt64(&dir.numChildren, 1)
+ dir.iterMu.Lock()
+ dir.childList.PushBack(child)
+ dir.iterMu.Unlock()
+}
+
+// Preconditions: filesystem.mu must be locked for writing.
+func (dir *directory) removeChildLocked(child *dentry) {
+ delete(dir.childMap, child.name)
+ atomic.AddInt64(&dir.numChildren, -1)
+ dir.iterMu.Lock()
+ dir.childList.Remove(child)
+ dir.iterMu.Unlock()
}
-func (i *inode) isDir() bool {
- _, ok := i.impl.(*directory)
- return ok
+func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error {
+ return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid)))
}
type directoryFD struct {
fileDescription
vfs.DirectoryFileDescriptionDefaultImpl
- // Protected by filesystem.mu.
+ // Protected by directory.iterMu.
iter *dentry
off int64
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *directoryFD) Release() {
+func (fd *directoryFD) Release(ctx context.Context) {
if fd.iter != nil {
- fs := fd.filesystem()
dir := fd.inode().impl.(*directory)
- fs.mu.Lock()
+ dir.iterMu.Lock()
dir.childList.Remove(fd.iter)
- fs.mu.Unlock()
+ dir.iterMu.Unlock()
fd.iter = nil
}
}
@@ -68,36 +108,43 @@ func (fd *directoryFD) Release() {
// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
fs := fd.filesystem()
- vfsd := fd.vfsfd.VirtualDentry().Dentry()
+ dir := fd.inode().impl.(*directory)
+
+ defer fd.dentry().InotifyWithParent(ctx, linux.IN_ACCESS, 0, vfs.PathEvent)
- fs.mu.Lock()
- defer fs.mu.Unlock()
+ // fs.mu is required to read d.parent and dentry.name.
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ dir.iterMu.Lock()
+ defer dir.iterMu.Unlock()
+
+ fd.inode().touchAtime(fd.vfsfd.Mount())
if fd.off == 0 {
- if !cb.Handle(vfs.Dirent{
+ if err := cb.Handle(vfs.Dirent{
Name: ".",
Type: linux.DT_DIR,
- Ino: vfsd.Impl().(*dentry).inode.ino,
+ Ino: dir.inode.ino,
NextOff: 1,
- }) {
- return nil
+ }); err != nil {
+ return err
}
fd.off++
}
+
if fd.off == 1 {
- parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
- if !cb.Handle(vfs.Dirent{
+ parentInode := genericParentOrSelf(&dir.dentry).inode
+ if err := cb.Handle(vfs.Dirent{
Name: "..",
Type: parentInode.direntType(),
Ino: parentInode.ino,
NextOff: 2,
- }) {
- return nil
+ }); err != nil {
+ return err
}
fd.off++
}
- dir := vfsd.Impl().(*dentry).inode.impl.(*directory)
var child *dentry
if fd.iter == nil {
// Start iteration at the beginning of dir.
@@ -111,14 +158,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
for child != nil {
// Skip other directoryFD iterators.
if child.inode != nil {
- if !cb.Handle(vfs.Dirent{
- Name: child.vfsd.Name(),
+ if err := cb.Handle(vfs.Dirent{
+ Name: child.name,
Type: child.inode.direntType(),
Ino: child.inode.ino,
NextOff: fd.off + 1,
- }) {
+ }); err != nil {
dir.childList.InsertBefore(child, fd.iter)
- return nil
+ return err
}
fd.off++
}
@@ -130,9 +177,9 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
- fs := fd.filesystem()
- fs.mu.Lock()
- defer fs.mu.Unlock()
+ dir := fd.inode().impl.(*directory)
+ dir.iterMu.Lock()
+ defer dir.iterMu.Unlock()
switch whence {
case linux.SEEK_SET:
@@ -160,8 +207,6 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in
remChildren = offset - 2
}
- dir := fd.inode().impl.(*directory)
-
// Ensure that fd.iter exists and is not linked into dir.childList.
if fd.iter == nil {
fd.iter = &dentry{}
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
new file mode 100644
index 000000000..cb8b2d944
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -0,0 +1,860 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // All filesystem state is in-memory.
+ return nil
+}
+
+// stepLocked resolves rp.Component() to an existing file, starting from the
+// given directory.
+//
+// stepLocked is loosely analogous to fs/namei.c:walk_component().
+//
+// Preconditions: filesystem.mu must be locked. !rp.Done().
+func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*dentry, error) {
+ dir, ok := d.inode.impl.(*directory)
+ if !ok {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+afterSymlink:
+ name := rp.Component()
+ if name == "." {
+ rp.Advance()
+ return d, nil
+ }
+ if name == ".." {
+ if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil {
+ return nil, err
+ } else if isRoot || d.parent == nil {
+ rp.Advance()
+ return d, nil
+ }
+ if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil {
+ return nil, err
+ }
+ rp.Advance()
+ return d.parent, nil
+ }
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ child, ok := dir.childMap[name]
+ if !ok {
+ return nil, syserror.ENOENT
+ }
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
+ return nil, err
+ }
+ if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
+ // Symlink traversal updates access time.
+ child.inode.touchAtime(rp.Mount())
+ if err := rp.HandleSymlink(symlink.target); err != nil {
+ return nil, err
+ }
+ goto afterSymlink // don't check the current directory again
+ }
+ rp.Advance()
+ return child, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// walkParentDirLocked is loosely analogous to Linux's
+// fs/namei.c:path_parentat().
+//
+// Preconditions: filesystem.mu must be locked. !rp.Done().
+func walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*directory, error) {
+ for !rp.Final() {
+ next, err := stepLocked(ctx, rp, d)
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ dir, ok := d.inode.impl.(*directory)
+ if !ok {
+ return nil, syserror.ENOTDIR
+ }
+ return dir, nil
+}
+
+// resolveLocked resolves rp to an existing file.
+//
+// resolveLocked is loosely analogous to Linux's fs/namei.c:path_lookupat().
+//
+// Preconditions: filesystem.mu must be locked.
+func resolveLocked(ctx context.Context, rp *vfs.ResolvingPath) (*dentry, error) {
+ d := rp.Start().Impl().(*dentry)
+ for !rp.Done() {
+ next, err := stepLocked(ctx, rp, d)
+ if err != nil {
+ return nil, err
+ }
+ d = next
+ }
+ if rp.MustBeDir() && !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return d, nil
+}
+
+// doCreateAt checks that creating a file at rp is permitted, then invokes
+// create to do so.
+//
+// doCreateAt is loosely analogous to a conjunction of Linux's
+// fs/namei.c:filename_create() and done_path_create().
+//
+// Preconditions: !rp.Done(). For the final path component in rp,
+// !rp.ShouldFollowSymlink().
+func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EEXIST
+ }
+ if len(name) > linux.NAME_MAX {
+ return syserror.ENAMETOOLONG
+ }
+ if _, ok := parentDir.childMap[name]; ok {
+ return syserror.EEXIST
+ }
+ if !dir && rp.MustBeDir() {
+ return syserror.ENOENT
+ }
+ // tmpfs never calls VFS.InvalidateDentry(), so parentDir.dentry can only
+ // be dead if it was deleted.
+ if parentDir.dentry.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ if err := create(parentDir, name); err != nil {
+ return err
+ }
+
+ ev := linux.IN_CREATE
+ if dir {
+ ev |= linux.IN_ISDIR
+ }
+ parentDir.inode.watches.Notify(ctx, name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */)
+ parentDir.inode.touchCMtime()
+ return nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return err
+ }
+ return d.inode.checkPermissions(creds, ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ dir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return nil, err
+ }
+ dir.dentry.IncRef()
+ return &dir.dentry.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error {
+ if rp.Mount() != vd.Mount() {
+ return syserror.EXDEV
+ }
+ d := vd.Dentry().Impl().(*dentry)
+ i := d.inode
+ if i.isDir() {
+ return syserror.EPERM
+ }
+ if err := vfs.MayLink(auth.CredentialsFromContext(ctx), linux.FileMode(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ return err
+ }
+ if i.nlink == 0 {
+ return syserror.ENOENT
+ }
+ if i.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ i.incLinksLocked()
+ i.watches.Notify(ctx, "", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */)
+ parentDir.insertChildLocked(fs.newDentry(i), name)
+ return nil
+ })
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ return fs.doCreateAt(ctx, rp, true /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
+ if parentDir.inode.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ parentDir.inode.incLinksLocked() // from child's ".."
+ childDir := fs.newDirectory(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
+ parentDir.insertChildLocked(&childDir.dentry, name)
+ return nil
+ })
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
+ var childInode *inode
+ switch opts.Mode.FileType() {
+ case linux.S_IFREG:
+ childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
+ case linux.S_IFIFO:
+ childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)
+ case linux.S_IFBLK:
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.BlockDevice, opts.DevMajor, opts.DevMinor)
+ case linux.S_IFCHR:
+ childInode = fs.newDeviceFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, vfs.CharDevice, opts.DevMajor, opts.DevMinor)
+ case linux.S_IFSOCK:
+ childInode = fs.newSocketFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode, opts.Endpoint)
+ default:
+ return syserror.EINVAL
+ }
+ child := fs.newDentry(childInode)
+ parentDir.insertChildLocked(child, name)
+ return nil
+ })
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ // Not yet supported.
+ return nil, syserror.EOPNOTSUPP
+ }
+
+ // Handle O_CREAT and !O_CREAT separately, since in the latter case we
+ // don't need fs.mu for writing.
+ if opts.Flags&linux.O_CREAT == 0 {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ return d.open(ctx, rp, &opts, false /* afterCreate */)
+ }
+
+ mustCreate := opts.Flags&linux.O_EXCL != 0
+ start := rp.Start().Impl().(*dentry)
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ if rp.Done() {
+ // Reject attempts to open mount root directory with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ return start.open(ctx, rp, &opts, false /* afterCreate */)
+ }
+afterTrailingSymlink:
+ parentDir, err := walkParentDirLocked(ctx, rp, start)
+ if err != nil {
+ return nil, err
+ }
+ // Check for search permission in the parent directory.
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ // Reject attempts to open directories with O_CREAT.
+ if rp.MustBeDir() {
+ return nil, syserror.EISDIR
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return nil, syserror.EISDIR
+ }
+ if len(name) > linux.NAME_MAX {
+ return nil, syserror.ENAMETOOLONG
+ }
+ // Determine whether or not we need to create a file.
+ child, ok := parentDir.childMap[name]
+ if !ok {
+ // Already checked for searchability above; now check for writability.
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ if err := rp.Mount().CheckBeginWrite(); err != nil {
+ return nil, err
+ }
+ defer rp.Mount().EndWrite()
+ // Create and open the child.
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode))
+ parentDir.insertChildLocked(child, name)
+ fd, err := child.open(ctx, rp, &opts, true)
+ if err != nil {
+ return nil, err
+ }
+ parentDir.inode.watches.Notify(ctx, name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */)
+ parentDir.inode.touchCMtime()
+ return fd, nil
+ }
+ if mustCreate {
+ return nil, syserror.EEXIST
+ }
+ // Is the file mounted over?
+ if err := rp.CheckMount(ctx, &child.vfsd); err != nil {
+ return nil, err
+ }
+ // Do we need to resolve a trailing symlink?
+ if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() {
+ // Symlink traversal updates access time.
+ child.inode.touchAtime(rp.Mount())
+ if err := rp.HandleSymlink(symlink.target); err != nil {
+ return nil, err
+ }
+ start = &parentDir.dentry
+ goto afterTrailingSymlink
+ }
+ if rp.MustBeDir() && !child.inode.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ return child.open(ctx, rp, &opts, false)
+}
+
+func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) {
+ ats := vfs.AccessTypesForOpenFlags(opts)
+ if !afterCreate {
+ if err := d.inode.checkPermissions(rp.Credentials(), ats); err != nil {
+ return nil, err
+ }
+ }
+ switch impl := d.inode.impl.(type) {
+ case *regularFile:
+ var fd regularFileFD
+ fd.LockFD.Init(&d.inode.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil {
+ return nil, err
+ }
+ if !afterCreate && opts.Flags&linux.O_TRUNC != 0 {
+ if _, err := impl.truncate(0); err != nil {
+ return nil, err
+ }
+ }
+ return &fd.vfsfd, nil
+ case *directory:
+ // Can't open directories writably.
+ if ats&vfs.MayWrite != 0 {
+ return nil, syserror.EISDIR
+ }
+ var fd directoryFD
+ fd.LockFD.Init(&d.inode.locks)
+ if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+ case *symlink:
+ // TODO(gvisor.dev/issue/2782): Can't open symlinks without O_PATH.
+ return nil, syserror.ELOOP
+ case *namedPipe:
+ return impl.pipe.Open(ctx, rp.Mount(), &d.vfsd, opts.Flags, &d.inode.locks)
+ case *deviceFile:
+ return rp.VirtualFilesystem().OpenDeviceSpecialFile(ctx, rp.Mount(), &d.vfsd, impl.kind, impl.major, impl.minor, opts)
+ case *socketFile:
+ return nil, syserror.ENXIO
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", d.inode.impl))
+ }
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return "", err
+ }
+ symlink, ok := d.inode.impl.(*symlink)
+ if !ok {
+ return "", syserror.EINVAL
+ }
+ symlink.inode.touchAtime(rp.Mount())
+ return symlink.target, nil
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ if opts.Flags != 0 {
+ // TODO(b/145974740): Support renameat2 flags.
+ return syserror.EINVAL
+ }
+
+ // Resolve newParent first to verify that it's on this Mount.
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ newName := rp.Component()
+ if newName == "." || newName == ".." {
+ return syserror.EBUSY
+ }
+ mnt := rp.Mount()
+ if mnt != oldParentVD.Mount() {
+ return syserror.EXDEV
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+
+ oldParentDir := oldParentVD.Dentry().Impl().(*dentry).inode.impl.(*directory)
+ if err := oldParentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ renamed, ok := oldParentDir.childMap[oldName]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if err := oldParentDir.mayDelete(rp.Credentials(), renamed); err != nil {
+ return err
+ }
+ // Note that we don't need to call rp.CheckMount(), since if renamed is a
+ // mount point then we want to rename the mount point, not anything in the
+ // mounted filesystem.
+ if renamed.inode.isDir() {
+ if renamed == &newParentDir.dentry || genericIsAncestorDentry(renamed, &newParentDir.dentry) {
+ return syserror.EINVAL
+ }
+ if oldParentDir != newParentDir {
+ // Writability is needed to change renamed's "..".
+ if err := renamed.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return err
+ }
+ }
+ } else {
+ if opts.MustBeDir || rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ }
+
+ if err := newParentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ replaced, ok := newParentDir.childMap[newName]
+ if ok {
+ replacedDir, ok := replaced.inode.impl.(*directory)
+ if ok {
+ if !renamed.inode.isDir() {
+ return syserror.EISDIR
+ }
+ if len(replacedDir.childMap) != 0 {
+ return syserror.ENOTEMPTY
+ }
+ } else {
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ if renamed.inode.isDir() {
+ return syserror.ENOTDIR
+ }
+ }
+ } else {
+ if renamed.inode.isDir() && newParentDir.inode.nlink == maxLinks {
+ return syserror.EMLINK
+ }
+ }
+ // tmpfs never calls VFS.InvalidateDentry(), so newParentDir.dentry can
+ // only be dead if it was deleted.
+ if newParentDir.dentry.vfsd.IsDead() {
+ return syserror.ENOENT
+ }
+
+ // Linux places this check before some of those above; we do it here for
+ // simplicity, under the assumption that applications are not intentionally
+ // doing noop renames expecting them to succeed where non-noop renames
+ // would fail.
+ if renamed == replaced {
+ return nil
+ }
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ var replacedVFSD *vfs.Dentry
+ if replaced != nil {
+ replacedVFSD = &replaced.vfsd
+ }
+ if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil {
+ return err
+ }
+ if replaced != nil {
+ newParentDir.removeChildLocked(replaced)
+ if replaced.inode.isDir() {
+ // Remove links for replaced/. and replaced/..
+ replaced.inode.decLinksLocked(ctx)
+ newParentDir.inode.decLinksLocked(ctx)
+ }
+ replaced.inode.decLinksLocked(ctx)
+ }
+ oldParentDir.removeChildLocked(renamed)
+ newParentDir.insertChildLocked(renamed, newName)
+ vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD)
+ oldParentDir.inode.touchCMtime()
+ if oldParentDir != newParentDir {
+ if renamed.inode.isDir() {
+ oldParentDir.inode.decLinksLocked(ctx)
+ newParentDir.inode.incLinksLocked()
+ }
+ newParentDir.inode.touchCMtime()
+ }
+ renamed.inode.touchCtime()
+
+ vfs.InotifyRename(ctx, &renamed.inode.watches, &oldParentDir.inode.watches, &newParentDir.inode.watches, oldName, newName, renamed.inode.isDir())
+ return nil
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." {
+ return syserror.EINVAL
+ }
+ if name == ".." {
+ return syserror.ENOTEMPTY
+ }
+ child, ok := parentDir.childMap[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
+ childDir, ok := child.inode.impl.(*directory)
+ if !ok {
+ return syserror.ENOTDIR
+ }
+ if len(childDir.childMap) != 0 {
+ return syserror.ENOTEMPTY
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+ parentDir.removeChildLocked(child)
+ parentDir.inode.watches.Notify(ctx, name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */)
+ // Remove links for child, child/., and child/..
+ child.inode.decLinksLocked(ctx)
+ child.inode.decLinksLocked(ctx)
+ parentDir.inode.decLinksLocked(ctx)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
+ parentDir.inode.touchCMtime()
+ return nil
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ fs.mu.RLock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ fs.mu.RUnlock()
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ var stat linux.Statx
+ d.inode.statTo(&stat)
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ if _, err := resolveLocked(ctx, rp); err != nil {
+ return linux.Statfs{}, err
+ }
+ statfs := linux.Statfs{
+ Type: linux.TMPFS_MAGIC,
+ BlockSize: usermem.PageSize,
+ FragmentSize: usermem.PageSize,
+ NameLength: linux.NAME_MAX,
+ // TODO(b/29637826): Allow configuring a tmpfs size and enforce it.
+ Blocks: 0,
+ BlocksFree: 0,
+ }
+ return statfs, nil
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ return fs.doCreateAt(ctx, rp, false /* dir */, func(parentDir *directory, name string) error {
+ creds := rp.Credentials()
+ child := fs.newDentry(fs.newSymlink(creds.EffectiveKUID, creds.EffectiveKGID, 0777, target))
+ parentDir.insertChildLocked(child, name)
+ return nil
+ })
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ fs.mu.Lock()
+ defer fs.mu.Unlock()
+ parentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry))
+ if err != nil {
+ return err
+ }
+ if err := parentDir.inode.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil {
+ return err
+ }
+ name := rp.Component()
+ if name == "." || name == ".." {
+ return syserror.EISDIR
+ }
+ child, ok := parentDir.childMap[name]
+ if !ok {
+ return syserror.ENOENT
+ }
+ if err := parentDir.mayDelete(rp.Credentials(), child); err != nil {
+ return err
+ }
+ if child.inode.isDir() {
+ return syserror.EISDIR
+ }
+ if rp.MustBeDir() {
+ return syserror.ENOTDIR
+ }
+ mnt := rp.Mount()
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ defer mnt.EndWrite()
+ vfsObj := rp.VirtualFilesystem()
+ mntns := vfs.MountNamespaceFromContext(ctx)
+ defer mntns.DecRef(ctx)
+ if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil {
+ return err
+ }
+
+ // Generate inotify events. Note that this must take place before the link
+ // count of the child is decremented, or else the watches may be dropped
+ // before these events are added.
+ vfs.InotifyRemoveChild(ctx, &child.inode.watches, &parentDir.inode.watches, name)
+
+ parentDir.removeChildLocked(child)
+ child.inode.decLinksLocked(ctx)
+ vfsObj.CommitDeleteDentry(ctx, &child.vfsd)
+ parentDir.inode.touchCMtime()
+ return nil
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ if err := d.inode.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil {
+ return nil, err
+ }
+ switch impl := d.inode.impl.(type) {
+ case *socketFile:
+ return impl.ep, nil
+ default:
+ return nil, syserror.ECONNREFUSED
+ }
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return nil, err
+ }
+ return d.inode.listxattr(size)
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ return "", err
+ }
+ return d.inode.getxattr(rp.Credentials(), &opts)
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ fs.mu.RLock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ fs.mu.RUnlock()
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ fs.mu.RLock()
+ d, err := resolveLocked(ctx, rp)
+ if err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ if err := d.inode.removexattr(rp.Credentials(), name); err != nil {
+ fs.mu.RUnlock()
+ return err
+ }
+ fs.mu.RUnlock()
+
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.mu.RLock()
+ defer fs.mu.RUnlock()
+ mnt := vd.Mount()
+ d := vd.Dentry().Impl().(*dentry)
+ for {
+ if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
+ return vfs.PrependPathAtVFSRootError{}
+ }
+ if &d.vfsd == mnt.Root() {
+ return nil
+ }
+ if d.parent == nil {
+ if d.name != "" {
+ // This must be an anonymous memfd file.
+ b.PrependComponent("/" + d.name)
+ return vfs.PrependPathSyntheticError{}
+ }
+ return vfs.PrependPathAtNonMountRootError{}
+ }
+ b.PrependComponent(d.name)
+ d = d.parent
+ }
+}
diff --git a/pkg/sentry/fsimpl/memfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
index 732ed7c58..739350cf0 100644
--- a/pkg/sentry/fsimpl/memfs/named_pipe.go
+++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go
@@ -12,15 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type namedPipe struct {
@@ -32,28 +30,9 @@ type namedPipe struct {
// Preconditions:
// * fs.mu must be locked.
// * rp.Mount().CheckBeginWrite() has been called successfully.
-func (fs *filesystem) newNamedPipe(creds *auth.Credentials, mode linux.FileMode) *inode {
- file := &namedPipe{pipe: pipe.NewVFSPipe(pipe.DefaultPipeSize, usermem.PageSize)}
- file.inode.init(file, fs, creds, mode)
+func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
+ file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)}
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode)
file.inode.nlink = 1 // Only the parent has a link.
return &file.inode
}
-
-// namedPipeFD implements vfs.FileDescriptionImpl. Methods are implemented
-// entirely via struct embedding.
-type namedPipeFD struct {
- fileDescription
-
- *pipe.VFSPipeFD
-}
-
-func newNamedPipeFD(ctx context.Context, np *namedPipe, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, flags uint32) (*vfs.FileDescription, error) {
- var err error
- var fd namedPipeFD
- fd.VFSPipeFD, err = np.pipe.NewVFSPipeFD(ctx, rp, vfsd, &fd.vfsfd, flags)
- if err != nil {
- return nil, err
- }
- fd.vfsfd.Init(&fd, rp.Mount(), vfsd)
- return &fd.vfsfd, nil
-}
diff --git a/pkg/sentry/fsimpl/memfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
index 0674b81a3..ec2701d8b 100644
--- a/pkg/sentry/fsimpl/memfs/pipe_test.go
+++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go
@@ -12,33 +12,34 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
"bytes"
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const fileName = "mypipe"
func TestSeparateFDs(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the read side. This is done in a concurrently because opening
// One end the pipe blocks until the other end is opened.
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
rfdchan := make(chan *vfs.FileDescription)
@@ -54,13 +55,13 @@ func TestSeparateFDs(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
}
- defer wfd.DecRef()
+ defer wfd.DecRef(ctx)
rfd, ok := <-rfdchan
if !ok {
t.Fatalf("failed to open pipe for reading %q", fileName)
}
- defer rfd.DecRef()
+ defer rfd.DecRef(ctx)
const msg = "vamos azul"
checkEmpty(ctx, t, rfd)
@@ -70,13 +71,13 @@ func TestSeparateFDs(t *testing.T) {
func TestNonblockingRead(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the read side as nonblocking.
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
openOpts := vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_NONBLOCK}
@@ -84,7 +85,7 @@ func TestNonblockingRead(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for reading %q: %v", fileName, err)
}
- defer rfd.DecRef()
+ defer rfd.DecRef(ctx)
// Open the write side.
openOpts = vfs.OpenOptions{Flags: linux.O_WRONLY}
@@ -92,7 +93,7 @@ func TestNonblockingRead(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
}
- defer wfd.DecRef()
+ defer wfd.DecRef(ctx)
const msg = "geh blau"
checkEmpty(ctx, t, rfd)
@@ -102,13 +103,13 @@ func TestNonblockingRead(t *testing.T) {
func TestNonblockingWriteError(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the write side as nonblocking, which should return ENXIO.
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
openOpts := vfs.OpenOptions{Flags: linux.O_WRONLY | linux.O_NONBLOCK}
@@ -120,13 +121,13 @@ func TestNonblockingWriteError(t *testing.T) {
func TestSingleFD(t *testing.T) {
ctx, creds, vfsObj, root := setup(t)
- defer root.DecRef()
+ defer root.DecRef(ctx)
// Open the pipe as readable and writable.
pop := vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}
openOpts := vfs.OpenOptions{Flags: linux.O_RDWR}
@@ -134,7 +135,7 @@ func TestSingleFD(t *testing.T) {
if err != nil {
t.Fatalf("failed to open pipe for writing %q: %v", fileName, err)
}
- defer fd.DecRef()
+ defer fd.DecRef(ctx)
const msg = "forza blu"
checkEmpty(ctx, t, fd)
@@ -150,9 +151,14 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
creds := auth.CredentialsFromContext(ctx)
// Create VFS.
- vfsObj := vfs.New()
- vfsObj.MustRegisterFilesystemType("memfs", FilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "memfs", &vfs.NewFilesystemOptions{})
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
if err != nil {
t.Fatalf("failed to create tmpfs root mount: %v", err)
}
@@ -160,10 +166,9 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
// Create the pipe.
root := mntns.Root()
pop := vfs.PathOperation{
- Root: root,
- Start: root,
- Pathname: fileName,
- FollowFinalSymlink: true,
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(fileName),
}
mknodOpts := vfs.MknodOptions{Mode: linux.ModeNamedPipe | 0644}
if err := vfsObj.MknodAt(ctx, creds, &pop, &mknodOpts); err != nil {
@@ -174,7 +179,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{
Root: root,
Start: root,
- Pathname: fileName,
+ Path: fspath.Parse(fileName),
FollowFinalSymlink: true,
}, &vfs.StatOptions{})
if err != nil {
@@ -194,7 +199,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy
func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
readData := make([]byte, 1)
dst := usermem.BytesIOSequence(readData)
- bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
if err != syserror.ErrWouldBlock {
t.Fatalf("expected ErrWouldBlock reading from empty pipe %q, but got: %v", fileName, err)
}
@@ -207,7 +212,7 @@ func checkEmpty(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
writeData := []byte(msg)
src := usermem.BytesIOSequence(writeData)
- bytesWritten, err := fd.Impl().Write(ctx, src, vfs.WriteOptions{})
+ bytesWritten, err := fd.Write(ctx, src, vfs.WriteOptions{})
if err != nil {
t.Fatalf("error writing to pipe %q: %v", fileName, err)
}
@@ -220,7 +225,7 @@ func checkWrite(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg
func checkRead(ctx context.Context, t *testing.T, fd *vfs.FileDescription, msg string) {
readData := make([]byte, len(msg))
dst := usermem.BytesIOSequence(readData)
- bytesRead, err := fd.Impl().Read(ctx, dst, vfs.ReadOptions{})
+ bytesRead, err := fd.Read(ctx, dst, vfs.ReadOptions{})
if err != nil {
t.Fatalf("error reading from pipe %q: %v", fileName, err)
}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
new file mode 100644
index 000000000..0710b65db
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -0,0 +1,637 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "io"
+ "math"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// regularFile is a regular (=S_IFREG) tmpfs file.
+type regularFile struct {
+ inode inode
+
+ // memFile is a platform.File used to allocate pages to this regularFile.
+ memFile *pgalloc.MemoryFile
+
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // Protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ // dataMu protects the fields below.
+ dataMu sync.RWMutex
+
+ // data maps offsets into the file to offsets into memFile that store
+ // the file's data.
+ //
+ // Protected by dataMu.
+ data fsutil.FileRangeSet
+
+ // seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
+ seals uint32
+
+ // size is the size of data.
+ //
+ // Protected by both dataMu and inode.mu; reading it requires holding
+ // either mutex, while writing requires holding both AND using atomics.
+ // Readers that do not require consistency (like Stat) may read the
+ // value atomically without holding either lock.
+ size uint64
+}
+
+func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode {
+ file := &regularFile{
+ memFile: fs.memFile,
+ seals: linux.F_SEAL_SEAL,
+ }
+ file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
+
+// truncate grows or shrinks the file to the given size. It returns true if the
+// file size was updated.
+func (rf *regularFile) truncate(newSize uint64) (bool, error) {
+ rf.inode.mu.Lock()
+ defer rf.inode.mu.Unlock()
+ return rf.truncateLocked(newSize)
+}
+
+// Preconditions: rf.inode.mu must be held.
+func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
+ oldSize := rf.size
+ if newSize == oldSize {
+ // Nothing to do.
+ return false, nil
+ }
+
+ // Need to hold inode.mu and dataMu while modifying size.
+ rf.dataMu.Lock()
+ if newSize > oldSize {
+ // Can we grow the file?
+ if rf.seals&linux.F_SEAL_GROW != 0 {
+ rf.dataMu.Unlock()
+ return false, syserror.EPERM
+ }
+ // We only need to update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+ return true, nil
+ }
+
+ // We are shrinking the file. First check if this is allowed.
+ if rf.seals&linux.F_SEAL_SHRINK != 0 {
+ rf.dataMu.Unlock()
+ return false, syserror.EPERM
+ }
+
+ // Update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+
+ // Invalidate past translations of truncated pages.
+ oldpgend := fs.OffsetPageEnd(int64(oldSize))
+ newpgend := fs.OffsetPageEnd(int64(newSize))
+ if newpgend < oldpgend {
+ rf.mapsMu.Lock()
+ rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ rf.mapsMu.Unlock()
+ }
+
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ rf.dataMu.Lock()
+ rf.data.Truncate(newSize, rf.memFile)
+ rf.dataMu.Unlock()
+ return true, nil
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if rf.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ rf.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+
+ rf.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return rf.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ rf.dataMu.Lock()
+ defer rf.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(int64(rf.size))
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: rf.memFile,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (*regularFile) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
+type regularFileFD struct {
+ fileDescription
+
+ // off is the file offset. off is accessed using atomic memory operations.
+ // offMu serializes operations that may mutate off.
+ off int64
+ offMu sync.Mutex
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *regularFileFD) Release(context.Context) {
+ // noop
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ f := fd.inode().impl.(*regularFile)
+
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ oldSize := f.size
+ size := offset + length
+ if oldSize >= size {
+ return nil
+ }
+ _, err := f.truncateLocked(size)
+ return err
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
+ // all state is in-memory.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := dst.CopyOutFrom(ctx, rw)
+ putRegularFileReadWriter(rw)
+ fd.inode().touchAtime(fd.vfsfd.Mount())
+ return n, err
+}
+
+// Read implements vfs.FileDescriptionImpl.Read.
+func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, err := fd.PRead(ctx, dst, fd.off, opts)
+ fd.off += n
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.PWrite.
+func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, _, err := fd.pwrite(ctx, src, offset, opts)
+ return n, err
+}
+
+// pwrite returns the number of bytes written, final offset and error. The
+// final offset should be ignored by PWrite.
+func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) {
+ if offset < 0 {
+ return 0, offset, syserror.EINVAL
+ }
+
+ // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since
+ // all state is in-memory.
+ //
+ // TODO(gvisor.dev/issue/2601): Support select preadv2 flags.
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, offset, syserror.EOPNOTSUPP
+ }
+
+ srclen := src.NumBytes()
+ if srclen == 0 {
+ return 0, offset, nil
+ }
+ f := fd.inode().impl.(*regularFile)
+ f.inode.mu.Lock()
+ defer f.inode.mu.Unlock()
+ // If the file is opened with O_APPEND, update offset to file size.
+ if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 {
+ // Locking f.inode.mu is sufficient for reading f.size.
+ offset = int64(f.size)
+ }
+ if end := offset + srclen; end < offset {
+ // Overflow.
+ return 0, offset, syserror.EINVAL
+ }
+
+ srclen, err = vfs.CheckLimit(ctx, offset, srclen)
+ if err != nil {
+ return 0, offset, err
+ }
+ src = src.TakeFirst64(srclen)
+
+ rw := getRegularFileReadWriter(f, offset)
+ n, err := src.CopyInTo(ctx, rw)
+ f.inode.touchCMtimeLocked()
+ putRegularFileReadWriter(rw)
+ return n, n + offset, err
+}
+
+// Write implements vfs.FileDescriptionImpl.Write.
+func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.offMu.Lock()
+ n, off, err := fd.pwrite(ctx, src, fd.off, opts)
+ fd.off = off
+ fd.offMu.Unlock()
+ return n, err
+}
+
+// Seek implements vfs.FileDescriptionImpl.Seek.
+func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ fd.offMu.Lock()
+ defer fd.offMu.Unlock()
+ switch whence {
+ case linux.SEEK_SET:
+ // use offset as specified
+ case linux.SEEK_CUR:
+ offset += fd.off
+ case linux.SEEK_END:
+ offset += int64(atomic.LoadUint64(&fd.inode().impl.(*regularFile).size))
+ default:
+ return 0, syserror.EINVAL
+ }
+ if offset < 0 {
+ return 0, syserror.EINVAL
+ }
+ fd.off = offset
+ return offset, nil
+}
+
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ file := fd.inode().impl.(*regularFile)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
+}
+
+// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
+type regularFileReadWriter struct {
+ file *regularFile
+
+ // Offset into the file to read/write at. Note that this may be
+ // different from the FD offset if PRead/PWrite is used.
+ off uint64
+}
+
+var regularFileReadWriterPool = sync.Pool{
+ New: func() interface{} {
+ return &regularFileReadWriter{}
+ },
+}
+
+func getRegularFileReadWriter(file *regularFile, offset int64) *regularFileReadWriter {
+ rw := regularFileReadWriterPool.Get().(*regularFileReadWriter)
+ rw.file = file
+ rw.off = uint64(offset)
+ return rw
+}
+
+func putRegularFileReadWriter(rw *regularFileReadWriter) {
+ rw.file = nil
+ regularFileReadWriterPool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ rw.file.dataMu.RLock()
+ defer rw.file.dataMu.RUnlock()
+ size := rw.file.size
+
+ // Compute the range to read (limited by file size and overflow-checked).
+ if rw.off >= size {
+ return 0, io.EOF
+ }
+ end := size
+ if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
+ end = rend
+ }
+
+ var done uint64
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
+ if err != nil {
+ return done, err
+ }
+
+ // Copy from internal mappings.
+ n, err := safemem.CopySeq(dsts, ims)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Tmpfs holes are zero-filled.
+ gapmr := gap.Range().Intersect(mr)
+ dst := dsts.TakeFirst64(gapmr.Length())
+ n, err := safemem.ZeroSeq(dst)
+ done += n
+ rw.off += uint64(n)
+ dsts = dsts.DropFirst64(n)
+ if err != nil {
+ return done, err
+ }
+
+ // Continue.
+ seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
+ }
+ }
+ return done, nil
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: inode.mu must be held.
+func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ // Hold dataMu so we can modify size.
+ rw.file.dataMu.Lock()
+ defer rw.file.dataMu.Unlock()
+
+ // Compute the range to write (overflow-checked).
+ end := rw.off + srcs.NumBytes()
+ if end <= rw.off {
+ end = math.MaxInt64
+ }
+
+ // Check if seals prevent either file growth or all writes.
+ switch {
+ case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
+ return 0, syserror.EPERM
+ case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
+ // When growth is sealed, Linux effectively allows writes which would
+ // normally grow the file to partially succeed up to the current EOF,
+ // rounded down to the page boundary before the EOF.
+ //
+ // This happens because writes (and thus the growth check) for tmpfs
+ // files proceed page-by-page on Linux, and the final write to the page
+ // containing EOF fails, resulting in a partial write up to the start of
+ // that page.
+ //
+ // To emulate this behaviour, artifically truncate the write to the
+ // start of the page containing the current EOF.
+ //
+ // See Linux, mm/filemap.c:generic_perform_write() and
+ // mm/shmem.c:shmem_write_begin().
+ if pgstart := uint64(usermem.Addr(rw.file.size).RoundDown()); end > pgstart {
+ end = pgstart
+ }
+ if end <= rw.off {
+ // Truncation would result in no data being written.
+ return 0, syserror.EPERM
+ }
+ }
+
+ // Page-aligned mr for when we need to allocate memory. RoundUp can't
+ // overflow since end is an int64.
+ pgstartaddr := usermem.Addr(rw.off).RoundDown()
+ pgendaddr, _ := usermem.Addr(end).RoundUp()
+ pgMR := memmap.MappableRange{uint64(pgstartaddr), uint64(pgendaddr)}
+
+ var (
+ done uint64
+ retErr error
+ )
+ seg, gap := rw.file.data.Find(uint64(rw.off))
+ for rw.off < end {
+ mr := memmap.MappableRange{uint64(rw.off), uint64(end)}
+ switch {
+ case seg.Ok():
+ // Get internal mappings.
+ ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Write)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Copy to internal mappings.
+ n, err := safemem.CopySeq(ims, srcs)
+ done += n
+ rw.off += uint64(n)
+ srcs = srcs.DropFirst64(n)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Continue.
+ seg, gap = seg.NextNonEmpty()
+
+ case gap.Ok():
+ // Allocate memory for the write.
+ gapMR := gap.Range().Intersect(pgMR)
+ fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs)
+ if err != nil {
+ retErr = err
+ goto exitLoop
+ }
+
+ // Write to that memory as usual.
+ seg, gap = rw.file.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{}
+ }
+ }
+exitLoop:
+ // If the write ends beyond the file's previous size, it causes the
+ // file to grow.
+ if rw.off > rw.file.size {
+ rw.file.size = rw.off
+ }
+
+ return done, retErr
+}
+
+// GetSeals returns the current set of seals on a memfd inode.
+func GetSeals(fd *vfs.FileDescription) (uint32, error) {
+ f, ok := fd.Impl().(*regularFileFD)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+ rf := f.inode().impl.(*regularFile)
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+ return rf.seals, nil
+}
+
+// AddSeals adds new file seals to a memfd inode.
+func AddSeals(fd *vfs.FileDescription, val uint32) error {
+ f, ok := fd.Impl().(*regularFileFD)
+ if !ok {
+ return syserror.EINVAL
+ }
+ rf := f.inode().impl.(*regularFile)
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ if rf.seals&linux.F_SEAL_SEAL != 0 {
+ // Seal applied which prevents addition of any new seals.
+ return syserror.EPERM
+ }
+
+ // F_SEAL_WRITE can only be added if there are no active writable maps.
+ if rf.seals&linux.F_SEAL_WRITE == 0 && val&linux.F_SEAL_WRITE != 0 {
+ if rf.writableMappingPages > 0 {
+ return syserror.EBUSY
+ }
+ }
+
+ // Seals can only be added, never removed.
+ rf.seals |= val
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file_test.go b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
new file mode 100644
index 000000000..146c7fdfe
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file_test.go
@@ -0,0 +1,349 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Test that we can write some data to a file and read it back.`
+func TestSimpleWriteRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write.
+ data := []byte("foobarbaz")
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+
+ // Seek back to beginning.
+ if _, err := fd.Seek(ctx, 0, linux.SEEK_SET); err != nil {
+ t.Fatalf("fd.Seek failed: %v", err)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(0); got != want {
+ t.Errorf("fd.Seek(0) left offset at %d, want %d", got, want)
+ }
+
+ // Read.
+ buf := make([]byte, len(data))
+ n, err = fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Read got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(buf), string(data); got != want {
+ t.Errorf("Read got %q want %s", got, want)
+ }
+ if got, want := fd.Impl().(*regularFileFD).off, int64(len(data)); got != want {
+ t.Errorf("fd.Write left offset at %d, want %d", got, want)
+ }
+}
+
+func TestPWrite(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill file with 1k 'a's.
+ data := bytes.Repeat([]byte{'a'}, 1000)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Write "gVisor is awesome" at various offsets.
+ buf := []byte("gVisor is awesome")
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PWrite offset=%d", offset)
+ t.Run(name, func(t *testing.T) {
+ n, err := fd.PWrite(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.WriteOptions{})
+ if err != nil {
+ t.Errorf("fd.PWrite got err %v want nil", err)
+ }
+ if n != int64(len(buf)) {
+ t.Errorf("fd.PWrite got %d bytes want %d", n, len(buf))
+ }
+
+ // Update data to reflect expected file contents.
+ if len(data) < offset+len(buf) {
+ data = append(data, make([]byte, (offset+len(buf))-len(data))...)
+ }
+ copy(data[offset:], buf)
+
+ // Read the whole file and compare with data.
+ readBuf := make([]byte, len(data))
+ n, err = fd.PRead(ctx, usermem.BytesIOSequence(readBuf), 0, vfs.ReadOptions{})
+ if err != nil {
+ t.Fatalf("fd.PRead failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.PRead got short read length %d, want %d", n, len(data))
+ }
+ if got, want := string(readBuf), string(data); got != want {
+ t.Errorf("PRead got %q want %s", got, want)
+ }
+
+ })
+ }
+}
+
+func TestLocks(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ uid1 := 123
+ uid2 := 456
+ if err := fd.Impl().LockBSD(ctx, uid1, lock.ReadLock, nil); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.ReadLock, nil); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+ if got, want := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil), syserror.ErrWouldBlock; got != want {
+ t.Fatalf("fd.Impl().LockBSD failed: got = %v, want = %v", got, want)
+ }
+ if err := fd.Impl().UnlockBSD(ctx, uid1); err != nil {
+ t.Fatalf("fd.Impl().UnlockBSD failed: err = %v", err)
+ }
+ if err := fd.Impl().LockBSD(ctx, uid2, lock.WriteLock, nil); err != nil {
+ t.Fatalf("fd.Impl().LockBSD failed: err = %v", err)
+ }
+
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.ReadLock, 0, 1, linux.SEEK_SET, nil); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if err := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 1, 2, linux.SEEK_SET, nil); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if err := fd.Impl().LockPOSIX(ctx, uid1, lock.WriteLock, 0, 1, linux.SEEK_SET, nil); err != nil {
+ t.Fatalf("fd.Impl().LockPOSIX failed: err = %v", err)
+ }
+ if got, want := fd.Impl().LockPOSIX(ctx, uid2, lock.ReadLock, 0, 1, linux.SEEK_SET, nil), syserror.ErrWouldBlock; got != want {
+ t.Fatalf("fd.Impl().LockPOSIX failed: got = %v, want = %v", got, want)
+ }
+ if err := fd.Impl().UnlockPOSIX(ctx, uid1, 0, 1, linux.SEEK_SET); err != nil {
+ t.Fatalf("fd.Impl().UnlockPOSIX failed: err = %v", err)
+ }
+}
+
+func TestPRead(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Write 100 sequences of 'gVisor is awesome'.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ n, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+ if n != int64(len(data)) {
+ t.Errorf("fd.Write got short write length %d, want %d", n, len(data))
+ }
+
+ // Read various sizes from various offsets.
+ sizes := []int{0, 1, 2, 10, 20, 50, 100, 1000}
+ offsets := []int{0, 1, 2, 10, 20, 50, 100, 1000, len(data) - 100, len(data) - 1, len(data), len(data) + 1}
+
+ for _, size := range sizes {
+ for _, offset := range offsets {
+ name := fmt.Sprintf("PRead offset=%d size=%d", offset, size)
+ t.Run(name, func(t *testing.T) {
+ var (
+ wantRead []byte
+ wantErr error
+ )
+ if offset < len(data) {
+ wantRead = data[offset:]
+ } else if size > 0 {
+ wantErr = io.EOF
+ }
+ if offset+size < len(data) {
+ wantRead = wantRead[:size]
+ }
+ buf := make([]byte, size)
+ n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), int64(offset), vfs.ReadOptions{})
+ if err != wantErr {
+ t.Errorf("fd.PRead got err %v want %v", err, wantErr)
+ }
+ if n != int64(len(wantRead)) {
+ t.Errorf("fd.PRead got %d bytes want %d", n, len(wantRead))
+ }
+ if got := string(buf[:n]); got != string(wantRead) {
+ t.Errorf("fd.PRead got %q want %q", got, string(wantRead))
+ }
+ })
+ }
+ }
+}
+
+func TestTruncate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ // Fill the file with some data.
+ data := bytes.Repeat([]byte("gVisor is awsome"), 100)
+ written, err := fd.Write(ctx, usermem.BytesIOSequence(data), vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("fd.Write failed: %v", err)
+ }
+
+ // Size should be same as written.
+ sizeStatOpts := vfs.StatOptions{Mask: linux.STATX_SIZE}
+ stat, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := int64(stat.Size), written; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+
+ // Truncate down.
+ newSize := uint64(10)
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateDown, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateDown.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should only read newSize worth of data.
+ buf := make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateDown.Mtime.ToNsec(); got <= stat.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, stat.Mtime)
+ }
+ if got := statAfterTruncateDown.Ctime.ToNsec(); got <= stat.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate up.
+ newSize = 100
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ // Size should be updated.
+ statAfterTruncateUp, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ if got, want := statAfterTruncateUp.Size, newSize; got != want {
+ t.Errorf("fd.Stat got size %d, want %d", got, want)
+ }
+ // We should read newSize worth of data.
+ buf = make([]byte, 1000)
+ if n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0, vfs.ReadOptions{}); err != nil && err != io.EOF {
+ t.Fatalf("fd.PRead failed: %v", err)
+ } else if uint64(n) != newSize {
+ t.Errorf("fd.PRead got size %d, want %d", n, newSize)
+ }
+ // Bytes should be null after 10, since we previously truncated to 10.
+ for i := uint64(10); i < newSize; i++ {
+ if buf[i] != 0 {
+ t.Errorf("fd.PRead got byte %d=%x, want 0", i, buf[i])
+ break
+ }
+ }
+ // Mtime and Ctime should be bumped.
+ if got := statAfterTruncateUp.Mtime.ToNsec(); got <= statAfterTruncateDown.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want > %v", got, statAfterTruncateDown.Mtime)
+ }
+ if got := statAfterTruncateUp.Ctime.ToNsec(); got <= statAfterTruncateDown.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want > %v", got, stat.Ctime)
+ }
+
+ // Truncate to the current size.
+ newSize = statAfterTruncateUp.Size
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: newSize,
+ },
+ }); err != nil {
+ t.Errorf("fd.Truncate failed: %v", err)
+ }
+ statAfterTruncateNoop, err := fd.Stat(ctx, sizeStatOpts)
+ if err != nil {
+ t.Fatalf("fd.Stat failed: %v", err)
+ }
+ // Mtime and Ctime should not be bumped, since operation is a noop.
+ if got := statAfterTruncateNoop.Mtime.ToNsec(); got != statAfterTruncateUp.Mtime.ToNsec() {
+ t.Errorf("fd.Stat got Mtime %v, want %v", got, statAfterTruncateUp.Mtime)
+ }
+ if got := statAfterTruncateNoop.Ctime.ToNsec(); got != statAfterTruncateUp.Ctime.ToNsec() {
+ t.Errorf("fd.Stat got Ctime %v, want %v", got, statAfterTruncateUp.Ctime)
+ }
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/socket_file.go b/pkg/sentry/fsimpl/tmpfs/socket_file.go
new file mode 100644
index 000000000..3ed650474
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/socket_file.go
@@ -0,0 +1,34 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+)
+
+// socketFile is a socket (=S_IFSOCK) tmpfs file.
+type socketFile struct {
+ inode inode
+ ep transport.BoundEndpoint
+}
+
+func (fs *filesystem) newSocketFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, ep transport.BoundEndpoint) *inode {
+ file := &socketFile{ep: ep}
+ file.inode.init(file, fs, kuid, kgid, mode)
+ file.inode.nlink = 1 // from parent directory
+ return &file.inode
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/stat_test.go b/pkg/sentry/fsimpl/tmpfs/stat_test.go
new file mode 100644
index 000000000..f7ee4aab2
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/stat_test.go
@@ -0,0 +1,236 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+func TestStatAfterCreate(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ got, err := fd.Stat(ctx, vfs.StatOptions{})
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Atime, Ctime, Mtime should all be current time (non-zero).
+ atime, ctime, mtime := got.Atime.ToNsec(), got.Ctime.ToNsec(), got.Mtime.ToNsec()
+ if atime != ctime || ctime != mtime {
+ t.Errorf("got atime=%d ctime=%d mtime=%d, wanted equal values", atime, ctime, mtime)
+ }
+ if atime == 0 {
+ t.Errorf("got atime=%d, want non-zero", atime)
+ }
+
+ // Btime should be 0, as it is not set by tmpfs.
+ if btime := got.Btime.ToNsec(); btime != 0 {
+ t.Errorf("got btime %d, want 0", got.Btime.ToNsec())
+ }
+
+ // Size should be 0 (except for directories, which make up a size
+ // of 20 per entry, including the "." and ".." entries present in
+ // otherwise-empty directories).
+ wantSize := uint64(0)
+ if typ == "dir" {
+ wantSize = 40
+ }
+ if got.Size != wantSize {
+ t.Errorf("got size %d, want %d", got.Size, wantSize)
+ }
+
+ // Nlink should be 1 for files, 2 for dirs.
+ wantNlink := uint32(1)
+ if typ == "dir" {
+ wantNlink = 2
+ }
+ if got.Nlink != wantNlink {
+ t.Errorf("got nlink %d, want %d", got.Nlink, wantNlink)
+ }
+
+ // UID and GID are set from context creds.
+ creds := auth.CredentialsFromContext(ctx)
+ if got.UID != uint32(creds.EffectiveKUID) {
+ t.Errorf("got uid %d, want %d", got.UID, uint32(creds.EffectiveKUID))
+ }
+ if got.GID != uint32(creds.EffectiveKGID) {
+ t.Errorf("got gid %d, want %d", got.GID, uint32(creds.EffectiveKGID))
+ }
+
+ // Mode.
+ wantMode := uint16(mode)
+ switch typ {
+ case "file":
+ wantMode |= linux.S_IFREG
+ case "dir":
+ wantMode |= linux.S_IFDIR
+ case "pipe":
+ wantMode |= linux.S_IFIFO
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+
+ if got.Mode != wantMode {
+ t.Errorf("got mode %x, want %x", got.Mode, wantMode)
+ }
+
+ // Ino.
+ if got.Ino == 0 {
+ t.Errorf("got ino %d, want not 0", got.Ino)
+ }
+ })
+ }
+}
+
+func TestSetStatAtime(t *testing.T) {
+ ctx := contexttest.Context(t)
+ fd, cleanup, err := newFileFD(ctx, 0644)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v", err)
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v", err)
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+}
+
+func TestSetStat(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mode := linux.FileMode(0644)
+
+ // Run with different file types.
+ for _, typ := range []string{"file", "dir", "pipe"} {
+ t.Run(fmt.Sprintf("type=%q", typ), func(t *testing.T) {
+ var (
+ fd *vfs.FileDescription
+ cleanup func()
+ err error
+ )
+ switch typ {
+ case "file":
+ fd, cleanup, err = newFileFD(ctx, mode)
+ case "dir":
+ fd, cleanup, err = newDirFD(ctx, mode)
+ case "pipe":
+ fd, cleanup, err = newPipeFD(ctx, mode)
+ default:
+ panic(fmt.Sprintf("unknown typ %q", typ))
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ allStatOptions := vfs.StatOptions{Mask: linux.STATX_ALL}
+
+ // Get initial stat.
+ initialStat, err := fd.Stat(ctx, allStatOptions)
+ if err != nil {
+ t.Fatalf("Stat failed: %v", err)
+ }
+
+ // Set atime, but without the mask.
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: linux.Statx{
+ Mask: 0,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }}); err != nil {
+ t.Errorf("SetStat atime without mask failed: %v", err)
+ }
+ // Atime should be unchanged.
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != initialStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, initialStat.Atime)
+ }
+
+ // Set atime, this time included in the mask.
+ setStat := linux.Statx{
+ Mask: linux.STATX_ATIME,
+ Atime: linux.NsecToStatxTimestamp(100),
+ }
+ if err := fd.SetStat(ctx, vfs.SetStatOptions{Stat: setStat}); err != nil {
+ t.Errorf("SetStat atime with mask failed: %v", err)
+ }
+ if gotStat, err := fd.Stat(ctx, allStatOptions); err != nil {
+ t.Errorf("Stat got error: %v", err)
+ } else if gotStat.Atime != setStat.Atime {
+ t.Errorf("Stat got atime %d, want %d", gotStat.Atime, setStat.Atime)
+ }
+ })
+ }
+}
diff --git a/pkg/sentry/fsimpl/memfs/symlink.go b/pkg/sentry/fsimpl/tmpfs/symlink.go
index b2ac2cbeb..b0de5fabe 100644
--- a/pkg/sentry/fsimpl/memfs/symlink.go
+++ b/pkg/sentry/fsimpl/tmpfs/symlink.go
@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package memfs
+package tmpfs
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -23,11 +24,11 @@ type symlink struct {
target string // immutable
}
-func (fs *filesystem) newSymlink(creds *auth.Credentials, target string) *inode {
+func (fs *filesystem) newSymlink(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode, target string) *inode {
link := &symlink{
target: target,
}
- link.inode.init(link, fs, creds, 0777)
+ link.inode.init(link, fs, kuid, kgid, linux.S_IFLNK|mode)
link.inode.nlink = 1 // from parent directory
return &link.inode
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
new file mode 100644
index 000000000..de2af6d01
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -0,0 +1,775 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tmpfs provides an in-memory filesystem whose contents are
+// application-mutable, consistent with Linux's tmpfs.
+//
+// Lock order:
+//
+// filesystem.mu
+// inode.mu
+// regularFileFD.offMu
+// *** "memmap.Mappable locks" below this point
+// regularFile.mapsMu
+// *** "memmap.Mappable locks taken by Translate" below this point
+// regularFile.dataMu
+// directory.iterMu
+package tmpfs
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/pgalloc"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs/memxattr"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Name is the default filesystem name.
+const Name = "tmpfs"
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // memFile is used to allocate pages to for regular files.
+ memFile *pgalloc.MemoryFile
+
+ // clock is a realtime clock used to set timestamps in file operations.
+ clock time.Clock
+
+ // devMinor is the filesystem's minor device number. devMinor is immutable.
+ devMinor uint32
+
+ // mu serializes changes to the Dentry tree.
+ mu sync.RWMutex
+
+ nextInoMinusOne uint64 // accessed using atomic memory operations
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// FilesystemOpts is used to pass configuration data to tmpfs.
+type FilesystemOpts struct {
+ // RootFileType is the FileType of the filesystem root. Valid values
+ // are: S_IFDIR, S_IFREG, and S_IFLNK. Defaults to S_IFDIR.
+ RootFileType uint16
+
+ // RootSymlinkTarget is the target of the root symlink. Only valid if
+ // RootFileType == S_IFLNK.
+ RootSymlinkTarget string
+
+ // FilesystemType allows setting a different FilesystemType for this
+ // tmpfs filesystem. This allows tmpfs to "impersonate" other
+ // filesystems, like ramdiskfs and cgroupfs.
+ FilesystemType vfs.FilesystemType
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx)
+ if memFileProvider == nil {
+ panic("MemoryFileProviderFromContext returned nil")
+ }
+
+ rootFileType := uint16(linux.S_IFDIR)
+ newFSType := vfs.FilesystemType(&fstype)
+ tmpfsOpts, ok := opts.InternalData.(FilesystemOpts)
+ if ok {
+ if tmpfsOpts.RootFileType != 0 {
+ rootFileType = tmpfsOpts.RootFileType
+ }
+ if tmpfsOpts.FilesystemType != nil {
+ newFSType = tmpfsOpts.FilesystemType
+ }
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ rootMode := linux.FileMode(0777)
+ if rootFileType == linux.S_IFDIR {
+ rootMode = 01777
+ }
+ modeStr, ok := mopts["mode"]
+ if ok {
+ delete(mopts, "mode")
+ mode, err := strconv.ParseUint(modeStr, 8, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid mode: %q", modeStr)
+ return nil, nil, syserror.EINVAL
+ }
+ rootMode = linux.FileMode(mode & 07777)
+ }
+ rootKUID := creds.EffectiveKUID
+ uidStr, ok := mopts["uid"]
+ if ok {
+ delete(mopts, "uid")
+ uid, err := strconv.ParseUint(uidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid uid: %q", uidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
+ if !kuid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKUID = kuid
+ }
+ rootKGID := creds.EffectiveKGID
+ gidStr, ok := mopts["gid"]
+ if ok {
+ delete(mopts, "gid")
+ gid, err := strconv.ParseUint(gidStr, 10, 32)
+ if err != nil {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: invalid gid: %q", gidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
+ if !kgid.Ok() {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
+ return nil, nil, syserror.EINVAL
+ }
+ rootKGID = kgid
+ }
+ if len(mopts) != 0 {
+ ctx.Warningf("tmpfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+ clock := time.RealtimeClockFromContext(ctx)
+ fs := filesystem{
+ memFile: memFileProvider.MemoryFile(),
+ clock: clock,
+ devMinor: devMinor,
+ }
+ fs.vfsfs.Init(vfsObj, newFSType, &fs)
+
+ var root *dentry
+ switch rootFileType {
+ case linux.S_IFREG:
+ root = fs.newDentry(fs.newRegularFile(rootKUID, rootKGID, rootMode))
+ case linux.S_IFLNK:
+ root = fs.newDentry(fs.newSymlink(rootKUID, rootKGID, rootMode, tmpfsOpts.RootSymlinkTarget))
+ case linux.S_IFDIR:
+ root = &fs.newDirectory(rootKUID, rootKGID, rootMode).dentry
+ default:
+ fs.vfsfs.DecRef(ctx)
+ return nil, nil, fmt.Errorf("invalid tmpfs root file type: %#o", rootFileType)
+ }
+ return &fs.vfsfs, &root.vfsd, nil
+}
+
+// NewFilesystem returns a new tmpfs filesystem.
+func NewFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*vfs.Filesystem, *vfs.Dentry, error) {
+ return FilesystemType{}.GetFilesystem(ctx, vfsObj, creds, "", vfs.GetFilesystemOptions{})
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ // parent is this dentry's parent directory. Each referenced dentry holds a
+ // reference on parent.dentry. If this dentry is a filesystem root, parent
+ // is nil. parent is protected by filesystem.mu.
+ parent *dentry
+
+ // name is the name of this dentry in its parent. If this dentry is a
+ // filesystem root, name is the empty string. name is protected by
+ // filesystem.mu.
+ name string
+
+ // dentryEntry (ugh) links dentries into their parent directory.childList.
+ dentryEntry
+
+ // inode is the inode represented by this dentry. Multiple Dentries may
+ // share a single non-directory inode (with hard links). inode is
+ // immutable.
+ //
+ // tmpfs doesn't count references on dentries; because the dentry tree is
+ // the sole source of truth, it is by definition always consistent with the
+ // state of the filesystem. However, it does count references on inodes,
+ // because inode resources are released when all references are dropped.
+ // dentry therefore forwards reference counting directly to inode.
+ inode *inode
+}
+
+func (fs *filesystem) newDentry(inode *inode) *dentry {
+ d := &dentry{
+ inode: inode,
+ }
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ d.inode.incRef()
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ return d.inode.tryIncRef()
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef(ctx context.Context) {
+ d.inode.decRef(ctx)
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {
+ if d.inode.isDir() {
+ events |= linux.IN_ISDIR
+ }
+
+ // tmpfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates
+ // that d was deleted.
+ deleted := d.vfsd.IsDead()
+
+ d.inode.fs.mu.RLock()
+ // The ordering below is important, Linux always notifies the parent first.
+ if d.parent != nil {
+ d.parent.inode.watches.Notify(ctx, d.name, events, cookie, et, deleted)
+ }
+ d.inode.watches.Notify(ctx, "", events, cookie, et, deleted)
+ d.inode.fs.mu.RUnlock()
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ return &d.inode.watches
+}
+
+// OnZeroWatches implements vfs.Dentry.OnZeroWatches.
+func (d *dentry) OnZeroWatches(context.Context) {}
+
+// inode represents a filesystem object.
+type inode struct {
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // A reference is held on all inodes as long as they are reachable in the
+ // filesystem tree, i.e. nlink is nonzero. This reference is dropped when
+ // nlink reaches 0.
+ refs inodeRefs
+
+ // xattrs implements extended attributes.
+ //
+ // TODO(b/148380782): Support xattrs other than user.*
+ xattrs memxattr.SimpleExtendedAttributes
+
+ // Inode metadata. Writing multiple fields atomically requires holding
+ // mu, othewise atomic operations can be used.
+ mu sync.Mutex
+ mode uint32 // file type and mode
+ nlink uint32 // protected by filesystem.mu instead of inode.mu
+ uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic
+ gid uint32 // auth.KGID, but ...
+ ino uint64 // immutable
+
+ // Linux's tmpfs has no concept of btime.
+ atime int64 // nanoseconds
+ ctime int64 // nanoseconds
+ mtime int64 // nanoseconds
+
+ locks vfs.FileLocks
+
+ // Inotify watches for this inode.
+ watches vfs.Watches
+
+ impl interface{} // immutable
+}
+
+const maxLinks = math.MaxUint32
+
+func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) {
+ if mode.FileType() == 0 {
+ panic("file type is required in FileMode")
+ }
+ i.fs = fs
+ i.mode = uint32(mode)
+ i.uid = uint32(kuid)
+ i.gid = uint32(kgid)
+ i.ino = atomic.AddUint64(&fs.nextInoMinusOne, 1)
+ // Tmpfs creation sets atime, ctime, and mtime to current time.
+ now := fs.clock.Now().Nanoseconds()
+ i.atime = now
+ i.ctime = now
+ i.mtime = now
+ // i.nlink initialized by caller
+ i.impl = impl
+ i.refs.EnableLeakCheck()
+}
+
+// incLinksLocked increments i's link count.
+//
+// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
+// i.nlink < maxLinks.
+func (i *inode) incLinksLocked() {
+ if i.nlink == 0 {
+ panic("tmpfs.inode.incLinksLocked() called with no existing links")
+ }
+ if i.nlink == maxLinks {
+ panic("tmpfs.inode.incLinksLocked() called with maximum link count")
+ }
+ atomic.AddUint32(&i.nlink, 1)
+}
+
+// decLinksLocked decrements i's link count. If the link count reaches 0, we
+// remove a reference on i as well.
+//
+// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0.
+func (i *inode) decLinksLocked(ctx context.Context) {
+ if i.nlink == 0 {
+ panic("tmpfs.inode.decLinksLocked() called with no existing links")
+ }
+ if atomic.AddUint32(&i.nlink, ^uint32(0)) == 0 {
+ i.decRef(ctx)
+ }
+}
+
+func (i *inode) incRef() {
+ i.refs.IncRef()
+}
+
+func (i *inode) tryIncRef() bool {
+ return i.refs.TryIncRef()
+}
+
+func (i *inode) decRef(ctx context.Context) {
+ i.refs.DecRef(func() {
+ i.watches.HandleDeletion(ctx)
+ if regFile, ok := i.impl.(*regularFile); ok {
+ // Release memory used by regFile to store data. Since regFile is
+ // no longer usable, we don't need to grab any locks or update any
+ // metadata.
+ regFile.data.DropAll(regFile.memFile)
+ }
+ })
+}
+
+func (i *inode) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ return vfs.GenericCheckPermissions(creds, ats, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid)))
+}
+
+// Go won't inline this function, and returning linux.Statx (which is quite
+// big) means spending a lot of time in runtime.duffcopy(), so instead it's an
+// output parameter.
+//
+// Note that Linux does not guarantee to return consistent data (in the case of
+// a concurrent modification), so we do not require holding inode.mu.
+func (i *inode) statTo(stat *linux.Statx) {
+ stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK |
+ linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE |
+ linux.STATX_BLOCKS | linux.STATX_ATIME | linux.STATX_CTIME |
+ linux.STATX_MTIME
+ stat.Blksize = usermem.PageSize
+ stat.Nlink = atomic.LoadUint32(&i.nlink)
+ stat.UID = atomic.LoadUint32(&i.uid)
+ stat.GID = atomic.LoadUint32(&i.gid)
+ stat.Mode = uint16(atomic.LoadUint32(&i.mode))
+ stat.Ino = i.ino
+ stat.Atime = linux.NsecToStatxTimestamp(i.atime)
+ stat.Ctime = linux.NsecToStatxTimestamp(i.ctime)
+ stat.Mtime = linux.NsecToStatxTimestamp(i.mtime)
+ stat.DevMajor = linux.UNNAMED_MAJOR
+ stat.DevMinor = i.fs.devMinor
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ stat.Mask |= linux.STATX_SIZE | linux.STATX_BLOCKS
+ stat.Size = uint64(atomic.LoadUint64(&impl.size))
+ // TODO(jamieliu): This should be impl.data.Span() / 512, but this is
+ // too expensive to compute here. Cache it in regularFile.
+ stat.Blocks = allocatedBlocksForSize(stat.Size)
+ case *directory:
+ // "20" is mm/shmem.c:BOGO_DIRENT_SIZE.
+ stat.Size = 20 * (2 + uint64(atomic.LoadInt64(&impl.numChildren)))
+ // stat.Blocks is 0.
+ case *symlink:
+ stat.Size = uint64(len(impl.target))
+ // stat.Blocks is 0.
+ case *namedPipe, *socketFile:
+ // stat.Size and stat.Blocks are 0.
+ case *deviceFile:
+ // stat.Size and stat.Blocks are 0.
+ stat.RdevMajor = impl.major
+ stat.RdevMinor = impl.minor
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+ }
+}
+
+func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error {
+ stat := &opts.Stat
+ if stat.Mask == 0 {
+ return nil
+ }
+ if stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME|linux.STATX_SIZE) != 0 {
+ return syserror.EPERM
+ }
+ mode := linux.FileMode(atomic.LoadUint32(&i.mode))
+ if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil {
+ return err
+ }
+ i.mu.Lock()
+ defer i.mu.Unlock()
+ var (
+ needsMtimeBump bool
+ needsCtimeBump bool
+ )
+ mask := stat.Mask
+ if mask&linux.STATX_MODE != 0 {
+ ft := atomic.LoadUint32(&i.mode) & linux.S_IFMT
+ atomic.StoreUint32(&i.mode, ft|uint32(stat.Mode&^linux.S_IFMT))
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_UID != 0 {
+ atomic.StoreUint32(&i.uid, stat.UID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_GID != 0 {
+ atomic.StoreUint32(&i.gid, stat.GID)
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ updated, err := impl.truncateLocked(stat.Size)
+ if err != nil {
+ return err
+ }
+ if updated {
+ needsMtimeBump = true
+ needsCtimeBump = true
+ }
+ case *directory:
+ return syserror.EISDIR
+ default:
+ return syserror.EINVAL
+ }
+ }
+ now := i.fs.clock.Now().Nanoseconds()
+ if mask&linux.STATX_ATIME != 0 {
+ if stat.Atime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.atime, now)
+ } else {
+ atomic.StoreInt64(&i.atime, stat.Atime.ToNsecCapped())
+ }
+ needsCtimeBump = true
+ }
+ if mask&linux.STATX_MTIME != 0 {
+ if stat.Mtime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.mtime, now)
+ } else {
+ atomic.StoreInt64(&i.mtime, stat.Mtime.ToNsecCapped())
+ }
+ needsCtimeBump = true
+ // Ignore the mtime bump, since we just set it ourselves.
+ needsMtimeBump = false
+ }
+ if mask&linux.STATX_CTIME != 0 {
+ if stat.Ctime.Nsec == linux.UTIME_NOW {
+ atomic.StoreInt64(&i.ctime, now)
+ } else {
+ atomic.StoreInt64(&i.ctime, stat.Ctime.ToNsecCapped())
+ }
+ // Ignore the ctime bump, since we just set it ourselves.
+ needsCtimeBump = false
+ }
+ if needsMtimeBump {
+ atomic.StoreInt64(&i.mtime, now)
+ }
+ if needsCtimeBump {
+ atomic.StoreInt64(&i.ctime, now)
+ }
+
+ return nil
+}
+
+// allocatedBlocksForSize returns the number of 512B blocks needed to
+// accommodate the given size in bytes, as appropriate for struct
+// stat::st_blocks and struct statx::stx_blocks. (Note that this 512B block
+// size is independent of the "preferred block size for I/O", struct
+// stat::st_blksize and struct statx::stx_blksize.)
+func allocatedBlocksForSize(size uint64) uint64 {
+ return (size + 511) / 512
+}
+
+func (i *inode) direntType() uint8 {
+ switch impl := i.impl.(type) {
+ case *regularFile:
+ return linux.DT_REG
+ case *directory:
+ return linux.DT_DIR
+ case *symlink:
+ return linux.DT_LNK
+ case *socketFile:
+ return linux.DT_SOCK
+ case *namedPipe:
+ return linux.DT_FIFO
+ case *deviceFile:
+ switch impl.kind {
+ case vfs.BlockDevice:
+ return linux.DT_BLK
+ case vfs.CharDevice:
+ return linux.DT_CHR
+ default:
+ panic(fmt.Sprintf("unknown vfs.DeviceKind: %v", impl.kind))
+ }
+ default:
+ panic(fmt.Sprintf("unknown inode type: %T", i.impl))
+ }
+}
+
+func (i *inode) isDir() bool {
+ return linux.FileMode(i.mode).FileType() == linux.S_IFDIR
+}
+
+func (i *inode) touchAtime(mnt *vfs.Mount) {
+ if mnt.Flags.NoATime {
+ return
+ }
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return
+ }
+ now := i.fs.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.atime, now)
+ i.mu.Unlock()
+ mnt.EndWrite()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCtime() {
+ now := i.fs.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite().
+func (i *inode) touchCMtime() {
+ now := i.fs.clock.Now().Nanoseconds()
+ i.mu.Lock()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+ i.mu.Unlock()
+}
+
+// Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds
+// inode.mu.
+func (i *inode) touchCMtimeLocked() {
+ now := i.fs.clock.Now().Nanoseconds()
+ atomic.StoreInt64(&i.mtime, now)
+ atomic.StoreInt64(&i.ctime, now)
+}
+
+func (i *inode) listxattr(size uint64) ([]string, error) {
+ return i.xattrs.Listxattr(size)
+}
+
+func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) {
+ if err := i.checkPermissions(creds, vfs.MayRead); err != nil {
+ return "", err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return "", syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return "", syserror.ENODATA
+ }
+ return i.xattrs.Getxattr(opts)
+}
+
+func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Setxattr(opts)
+}
+
+func (i *inode) removexattr(creds *auth.Credentials, name string) error {
+ if err := i.checkPermissions(creds, vfs.MayWrite); err != nil {
+ return err
+ }
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+ if !i.userXattrSupported() {
+ return syserror.EPERM
+ }
+ return i.xattrs.Removexattr(name)
+}
+
+// Extended attributes in the user.* namespace are only supported for regular
+// files and directories.
+func (i *inode) userXattrSupported() bool {
+ filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode)
+ return filetype == linux.S_IFREG || filetype == linux.S_IFDIR
+}
+
+// fileDescription is embedded by tmpfs implementations of
+// vfs.FileDescriptionImpl.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+}
+
+func (fd *fileDescription) filesystem() *filesystem {
+ return fd.vfsfd.Mount().Filesystem().Impl().(*filesystem)
+}
+
+func (fd *fileDescription) dentry() *dentry {
+ return fd.vfsfd.Dentry().Impl().(*dentry)
+}
+
+func (fd *fileDescription) inode() *inode {
+ return fd.dentry().inode
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ var stat linux.Statx
+ fd.inode().statTo(&stat)
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ creds := auth.CredentialsFromContext(ctx)
+ d := fd.dentry()
+ if err := d.inode.setStat(ctx, creds, &opts); err != nil {
+ return err
+ }
+
+ if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 {
+ d.InotifyWithParent(ctx, ev, 0, vfs.InodeEvent)
+ }
+ return nil
+}
+
+// Listxattr implements vfs.FileDescriptionImpl.Listxattr.
+func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ return fd.inode().listxattr(size)
+}
+
+// Getxattr implements vfs.FileDescriptionImpl.Getxattr.
+func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) {
+ return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts)
+}
+
+// Setxattr implements vfs.FileDescriptionImpl.Setxattr.
+func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error {
+ d := fd.dentry()
+ if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil {
+ return err
+ }
+
+ // Generate inotify events.
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// Removexattr implements vfs.FileDescriptionImpl.Removexattr.
+func (fd *fileDescription) Removexattr(ctx context.Context, name string) error {
+ d := fd.dentry()
+ if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil {
+ return err
+ }
+
+ // Generate inotify events.
+ d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent)
+ return nil
+}
+
+// NewMemfd creates a new tmpfs regular file and file description that can back
+// an anonymous fd created by memfd_create.
+func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) {
+ fs, ok := mount.Filesystem().Impl().(*filesystem)
+ if !ok {
+ panic("NewMemfd() called with non-tmpfs mount")
+ }
+
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with
+ // S_IRWXUGO.
+ inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777)
+ rf := inode.impl.(*regularFile)
+ if allowSeals {
+ rf.seals = 0
+ }
+
+ d := fs.newDentry(inode)
+ defer d.DecRef(ctx)
+ d.name = name
+
+ // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with
+ // FMODE_READ | FMODE_WRITE.
+ var fd regularFileFD
+ fd.Init(&inode.locks)
+ flags := uint32(linux.O_RDWR)
+ if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
+
+// Sync implements vfs.FileDescriptionImpl.Sync. It does nothing because all
+// filesystem state is in-memory.
+func (*fileDescription) Sync(context.Context) error {
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
new file mode 100644
index 000000000..6f3e3ae6f
--- /dev/null
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go
@@ -0,0 +1,156 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tmpfs
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// nextFileID is used to generate unique file names.
+var nextFileID int64
+
+// newTmpfsRoot creates a new tmpfs mount, and returns the root. If the error
+// is not nil, then cleanup should be called when the root is no longer needed.
+func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentry, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+
+ vfsObj := &vfs.VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err)
+ }
+
+ vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ })
+ mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{})
+ if err != nil {
+ return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err)
+ }
+ root := mntns.Root()
+ return vfsObj, root, func() {
+ root.DecRef(ctx)
+ mntns.DecRef(ctx)
+ }, nil
+}
+
+// newFileFD creates a new file in a new tmpfs mount, and returns the FD. If
+// the returned err is not nil, then cleanup should be called when the FD is no
+// longer needed.
+func newFileFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ filename := fmt.Sprintf("tmpfs-test-file-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the file that will be write/read.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(filename),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL,
+ Mode: linux.ModeRegular | mode,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create file %q: %v", filename, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// newDirFD is like newFileFD, but for directories.
+func newDirFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ dirname := fmt.Sprintf("tmpfs-test-dir-%d", atomic.AddInt64(&nextFileID, 1))
+
+ // Create the dir.
+ if err := vfsObj.MkdirAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.MkdirOptions{
+ Mode: linux.ModeDirectory | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create directory %q: %v", dirname, err)
+ }
+
+ // Open the dir and return it.
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(dirname),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY | linux.O_DIRECTORY,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open directory %q: %v", dirname, err)
+ }
+
+ return fd, cleanup, nil
+}
+
+// newPipeFD is like newFileFD, but for pipes.
+func newPipeFD(ctx context.Context, mode linux.FileMode) (*vfs.FileDescription, func(), error) {
+ creds := auth.CredentialsFromContext(ctx)
+ vfsObj, root, cleanup, err := newTmpfsRoot(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ name := fmt.Sprintf("tmpfs-test-%d", atomic.AddInt64(&nextFileID, 1))
+
+ if err := vfsObj.MknodAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }, &vfs.MknodOptions{
+ Mode: linux.ModeNamedPipe | mode,
+ }); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to create pipe %q: %v", name, err)
+ }
+
+ fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(name),
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR,
+ })
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("failed to open pipe %q: %v", name, err)
+ }
+
+ return fd, cleanup, nil
+}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
new file mode 100644
index 000000000..28d2a4bcb
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "verity",
+ srcs = [
+ "filesystem.go",
+ "verity.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
new file mode 100644
index 000000000..78c6074bd
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -0,0 +1,333 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package verity
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements vfs.FilesystemImpl.Sync.
+func (fs *filesystem) Sync(ctx context.Context) error {
+ // All files should be read-only.
+ return nil
+}
+
+var dentrySlicePool = sync.Pool{
+ New: func() interface{} {
+ ds := make([]*dentry, 0, 4) // arbitrary non-zero initial capacity
+ return &ds
+ },
+}
+
+func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry {
+ if ds == nil {
+ ds = dentrySlicePool.Get().(*[]*dentry)
+ }
+ *ds = append(*ds, d)
+ return ds
+}
+
+// Preconditions: ds != nil.
+func putDentrySlice(ds *[]*dentry) {
+ // Allow dentries to be GC'd.
+ for i := range *ds {
+ (*ds)[i] = nil
+ }
+ *ds = (*ds)[:0]
+ dentrySlicePool.Put(ds)
+}
+
+// renameMuRUnlockAndCheckDrop calls fs.renameMu.RUnlock(), then calls
+// dentry.checkDropLocked on all dentries in *ds with fs.renameMu locked for
+// writing.
+//
+// ds is a pointer-to-pointer since defer evaluates its arguments immediately,
+// but dentry slices are allocated lazily, and it's much easier to say "defer
+// fs.renameMuRUnlockAndCheckDrop(&ds)" than "defer func() {
+// fs.renameMuRUnlockAndCheckDrop(ds) }()" to work around this.
+func (fs *filesystem) renameMuRUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
+ fs.renameMu.RUnlock()
+ if *ds == nil {
+ return
+ }
+ if len(**ds) != 0 {
+ fs.renameMu.Lock()
+ for _, d := range **ds {
+ d.checkDropLocked(ctx)
+ }
+ fs.renameMu.Unlock()
+ }
+ putDentrySlice(*ds)
+}
+
+func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*dentry) {
+ if *ds == nil {
+ fs.renameMu.Unlock()
+ return
+ }
+ for _, d := range **ds {
+ d.checkDropLocked(ctx)
+ }
+ fs.renameMu.Unlock()
+ putDentrySlice(*ds)
+}
+
+// resolveLocked resolves rp to an existing file.
+func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) {
+ // TODO(b/159261227): Implement resolveLocked.
+ return nil, nil
+}
+
+// walkParentDirLocked resolves all but the last path component of rp to an
+// existing directory, starting from the given directory (which is usually
+// rp.Start().Impl().(*dentry)). It does not check that the returned directory
+// is searchable by the provider of rp.
+//
+// Preconditions: fs.renameMu must be locked. !rp.Done().
+func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) {
+ // TODO(b/159261227): Implement walkParentDirLocked.
+ return nil, nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *filesystem) AccessAt(ctx context.Context, rp *vfs.ResolvingPath, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ // Verity file system is read-only.
+ if ats&vfs.MayWrite != 0 {
+ return syserror.EROFS
+ }
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return err
+ }
+ return d.checkPermissions(creds, ats)
+}
+
+// GetDentryAt implements vfs.FilesystemImpl.GetDentryAt.
+func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetDentryOptions) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ if opts.CheckSearchable {
+ if !d.isDir() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil {
+ return nil, err
+ }
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// GetParentDentryAt implements vfs.FilesystemImpl.GetParentDentryAt.
+func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ start := rp.Start().Impl().(*dentry)
+ d, err := fs.walkParentDirLocked(ctx, rp, start, &ds)
+ if err != nil {
+ return nil, err
+ }
+ d.IncRef()
+ return &d.vfsd, nil
+}
+
+// LinkAt implements vfs.FilesystemImpl.LinkAt.
+func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// MkdirAt implements vfs.FilesystemImpl.MkdirAt.
+func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// MknodAt implements vfs.FilesystemImpl.MknodAt.
+func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// OpenAt implements vfs.FilesystemImpl.OpenAt.
+func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ //TODO(b/159261227): Implement OpenAt.
+ return nil, nil
+}
+
+// ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt.
+func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ //TODO(b/162787271): Provide integrity check for ReadlinkAt.
+ return fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ })
+}
+
+// RenameAt implements vfs.FilesystemImpl.RenameAt.
+func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// RmdirAt implements vfs.FilesystemImpl.RmdirAt.
+func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// SetStatAt implements vfs.FilesystemImpl.SetStatAt.
+func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// StatAt implements vfs.FilesystemImpl.StatAt.
+func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+
+ var stat linux.Statx
+ stat, err = fs.vfsfs.VirtualFilesystem().StatAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ }, &opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ return stat, nil
+}
+
+// StatFSAt implements vfs.FilesystemImpl.StatFSAt.
+func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) {
+ // TODO(b/159261227): Implement StatFSAt.
+ return linux.Statfs{}, nil
+}
+
+// SymlinkAt implements vfs.FilesystemImpl.SymlinkAt.
+func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// UnlinkAt implements vfs.FilesystemImpl.UnlinkAt.
+func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ if _, err := fs.resolveLocked(ctx, rp, &ds); err != nil {
+ return nil, err
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt.
+func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return nil, err
+ }
+ lowerVD := d.lowerVD
+ return fs.vfsfs.VirtualFilesystem().ListxattrAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, size)
+}
+
+// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt.
+func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) {
+ var ds *[]*dentry
+ fs.renameMu.RLock()
+ defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds)
+ d, err := fs.resolveLocked(ctx, rp, &ds)
+ if err != nil {
+ return "", err
+ }
+ lowerVD := d.lowerVD
+ return fs.vfsfs.VirtualFilesystem().GetxattrAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: lowerVD,
+ Start: lowerVD,
+ }, &opts)
+}
+
+// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt.
+func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt.
+func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error {
+ // Verity file system is read-only.
+ return syserror.EROFS
+}
+
+// PrependPath implements vfs.FilesystemImpl.PrependPath.
+func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDentry, b *fspath.Builder) error {
+ fs.renameMu.RLock()
+ defer fs.renameMu.RUnlock()
+ mnt := vd.Mount()
+ d := vd.Dentry().Impl().(*dentry)
+ for {
+ if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
+ return vfs.PrependPathAtVFSRootError{}
+ }
+ if &d.vfsd == mnt.Root() {
+ return nil
+ }
+ if d.parent == nil {
+ return vfs.PrependPathAtNonMountRootError{}
+ }
+ b.PrependComponent(d.name)
+ d = d.parent
+ }
+}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
new file mode 100644
index 000000000..cb29d33a5
--- /dev/null
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -0,0 +1,355 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package verity provides a filesystem implementation that is a wrapper of
+// another file system.
+// The verity file system provides integrity check for the underlying file
+// system by providing verification for path traversals and each read.
+// The verity file system is read-only, except for one case: when
+// allowRuntimeEnable is true, additional Merkle files can be generated using
+// the FS_IOC_ENABLE_VERITY ioctl.
+package verity
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Name is the default filesystem name.
+const Name = "verity"
+
+// testOnlyDebugging allows verity file system to return error instead of
+// crashing the application when a malicious action is detected. This should
+// only be set for tests.
+var testOnlyDebugging bool
+
+// FilesystemType implements vfs.FilesystemType.
+type FilesystemType struct{}
+
+// filesystem implements vfs.FilesystemImpl.
+type filesystem struct {
+ vfsfs vfs.Filesystem
+
+ // creds is a copy of the filesystem's creator's credentials, which are
+ // used for accesses to the underlying file system. creds is immutable.
+ creds *auth.Credentials
+
+ // allowRuntimeEnable is true if using ioctl with FS_IOC_ENABLE_VERITY
+ // to build Merkle trees in the verity file system is allowed. If this
+ // is false, no new Merkle trees can be built, and only the files that
+ // had Merkle trees before startup (e.g. from a host filesystem mounted
+ // with gofer fs) can be verified.
+ allowRuntimeEnable bool
+
+ // lowerMount is the underlying file system mount.
+ lowerMount *vfs.Mount
+
+ // rootDentry is the mount root Dentry for this file system, which
+ // stores the root hash of the whole file system in bytes.
+ rootDentry *dentry
+
+ // renameMu synchronizes renaming with non-renaming operations in order
+ // to ensure consistent lock ordering between dentry.dirMu in different
+ // dentries.
+ renameMu sync.RWMutex
+}
+
+// InternalFilesystemOptions may be passed as
+// vfs.GetFilesystemOptions.InternalData to FilesystemType.GetFilesystem.
+type InternalFilesystemOptions struct {
+ // RootMerkleFileName is the name of the verity root Merkle tree file.
+ RootMerkleFileName string
+
+ // LowerName is the name of the filesystem wrapped by verity fs.
+ LowerName string
+
+ // RootHash is the root hash of the overall verity file system.
+ RootHash []byte
+
+ // AllowRuntimeEnable specifies whether the verity file system allows
+ // enabling verification for files (i.e. building Merkle trees) during
+ // runtime.
+ AllowRuntimeEnable bool
+
+ // LowerGetFSOptions is the file system option for the lower layer file
+ // system wrapped by verity file system.
+ LowerGetFSOptions vfs.GetFilesystemOptions
+
+ // TestOnlyDebugging allows verity file system to return error instead
+ // of crashing the application when a malicious action is detected. This
+ // should only be set for tests.
+ TestOnlyDebugging bool
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ //TODO(b/159261227): Implement GetFilesystem.
+ return nil, nil, nil
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ fs.lowerMount.DecRef(ctx)
+}
+
+// dentry implements vfs.DentryImpl.
+type dentry struct {
+ vfsd vfs.Dentry
+
+ refs int64
+
+ // fs is the owning filesystem. fs is immutable.
+ fs *filesystem
+
+ // mode, uid and gid are the file mode, owner, and group of the file in
+ // the underlying file system.
+ mode uint32
+ uid uint32
+ gid uint32
+
+ // parent is the dentry corresponding to this dentry's parent directory.
+ // name is this dentry's name in parent. If this dentry is a filesystem
+ // root, parent is nil and name is the empty string. parent and name are
+ // protected by fs.renameMu.
+ parent *dentry
+ name string
+
+ // If this dentry represents a directory, children maps the names of
+ // children for which dentries have been instantiated to those dentries,
+ // and dirents (if not nil) is a cache of dirents as returned by
+ // directoryFDs representing this directory. children is protected by
+ // dirMu.
+ dirMu sync.Mutex
+ children map[string]*dentry
+
+ // lowerVD is the VirtualDentry in the underlying file system.
+ lowerVD vfs.VirtualDentry
+
+ // lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree
+ // in the underlying file system.
+ lowerMerkleVD vfs.VirtualDentry
+
+ // rootHash is the rootHash for the current file or directory.
+ rootHash []byte
+}
+
+// newDentry creates a new dentry representing the given verity file. The
+// dentry initially has no references; it is the caller's responsibility to set
+// the dentry's reference count and/or call dentry.destroy() as appropriate.
+// The dentry is initially invalid in that it contains no underlying dentry;
+// the caller is responsible for setting them.
+func (fs *filesystem) newDentry() *dentry {
+ d := &dentry{
+ fs: fs,
+ }
+ d.vfsd.Init(d)
+ return d
+}
+
+// IncRef implements vfs.DentryImpl.IncRef.
+func (d *dentry) IncRef() {
+ atomic.AddInt64(&d.refs, 1)
+}
+
+// TryIncRef implements vfs.DentryImpl.TryIncRef.
+func (d *dentry) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&d.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef implements vfs.DentryImpl.DecRef.
+func (d *dentry) DecRef(ctx context.Context) {
+ if refs := atomic.AddInt64(&d.refs, -1); refs == 0 {
+ d.fs.renameMu.Lock()
+ d.checkDropLocked(ctx)
+ d.fs.renameMu.Unlock()
+ } else if refs < 0 {
+ panic("verity.dentry.DecRef() called without holding a reference")
+ }
+}
+
+// checkDropLocked should be called after d's reference count becomes 0 or it
+// becomes deleted.
+func (d *dentry) checkDropLocked(ctx context.Context) {
+ // Dentries with a positive reference count must be retained. Dentries
+ // with a negative reference count have already been destroyed.
+ if atomic.LoadInt64(&d.refs) != 0 {
+ return
+ }
+ // Refs is still zero; destroy it.
+ d.destroyLocked(ctx)
+ return
+}
+
+// destroyLocked destroys the dentry.
+//
+// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0.
+func (d *dentry) destroyLocked(ctx context.Context) {
+ switch atomic.LoadInt64(&d.refs) {
+ case 0:
+ // Mark the dentry destroyed.
+ atomic.StoreInt64(&d.refs, -1)
+ case -1:
+ panic("verity.dentry.destroyLocked() called on already destroyed dentry")
+ default:
+ panic("verity.dentry.destroyLocked() called with references on the dentry")
+ }
+
+ if d.lowerVD.Ok() {
+ d.lowerVD.DecRef(ctx)
+ }
+
+ if d.lowerMerkleVD.Ok() {
+ d.lowerMerkleVD.DecRef(ctx)
+ }
+
+ if d.parent != nil {
+ d.parent.dirMu.Lock()
+ if !d.vfsd.IsDead() {
+ delete(d.parent.children, d.name)
+ }
+ d.parent.dirMu.Unlock()
+ if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 {
+ d.parent.checkDropLocked(ctx)
+ } else if refs < 0 {
+ panic("verity.dentry.DecRef() called without holding a reference")
+ }
+ }
+}
+
+// InotifyWithParent implements vfs.DentryImpl.InotifyWithParent.
+func (d *dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et vfs.EventType) {
+ //TODO(b/159261227): Implement InotifyWithParent.
+}
+
+// Watches implements vfs.DentryImpl.Watches.
+func (d *dentry) Watches() *vfs.Watches {
+ //TODO(b/159261227): Implement Watches.
+ return nil
+}
+
+// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches.
+func (d *dentry) OnZeroWatches(context.Context) {
+ //TODO(b/159261227): Implement OnZeroWatches.
+}
+
+func (d *dentry) isSymlink() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK
+}
+
+func (d *dentry) isDir() bool {
+ return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR
+}
+
+func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error {
+ return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid)))
+}
+
+func (d *dentry) readlink(ctx context.Context) (string, error) {
+ return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{
+ Root: d.lowerVD,
+ Start: d.lowerVD,
+ })
+}
+
+// FileDescription implements vfs.FileDescriptionImpl for verity fds.
+// FileDescription is a wrapper of the underlying lowerFD, with support to build
+// Merkle trees through the Linux fs-verity API to verify contents read from
+// lowerFD.
+type fileDescription struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ // d is the corresponding dentry to the fileDescription.
+ d *dentry
+
+ // isDir specifies whehter the fileDescription points to a directory.
+ isDir bool
+
+ // lowerFD is the FileDescription corresponding to the file in the
+ // underlying file system.
+ lowerFD *vfs.FileDescription
+
+ // merkleReader is the read-only FileDescription corresponding to the
+ // Merkle tree file in the underlying file system.
+ merkleReader *vfs.FileDescription
+
+ // merkleWriter is the FileDescription corresponding to the Merkle tree
+ // file in the underlying file system for writing. This should only be
+ // used when allowRuntimeEnable is set to true.
+ merkleWriter *vfs.FileDescription
+
+ // parentMerkleWriter is the FileDescription of the Merkle tree for the
+ // directory that contains the current file/directory. This is only used
+ // if allowRuntimeEnable is set to true.
+ parentMerkleWriter *vfs.FileDescription
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (fd *fileDescription) Release(ctx context.Context) {
+ fd.lowerFD.DecRef(ctx)
+ fd.merkleReader.DecRef(ctx)
+ if fd.merkleWriter != nil {
+ fd.merkleWriter.DecRef(ctx)
+ }
+ if fd.parentMerkleWriter != nil {
+ fd.parentMerkleWriter.DecRef(ctx)
+ }
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ // TODO(b/162788573): Add integrity check for metadata.
+ stat, err := fd.lowerFD.Stat(ctx, opts)
+ if err != nil {
+ return linux.Statx{}, err
+ }
+ return stat, nil
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ // Verity files are read-only.
+ return syserror.EPERM
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD
index 359468ccc..e6933aa70 100644
--- a/pkg/sentry/hostcpu/BUILD
+++ b/pkg/sentry/hostcpu/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,7 +9,6 @@ go_library(
"getcpu_arm64.s",
"hostcpu.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu",
visibility = ["//:sandbox"],
)
@@ -18,5 +16,5 @@ go_test(
name = "hostcpu_test",
size = "small",
srcs = ["hostcpu_test.go"],
- embed = [":hostcpu"],
+ library = ":hostcpu",
)
diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD
new file mode 100644
index 000000000..364a78306
--- /dev/null
+++ b/pkg/sentry/hostfd/BUILD
@@ -0,0 +1,17 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "hostfd",
+ srcs = [
+ "hostfd.go",
+ "hostfd_unsafe.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/safemem",
+ "//pkg/sync",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/hostfd/hostfd.go b/pkg/sentry/hostfd/hostfd.go
new file mode 100644
index 000000000..70dd9cafb
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostfd provides efficient I/O with host file descriptors.
+package hostfd
+
+import (
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// ReadWriterAt implements safemem.Reader and safemem.Writer by reading from
+// and writing to a host file descriptor respectively. ReadWriterAts should be
+// obtained by calling GetReadWriterAt.
+//
+// Clients should usually prefer to use Preadv2 and Pwritev2 directly.
+type ReadWriterAt struct {
+ fd int32
+ offset int64
+ flags uint32
+}
+
+var rwpool = sync.Pool{
+ New: func() interface{} {
+ return &ReadWriterAt{}
+ },
+}
+
+// GetReadWriterAt returns a ReadWriterAt that reads from / writes to the given
+// host file descriptor, starting at the given offset and using the given
+// preadv2(2)/pwritev2(2) flags. If offset is -1, the host file descriptor's
+// offset is used instead. Users are responsible for ensuring that fd remains
+// valid for the lifetime of the returned ReadWriterAt, and must call
+// PutReadWriterAt when it is no longer needed.
+func GetReadWriterAt(fd int32, offset int64, flags uint32) *ReadWriterAt {
+ rw := rwpool.Get().(*ReadWriterAt)
+ *rw = ReadWriterAt{
+ fd: fd,
+ offset: offset,
+ flags: flags,
+ }
+ return rw
+}
+
+// PutReadWriterAt releases a ReadWriterAt returned by a previous call to
+// GetReadWriterAt that is no longer in use.
+func PutReadWriterAt(rw *ReadWriterAt) {
+ rwpool.Put(rw)
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (rw *ReadWriterAt) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ n, err := Preadv2(rw.fd, dsts, rw.offset, rw.flags)
+ if rw.offset >= 0 {
+ rw.offset += int64(n)
+ }
+ return n, err
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (rw *ReadWriterAt) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ n, err := Pwritev2(rw.fd, srcs, rw.offset, rw.flags)
+ if rw.offset >= 0 {
+ rw.offset += int64(n)
+ }
+ return n, err
+}
diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go
new file mode 100644
index 000000000..cd4dc67fb
--- /dev/null
+++ b/pkg/sentry/hostfd/hostfd_unsafe.go
@@ -0,0 +1,85 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostfd
+
+import (
+ "io"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/safemem"
+)
+
+// Preadv2 reads up to dsts.NumBytes() bytes from host file descriptor fd into
+// dsts. offset and flags are interpreted as for preadv2(2).
+//
+// Preconditions: !dsts.IsEmpty().
+func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ var (
+ n uintptr
+ e syscall.Errno
+ )
+ if flags == 0 && dsts.NumBlocks() == 1 {
+ // Use read() or pread() to avoid iovec allocation and copying.
+ dst := dsts.Head()
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_READ, uintptr(fd), dst.Addr(), uintptr(dst.Len()))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PREAD64, uintptr(fd), dst.Addr(), uintptr(dst.Len()), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(dsts)
+ n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
+ }
+ if e != 0 {
+ return 0, e
+ }
+ if n == 0 {
+ return 0, io.EOF
+ }
+ return uint64(n), nil
+}
+
+// Pwritev2 writes up to srcs.NumBytes() from srcs into host file descriptor
+// fd. offset and flags are interpreted as for pwritev2(2).
+//
+// Preconditions: !srcs.IsEmpty().
+func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint64, error) {
+ // No buffering is necessary regardless of safecopy; host syscalls will
+ // return EFAULT if appropriate, instead of raising SIGBUS.
+ var (
+ n uintptr
+ e syscall.Errno
+ )
+ if flags == 0 && srcs.NumBlocks() == 1 {
+ // Use write() or pwrite() to avoid iovec allocation and copying.
+ src := srcs.Head()
+ if offset == -1 {
+ n, _, e = syscall.Syscall(unix.SYS_WRITE, uintptr(fd), src.Addr(), uintptr(src.Len()))
+ } else {
+ n, _, e = syscall.Syscall6(unix.SYS_PWRITE64, uintptr(fd), src.Addr(), uintptr(src.Len()), uintptr(offset), 0 /* pos_h */, 0 /* unused */)
+ }
+ } else {
+ iovs := safemem.IovecsFromBlockSeq(srcs)
+ n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags))
+ }
+ if e != 0 {
+ return 0, e
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/hostmm/BUILD b/pkg/sentry/hostmm/BUILD
index 67831d5a1..61c78569d 100644
--- a/pkg/sentry/hostmm/BUILD
+++ b/pkg/sentry/hostmm/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,11 +8,10 @@ go_library(
"cgroup.go",
"hostmm.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/hostmm",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/fd",
"//pkg/log",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/hostmm/hostmm.go b/pkg/sentry/hostmm/hostmm.go
index 19335ca73..506c7864a 100644
--- a/pkg/sentry/hostmm/hostmm.go
+++ b/pkg/sentry/hostmm/hostmm.go
@@ -24,7 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// NotifyCurrentMemcgPressureCallback requests that f is called whenever the
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 99481e05e..5bba9de0b 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(
default_visibility = ["//:sandbox"],
@@ -10,11 +10,12 @@ go_library(
srcs = [
"context.go",
"inet.go",
+ "namespace.go",
"test_stack.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/inet",
deps = [
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/tcpip",
+ "//pkg/tcpip/stack",
],
)
diff --git a/pkg/sentry/inet/context.go b/pkg/sentry/inet/context.go
index 4eda7dd1f..e8cc1bffd 100644
--- a/pkg/sentry/inet/context.go
+++ b/pkg/sentry/inet/context.go
@@ -15,7 +15,7 @@
package inet
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is the inet package's type for context.Context.Value keys.
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 6217100b2..fbe6d6aa6 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -15,7 +15,10 @@
// Package inet defines semantics for IP stacks.
package inet
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
// Stack represents a TCP/IP stack.
type Stack interface {
@@ -28,6 +31,10 @@ type Stack interface {
// interface indexes to a slice of associated interface address properties.
InterfaceAddrs() map[int32][]InterfaceAddr
+ // AddInterfaceAddr adds an address to the network interface identified by
+ // index.
+ AddInterfaceAddr(idx int32, addr InterfaceAddr) error
+
// SupportsIPv6 returns true if the stack supports IPv6 connectivity.
SupportsIPv6() bool
@@ -52,6 +59,12 @@ type Stack interface {
// settings.
SetTCPSACKEnabled(enabled bool) error
+ // TCPRecovery returns the TCP loss detection algorithm.
+ TCPRecovery() (TCPLossRecovery, error)
+
+ // SetTCPRecovery attempts to change TCP loss detection algorithm.
+ SetTCPRecovery(recovery TCPLossRecovery) error
+
// Statistics reports stack statistics.
Statistics(stat interface{}, arg string) error
@@ -61,6 +74,16 @@ type Stack interface {
// Resume restarts the network stack after restore.
Resume()
+ // RegisteredEndpoints returns all endpoints which are currently registered.
+ RegisteredEndpoints() []stack.TransportEndpoint
+
+ // CleanupEndpoints returns endpoints currently in the cleanup state.
+ CleanupEndpoints() []stack.TransportEndpoint
+
+ // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+ // for restoring a stack after a save.
+ RestoreCleanupEndpoints([]stack.TransportEndpoint)
+
// Forwarding returns if packet forwarding between NICs is enabled.
Forwarding(protocol tcpip.NetworkProtocolNumber) bool
@@ -181,3 +204,14 @@ type StatSNMPUDP [8]uint64
// StatSNMPUDPLite describes UdpLite line of /proc/net/snmp.
type StatSNMPUDPLite [8]uint64
+
+// TCPLossRecovery indicates TCP loss detection and recovery methods to use.
+type TCPLossRecovery int32
+
+// Loss recovery constants from include/net/tcp.h which are used to set
+// /proc/sys/net/ipv4/tcp_recovery.
+const (
+ TCP_RACK_LOSS_DETECTION TCPLossRecovery = 1 << iota
+ TCP_RACK_STATIC_REO_WND
+ TCP_RACK_NO_DUPTHRESH
+)
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
new file mode 100644
index 000000000..029af3025
--- /dev/null
+++ b/pkg/sentry/inet/namespace.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// Namespace represents a network namespace. See network_namespaces(7).
+//
+// +stateify savable
+type Namespace struct {
+ // stack is the network stack implementation of this network namespace.
+ stack Stack `state:"nosave"`
+
+ // creator allows kernel to create new network stack for network namespaces.
+ // If nil, no networking will function if network is namespaced.
+ //
+ // At afterLoad(), creator will be used to create network stack. Stateify
+ // needs to wait for this field to be loaded before calling afterLoad().
+ creator NetworkStackCreator `state:"wait"`
+
+ // isRoot indicates whether this is the root network namespace.
+ isRoot bool
+}
+
+// NewRootNamespace creates the root network namespace, with creator
+// allowing new network namespaces to be created. If creator is nil, no
+// networking will function if the network is namespaced.
+func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace {
+ return &Namespace{
+ stack: stack,
+ creator: creator,
+ isRoot: true,
+ }
+}
+
+// NewNamespace creates a new network namespace from the root.
+func NewNamespace(root *Namespace) *Namespace {
+ n := &Namespace{
+ creator: root.creator,
+ }
+ n.init()
+ return n
+}
+
+// Stack returns the network stack of n. Stack may return nil if no network
+// stack is configured.
+func (n *Namespace) Stack() Stack {
+ return n.stack
+}
+
+// IsRoot returns whether n is the root network namespace.
+func (n *Namespace) IsRoot() bool {
+ return n.isRoot
+}
+
+// RestoreRootStack restores the root network namespace with stack. This should
+// only be called when restoring kernel.
+func (n *Namespace) RestoreRootStack(stack Stack) {
+ if !n.isRoot {
+ panic("RestoreRootStack can only be called on root network namespace")
+ }
+ if n.stack != nil {
+ panic("RestoreRootStack called after a stack has already been set")
+ }
+ n.stack = stack
+}
+
+func (n *Namespace) init() {
+ // Root network namespace will have stack assigned later.
+ if n.isRoot {
+ return
+ }
+ if n.creator != nil {
+ var err error
+ n.stack, err = n.creator.CreateStack()
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (n *Namespace) afterLoad() {
+ n.init()
+}
+
+// NetworkStackCreator allows new instances of a network stack to be created. It
+// is used by the kernel to create new network namespaces when requested.
+type NetworkStackCreator interface {
+ // CreateStack creates a new network stack for a network namespace.
+ CreateStack() (Stack, error)
+}
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index c6907cfcb..1779cc6f3 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -16,6 +16,7 @@ package inet
import (
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
// TestStack is a dummy implementation of Stack for tests.
@@ -27,6 +28,7 @@ type TestStack struct {
TCPRecvBufSize TCPBufferSize
TCPSendBufSize TCPBufferSize
TCPSACKFlag bool
+ Recovery TCPLossRecovery
IPForwarding bool
}
@@ -50,6 +52,12 @@ func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr {
return s.InterfaceAddrsMap
}
+// AddInterfaceAddr implements Stack.AddInterfaceAddr.
+func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
+ s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr)
+ return nil
+}
+
// SupportsIPv6 implements Stack.SupportsIPv6.
func (s *TestStack) SupportsIPv6() bool {
return s.SupportsIPv6Flag
@@ -88,6 +96,17 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error {
return nil
}
+// TCPRecovery implements Stack.TCPRecovery.
+func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) {
+ return s.Recovery, nil
+}
+
+// SetTCPRecovery implements Stack.SetTCPRecovery.
+func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error {
+ s.Recovery = recovery
+ return nil
+}
+
// Statistics implements inet.Stack.Statistics.
func (s *TestStack) Statistics(stat interface{}, arg string) error {
return nil
@@ -99,9 +118,21 @@ func (s *TestStack) RouteTable() []Route {
}
// Resume implements Stack.Resume.
-func (s *TestStack) Resume() {
+func (s *TestStack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return nil
}
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint {
+ return nil
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
+
// Forwarding implements inet.Stack.Forwarding.
func (s *TestStack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
return s.IPForwarding
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e041c51b3..5416a310d 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,8 +1,5 @@
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
+load("//tools:defs.bzl", "go_library", "go_test", "proto_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -35,7 +32,7 @@ go_template_instance(
out = "seqatomic_taskgoroutineschedinfo_unsafe.go",
package = "kernel",
suffix = "TaskGoroutineSchedInfo",
- template = "//third_party/gvsync:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "TaskGoroutineSchedInfo",
},
@@ -78,36 +75,24 @@ go_template_instance(
)
proto_library(
- name = "uncaught_signal_proto",
+ name = "uncaught_signal",
srcs = ["uncaught_signal.proto"],
visibility = ["//visibility:public"],
deps = ["//pkg/sentry/arch:registers_proto"],
)
-cc_proto_library(
- name = "uncaught_signal_cc_proto",
- visibility = ["//visibility:public"],
- deps = [":uncaught_signal_proto"],
-)
-
-go_proto_library(
- name = "uncaught_signal_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto",
- proto = ":uncaught_signal_proto",
- visibility = ["//visibility:public"],
- deps = ["//pkg/sentry/arch:registers_go_proto"],
-)
-
go_library(
name = "kernel",
srcs = [
"abstract_socket_namespace.go",
+ "aio.go",
"context.go",
"fd_table.go",
"fd_table_unsafe.go",
"fs_context.go",
"ipc_namespace.go",
"kernel.go",
+ "kernel_opts.go",
"kernel_state.go",
"pending_signals.go",
"pending_signals_list.go",
@@ -147,6 +132,7 @@ go_library(
"task_stop.go",
"task_syscall.go",
"task_usermem.go",
+ "task_work.go",
"thread_group.go",
"threads.go",
"timekeeper.go",
@@ -156,7 +142,6 @@ go_library(
"vdso.go",
"version.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel",
imports = [
"gvisor.dev/gvisor/pkg/bpf",
"gvisor.dev/gvisor/pkg/sentry/device",
@@ -171,18 +156,27 @@ go_library(
"//pkg/binary",
"//pkg/bits",
"//pkg/bpf",
+ "//pkg/context",
"//pkg/cpuid",
"//pkg/eventchannel",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/metric",
"//pkg/refs",
+ "//pkg/refs_vfs2",
+ "//pkg/safemem",
"//pkg/secio",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/fsimpl/pipefs",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/fsimpl/timerfd",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/hostcpu",
"//pkg/sentry/inet",
"//pkg/sentry/kernel/auth",
@@ -198,7 +192,6 @@ go_library(
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/time",
@@ -206,15 +199,18 @@ go_library(
"//pkg/sentry/unimpl:unimplemented_syscall_go_proto",
"//pkg/sentry/uniqueid",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/state",
"//pkg/state/statefile",
+ "//pkg/state/wire",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/stack",
+ "//pkg/usermem",
"//pkg/waiter",
- "//third_party/gvsync",
+ "//tools/go_marshal/marshal",
],
)
@@ -227,12 +223,12 @@ go_test(
"task_test.go",
"timekeeper_test.go",
],
- embed = [":kernel"],
+ library = ":kernel",
deps = [
"//pkg/abi",
+ "//pkg/context",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
"//pkg/sentry/fs/filetest",
"//pkg/sentry/kernel/sched",
@@ -240,7 +236,8 @@ go_test(
"//pkg/sentry/pgalloc",
"//pkg/sentry/time",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 244655b5c..1b9721534 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -15,28 +15,21 @@
package kernel
import (
- "sync"
+ "fmt"
"syscall"
- "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs_vfs2"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
)
// +stateify savable
type abstractEndpoint struct {
- ep transport.BoundEndpoint
- wr *refs.WeakRef
- name string
- ns *AbstractSocketNamespace
-}
-
-// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (e *abstractEndpoint) WeakRefGone() {
- e.ns.mu.Lock()
- if e.ns.endpoints[e.name].ep == e.ep {
- delete(e.ns.endpoints, e.name)
- }
- e.ns.mu.Unlock()
+ ep transport.BoundEndpoint
+ socket refs_vfs2.RefCounter
+ name string
+ ns *AbstractSocketNamespace
}
// AbstractSocketNamespace is used to implement the Linux abstract socket functionality.
@@ -45,7 +38,11 @@ func (e *abstractEndpoint) WeakRefGone() {
type AbstractSocketNamespace struct {
mu sync.Mutex `state:"nosave"`
- // Keeps mapping from name to endpoint.
+ // Keeps a mapping from name to endpoint. AbstractSocketNamespace does not hold
+ // any references on any sockets that it contains; when retrieving a socket,
+ // TryIncRef() must be called in case the socket is concurrently being
+ // destroyed. It is the responsibility of the socket to remove itself from the
+ // abstract socket namespace when it is destroyed.
endpoints map[string]abstractEndpoint
}
@@ -57,16 +54,16 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace {
}
// A boundEndpoint wraps a transport.BoundEndpoint to maintain a reference on
-// its backing object.
+// its backing socket.
type boundEndpoint struct {
transport.BoundEndpoint
- rc refs.RefCounter
+ socket refs_vfs2.RefCounter
}
// Release implements transport.BoundEndpoint.Release.
-func (e *boundEndpoint) Release() {
- e.rc.DecRef()
- e.BoundEndpoint.Release()
+func (e *boundEndpoint) Release(ctx context.Context) {
+ e.socket.DecRef(ctx)
+ e.BoundEndpoint.Release(ctx)
}
// BoundEndpoint retrieves the endpoint bound to the given name. The return
@@ -80,32 +77,59 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp
return nil
}
- rc := ep.wr.Get()
- if rc == nil {
- delete(a.endpoints, name)
+ if !ep.socket.TryIncRef() {
+ // The socket has reached zero references and is being destroyed.
return nil
}
- return &boundEndpoint{ep.ep, rc}
+ return &boundEndpoint{ep.ep, ep.socket}
}
// Bind binds the given socket.
//
-// When the last reference managed by rc is dropped, ep may be removed from the
+// When the last reference managed by socket is dropped, ep may be removed from the
// namespace.
-func (a *AbstractSocketNamespace) Bind(name string, ep transport.BoundEndpoint, rc refs.RefCounter) error {
+func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error {
a.mu.Lock()
defer a.mu.Unlock()
+ // Check if there is already a socket (which has not yet been destroyed) bound at name.
if ep, ok := a.endpoints[name]; ok {
- if rc := ep.wr.Get(); rc != nil {
- rc.DecRef()
+ if ep.socket.TryIncRef() {
+ ep.socket.DecRef(ctx)
return syscall.EADDRINUSE
}
}
ae := abstractEndpoint{ep: ep, name: name, ns: a}
- ae.wr = refs.NewWeakRef(rc, &ae)
+ ae.socket = socket
a.endpoints[name] = ae
return nil
}
+
+// Remove removes the specified socket at name from the abstract socket
+// namespace, if it has not yet been replaced.
+func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ ep, ok := a.endpoints[name]
+ if !ok {
+ // We never delete a map entry apart from a socket's destructor (although the
+ // map entry may be overwritten). Therefore, a socket should exist, even if it
+ // may not be the one we expect.
+ panic(fmt.Sprintf("expected socket to exist at '%s' in abstract socket namespace", name))
+ }
+
+ // A Bind() operation may race with callers of Remove(), e.g. in the
+ // following case:
+ // socket1 reaches zero references and begins destruction
+ // a.Bind("foo", ep, socket2) replaces socket1 with socket2
+ // socket1's destructor calls a.Remove("foo", socket1)
+ //
+ // Therefore, we need to check that the socket at name is what we expect
+ // before modifying the map.
+ if ep.socket == socket {
+ delete(a.endpoints, name)
+ }
+}
diff --git a/pkg/sentry/kernel/aio.go b/pkg/sentry/kernel/aio.go
new file mode 100644
index 000000000..0ac78c0b8
--- /dev/null
+++ b/pkg/sentry/kernel/aio.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// AIOCallback is an function that does asynchronous I/O on behalf of a task.
+type AIOCallback func(context.Context)
+
+// QueueAIO queues an AIOCallback which will be run asynchronously.
+func (t *Task) QueueAIO(cb AIOCallback) {
+ ctx := taskAsyncContext{t: t}
+ wg := &t.TaskSet().aioGoroutines
+ wg.Add(1)
+ go func() {
+ cb(ctx)
+ wg.Done()
+ }()
+}
+
+type taskAsyncContext struct {
+ context.NoopSleeper
+ t *Task
+}
+
+// Debugf implements log.Logger.Debugf.
+func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) {
+ ctx.t.Debugf(format, v...)
+}
+
+// Infof implements log.Logger.Infof.
+func (ctx taskAsyncContext) Infof(format string, v ...interface{}) {
+ ctx.t.Infof(format, v...)
+}
+
+// Warningf implements log.Logger.Warningf.
+func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) {
+ ctx.t.Warningf(format, v...)
+}
+
+// IsLogging implements log.Logger.IsLogging.
+func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
+ return ctx.t.IsLogging(level)
+}
+
+// Deadline implements context.Context.Deadline.
+func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
+ return ctx.t.Deadline()
+}
+
+// Done implements context.Context.Done.
+func (ctx taskAsyncContext) Done() <-chan struct{} {
+ return ctx.t.Done()
+}
+
+// Err implements context.Context.Err.
+func (ctx taskAsyncContext) Err() error {
+ return ctx.t.Err()
+}
+
+// Value implements context.Context.Value.
+func (ctx taskAsyncContext) Value(key interface{}) interface{} {
+ return ctx.t.Value(key)
+}
diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD
index 51de4568a..2bc49483a 100644
--- a/pkg/sentry/kernel/auth/BUILD
+++ b/pkg/sentry/kernel/auth/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_credentials_unsafe.go",
package = "auth",
suffix = "Credentials",
- template = "//third_party/gvsync:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "Credentials",
},
@@ -57,13 +57,13 @@ go_library(
"id_map_set.go",
"user_namespace.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/auth",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/bits",
+ "//pkg/context",
"//pkg/log",
- "//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go
index 5c0e7d6b6..ef5723127 100644
--- a/pkg/sentry/kernel/auth/context.go
+++ b/pkg/sentry/kernel/auth/context.go
@@ -15,7 +15,7 @@
package auth
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is the auth package's type for context.Context.Value keys.
diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go
index e057d2c6d..6862f2ef5 100644
--- a/pkg/sentry/kernel/auth/credentials.go
+++ b/pkg/sentry/kernel/auth/credentials.go
@@ -232,3 +232,31 @@ func (c *Credentials) UseGID(gid GID) (KGID, error) {
}
return NoID, syserror.EPERM
}
+
+// SetUID translates the provided uid to the root user namespace and updates c's
+// uids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetUID(uid UID) error {
+ kuid := c.UserNamespace.MapToKUID(uid)
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKUID = kuid
+ c.EffectiveKUID = kuid
+ c.SavedKUID = kuid
+ return nil
+}
+
+// SetGID translates the provided gid to the root user namespace and updates c's
+// gids to it. This performs no permissions or capabilities checks, the caller
+// is responsible for ensuring the calling context is permitted to modify c.
+func (c *Credentials) SetGID(gid GID) error {
+ kgid := c.UserNamespace.MapToKGID(gid)
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ c.RealKGID = kgid
+ c.EffectiveKGID = kgid
+ c.SavedKGID = kgid
+ return nil
+}
diff --git a/pkg/sentry/kernel/auth/id_map.go b/pkg/sentry/kernel/auth/id_map.go
index 3d74bc610..28cbe159d 100644
--- a/pkg/sentry/kernel/auth/id_map.go
+++ b/pkg/sentry/kernel/auth/id_map.go
@@ -16,7 +16,7 @@ package auth
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/auth/user_namespace.go b/pkg/sentry/kernel/auth/user_namespace.go
index af28ccc65..9dd52c860 100644
--- a/pkg/sentry/kernel/auth/user_namespace.go
+++ b/pkg/sentry/kernel/auth/user_namespace.go
@@ -16,8 +16,8 @@ package auth
import (
"math"
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
index e3f5b0d83..dd5f0f5fa 100644
--- a/pkg/sentry/kernel/context.go
+++ b/pkg/sentry/kernel/context.go
@@ -15,8 +15,9 @@
package kernel
import (
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is the kernel package's type for context.Context.Value keys.
@@ -97,39 +98,17 @@ func TaskFromContext(ctx context.Context) *Task {
return nil
}
-// AsyncContext returns a context.Context that may be used by goroutines that
-// do work on behalf of t and therefore share its contextual values, but are
-// not t's task goroutine (e.g. asynchronous I/O).
-func (t *Task) AsyncContext() context.Context {
- return taskAsyncContext{t: t}
-}
-
-type taskAsyncContext struct {
- context.NoopSleeper
- t *Task
+// Deadline implements context.Context.Deadline.
+func (*Task) Deadline() (time.Time, bool) {
+ return time.Time{}, false
}
-// Debugf implements log.Logger.Debugf.
-func (ctx taskAsyncContext) Debugf(format string, v ...interface{}) {
- ctx.t.Debugf(format, v...)
-}
-
-// Infof implements log.Logger.Infof.
-func (ctx taskAsyncContext) Infof(format string, v ...interface{}) {
- ctx.t.Infof(format, v...)
-}
-
-// Warningf implements log.Logger.Warningf.
-func (ctx taskAsyncContext) Warningf(format string, v ...interface{}) {
- ctx.t.Warningf(format, v...)
-}
-
-// IsLogging implements log.Logger.IsLogging.
-func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
- return ctx.t.IsLogging(level)
+// Done implements context.Context.Done.
+func (*Task) Done() <-chan struct{} {
+ return nil
}
-// Value implements context.Context.Value.
-func (ctx taskAsyncContext) Value(key interface{}) interface{} {
- return ctx.t.Value(key)
+// Err implements context.Context.Err.
+func (*Task) Err() error {
+ return nil
}
diff --git a/pkg/sentry/kernel/contexttest/BUILD b/pkg/sentry/kernel/contexttest/BUILD
index 3a88a585c..9d26392c0 100644
--- a/pkg/sentry/kernel/contexttest/BUILD
+++ b/pkg/sentry/kernel/contexttest/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,11 +6,10 @@ go_library(
name = "contexttest",
testonly = 1,
srcs = ["contexttest.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest",
visibility = ["//pkg/sentry:internal"],
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/kernel",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
diff --git a/pkg/sentry/kernel/contexttest/contexttest.go b/pkg/sentry/kernel/contexttest/contexttest.go
index 82f9d8922..22c340e56 100644
--- a/pkg/sentry/kernel/contexttest/contexttest.go
+++ b/pkg/sentry/kernel/contexttest/contexttest.go
@@ -19,8 +19,8 @@ package contexttest
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index 3361e8b7d..75eedd5a2 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,15 +22,16 @@ go_library(
"epoll_list.go",
"epoll_state.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/epoll",
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
"//pkg/refs",
- "//pkg/sentry/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -42,9 +42,9 @@ go_test(
srcs = [
"epoll_test.go",
],
- embed = [":epoll"],
+ library = ":epoll",
deps = [
- "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs/filetest",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 9c0a4e1b4..15519f0df 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -18,31 +18,19 @@ package epoll
import (
"fmt"
- "sync"
"syscall"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
-// Event describes the event mask that was observed and the user data to be
-// returned when one of the events occurs. It has this format to match the linux
-// format to avoid extra copying/allocation when writing events to userspace.
-type Event struct {
- // Events is the event mask containing the set of events that have been
- // observed on an entry.
- Events uint32
-
- // Data is an opaque 64-bit value provided by the caller when adding the
- // entry, and returned to the caller when the entry reports an event.
- Data [2]int32
-}
-
// EntryFlags is a bitmask that holds an entry's flags.
type EntryFlags int
@@ -88,8 +76,8 @@ type pollEntry struct {
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
// weakReferenceGone is called when the file in the weak reference is destroyed.
// The poll entry is removed in response to this.
-func (p *pollEntry) WeakRefGone() {
- p.epoll.RemoveEntry(p.id)
+func (p *pollEntry) WeakRefGone(ctx context.Context) {
+ p.epoll.RemoveEntry(ctx, p.id)
}
// EventPoll holds all the state associated with an event poll object, that is,
@@ -119,7 +107,7 @@ type EventPoll struct {
// different lock to avoid circular lock acquisition order involving
// the wait queue mutexes and mu. The full order is mu, observed file
// wait queue mutex, then listsMu; this allows listsMu to be acquired
- // when readyCallback is called.
+ // when (*pollEntry).Callback is called.
//
// An entry is always in one of the following lists:
// readyList -- when there's a chance that it's ready to have
@@ -128,7 +116,7 @@ type EventPoll struct {
// readEvents() functions always call the entry's file
// Readiness() function to confirm it's ready.
// waitingList -- when there's no chance that the entry is ready,
- // so it's waiting for the readyCallback to be called
+ // so it's waiting for the (*pollEntry).Callback to be called
// on it before it gets moved to the readyList.
// disabledList -- when the entry is disabled. This happens when
// a one-shot entry gets delivered via readEvents().
@@ -156,14 +144,14 @@ func NewEventPoll(ctx context.Context) *fs.File {
// name matches fs/eventpoll.c:epoll_create1.
dirent := fs.NewDirent(ctx, anon.NewInode(ctx), fmt.Sprintf("anon_inode:[eventpoll]"))
// Release the initial dirent reference after NewFile takes a reference.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return fs.NewFile(ctx, dirent, fs.FileFlags{}, &EventPoll{
files: make(map[FileIdentifier]*pollEntry),
})
}
// Release implements fs.FileOperations.Release.
-func (e *EventPoll) Release() {
+func (e *EventPoll) Release(ctx context.Context) {
// We need to take the lock now because files may be attempting to
// remove entries in parallel if they get destroyed.
e.mu.Lock()
@@ -172,8 +160,9 @@ func (e *EventPoll) Release() {
// Go through all entries and clean up.
for _, entry := range e.files {
entry.id.File.EventUnregister(&entry.waiter)
- entry.file.Drop()
+ entry.file.Drop(ctx)
}
+ e.files = nil
}
// Read implements fs.FileOperations.Read.
@@ -226,9 +215,9 @@ func (e *EventPoll) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// ReadEvents returns up to max available events.
-func (e *EventPoll) ReadEvents(max int) []Event {
+func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
var local pollEntryList
- var ret []Event
+ var ret []linux.EpollEvent
e.listsMu.Lock()
@@ -250,7 +239,7 @@ func (e *EventPoll) ReadEvents(max int) []Event {
}
// Add event to the array that will be returned to caller.
- ret = append(ret, Event{
+ ret = append(ret, linux.EpollEvent{
Events: uint32(ready),
Data: entry.userData,
})
@@ -280,23 +269,23 @@ func (e *EventPoll) ReadEvents(max int) []Event {
return ret
}
-// readyCallback is called when one of the files we're polling becomes ready. It
-// moves said file to the readyList if it's currently in the waiting list.
-type readyCallback struct{}
-
// Callback implements waiter.EntryCallback.Callback.
-func (*readyCallback) Callback(w *waiter.Entry) {
- entry := w.Context.(*pollEntry)
- e := entry.epoll
+//
+// Callback is called when one of the files we're polling becomes ready. It
+// moves said file to the readyList if it's currently in the waiting list.
+func (p *pollEntry) Callback(*waiter.Entry) {
+ e := p.epoll
e.listsMu.Lock()
- if entry.curList == &e.waitingList {
- e.waitingList.Remove(entry)
- e.readyList.PushBack(entry)
- entry.curList = &e.readyList
+ if p.curList == &e.waitingList {
+ e.waitingList.Remove(p)
+ e.readyList.PushBack(p)
+ p.curList = &e.readyList
+ e.listsMu.Unlock()
e.Notify(waiter.EventIn)
+ return
}
e.listsMu.Unlock()
@@ -319,7 +308,7 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
// Check if the file happens to already be in a ready state.
ready := f.Readiness(entry.mask) & entry.mask
if ready != 0 {
- (*readyCallback).Callback(nil, &entry.waiter)
+ entry.Callback(&entry.waiter)
}
}
@@ -389,10 +378,9 @@ func (e *EventPoll) AddEntry(id FileIdentifier, flags EntryFlags, mask waiter.Ev
userData: data,
epoll: e,
flags: flags,
- waiter: waiter.Entry{Callback: &readyCallback{}},
mask: mask,
}
- entry.waiter.Context = entry
+ entry.waiter.Callback = entry
e.files[id] = entry
entry.file = refs.NewWeakRef(id.File, entry)
@@ -415,7 +403,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter
}
// Unregister the old mask and remove entry from the list it's in, so
- // readyCallback is guaranteed to not be called on this entry anymore.
+ // (*pollEntry).Callback is guaranteed to not be called on this entry anymore.
entry.id.File.EventUnregister(&entry.waiter)
// Remove entry from whatever list it's in. This ensure that no other
@@ -435,7 +423,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter
}
// RemoveEntry a files from the collection of observed files.
-func (e *EventPoll) RemoveEntry(id FileIdentifier) error {
+func (e *EventPoll) RemoveEntry(ctx context.Context, id FileIdentifier) error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -457,7 +445,7 @@ func (e *EventPoll) RemoveEntry(id FileIdentifier) error {
// Remove file from map, and drop weak reference.
delete(e.files, id)
- entry.file.Drop()
+ entry.file.Drop(ctx)
return nil
}
diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go
index a0d35d350..7c61e0258 100644
--- a/pkg/sentry/kernel/epoll/epoll_state.go
+++ b/pkg/sentry/kernel/epoll/epoll_state.go
@@ -21,8 +21,7 @@ import (
// afterLoad is invoked by stateify.
func (p *pollEntry) afterLoad() {
- p.waiter = waiter.Entry{Callback: &readyCallback{}}
- p.waiter.Context = p
+ p.waiter.Callback = p
p.file = refs.NewWeakRef(p.id.File, p)
p.id.File.EventRegister(&p.waiter, p.mask)
}
@@ -38,11 +37,14 @@ func (e *EventPoll) afterLoad() {
}
}
- for it := e.waitingList.Front(); it != nil; it = it.Next() {
- if it.id.File.Readiness(it.mask) != 0 {
- e.waitingList.Remove(it)
- e.readyList.PushBack(it)
- it.curList = &e.readyList
+ for it := e.waitingList.Front(); it != nil; {
+ entry := it
+ it = it.Next()
+
+ if entry.id.File.Readiness(entry.mask) != 0 {
+ e.waitingList.Remove(entry)
+ e.readyList.PushBack(entry)
+ entry.curList = &e.readyList
e.Notify(waiter.EventIn)
}
}
diff --git a/pkg/sentry/kernel/epoll/epoll_test.go b/pkg/sentry/kernel/epoll/epoll_test.go
index 4a20d4c82..55b505593 100644
--- a/pkg/sentry/kernel/epoll/epoll_test.go
+++ b/pkg/sentry/kernel/epoll/epoll_test.go
@@ -17,7 +17,7 @@ package epoll
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs/filetest"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -26,7 +26,8 @@ func TestFileDestroyed(t *testing.T) {
f := filetest.NewTestFile(t)
id := FileIdentifier{f, 12}
- efile := NewEventPoll(contexttest.Context(t))
+ ctx := contexttest.Context(t)
+ efile := NewEventPoll(ctx)
e := efile.FileOperations.(*EventPoll)
if err := e.AddEntry(id, 0, waiter.EventIn, [2]int32{}); err != nil {
t.Fatalf("addEntry failed: %v", err)
@@ -44,7 +45,7 @@ func TestFileDestroyed(t *testing.T) {
}
// Destroy the file. Check that we get no more events.
- f.DecRef()
+ f.DecRef(ctx)
evt = e.ReadEvents(1)
if len(evt) != 0 {
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index e65b961e8..9983a32e5 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -1,22 +1,21 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "eventfd",
srcs = ["eventfd.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/eventfd",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/fdnotifier",
- "//pkg/sentry/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -25,10 +24,10 @@ go_test(
name = "eventfd_test",
size = "small",
srcs = ["eventfd_test.go"],
- embed = [":eventfd"],
+ library = ":eventfd",
deps = [
- "//pkg/sentry/context/contexttest",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/contexttest",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/eventfd/eventfd.go b/pkg/sentry/kernel/eventfd/eventfd.go
index 12f0d429b..bbf568dfc 100644
--- a/pkg/sentry/kernel/eventfd/eventfd.go
+++ b/pkg/sentry/kernel/eventfd/eventfd.go
@@ -18,17 +18,17 @@ package eventfd
import (
"math"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -70,7 +70,7 @@ func New(ctx context.Context, initVal uint64, semMode bool) *fs.File {
// name matches fs/eventfd.c:eventfd_file_create.
dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[eventfd]")
// Release the initial dirent reference after NewFile takes a reference.
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &EventOperations{
val: initVal,
semMode: semMode,
@@ -106,7 +106,7 @@ func (e *EventOperations) HostFD() (int, error) {
}
// Release implements fs.FileOperations.Release.
-func (e *EventOperations) Release() {
+func (e *EventOperations) Release(context.Context) {
e.mu.Lock()
defer e.mu.Unlock()
if e.hostfd >= 0 {
diff --git a/pkg/sentry/kernel/eventfd/eventfd_test.go b/pkg/sentry/kernel/eventfd/eventfd_test.go
index 018c7f3ef..9b4892f74 100644
--- a/pkg/sentry/kernel/eventfd/eventfd_test.go
+++ b/pkg/sentry/kernel/eventfd/eventfd_test.go
@@ -17,8 +17,8 @@ package eventfd
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 49d81b712..2b3955598 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -1,17 +1,18 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "fasync",
srcs = ["fasync.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/fasync",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index 6b0bb0324..153d2cd9b 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -16,20 +16,25 @@
package fasync
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
-// New creates a new FileAsync.
+// New creates a new fs.FileAsync.
func New() fs.FileAsync {
return &FileAsync{}
}
+// NewVFS2 creates a new vfs.FileAsync.
+func NewVFS2() vfs.FileAsync {
+ return &FileAsync{}
+}
+
// FileAsync sends signals when the registered file is ready for IO.
//
// +stateify savable
@@ -171,3 +176,13 @@ func (a *FileAsync) SetOwnerProcessGroup(requester *kernel.Task, recipient *kern
a.recipientTG = nil
a.recipientPG = recipient
}
+
+// ClearOwner unsets the current signal recipient.
+func (a *FileAsync) ClearOwner() {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.requester = nil
+ a.recipientT = nil
+ a.recipientTG = nil
+ a.recipientPG = nil
+}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 11f613a11..ce53af69b 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -1,4 +1,4 @@
-// Copyright 2018 Google LLC
+// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,19 +15,21 @@
package kernel
import (
- "bytes"
"fmt"
"math"
- "sync"
+ "strings"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// FDFlags define flags for an individual descriptor.
@@ -62,10 +64,14 @@ func (f FDFlags) ToLinuxFDFlags() (mask uint) {
// Note that this is immutable and can only be changed via operations on the
// descriptorTable.
//
+// It contains both VFS1 and VFS2 file types, but only one of them can be set.
+//
// +stateify savable
type descriptor struct {
- file *fs.File
- flags FDFlags
+ // TODO(gvisor.dev/issue/1624): Remove fs.File.
+ file *fs.File
+ fileVFS2 *vfs.FileDescription
+ flags FDFlags
}
// FDTable is used to manage File references and flags.
@@ -75,9 +81,6 @@ type FDTable struct {
refs.AtomicRefCount
k *Kernel
- // uid is a unique identifier.
- uid uint64
-
// mu protects below.
mu sync.Mutex `state:"nosave"`
@@ -95,32 +98,38 @@ type FDTable struct {
func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
m := make(map[int32]descriptor)
- f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
+ f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
m[fd] = descriptor{
- file: file,
- flags: flags,
+ file: file,
+ fileVFS2: fileVFS2,
+ flags: flags,
}
})
return m
}
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
+ ctx := context.Background()
f.init() // Initialize table.
for fd, d := range m {
- f.set(fd, d.file, d.flags)
-
- // Note that we do _not_ need to acquire a extra table
- // reference here. The table reference will already be
- // accounted for in the file, so we drop the reference taken by
- // set above.
- d.file.DecRef()
+ f.setAll(fd, d.file, d.fileVFS2, d.flags)
+
+ // Note that we do _not_ need to acquire a extra table reference here. The
+ // table reference will already be accounted for in the file, so we drop the
+ // reference taken by set above.
+ switch {
+ case d.file != nil:
+ d.file.DecRef(ctx)
+ case d.fileVFS2 != nil:
+ d.fileVFS2.DecRef(ctx)
+ }
}
}
// drop drops the table reference.
func (f *FDTable) drop(file *fs.File) {
// Release locks.
- file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lock.UniqueID(f.uid), lock.LockRange{0, lock.LockEOF})
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF})
// Send inotify events.
d := file.Dirent
@@ -136,34 +145,47 @@ func (f *FDTable) drop(file *fs.File) {
d.InotifyEvent(ev, 0)
// Drop the table reference.
- file.DecRef()
+ file.DecRef(context.Background())
}
-// ID returns a unique identifier for this FDTable.
-func (f *FDTable) ID() uint64 {
- return f.uid
+// dropVFS2 drops the table reference.
+func (f *FDTable) dropVFS2(file *vfs.FileDescription) {
+ // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the
+ // entire file.
+ ctx := context.Background()
+ err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET)
+ if err != nil && err != syserror.ENOLCK {
+ panic(fmt.Sprintf("UnlockPOSIX failed: %v", err))
+ }
+
+ // Generate inotify events.
+ ev := uint32(linux.IN_CLOSE_NOWRITE)
+ if file.IsWritable() {
+ ev = linux.IN_CLOSE_WRITE
+ }
+ file.Dentry().InotifyWithParent(ctx, ev, 0, vfs.PathEvent)
+
+ // Drop the table's reference.
+ file.DecRef(ctx)
}
// NewFDTable allocates a new FDTable that may be used by tasks in k.
func (k *Kernel) NewFDTable() *FDTable {
- f := &FDTable{
- k: k,
- uid: atomic.AddUint64(&k.fdMapUids, 1),
- }
+ f := &FDTable{k: k}
f.init()
return f
}
// destroy removes all of the file descriptors from the map.
-func (f *FDTable) destroy() {
- f.RemoveIf(func(*fs.File, FDFlags) bool {
+func (f *FDTable) destroy(ctx context.Context) {
+ f.RemoveIf(ctx, func(*fs.File, *vfs.FileDescription, FDFlags) bool {
return true
})
}
// DecRef implements RefCounter.DecRef with destructor f.destroy.
-func (f *FDTable) DecRef() {
- f.DecRefWithDestructor(f.destroy)
+func (f *FDTable) DecRef(ctx context.Context) {
+ f.DecRefWithDestructor(ctx, f.destroy)
}
// Size returns the number of file descriptor slots currently allocated.
@@ -172,35 +194,66 @@ func (f *FDTable) Size() int {
return int(size)
}
-// forEach iterates over all non-nil files.
+// forEach iterates over all non-nil files in sorted order.
//
// It is the caller's responsibility to acquire an appropriate lock.
-func (f *FDTable) forEach(fn func(fd int32, file *fs.File, flags FDFlags)) {
+func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
+ // retries tracks the number of failed TryIncRef attempts for the same FD.
+ retries := 0
fd := int32(0)
for {
- file, flags, ok := f.get(fd)
+ file, fileVFS2, flags, ok := f.getAll(fd)
if !ok {
break
}
- if file != nil {
+ switch {
+ case file != nil:
if !file.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, FileOps: %+v", fd, file, file.FileOperations))
+ }
continue // Race caught.
}
- fn(int32(fd), file, flags)
- file.DecRef()
+ fn(fd, file, nil, flags)
+ file.DecRef(ctx)
+ case fileVFS2 != nil:
+ if !fileVFS2.TryIncRef() {
+ retries++
+ if retries > 1000 {
+ panic(fmt.Sprintf("File in FD table has been destroyed. FD: %d, File: %+v, Impl: %+v", fd, fileVFS2, fileVFS2.Impl()))
+ }
+ continue // Race caught.
+ }
+ fn(fd, nil, fileVFS2, flags)
+ fileVFS2.DecRef(ctx)
}
+ retries = 0
fd++
}
}
// String is a stringer for FDTable.
func (f *FDTable) String() string {
- var b bytes.Buffer
- f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
- n, _ := file.Dirent.FullName(nil /* root */)
- b.WriteString(fmt.Sprintf("\tfd:%d => name %s\n", fd, n))
+ var buf strings.Builder
+ ctx := context.Background()
+ f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ switch {
+ case file != nil:
+ n, _ := file.Dirent.FullName(nil /* root */)
+ fmt.Fprintf(&buf, "\tfd:%d => name %s\n", fd, n)
+
+ case fileVFS2 != nil:
+ vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem()
+ name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry())
+ if err != nil {
+ fmt.Fprintf(&buf, "<err: %v>\n", err)
+ return
+ }
+ fmt.Fprintf(&buf, "\tfd:%d => name %s\n", fd, name)
+ }
})
- return b.String()
+ return buf.String()
}
// NewFDs allocates new FDs guaranteed to be the lowest number available
@@ -258,18 +311,125 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available
+// greater than or equal to the fd parameter. All files will share the set
+// flags. Success is guaranteed to be all or none.
+func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return nil, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if fd >= end {
+ return nil, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ if fd < f.next {
+ fd = f.next
+ }
+
+ // Install all entries.
+ for i := fd; i < end && len(fds) < len(files); i++ {
+ if d, _, _ := f.getVFS2(i); d == nil {
+ f.setVFS2(i, files[len(fds)], flags) // Set the descriptor.
+ fds = append(fds, i) // Record the file descriptor.
+ }
+ }
+
+ // Failure? Unwind existing FDs.
+ if len(fds) < len(files) {
+ for _, i := range fds {
+ f.setVFS2(i, nil, FDFlags{}) // Zap entry.
+ }
+ return nil, syscall.EMFILE
+ }
+
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fds[len(fds)-1] + 1
+ }
+
+ return fds, nil
+}
+
+// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// the given file description. If it succeeds, it takes a reference on file.
+func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ if minfd < 0 {
+ // Don't accept negative FDs.
+ return -1, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if minfd >= end {
+ return -1, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ fd := minfd
+ if fd < f.next {
+ fd = f.next
+ }
+ for fd < end {
+ if d, _, _ := f.getVFS2(fd); d == nil {
+ f.setVFS2(fd, file, flags)
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fd + 1
+ }
+ return fd, nil
+ }
+ fd++
+ }
+ return -1, syscall.EMFILE
+}
+
// NewFDAt sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FDFlags) error {
+ return f.newFDAt(ctx, fd, file, nil, flags)
+}
+
+// NewFDAtVFS2 sets the file reference for the given FD. If there is an active
+// reference for that FD, the ref count for that existing reference is
+// decremented.
+func (f *FDTable) NewFDAtVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return f.newFDAt(ctx, fd, nil, file, flags)
+}
+
+func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) error {
if fd < 0 {
// Don't accept negative FDs.
return syscall.EBADF
}
- f.mu.Lock()
- defer f.mu.Unlock()
-
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
@@ -278,7 +438,9 @@ func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FD
}
// Install the entry.
- f.set(fd, file, flags)
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.setAll(fd, file, fileVFS2, flags)
return nil
}
@@ -305,6 +467,29 @@ func (f *FDTable) SetFlags(fd int32, flags FDFlags) error {
return nil
}
+// SetFlagsVFS2 sets the flags for the given file descriptor.
+//
+// True is returned iff flags were changed.
+func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error {
+ if fd < 0 {
+ // Don't accept negative FDs.
+ return syscall.EBADF
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ file, _, _ := f.getVFS2(fd)
+ if file == nil {
+ // No file found.
+ return syscall.EBADF
+ }
+
+ // Update the flags.
+ f.setVFS2(fd, file, flags)
+ return nil
+}
+
// Get returns a reference to the file and the flags for the FD or nil if no
// file is defined for the given fd.
//
@@ -330,10 +515,38 @@ func (f *FDTable) Get(fd int32) (*fs.File, FDFlags) {
}
}
-// GetFDs returns a list of valid fds.
-func (f *FDTable) GetFDs() []int32 {
+// GetVFS2 returns a reference to the file and the flags for the FD or nil if no
+// file is defined for the given fd.
+//
+// N.B. Callers are required to use DecRef when they are done.
+//
+//go:nosplit
+func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) {
+ if fd < 0 {
+ return nil, FDFlags{}
+ }
+
+ for {
+ file, flags, _ := f.getVFS2(fd)
+ if file != nil {
+ if !file.TryIncRef() {
+ continue // Race caught.
+ }
+ // Reference acquired.
+ return file, flags
+ }
+ // No file available.
+ return nil, FDFlags{}
+ }
+}
+
+// GetFDs returns a sorted list of valid fds.
+//
+// Precondition: The caller must be running on the task goroutine, or Task.mu
+// must be locked.
+func (f *FDTable) GetFDs(ctx context.Context) []int32 {
fds := make([]int32, 0, int(atomic.LoadInt32(&f.used)))
- f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
+ f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
fds = append(fds, fd)
})
return fds
@@ -342,9 +555,21 @@ func (f *FDTable) GetFDs() []int32 {
// GetRefs returns a stable slice of references to all files and bumps the
// reference count on each. The caller must use DecRef on each reference when
// they're done using the slice.
-func (f *FDTable) GetRefs() []*fs.File {
+func (f *FDTable) GetRefs(ctx context.Context) []*fs.File {
files := make([]*fs.File, 0, f.Size())
- f.forEach(func(_ int32, file *fs.File, flags FDFlags) {
+ f.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ file.IncRef() // Acquire a reference for caller.
+ files = append(files, file)
+ })
+ return files
+}
+
+// GetRefsVFS2 returns a stable slice of references to all files and bumps the
+// reference count on each. The caller must use DecRef on each reference when
+// they're done using the slice.
+func (f *FDTable) GetRefsVFS2(ctx context.Context) []*vfs.FileDescription {
+ files := make([]*vfs.FileDescription, 0, f.Size())
+ f.forEach(ctx, func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) {
file.IncRef() // Acquire a reference for caller.
files = append(files, file)
})
@@ -352,13 +577,18 @@ func (f *FDTable) GetRefs() []*fs.File {
}
// Fork returns an independent FDTable.
-func (f *FDTable) Fork() *FDTable {
+func (f *FDTable) Fork(ctx context.Context) *FDTable {
clone := f.k.NewFDTable()
- f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
+ f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
// The set function here will acquire an appropriate table
// reference for the clone. We don't need anything else.
- clone.set(fd, file, flags)
+ switch {
+ case file != nil:
+ clone.set(fd, file, flags)
+ case fileVFS2 != nil:
+ clone.setVFS2(fd, fileVFS2, flags)
+ }
})
return clone
}
@@ -366,9 +596,9 @@ func (f *FDTable) Fork() *FDTable {
// Remove removes an FD from and returns a non-file iff successful.
//
// N.B. Callers are required to use DecRef when they are done.
-func (f *FDTable) Remove(fd int32) *fs.File {
+func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) {
if fd < 0 {
- return nil
+ return nil, nil
}
f.mu.Lock()
@@ -379,21 +609,28 @@ func (f *FDTable) Remove(fd int32) *fs.File {
f.next = fd
}
- orig, _, _ := f.get(fd)
- if orig != nil {
- orig.IncRef() // Reference for caller.
- f.set(fd, nil, FDFlags{}) // Zap entry.
+ orig, orig2, _, _ := f.getAll(fd)
+
+ // Add reference for caller.
+ switch {
+ case orig != nil:
+ orig.IncRef()
+ case orig2 != nil:
+ orig2.IncRef()
+ }
+ if orig != nil || orig2 != nil {
+ f.setAll(fd, nil, nil, FDFlags{}) // Zap entry.
}
- return orig
+ return orig, orig2
}
// RemoveIf removes all FDs where cond is true.
-func (f *FDTable) RemoveIf(cond func(*fs.File, FDFlags) bool) {
+func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) {
f.mu.Lock()
defer f.mu.Unlock()
- f.forEach(func(fd int32, file *fs.File, flags FDFlags) {
- if cond(file, flags) {
+ f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ if cond(file, fileVFS2, flags) {
f.set(fd, nil, FDFlags{}) // Clear from table.
// Update current available position.
if fd < f.next {
diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go
index 2bcb6216a..e3f30ba2a 100644
--- a/pkg/sentry/kernel/fd_table_test.go
+++ b/pkg/sentry/kernel/fd_table_test.go
@@ -1,4 +1,4 @@
-// Copyright 2018 Google LLC
+// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,14 +16,14 @@ package kernel
import (
"runtime"
- "sync"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/filetest"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
@@ -150,13 +150,13 @@ func TestFDTable(t *testing.T) {
t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref)
}
- ref := fdTable.Remove(1)
+ ref, _ := fdTable.Remove(1)
if ref == nil {
t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success")
}
- ref.DecRef()
+ ref.DecRef(ctx)
- if ref := fdTable.Remove(1); ref != nil {
+ if ref, _ := fdTable.Remove(1); ref != nil {
t.Fatalf("r.Remove(1) for a removed FD: got success, want failure")
}
})
@@ -191,7 +191,7 @@ func BenchmarkFDLookupAndDecRef(b *testing.B) {
b.StartTimer() // Benchmark.
for i := 0; i < b.N; i++ {
tf, _ := fdTable.Get(fds[i%len(fds)])
- tf.DecRef()
+ tf.DecRef(ctx)
}
})
}
@@ -219,7 +219,7 @@ func BenchmarkFDLookupAndDecRefConcurrent(b *testing.B) {
defer wg.Done()
for i := 0; i < each; i++ {
tf, _ := fdTable.Get(fds[i%len(fds)])
- tf.DecRef()
+ tf.DecRef(ctx)
}
}()
}
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index e009df974..7fd97dc53 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -1,4 +1,4 @@
-// Copyright 2018 Google LLC
+// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ import (
"unsafe"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
)
type descriptorTable struct {
@@ -41,15 +42,38 @@ func (f *FDTable) init() {
//
//go:nosplit
func (f *FDTable) get(fd int32) (*fs.File, FDFlags, bool) {
+ file, _, flags, ok := f.getAll(fd)
+ return file, flags, ok
+}
+
+// getVFS2 gets a file entry.
+//
+// The boolean indicates whether this was in range.
+//
+//go:nosplit
+func (f *FDTable) getVFS2(fd int32) (*vfs.FileDescription, FDFlags, bool) {
+ _, file, flags, ok := f.getAll(fd)
+ return file, flags, ok
+}
+
+// getAll gets a file entry.
+//
+// The boolean indicates whether this was in range.
+//
+//go:nosplit
+func (f *FDTable) getAll(fd int32) (*fs.File, *vfs.FileDescription, FDFlags, bool) {
slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
if fd >= int32(len(slice)) {
- return nil, FDFlags{}, false
+ return nil, nil, FDFlags{}, false
}
d := (*descriptor)(atomic.LoadPointer(&slice[fd]))
if d == nil {
- return nil, FDFlags{}, true
+ return nil, nil, FDFlags{}, true
}
- return d.file, d.flags, true
+ if d.file != nil && d.fileVFS2 != nil {
+ panic("VFS1 and VFS2 files set")
+ }
+ return d.file, d.fileVFS2, d.flags, true
}
// set sets an entry.
@@ -59,6 +83,30 @@ func (f *FDTable) get(fd int32) (*fs.File, FDFlags, bool) {
//
// Precondition: mu must be held.
func (f *FDTable) set(fd int32, file *fs.File, flags FDFlags) {
+ f.setAll(fd, file, nil, flags)
+}
+
+// setVFS2 sets an entry.
+//
+// This handles accounting changes, as well as acquiring and releasing the
+// reference needed by the table iff the file is different.
+//
+// Precondition: mu must be held.
+func (f *FDTable) setVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) {
+ f.setAll(fd, nil, file, flags)
+}
+
+// setAll sets an entry.
+//
+// This handles accounting changes, as well as acquiring and releasing the
+// reference needed by the table iff the file is different.
+//
+// Precondition: mu must be held.
+func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
+ if file != nil && fileVFS2 != nil {
+ panic("VFS1 and VFS2 files set")
+ }
+
slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
// Grow the table as required.
@@ -71,33 +119,51 @@ func (f *FDTable) set(fd int32, file *fs.File, flags FDFlags) {
atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
}
- // Create the new element.
- var d *descriptor
- if file != nil {
- d = &descriptor{
- file: file,
- flags: flags,
+ var desc *descriptor
+ if file != nil || fileVFS2 != nil {
+ desc = &descriptor{
+ file: file,
+ fileVFS2: fileVFS2,
+ flags: flags,
}
}
// Update the single element.
- orig := (*descriptor)(atomic.SwapPointer(&slice[fd], unsafe.Pointer(d)))
+ orig := (*descriptor)(atomic.SwapPointer(&slice[fd], unsafe.Pointer(desc)))
// Acquire a table reference.
- if file != nil && (orig == nil || file != orig.file) {
- file.IncRef()
+ if desc != nil {
+ switch {
+ case desc.file != nil:
+ if orig == nil || desc.file != orig.file {
+ desc.file.IncRef()
+ }
+ case desc.fileVFS2 != nil:
+ if orig == nil || desc.fileVFS2 != orig.fileVFS2 {
+ desc.fileVFS2.IncRef()
+ }
+ }
}
// Drop the table reference.
- if orig != nil && file != orig.file {
- f.drop(orig.file)
+ if orig != nil {
+ switch {
+ case orig.file != nil:
+ if desc == nil || desc.file != orig.file {
+ f.drop(orig.file)
+ }
+ case orig.fileVFS2 != nil:
+ if desc == nil || desc.fileVFS2 != orig.fileVFS2 {
+ f.dropVFS2(orig.fileVFS2)
+ }
+ }
}
// Adjust used.
switch {
- case orig == nil && file != nil:
+ case orig == nil && desc != nil:
atomic.AddInt32(&f.used, 1)
- case orig != nil && file == nil:
+ case orig != nil && desc == nil:
atomic.AddInt32(&f.used, -1)
}
}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index ded27d668..8f2d36d5a 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -16,10 +16,12 @@ package kernel
import (
"fmt"
- "sync"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
)
// FSContext contains filesystem context.
@@ -37,10 +39,16 @@ type FSContext struct {
// destroyed.
root *fs.Dirent
+ // rootVFS2 is the filesystem root.
+ rootVFS2 vfs.VirtualDentry
+
// cwd is the current working directory. Will be nil iff the FSContext
// has been destroyed.
cwd *fs.Dirent
+ // cwdVFS2 is the current working directory.
+ cwdVFS2 vfs.VirtualDentry
+
// umask is the current file mode creation mask. When a thread using this
// context invokes a syscall that creates a file, bits set in umask are
// removed from the permissions that the file is created with.
@@ -60,6 +68,19 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext {
return &f
}
+// NewFSContextVFS2 returns a new filesystem context.
+func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext {
+ root.IncRef()
+ cwd.IncRef()
+ f := FSContext{
+ rootVFS2: root,
+ cwdVFS2: cwd,
+ umask: umask,
+ }
+ f.EnableLeakCheck("kernel.FSContext")
+ return &f
+}
+
// destroy is the destructor for an FSContext.
//
// This will call DecRef on both root and cwd Dirents. If either call to
@@ -69,22 +90,28 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext {
// Note that there may still be calls to WorkingDirectory() or RootDirectory()
// (that return nil). This is because valid references may still be held via
// proc files or other mechanisms.
-func (f *FSContext) destroy() {
+func (f *FSContext) destroy(ctx context.Context) {
// Hold f.mu so that we don't race with RootDirectory() and
// WorkingDirectory().
f.mu.Lock()
defer f.mu.Unlock()
- f.root.DecRef()
- f.root = nil
-
- f.cwd.DecRef()
- f.cwd = nil
+ if VFS2Enabled {
+ f.rootVFS2.DecRef(ctx)
+ f.rootVFS2 = vfs.VirtualDentry{}
+ f.cwdVFS2.DecRef(ctx)
+ f.cwdVFS2 = vfs.VirtualDentry{}
+ } else {
+ f.root.DecRef(ctx)
+ f.root = nil
+ f.cwd.DecRef(ctx)
+ f.cwd = nil
+ }
}
// DecRef implements RefCounter.DecRef with destructor f.destroy.
-func (f *FSContext) DecRef() {
- f.DecRefWithDestructor(f.destroy)
+func (f *FSContext) DecRef(ctx context.Context) {
+ f.DecRefWithDestructor(ctx, f.destroy)
}
// Fork forks this FSContext.
@@ -93,12 +120,21 @@ func (f *FSContext) DecRef() {
func (f *FSContext) Fork() *FSContext {
f.mu.Lock()
defer f.mu.Unlock()
- f.cwd.IncRef()
- f.root.IncRef()
+
+ if VFS2Enabled {
+ f.cwdVFS2.IncRef()
+ f.rootVFS2.IncRef()
+ } else {
+ f.cwd.IncRef()
+ f.root.IncRef()
+ }
+
return &FSContext{
- cwd: f.cwd,
- root: f.root,
- umask: f.umask,
+ cwd: f.cwd,
+ root: f.root,
+ cwdVFS2: f.cwdVFS2,
+ rootVFS2: f.rootVFS2,
+ umask: f.umask,
}
}
@@ -109,17 +145,28 @@ func (f *FSContext) Fork() *FSContext {
func (f *FSContext) WorkingDirectory() *fs.Dirent {
f.mu.Lock()
defer f.mu.Unlock()
- if f.cwd != nil {
- f.cwd.IncRef()
- }
+
+ f.cwd.IncRef()
return f.cwd
}
+// WorkingDirectoryVFS2 returns the current working directory.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.cwdVFS2.IncRef()
+ return f.cwdVFS2
+}
+
// SetWorkingDirectory sets the current working directory.
// This will take an extra reference on the Dirent.
//
// This is not a valid call after destroy.
-func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
+func (f *FSContext) SetWorkingDirectory(ctx context.Context, d *fs.Dirent) {
if d == nil {
panic("FSContext.SetWorkingDirectory called with nil dirent")
}
@@ -134,7 +181,21 @@ func (f *FSContext) SetWorkingDirectory(d *fs.Dirent) {
old := f.cwd
f.cwd = d
d.IncRef()
- old.DecRef()
+ old.DecRef(ctx)
+}
+
+// SetWorkingDirectoryVFS2 sets the current working directory.
+// This will take an extra reference on the VirtualDentry.
+//
+// This is not a valid call after destroy.
+func (f *FSContext) SetWorkingDirectoryVFS2(ctx context.Context, d vfs.VirtualDentry) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ old := f.cwdVFS2
+ f.cwdVFS2 = d
+ d.IncRef()
+ old.DecRef(ctx)
}
// RootDirectory returns the current filesystem root.
@@ -150,11 +211,23 @@ func (f *FSContext) RootDirectory() *fs.Dirent {
return f.root
}
+// RootDirectoryVFS2 returns the current filesystem root.
+//
+// This will return nil if called after destroy(), otherwise it will return a
+// Dirent with a reference taken.
+func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ f.rootVFS2.IncRef()
+ return f.rootVFS2
+}
+
// SetRootDirectory sets the root directory.
// This will take an extra reference on the Dirent.
//
// This is not a valid call after free.
-func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
+func (f *FSContext) SetRootDirectory(ctx context.Context, d *fs.Dirent) {
if d == nil {
panic("FSContext.SetRootDirectory called with nil dirent")
}
@@ -169,7 +242,29 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
old := f.root
f.root = d
d.IncRef()
- old.DecRef()
+ old.DecRef(ctx)
+}
+
+// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectoryVFS2(ctx context.Context, vd vfs.VirtualDentry) {
+ if !vd.Ok() {
+ panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
+ }
+
+ f.mu.Lock()
+
+ if !f.rootVFS2.Ok() {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd))
+ }
+
+ old := f.rootVFS2
+ vd.IncRef()
+ f.rootVFS2 = vd
+ f.mu.Unlock()
+ old.DecRef(ctx)
}
// Umask returns the current umask.
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 34286c7a8..daa2dae76 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,7 +8,7 @@ go_template_instance(
out = "atomicptr_bucket_unsafe.go",
package = "futex",
suffix = "Bucket",
- template = "//third_party/gvsync:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "bucket",
},
@@ -34,15 +33,15 @@ go_library(
"futex.go",
"waiter_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/futex",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
- "//pkg/sentry/context",
"//pkg/sentry/memmap",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
@@ -50,6 +49,10 @@ go_test(
name = "futex_test",
size = "small",
srcs = ["futex_test.go"],
- embed = [":futex"],
- deps = ["//pkg/sentry/usermem"],
+ library = ":futex",
+ deps = [
+ "//pkg/context",
+ "//pkg/sync",
+ "//pkg/usermem",
+ ],
)
diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go
index 278cc8143..e4dcc4d40 100644
--- a/pkg/sentry/kernel/futex/futex.go
+++ b/pkg/sentry/kernel/futex/futex.go
@@ -18,12 +18,12 @@
package futex
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// KeyKind indicates the type of a Key.
@@ -67,9 +67,9 @@ type Key struct {
Offset uint64
}
-func (k *Key) release() {
+func (k *Key) release(t Target) {
if k.MappingIdentity != nil {
- k.MappingIdentity.DecRef()
+ k.MappingIdentity.DecRef(t)
}
k.Mappable = nil
k.MappingIdentity = nil
@@ -95,6 +95,8 @@ func (k *Key) matches(k2 *Key) bool {
// Target abstracts memory accesses and keys.
type Target interface {
+ context.Context
+
// SwapUint32 gives access to usermem.IO.SwapUint32.
SwapUint32(addr usermem.Addr, new uint32) (uint32, error)
@@ -297,7 +299,7 @@ func (b *bucket) wakeWaiterLocked(w *Waiter) {
// bucket "to".
//
// Preconditions: b and to must be locked.
-func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int {
+func (b *bucket) requeueLocked(t Target, to *bucket, key, nkey *Key, n int) int {
done := 0
for w := b.waiters.Front(); done < n && w != nil; {
if !w.key.matches(key) {
@@ -309,7 +311,7 @@ func (b *bucket) requeueLocked(to *bucket, key, nkey *Key, n int) int {
requeued := w
w = w.Next() // Next iteration.
b.waiters.Remove(requeued)
- requeued.key.release()
+ requeued.key.release(t)
requeued.key = nkey.clone()
to.waiters.PushBack(requeued)
requeued.bucket.Store(to)
@@ -457,7 +459,7 @@ func (m *Manager) Wake(t Target, addr usermem.Addr, private bool, bitmask uint32
r := b.wakeLocked(&k, bitmask, n)
b.mu.Unlock()
- k.release()
+ k.release(t)
return r, nil
}
@@ -466,12 +468,12 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch
if err != nil {
return 0, err
}
- defer k1.release()
+ defer k1.release(t)
k2, err := getKey(t, naddr, private)
if err != nil {
return 0, err
}
- defer k2.release()
+ defer k2.release(t)
b1, b2 := m.lockBuckets(&k1, &k2)
defer b1.mu.Unlock()
@@ -489,7 +491,7 @@ func (m *Manager) doRequeue(t Target, addr, naddr usermem.Addr, private bool, ch
done := b1.wakeLocked(&k1, ^uint32(0), nwake)
// Requeue the number required.
- b1.requeueLocked(b2, &k1, &k2, nreq)
+ b1.requeueLocked(t, b2, &k1, &k2, nreq)
return done, nil
}
@@ -516,12 +518,12 @@ func (m *Manager) WakeOp(t Target, addr1, addr2 usermem.Addr, private bool, nwak
if err != nil {
return 0, err
}
- defer k1.release()
+ defer k1.release(t)
k2, err := getKey(t, addr2, private)
if err != nil {
return 0, err
}
- defer k2.release()
+ defer k2.release(t)
b1, b2 := m.lockBuckets(&k1, &k2)
defer b1.mu.Unlock()
@@ -572,7 +574,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo
// Perform our atomic check.
if err := check(t, addr, val); err != nil {
b.mu.Unlock()
- w.key.release()
+ w.key.release(t)
return err
}
@@ -586,7 +588,7 @@ func (m *Manager) WaitPrepare(w *Waiter, t Target, addr usermem.Addr, private bo
// WaitComplete must be called when a Waiter previously added by WaitPrepare is
// no longer eligible to be woken.
-func (m *Manager) WaitComplete(w *Waiter) {
+func (m *Manager) WaitComplete(w *Waiter, t Target) {
// Remove w from the bucket it's in.
for {
b := w.bucket.Load()
@@ -618,7 +620,7 @@ func (m *Manager) WaitComplete(w *Waiter) {
}
// Release references held by the waiter.
- w.key.release()
+ w.key.release(t)
}
// LockPI attempts to lock the futex following the Priority-inheritance futex
@@ -649,13 +651,13 @@ func (m *Manager) LockPI(w *Waiter, t Target, addr usermem.Addr, tid uint32, pri
success, err := m.lockPILocked(w, t, addr, tid, b, try)
if err != nil {
- w.key.release()
+ w.key.release(t)
b.mu.Unlock()
return false, err
}
if success || try {
// Release waiter if it's not going to be a wait.
- w.key.release()
+ w.key.release(t)
}
b.mu.Unlock()
return success, nil
@@ -718,10 +720,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3
}
}
-// UnlockPI unlock the futex following the Priority-inheritance futex
-// rules. The address provided must contain the caller's TID. If there are
-// waiters, TID of the next waiter (FIFO) is set to the given address, and the
-// waiter woken up. If there are no waiters, 0 is set to the address.
+// UnlockPI unlocks the futex following the Priority-inheritance futex rules.
+// The address provided must contain the caller's TID. If there are waiters,
+// TID of the next waiter (FIFO) is set to the given address, and the waiter
+// woken up. If there are no waiters, 0 is set to the address.
func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error {
k, err := getKey(t, addr, private)
if err != nil {
@@ -731,7 +733,7 @@ func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool
err = m.unlockPILocked(t, addr, tid, b, &k)
- k.release()
+ k.release(t)
b.mu.Unlock()
return err
}
diff --git a/pkg/sentry/kernel/futex/futex_test.go b/pkg/sentry/kernel/futex/futex_test.go
index 65e5d1428..d0128c548 100644
--- a/pkg/sentry/kernel/futex/futex_test.go
+++ b/pkg/sentry/kernel/futex/futex_test.go
@@ -17,40 +17,46 @@ package futex
import (
"math"
"runtime"
- "sync"
"sync/atomic"
"syscall"
"testing"
"unsafe"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// testData implements the Target interface, and allows us to
// treat the address passed for futex operations as an index in
// a byte slice for testing simplicity.
-type testData []byte
+type testData struct {
+ context.Context
+ data []byte
+}
const sizeofInt32 = 4
func newTestData(size uint) testData {
- return make([]byte, size)
+ return testData{
+ data: make([]byte, size),
+ }
}
func (t testData) SwapUint32(addr usermem.Addr, new uint32) (uint32, error) {
- val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t[addr])), new)
+ val := atomic.SwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), new)
return val, nil
}
func (t testData) CompareAndSwapUint32(addr usermem.Addr, old, new uint32) (uint32, error) {
- if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t[addr])), old, new) {
+ if atomic.CompareAndSwapUint32((*uint32)(unsafe.Pointer(&t.data[addr])), old, new) {
return old, nil
}
- return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil
}
func (t testData) LoadUint32(addr usermem.Addr) (uint32, error) {
- return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t[addr]))), nil
+ return atomic.LoadUint32((*uint32)(unsafe.Pointer(&t.data[addr]))), nil
}
func (t testData) GetSharedKey(addr usermem.Addr) (Key, error) {
@@ -83,7 +89,7 @@ func TestFutexWake(t *testing.T) {
// Start waiting for wakeup.
w := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w)
+ defer m.WaitComplete(w, d)
// Perform a wakeup.
if n, err := m.Wake(d, 0, private, ^uint32(0), 1); err != nil || n != 1 {
@@ -106,7 +112,7 @@ func TestFutexWakeBitmask(t *testing.T) {
// Start waiting for wakeup.
w := newPreparedTestWaiter(t, m, d, 0, private, 0, 0x0000ffff)
- defer m.WaitComplete(w)
+ defer m.WaitComplete(w, d)
// Perform a wakeup using the wrong bitmask.
if n, err := m.Wake(d, 0, private, 0xffff0000, 1); err != nil || n != 0 {
@@ -141,7 +147,7 @@ func TestFutexWakeTwo(t *testing.T) {
var ws [3]*Waiter
for i := range ws {
ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(ws[i])
+ defer m.WaitComplete(ws[i], d)
}
// Perform two wakeups.
@@ -174,9 +180,9 @@ func TestFutexWakeUnrelated(t *testing.T) {
// Start two waiters waiting for wakeup on different addresses.
w1 := newPreparedTestWaiter(t, m, d, 0*sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 1*sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform two wakeups on the second address.
if n, err := m.Wake(d, 1*sizeofInt32, private, ^uint32(0), 2); err != nil || n != 1 {
@@ -216,9 +222,9 @@ func TestWakeOpFirstNonEmpty(t *testing.T) {
// Add two waiters on address 0.
w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform 10 wakeups on address 0.
if n, err := m.WakeOp(d, 0, sizeofInt32, private, 10, 0, 0); err != nil || n != 2 {
@@ -244,9 +250,9 @@ func TestWakeOpSecondNonEmpty(t *testing.T) {
// Add two waiters on address sizeofInt32.
w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform 10 wakeups on address sizeofInt32 (contingent on
// d.Op(0), which should succeed).
@@ -273,9 +279,9 @@ func TestWakeOpSecondNonEmptyFailingOp(t *testing.T) {
// Add two waiters on address sizeofInt32.
w1 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Perform 10 wakeups on address sizeofInt32 (contingent on
// d.Op(1), which should fail).
@@ -302,15 +308,15 @@ func TestWakeOpAllNonEmpty(t *testing.T) {
// Add two waiters on address 0.
w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Add two waiters on address sizeofInt32.
w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w3)
+ defer m.WaitComplete(w3, d)
w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w4)
+ defer m.WaitComplete(w4, d)
// Perform 10 wakeups on address 0 (unconditionally), and 10
// wakeups on address sizeofInt32 (contingent on d.Op(0), which
@@ -344,15 +350,15 @@ func TestWakeOpAllNonEmptyFailingOp(t *testing.T) {
// Add two waiters on address 0.
w1 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w1)
+ defer m.WaitComplete(w1, d)
w2 := newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(w2)
+ defer m.WaitComplete(w2, d)
// Add two waiters on address sizeofInt32.
w3 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w3)
+ defer m.WaitComplete(w3, d)
w4 := newPreparedTestWaiter(t, m, d, sizeofInt32, private, 0, ^uint32(0))
- defer m.WaitComplete(w4)
+ defer m.WaitComplete(w4, d)
// Perform 10 wakeups on address 0 (unconditionally), and 10
// wakeups on address sizeofInt32 (contingent on d.Op(1), which
@@ -388,7 +394,7 @@ func TestWakeOpSameAddress(t *testing.T) {
var ws [4]*Waiter
for i := range ws {
ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(ws[i])
+ defer m.WaitComplete(ws[i], d)
}
// Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup
@@ -422,7 +428,7 @@ func TestWakeOpSameAddressFailingOp(t *testing.T) {
var ws [4]*Waiter
for i := range ws {
ws[i] = newPreparedTestWaiter(t, m, d, 0, private, 0, ^uint32(0))
- defer m.WaitComplete(ws[i])
+ defer m.WaitComplete(ws[i], d)
}
// Perform 1 wakeup on address 0 (unconditionally), and 1 wakeup
@@ -472,7 +478,7 @@ func (t *testMutex) Lock() {
for {
// Attempt to grab the lock.
if atomic.CompareAndSwapUint32(
- (*uint32)(unsafe.Pointer(&t.d[t.a])),
+ (*uint32)(unsafe.Pointer(&t.d.data[t.a])),
testMutexUnlocked,
testMutexLocked) {
// Lock held.
@@ -490,7 +496,7 @@ func (t *testMutex) Lock() {
panic("WaitPrepare returned unexpected error: " + err.Error())
}
<-w.C
- t.m.WaitComplete(w)
+ t.m.WaitComplete(w, t.d)
}
}
@@ -498,7 +504,7 @@ func (t *testMutex) Lock() {
// This will notify any waiters via the futex manager.
func (t *testMutex) Unlock() {
// Unlock.
- atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d[t.a])), testMutexUnlocked)
+ atomic.StoreUint32((*uint32)(unsafe.Pointer(&t.d.data[t.a])), testMutexUnlocked)
// Notify all waiters.
t.m.Wake(t.d, t.a, true, ^uint32(0), math.MaxInt32)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index fcfe7a16d..1028d13c6 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -34,21 +34,25 @@ package kernel
import (
"errors"
"fmt"
- "io"
"path/filepath"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/eventchannel"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
+ oldtimerfd "gvisor.dev/gvisor/pkg/sentry/fs/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -66,10 +70,21 @@ import (
"gvisor.dev/gvisor/pkg/sentry/unimpl"
uspb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// VFS2Enabled is set to true when VFS2 is enabled. Added as a global for allow
+// easy access everywhere. To be removed once VFS2 becomes the default.
+var VFS2Enabled = false
+
+// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow
+// easy access everywhere. To be removed once FUSE is completed.
+var FUSEEnabled = false
+
// Kernel represents an emulated Linux kernel. It must be initialized by calling
// Init() or LoadFrom().
//
@@ -104,7 +119,7 @@ type Kernel struct {
timekeeper *Timekeeper
tasks *TaskSet
rootUserNamespace *auth.UserNamespace
- networkStack inet.Stack `state:"nosave"`
+ rootNetworkNamespace *inet.Namespace
applicationCores uint
useHostCores bool
extraAuxv []arch.AuxEntry
@@ -183,11 +198,6 @@ type Kernel struct {
// cpuClockTickerSetting is protected by runningTasksMu.
cpuClockTickerSetting ktime.Setting
- // fdMapUids is an ever-increasing counter for generating FDTable uids.
- //
- // fdMapUids is mutable, and is accessed using atomic memory operations.
- fdMapUids uint64
-
// uniqueID is used to generate unique identifiers.
//
// uniqueID is mutable, and is accessed using atomic memory operations.
@@ -234,6 +244,36 @@ type Kernel struct {
// events. This is initialized lazily on the first unimplemented
// syscall.
unimplementedSyscallEmitter eventchannel.Emitter `state:"nosave"`
+
+ // SpecialOpts contains special kernel options.
+ SpecialOpts
+
+ // VFS keeps the filesystem state used across the kernel.
+ vfs vfs.VirtualFilesystem
+
+ // hostMount is the Mount used for file descriptors that were imported
+ // from the host.
+ hostMount *vfs.Mount
+
+ // pipeMount is the Mount used for pipes created by the pipe() and pipe2()
+ // syscalls (as opposed to named pipes created by mknod()).
+ pipeMount *vfs.Mount
+
+ // shmMount is the Mount used for anonymous files created by the
+ // memfd_create() syscalls. It is analagous to Linux's shm_mnt.
+ shmMount *vfs.Mount
+
+ // socketMount is the Mount used for sockets created by the socket() and
+ // socketpair() syscalls. There are several cases where a socket dentry will
+ // not be contained in socketMount:
+ // 1. Socket files created by mknod()
+ // 2. Socket fds imported from the host (Kernel.hostMount is used for these)
+ // 3. Socket files created by binding Unix sockets to a file path
+ socketMount *vfs.Mount
+
+ // If set to true, report address space activation waits as if the task is in
+ // external wait so that the watchdog doesn't report the task stuck.
+ SleepForAddressSpaceActivation bool
}
// InitKernelArgs holds arguments to Init.
@@ -247,8 +287,9 @@ type InitKernelArgs struct {
// RootUserNamespace is the root user namespace.
RootUserNamespace *auth.UserNamespace
- // NetworkStack is the TCP/IP network stack. NetworkStack may be nil.
- NetworkStack inet.Stack
+ // RootNetworkNamespace is the root network namespace. If nil, no networking
+ // will be available.
+ RootNetworkNamespace *inet.Namespace
// ApplicationCores is the number of logical CPUs visible to sandboxed
// applications. The set of logical CPU IDs is [0, ApplicationCores); thus
@@ -293,6 +334,9 @@ func (k *Kernel) Init(args InitKernelArgs) error {
if args.Timekeeper == nil {
return fmt.Errorf("Timekeeper is nil")
}
+ if args.Timekeeper.clocks == nil {
+ return fmt.Errorf("Must call Timekeeper.SetClocks() before Kernel.Init()")
+ }
if args.RootUserNamespace == nil {
return fmt.Errorf("RootUserNamespace is nil")
}
@@ -307,7 +351,10 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.rootUTSNamespace = args.RootUTSNamespace
k.rootIPCNamespace = args.RootIPCNamespace
k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
- k.networkStack = args.NetworkStack
+ k.rootNetworkNamespace = args.RootNetworkNamespace
+ if k.rootNetworkNamespace == nil {
+ k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil)
+ }
k.applicationCores = args.ApplicationCores
if args.UseHostCores {
k.useHostCores = true
@@ -327,13 +374,55 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
+
+ if VFS2Enabled {
+ ctx := k.SupervisorContext()
+ if err := k.vfs.Init(ctx); err != nil {
+ return fmt.Errorf("failed to initialize VFS: %v", err)
+ }
+
+ pipeFilesystem, err := pipefs.NewFilesystem(&k.vfs)
+ if err != nil {
+ return fmt.Errorf("failed to create pipefs filesystem: %v", err)
+ }
+ defer pipeFilesystem.DecRef(ctx)
+ pipeMount, err := k.vfs.NewDisconnectedMount(pipeFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create pipefs mount: %v", err)
+ }
+ k.pipeMount = pipeMount
+
+ tmpfsFilesystem, tmpfsRoot, err := tmpfs.NewFilesystem(ctx, &k.vfs, auth.NewRootCredentials(k.rootUserNamespace))
+ if err != nil {
+ return fmt.Errorf("failed to create tmpfs filesystem: %v", err)
+ }
+ defer tmpfsFilesystem.DecRef(ctx)
+ defer tmpfsRoot.DecRef(ctx)
+ shmMount, err := k.vfs.NewDisconnectedMount(tmpfsFilesystem, tmpfsRoot, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create tmpfs mount: %v", err)
+ }
+ k.shmMount = shmMount
+
+ socketFilesystem, err := sockfs.NewFilesystem(&k.vfs)
+ if err != nil {
+ return fmt.Errorf("failed to create sockfs filesystem: %v", err)
+ }
+ defer socketFilesystem.DecRef(ctx)
+ socketMount, err := k.vfs.NewDisconnectedMount(socketFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return fmt.Errorf("failed to create sockfs mount: %v", err)
+ }
+ k.socketMount = socketMount
+ }
+
return nil
}
// SaveTo saves the state of k to w.
//
// Preconditions: The kernel must be paused throughout the call to SaveTo.
-func (k *Kernel) SaveTo(w io.Writer) error {
+func (k *Kernel) SaveTo(w wire.Writer) error {
saveStart := time.Now()
ctx := k.SupervisorContext()
@@ -342,8 +431,8 @@ func (k *Kernel) SaveTo(w io.Writer) error {
defer k.extMu.Unlock()
// Stop time.
- k.pauseTimeLocked()
- defer k.resumeTimeLocked()
+ k.pauseTimeLocked(ctx)
+ defer k.resumeTimeLocked(ctx)
// Evict all evictable MemoryFile allocations.
k.mf.StartEvictions()
@@ -359,18 +448,16 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// Remove all epoll waiter objects from underlying wait queues.
// NOTE: for programs to resume execution in future snapshot scenarios,
// we will need to re-establish these waiter objects after saving.
- k.tasks.unregisterEpollWaiters()
+ k.tasks.unregisterEpollWaiters(ctx)
// Clear the dirent cache before saving because Dirents must be Loaded in a
// particular order (parents before children), and Loading dirents from a cache
// breaks that order.
- if err := k.flushMountSourceRefs(); err != nil {
+ if err := k.flushMountSourceRefs(ctx); err != nil {
return err
}
- // Ensure that all pending asynchronous work is complete:
- // - inode and mount release
- // - asynchronuous IO
+ // Ensure that all inode and mount release operations have completed.
fs.AsyncBarrier()
// Once all fs work has completed (flushed references have all been released),
@@ -391,23 +478,23 @@ func (k *Kernel) SaveTo(w io.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if err := state.Save(w, k.FeatureSet(), nil); err != nil {
+ if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
// Save the kernel state.
kernelStart := time.Now()
- var stats state.Stats
- if err := state.Save(w, k, &stats); err != nil {
+ stats, err := state.Save(k.SupervisorContext(), w, k)
+ if err != nil {
return err
}
- log.Infof("Kernel save stats: %s", &stats)
+ log.Infof("Kernel save stats: %s", stats.String())
log.Infof("Kernel save took [%s].", time.Since(kernelStart))
// Save the memory file's state.
memoryStart := time.Now()
- if err := k.mf.SaveTo(w); err != nil {
+ if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
return err
}
log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -419,7 +506,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// flushMountSourceRefs flushes the MountSources for all mounted filesystems
// and open FDs.
-func (k *Kernel) flushMountSourceRefs() error {
+func (k *Kernel) flushMountSourceRefs(ctx context.Context) error {
// Flush all mount sources for currently mounted filesystems in each task.
flushed := make(map[*fs.MountNamespace]struct{})
k.tasks.mu.RLock()
@@ -435,17 +522,22 @@ func (k *Kernel) flushMountSourceRefs() error {
// There may be some open FDs whose filesystems have been unmounted. We
// must flush those as well.
- return k.tasks.forEachFDPaused(func(file *fs.File) error {
+ return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
file.Dirent.Inode.MountSource.FlushDirentRefs()
return nil
})
}
-// forEachFDPaused applies the given function to each open file descriptor in each
-// task.
+// forEachFDPaused applies the given function to each open file descriptor in
+// each task.
//
// Precondition: Must be called with the kernel paused.
-func (ts *TaskSet) forEachFDPaused(f func(*fs.File) error) (err error) {
+func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.FileDescription) error) (err error) {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return nil
+ }
+
ts.mu.RLock()
defer ts.mu.RUnlock()
for t := range ts.Root.tids {
@@ -453,8 +545,8 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File) error) (err error) {
if t.fdTable == nil {
continue
}
- t.fdTable.forEach(func(_ int32, file *fs.File, _ FDFlags) {
- if lastErr := f(file); lastErr != nil && err == nil {
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fileVFS2 *vfs.FileDescription, _ FDFlags) {
+ if lastErr := f(file, fileVFS2); lastErr != nil && err == nil {
err = lastErr
}
})
@@ -463,7 +555,8 @@ func (ts *TaskSet) forEachFDPaused(f func(*fs.File) error) (err error) {
}
func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
- return ts.forEachFDPaused(func(file *fs.File) error {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error {
if flags := file.Flags(); !flags.Write {
return nil
}
@@ -474,12 +567,9 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error {
syncErr := file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncAll)
if err := fs.SaveFileFsyncError(syncErr); err != nil {
name, _ := file.Dirent.FullName(nil /* root */)
- // Wrap this error in ErrSaveRejection
- // so that it will trigger a save
- // error, rather than a panic. This
- // also allows us to distinguish Fsync
- // errors from state file errors in
- // state.Save.
+ // Wrap this error in ErrSaveRejection so that it will trigger a save
+ // error, rather than a panic. This also allows us to distinguish Fsync
+ // errors from state file errors in state.Save.
return fs.ErrSaveRejection{
Err: fmt.Errorf("%q was not sufficiently synced: %v", name, err),
}
@@ -513,27 +603,40 @@ func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error {
return nil
}
-func (ts *TaskSet) unregisterEpollWaiters() {
+func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) {
+ // TODO(gvisor.dev/issue/1663): Add save support for VFS2.
+ if VFS2Enabled {
+ return
+ }
+
ts.mu.RLock()
defer ts.mu.RUnlock()
+
+ // Tasks that belong to the same process could potentially point to the
+ // same FDTable. So we retain a map of processed ones to avoid
+ // processing the same FDTable multiple times.
+ processed := make(map[*FDTable]struct{})
for t := range ts.Root.tids {
// We can skip locking Task.mu here since the kernel is paused.
- if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ FDFlags) {
- if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
- e.UnregisterEpollWaiters()
- }
- })
+ if t.fdTable == nil {
+ continue
}
+ if _, ok := processed[t.fdTable]; ok {
+ continue
+ }
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) {
+ if e, ok := file.FileOperations.(*epoll.EventPoll); ok {
+ e.UnregisterEpollWaiters()
+ }
+ })
+ processed[t.fdTable] = struct{}{}
}
}
// LoadFrom returns a new Kernel loaded from args.
-func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
+func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
- k.networkStack = net
-
initAppCores := k.applicationCores
// Load the pre-saved CPUID FeatureSet.
@@ -542,7 +645,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if err := state.Load(r, &features, nil); err != nil {
+ if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -557,16 +660,20 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the kernel state.
kernelStart := time.Now()
- var stats state.Stats
- if err := state.Load(r, k, &stats); err != nil {
+ stats, err := state.Load(k.SupervisorContext(), r, k)
+ if err != nil {
return err
}
- log.Infof("Kernel load stats: %s", &stats)
+ log.Infof("Kernel load stats: %s", stats.String())
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+ // rootNetworkNamespace should be populated after loading the state file.
+ // Restore the root network stack.
+ k.rootNetworkNamespace.RestoreRootStack(net)
+
// Load the memory file's state.
memoryStart := time.Now()
- if err := k.mf.LoadFrom(r); err != nil {
+ if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
return err
}
log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -622,7 +729,7 @@ type CreateProcessArgs struct {
// File is a passed host FD pointing to a file to load as the init binary.
//
// This is checked if and only if Filename is "".
- File *fs.File
+ File fsbridge.File
// Argvv is a list of arguments.
Argv []string
@@ -671,6 +778,13 @@ type CreateProcessArgs struct {
// increment it).
MountNamespace *fs.MountNamespace
+ // MountNamespaceVFS2 optionally contains the mount namespace for this
+ // process. If nil, the init process's mount namespace is used.
+ //
+ // Anyone setting MountNamespaceVFS2 must donate a reference (i.e.
+ // increment it).
+ MountNamespaceVFS2 *vfs.MountNamespace
+
// ContainerID is the container that the process belongs to.
ContainerID string
}
@@ -709,13 +823,26 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} {
return ctx.args.Credentials
case fs.CtxRoot:
if ctx.args.MountNamespace != nil {
- // MountNamespace.Root() will take a reference on the root
- // dirent for us.
+ // MountNamespace.Root() will take a reference on the root dirent for us.
return ctx.args.MountNamespace.Root()
}
return nil
+ case vfs.CtxRoot:
+ if ctx.args.MountNamespaceVFS2 == nil {
+ return nil
+ }
+ // MountNamespaceVFS2.Root() takes a reference on the root dirent for us.
+ return ctx.args.MountNamespaceVFS2.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2 takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
@@ -755,34 +882,77 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
defer k.extMu.Unlock()
log.Infof("EXEC: %v", args.Argv)
- // Grab the mount namespace.
- mounts := args.MountNamespace
- if mounts == nil {
- mounts = k.GlobalInit().Leader().MountNamespace()
- mounts.IncRef()
- }
-
- tg := k.newThreadGroup(mounts, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits, k.monotonicClock)
ctx := args.NewContext(k)
- // Get the root directory from the MountNamespace.
- root := mounts.Root()
- // The call to newFSContext below will take a reference on root, so we
- // don't need to hold this one.
- defer root.DecRef()
-
- // Grab the working directory.
- remainingTraversals := uint(args.MaxSymlinkTraversals)
- wd := root // Default.
- if args.WorkingDirectory != "" {
- var err error
- wd, err = mounts.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
- if err != nil {
- return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ var (
+ opener fsbridge.Lookup
+ fsContext *FSContext
+ mntns *fs.MountNamespace
+ )
+
+ if VFS2Enabled {
+ mntnsVFS2 := args.MountNamespaceVFS2
+ if mntnsVFS2 == nil {
+ // MountNamespaceVFS2 adds a reference to the namespace, which is
+ // transferred to the new process.
+ mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2()
+ }
+ // Get the root directory from the MountNamespace.
+ root := args.MountNamespaceVFS2.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef(ctx)
+
+ // Grab the working directory.
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: wd,
+ Path: fspath.Parse(args.WorkingDirectory),
+ FollowFinalSymlink: true,
+ }
+ var err error
+ wd, err = k.VFS().GetDentryAt(ctx, args.Credentials, &pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef(ctx)
+ }
+ opener = fsbridge.NewVFSLookup(mntnsVFS2, root, wd)
+ fsContext = NewFSContextVFS2(root, wd, args.Umask)
+
+ } else {
+ mntns = args.MountNamespace
+ if mntns == nil {
+ mntns = k.GlobalInit().Leader().MountNamespace()
+ mntns.IncRef()
+ }
+ // Get the root directory from the MountNamespace.
+ root := mntns.Root()
+ // The call to newFSContext below will take a reference on root, so we
+ // don't need to hold this one.
+ defer root.DecRef(ctx)
+
+ // Grab the working directory.
+ remainingTraversals := args.MaxSymlinkTraversals
+ wd := root // Default.
+ if args.WorkingDirectory != "" {
+ var err error
+ wd, err = mntns.FindInode(ctx, root, nil, args.WorkingDirectory, &remainingTraversals)
+ if err != nil {
+ return nil, 0, fmt.Errorf("failed to find initial working directory %q: %v", args.WorkingDirectory, err)
+ }
+ defer wd.DecRef(ctx)
}
- defer wd.DecRef()
+ opener = fsbridge.NewFSLookup(mntns, root, wd)
+ fsContext = newFSContext(root, wd, args.Umask)
}
+ tg := k.NewThreadGroup(mntns, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits)
+
// Check which file to start from.
switch {
case args.Filename != "":
@@ -803,15 +973,14 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
}
// Create a fresh task context.
- remainingTraversals = uint(args.MaxSymlinkTraversals)
+ remainingTraversals := args.MaxSymlinkTraversals
loadArgs := loader.LoadArgs{
- Mounts: mounts,
- Root: root,
- WorkingDirectory: wd,
+ Opener: opener,
RemainingTraversals: &remainingTraversals,
ResolveFinal: true,
Filename: args.Filename,
File: args.File,
+ CloseOnExec: false,
Argv: args.Argv,
Envv: args.Envv,
Features: k.featureSet,
@@ -831,18 +1000,22 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
Kernel: k,
ThreadGroup: tg,
TaskContext: tc,
- FSContext: newFSContext(root, wd, args.Umask),
+ FSContext: fsContext,
FDTable: args.FDTable,
Credentials: args.Credentials,
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
UTSNamespace: args.UTSNamespace,
IPCNamespace: args.IPCNamespace,
AbstractSocketNamespace: args.AbstractSocketNamespace,
+ MountNamespaceVFS2: args.MountNamespaceVFS2,
ContainerID: args.ContainerID,
}
- if _, err := k.tasks.NewTask(config); err != nil {
+ t, err := k.tasks.NewTask(config)
+ if err != nil {
return nil, 0, err
}
+ t.traceExecEvent(tc) // Simulate exec for tracing.
// Success.
tgid := k.tasks.Root.IDOfThreadGroup(tg)
@@ -882,7 +1055,7 @@ func (k *Kernel) Start() error {
// If k was created by LoadKernelFrom, timers were stopped during
// Kernel.SaveTo and need to be resumed. If k was created by NewKernel,
// this is a no-op.
- k.resumeTimeLocked()
+ k.resumeTimeLocked(k.SupervisorContext())
// Start task goroutines.
k.tasks.mu.RLock()
defer k.tasks.mu.RUnlock()
@@ -896,7 +1069,7 @@ func (k *Kernel) Start() error {
//
// Preconditions: Any task goroutines running in k must be stopped. k.extMu
// must be locked.
-func (k *Kernel) pauseTimeLocked() {
+func (k *Kernel) pauseTimeLocked(ctx context.Context) {
// k.cpuClockTicker may be nil since Kernel.SaveTo() may be called before
// Kernel.Start().
if k.cpuClockTicker != nil {
@@ -918,9 +1091,15 @@ func (k *Kernel) pauseTimeLocked() {
// This means we'll iterate FDTables shared by multiple tasks repeatedly,
// but ktime.Timer.Pause is idempotent so this is harmless.
if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ FDFlags) {
- if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
- tfd.PauseTimer()
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok {
+ tfd.PauseTimer()
+ }
+ } else {
+ if tfd, ok := file.FileOperations.(*oldtimerfd.TimerOperations); ok {
+ tfd.PauseTimer()
+ }
}
})
}
@@ -934,7 +1113,7 @@ func (k *Kernel) pauseTimeLocked() {
//
// Preconditions: Any task goroutines running in k must be stopped. k.extMu
// must be locked.
-func (k *Kernel) resumeTimeLocked() {
+func (k *Kernel) resumeTimeLocked(ctx context.Context) {
if k.cpuClockTicker != nil {
k.cpuClockTicker.Resume()
}
@@ -948,9 +1127,15 @@ func (k *Kernel) resumeTimeLocked() {
}
}
if t.fdTable != nil {
- t.fdTable.forEach(func(_ int32, file *fs.File, _ FDFlags) {
- if tfd, ok := file.FileOperations.(*timerfd.TimerOperations); ok {
- tfd.ResumeTimer()
+ t.fdTable.forEach(ctx, func(_ int32, file *fs.File, fd *vfs.FileDescription, _ FDFlags) {
+ if VFS2Enabled {
+ if tfd, ok := fd.Impl().(*timerfd.TimerFileDescription); ok {
+ tfd.ResumeTimer()
+ }
+ } else {
+ if tfd, ok := file.FileOperations.(*oldtimerfd.TimerOperations); ok {
+ tfd.ResumeTimer()
+ }
}
})
}
@@ -1067,13 +1252,22 @@ func (k *Kernel) Kill(es ExitStatus) {
}
// Pause requests that all tasks in k temporarily stop executing, and blocks
-// until all tasks in k have stopped. Multiple calls to Pause nest and require
-// an equal number of calls to Unpause to resume execution.
+// until all tasks and asynchronous I/O operations in k have stopped. Multiple
+// calls to Pause nest and require an equal number of calls to Unpause to
+// resume execution.
func (k *Kernel) Pause() {
k.extMu.Lock()
k.tasks.BeginExternalStop()
k.extMu.Unlock()
k.tasks.runningGoroutines.Wait()
+ k.tasks.aioGoroutines.Wait()
+}
+
+// ReceiveTaskStates receives full states for all tasks.
+func (k *Kernel) ReceiveTaskStates() {
+ k.extMu.Lock()
+ k.tasks.PullFullState()
+ k.extMu.Unlock()
}
// Unpause ends the effect of a previous call to Pause. If Unpause is called
@@ -1095,6 +1289,14 @@ func (k *Kernel) SendExternalSignal(info *arch.SignalInfo, context string) {
k.sendExternalSignal(info, context)
}
+// SendExternalSignalThreadGroup injects a signal into an specific ThreadGroup.
+// This function doesn't skip signals like SendExternalSignal does.
+func (k *Kernel) SendExternalSignalThreadGroup(tg *ThreadGroup, info *arch.SignalInfo) error {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ return tg.SendSignal(info)
+}
+
// SendContainerSignal sends the given signal to all processes inside the
// namespace that match the given container ID.
func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
@@ -1117,6 +1319,22 @@ func (k *Kernel) SendContainerSignal(cid string, info *arch.SignalInfo) error {
return lastErr
}
+// RebuildTraceContexts rebuilds the trace context for all tasks.
+//
+// Unfortunately, if these are built while tracing is not enabled, then we will
+// not have meaningful trace data. Rebuilding here ensures that we can do so
+// after tracing has been enabled.
+func (k *Kernel) RebuildTraceContexts() {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ k.tasks.mu.RLock()
+ defer k.tasks.mu.RUnlock()
+
+ for t, tid := range k.tasks.Root.tids {
+ t.rebuildTraceContext(tid)
+ }
+}
+
// FeatureSet returns the FeatureSet.
func (k *Kernel) FeatureSet() *cpuid.FeatureSet {
return k.featureSet
@@ -1157,10 +1375,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
return k.rootAbstractSocketNamespace
}
-// NetworkStack returns the network stack. NetworkStack may return nil if no
-// network stack is available.
-func (k *Kernel) NetworkStack() inet.Stack {
- return k.networkStack
+// RootNetworkNamespace returns the root network namespace, always non-nil.
+func (k *Kernel) RootNetworkNamespace() *inet.Namespace {
+ return k.rootNetworkNamespace
}
// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
@@ -1172,6 +1389,11 @@ func (k *Kernel) GlobalInit() *ThreadGroup {
return k.globalInit
}
+// TestOnly_SetGlobalInit sets the thread group with ID 1 in the root PID namespace.
+func (k *Kernel) TestOnly_SetGlobalInit(tg *ThreadGroup) {
+ k.globalInit = tg
+}
+
// ApplicationCores returns the number of CPUs visible to sandboxed
// applications.
func (k *Kernel) ApplicationCores() uint {
@@ -1255,6 +1477,11 @@ func (k *Kernel) NowMonotonic() int64 {
return now
}
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ return ktime.TcpipAfterFunc(k.realtimeClock, d, f)
+}
+
// SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or
// LoadFrom.
func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) {
@@ -1285,13 +1512,14 @@ func (k *Kernel) SupervisorContext() context.Context {
// +stateify savable
type SocketEntry struct {
socketEntry
- k *Kernel
- Sock *refs.WeakRef
- ID uint64 // Socket table entry number.
+ k *Kernel
+ Sock *refs.WeakRef
+ SockVFS2 *vfs.FileDescription
+ ID uint64 // Socket table entry number.
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *SocketEntry) WeakRefGone() {
+func (s *SocketEntry) WeakRefGone(context.Context) {
s.k.extMu.Lock()
s.k.sockets.Remove(s)
s.k.extMu.Unlock()
@@ -1310,7 +1538,30 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Unlock()
}
+// RecordSocketVFS2 adds a VFS2 socket to the system-wide socket table for
+// tracking.
+//
+// Precondition: Caller must hold a reference to sock.
+//
+// Note that the socket table will not hold a reference on the
+// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ id := k.nextSocketEntry
+ k.nextSocketEntry++
+ s := &SocketEntry{
+ k: k,
+ ID: id,
+ SockVFS2: sock,
+ }
+ k.sockets.PushBack(s)
+ k.extMu.Unlock()
+}
+
// ListSockets returns a snapshot of all sockets.
+//
+// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
+// to get a reference on a socket in the table.
func (k *Kernel) ListSockets() []*SocketEntry {
k.extMu.Lock()
var socks []*SocketEntry
@@ -1321,6 +1572,7 @@ func (k *Kernel) ListSockets() []*SocketEntry {
return socks
}
+// supervisorContext is a privileged context.
type supervisorContext struct {
context.NoopSleeper
log.Logger
@@ -1351,8 +1603,24 @@ func (ctx supervisorContext) Value(key interface{}) interface{} {
return ctx.k.globalInit.mounts.Root()
}
return nil
+ case vfs.CtxRoot:
+ if ctx.k.globalInit == nil {
+ return vfs.VirtualDentry{}
+ }
+ mntns := ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
+ defer mntns.DecRef(ctx)
+ // Root() takes a reference on the root dirent for us.
+ return mntns.Root()
+ case vfs.CtxMountNamespace:
+ if ctx.k.globalInit == nil {
+ return nil
+ }
+ // MountNamespaceVFS2() takes a reference for us.
+ return ctx.k.GlobalInit().Leader().MountNamespaceVFS2()
case fs.CtxDirentCacheLimiter:
return ctx.k.DirentCacheLimiter
+ case inet.CtxStack:
+ return ctx.k.RootNetworkNamespace().Stack()
case ktime.CtxRealtimeClock:
return ctx.k.RealtimeClock()
case limits.CtxLimits:
@@ -1396,3 +1664,36 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
Registers: t.Arch().StateData().Proto(),
})
}
+
+// VFS returns the virtual filesystem for the kernel.
+func (k *Kernel) VFS() *vfs.VirtualFilesystem {
+ return &k.vfs
+}
+
+// SetHostMount sets the hostfs mount.
+func (k *Kernel) SetHostMount(mnt *vfs.Mount) {
+ if k.hostMount != nil {
+ panic("Kernel.hostMount cannot be set more than once")
+ }
+ k.hostMount = mnt
+}
+
+// HostMount returns the hostfs mount.
+func (k *Kernel) HostMount() *vfs.Mount {
+ return k.hostMount
+}
+
+// PipeMount returns the pipefs mount.
+func (k *Kernel) PipeMount() *vfs.Mount {
+ return k.pipeMount
+}
+
+// ShmMount returns the tmpfs mount.
+func (k *Kernel) ShmMount() *vfs.Mount {
+ return k.shmMount
+}
+
+// SocketMount returns the sockfs mount.
+func (k *Kernel) SocketMount() *vfs.Mount {
+ return k.socketMount
+}
diff --git a/pkg/sentry/kernel/kernel_opts.go b/pkg/sentry/kernel/kernel_opts.go
new file mode 100644
index 000000000..2e66ec587
--- /dev/null
+++ b/pkg/sentry/kernel/kernel_opts.go
@@ -0,0 +1,20 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+// SpecialOpts contains non-standard options for the kernel.
+//
+// +stateify savable
+type SpecialOpts struct{}
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index d7a7d1169..4486848d2 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -1,13 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
+load("//tools:defs.bzl", "go_library", "proto_library")
package(licenses = ["notice"])
go_library(
name = "memevent",
srcs = ["memory_events.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent",
visibility = ["//:sandbox"],
deps = [
":memory_events_go_proto",
@@ -16,24 +13,12 @@ go_library(
"//pkg/metric",
"//pkg/sentry/kernel",
"//pkg/sentry/usage",
+ "//pkg/sync",
],
)
proto_library(
- name = "memory_events_proto",
+ name = "memory_events",
srcs = ["memory_events.proto"],
visibility = ["//visibility:public"],
)
-
-cc_proto_library(
- name = "memory_events_cc_proto",
- visibility = ["//visibility:public"],
- deps = [":memory_events_proto"],
-)
-
-go_proto_library(
- name = "memory_events_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto",
- proto = ":memory_events_proto",
- visibility = ["//visibility:public"],
-)
diff --git a/pkg/sentry/kernel/memevent/memory_events.go b/pkg/sentry/kernel/memevent/memory_events.go
index b0d98e7f0..200565bb8 100644
--- a/pkg/sentry/kernel/memevent/memory_events.go
+++ b/pkg/sentry/kernel/memevent/memory_events.go
@@ -17,7 +17,6 @@
package memevent
import (
- "sync"
"time"
"gvisor.dev/gvisor/pkg/eventchannel"
@@ -26,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
pb "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
)
var totalTicks = metric.MustCreateNewUint64Metric("/memory_events/ticks", false /*sync*/, "Total number of memory event periods that have elapsed since startup.")
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 9d34f6d4d..449643118 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,49 +1,36 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
-go_template_instance(
- name = "buffer_list",
- out = "buffer_list.go",
- package = "pipe",
- prefix = "buffer",
- template = "//pkg/ilist:generic_list",
- types = {
- "Element": "*buffer",
- "Linker": "*buffer",
- },
-)
-
go_library(
name = "pipe",
srcs = [
- "buffer.go",
- "buffer_list.go",
"device.go",
"node.go",
"pipe.go",
+ "pipe_unsafe.go",
"pipe_util.go",
"reader.go",
"reader_writer.go",
"vfs.go",
"writer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/pipe",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/amutex",
+ "//pkg/buffer",
+ "//pkg/context",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
- "//pkg/sentry/safemem",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
@@ -52,17 +39,16 @@ go_test(
name = "pipe_test",
size = "small",
srcs = [
- "buffer_test.go",
"node_test.go",
"pipe_test.go",
],
- embed = [":pipe"],
+ library = ":pipe",
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
- "//pkg/sentry/usermem",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
deleted file mode 100644
index 95bee2d37..000000000
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ /dev/null
@@ -1,115 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package pipe
-
-import (
- "io"
- "sync"
-
- "gvisor.dev/gvisor/pkg/sentry/safemem"
-)
-
-// buffer encapsulates a queueable byte buffer.
-//
-// Note that the total size is slightly less than two pages. This
-// is done intentionally to ensure that the buffer object aligns
-// with runtime internals. We have no hard size or alignment
-// requirements. This two page size will effectively minimize
-// internal fragmentation, but still have a large enough chunk
-// to limit excessive segmentation.
-//
-// +stateify savable
-type buffer struct {
- data [8144]byte
- read int
- write int
- bufferEntry
-}
-
-// Reset resets internal data.
-//
-// This must be called before use.
-func (b *buffer) Reset() {
- b.read = 0
- b.write = 0
-}
-
-// Empty indicates the buffer is empty.
-//
-// This indicates there is no data left to read.
-func (b *buffer) Empty() bool {
- return b.read == b.write
-}
-
-// Full indicates the buffer is full.
-//
-// This indicates there is no capacity left to write.
-func (b *buffer) Full() bool {
- return b.write == len(b.data)
-}
-
-// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
-func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.write:]))
- n, err := safemem.CopySeq(dst, srcs)
- b.write += int(n)
- return n, err
-}
-
-// WriteFromReader writes to the buffer from an io.Reader.
-func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
- dst := b.data[b.write:]
- if count < int64(len(dst)) {
- dst = b.data[b.write:][:count]
- }
- n, err := r.Read(dst)
- b.write += n
- return int64(n), err
-}
-
-// ReadToBlocks implements safemem.Reader.ReadToBlocks.
-func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
- n, err := safemem.CopySeq(dsts, src)
- b.read += int(n)
- return n, err
-}
-
-// ReadToWriter reads from the buffer into an io.Writer.
-func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
- src := b.data[b.read:b.write]
- if count < int64(len(src)) {
- src = b.data[b.read:][:count]
- }
- n, err := w.Write(src)
- if !dup {
- b.read += n
- }
- return int64(n), err
-}
-
-// bufferPool is a pool for buffers.
-var bufferPool = sync.Pool{
- New: func() interface{} {
- return new(buffer)
- },
-}
-
-// newBuffer grabs a new buffer from the pool.
-func newBuffer() *buffer {
- b := bufferPool.Get().(*buffer)
- b.Reset()
- return b
-}
diff --git a/pkg/sentry/kernel/pipe/node.go b/pkg/sentry/kernel/pipe/node.go
index 4a19ab7ce..6497dc4ba 100644
--- a/pkg/sentry/kernel/pipe/node.go
+++ b/pkg/sentry/kernel/pipe/node.go
@@ -15,12 +15,11 @@
package pipe
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -94,7 +93,7 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
if i.p.isNamed && !flags.NonBlocking && !i.p.HasWriters() {
if !waitFor(&i.mu, &i.wWakeup, ctx) {
- r.DecRef()
+ r.DecRef(ctx)
return nil, syserror.ErrInterrupted
}
}
@@ -112,12 +111,12 @@ func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.Fi
// On a nonblocking, write-only open, the open fails with ENXIO if the
// read side isn't open yet.
if flags.NonBlocking {
- w.DecRef()
+ w.DecRef(ctx)
return nil, syserror.ENXIO
}
if !waitFor(&i.mu, &i.rWakeup, ctx) {
- w.DecRef()
+ w.DecRef(ctx)
return nil, syserror.ErrInterrupted
}
}
diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go
index 16fa80abe..ce0db5583 100644
--- a/pkg/sentry/kernel/pipe/node_test.go
+++ b/pkg/sentry/kernel/pipe/node_test.go
@@ -18,11 +18,11 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type sleeper struct {
@@ -167,7 +167,7 @@ func TestClosedReaderBlocksWriteOpen(t *testing.T) {
f := NewInodeOperations(ctx, perms, newNamedPipe(t))
rFile, _ := testOpenOrDie(ctx, t, f, fs.FileFlags{Read: true, NonBlocking: true}, nil)
- rFile.DecRef()
+ rFile.DecRef(ctx)
wDone := make(chan struct{})
// This open for write should block because the reader is now gone.
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 1a1b38f83..297e8f28f 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -17,12 +17,13 @@ package pipe
import (
"fmt"
- "sync"
"sync/atomic"
"syscall"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -70,10 +71,10 @@ type Pipe struct {
// mu protects all pipe internal state below.
mu sync.Mutex `state:"nosave"`
- // data is the buffer queue of pipe contents.
+ // view is the underlying set of buffers.
//
// This is protected by mu.
- data bufferList
+ view buffer.View
// max is the maximum size of the pipe in bytes. When this max has been
// reached, writers will get EWOULDBLOCK.
@@ -81,11 +82,6 @@ type Pipe struct {
// This is protected by mu.
max int64
- // size is the current size of the pipe in bytes.
- //
- // This is protected by mu.
- size int64
-
// hadWriter indicates if this pipe ever had a writer. Note that this
// does not necessarily indicate there is *currently* a writer, just
// that there has been a writer at some point since the pipe was
@@ -156,7 +152,7 @@ func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.
d := fs.NewDirent(ctx, fs.NewInode(ctx, iops, ms, sattr), fmt.Sprintf("pipe:[%d]", ino))
// The p.Open calls below will each take a reference on the Dirent. We
// must drop the one we already have.
- defer d.DecRef()
+ defer d.DecRef(ctx)
return p.Open(ctx, d, fs.FileFlags{Read: true}), p.Open(ctx, d, fs.FileFlags{Write: true})
}
@@ -196,7 +192,7 @@ type readOps struct {
limit func(int64)
// read performs the actual read operation.
- read func(*buffer) (int64, error)
+ read func(*buffer.View) (int64, error)
}
// read reads data from the pipe into dst and returns the number of bytes
@@ -211,82 +207,27 @@ func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
-
- // Is the pipe empty?
- if p.size == 0 {
- if !p.HasWriters() {
- // There are no writers, return EOF.
- return 0, nil
- }
- return 0, syserror.ErrWouldBlock
- }
-
- // Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
- }
-
- done := int64(0)
- for ops.left() > 0 {
- // Pop the first buffer.
- first := p.data.Front()
- if first == nil {
- break
- }
-
- // Copy user data.
- n, err := ops.read(first)
- done += int64(n)
- p.size -= n
-
- // Empty buffer?
- if first.Empty() {
- // Push to the free list.
- p.data.Remove(first)
- bufferPool.Put(first)
- }
-
- // Handle errors.
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
+ return p.readLocked(ctx, ops)
}
-// dup duplicates all data from this pipe into the given writer.
-//
-// There is no blocking behavior implemented here. The writer may propagate
-// some blocking error. All the writes must be complete writes.
-func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
- p.mu.Lock()
- defer p.mu.Unlock()
-
+func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) {
// Is the pipe empty?
- if p.size == 0 {
+ if p.view.Size() == 0 {
if !p.HasWriters() {
- // See above.
+ // There are no writers, return EOF.
return 0, nil
}
return 0, syserror.ErrWouldBlock
}
// Limit how much we consume.
- if ops.left() > p.size {
- ops.limit(p.size)
+ if ops.left() > p.view.Size() {
+ ops.limit(p.view.Size())
}
- done := int64(0)
- for buf := p.data.Front(); buf != nil; buf = buf.Next() {
- n, err := ops.read(buf)
- done += n
- if err != nil {
- return done, err
- }
- }
-
- return done, nil
+ // Copy user data; the read op is responsible for trimming.
+ done, err := ops.read(&p.view)
+ return done, err
}
type writeOps struct {
@@ -297,7 +238,7 @@ type writeOps struct {
limit func(int64)
// write should write to the provided buffer.
- write func(*buffer) (int64, error)
+ write func(*buffer.View) (int64, error)
}
// write writes data from sv into the pipe and returns the number of bytes
@@ -308,7 +249,10 @@ type writeOps struct {
func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
+ return p.writeLocked(ctx, ops)
+}
+func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) {
// Can't write to a pipe with no readers.
if !p.HasReaders() {
return 0, syscall.EPIPE
@@ -317,35 +261,28 @@ func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
wanted := ops.left()
- if avail := p.max - p.size; wanted > avail {
+ avail := p.max - p.view.Size()
+ if wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
ops.limit(avail)
}
- done := int64(0)
- for ops.left() > 0 {
- // Need a new buffer?
- last := p.data.Back()
- if last == nil || last.Full() {
- // Add a new buffer to the data list.
- last = newBuffer()
- p.data.PushBack(last)
- }
-
- // Copy user data.
- n, err := ops.write(last)
- done += int64(n)
- p.size += n
+ // Copy user data.
+ done, err := ops.write(&p.view)
+ if err != nil {
+ return done, err
+ }
- // Handle errors.
- if err != nil {
- return done, err
- }
+ if done < avail {
+ // Non-failure, but short write.
+ return done, nil
}
- if wanted > done {
- // Partial write due to full pipe.
+ if done < wanted {
+ // Partial write due to full pipe. Note that this could also be
+ // the short write case above, we would expect a second call
+ // and the write to return zero bytes in this case.
return done, syserror.ErrWouldBlock
}
@@ -396,7 +333,7 @@ func (p *Pipe) HasWriters() bool {
// Precondition: mu must be held.
func (p *Pipe) rReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasReaders() && p.data.Front() != nil {
+ if p.HasReaders() && p.view.Size() != 0 {
ready |= waiter.EventIn
}
if !p.HasWriters() && p.hadWriter {
@@ -422,7 +359,7 @@ func (p *Pipe) rReadiness() waiter.EventMask {
// Precondition: mu must be held.
func (p *Pipe) wReadinessLocked() waiter.EventMask {
ready := waiter.EventMask(0)
- if p.HasWriters() && p.size < p.max {
+ if p.HasWriters() && p.view.Size() < p.max {
ready |= waiter.EventOut
}
if !p.HasReaders() {
@@ -451,7 +388,7 @@ func (p *Pipe) rwReadiness() waiter.EventMask {
func (p *Pipe) queued() int64 {
p.mu.Lock()
defer p.mu.Unlock()
- return p.size
+ return p.view.Size()
}
// FifoSize implements fs.FifoSizer.FifoSize.
@@ -474,7 +411,7 @@ func (p *Pipe) SetFifoSize(size int64) (int64, error) {
}
p.mu.Lock()
defer p.mu.Unlock()
- if size < p.size {
+ if size < p.view.Size() {
return 0, syserror.EBUSY
}
p.max = size
diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go
index e3a14b665..fe97e9800 100644
--- a/pkg/sentry/kernel/pipe/pipe_test.go
+++ b/pkg/sentry/kernel/pipe/pipe_test.go
@@ -18,17 +18,17 @@ import (
"bytes"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
func TestPipeRW(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, 65536, 4096)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
msg := []byte("here's some bytes")
wantN := int64(len(msg))
@@ -47,8 +47,8 @@ func TestPipeRW(t *testing.T) {
func TestPipeReadBlock(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, 65536, 4096)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
n, err := r.Readv(ctx, usermem.BytesIOSequence(make([]byte, 1)))
if n != 0 || err != syserror.ErrWouldBlock {
@@ -62,8 +62,8 @@ func TestPipeWriteBlock(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
msg := make([]byte, capacity+1)
n, err := w.Writev(ctx, usermem.BytesIOSequence(msg))
@@ -77,8 +77,8 @@ func TestPipeWriteUntilEnd(t *testing.T) {
ctx := contexttest.Context(t)
r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes)
- defer r.DecRef()
- defer w.DecRef()
+ defer r.DecRef(ctx)
+ defer w.DecRef(ctx)
msg := []byte("here's some bytes")
diff --git a/pkg/sentry/kernel/pipe/pipe_unsafe.go b/pkg/sentry/kernel/pipe/pipe_unsafe.go
new file mode 100644
index 000000000..dd60cba24
--- /dev/null
+++ b/pkg/sentry/kernel/pipe/pipe_unsafe.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pipe
+
+import (
+ "unsafe"
+)
+
+// lockTwoPipes locks both x.mu and y.mu in an order that is guaranteed to be
+// consistent for both lockTwoPipes(x, y) and lockTwoPipes(y, x), such that
+// concurrent calls cannot deadlock.
+//
+// Preconditions: x != y.
+func lockTwoPipes(x, y *Pipe) {
+ // Lock the two pipes in order of increasing address.
+ if uintptr(unsafe.Pointer(x)) < uintptr(unsafe.Pointer(y)) {
+ x.mu.Lock()
+ y.mu.Lock()
+ } else {
+ y.mu.Lock()
+ x.mu.Lock()
+ }
+}
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go
index ef9641e6a..6d58b682f 100644
--- a/pkg/sentry/kernel/pipe/pipe_util.go
+++ b/pkg/sentry/kernel/pipe/pipe_util.go
@@ -17,14 +17,15 @@ package pipe
import (
"io"
"math"
- "sync"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -32,7 +33,7 @@ import (
// the old fs architecture.
// Release cleans up the pipe's state.
-func (p *Pipe) Release() {
+func (p *Pipe) Release(context.Context) {
p.rClose()
p.wClose()
@@ -49,9 +50,10 @@ func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error)
limit: func(l int64) {
dst = dst.TakeFirst64(l)
},
- read: func(buf *buffer) (int64, error) {
- n, err := dst.CopyOutFrom(ctx, buf)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, view)
dst = dst.DropFirst64(n)
+ view.TrimFront(n)
return n, err
},
})
@@ -70,16 +72,15 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool)
limit: func(l int64) {
count = l
},
- read: func(buf *buffer) (int64, error) {
- n, err := buf.ReadToWriter(w, count, dup)
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToWriter(w, count)
+ if !dup {
+ view.TrimFront(n)
+ }
count -= n
return n, err
},
}
- if dup {
- // There is no notification for dup operations.
- return p.dup(ctx, ops)
- }
n, err := p.read(ctx, ops)
if n > 0 {
p.Notify(waiter.EventOut)
@@ -96,8 +97,8 @@ func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error)
limit: func(l int64) {
src = src.TakeFirst64(l)
},
- write: func(buf *buffer) (int64, error) {
- n, err := src.CopyInTo(ctx, buf)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := src.CopyInTo(ctx, view)
src = src.DropFirst64(n)
return n, err
},
@@ -117,8 +118,8 @@ func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, e
limit: func(l int64) {
count = l
},
- write: func(buf *buffer) (int64, error) {
- n, err := buf.WriteFromReader(r, count)
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromReader(r, count)
count -= n
return n, err
},
@@ -143,7 +144,7 @@ func (p *Pipe) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArgume
if v > math.MaxInt32 {
v = math.MaxInt32 // Silently truncate.
}
- // Copy result to user-space.
+ // Copy result to userspace.
_, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
AddressSpaceActive: true,
})
diff --git a/pkg/sentry/kernel/pipe/reader.go b/pkg/sentry/kernel/pipe/reader.go
index 7724b4452..ac18785c0 100644
--- a/pkg/sentry/kernel/pipe/reader.go
+++ b/pkg/sentry/kernel/pipe/reader.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -29,7 +30,7 @@ type Reader struct {
// Release implements fs.FileOperations.Release.
//
// This overrides ReaderWriter.Release.
-func (r *Reader) Release() {
+func (r *Reader) Release(context.Context) {
r.Pipe.rClose()
// Wake up writers.
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index b4d29fc77..b2b5691ee 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -17,11 +17,11 @@ package pipe
import (
"io"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// ReaderWriter satisfies the FileOperations interface and services both
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 6416e0dd8..28f998e45 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -15,14 +15,16 @@
package pipe
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -50,38 +52,44 @@ type VFSPipe struct {
}
// NewVFSPipe returns an initialized VFSPipe.
-func NewVFSPipe(sizeBytes, atomicIOBytes int64) *VFSPipe {
+func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe {
var vp VFSPipe
- initPipe(&vp.pipe, true /* isNamed */, sizeBytes, atomicIOBytes)
+ initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes)
return &vp
}
-// NewVFSPipeFD opens a named pipe. Named pipes have special blocking semantics
-// during open:
+// ReaderWriterPair returns read-only and write-only FDs for vp.
//
-// "Normally, opening the FIFO blocks until the other end is opened also. A
-// process can open a FIFO in nonblocking mode. In this case, opening for
-// read-only will succeed even if no-one has opened on the write side yet,
-// opening for write-only will fail with ENXIO (no such device or address)
-// unless the other end has already been opened. Under Linux, opening a FIFO
-// for read and write will succeed both in blocking and nonblocking mode. POSIX
-// leaves this behavior undefined. This can be used to open a FIFO for writing
-// while there are no readers available." - fifo(7)
-func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
+// Preconditions: statusFlags should not contain an open access mode.
+func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+ // Connected pipes share the same locks.
+ locks := &vfs.FileLocks{}
+ return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
+}
+
+// Open opens the pipe represented by vp.
+func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
vp.mu.Lock()
defer vp.mu.Unlock()
- readable := vfs.MayReadFileWithOpenFlags(flags)
- writable := vfs.MayWriteFileWithOpenFlags(flags)
+ readable := vfs.MayReadFileWithOpenFlags(statusFlags)
+ writable := vfs.MayWriteFileWithOpenFlags(statusFlags)
if !readable && !writable {
return nil, syserror.EINVAL
}
- vfd, err := vp.open(rp, vfsd, vfsfd, flags)
- if err != nil {
- return nil, err
- }
+ fd := vp.newFD(mnt, vfsd, statusFlags, locks)
+ // Named pipes have special blocking semantics during open:
+ //
+ // "Normally, opening the FIFO blocks until the other end is opened also. A
+ // process can open a FIFO in nonblocking mode. In this case, opening for
+ // read-only will succeed even if no-one has opened on the write side yet,
+ // opening for write-only will fail with ENXIO (no such device or address)
+ // unless the other end has already been opened. Under Linux, opening a
+ // FIFO for read and write will succeed both in blocking and nonblocking
+ // mode. POSIX leaves this behavior undefined. This can be used to open a
+ // FIFO for writing while there are no readers available." - fifo(7)
switch {
case readable && writable:
// Pipes opened for read-write always succeed without blocking.
@@ -90,23 +98,26 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd
case readable:
newHandleLocked(&vp.rWakeup)
- // If this pipe is being opened as nonblocking and there's no
+ // If this pipe is being opened as blocking and there's no
// writer, we have to wait for a writer to open the other end.
- if flags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ if vp.pipe.isNamed && statusFlags&linux.O_NONBLOCK == 0 && !vp.pipe.HasWriters() && !waitFor(&vp.mu, &vp.wWakeup, ctx) {
+ fd.DecRef(ctx)
return nil, syserror.EINTR
}
case writable:
newHandleLocked(&vp.wWakeup)
- if !vp.pipe.HasReaders() {
- // Nonblocking, write-only opens fail with ENXIO when
- // the read side isn't open yet.
- if flags&linux.O_NONBLOCK != 0 {
+ if vp.pipe.isNamed && !vp.pipe.HasReaders() {
+ // Non-blocking, write-only opens fail with ENXIO when the read
+ // side isn't open yet.
+ if statusFlags&linux.O_NONBLOCK != 0 {
+ fd.DecRef(ctx)
return nil, syserror.ENXIO
}
// Wait for a reader to open the other end.
if !waitFor(&vp.mu, &vp.rWakeup, ctx) {
+ fd.DecRef(ctx)
return nil, syserror.EINTR
}
}
@@ -115,102 +126,102 @@ func (vp *VFSPipe) NewVFSPipeFD(ctx context.Context, rp *vfs.ResolvingPath, vfsd
panic("invalid pipe flags: must be readable, writable, or both")
}
- return vfd, nil
+ return fd, nil
}
// Preconditions: vp.mu must be held.
-func (vp *VFSPipe) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, vfsfd *vfs.FileDescription, flags uint32) (*VFSPipeFD, error) {
- var fd VFSPipeFD
- fd.flags = flags
- fd.readable = vfs.MayReadFileWithOpenFlags(flags)
- fd.writable = vfs.MayWriteFileWithOpenFlags(flags)
- fd.vfsfd = vfsfd
- fd.pipe = &vp.pipe
- if fd.writable {
- // The corresponding Mount.EndWrite() is in VFSPipe.Release().
- if err := rp.Mount().CheckBeginWrite(); err != nil {
- return nil, err
- }
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription {
+ fd := &VFSPipeFD{
+ pipe: &vp.pipe,
}
+ fd.LockFD.Init(locks)
+ fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ })
switch {
- case fd.readable && fd.writable:
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
vp.pipe.rOpen()
vp.pipe.wOpen()
- case fd.readable:
+ case fd.vfsfd.IsReadable():
vp.pipe.rOpen()
- case fd.writable:
+ case fd.vfsfd.IsWritable():
vp.pipe.wOpen()
default:
panic("invalid pipe flags: must be readable, writable, or both")
}
- return &fd, nil
+ return &fd.vfsfd
}
-// VFSPipeFD implements a subset of vfs.FileDescriptionImpl for pipes. It is
-// expected that filesystesm will use this in a struct implementing
-// vfs.FileDescriptionImpl.
+// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
+// non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to
+// other FileDescriptions for splice(2) and tee(2).
type VFSPipeFD struct {
- pipe *Pipe
- flags uint32
- readable bool
- writable bool
- vfsfd *vfs.FileDescription
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ pipe *Pipe
}
// Release implements vfs.FileDescriptionImpl.Release.
-func (fd *VFSPipeFD) Release() {
+func (fd *VFSPipeFD) Release(context.Context) {
var event waiter.EventMask
- if fd.readable {
+ if fd.vfsfd.IsReadable() {
fd.pipe.rClose()
- event |= waiter.EventIn
+ event |= waiter.EventOut
}
- if fd.writable {
+ if fd.vfsfd.IsWritable() {
fd.pipe.wClose()
- event |= waiter.EventOut
+ event |= waiter.EventIn | waiter.EventHUp
}
if event == 0 {
panic("invalid pipe flags: must be readable, writable, or both")
}
- if fd.writable {
- fd.vfsfd.VirtualDentry().Mount().EndWrite()
+ fd.pipe.Notify(event)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ switch {
+ case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
+ return fd.pipe.rwReadiness()
+ case fd.vfsfd.IsReadable():
+ return fd.pipe.rReadiness()
+ case fd.vfsfd.IsWritable():
+ return fd.pipe.wReadiness()
+ default:
+ panic("pipe FD is neither readable nor writable")
}
+}
- fd.pipe.Notify(event)
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (fd *VFSPipeFD) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ESPIPE
}
-// OnClose implements vfs.FileDescriptionImpl.OnClose.
-func (fd *VFSPipeFD) OnClose(_ context.Context) error {
- return nil
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.pipe.EventRegister(e, mask)
}
-// PRead implements vfs.FileDescriptionImpl.PRead.
-func (fd *VFSPipeFD) PRead(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.ReadOptions) (int64, error) {
- return 0, syserror.ESPIPE
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *VFSPipeFD) EventUnregister(e *waiter.Entry) {
+ fd.pipe.EventUnregister(e)
}
// Read implements vfs.FileDescriptionImpl.Read.
func (fd *VFSPipeFD) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) {
- if !fd.readable {
- return 0, syserror.EINVAL
- }
-
return fd.pipe.Read(ctx, dst)
}
-// PWrite implements vfs.FileDescriptionImpl.PWrite.
-func (fd *VFSPipeFD) PWrite(_ context.Context, _ usermem.IOSequence, _ int64, _ vfs.WriteOptions) (int64, error) {
- return 0, syserror.ESPIPE
-}
-
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) {
- if !fd.writable {
- return 0, syserror.EINVAL
- }
-
return fd.pipe.Write(ctx, src)
}
@@ -218,3 +229,240 @@ func (fd *VFSPipeFD) Write(ctx context.Context, src usermem.IOSequence, _ vfs.Wr
func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
return fd.pipe.Ioctl(ctx, uio, args)
}
+
+// PipeSize implements fcntl(F_GETPIPE_SZ).
+func (fd *VFSPipeFD) PipeSize() int64 {
+ // Inline Pipe.FifoSize() rather than calling it with nil Context and
+ // fs.File and ignoring the returned error (which is always nil).
+ fd.pipe.mu.Lock()
+ defer fd.pipe.mu.Unlock()
+ return fd.pipe.max
+}
+
+// SetPipeSize implements fcntl(F_SETPIPE_SZ).
+func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) {
+ return fd.pipe.SetFifoSize(size)
+}
+
+// IOSequence returns a useremm.IOSequence that reads up to count bytes from,
+// or writes up to count bytes to, fd.
+func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence {
+ return usermem.IOSequence{
+ IO: fd,
+ Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}),
+ }
+}
+
+// CopyIn implements usermem.IO.CopyIn.
+func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(dst))
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return int64(len(dst))
+ },
+ limit: func(l int64) {
+ dst = dst[:l]
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadAt(dst, 0)
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// CopyOut implements usermem.IO.CopyOut.
+func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) {
+ origCount := int64(len(src))
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return int64(len(src))
+ },
+ limit: func(l int64) {
+ src = src[:l]
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Append(src)
+ return int64(len(src)), nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return int(n), syserror.ErrWouldBlock
+ }
+ return int(n), err
+}
+
+// ZeroOut implements usermem.IO.ZeroOut.
+func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) {
+ origCount := toZero
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return toZero
+ },
+ limit: func(l int64) {
+ toZero = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ view.Grow(view.Size()+toZero, true /* zero */)
+ return toZero, nil
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyInTo implements usermem.IO.CopyInTo.
+func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.read(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(view *buffer.View) (int64, error) {
+ n, err := view.ReadToSafememWriter(dst, uint64(count))
+ view.TrimFront(int64(n))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventOut)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// CopyOutFrom implements usermem.IO.CopyOutFrom.
+func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, src safemem.Reader, opts usermem.IOOpts) (int64, error) {
+ count := ars.NumBytes()
+ if count == 0 {
+ return 0, nil
+ }
+ origCount := count
+ n, err := fd.pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(view *buffer.View) (int64, error) {
+ n, err := view.WriteFromSafememReader(src, uint64(count))
+ return int64(n), err
+ },
+ })
+ if n > 0 {
+ fd.pipe.Notify(waiter.EventIn)
+ }
+ if err == nil && n != origCount {
+ return n, syserror.ErrWouldBlock
+ }
+ return n, err
+}
+
+// SwapUint32 implements usermem.IO.SwapUint32.
+func (fd *VFSPipeFD) SwapUint32(ctx context.Context, addr usermem.Addr, new uint32, opts usermem.IOOpts) (uint32, error) {
+ // How did a pipe get passed as the virtual address space to futex(2)?
+ panic("VFSPipeFD.SwapUint32 called unexpectedly")
+}
+
+// CompareAndSwapUint32 implements usermem.IO.CompareAndSwapUint32.
+func (fd *VFSPipeFD) CompareAndSwapUint32(ctx context.Context, addr usermem.Addr, old, new uint32, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.CompareAndSwapUint32 called unexpectedly")
+}
+
+// LoadUint32 implements usermem.IO.LoadUint32.
+func (fd *VFSPipeFD) LoadUint32(ctx context.Context, addr usermem.Addr, opts usermem.IOOpts) (uint32, error) {
+ panic("VFSPipeFD.LoadUint32 called unexpectedly")
+}
+
+// Splice reads up to count bytes from src and writes them to dst. It returns
+// the number of bytes moved.
+//
+// Preconditions: count > 0.
+func Splice(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, true /* removeFromSrc */)
+}
+
+// Tee reads up to count bytes from src and writes them to dst, without
+// removing the read bytes from src. It returns the number of bytes copied.
+//
+// Preconditions: count > 0.
+func Tee(ctx context.Context, dst, src *VFSPipeFD, count int64) (int64, error) {
+ return spliceOrTee(ctx, dst, src, count, false /* removeFromSrc */)
+}
+
+// Preconditions: count > 0.
+func spliceOrTee(ctx context.Context, dst, src *VFSPipeFD, count int64, removeFromSrc bool) (int64, error) {
+ if dst.pipe == src.pipe {
+ return 0, syserror.EINVAL
+ }
+
+ lockTwoPipes(dst.pipe, src.pipe)
+ defer dst.pipe.mu.Unlock()
+ defer src.pipe.mu.Unlock()
+
+ n, err := dst.pipe.writeLocked(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(dstView *buffer.View) (int64, error) {
+ return src.pipe.readLocked(ctx, readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(srcView *buffer.View) (int64, error) {
+ n, err := srcView.ReadToSafememWriter(dstView, uint64(count))
+ if n > 0 && removeFromSrc {
+ srcView.TrimFront(int64(n))
+ }
+ return int64(n), err
+ },
+ })
+ },
+ })
+ if n > 0 {
+ dst.pipe.Notify(waiter.EventIn)
+ src.pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (fd *VFSPipeFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (fd *VFSPipeFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/kernel/pipe/writer.go b/pkg/sentry/kernel/pipe/writer.go
index 5bc6aa931..ef4b70ca3 100644
--- a/pkg/sentry/kernel/pipe/writer.go
+++ b/pkg/sentry/kernel/pipe/writer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -29,7 +30,7 @@ type Writer struct {
// Release implements fs.FileOperations.Release.
//
// This overrides ReaderWriter.Release.
-func (w *Writer) Release() {
+func (w *Writer) Release(context.Context) {
w.Pipe.wClose()
// Wake up readers.
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 3be171cdc..619b0cb7c 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -20,8 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// ptraceOptions are the subset of options controlling a task's ptrace behavior
@@ -184,7 +184,6 @@ func (t *Task) CanTrace(target *Task, attach bool) bool {
if targetCreds.PermittedCaps&^callerCreds.PermittedCaps != 0 {
return false
}
- // TODO: Yama LSM
return true
}
@@ -1019,6 +1018,9 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
if err != nil {
return err
}
+
+ t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch())
+
ar := ars.Head()
n, err := target.Arch().PtraceGetRegSet(uintptr(addr), &usermem.IOReadWriter{
Ctx: t,
@@ -1045,10 +1047,14 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
if err != nil {
return err
}
+
+ mm := t.MemoryManager()
+ t.p.PullFullState(mm.AddressSpace(), t.Arch())
+
ar := ars.Head()
n, err := target.Arch().PtraceSetRegSet(uintptr(addr), &usermem.IOReadWriter{
Ctx: t,
- IO: t.MemoryManager(),
+ IO: mm,
Addr: ar.Start,
Opts: usermem.IOOpts{
AddressSpaceActive: true,
@@ -1057,6 +1063,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error {
if err != nil {
return err
}
+ t.p.FullStateChanged()
ar.End -= usermem.Addr(n)
return t.CopyOutIovecs(data, usermem.AddrRangeSeqOf(ar))
diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go
index 5514cf432..cef1276ec 100644
--- a/pkg/sentry/kernel/ptrace_amd64.go
+++ b/pkg/sentry/kernel/ptrace_amd64.go
@@ -18,8 +18,8 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// ptraceArch implements arch-specific ptrace commands.
diff --git a/pkg/sentry/kernel/ptrace_arm64.go b/pkg/sentry/kernel/ptrace_arm64.go
index 0acdf769d..d971b96b3 100644
--- a/pkg/sentry/kernel/ptrace_arm64.go
+++ b/pkg/sentry/kernel/ptrace_arm64.go
@@ -17,9 +17,8 @@
package kernel
import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// ptraceArch implements arch-specific ptrace commands.
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 24ea002ba..18416643b 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -15,17 +15,29 @@
package kernel
import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/hostcpu"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// Restartable sequences, as described in https://lwn.net/Articles/650333/.
+// Restartable sequences.
+//
+// We support two different APIs for restartable sequences.
+//
+// 1. The upstream interface added in v4.18.
+// 2. The interface described in https://lwn.net/Articles/650333/.
+//
+// Throughout this file and other parts of the kernel, the latter is referred
+// to as "old rseq". This interface was never merged upstream, but is supported
+// for a limited set of applications that use it regardless.
-// RSEQCriticalRegion describes a restartable sequence critical region.
+// OldRSeqCriticalRegion describes an old rseq critical region.
//
// +stateify savable
-type RSEQCriticalRegion struct {
+type OldRSeqCriticalRegion struct {
// When a task in this thread group has its CPU preempted (as defined by
// platform.ErrContextCPUPreempted) or has a signal delivered to an
// application handler while its instruction pointer is in CriticalSection,
@@ -35,86 +47,347 @@ type RSEQCriticalRegion struct {
Restart usermem.Addr
}
-// RSEQAvailable returns true if t supports restartable sequences.
-func (t *Task) RSEQAvailable() bool {
+// RSeqAvailable returns true if t supports (old and new) restartable sequences.
+func (t *Task) RSeqAvailable() bool {
return t.k.useHostCores && t.k.Platform.DetectsCPUPreemption()
}
-// RSEQCriticalRegion returns a copy of t's thread group's current restartable
-// sequence.
-func (t *Task) RSEQCriticalRegion() RSEQCriticalRegion {
- return *t.tg.rscr.Load().(*RSEQCriticalRegion)
+// SetRSeq registers addr as this thread's rseq structure.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) SetRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr != 0 {
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EINVAL
+ }
+ return syserror.EBUSY
+ }
+
+ // rseq must be aligned and correctly sized.
+ if addr&(linux.AlignOfRSeq-1) != 0 {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if _, ok := t.MemoryManager().CheckIORange(addr, linux.SizeOfRSeq); !ok {
+ return syserror.EFAULT
+ }
+
+ t.rseqAddr = addr
+ t.rseqSignature = signature
+
+ // Initialize the CPUID.
+ //
+ // Linux implicitly does this on return from userspace, where failure
+ // would cause SIGSEGV.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return syserror.EFAULT
+ }
+
+ return nil
+}
+
+// ClearRSeq unregisters addr as this thread's rseq structure.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) ClearRSeq(addr usermem.Addr, length, signature uint32) error {
+ if t.rseqAddr == 0 {
+ return syserror.EINVAL
+ }
+ if t.rseqAddr != addr {
+ return syserror.EINVAL
+ }
+ if length != linux.SizeOfRSeq {
+ return syserror.EINVAL
+ }
+ if t.rseqSignature != signature {
+ return syserror.EPERM
+ }
+
+ if err := t.rseqClearCPU(); err != nil {
+ return err
+ }
+
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+
+ if t.oldRSeqCPUAddr == 0 {
+ // rseqCPU no longer needed.
+ t.rseqCPU = -1
+ }
+
+ return nil
}
-// SetRSEQCriticalRegion replaces t's thread group's restartable sequence.
+// OldRSeqCriticalRegion returns a copy of t's thread group's current
+// old restartable sequence.
+func (t *Task) OldRSeqCriticalRegion() OldRSeqCriticalRegion {
+ return *t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+}
+
+// SetOldRSeqCriticalRegion replaces t's thread group's old restartable
+// sequence.
//
-// Preconditions: t.RSEQAvailable() == true.
-func (t *Task) SetRSEQCriticalRegion(rscr RSEQCriticalRegion) error {
+// Preconditions: t.RSeqAvailable() == true.
+func (t *Task) SetOldRSeqCriticalRegion(r OldRSeqCriticalRegion) error {
// These checks are somewhat more lenient than in Linux, which (bizarrely)
- // requires rscr.CriticalSection to be non-empty and rscr.Restart to be
- // outside of rscr.CriticalSection, even if rscr.CriticalSection.Start == 0
+ // requires r.CriticalSection to be non-empty and r.Restart to be
+ // outside of r.CriticalSection, even if r.CriticalSection.Start == 0
// (which disables the critical region).
- if rscr.CriticalSection.Start == 0 {
- rscr.CriticalSection.End = 0
- rscr.Restart = 0
- t.tg.rscr.Store(&rscr)
+ if r.CriticalSection.Start == 0 {
+ r.CriticalSection.End = 0
+ r.Restart = 0
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
- if rscr.CriticalSection.Start >= rscr.CriticalSection.End {
+ if r.CriticalSection.Start >= r.CriticalSection.End {
return syserror.EINVAL
}
- if rscr.CriticalSection.Contains(rscr.Restart) {
+ if r.CriticalSection.Contains(r.Restart) {
return syserror.EINVAL
}
- // TODO(jamieliu): check that rscr.CriticalSection and rscr.Restart are in
- // the application address range, for consistency with Linux
- t.tg.rscr.Store(&rscr)
+ // TODO(jamieliu): check that r.CriticalSection and r.Restart are in
+ // the application address range, for consistency with Linux.
+ t.tg.oldRSeqCritical.Store(&r)
return nil
}
-// RSEQCPUAddr returns the address that RSEQ will keep updated with t's CPU
-// number.
+// OldRSeqCPUAddr returns the address that old rseq will keep updated with t's
+// CPU number.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) RSEQCPUAddr() usermem.Addr {
- return t.rseqCPUAddr
+func (t *Task) OldRSeqCPUAddr() usermem.Addr {
+ return t.oldRSeqCPUAddr
}
-// SetRSEQCPUAddr replaces the address that RSEQ will keep updated with t's CPU
-// number.
+// SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with
+// t's CPU number.
//
-// Preconditions: t.RSEQAvailable() == true. The caller must be running on the
+// Preconditions: t.RSeqAvailable() == true. The caller must be running on the
// task goroutine. t's AddressSpace must be active.
-func (t *Task) SetRSEQCPUAddr(addr usermem.Addr) error {
- t.rseqCPUAddr = addr
- if addr != 0 {
- t.rseqCPU = int32(hostcpu.GetCPU())
- if err := t.rseqCopyOutCPU(); err != nil {
- t.rseqCPUAddr = 0
- t.rseqCPU = -1
- return syserror.EINVAL // yes, EINVAL, not err or EFAULT
- }
- } else {
- t.rseqCPU = -1
+func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error {
+ t.oldRSeqCPUAddr = addr
+
+ // Check that addr is writable.
+ //
+ // N.B. rseqUpdateCPU may fail on a bad t.rseqAddr as well. That's
+ // unfortunate, but unlikely in a correct program.
+ if err := t.rseqUpdateCPU(); err != nil {
+ t.oldRSeqCPUAddr = 0
+ return syserror.EINVAL // yes, EINVAL, not err or EFAULT
}
return nil
}
// Preconditions: The caller must be running on the task goroutine. t's
// AddressSpace must be active.
-func (t *Task) rseqCopyOutCPU() error {
+func (t *Task) rseqUpdateCPU() error {
+ if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 {
+ t.rseqCPU = -1
+ return nil
+ }
+
+ t.rseqCPU = int32(hostcpu.GetCPU())
+
+ // Update both CPUs, even if one fails.
+ rerr := t.rseqCopyOutCPU()
+ oerr := t.oldRSeqCopyOutCPU()
+
+ if rerr != nil {
+ return rerr
+ }
+ return oerr
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) oldRSeqCopyOutCPU() error {
+ if t.oldRSeqCPUAddr == 0 {
+ return nil
+ }
+
buf := t.CopyScratchBuffer(4)
usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU))
- _, err := t.CopyOutBytes(t.rseqCPUAddr, buf)
+ _, err := t.CopyOutBytes(t.oldRSeqCPUAddr, buf)
return err
}
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqCopyOutCPU() error {
+ if t.rseqAddr == 0 {
+ return nil
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, uint32(t.rseqCPU)) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], uint32(t.rseqCPU)) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
+ return err
+}
+
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqClearCPU() error {
+ buf := t.CopyScratchBuffer(8)
+ // CPUIDStart and CPUID are the first two fields in linux.RSeq.
+ usermem.ByteOrder.PutUint32(buf, 0) // CPUIDStart
+ usermem.ByteOrder.PutUint32(buf[4:], linux.RSEQ_CPU_ID_UNINITIALIZED) // CPUID
+ // N.B. This write is not atomic, but since this occurs on the task
+ // goroutine then as long as userspace uses a single-instruction read
+ // it can't see an invalid value.
+ _, err := t.CopyOutBytes(t.rseqAddr, buf)
+ return err
+}
+
+// rseqAddrInterrupt checks if IP is in a critical section, and aborts if so.
+//
+// This is a bit complex since both the RSeq and RSeqCriticalSection structs
+// are stored in userspace. So we must:
+//
+// 1. Copy in the address of RSeqCriticalSection from RSeq.
+// 2. Copy in RSeqCriticalSection itself.
+// 3. Validate critical section struct version, address range, abort address.
+// 4. Validate the abort signature (4 bytes preceding abort IP match expected
+// signature).
+// 5. Clear address of RSeqCriticalSection from RSeq.
+// 6. Finally, conditionally abort.
+//
+// See kernel/rseq.c:rseq_ip_fixup for reference.
+//
+// Preconditions: The caller must be running on the task goroutine. t's
+// AddressSpace must be active.
+func (t *Task) rseqAddrInterrupt() {
+ if t.rseqAddr == 0 {
+ return
+ }
+
+ critAddrAddr, ok := t.rseqAddr.AddLength(linux.OffsetOfRSeqCriticalSection)
+ if !ok {
+ // SetRSeq should validate this.
+ panic(fmt.Sprintf("t.rseqAddr (%#x) not large enough", t.rseqAddr))
+ }
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ t.Debugf("Only 64-bit rseq supported.")
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ buf := t.CopyScratchBuffer(8)
+ if _, err := t.CopyInBytes(critAddrAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ critAddr := usermem.Addr(usermem.ByteOrder.Uint64(buf))
+ if critAddr == 0 {
+ return
+ }
+
+ var cs linux.RSeqCriticalSection
+ if _, err := cs.CopyIn(t, critAddr); err != nil {
+ t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ if cs.Version != 0 {
+ t.Debugf("Unknown version in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ start := usermem.Addr(cs.Start)
+ critRange, ok := start.ToRange(cs.PostCommitOffset)
+ if !ok {
+ t.Debugf("Invalid start and offset in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ abort := usermem.Addr(cs.Abort)
+ if critRange.Contains(abort) {
+ t.Debugf("Abort in critical section in %+v", cs)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Verify signature.
+ sigAddr := abort - linux.SizeOfRSeqSignature
+
+ buf = t.CopyScratchBuffer(linux.SizeOfRSeqSignature)
+ if _, err := t.CopyInBytes(sigAddr, buf); err != nil {
+ t.Debugf("Failed to copy critical section signature from %#x for rseq: %v", sigAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ sig := usermem.ByteOrder.Uint32(buf)
+ if sig != t.rseqSignature {
+ t.Debugf("Mismatched rseq signature %d != %d", sig, t.rseqSignature)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Clear the critical section address.
+ //
+ // NOTE(b/143949567): We don't support any rseq flags, so we always
+ // restart if we are in the critical section, and thus *always* clear
+ // critAddrAddr.
+ if _, err := t.MemoryManager().ZeroOut(t, critAddrAddr, int64(t.Arch().Width()), usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ t.Debugf("Failed to clear critical section address from %#x for rseq: %v", critAddrAddr, err)
+ t.forceSignal(linux.SIGSEGV, false /* unconditional */)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ return
+ }
+
+ // Finally we can actually decide whether or not to restart.
+ if !critRange.Contains(usermem.Addr(t.Arch().IP())) {
+ return
+ }
+
+ t.Arch().SetIP(uintptr(cs.Abort))
+}
+
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) rseqInterrupt() {
- rscr := t.tg.rscr.Load().(*RSEQCriticalRegion)
- if ip := t.Arch().IP(); rscr.CriticalSection.Contains(usermem.Addr(ip)) {
- t.Debugf("Interrupted RSEQ critical section at %#x; restarting at %#x", ip, rscr.Restart)
- t.Arch().SetIP(uintptr(rscr.Restart))
- t.Arch().SetRSEQInterruptedIP(ip)
+func (t *Task) oldRSeqInterrupt() {
+ r := t.tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
+ if ip := t.Arch().IP(); r.CriticalSection.Contains(usermem.Addr(ip)) {
+ t.Debugf("Interrupted rseq critical section at %#x; restarting at %#x", ip, r.Restart)
+ t.Arch().SetIP(uintptr(r.Restart))
+ t.Arch().SetOldRSeqInterruptedIP(ip)
}
}
+
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) rseqInterrupt() {
+ t.rseqAddrInterrupt()
+ t.oldRSeqInterrupt()
+}
diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD
index 98ea7a0d8..1b82e087b 100644
--- a/pkg/sentry/kernel/sched/BUILD
+++ b/pkg/sentry/kernel/sched/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"cpuset.go",
"sched.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/sched",
visibility = ["//pkg/sentry:internal"],
)
@@ -17,5 +15,5 @@ go_test(
name = "sched_test",
size = "small",
srcs = ["cpuset_test.go"],
- embed = [":sched"],
+ library = ":sched",
)
diff --git a/pkg/sentry/kernel/seccomp.go b/pkg/sentry/kernel/seccomp.go
index 2347dcf36..c38c5a40c 100644
--- a/pkg/sentry/kernel/seccomp.go
+++ b/pkg/sentry/kernel/seccomp.go
@@ -21,8 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const maxSyscallFilterInstructions = 1 << 15
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index f4c00cd86..65e5427c1 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -22,15 +21,15 @@ go_library(
"semaphore.go",
"waiter_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/semaphore",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
- "//pkg/sentry/context",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
+ "//pkg/sync",
"//pkg/syserror",
],
)
@@ -39,11 +38,11 @@ go_test(
name = "semaphore_test",
size = "small",
srcs = ["semaphore_test.go"],
- embed = [":semaphore"],
+ library = ":semaphore",
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/kernel/auth",
"//pkg/syserror",
],
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index 93fe68a3e..c00fa1138 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -17,14 +17,14 @@ package semaphore
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -302,7 +302,7 @@ func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Cred
return syserror.ERANGE
}
- // TODO(b/29354920): Clear undo entries in all processes
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
sem.value = val
sem.pid = pid
s.changeTime = ktime.NowFromContext(ctx)
@@ -336,7 +336,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti
for i, val := range vals {
sem := &s.sems[i]
- // TODO(b/29354920): Clear undo entries in all processes
+ // TODO(gvisor.dev/issue/137): Clear undo entries in all processes.
sem.value = int16(val)
sem.pid = pid
sem.wakeWaiters()
@@ -481,7 +481,7 @@ func (s *Set) executeOps(ctx context.Context, ops []linux.Sembuf, pid int32) (ch
}
// All operations succeeded, apply them.
- // TODO(b/29354920): handle undo operations.
+ // TODO(gvisor.dev/issue/137): handle undo operations.
for i, v := range tmpVals {
s.sems[i].value = v
s.sems[i].wakeWaiters()
@@ -554,6 +554,7 @@ func (s *sem) wakeWaiters() {
for w := s.waiters.Front(); w != nil; {
if s.value < w.value {
// Still blocked, skip it.
+ w = w.Next()
continue
}
w.ch <- struct{}{}
diff --git a/pkg/sentry/kernel/semaphore/semaphore_test.go b/pkg/sentry/kernel/semaphore/semaphore_test.go
index c235f6ca4..e47acefdf 100644
--- a/pkg/sentry/kernel/semaphore/semaphore_test.go
+++ b/pkg/sentry/kernel/semaphore/semaphore_test.go
@@ -18,8 +18,8 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 047b5214d..5c4c622c2 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -16,6 +16,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/syserror"
@@ -70,7 +71,7 @@ func (s *Session) incRef() {
//
// Precondition: callers must hold TaskSet.mu for writing.
func (s *Session) decRef() {
- s.refs.DecRefWithDestructor(func() {
+ s.refs.DecRefWithDestructor(nil, func(context.Context) {
// Remove translations from the leader.
for ns := s.leader.pidns; ns != nil; ns = ns.parent {
id := ns.sids[s]
@@ -162,7 +163,7 @@ func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) {
}
alive := true
- pg.refs.DecRefWithDestructor(func() {
+ pg.refs.DecRefWithDestructor(nil, func(context.Context) {
alive = false // don't bother with handleOrphan.
// Remove translations from the originator.
@@ -246,7 +247,7 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
var lastErr error
for tg := range tasks.Root.tgids {
- if tg.ProcessGroup() == pg {
+ if tg.processGroup == pg {
tg.signalHandlers.mu.Lock()
infoCopy := *info
if err := tg.leader.sendSignalLocked(&infoCopy, true /*group*/); err != nil {
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index cd48945e6..c211fc8d0 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,22 +8,21 @@ go_library(
"device.go",
"shm.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/shm",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
"//pkg/refs",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index 5bd610f68..13ec7afe0 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -35,21 +35,20 @@ package shm
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Key represents a shm segment key. Analogous to a file name.
@@ -71,9 +70,20 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
// shms maps segment ids to segments.
+ //
+ // shms holds all referenced segments, which are removed on the last
+ // DecRef. Thus, it cannot itself hold a reference on the Shm.
+ //
+ // Since removal only occurs after the last (unlocked) DecRef, there
+ // exists a short window during which a Shm still exists in Shm, but is
+ // unreferenced. Users must use TryIncRef to determine if the Shm is
+ // still valid.
shms map[ID]*Shm
// keysToShms maps segment keys to segments.
+ //
+ // Shms in keysToShms are guaranteed to be referenced, as they are
+ // removed by disassociateKey before the last DecRef.
keysToShms map[Key]*Shm
// Sum of the sizes of all existing segments rounded up to page size, in
@@ -95,10 +105,18 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
}
// FindByID looks up a segment given an ID.
+//
+// FindByID returns a reference on Shm.
func (r *Registry) FindByID(id ID) *Shm {
r.mu.Lock()
defer r.mu.Unlock()
- return r.shms[id]
+ s := r.shms[id]
+ // Take a reference on s. If TryIncRef fails, s has reached the last
+ // DecRef, but hasn't quite been removed from r.shms yet.
+ if s != nil && s.TryIncRef() {
+ return s
+ }
+ return nil
}
// dissociateKey removes the association between a segment and its key,
@@ -119,6 +137,8 @@ func (r *Registry) dissociateKey(s *Shm) {
// FindOrCreate looks up or creates a segment in the registry. It's functionally
// analogous to open(2).
+//
+// FindOrCreate returns a reference on Shm.
func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size uint64, mode linux.FileMode, private, create, exclusive bool) (*Shm, error) {
if (create || private) && (size < linux.SHMMIN || size > linux.SHMMAX) {
// "A new segment was to be created and size is less than SHMMIN or
@@ -166,6 +186,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
return nil, syserror.EEXIST
}
+ shm.IncRef()
return shm, nil
}
@@ -193,7 +214,14 @@ func (r *Registry) FindOrCreate(ctx context.Context, pid int32, key Key, size ui
// Need to create a new segment.
creator := fs.FileOwnerFromContext(ctx)
perms := fs.FilePermsFromMode(mode)
- return r.newShm(ctx, pid, key, creator, perms, size)
+ s, err := r.newShm(ctx, pid, key, creator, perms, size)
+ if err != nil {
+ return nil, err
+ }
+ // The initial reference is held by s itself. Take another to return to
+ // the caller.
+ s.IncRef()
+ return s, nil
}
// newShm creates a new segment in the registry.
@@ -296,22 +324,26 @@ func (r *Registry) remove(s *Shm) {
// Shm represents a single shared memory segment.
//
-// Shm segment are backed directly by an allocation from platform
-// memory. Segments are always mapped as a whole, greatly simplifying how
-// mappings are tracked. However note that mremap and munmap calls may cause the
-// vma for a segment to become fragmented; which requires special care when
-// unmapping a segment. See mm/shm.go.
+// Shm segment are backed directly by an allocation from platform memory.
+// Segments are always mapped as a whole, greatly simplifying how mappings are
+// tracked. However note that mremap and munmap calls may cause the vma for a
+// segment to become fragmented; which requires special care when unmapping a
+// segment. See mm/shm.go.
//
// Segments persist until they are explicitly marked for destruction via
-// shmctl(SHM_RMID).
+// MarkDestroyed().
//
// Shm implements memmap.Mappable and memmap.MappingIdentity.
//
// +stateify savable
type Shm struct {
- // AtomicRefCount tracks the number of references to this segment from
- // maps. A segment always holds a reference to itself, until it's marked for
+ // AtomicRefCount tracks the number of references to this segment.
+ //
+ // A segment holds a reference to itself until it is marked for
// destruction.
+ //
+ // In addition to direct users, the MemoryManager will hold references
+ // via MappingIdentity.
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
@@ -337,7 +369,7 @@ type Shm struct {
// fr is the offset into mfp.MemoryFile() that backs this contents of this
// segment. Immutable.
- fr platform.FileRange
+ fr memmap.FileRange
// mu protects all fields below.
mu sync.Mutex `state:"nosave"`
@@ -399,8 +431,8 @@ func (s *Shm) InodeID() uint64 {
// DecRef overrides refs.RefCount.DecRef with a destructor.
//
// Precondition: Caller must not hold s.mu.
-func (s *Shm) DecRef() {
- s.DecRefWithDestructor(s.destroy)
+func (s *Shm) DecRef(ctx context.Context) {
+ s.DecRefWithDestructor(ctx, s.destroy)
}
// Msync implements memmap.MappingIdentity.Msync. Msync is a no-op for shm
@@ -428,7 +460,7 @@ func (s *Shm) AddMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.A
func (s *Shm) RemoveMapping(ctx context.Context, _ memmap.MappingSpace, _ usermem.AddrRange, _ uint64, _ bool) {
s.mu.Lock()
defer s.mu.Unlock()
- // TODO(b/38173783): RemoveMapping may be called during task exit, when ctx
+ // RemoveMapping may be called during task exit, when ctx
// is context.Background. Gracefully handle missing clocks. Failing to
// update the detach time in these cases is ok, since no one can observe the
// omission.
@@ -484,9 +516,8 @@ type AttachOpts struct {
// ConfigureAttach creates an mmap configuration for the segment with the
// requested attach options.
//
-// ConfigureAttach returns with a ref on s on success. The caller should drop
-// this once the map is installed. This reference prevents s from being
-// destroyed before the returned configuration is used.
+// Postconditions: The returned MMapOpts are valid only as long as a reference
+// continues to be held on s.
func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts AttachOpts) (memmap.MMapOpts, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -504,7 +535,6 @@ func (s *Shm) ConfigureAttach(ctx context.Context, addr usermem.Addr, opts Attac
// in the user namespace that governs its IPC namespace." - man shmat(2)
return memmap.MMapOpts{}, syserror.EACCES
}
- s.IncRef()
return memmap.MMapOpts{
Length: s.size,
Offset: 0,
@@ -549,10 +579,15 @@ func (s *Shm) IPCStat(ctx context.Context) (*linux.ShmidDS, error) {
}
creds := auth.CredentialsFromContext(ctx)
- nattach := uint64(s.ReadRefs())
- // Don't report the self-reference we keep prior to being marked for
- // destruction. However, also don't report a count of -1 for segments marked
- // as destroyed, with no mappings.
+ // Use the reference count as a rudimentary count of the number of
+ // attaches. We exclude:
+ //
+ // 1. The reference the caller holds.
+ // 2. The self-reference held by s prior to destruction.
+ //
+ // Note that this may still overcount by including transient references
+ // used in concurrent calls.
+ nattach := uint64(s.ReadRefs()) - 1
if !s.pendingDestruction {
nattach--
}
@@ -607,7 +642,7 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error {
return nil
}
-func (s *Shm) destroy() {
+func (s *Shm) destroy(context.Context) {
s.mfp.MemoryFile().DecRef(s.fr)
s.registry.remove(s)
}
@@ -616,22 +651,21 @@ func (s *Shm) destroy() {
// destroyed once it has no references. MarkDestroyed may be called multiple
// times, and is safe to call after a segment has already been destroyed. See
// shmctl(IPC_RMID).
-func (s *Shm) MarkDestroyed() {
+func (s *Shm) MarkDestroyed(ctx context.Context) {
s.registry.dissociateKey(s)
s.mu.Lock()
- // Only drop the segment's self-reference once, when destruction is
- // requested. Otherwise, repeated calls to shmctl(IPC_RMID) would force a
- // segment to be destroyed prematurely, potentially with active maps to the
- // segment's address range. Remaining references are dropped when the
- // segment is detached or unmaped.
+ defer s.mu.Unlock()
if !s.pendingDestruction {
s.pendingDestruction = true
- s.mu.Unlock() // Must release s.mu before calling s.DecRef.
- s.DecRef()
+ // Drop the self-reference so destruction occurs when all
+ // external references are gone.
+ //
+ // N.B. This cannot be the final DecRef, as the caller also
+ // holds a reference.
+ s.DecRef(ctx)
return
}
- s.mu.Unlock()
}
// checkOwnership verifies whether a segment may be accessed by ctx as an
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index 02eede93d..e8cce37d0 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -38,6 +38,9 @@ const SignalPanic = linux.SIGUSR2
// Preconditions: Kernel must have an init process.
func (k *Kernel) sendExternalSignal(info *arch.SignalInfo, context string) {
switch linux.Signal(info.Signo) {
+ case linux.SIGURG:
+ // Sent by the Go 1.14+ runtime for asynchronous goroutine preemption.
+
case platform.SignalInterrupt:
// Assume that a call to platform.Context.Interrupt() misfired.
diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go
index a16f3d57f..768fda220 100644
--- a/pkg/sentry/kernel/signal_handlers.go
+++ b/pkg/sentry/kernel/signal_handlers.go
@@ -15,10 +15,9 @@
package kernel
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sync"
)
// SignalHandlers holds information about signal actions.
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
index 50b69d154..3eb78e91b 100644
--- a/pkg/sentry/kernel/signalfd/BUILD
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -1,22 +1,22 @@
-package(licenses = ["notice"])
+load("//tools:defs.bzl", "go_library")
-load("//tools/go_stateify:defs.bzl", "go_library")
+licenses(["notice"])
go_library(
name = "signalfd",
srcs = ["signalfd.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/binary",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 4b08d7d72..b07e1c1bd 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -16,17 +16,16 @@
package signalfd
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/anon"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -77,7 +76,7 @@ func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
}
// Release implements fs.FileOperations.Release.
-func (s *SignalOperations) Release() {}
+func (s *SignalOperations) Release(context.Context) {}
// Mask returns the signal mask.
func (s *SignalOperations) Mask() linux.SignalSet {
diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go
index 220fa73a2..413111faf 100644
--- a/pkg/sentry/kernel/syscalls.go
+++ b/pkg/sentry/kernel/syscalls.go
@@ -16,20 +16,20 @@ package kernel
import (
"fmt"
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// maxSyscallNum is the highest supported syscall number.
//
// The types below create fast lookup slices for all syscalls. This maximum
// serves as a sanity check that we don't allocate huge slices for a very large
-// syscall.
+// syscall. This is checked during registration.
const maxSyscallNum = 2000
// SyscallSupportLevel is a syscall support levels.
@@ -209,65 +209,71 @@ type Stracer interface {
// SyscallEnter is called on syscall entry.
//
// The returned private data is passed to SyscallExit.
- //
- // TODO(gvisor.dev/issue/155): remove kernel imports from the strace
- // package so that the type can be used directly.
SyscallEnter(t *Task, sysno uintptr, args arch.SyscallArguments, flags uint32) interface{}
// SyscallExit is called on syscall exit.
SyscallExit(context interface{}, t *Task, sysno, rval uintptr, err error)
}
-// SyscallTable is a lookup table of system calls. Critically, a SyscallTable
-// is *immutable*. In order to make supporting suspend and resume sane, they
-// must be uniquely registered and may not change during operation.
+// SyscallTable is a lookup table of system calls.
//
-// +stateify savable
+// Note that a SyscallTable is not savable directly. Instead, they are saved as
+// an OS/Arch pair and lookup happens again on restore.
type SyscallTable struct {
// OS is the operating system that this syscall table implements.
- OS abi.OS `state:"wait"`
+ OS abi.OS
// Arch is the architecture that this syscall table targets.
- Arch arch.Arch `state:"wait"`
+ Arch arch.Arch
// The OS version that this syscall table implements.
- Version Version `state:"manual"`
+ Version Version
// AuditNumber is a numeric constant that represents the syscall table. If
// non-zero, auditNumber must be one of the AUDIT_ARCH_* values defined by
// linux/audit.h.
- AuditNumber uint32 `state:"manual"`
+ AuditNumber uint32
// Table is the collection of functions.
- Table map[uintptr]Syscall `state:"manual"`
+ Table map[uintptr]Syscall
// lookup is a fixed-size array that holds the syscalls (indexed by
// their numbers). It is used for fast look ups.
- lookup []SyscallFn `state:"manual"`
+ lookup []SyscallFn
// Emulate is a collection of instruction addresses to emulate. The
// keys are addresses, and the values are system call numbers.
- Emulate map[usermem.Addr]uintptr `state:"manual"`
+ Emulate map[usermem.Addr]uintptr
// The function to call in case of a missing system call.
- Missing MissingFn `state:"manual"`
+ Missing MissingFn
// Stracer traces this syscall table.
- Stracer Stracer `state:"manual"`
+ Stracer Stracer
// External is used to handle an external callback.
- External func(*Kernel) `state:"manual"`
+ External func(*Kernel)
// ExternalFilterBefore is called before External is called before the syscall is executed.
// External is not called if it returns false.
- ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+ ExternalFilterBefore func(*Task, uintptr, arch.SyscallArguments) bool
// ExternalFilterAfter is called before External is called after the syscall is executed.
// External is not called if it returns false.
- ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool `state:"manual"`
+ ExternalFilterAfter func(*Task, uintptr, arch.SyscallArguments) bool
// FeatureEnable stores the strace and one-shot enable bits.
- FeatureEnable SyscallFlagsTable `state:"manual"`
+ FeatureEnable SyscallFlagsTable
+}
+
+// MaxSysno returns the largest system call number.
+func (s *SyscallTable) MaxSysno() (max uintptr) {
+ for num := range s.Table {
+ if num > max {
+ max = num
+ }
+ }
+ return max
}
// allSyscallTables contains all known tables.
@@ -290,6 +296,20 @@ func LookupSyscallTable(os abi.OS, a arch.Arch) (*SyscallTable, bool) {
// RegisterSyscallTable registers a new syscall table for use by a Kernel.
func RegisterSyscallTable(s *SyscallTable) {
+ if max := s.MaxSysno(); max > maxSyscallNum {
+ panic(fmt.Sprintf("SyscallTable %+v contains too large syscall number %d", s, max))
+ }
+ if _, ok := LookupSyscallTable(s.OS, s.Arch); ok {
+ panic(fmt.Sprintf("Duplicate SyscallTable registered for OS %v Arch %v", s.OS, s.Arch))
+ }
+ allSyscallTables = append(allSyscallTables, s)
+ s.Init()
+}
+
+// Init initializes the system call table.
+//
+// This should normally be called only during registration.
+func (s *SyscallTable) Init() {
if s.Table == nil {
// Ensure non-nil lookup table.
s.Table = make(map[uintptr]Syscall)
@@ -299,35 +319,16 @@ func RegisterSyscallTable(s *SyscallTable) {
s.Emulate = make(map[usermem.Addr]uintptr)
}
- var max uintptr
- for num := range s.Table {
- if num > max {
- max = num
- }
- }
-
- if max > maxSyscallNum {
- panic(fmt.Sprintf("SyscallTable %+v contains too large syscall number %d", s, max))
- }
-
- s.lookup = make([]SyscallFn, max+1)
+ max := s.MaxSysno() // Checked during RegisterSyscallTable.
// Initialize the fast-lookup table.
+ s.lookup = make([]SyscallFn, max+1)
for num, sc := range s.Table {
s.lookup[num] = sc.Fn
}
+ // Initialize all features.
s.FeatureEnable.init(s.Table, max)
-
- if _, ok := LookupSyscallTable(s.OS, s.Arch); ok {
- panic(fmt.Sprintf("Duplicate SyscallTable registered for OS %v Arch %v", s.OS, s.Arch))
- }
-
- // Save a reference to this table.
- //
- // This is required for a Kernel to find the table and for save/restore
- // operations below.
- allSyscallTables = append(allSyscallTables, s)
}
// Lookup returns the syscall implementation, if one exists.
@@ -339,6 +340,14 @@ func (s *SyscallTable) Lookup(sysno uintptr) SyscallFn {
return nil
}
+// LookupName looks up a syscall name.
+func (s *SyscallTable) LookupName(sysno uintptr) string {
+ if sc, ok := s.Table[sysno]; ok {
+ return sc.Name
+ }
+ return fmt.Sprintf("sys_%d", sysno) // Unlikely.
+}
+
// LookupEmulate looks up an emulation syscall number.
func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) {
sysno, ok := s.Emulate[addr]
diff --git a/pkg/sentry/kernel/syscalls_state.go b/pkg/sentry/kernel/syscalls_state.go
index 00358326b..90f890495 100644
--- a/pkg/sentry/kernel/syscalls_state.go
+++ b/pkg/sentry/kernel/syscalls_state.go
@@ -14,16 +14,34 @@
package kernel
-import "fmt"
+import (
+ "fmt"
-// afterLoad is invoked by stateify.
-func (s *SyscallTable) afterLoad() {
- otherTable, ok := LookupSyscallTable(s.OS, s.Arch)
- if !ok {
- // Couldn't find a reference?
- panic(fmt.Sprintf("syscall table not found for OS %v Arch %v", s.OS, s.Arch))
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// syscallTableInfo is used to reload the SyscallTable.
+//
+// +stateify savable
+type syscallTableInfo struct {
+ OS abi.OS
+ Arch arch.Arch
+}
+
+// saveSt saves the SyscallTable.
+func (tc *TaskContext) saveSt() syscallTableInfo {
+ return syscallTableInfo{
+ OS: tc.st.OS,
+ Arch: tc.st.Arch,
}
+}
- // Copy the table.
- *s = *otherTable
+// loadSt loads the SyscallTable.
+func (tc *TaskContext) loadSt(sti syscallTableInfo) {
+ st, ok := LookupSyscallTable(sti.OS, sti.Arch)
+ if !ok {
+ panic(fmt.Sprintf("syscall table not found for OS %v, Arch %v", sti.OS, sti.Arch))
+ }
+ tc.st = st // Save the table reference.
}
diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go
index 8227ecf1d..a83ce219c 100644
--- a/pkg/sentry/kernel/syslog.go
+++ b/pkg/sentry/kernel/syslog.go
@@ -17,7 +17,8 @@ package kernel
import (
"fmt"
"math/rand"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// syslog represents a sentry-global kernel log.
@@ -97,6 +98,15 @@ func (s *syslog) Log() []byte {
s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...)
}
+ if VFS2Enabled {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up VFS2..."))...)
+ if FUSEEnabled {
+ time += rand.Float64() / 2
+ s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up FUSE..."))...)
+ }
+ }
+
time += rand.Float64() / 2
s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...)
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index c82ef5486..5aee699e7 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -15,13 +15,14 @@
package kernel
import (
- "sync"
+ gocontext "context"
+ "runtime/trace"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -34,9 +35,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/unimpl"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/third_party/gvsync"
)
// Task represents a thread of execution in the untrusted app. It
@@ -65,6 +68,21 @@ type Task struct {
// runState is exclusive to the task goroutine.
runState taskRunState
+ // taskWorkCount represents the current size of the task work queue. It is
+ // used to avoid acquiring taskWorkMu when the queue is empty.
+ //
+ // Must accessed with atomic memory operations.
+ taskWorkCount int32
+
+ // taskWorkMu protects taskWork.
+ taskWorkMu sync.Mutex `state:"nosave"`
+
+ // taskWork is a queue of work to be executed before resuming user execution.
+ // It is similar to the task_work mechanism in Linux.
+ //
+ // taskWork is exclusive to the task goroutine.
+ taskWork []TaskWorker
+
// haveSyscallReturn is true if tc.Arch().Return() represents a value
// returned by a syscall (or set by ptrace after a syscall).
//
@@ -83,7 +101,7 @@ type Task struct {
//
// gosched is protected by goschedSeq. gosched is owned by the task
// goroutine.
- goschedSeq gvsync.SeqCount `state:"nosave"`
+ goschedSeq sync.SeqCount `state:"nosave"`
gosched TaskGoroutineSchedInfo
// yieldCount is the number of times the task goroutine has called
@@ -390,7 +408,14 @@ type Task struct {
// logPrefix is a string containing the task's thread ID in the root PID
// namespace, and is prepended to log messages emitted by Task.Infof etc.
- logPrefix atomic.Value `state:".(string)"`
+ logPrefix atomic.Value `state:"nosave"`
+
+ // traceContext and traceTask are both used for tracing, and are
+ // updated along with the logPrefix in updateInfoLocked.
+ //
+ // These are exclusive to the task goroutine.
+ traceContext gocontext.Context `state:"nosave"`
+ traceTask *trace.Task `state:"nosave"`
// creds is the task's credentials.
//
@@ -415,6 +440,11 @@ type Task struct {
// abstractSockets is protected by mu.
abstractSockets *AbstractSocketNamespace
+ // mountNamespaceVFS2 is the task's mount namespace.
+ //
+ // It is protected by mu. It is owned by the task goroutine.
+ mountNamespaceVFS2 *vfs.MountNamespace
+
// parentDeathSignal is sent to this task's thread group when its parent exits.
//
// parentDeathSignal is protected by mu.
@@ -469,29 +499,51 @@ type Task struct {
// bit.
//
// numaPolicy and numaNodeMask are protected by mu.
- numaPolicy int32
+ numaPolicy linux.NumaPolicy
numaNodeMask uint64
- // If netns is true, the task is in a non-root network namespace. Network
- // namespaces aren't currently implemented in full; being in a network
- // namespace simply prevents the task from observing any network devices
- // (including loopback) or using abstract socket addresses (see unix(7)).
+ // netns is the task's network namespace. netns is never nil.
//
- // netns is protected by mu. netns is owned by the task goroutine.
- netns bool
+ // netns is protected by mu.
+ netns *inet.Namespace
- // If rseqPreempted is true, before the next call to p.Switch(), interrupt
- // RSEQ critical regions as defined by tg.rseq and write the task
- // goroutine's CPU number to rseqCPUAddr. rseqCPU is the last CPU number
- // written to rseqCPUAddr.
+ // If rseqPreempted is true, before the next call to p.Switch(),
+ // interrupt rseq critical regions as defined by rseqAddr and
+ // tg.oldRSeqCritical and write the task goroutine's CPU number to
+ // rseqAddr/oldRSeqCPUAddr.
//
- // If rseqCPUAddr is 0, rseqCPU is -1.
+ // We support two ABIs for restartable sequences:
//
- // rseqCPUAddr, rseqCPU, and rseqPreempted are exclusive to the task
- // goroutine.
+ // 1. The upstream interface added in v4.18,
+ // 2. An "old" interface never merged upstream. In the implementation,
+ // this is referred to as "old rseq".
+ //
+ // rseqPreempted is exclusive to the task goroutine.
rseqPreempted bool `state:"nosave"`
- rseqCPUAddr usermem.Addr
- rseqCPU int32
+
+ // rseqCPU is the last CPU number written to rseqAddr/oldRSeqCPUAddr.
+ //
+ // If rseq is unused, rseqCPU is -1 for convenient use in
+ // platform.Context.Switch.
+ //
+ // rseqCPU is exclusive to the task goroutine.
+ rseqCPU int32
+
+ // oldRSeqCPUAddr is a pointer to the userspace old rseq CPU variable.
+ //
+ // oldRSeqCPUAddr is exclusive to the task goroutine.
+ oldRSeqCPUAddr usermem.Addr
+
+ // rseqAddr is a pointer to the userspace linux.RSeq structure.
+ //
+ // rseqAddr is exclusive to the task goroutine.
+ rseqAddr usermem.Addr
+
+ // rseqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ //
+ // rseqSignature is exclusive to the task goroutine.
+ rseqSignature uint32
// copyScratchBuffer is a buffer available to CopyIn/CopyOut
// implementations that require an intermediate buffer to copy data
@@ -513,6 +565,10 @@ type Task struct {
// futexWaiter is exclusive to the task goroutine.
futexWaiter *futex.Waiter `state:"nosave"`
+ // robustList is a pointer to the head of the tasks's robust futex
+ // list.
+ robustList usermem.Addr
+
// startTime is the real time at which the task started. It is set when
// a Task is created or invokes execve(2).
//
@@ -528,14 +584,6 @@ func (t *Task) loadPtraceTracer(tracer *Task) {
t.ptraceTracer.Store(tracer)
}
-func (t *Task) saveLogPrefix() string {
- return t.logPrefix.Load().(string)
-}
-
-func (t *Task) loadLogPrefix(prefix string) {
- t.logPrefix.Store(prefix)
-}
-
func (t *Task) saveSyscallFilters() []bpf.Program {
if f := t.syscallFilters.Load(); f != nil {
return f.([]bpf.Program)
@@ -549,6 +597,7 @@ func (t *Task) loadSyscallFilters(filters []bpf.Program) {
// afterLoad is invoked by stateify.
func (t *Task) afterLoad() {
+ t.updateInfoLocked()
t.interruptChan = make(chan struct{}, 1)
t.gosched.State = TaskGoroutineNonexistent
if t.stop != nil {
@@ -611,6 +660,11 @@ func (t *Task) Value(key interface{}) interface{} {
return int32(t.ThreadGroup().ID())
case fs.CtxRoot:
return t.fsContext.RootDirectory()
+ case vfs.CtxRoot:
+ return t.fsContext.RootDirectoryVFS2()
+ case vfs.CtxMountNamespace:
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
case fs.CtxDirentCacheLimiter:
return t.k.DirentCacheLimiter
case inet.CtxStack:
@@ -674,11 +728,19 @@ func (t *Task) SyscallRestartBlock() SyscallRestartBlock {
// Preconditions: The caller must be running on the task goroutine, or t.mu
// must be locked.
func (t *Task) IsChrooted() bool {
+ if VFS2Enabled {
+ realRoot := t.mountNamespaceVFS2.Root()
+ defer realRoot.DecRef(t)
+ root := t.fsContext.RootDirectoryVFS2()
+ defer root.DecRef(t)
+ return root != realRoot
+ }
+
realRoot := t.tg.mounts.Root()
- defer realRoot.DecRef()
+ defer realRoot.DecRef(t)
root := t.fsContext.RootDirectory()
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(t)
}
return root != realRoot
}
@@ -709,14 +771,22 @@ func (t *Task) FDTable() *FDTable {
return t.fdTable
}
-// GetFile is a convenience wrapper t.FDTable().GetFile.
+// GetFile is a convenience wrapper for t.FDTable().Get.
//
-// Precondition: same as FDTable.
+// Precondition: same as FDTable.Get.
func (t *Task) GetFile(fd int32) *fs.File {
f, _ := t.fdTable.Get(fd)
return f
}
+// GetFileVFS2 is a convenience wrapper for t.FDTable().GetVFS2.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) GetFileVFS2(fd int32) *vfs.FileDescription {
+ f, _ := t.fdTable.GetVFS2(fd)
+ return f
+}
+
// NewFDs is a convenience wrapper for t.FDTable().NewFDs.
//
// This automatically passes the task as the context.
@@ -726,6 +796,15 @@ func (t *Task) NewFDs(fd int32, files []*fs.File, flags FDFlags) ([]int32, error
return t.fdTable.NewFDs(t, fd, files, flags)
}
+// NewFDsVFS2 is a convenience wrapper for t.FDTable().NewFDsVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDsVFS2(fd int32, files []*vfs.FileDescription, flags FDFlags) ([]int32, error) {
+ return t.fdTable.NewFDsVFS2(t, fd, files, flags)
+}
+
// NewFDFrom is a convenience wrapper for t.FDTable().NewFDs with a single file.
//
// This automatically passes the task as the context.
@@ -739,6 +818,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error)
return fds[0], nil
}
+// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ return t.fdTable.NewFDVFS2(t, fd, file, flags)
+}
+
// NewFDAt is a convenience wrapper for t.FDTable().NewFDAt.
//
// This automatically passes the task as the context.
@@ -748,6 +836,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error {
return t.fdTable.NewFDAt(t, fd, file, flags)
}
+// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return t.fdTable.NewFDAtVFS2(t, fd, file, flags)
+}
+
// WithMuLocked executes f with t.mu locked.
func (t *Task) WithMuLocked(f func(*Task)) {
t.mu.Lock()
@@ -761,6 +858,15 @@ func (t *Task) MountNamespace() *fs.MountNamespace {
return t.tg.mounts
}
+// MountNamespaceVFS2 returns t's MountNamespace. A reference is taken on the
+// returned mount namespace.
+func (t *Task) MountNamespaceVFS2() *vfs.MountNamespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.mountNamespaceVFS2.IncRef()
+ return t.mountNamespaceVFS2
+}
+
// AbstractSockets returns t's AbstractSocketNamespace.
func (t *Task) AbstractSockets() *AbstractSocketNamespace {
return t.abstractSockets
@@ -770,3 +876,30 @@ func (t *Task) AbstractSockets() *AbstractSocketNamespace {
func (t *Task) ContainerID() string {
return t.containerID
}
+
+// OOMScoreAdj gets the task's thread group's OOM score adjustment.
+func (t *Task) OOMScoreAdj() int32 {
+ return atomic.LoadInt32(&t.tg.oomScoreAdj)
+}
+
+// SetOOMScoreAdj sets the task's thread group's OOM score adjustment. The
+// value should be between -1000 and 1000 inclusive.
+func (t *Task) SetOOMScoreAdj(adj int32) error {
+ if adj > 1000 || adj < -1000 {
+ return syserror.EINVAL
+ }
+ atomic.StoreInt32(&t.tg.oomScoreAdj, adj)
+ return nil
+}
+
+// UID returns t's uid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) UID() uint32 {
+ return uint32(t.Credentials().EffectiveKUID)
+}
+
+// GID returns t's gid.
+// TODO(gvisor.dev/issue/170): This method is not namespaced yet.
+func (t *Task) GID() uint32 {
+ return uint32(t.Credentials().EffectiveKGID)
+}
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index dd69939f9..4a4a69ee2 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -16,6 +16,7 @@ package kernel
import (
"runtime"
+ "runtime/trace"
"time"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -133,19 +134,24 @@ func (t *Task) block(C <-chan struct{}, timerChan <-chan struct{}) error {
runtime.Gosched()
}
+ region := trace.StartRegion(t.traceContext, blockRegion)
select {
case <-C:
+ region.End()
t.SleepFinish(true)
+ // Woken by event.
return nil
case <-interrupt:
+ region.End()
t.SleepFinish(false)
// Return the indicated error on interrupt.
return syserror.ErrInterrupted
case <-timerChan:
- // We've timed out.
+ region.End()
t.SleepFinish(true)
+ // We've timed out.
return syserror.ETIMEDOUT
}
}
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index 0916fd658..9d7a9128f 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -15,10 +15,13 @@
package kernel
import (
+ "sync/atomic"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// SharingOptions controls what resources are shared by a new task created by
@@ -54,8 +57,7 @@ type SharingOptions struct {
NewUserNamespace bool
// If NewNetworkNamespace is true, the task should have an independent
- // network namespace. (Note that network namespaces are not really
- // implemented; see comment on Task.netns for details.)
+ // network namespace.
NewNetworkNamespace bool
// If NewFiles is true, the task should use an independent file descriptor
@@ -159,6 +161,10 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
return 0, nil, syserror.EINVAL
}
+ // Pull task registers and FPU state, a cloned task will inherit the
+ // state of the current task.
+ t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch())
+
// "If CLONE_NEWUSER is specified along with other CLONE_NEW* flags in a
// single clone(2) or unshare(2) call, the user namespace is guaranteed to
// be created first, giving the child (clone(2)) or caller (unshare(2))
@@ -199,6 +205,17 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
ipcns = NewIPCNamespace(userns)
}
+ netns := t.NetworkNamespace()
+ if opts.NewNetworkNamespace {
+ netns = inet.NewNamespace(netns)
+ }
+
+ // TODO(b/63601033): Implement CLONE_NEWNS.
+ mntnsVFS2 := t.mountNamespaceVFS2
+ if mntnsVFS2 != nil {
+ mntnsVFS2.IncRef()
+ }
+
tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace)
if err != nil {
return 0, nil, err
@@ -224,7 +241,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
var fdTable *FDTable
if opts.NewFiles {
- fdTable = t.fdTable.Fork()
+ fdTable = t.fdTable.Fork(t)
} else {
fdTable = t.fdTable
fdTable.IncRef()
@@ -236,14 +253,22 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else if opts.NewPIDNamespace {
pidns = pidns.NewChild(userns)
}
+
tg := t.tg
+ rseqAddr := usermem.Addr(0)
+ rseqSignature := uint32(0)
if opts.NewThreadGroup {
- tg.mounts.IncRef()
+ if tg.mounts != nil {
+ tg.mounts.IncRef()
+ }
sh := t.tg.signalHandlers
if opts.NewSignalHandlers {
sh = sh.Fork()
}
- tg = t.k.newThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy(), t.k.monotonicClock)
+ tg = t.k.NewThreadGroup(tg.mounts, pidns, sh, opts.TerminationSignal, tg.limits.GetCopy())
+ tg.oomScoreAdj = atomic.LoadInt32(&t.tg.oomScoreAdj)
+ rseqAddr = t.rseqAddr
+ rseqSignature = t.rseqSignature
}
cfg := &TaskConfig{
@@ -255,11 +280,14 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
FDTable: fdTable,
Credentials: creds,
Niceness: t.Niceness(),
- NetworkNamespaced: t.netns,
+ NetworkNamespace: netns,
AllowedCPUMask: t.CPUMask(),
UTSNamespace: utsns,
IPCNamespace: ipcns,
AbstractSocketNamespace: t.abstractSockets,
+ MountNamespaceVFS2: mntnsVFS2,
+ RSeqAddr: rseqAddr,
+ RSeqSignature: rseqSignature,
ContainerID: t.ContainerID(),
}
if opts.NewThreadGroup {
@@ -267,13 +295,10 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
cfg.InheritParent = t
}
- if opts.NewNetworkNamespace {
- cfg.NetworkNamespaced = true
- }
nt, err := t.tg.pidns.owner.NewTask(cfg)
if err != nil {
if opts.NewThreadGroup {
- tg.release()
+ tg.release(t)
}
return 0, nil, err
}
@@ -299,6 +324,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
// nt that it must receive before its task goroutine starts running.
tid := nt.k.tasks.Root.IDOfTask(nt)
defer nt.Start(tid)
+ t.traceCloneEvent(tid)
// "If fork/clone and execve are allowed by @prog, any child processes will
// be constrained to the same filters and system call ABI as the parent." -
@@ -465,7 +491,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
t.mu.Unlock()
return syserror.EPERM
}
- t.netns = true
+ t.netns = inet.NewNamespace(t.netns)
}
if opts.NewUTSNamespace {
if !haveCapSysAdmin {
@@ -488,7 +514,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
var oldFDTable *FDTable
if opts.NewFiles {
oldFDTable = t.fdTable
- t.fdTable = oldFDTable.Fork()
+ t.fdTable = oldFDTable.Fork(t)
}
var oldFSContext *FSContext
if opts.NewFSContext {
@@ -497,10 +523,10 @@ func (t *Task) Unshare(opts *SharingOptions) error {
}
t.mu.Unlock()
if oldFDTable != nil {
- oldFDTable.DecRef()
+ oldFDTable.DecRef(t)
}
if oldFSContext != nil {
- oldFSContext.DecRef()
+ oldFSContext.DecRef(t)
}
return nil
}
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index bb5560acf..9fa528384 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -18,13 +18,13 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/loader"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/usermem"
)
var errNoSyscalls = syserr.New("no syscall table found", linux.ENOEXEC)
@@ -49,7 +49,7 @@ type TaskContext struct {
fu *futex.Manager
// st is the task's syscall table.
- st *SyscallTable
+ st *SyscallTable `state:".(syscallTableInfo)"`
}
// release releases all resources held by the TaskContext. release is called by
@@ -58,7 +58,6 @@ func (tc *TaskContext) release() {
// Nil out pointers so that if the task is saved after release, it doesn't
// follow the pointers to possibly now-invalid objects.
if tc.MemoryManager != nil {
- // TODO(b/38173783)
tc.MemoryManager.DecUsers(context.Background())
tc.MemoryManager = nil
}
@@ -136,11 +135,11 @@ func (t *Task) Stack() *arch.Stack {
func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*TaskContext, *syserr.Error) {
// If File is not nil, we should load that instead of resolving Filename.
if args.File != nil {
- args.Filename = args.File.MappedName(ctx)
+ args.Filename = args.File.PathnameWithDeleted(ctx)
}
// Prepare a new user address space to load into.
- m := mm.NewMemoryManager(k, k)
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
defer m.DecUsers(ctx)
args.MemoryManager = m
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 17a089b90..5e4fb3e3a 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -69,6 +69,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/mm"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -129,6 +130,7 @@ type runSyscallAfterExecStop struct {
}
func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
+ t.traceExecEvent(r.tc)
t.tg.pidns.owner.mu.Lock()
t.tg.execing = nil
if t.killed() {
@@ -189,16 +191,25 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.updateRSSLocked()
// Restartable sequence state is discarded.
t.rseqPreempted = false
- t.rseqCPUAddr = 0
t.rseqCPU = -1
- t.tg.rscr.Store(&RSEQCriticalRegion{})
+ t.rseqAddr = 0
+ t.rseqSignature = 0
+ t.oldRSeqCPUAddr = 0
+ t.tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
t.tg.pidns.owner.mu.Unlock()
+ oldFDTable := t.fdTable
+ t.fdTable = t.fdTable.Fork(t)
+ oldFDTable.DecRef(t)
+
// Remove FDs with the CloseOnExec flag set.
- t.fdTable.RemoveIf(func(file *fs.File, flags FDFlags) bool {
+ t.fdTable.RemoveIf(t, func(_ *fs.File, _ *vfs.FileDescription, flags FDFlags) bool {
return flags.CloseOnExec
})
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// NOTE(b/30815691): We currently do not implement privileged
// executables (set-user/group-ID bits and file capabilities). This
// allows us to unconditionally enable user dumpability on the new mm.
@@ -215,8 +226,9 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.tc = *r.tc
t.mu.Unlock()
t.unstopVforkParent()
+ t.p.FullStateChanged()
// NOTE(b/30316266): All locks must be dropped prior to calling Activate.
- t.MemoryManager().Activate()
+ t.MemoryManager().Activate(t)
t.ptraceExec(oldTID)
return (*runSyscallExit)(nil)
@@ -253,7 +265,7 @@ func (t *Task) promoteLocked() {
t.tg.leader = t
t.Infof("Becoming TID %d (in root PID namespace)", t.tg.pidns.owner.Root.tids[t])
- t.updateLogPrefixLocked()
+ t.updateInfoLocked()
// Reap the original leader. If it has a tracer, detach it instead of
// waiting for it to acknowledge the original leader's death.
oldLeader.exitParentNotified = true
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index 535f03e50..c165d6cb1 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -236,6 +236,7 @@ func (*runExit) execute(t *Task) taskRunState {
type runExitMain struct{}
func (*runExitMain) execute(t *Task) taskRunState {
+ t.traceExitEvent()
lastExiter := t.exitThreadGroup()
// If the task has a cleartid, and the thread group wasn't killed by a
@@ -252,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState {
}
}
+ // Handle the robust futex list.
+ t.exitRobustList()
+
// Deactivate the address space and update max RSS before releasing the
// task's MM.
t.Deactivate()
@@ -265,13 +269,20 @@ func (*runExitMain) execute(t *Task) taskRunState {
// Releasing the MM unblocks a blocked CLONE_VFORK parent.
t.unstopVforkParent()
- t.fsContext.DecRef()
- t.fdTable.DecRef()
+ t.fsContext.DecRef(t)
+ t.fdTable.DecRef(t)
+
+ t.mu.Lock()
+ if t.mountNamespaceVFS2 != nil {
+ t.mountNamespaceVFS2.DecRef(t)
+ t.mountNamespaceVFS2 = nil
+ }
+ t.mu.Unlock()
// If this is the last task to exit from the thread group, release the
// thread group's resources.
if lastExiter {
- t.tg.release()
+ t.tg.release(t)
}
// Detach tracees.
diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go
index c211b5b74..4b535c949 100644
--- a/pkg/sentry/kernel/task_futex.go
+++ b/pkg/sentry/kernel/task_futex.go
@@ -15,8 +15,9 @@
package kernel
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Futex returns t's futex manager.
@@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) {
func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) {
return t.MemoryManager().GetSharedFutexKey(t, addr)
}
+
+// GetRobustList sets the robust futex list for the task.
+func (t *Task) GetRobustList() usermem.Addr {
+ t.mu.Lock()
+ addr := t.robustList
+ t.mu.Unlock()
+ return addr
+}
+
+// SetRobustList sets the robust futex list for the task.
+func (t *Task) SetRobustList(addr usermem.Addr) {
+ t.mu.Lock()
+ t.robustList = addr
+ t.mu.Unlock()
+}
+
+// exitRobustList walks the robust futex list, marking locks dead and notifying
+// wakers. It corresponds to Linux's exit_robust_list(). Following Linux,
+// errors are silently ignored.
+func (t *Task) exitRobustList() {
+ t.mu.Lock()
+ addr := t.robustList
+ t.robustList = 0
+ t.mu.Unlock()
+
+ if addr == 0 {
+ return
+ }
+
+ var rl linux.RobustListHead
+ if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil {
+ return
+ }
+
+ next := rl.List
+ done := 0
+ var pendingLockAddr usermem.Addr
+ if rl.ListOpPending != 0 {
+ pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset)
+ }
+
+ // Wake up normal elements.
+ for usermem.Addr(next) != addr {
+ // We traverse to the next element of the list before we
+ // actually wake anything. This prevents the race where waking
+ // this futex causes a modification of the list.
+ thisLockAddr := usermem.Addr(next + rl.FutexOffset)
+
+ // Try to decode the next element in the list before waking the
+ // current futex. But don't check the error until after we've
+ // woken the current futex. Linux does it in this order too
+ _, nextErr := t.CopyIn(usermem.Addr(next), &next)
+
+ // Wakeup the current futex if it's not pending.
+ if thisLockAddr != pendingLockAddr {
+ t.wakeRobustListOne(thisLockAddr)
+ }
+
+ // If there was an error copying the next futex, we must bail.
+ if nextErr != nil {
+ break
+ }
+
+ // This is a user structure, so it could be a massive list, or
+ // even contain a loop if they are trying to mess with us. We
+ // cap traversal to prevent that.
+ done++
+ if done >= linux.ROBUST_LIST_LIMIT {
+ break
+ }
+ }
+
+ // Is there a pending entry to wake?
+ if pendingLockAddr != 0 {
+ t.wakeRobustListOne(pendingLockAddr)
+ }
+}
+
+// wakeRobustListOne wakes a single futex from the robust list.
+func (t *Task) wakeRobustListOne(addr usermem.Addr) {
+ // Bit 0 in address signals PI futex.
+ pi := addr&1 == 1
+ addr = addr &^ 1
+
+ // Load the futex.
+ f, err := t.LoadUint32(addr)
+ if err != nil {
+ // Can't read this single value? Ignore the problem.
+ // We can wake the other futexes in the list.
+ return
+ }
+
+ tid := uint32(t.ThreadID())
+ for {
+ // Is this held by someone else?
+ if f&linux.FUTEX_TID_MASK != tid {
+ return
+ }
+
+ // This thread is dying and it's holding this futex. We need to
+ // set the owner died bit and wake up any waiters.
+ newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED
+ if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil {
+ return
+ } else if curF != f {
+ // Futex changed out from under us. Try again...
+ f = curF
+ continue
+ }
+
+ // Wake waiters if there are any.
+ if f&linux.FUTEX_WAITERS != 0 {
+ private := f&linux.FUTEX_PRIVATE_FLAG != 0
+ if pi {
+ t.Futex().UnlockPI(t, addr, tid, private)
+ return
+ }
+ t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1)
+ }
+
+ // Done.
+ return
+ }
+}
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index ce3e6ef28..0325967e4 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -455,7 +455,7 @@ func (t *Task) SetKeepCaps(k bool) {
t.creds.Store(creds)
}
-// updateCredsForExec updates t.creds to reflect an execve().
+// updateCredsForExecLocked updates t.creds to reflect an execve().
//
// NOTE(b/30815691): We currently do not implement privileged executables
// (set-user/group-ID bits and file capabilities). This allows us to make a lot
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index a29e9b9eb..d23cea802 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -16,36 +16,40 @@ package kernel
import (
"fmt"
+ "runtime/trace"
"sort"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
// maxStackDebugBytes is the maximum number of user stack bytes that may be
// printed by debugDumpStack.
maxStackDebugBytes = 1024
+ // maxCodeDebugBytes is the maximum number of user code bytes that may be
+ // printed by debugDumpCode.
+ maxCodeDebugBytes = 128
)
// Infof logs an formatted info message by calling log.Infof.
func (t *Task) Infof(fmt string, v ...interface{}) {
if log.IsLogging(log.Info) {
- log.Infof(t.logPrefix.Load().(string)+fmt, v...)
+ log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Warningf logs a warning string by calling log.Warningf.
func (t *Task) Warningf(fmt string, v ...interface{}) {
if log.IsLogging(log.Warning) {
- log.Warningf(t.logPrefix.Load().(string)+fmt, v...)
+ log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Debugf creates a debug string that includes the task ID.
func (t *Task) Debugf(fmt string, v ...interface{}) {
if log.IsLogging(log.Debug) {
- log.Debugf(t.logPrefix.Load().(string)+fmt, v...)
+ log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
@@ -60,6 +64,7 @@ func (t *Task) IsLogging(level log.Level) bool {
func (t *Task) DebugDumpState() {
t.debugDumpRegisters()
t.debugDumpStack()
+ t.debugDumpCode()
if mm := t.MemoryManager(); mm != nil {
t.Debugf("Mappings:\n%s", mm)
}
@@ -127,11 +132,120 @@ func (t *Task) debugDumpStack() {
}
}
-// updateLogPrefix updates the task's cached log prefix to reflect its
-// current thread ID.
+// debugDumpCode logs user code contents at log level debug.
+//
+// Preconditions: The caller must be running on the task goroutine.
+func (t *Task) debugDumpCode() {
+ if !t.IsLogging(log.Debug) {
+ return
+ }
+ m := t.MemoryManager()
+ if m == nil {
+ t.Debugf("Memory manager for task is gone, skipping application code dump.")
+ return
+ }
+ t.Debugf("Code:")
+ // Print code on both sides of the instruction register.
+ start := usermem.Addr(t.Arch().IP()) - maxCodeDebugBytes/2
+ // Round addr down to a 16-byte boundary.
+ start &= ^usermem.Addr(15)
+ // Print 16 bytes per line, one byte at a time.
+ for offset := uint64(0); offset < maxCodeDebugBytes; offset += 16 {
+ addr, ok := start.AddLength(offset)
+ if !ok {
+ break
+ }
+ var data [16]byte
+ n, err := m.CopyIn(t, addr, data[:], usermem.IOOpts{
+ IgnorePermissions: true,
+ })
+ // Print as much of the line as we can, even if an error was
+ // encountered.
+ if n > 0 {
+ t.Debugf("%x: % x", addr, data[:n])
+ }
+ if err != nil {
+ t.Debugf("Error reading stack at address %x: %v", addr+usermem.Addr(n), err)
+ break
+ }
+ }
+}
+
+// trace definitions.
+//
+// Note that all region names are prefixed by ':' in order to ensure that they
+// are lexically ordered before all system calls, which use the naked system
+// call name (e.g. "read") for maximum clarity.
+const (
+ traceCategory = "task"
+ runRegion = ":run"
+ blockRegion = ":block"
+ cpuidRegion = ":cpuid"
+ faultRegion = ":fault"
+)
+
+// updateInfoLocked updates the task's cached log prefix and tracing
+// information to reflect its current thread ID.
//
// Preconditions: The task's owning TaskSet.mu must be locked.
-func (t *Task) updateLogPrefixLocked() {
+func (t *Task) updateInfoLocked() {
// Use the task's TID in the root PID namespace for logging.
- t.logPrefix.Store(fmt.Sprintf("[% 4d] ", t.tg.pidns.owner.Root.tids[t]))
+ tid := t.tg.pidns.owner.Root.tids[t]
+ t.logPrefix.Store(fmt.Sprintf("[% 4d] ", tid))
+ t.rebuildTraceContext(tid)
+}
+
+// rebuildTraceContext rebuilds the trace context.
+//
+// Precondition: the passed tid must be the tid in the root namespace.
+func (t *Task) rebuildTraceContext(tid ThreadID) {
+ // Re-initialize the trace context.
+ if t.traceTask != nil {
+ t.traceTask.End()
+ }
+
+ // Note that we define the "task type" to be the dynamic TID. This does
+ // not align perfectly with the documentation for "tasks" in the
+ // tracing package. Tasks may be assumed to be bounded by analysis
+ // tools. However, if we just use a generic "task" type here, then the
+ // "user-defined tasks" page on the tracing dashboard becomes nearly
+ // unusable, as it loads all traces from all tasks.
+ //
+ // We can assume that the number of tasks in the system is not
+ // arbitrarily large (in general it won't be, especially for cases
+ // where we're collecting a brief profile), so using the TID is a
+ // reasonable compromise in this case.
+ t.traceContext, t.traceTask = trace.NewTask(t, fmt.Sprintf("tid:%d", tid))
+}
+
+// traceCloneEvent is called when a new task is spawned.
+//
+// ntid must be the new task's ThreadID in the root namespace.
+func (t *Task) traceCloneEvent(ntid ThreadID) {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "spawn: %d", ntid)
+}
+
+// traceExitEvent is called when a task exits.
+func (t *Task) traceExitEvent() {
+ if !trace.IsEnabled() {
+ return
+ }
+ trace.Logf(t.traceContext, traceCategory, "exit status: 0x%x", t.exitStatus.Status())
+}
+
+// traceExecEvent is called when a task calls exec.
+func (t *Task) traceExecEvent(tc *TaskContext) {
+ if !trace.IsEnabled() {
+ return
+ }
+ file := tc.MemoryManager.Executable()
+ if file == nil {
+ trace.Logf(t.traceContext, traceCategory, "exec: << unknown >>")
+ return
+ }
+ defer file.DecRef(t)
+ trace.Logf(t.traceContext, traceCategory, "exec: %s", file.PathnameWithDeleted(t))
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index 172a31e1d..f7711232c 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -22,14 +22,23 @@ import (
func (t *Task) IsNetworkNamespaced() bool {
t.mu.Lock()
defer t.mu.Unlock()
- return t.netns
+ return !t.netns.IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
// may return nil if no network stack is available.
+//
+// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
+// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- if t.IsNetworkNamespaced() {
- return nil
- }
- return t.k.networkStack
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns.Stack()
+}
+
+// NetworkNamespace returns the network namespace observed by the task.
+func (t *Task) NetworkNamespace() *inet.Namespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
}
diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go
index c92266c59..abaf29216 100644
--- a/pkg/sentry/kernel/task_run.go
+++ b/pkg/sentry/kernel/task_run.go
@@ -17,6 +17,7 @@ package kernel
import (
"bytes"
"runtime"
+ "runtime/trace"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,7 +26,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// A taskRunState is a reified state in the task state machine. See README.md
@@ -95,6 +96,7 @@ func (t *Task) run(threadID uintptr) {
t.tg.liveGoroutines.Done()
t.tg.pidns.owner.liveGoroutines.Done()
t.tg.pidns.owner.runningGoroutines.Done()
+ t.p.Release()
// Keep argument alive because stack trace for dead variables may not be correct.
runtime.KeepAlive(threadID)
@@ -125,13 +127,39 @@ func (t *Task) doStop() {
}
}
+func (*runApp) handleCPUIDInstruction(t *Task) error {
+ if len(arch.CPUIDInstruction) == 0 {
+ // CPUID emulation isn't supported, but this code can be
+ // executed, because the ptrace platform returns
+ // ErrContextSignalCPUID on page faults too. Look at
+ // pkg/sentry/platform/ptrace/ptrace.go:context.Switch for more
+ // details.
+ return platform.ErrContextSignal
+ }
+ // Is this a CPUID instruction?
+ region := trace.StartRegion(t.traceContext, cpuidRegion)
+ expected := arch.CPUIDInstruction[:]
+ found := make([]byte, len(expected))
+ _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
+ if err == nil && bytes.Equal(expected, found) {
+ // Skip the cpuid instruction.
+ t.Arch().CPUIDEmulate(t)
+ t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
+ region.End()
+
+ return nil
+ }
+ region.End() // Not an actual CPUID, but required copy-in.
+ return platform.ErrContextSignal
+}
+
// The runApp state checks for interrupts before executing untrusted
// application code.
//
// +stateify savable
type runApp struct{}
-func (*runApp) execute(t *Task) taskRunState {
+func (app *runApp) execute(t *Task) taskRunState {
if t.interrupted() {
// Checkpointing instructs tasks to stop by sending an interrupt, so we
// must check for stops before entering runInterrupt (instead of
@@ -139,7 +167,22 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runInterrupt)(nil)
}
- // We're about to switch to the application again. If there's still a
+ // Execute any task work callbacks before returning to user space.
+ if atomic.LoadInt32(&t.taskWorkCount) > 0 {
+ t.taskWorkMu.Lock()
+ queue := t.taskWork
+ t.taskWork = nil
+ atomic.StoreInt32(&t.taskWorkCount, 0)
+ t.taskWorkMu.Unlock()
+
+ // Do not hold taskWorkMu while executing task work, which may register
+ // more work.
+ for _, work := range queue {
+ work.TaskWork(t)
+ }
+ }
+
+ // We're about to switch to the application again. If there's still an
// unhandled SyscallRestartErrno that wasn't translated to an EINTR,
// restart the syscall that was interrupted. If there's a saved signal
// mask, restore it. (Note that restoring the saved signal mask may unblock
@@ -168,12 +211,22 @@ func (*runApp) execute(t *Task) taskRunState {
// Apply restartable sequences.
if t.rseqPreempted {
t.rseqPreempted = false
- if t.rseqCPUAddr != 0 {
+ if t.rseqAddr != 0 || t.oldRSeqCPUAddr != 0 {
+ // Linux writes the CPU on every preemption. We only do
+ // so if it changed. Thus we may delay delivery of
+ // SIGSEGV if rseqAddr/oldRSeqCPUAddr is invalid.
cpu := int32(hostcpu.GetCPU())
if t.rseqCPU != cpu {
t.rseqCPU = cpu
if err := t.rseqCopyOutCPU(); err != nil {
- t.Warningf("Failed to copy CPU to %#x for RSEQ: %v", t.rseqCPUAddr, err)
+ t.Debugf("Failed to copy CPU to %#x for rseq: %v", t.rseqAddr, err)
+ t.forceSignal(linux.SIGSEGV, false)
+ t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
+ // Re-enter the task run loop for signal delivery.
+ return (*runApp)(nil)
+ }
+ if err := t.oldRSeqCopyOutCPU(); err != nil {
+ t.Debugf("Failed to copy CPU to %#x for old rseq: %v", t.oldRSeqCPUAddr, err)
t.forceSignal(linux.SIGSEGV, false)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
// Re-enter the task run loop for signal delivery.
@@ -205,9 +258,11 @@ func (*runApp) execute(t *Task) taskRunState {
t.tg.pidns.owner.mu.RUnlock()
}
+ region := trace.StartRegion(t.traceContext, runRegion)
t.accountTaskGoroutineEnter(TaskGoroutineRunningApp)
- info, at, err := t.p.Switch(t.MemoryManager().AddressSpace(), t.Arch(), t.rseqCPU)
+ info, at, err := t.p.Switch(t, t.MemoryManager(), t.Arch(), t.rseqCPU)
t.accountTaskGoroutineLeave(TaskGoroutineRunningApp)
+ region.End()
if clearSinglestep {
t.Arch().ClearSingleStep()
@@ -224,15 +279,7 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextSignalCPUID:
- // Is this a CPUID instruction?
- expected := arch.CPUIDInstruction[:]
- found := make([]byte, len(expected))
- _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found)
- if err == nil && bytes.Equal(expected, found) {
- // Skip the cpuid instruction.
- t.Arch().CPUIDEmulate(t)
- t.Arch().SetIP(t.Arch().IP() + uintptr(len(expected)))
-
+ if err := app.handleCPUIDInstruction(t); err == nil {
// Resume execution.
return (*runApp)(nil)
}
@@ -251,8 +298,10 @@ func (*runApp) execute(t *Task) taskRunState {
// an application-generated signal and we should continue execution
// normally.
if at.Any() {
+ region := trace.StartRegion(t.traceContext, faultRegion)
addr := usermem.Addr(info.Addr())
err := t.MemoryManager().HandleUserFault(t, addr, at, usermem.Addr(t.Arch().Stack()))
+ region.End()
if err == nil {
// The fault was handled appropriately.
// We can resume running the application.
@@ -260,6 +309,12 @@ func (*runApp) execute(t *Task) taskRunState {
}
// Is this a vsyscall that we need emulate?
+ //
+ // Note that we don't track vsyscalls as part of a
+ // specific trace region. This is because regions don't
+ // stack, and the actual system call will count as a
+ // region. We should be able to easily identify
+ // vsyscalls by having a <fault><syscall> pair.
if at.Execute {
if sysno, ok := t.tc.st.LookupEmulate(addr); ok {
return t.doVsyscall(addr, sysno)
@@ -306,7 +361,7 @@ func (*runApp) execute(t *Task) taskRunState {
return (*runApp)(nil)
case platform.ErrContextCPUPreempted:
- // Ensure that RSEQ critical sections are interrupted and per-thread
+ // Ensure that rseq critical sections are interrupted and per-thread
// CPU values are updated before the next platform.Context.Switch().
t.rseqPreempted = true
return (*runApp)(nil)
@@ -314,7 +369,7 @@ func (*runApp) execute(t *Task) taskRunState {
default:
// What happened? Can't continue.
t.Warningf("Unexpected SwitchToApp error: %v", err)
- t.PrepareExit(ExitStatus{Code: t.ExtractErrno(err, -1)})
+ t.PrepareExit(ExitStatus{Code: ExtractErrno(err, -1)})
return (*runExit)(nil)
}
}
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index 8b148db35..09366b60c 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -653,14 +653,14 @@ func (t *Task) SetNiceness(n int) {
}
// NumaPolicy returns t's current numa policy.
-func (t *Task) NumaPolicy() (policy int32, nodeMask uint64) {
+func (t *Task) NumaPolicy() (policy linux.NumaPolicy, nodeMask uint64) {
t.mu.Lock()
defer t.mu.Unlock()
return t.numaPolicy, t.numaNodeMask
}
// SetNumaPolicy sets t's numa policy.
-func (t *Task) SetNumaPolicy(policy int32, nodeMask uint64) {
+func (t *Task) SetNumaPolicy(policy linux.NumaPolicy, nodeMask uint64) {
t.mu.Lock()
defer t.mu.Unlock()
t.numaPolicy = policy
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 39cd1340d..cff2a8365 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -26,8 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -174,7 +174,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS
fallthrough
case (sre == ERESTARTSYS && !act.IsRestart()):
t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(syserror.EINTR, -1)))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1)))
default:
t.Debugf("Restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo)
t.Arch().RestartSyscall()
@@ -255,17 +255,32 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct)
}
}
+ mm := t.MemoryManager()
// Set up the signal handler. If we have a saved signal mask, the signal
// handler should run with the current mask, but sigreturn should restore
// the saved one.
- st := &arch.Stack{t.Arch(), t.MemoryManager(), sp}
+ st := &arch.Stack{t.Arch(), mm, sp}
mask := t.signalMask
if t.haveSavedSignalMask {
mask = t.savedSignalMask
}
+
+ // Set up the restorer.
+ // x86-64 should always uses SA_RESTORER, but this flag is optional on other platforms.
+ // Please see the linux code as reference:
+ // linux/arch/x86/kernel/signal.c:__setup_rt_frame()
+ // If SA_RESTORER is not configured, we can use the sigreturn trampolines
+ // the vdso provides instead.
+ // Please see the linux code as reference:
+ // linux/arch/arm64/kernel/signal.c:setup_return()
+ if act.Flags&linux.SA_RESTORER == 0 {
+ act.Restorer = mm.VDSOSigReturn()
+ }
+
if err := t.Arch().SignalSetup(st, &act, info, &alt, mask); err != nil {
return err
}
+ t.p.FullStateChanged()
t.haveSavedSignalMask = false
// Add our signal mask.
@@ -297,6 +312,7 @@ func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) {
// Restore our signal mask. SIGKILL and SIGSTOP should not be blocked.
t.SetSignalMask(sigset &^ UnblockableSignals)
+ t.p.FullStateChanged()
return ctrlResume, nil
}
@@ -513,8 +529,6 @@ func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
if t.stop != nil {
return false
}
- // - TODO(b/38173783): No special case for when t is also the sending task,
- // because the identity of the sender is unknown.
// - Do not choose tasks that have already been interrupted, as they may be
// busy handling another signal.
if len(t.interruptChan) != 0 {
@@ -625,6 +639,7 @@ func (t *Task) SetSavedSignalMask(mask linux.SignalSet) {
// SignalStack returns the task-private signal stack.
func (t *Task) SignalStack() arch.SignalStack {
+ t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch())
alt := t.signalStack
if t.onSignalStack(alt) {
alt.Flags |= arch.SignalStackFlagOnStack
@@ -705,7 +720,7 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a
func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
n := t.Arch().NewSignalAct()
n.SerializeFrom(s)
- _, err := t.CopyOut(addr, n)
+ _, err := n.CopyOut(t, addr)
return err
}
@@ -714,7 +729,7 @@ func (t *Task) CopyOutSignalAct(addr usermem.Addr, s *arch.SignalAct) error {
func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
n := t.Arch().NewSignalAct()
var s arch.SignalAct
- if _, err := t.CopyIn(addr, n); err != nil {
+ if _, err := n.CopyIn(t, addr); err != nil {
return s, err
}
n.DeserializeTo(&s)
@@ -726,7 +741,7 @@ func (t *Task) CopyInSignalAct(addr usermem.Addr) (arch.SignalAct, error) {
func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error {
n := t.Arch().NewSignalStack()
n.SerializeFrom(s)
- _, err := t.CopyOut(addr, n)
+ _, err := n.CopyOut(t, addr)
return err
}
@@ -735,7 +750,7 @@ func (t *Task) CopyOutSignalStack(addr usermem.Addr, s *arch.SignalStack) error
func (t *Task) CopyInSignalStack(addr usermem.Addr) (arch.SignalStack, error) {
n := t.Arch().NewSignalStack()
var s arch.SignalStack
- if _, err := t.CopyIn(addr, n); err != nil {
+ if _, err := n.CopyIn(t, addr); err != nil {
return s, err
}
n.DeserializeTo(&s)
@@ -1039,6 +1054,8 @@ func (*runInterrupt) execute(t *Task) taskRunState {
// Are there signals pending?
if info := t.dequeueSignalLocked(t.signalMask); info != nil {
+ t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch())
+
if linux.SignalSetOf(linux.Signal(info.Signo))&StopSignals != 0 {
// Indicate that we've dequeued a stop signal before unlocking the
// signal mutex; initiateGroupStop will check for races with
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index ae6fc4025..64c1e120a 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -17,11 +17,14 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// TaskConfig defines the configuration of a new Task (see below).
@@ -63,9 +66,8 @@ type TaskConfig struct {
// Niceness is the niceness of the new task.
Niceness int
- // If NetworkNamespaced is true, the new task should observe a non-root
- // network namespace.
- NetworkNamespaced bool
+ // NetworkNamespace is the network namespace to be used for the new task.
+ NetworkNamespace *inet.Namespace
// AllowedCPUMask contains the cpus that this task can run on.
AllowedCPUMask sched.CPUSet
@@ -79,6 +81,16 @@ type TaskConfig struct {
// AbstractSocketNamespace is the AbstractSocketNamespace of the new task.
AbstractSocketNamespace *AbstractSocketNamespace
+ // MountNamespaceVFS2 is the MountNamespace of the new task.
+ MountNamespaceVFS2 *vfs.MountNamespace
+
+ // RSeqAddr is a pointer to the the userspace linux.RSeq structure.
+ RSeqAddr usermem.Addr
+
+ // RSeqSignature is the signature that the rseq abort IP must be signed
+ // with.
+ RSeqSignature uint32
+
// ContainerID is the container the new task belongs to.
ContainerID string
}
@@ -90,8 +102,11 @@ func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) {
t, err := ts.newTask(cfg)
if err != nil {
cfg.TaskContext.release()
- cfg.FSContext.DecRef()
- cfg.FDTable.DecRef()
+ cfg.FSContext.DecRef(t)
+ cfg.FDTable.DecRef(t)
+ if cfg.MountNamespaceVFS2 != nil {
+ cfg.MountNamespaceVFS2.DecRef(t)
+ }
return nil, err
}
return t, nil
@@ -108,26 +123,29 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
parent: cfg.Parent,
children: make(map[*Task]struct{}),
},
- runState: (*runApp)(nil),
- interruptChan: make(chan struct{}, 1),
- signalMask: cfg.SignalMask,
- signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
- tc: *tc,
- fsContext: cfg.FSContext,
- fdTable: cfg.FDTable,
- p: cfg.Kernel.Platform.NewContext(),
- k: cfg.Kernel,
- ptraceTracees: make(map[*Task]struct{}),
- allowedCPUMask: cfg.AllowedCPUMask.Copy(),
- ioUsage: &usage.IO{},
- niceness: cfg.Niceness,
- netns: cfg.NetworkNamespaced,
- utsns: cfg.UTSNamespace,
- ipcns: cfg.IPCNamespace,
- abstractSockets: cfg.AbstractSocketNamespace,
- rseqCPU: -1,
- futexWaiter: futex.NewWaiter(),
- containerID: cfg.ContainerID,
+ runState: (*runApp)(nil),
+ interruptChan: make(chan struct{}, 1),
+ signalMask: cfg.SignalMask,
+ signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable},
+ tc: *tc,
+ fsContext: cfg.FSContext,
+ fdTable: cfg.FDTable,
+ p: cfg.Kernel.Platform.NewContext(),
+ k: cfg.Kernel,
+ ptraceTracees: make(map[*Task]struct{}),
+ allowedCPUMask: cfg.AllowedCPUMask.Copy(),
+ ioUsage: &usage.IO{},
+ niceness: cfg.Niceness,
+ netns: cfg.NetworkNamespace,
+ utsns: cfg.UTSNamespace,
+ ipcns: cfg.IPCNamespace,
+ abstractSockets: cfg.AbstractSocketNamespace,
+ mountNamespaceVFS2: cfg.MountNamespaceVFS2,
+ rseqCPU: -1,
+ rseqAddr: cfg.RSeqAddr,
+ rseqSignature: cfg.RSeqSignature,
+ futexWaiter: futex.NewWaiter(),
+ containerID: cfg.ContainerID,
}
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
@@ -154,10 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
// Below this point, newTask is expected not to fail (there is no rollback
// of assignTIDsLocked or any of the following).
- // Logging on t's behalf will panic if t.logPrefix hasn't been initialized.
- // This is the earliest point at which we can do so (since t now has thread
- // IDs).
- t.updateLogPrefixLocked()
+ // Logging on t's behalf will panic if t.logPrefix hasn't been
+ // initialized. This is the earliest point at which we can do so
+ // (since t now has thread IDs).
+ t.updateInfoLocked()
if cfg.InheritParent != nil {
t.parent = cfg.InheritParent.parent
diff --git a/pkg/sentry/kernel/task_stop.go b/pkg/sentry/kernel/task_stop.go
index 10c6e455c..296735d32 100644
--- a/pkg/sentry/kernel/task_stop.go
+++ b/pkg/sentry/kernel/task_stop.go
@@ -205,6 +205,22 @@ func (ts *TaskSet) BeginExternalStop() {
}
}
+// PullFullState receives full states for all tasks.
+func (ts *TaskSet) PullFullState() {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ if ts.Root == nil {
+ return
+ }
+ for t := range ts.Root.tids {
+ t.Activate()
+ if mm := t.MemoryManager(); mm != nil {
+ t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch())
+ }
+ t.Deactivate()
+ }
+}
+
// EndExternalStop indicates the end of an external stop started by a previous
// call to TaskSet.BeginExternalStop. EndExternalStop does not wait for task
// goroutines to resume.
diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go
index b543d536a..a5903b0b5 100644
--- a/pkg/sentry/kernel/task_syscall.go
+++ b/pkg/sentry/kernel/task_syscall.go
@@ -17,6 +17,7 @@ package kernel
import (
"fmt"
"os"
+ "runtime/trace"
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -24,8 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel
@@ -160,6 +161,10 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
ctrl = ctrlStopAndReinvokeSyscall
} else {
fn := s.Lookup(sysno)
+ var region *trace.Region // Only non-nil if tracing == true.
+ if trace.IsEnabled() {
+ region = trace.StartRegion(t.traceContext, s.LookupName(sysno))
+ }
if fn != nil {
// Call our syscall implementation.
rval, ctrl, err = fn(t, args)
@@ -167,6 +172,9 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
// Use the missing function if not found.
rval, err = t.SyscallTable().Missing(t, sysno, args)
}
+ if region != nil {
+ region.End()
+ }
}
if bits.IsOn32(fe, ExternalAfterEnable) && (s.ExternalFilterAfter == nil || s.ExternalFilterAfter(t, sysno, args)) {
@@ -186,6 +194,19 @@ func (t *Task) executeSyscall(sysno uintptr, args arch.SyscallArguments) (rval u
//
// The syscall path is very hot; avoid defer.
func (t *Task) doSyscall() taskRunState {
+ // Save value of the register which is clobbered in the following
+ // t.Arch().SetReturn(-ENOSYS) operation. This is dedicated to arm64.
+ //
+ // On x86, register rax was shared by syscall number and return
+ // value, and at the entry of the syscall handler, the rax was
+ // saved to regs.orig_rax which was exposed to userspace.
+ // But on arm64, syscall number was passed through X8, and the X0
+ // was shared by the first syscall argument and return value. The
+ // X0 was saved to regs.orig_x0 which was not exposed to userspace.
+ // So we have to do the same operation here to save the X0 value
+ // into the task context.
+ t.Arch().SyscallSaveOrig()
+
sysno := t.Arch().SyscallNo()
args := t.Arch().SyscallArgs()
@@ -261,6 +282,7 @@ func (*runSyscallAfterSyscallEnterStop) execute(t *Task) taskRunState {
return (*runSyscallExit)(nil)
}
args := t.Arch().SyscallArgs()
+
return t.doSyscallInvoke(sysno, args)
}
@@ -290,7 +312,7 @@ func (t *Task) doSyscallInvoke(sysno uintptr, args arch.SyscallArguments) taskRu
return ctrl.next
}
} else if err != nil {
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
t.haveSyscallReturn = true
} else {
t.Arch().SetReturn(rval)
@@ -409,7 +431,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle
// A return is not emulated in this case.
return (*runApp)(nil)
}
- t.Arch().SetReturn(uintptr(-t.ExtractErrno(err, int(sysno))))
+ t.Arch().SetReturn(uintptr(-ExtractErrno(err, int(sysno))))
}
t.Arch().SetIP(t.Arch().Value(caller))
t.Arch().SetStack(t.Arch().Stack() + uintptr(t.Arch().Width()))
@@ -419,7 +441,7 @@ func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, calle
// ExtractErrno extracts an integer error number from the error.
// The syscall number is purely for context in the error case. Use -1 if
// syscall number is unknown.
-func (t *Task) ExtractErrno(err error, sysno int) int {
+func ExtractErrno(err error, sysno int) int {
switch err := err.(type) {
case nil:
return 0
@@ -433,11 +455,11 @@ func (t *Task) ExtractErrno(err error, sysno int) int {
// handled (and the SIGBUS is delivered).
return int(syscall.EFAULT)
case *os.PathError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
case *os.LinkError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
case *os.SyscallError:
- return t.ExtractErrno(err.Err, sysno)
+ return ExtractErrno(err.Err, sysno)
default:
if errno, ok := syserror.TranslateError(err); ok {
return int(errno)
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 518bfe1bd..b02044ad2 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -18,8 +18,8 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// MAX_RW_COUNT is the maximum size in bytes of a single read or write.
@@ -30,7 +30,7 @@ var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
// Activate ensures that the task has an active address space.
func (t *Task) Activate() {
if mm := t.MemoryManager(); mm != nil {
- if err := mm.Activate(); err != nil {
+ if err := mm.Activate(t); err != nil {
panic("unable to activate mm: " + err.Error())
}
}
diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go
new file mode 100644
index 000000000..dda5a433a
--- /dev/null
+++ b/pkg/sentry/kernel/task_work.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync/atomic"
+
+// TaskWorker is a deferred task.
+//
+// This must be savable.
+type TaskWorker interface {
+ // TaskWork will be executed prior to returning to user space. Note that
+ // TaskWork may call RegisterWork again, but this will not be executed until
+ // the next return to user space, unlike in Linux. This effectively allows
+ // registration of indefinite user return hooks, but not by default.
+ TaskWork(t *Task)
+}
+
+// RegisterWork can be used to register additional task work that will be
+// performed prior to returning to user space. See TaskWorker.TaskWork for
+// semantics regarding registration.
+func (t *Task) RegisterWork(work TaskWorker) {
+ t.taskWorkMu.Lock()
+ defer t.taskWorkMu.Unlock()
+ atomic.AddInt32(&t.taskWorkCount, 1)
+ t.taskWork = append(t.taskWork, work)
+}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 72568d296..0b34c0099 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -15,7 +15,6 @@
package kernel
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,6 +24,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -238,8 +238,8 @@ type ThreadGroup struct {
// execed is protected by the TaskSet mutex.
execed bool
- // rscr is the thread group's RSEQ critical region.
- rscr atomic.Value `state:".(*RSEQCriticalRegion)"`
+ // oldRSeqCritical is the thread group's old rseq critical region.
+ oldRSeqCritical atomic.Value `state:".(*OldRSeqCriticalRegion)"`
// mounts is the thread group's mount namespace. This does not really
// correspond to a "mount namespace" in Linux, but is more like a
@@ -254,37 +254,44 @@ type ThreadGroup struct {
//
// tty is protected by the signal mutex.
tty *TTY
+
+ // oomScoreAdj is the thread group's OOM score adjustment. This is
+ // currently not used but is maintained for consistency.
+ // TODO(gvisor.dev/issue/1967)
+ //
+ // oomScoreAdj is accessed using atomic memory operations.
+ oomScoreAdj int32
}
-// newThreadGroup returns a new, empty thread group in PID namespace ns. The
+// NewThreadGroup returns a new, empty thread group in PID namespace pidns. The
// thread group leader will send its parent terminationSignal when it exits.
// The new thread group isn't visible to the system until a task has been
// created inside of it by a successful call to TaskSet.NewTask.
-func (k *Kernel) newThreadGroup(mounts *fs.MountNamespace, ns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet, monotonicClock *timekeeperClock) *ThreadGroup {
+func (k *Kernel) NewThreadGroup(mntns *fs.MountNamespace, pidns *PIDNamespace, sh *SignalHandlers, terminationSignal linux.Signal, limits *limits.LimitSet) *ThreadGroup {
tg := &ThreadGroup{
threadGroupNode: threadGroupNode{
- pidns: ns,
+ pidns: pidns,
},
signalHandlers: sh,
terminationSignal: terminationSignal,
ioUsage: &usage.IO{},
limits: limits,
- mounts: mounts,
+ mounts: mntns,
}
tg.itimerRealTimer = ktime.NewTimer(k.monotonicClock, &itimerRealListener{tg: tg})
tg.timers = make(map[linux.TimerID]*IntervalTimer)
- tg.rscr.Store(&RSEQCriticalRegion{})
+ tg.oldRSeqCritical.Store(&OldRSeqCriticalRegion{})
return tg
}
-// saveRscr is invoked by stateify.
-func (tg *ThreadGroup) saveRscr() *RSEQCriticalRegion {
- return tg.rscr.Load().(*RSEQCriticalRegion)
+// saveOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) saveOldRSeqCritical() *OldRSeqCriticalRegion {
+ return tg.oldRSeqCritical.Load().(*OldRSeqCriticalRegion)
}
-// loadRscr is invoked by stateify.
-func (tg *ThreadGroup) loadRscr(rscr *RSEQCriticalRegion) {
- tg.rscr.Store(rscr)
+// loadOldRSeqCritical is invoked by stateify.
+func (tg *ThreadGroup) loadOldRSeqCritical(r *OldRSeqCriticalRegion) {
+ tg.oldRSeqCritical.Store(r)
}
// SignalHandlers returns the signal handlers used by tg.
@@ -301,7 +308,7 @@ func (tg *ThreadGroup) Limits() *limits.LimitSet {
}
// release releases the thread group's resources.
-func (tg *ThreadGroup) release() {
+func (tg *ThreadGroup) release(t *Task) {
// Timers must be destroyed without holding the TaskSet or signal mutexes
// since timers send signals with Timer.mu locked.
tg.itimerRealTimer.Destroy()
@@ -317,7 +324,9 @@ func (tg *ThreadGroup) release() {
for _, it := range its {
it.DestroyTimer()
}
- tg.mounts.DecRef()
+ if tg.mounts != nil {
+ tg.mounts.DecRef(t)
+ }
}
// forEachChildThreadGroupLocked indicates over all child ThreadGroups.
@@ -357,7 +366,8 @@ func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error {
// terminal is stolen, and all processes that had it as controlling
// terminal lose it." - tty_ioctl(4)
if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session {
- if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 {
+ // Stealing requires CAP_SYS_ADMIN in the root user namespace.
+ if creds := auth.CredentialsFromContext(tg.leader); !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) || arg != 1 {
return syserror.EPERM
}
// Steal the TTY away. Unlike TIOCNOTTY, don't send signals.
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 8267929a6..872e1a82d 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -16,9 +16,9 @@ package kernel
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -87,6 +87,13 @@ type TaskSet struct {
// at time of save (but note that this is not necessarily the same thing as
// sync.WaitGroup's zero value).
runningGoroutines sync.WaitGroup `state:"nosave"`
+
+ // aioGoroutines is the number of goroutines running async I/O
+ // callbacks.
+ //
+ // aioGoroutines is not saved but is required to be zero at the time of
+ // save.
+ aioGoroutines sync.WaitGroup `state:"nosave"`
}
// newTaskSet returns a new, empty TaskSet.
diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD
index 31847e1df..2817aa3ba 100644
--- a/pkg/sentry/kernel/time/BUILD
+++ b/pkg/sentry/kernel/time/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,13 +6,14 @@ go_library(
name = "time",
srcs = [
"context.go",
+ "tcpip.go",
"time.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/time",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/waiter",
],
diff --git a/pkg/sentry/kernel/time/context.go b/pkg/sentry/kernel/time/context.go
index 8ef483dd3..00b729d88 100644
--- a/pkg/sentry/kernel/time/context.go
+++ b/pkg/sentry/kernel/time/context.go
@@ -15,7 +15,7 @@
package time
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is the time package's type for context.Context.Value keys.
diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go
new file mode 100644
index 000000000..c4474c0cf
--- /dev/null
+++ b/pkg/sentry/kernel/time/tcpip.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package time
+
+import (
+ "sync"
+ "time"
+)
+
+// TcpipAfterFunc waits for duration to elapse according to clock then runs fn.
+// The timer is started immediately and will fire exactly once.
+func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer {
+ timer := &TcpipTimer{
+ clock: clock,
+ }
+ timer.notifier = functionNotifier{
+ fn: func() {
+ // tcpip.Timer.Stop() explicitly states that the function is called in a
+ // separate goroutine that Stop() does not synchronize with.
+ // Timer.Destroy() synchronizes with calls to TimerListener.Notify().
+ // This is semantically meaningful because, in the former case, it's
+ // legal to call tcpip.Timer.Stop() while holding locks that may also be
+ // taken by the function, but this isn't so in the latter case. Most
+ // immediately, Timer calls TimerListener.Notify() while holding
+ // Timer.mu. A deadlock occurs without spawning a goroutine:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify()
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock!
+ //
+ // Spawning a goroutine avoids the deadlock:
+ // T1: (Timer expires)
+ // => Timer.Tick() <- Timer.mu.Lock() called
+ // => TimerListener.Notify() <- Launches T2
+ // T2:
+ // => Timer.Stop()
+ // => Timer.Destroy() <- Timer.mu.Lock() called, blocks
+ // T1:
+ // => (returns) <- Timer.mu.Unlock() called
+ // T2:
+ // => (continues) <- No deadlock!
+ go func() {
+ timer.Stop()
+ fn()
+ }()
+ },
+ }
+ timer.Reset(duration)
+ return timer
+}
+
+// TcpipTimer is a resettable timer with variable duration expirations.
+// Implements tcpip.Timer, which does not define a Destroy method; instead, all
+// resources are released after timer expiration and calls to Timer.Stop.
+//
+// Must be created by AfterFunc.
+type TcpipTimer struct {
+ // clock is the time source. clock is immutable.
+ clock Clock
+
+ // notifier is called when the Timer expires. notifier is immutable.
+ notifier functionNotifier
+
+ // mu protects t.
+ mu sync.Mutex
+
+ // t stores the latest running Timer. This is replaced whenever Reset is
+ // called since Timer cannot be restarted once it has been Destroyed by Stop.
+ //
+ // This field is nil iff Stop has been called.
+ t *Timer
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (r *TcpipTimer) Stop() bool {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ return false
+ }
+ _, lastSetting := r.t.Swap(Setting{})
+ r.t.Destroy()
+ r.t = nil
+ return lastSetting.Enabled
+}
+
+// Reset implements tcpip.Timer.Reset.
+func (r *TcpipTimer) Reset(d time.Duration) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.t == nil {
+ r.t = NewTimer(r.clock, &r.notifier)
+ }
+
+ r.t.Swap(Setting{
+ Enabled: true,
+ Period: 0,
+ Next: r.clock.Now().Add(d),
+ })
+}
+
+// functionNotifier is a TimerListener that runs a function.
+//
+// functionNotifier cannot be saved or loaded.
+type functionNotifier struct {
+ fn func()
+}
+
+// Notify implements ktime.TimerListener.Notify.
+func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) {
+ f.fn()
+ return Setting{}, false
+}
+
+// Destroy implements ktime.TimerListener.Destroy.
+func (f *functionNotifier) Destroy() {}
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index 107394183..e959700f2 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -19,10 +19,10 @@ package time
import (
"fmt"
"math"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -245,7 +245,7 @@ type Clock interface {
type WallRateClock struct{}
// WallTimeUntil implements Clock.WallTimeUntil.
-func (WallRateClock) WallTimeUntil(t, now Time) time.Duration {
+func (*WallRateClock) WallTimeUntil(t, now Time) time.Duration {
return t.Sub(now)
}
@@ -254,16 +254,16 @@ func (WallRateClock) WallTimeUntil(t, now Time) time.Duration {
type NoClockEvents struct{}
// Readiness implements waiter.Waitable.Readiness.
-func (NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (*NoClockEvents) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (*NoClockEvents) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (NoClockEvents) EventUnregister(e *waiter.Entry) {
+func (*NoClockEvents) EventUnregister(e *waiter.Entry) {
}
// ClockEventsQueue implements waiter.Waitable by wrapping waiter.Queue and
@@ -273,7 +273,7 @@ type ClockEventsQueue struct {
}
// Readiness implements waiter.Waitable.Readiness.
-func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (*ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
return 0
}
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index 76417342a..7c4fefb16 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -16,14 +16,15 @@ package kernel
import (
"fmt"
- "sync"
+ "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/log"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Timekeeper manages all of the kernel clocks.
@@ -48,6 +49,9 @@ type Timekeeper struct {
// It is set only once, by SetClocks.
monotonicOffset int64 `state:"nosave"`
+ // monotonicLowerBound is the lowerBound for monotonic time.
+ monotonicLowerBound int64 `state:"nosave"`
+
// restored, if non-nil, indicates that this Timekeeper was restored
// from a state file. The clocks are not set until restored is closed.
restored chan struct{} `state:"nosave"`
@@ -86,7 +90,7 @@ type Timekeeper struct {
// NewTimekeeper does not take ownership of paramPage.
//
// SetClocks must be called on the returned Timekeeper before it is usable.
-func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) {
+func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) {
return &Timekeeper{
params: NewVDSOParamPage(mfp, paramPage),
}, nil
@@ -182,6 +186,7 @@ func (t *Timekeeper) startUpdater() {
timer := time.NewTicker(sentrytime.ApproxUpdateInterval)
t.wg.Add(1)
go func() { // S/R-SAFE: stopped during save.
+ defer t.wg.Done()
for {
// Start with an update immediately, so the clocks are
// ready ASAP.
@@ -205,9 +210,6 @@ func (t *Timekeeper) startUpdater() {
p.realtimeBaseRef = int64(realtimeParams.BaseRef)
p.realtimeFrequency = realtimeParams.Frequency
}
-
- log.Debugf("Updating VDSO parameters: %+v", p)
-
return p
}); err != nil {
log.Warningf("Unable to update VDSO parameter page: %v", err)
@@ -216,7 +218,6 @@ func (t *Timekeeper) startUpdater() {
select {
case <-timer.C:
case <-t.stop:
- t.wg.Done()
return
}
}
@@ -271,6 +272,21 @@ func (t *Timekeeper) GetTime(c sentrytime.ClockID) (int64, error) {
now, err := t.clocks.GetTime(c)
if err == nil && c == sentrytime.Monotonic {
now += t.monotonicOffset
+ for {
+ // It's possible that the clock is shaky. This may be due to
+ // platform issues, e.g. the KVM platform relies on the guest
+ // TSC and host TSC, which may not be perfectly in sync. To
+ // work around this issue, ensure that the monotonic time is
+ // always bounded by the last time read.
+ oldLowerBound := atomic.LoadInt64(&t.monotonicLowerBound)
+ if now < oldLowerBound {
+ now = oldLowerBound
+ break
+ }
+ if atomic.CompareAndSwapInt64(&t.monotonicLowerBound, oldLowerBound, now) {
+ break
+ }
+ }
}
return now, err
}
diff --git a/pkg/sentry/kernel/timekeeper_test.go b/pkg/sentry/kernel/timekeeper_test.go
index 849c5b646..cf2f7ca72 100644
--- a/pkg/sentry/kernel/timekeeper_test.go
+++ b/pkg/sentry/kernel/timekeeper_test.go
@@ -17,12 +17,12 @@ package kernel
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// mockClocks is a sentrytime.Clocks that simply returns the times in the
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
index 34f84487a..d0e0810e8 100644
--- a/pkg/sentry/kernel/tty.go
+++ b/pkg/sentry/kernel/tty.go
@@ -14,15 +14,28 @@
package kernel
-import "sync"
+import "gvisor.dev/gvisor/pkg/sync"
// TTY defines the relationship between a thread group and its controlling
// terminal.
//
// +stateify savable
type TTY struct {
+ // Index is the terminal index. It is immutable.
+ Index uint32
+
mu sync.Mutex `state:"nosave"`
// tg is protected by mu.
tg *ThreadGroup
}
+
+// TTY returns the thread group's controlling terminal. If nil, there is no
+// controlling terminal.
+func (tg *ThreadGroup) TTY() *TTY {
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+ return tg.tty
+}
diff --git a/pkg/sentry/kernel/uts_namespace.go b/pkg/sentry/kernel/uts_namespace.go
index 0a563e715..8ccf04bd1 100644
--- a/pkg/sentry/kernel/uts_namespace.go
+++ b/pkg/sentry/kernel/uts_namespace.go
@@ -15,9 +15,8 @@
package kernel
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
)
// UTSNamespace represents a UTS namespace, a holder of two system identifiers:
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index fdd10c56c..290c32466 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -18,10 +18,10 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// vdsoParams are the parameters exposed to the VDSO.
@@ -58,7 +58,7 @@ type vdsoParams struct {
type VDSOParamPage struct {
// The parameter page is fr, allocated from mfp.MemoryFile().
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
// seq is the current sequence count written to the page.
//
@@ -81,7 +81,7 @@ type VDSOParamPage struct {
// * VDSOParamPage must be the only writer to fr.
//
// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
-func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage {
+func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage {
return &VDSOParamPage{mfp: mfp, fr: fr}
}
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 156e67bf8..cf591c4c1 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,11 +9,11 @@ go_library(
"limits.go",
"linux.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/limits",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
+ "//pkg/sync",
],
)
@@ -24,5 +23,5 @@ go_test(
srcs = [
"limits_test.go",
],
- embed = [":limits"],
+ library = ":limits",
)
diff --git a/pkg/sentry/limits/context.go b/pkg/sentry/limits/context.go
index 6972749ed..77e1fe217 100644
--- a/pkg/sentry/limits/context.go
+++ b/pkg/sentry/limits/context.go
@@ -15,7 +15,7 @@
package limits
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is the limit package's type for context.Context.Value keys.
diff --git a/pkg/sentry/limits/limits.go b/pkg/sentry/limits/limits.go
index b6c22656b..31b9e9ff6 100644
--- a/pkg/sentry/limits/limits.go
+++ b/pkg/sentry/limits/limits.go
@@ -16,8 +16,9 @@
package limits
import (
- "sync"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// LimitType defines a type of resource limit.
diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD
index 2890393bd..34bdb0b69 100644
--- a/pkg/sentry/loader/BUILD
+++ b/pkg/sentry/loader/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_embed_data")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_embed_data", "go_library")
package(licenses = ["notice"])
@@ -20,31 +19,28 @@ go_library(
"vdso_state.go",
":vdso_bin",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/loader",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi",
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/context",
"//pkg/cpuid",
"//pkg/log",
"//pkg/rand",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/fs",
- "//pkg/sentry/fs/anon",
- "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/safemem",
"//pkg/sentry/uniqueid",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
- "//pkg/waiter",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index 3ea037e4d..20dd1cc21 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -23,16 +23,16 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -90,18 +90,27 @@ type elfInfo struct {
sharedObject bool
}
+// fullReader interface extracts the ReadFull method from fsbridge.File so that
+// client code does not need to define an entire fsbridge.File when only read
+// functionality is needed.
+//
+// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap
+// vfs.FileDescription's PRead/Read instead.
+type fullReader interface {
+ // ReadFull is the same as fsbridge.File.ReadFull.
+ ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error)
+}
+
// parseHeader parse the ELF header, verifying that this is a supported ELF
// file and returning the ELF program headers.
//
// This is similar to elf.NewFile, except that it is more strict about what it
// accepts from the ELF, and it doesn't parse unnecessary parts of the file.
-//
-// ctx may be nil if f does not need it.
-func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
+func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
// Check ident first; it will tell us the endianness of the rest of the
// structs.
var ident [elf.EI_NIDENT]byte
- _, err := readFull(ctx, f, usermem.BytesIOSequence(ident[:]), 0)
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(ident[:]), 0)
if err != nil {
log.Infof("Error reading ELF ident: %v", err)
// The entire ident array always exists.
@@ -137,7 +146,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
var hdr elf.Header64
hdrBuf := make([]byte, header64Size)
- _, err = readFull(ctx, f, usermem.BytesIOSequence(hdrBuf), 0)
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0)
if err != nil {
log.Infof("Error reading ELF header: %v", err)
// The entire header always exists.
@@ -187,7 +196,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
}
phdrBuf := make([]byte, totalPhdrSize)
- _, err = readFull(ctx, f, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
+ _, err = f.ReadFull(ctx, usermem.BytesIOSequence(phdrBuf), int64(hdr.Phoff))
if err != nil {
log.Infof("Error reading ELF phdrs: %v", err)
// If phdrs were specified, they should all exist.
@@ -227,7 +236,7 @@ func parseHeader(ctx context.Context, f *fs.File) (elfInfo, error) {
// mapSegment maps a phdr into the Task. offset is the offset to apply to
// phdr.Vaddr.
-func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
+func mapSegment(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, phdr *elf.ProgHeader, offset usermem.Addr) error {
// We must make a page-aligned mapping.
adjust := usermem.Addr(phdr.Vaddr).PageOffset()
@@ -272,7 +281,7 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.
}
defer func() {
if mopts.MappingIdentity != nil {
- mopts.MappingIdentity.DecRef()
+ mopts.MappingIdentity.DecRef(ctx)
}
}()
if err := f.ConfigureMMap(ctx, &mopts); err != nil {
@@ -395,7 +404,7 @@ type loadedELF struct {
//
// Preconditions:
// * f is an ELF file
-func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
+func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) {
first := true
var start, end usermem.Addr
var interpreter string
@@ -408,6 +417,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
start = vaddr
}
if vaddr < end {
+ // NOTE(b/37474556): Linux allows out-of-order
+ // segments, in violation of the spec.
ctx.Infof("PT_LOAD headers out-of-order. %#x < %#x", vaddr, end)
return loadedELF{}, syserror.ENOEXEC
}
@@ -429,7 +440,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
}
path := make([]byte, phdr.Filesz)
- _, err := readFull(ctx, f, usermem.BytesIOSequence(path), int64(phdr.Off))
+ _, err := f.ReadFull(ctx, usermem.BytesIOSequence(path), int64(phdr.Off))
if err != nil {
// If an interpreter was specified, it should exist.
ctx.Infof("Error reading PT_INTERP path: %v", err)
@@ -562,7 +573,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
// Preconditions:
// * f is an ELF file
// * f is the first ELF loaded into m
-func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f *fs.File) (loadedELF, arch.Context, error) {
+func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f fsbridge.File) (loadedELF, arch.Context, error) {
info, err := parseHeader(ctx, f)
if err != nil {
ctx.Infof("Failed to parse initial ELF: %v", err)
@@ -600,7 +611,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS
//
// Preconditions:
// * f is an ELF file
-func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, initial loadedELF) (loadedELF, error) {
+func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) {
info, err := parseHeader(ctx, f)
if err != nil {
if err == syserror.ENOEXEC {
@@ -644,17 +655,17 @@ func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error
// resolved, the interpreter should still be resolved if it is
// a symlink.
args.ResolveFinal = true
+ // Refresh the traversal limit.
+ *args.RemainingTraversals = linux.MaxSymlinkTraversals
args.Filename = bin.interpreter
- d, i, err := openPath(ctx, args)
+ intFile, err := openPath(ctx, args)
if err != nil {
ctx.Infof("Error opening interpreter %s: %v", bin.interpreter, err)
return loadedELF{}, nil, err
}
- defer i.DecRef()
- // We don't need the Dirent.
- d.DecRef()
+ defer intFile.DecRef(ctx)
- interp, err = loadInterpreterELF(ctx, args.MemoryManager, i, bin)
+ interp, err = loadInterpreterELF(ctx, args.MemoryManager, intFile, bin)
if err != nil {
ctx.Infof("Error loading interpreter: %v", err)
return loadedELF{}, nil, err
diff --git a/pkg/sentry/loader/interpreter.go b/pkg/sentry/loader/interpreter.go
index ccf909cac..3886b4d33 100644
--- a/pkg/sentry/loader/interpreter.go
+++ b/pkg/sentry/loader/interpreter.go
@@ -18,10 +18,10 @@ import (
"bytes"
"io"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -37,9 +37,9 @@ const (
)
// parseInterpreterScript returns the interpreter path and argv.
-func parseInterpreterScript(ctx context.Context, filename string, f *fs.File, argv []string) (newpath string, newargv []string, err error) {
+func parseInterpreterScript(ctx context.Context, filename string, f fsbridge.File, argv []string) (newpath string, newargv []string, err error) {
line := make([]byte, interpMaxLineLength)
- n, err := readFull(ctx, f, usermem.BytesIOSequence(line), 0)
+ n, err := f.ReadFull(ctx, usermem.BytesIOSequence(line), 0)
// Short read is OK.
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index 818941762..8d6802ea3 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -23,16 +23,17 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// LoadArgs holds specifications for an executable file to be loaded.
@@ -40,16 +41,6 @@ type LoadArgs struct {
// MemoryManager is the memory manager to load the executable into.
MemoryManager *mm.MemoryManager
- // Mounts is the mount namespace in which to look up Filename.
- Mounts *fs.MountNamespace
-
- // Root is the root directory under which to look up Filename.
- Root *fs.Dirent
-
- // WorkingDirectory is the working directory under which to look up
- // Filename.
- WorkingDirectory *fs.Dirent
-
// RemainingTraversals is the maximum number of symlinks to follow to
// resolve Filename. This counter is passed by reference to keep it
// updated throughout the call stack.
@@ -64,7 +55,18 @@ type LoadArgs struct {
// File is an open fs.File object of the executable. If File is not
// nil, then File will be loaded and Filename will be ignored.
- File *fs.File
+ //
+ // The caller is responsible for checking that the user can execute this file.
+ File fsbridge.File
+
+ // Opener is used to open the executable file when 'File' is nil.
+ Opener fsbridge.Lookup
+
+ // CloseOnExec indicates that the executable (or one of its parent
+ // directories) was opened with O_CLOEXEC. If the executable is an
+ // interpreter script, then cause an ENOENT error to occur, since the
+ // script would otherwise be inaccessible to the interpreter.
+ CloseOnExec bool
// Argv is the vector of arguments to pass to the executable.
Argv []string
@@ -77,135 +79,41 @@ type LoadArgs struct {
Features *cpuid.FeatureSet
}
-// readFull behaves like io.ReadFull for an *fs.File.
-func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
- var total int64
- for dst.NumBytes() > 0 {
- n, err := f.Preadv(ctx, dst, offset+total)
- total += n
- if err == io.EOF && total != 0 {
- return total, io.ErrUnexpectedEOF
- } else if err != nil {
- return total, err
- }
- dst = dst.DropFirst64(n)
- }
- return total, nil
-}
-
-// openPath opens args.Filename for loading.
+// openPath opens args.Filename and checks that it is valid for loading.
//
-// openPath returns the fs.Dirent and an *fs.File for args.Filename, which is
-// not installed in the Task FDTable. The caller takes ownership of both.
+// openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not
+// installed in the Task FDTable. The caller takes ownership of both.
//
// args.Filename must be a readable, executable, regular file.
-func openPath(ctx context.Context, args LoadArgs) (*fs.Dirent, *fs.File, error) {
- var err error
+func openPath(ctx context.Context, args LoadArgs) (fsbridge.File, error) {
if args.Filename == "" {
ctx.Infof("cannot open empty name")
- return nil, nil, syserror.ENOENT
- }
-
- var d *fs.Dirent
- if args.ResolveFinal {
- d, err = args.Mounts.FindInode(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
- } else {
- d, err = args.Mounts.FindLink(ctx, args.Root, args.WorkingDirectory, args.Filename, args.RemainingTraversals)
- }
- if err != nil {
- return nil, nil, err
- }
-
- // Open file will take a reference to Dirent, so destroy this one.
- defer d.DecRef()
-
- if !args.ResolveFinal && fs.IsSymlink(d.Inode.StableAttr) {
- return nil, nil, syserror.ELOOP
- }
-
- return openFile(ctx, nil, d, args.Filename)
-}
-
-// openFile takes that file's Dirent and performs checks on it. If provided a
-// *fs.Dirent and not a *fs.File, it creates a *fs.File object from the Dirent's
-// Inode and performs checks on that.
-//
-// openFile returns an *fs.File and *fs.Dirent, and the caller takes ownership
-// of both.
-//
-// "dirent" and "file" must not both be nil and point to a readable, executable, regular file.
-func openFile(ctx context.Context, file *fs.File, dirent *fs.Dirent, name string) (*fs.Dirent, *fs.File, error) {
- // file and dirent must not be nil.
- if dirent == nil && file == nil {
- ctx.Infof("dirent and file cannot both be nil.")
- return nil, nil, syserror.ENOENT
- }
-
- if file != nil {
- dirent = file.Dirent
- }
-
- // Perform permissions checks on the file.
- if err := checkFile(ctx, dirent, name); err != nil {
- return nil, nil, err
+ return nil, syserror.ENOENT
}
- if file == nil {
- var ferr error
- if file, ferr = dirent.Inode.GetFile(ctx, dirent, fs.FileFlags{Read: true}); ferr != nil {
- return nil, nil, ferr
- }
- } else {
- // GetFile takes a reference to the created file, so make one in the case
- // that the file reference already existed.
- file.IncRef()
- }
-
- // We must be able to read at arbitrary offsets.
- if !file.Flags().Pread {
- file.DecRef()
- ctx.Infof("%s cannot be read at an offset: %+v", file.MappedName(ctx), file.Flags())
- return nil, nil, syserror.EACCES
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ opts := vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
}
-
- // Grab reference for caller.
- dirent.IncRef()
- return dirent, file, nil
+ return args.Opener.OpenPath(ctx, args.Filename, opts, args.RemainingTraversals, args.ResolveFinal)
}
-// checkFile performs file permissions checks for binaries called in openPath
-// and openFile
-func checkFile(ctx context.Context, d *fs.Dirent, name string) error {
- perms := fs.PermMask{
- // TODO(gvisor.dev/issue/160): Linux requires only execute
- // permission, not read. However, our backing filesystems may
- // prevent us from reading the file without read permission.
- //
- // Additionally, a task with a non-readable executable has
- // additional constraints on access via ptrace and procfs.
- Read: true,
- Execute: true,
- }
- if err := d.Inode.CheckPermission(ctx, perms); err != nil {
+// checkIsRegularFile prevents us from trying to execute a directory, pipe, etc.
+func checkIsRegularFile(ctx context.Context, file fsbridge.File, filename string) error {
+ t, err := file.Type(ctx)
+ if err != nil {
return err
}
-
- // If they claim it's a directory, then make sure.
- //
- // N.B. we reject directories below, but we must first reject
- // non-directories passed as directories.
- if len(name) > 0 && name[len(name)-1] == '/' && !fs.IsDir(d.Inode.StableAttr) {
- return syserror.ENOTDIR
- }
-
- // No exec-ing directories, pipes, etc!
- if !fs.IsRegular(d.Inode.StableAttr) {
- ctx.Infof("%s is not regular: %v", name, d.Inode.StableAttr)
+ if t != linux.ModeRegular {
+ ctx.Infof("%q is not a regular file: %v", filename, t)
return syserror.EACCES
}
-
return nil
-
}
// allocStack allocates and maps a stack in to any available part of the address space.
@@ -224,8 +132,10 @@ const (
maxLoaderAttempts = 6
)
-// loadExecutable loads an executable that is pointed to by args.File. If nil,
-// the path args.Filename is resolved and loaded. If the executable is an
+// loadExecutable loads an executable that is pointed to by args.File. The
+// caller is responsible for checking that the user can execute this file.
+// If nil, the path args.Filename is resolved and loaded (check that the user
+// can execute this file is done here in this case). If the executable is an
// interpreter script rather than an ELF, the binary of the corresponding
// interpreter will be loaded.
//
@@ -234,33 +144,29 @@ const (
// * arch.Context matching the binary arch
// * fs.Dirent of the binary file
// * Possibly updated args.Argv
-func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, *fs.Dirent, []string, error) {
+func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, fsbridge.File, []string, error) {
for i := 0; i < maxLoaderAttempts; i++ {
- var (
- d *fs.Dirent
- err error
- )
if args.File == nil {
- d, args.File, err = openPath(ctx, args)
+ var err error
+ args.File, err = openPath(ctx, args)
+ if err != nil {
+ ctx.Infof("Error opening %s: %v", args.Filename, err)
+ return loadedELF{}, nil, nil, nil, err
+ }
+ // Ensure file is release in case the code loops or errors out.
+ defer args.File.DecRef(ctx)
} else {
- d, args.File, err = openFile(ctx, args.File, nil, "")
- }
-
- if err != nil {
- ctx.Infof("Error opening %s: %v", args.Filename, err)
- return loadedELF{}, nil, nil, nil, err
+ if err := checkIsRegularFile(ctx, args.File, args.Filename); err != nil {
+ return loadedELF{}, nil, nil, nil, err
+ }
}
- defer args.File.DecRef()
- // We will return d in the successful case, but defer a DecRef
- // for intermediate loops and failure cases.
- defer d.DecRef()
// Check the header. Is this an ELF or interpreter script?
var hdr [4]uint8
// N.B. We assume that reading from a regular file cannot block.
- _, err = readFull(ctx, args.File, usermem.BytesIOSequence(hdr[:]), 0)
- // Allow unexpected EOF, as a valid executable could be only
- // three bytes (e.g., #!a).
+ _, err := args.File.ReadFull(ctx, usermem.BytesIOSequence(hdr[:]), 0)
+ // Allow unexpected EOF, as a valid executable could be only three bytes
+ // (e.g., #!a).
if err != nil && err != io.ErrUnexpectedEOF {
if err == io.EOF {
err = syserror.ENOEXEC
@@ -275,15 +181,22 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
ctx.Infof("Error loading ELF: %v", err)
return loadedELF{}, nil, nil, nil, err
}
- // An ELF is always terminal. Hold on to d.
- d.IncRef()
- return loaded, ac, d, args.Argv, err
+ // An ELF is always terminal. Hold on to file.
+ args.File.IncRef()
+ return loaded, ac, args.File, args.Argv, err
+
case bytes.Equal(hdr[:2], []byte(interpreterScriptMagic)):
+ if args.CloseOnExec {
+ return loadedELF{}, nil, nil, nil, syserror.ENOENT
+ }
args.Filename, args.Argv, err = parseInterpreterScript(ctx, args.Filename, args.File, args.Argv)
if err != nil {
ctx.Infof("Error loading interpreter script: %v", err)
return loadedELF{}, nil, nil, nil, err
}
+ // Refresh the traversal limit for the interpreter.
+ *args.RemainingTraversals = linux.MaxSymlinkTraversals
+
default:
ctx.Infof("Unknown magic: %v", hdr)
return loadedELF{}, nil, nil, nil, syserror.ENOEXEC
@@ -306,16 +219,16 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context
// * Load is called on the Task goroutine.
func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) {
// Load the executable itself.
- loaded, ac, d, newArgv, err := loadExecutable(ctx, args)
+ loaded, ac, file, newArgv, err := loadExecutable(ctx, args)
if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux())
}
- defer d.DecRef()
+ defer file.DecRef(ctx)
// Load the VDSO.
vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded)
if err != nil {
- return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux())
}
// Setup the heap. brk starts at the next page after the end of the
@@ -379,7 +292,16 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V
m.SetEnvvStart(sl.EnvvStart)
m.SetEnvvEnd(sl.EnvvEnd)
m.SetAuxv(auxv)
- m.SetExecutable(d)
+ m.SetExecutable(ctx, file)
+
+ symbolValue, err := getSymbolValueFromVDSO("rt_sigreturn")
+ if err != nil {
+ return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to find rt_sigreturn in vdso: %v", err), syserr.FromError(err).ToLinux())
+ }
+
+ // Found rt_sigretrun.
+ addr := uint64(vdsoAddr) + symbolValue - vdsoPrelink
+ m.SetVDSOSigReturn(addr)
ac.SetIP(uintptr(loaded.entry))
ac.SetStack(uintptr(stack.Bottom))
diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go
index ada28aea3..05a294fe6 100644
--- a/pkg/sentry/loader/vdso.go
+++ b/pkg/sentry/loader/vdso.go
@@ -15,28 +15,28 @@
package loader
import (
+ "bytes"
"debug/elf"
"fmt"
"io"
+ "strings"
"gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/anon"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+const vdsoPrelink = 0xffffffffff700000
+
type fileContext struct {
context.Context
}
@@ -50,50 +50,11 @@ func (f *fileContext) Value(key interface{}) interface{} {
}
}
-// byteReader implements fs.FileOperations for reading from a []byte source.
-type byteReader struct {
- fsutil.FileNoFsync `state:"nosave"`
- fsutil.FileNoIoctl `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileNoopRelease `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FilePipeSeek `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- waiter.AlwaysReady `state:"nosave"`
-
+type byteFullReader struct {
data []byte
}
-var _ fs.FileOperations = (*byteReader)(nil)
-
-// newByteReaderFile creates a fake file to read data from.
-func newByteReaderFile(ctx context.Context, data []byte) *fs.File {
- // Create a fake inode.
- inode := fs.NewInode(
- ctx,
- &fsutil.SimpleFileInode{},
- fs.NewPseudoMountSource(ctx),
- fs.StableAttr{
- Type: fs.Anonymous,
- DeviceID: anon.PseudoDevice.DeviceID(),
- InodeID: anon.PseudoDevice.NextIno(),
- BlockSize: usermem.PageSize,
- })
-
- // Use the fake inode to create a fake dirent.
- dirent := fs.NewTransientDirent(inode)
- defer dirent.DecRef()
-
- // Use the fake dirent to make a fake file.
- flags := fs.FileFlags{Read: true, Pread: true}
- return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{
- data: data,
- })
-}
-
-func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) {
if offset < 0 {
return 0, syserror.EINVAL
}
@@ -104,10 +65,6 @@ func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequ
return int64(n), err
}
-func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
- panic("Write not supported")
-}
-
// validateVDSO checks that the VDSO can be loaded by loadVDSO.
//
// VDSOs are special (see below). Since we are going to map the VDSO directly
@@ -123,7 +80,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq
// * PT_LOAD segments don't extend beyond the end of the file.
//
// ctx may be nil if f does not need it.
-func validateVDSO(ctx context.Context, f *fs.File, size uint64) (elfInfo, error) {
+func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) {
info, err := parseHeader(ctx, f)
if err != nil {
log.Infof("Unable to parse VDSO header: %v", err)
@@ -218,15 +175,35 @@ type VDSO struct {
phdrs []elf.ProgHeader `state:".([]elfProgHeader)"`
}
+// getSymbolValueFromVDSO returns the specific symbol value in vdso.so.
+func getSymbolValueFromVDSO(symbol string) (uint64, error) {
+ f, err := elf.NewFile(bytes.NewReader(vdsoBin))
+ if err != nil {
+ return 0, err
+ }
+ syms, err := f.Symbols()
+ if err != nil {
+ return 0, err
+ }
+
+ for _, sym := range syms {
+ if elf.ST_BIND(sym.Info) != elf.STB_LOCAL && sym.Section != elf.SHN_UNDEF {
+ if strings.Contains(sym.Name, symbol) {
+ return sym.Value, nil
+ }
+ }
+ }
+ return 0, fmt.Errorf("no %v in vdso.so", symbol)
+}
+
// PrepareVDSO validates the system VDSO and returns a VDSO, containing the
// param page for updating by the kernel.
-func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
- vdsoFile := newByteReaderFile(ctx, vdsoBin)
+func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) {
+ vdsoFile := &byteFullReader{data: vdsoBin}
// First make sure the VDSO is valid. vdsoFile does not use ctx, so a
// nil context can be passed.
info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin)))
- vdsoFile.DecRef()
if err != nil {
return nil, err
}
@@ -268,6 +245,8 @@ func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, er
// some applications may not be able to handle multiple [vdso]
// hints.
vdso: mm.NewSpecialMappable("", mfp, vdso),
+ os: info.os,
+ arch: info.arch,
phdrs: info.phdrs,
}, nil
}
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 3ef84245b..2c95669cd 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -29,23 +28,33 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "file_range",
+ out = "file_range.go",
+ package = "memmap",
+ prefix = "File",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
go_library(
name = "memmap",
srcs = [
+ "file_range.go",
"mappable_range.go",
"mapping_set.go",
"mapping_set_impl.go",
"memmap.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/memmap",
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/context",
"//pkg/log",
- "//pkg/refs",
- "//pkg/sentry/context",
- "//pkg/sentry/platform",
- "//pkg/sentry/usermem",
+ "//pkg/safemem",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
@@ -53,6 +62,6 @@ go_test(
name = "memmap_test",
size = "small",
srcs = ["mapping_set_test.go"],
- embed = [":memmap"],
- deps = ["//pkg/sentry/usermem"],
+ library = ":memmap",
+ deps = ["//pkg/usermem"],
)
diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go
index 0a5b7ce45..d609c1ae0 100644
--- a/pkg/sentry/memmap/mapping_set.go
+++ b/pkg/sentry/memmap/mapping_set.go
@@ -18,7 +18,7 @@ import (
"fmt"
"math"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// MappingSet maps offsets into a Mappable to mappings of those offsets. It is
diff --git a/pkg/sentry/memmap/mapping_set_test.go b/pkg/sentry/memmap/mapping_set_test.go
index f9b11a59c..d39efe38f 100644
--- a/pkg/sentry/memmap/mapping_set_test.go
+++ b/pkg/sentry/memmap/mapping_set_test.go
@@ -18,7 +18,7 @@ import (
"reflect"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type testMappingSpace struct {
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 03b99aaea..65d83096f 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -18,14 +18,13 @@ package memmap
import (
"fmt"
- "gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Mappable represents a memory-mappable object, a mutable mapping from uint64
-// offsets to (platform.File, uint64 File offset) pairs.
+// offsets to (File, uint64 File offset) pairs.
//
// See mm/mm.go for Mappable's place in the lock order.
//
@@ -75,7 +74,7 @@ type Mappable interface {
// Translations are valid until invalidated by a callback to
// MappingSpace.Invalidate or until the caller removes its mapping of the
// translated range. Mappable implementations must ensure that at least one
- // reference is held on all pages in a platform.File that may be the result
+ // reference is held on all pages in a File that may be the result
// of a valid Translation.
//
// Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
@@ -101,7 +100,7 @@ type Translation struct {
Source MappableRange
// File is the mapped file.
- File platform.File
+ File File
// Offset is the offset into File at which this Translation begins.
Offset uint64
@@ -111,9 +110,9 @@ type Translation struct {
Perms usermem.AccessType
}
-// FileRange returns the platform.FileRange represented by t.
-func (t Translation) FileRange() platform.FileRange {
- return platform.FileRange{t.Offset, t.Offset + t.Source.Length()}
+// FileRange returns the FileRange represented by t.
+func (t Translation) FileRange() FileRange {
+ return FileRange{t.Offset, t.Offset + t.Source.Length()}
}
// CheckTranslateResult returns an error if (ts, terr) does not satisfy all
@@ -235,8 +234,11 @@ type InvalidateOpts struct {
// coincidental; fs.File implements MappingIdentity, and some
// fs.InodeOperations implement Mappable.)
type MappingIdentity interface {
- // MappingIdentity is reference-counted.
- refs.RefCounter
+ // IncRef increments the MappingIdentity's reference count.
+ IncRef()
+
+ // DecRef decrements the MappingIdentity's reference count.
+ DecRef(ctx context.Context)
// MappedName returns the application-visible name shown in
// /proc/[pid]/maps.
@@ -358,4 +360,57 @@ type MMapOpts struct {
//
// TODO(jamieliu): Replace entirely with MappingIdentity?
Hint string
+
+ // Force means to skip validation checks of Addr and Length. It can be
+ // used to create special mappings below mm.layout.MinAddr and
+ // mm.layout.MaxAddr. It has to be used with caution.
+ //
+ // If Force is true, Unmap and Fixed must be true.
+ Force bool
+}
+
+// File represents a host file that may be mapped into an platform.AddressSpace.
+type File interface {
+ // All pages in a File are reference-counted.
+
+ // IncRef increments the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr. (The File
+ // interface does not provide a way to acquire an initial reference;
+ // implementors may define mechanisms for doing so.)
+ IncRef(fr FileRange)
+
+ // DecRef decrements the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr.
+ DecRef(fr FileRange)
+
+ // MapInternal returns a mapping of the given file offsets in the invoking
+ // process' address space for reading and writing.
+ //
+ // Note that fr.Start and fr.End need not be page-aligned.
+ //
+ // Preconditions: fr.Length() > 0. At least one reference must be held on
+ // all pages in fr.
+ //
+ // Postconditions: The returned mapping is valid as long as at least one
+ // reference is held on the mapped pages.
+ MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
+
+ // FD returns the file descriptor represented by the File.
+ //
+ // The only permitted operation on the returned file descriptor is to map
+ // pages from it consistent with the requirements of AddressSpace.MapFile.
+ FD() int
+}
+
+// FileRange represents a range of uint64 offsets into a File.
+//
+// type FileRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (fr FileRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
}
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index a804b8b5c..f9d0837a1 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,14 +7,14 @@ go_template_instance(
name = "file_refcount_set",
out = "file_refcount_set.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "mm",
prefix = "fileRefcount",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "int32",
"Functions": "fileRefcountSetFunctions",
},
@@ -26,9 +25,10 @@ go_template_instance(
out = "vma_set.go",
consts = {
"minDegree": "8",
+ "trackGaps": "1",
},
imports = {
- "usermem": "gvisor.dev/gvisor/pkg/sentry/usermem",
+ "usermem": "gvisor.dev/gvisor/pkg/usermem",
},
package = "mm",
prefix = "vma",
@@ -48,7 +48,7 @@ go_template_instance(
"minDegree": "8",
},
imports = {
- "usermem": "gvisor.dev/gvisor/pkg/sentry/usermem",
+ "usermem": "gvisor.dev/gvisor/pkg/usermem",
},
package = "mm",
prefix = "pma",
@@ -96,17 +96,18 @@ go_library(
"vma.go",
"vma_set.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/mm",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/atomicbitops",
+ "//pkg/context",
"//pkg/log",
"//pkg/refs",
+ "//pkg/safecopy",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/fs",
"//pkg/sentry/fs/proc/seqfile",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/futex",
"//pkg/sentry/kernel/shm",
@@ -114,13 +115,11 @@ go_library(
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
- "//pkg/sentry/platform/safecopy",
- "//pkg/sentry/safemem",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip/buffer",
- "//third_party/gvsync",
+ "//pkg/usermem",
],
)
@@ -128,16 +127,16 @@ go_test(
name = "mm_test",
size = "small",
srcs = ["mm_test.go"],
- embed = [":mm"],
+ library = ":mm",
deps = [
+ "//pkg/context",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
"//pkg/sentry/platform",
- "//pkg/sentry/usermem",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/mm/README.md b/pkg/sentry/mm/README.md
index e1322e373..f4d43d927 100644
--- a/pkg/sentry/mm/README.md
+++ b/pkg/sentry/mm/README.md
@@ -274,7 +274,7 @@ In the sentry:
methods
[`platform.AddressSpace.MapFile` and `platform.AddressSpace.Unmap`][platform].
-[memmap]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/memmap/memmap.go
-[mm]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/mm/mm.go
-[pgalloc]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/pgalloc/pgalloc.go
-[platform]: https://github.com/google/gvisor/blob/master/+/master/pkg/sentry/platform/platform.go
+[memmap]: https://github.com/google/gvisor/blob/master/pkg/sentry/memmap/memmap.go
+[mm]: https://github.com/google/gvisor/blob/master/pkg/sentry/mm/mm.go
+[pgalloc]: https://github.com/google/gvisor/blob/master/pkg/sentry/pgalloc/pgalloc.go
+[platform]: https://github.com/google/gvisor/blob/master/pkg/sentry/platform/platform.go
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index cfebcfd42..5c667117c 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -18,9 +18,9 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// AddressSpace returns the platform.AddressSpace bound to mm.
@@ -39,11 +39,18 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
//
// When this MemoryManager is no longer needed by a task, it should call
// Deactivate to release the reference.
-func (mm *MemoryManager) Activate() error {
+func (mm *MemoryManager) Activate(ctx context.Context) error {
// Fast path: the MemoryManager already has an active
// platform.AddressSpace, and we just need to indicate that we need it too.
- if atomicbitops.IncUnlessZeroInt32(&mm.active) {
- return nil
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 0 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active+1) {
+ return nil
+ }
}
for {
@@ -85,16 +92,20 @@ func (mm *MemoryManager) Activate() error {
if as == nil {
// AddressSpace is unavailable, we must wait.
//
- // activeMu must not be held while waiting, as the user
- // of the address space we are waiting on may attempt
- // to take activeMu.
- //
- // Don't call UninterruptibleSleepStart to register the
- // wait to allow the watchdog stuck task to trigger in
- // case a process is starved waiting for the address
- // space.
+ // activeMu must not be held while waiting, as the user of the address
+ // space we are waiting on may attempt to take activeMu.
mm.activeMu.Unlock()
+
+ sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation
+ if sleep {
+ // Mark this task sleeping while waiting for the address space to
+ // prevent the watchdog from reporting it as a stuck task.
+ ctx.UninterruptibleSleepStart(false)
+ }
<-c
+ if sleep {
+ ctx.UninterruptibleSleepFinish(false)
+ }
continue
}
@@ -118,8 +129,15 @@ func (mm *MemoryManager) Activate() error {
func (mm *MemoryManager) Deactivate() {
// Fast path: this is not the last goroutine to deactivate the
// MemoryManager.
- if atomicbitops.DecUnlessOneInt32(&mm.active) {
- return
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 1 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active-1) {
+ return
+ }
}
mm.activeMu.Lock()
@@ -183,8 +201,10 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, pre
if pma.needCOW {
perms.Write = false
}
- if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
- return err
+ if perms.Any() { // MapFile precondition
+ if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, precommit); err != nil {
+ return err
+ }
}
pseg = pseg.NextSegment()
}
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 1b746d030..16fea53c4 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -15,17 +15,15 @@
package mm
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// aioManager creates and manages asynchronous I/O contexts.
@@ -60,25 +58,27 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
}
a.contexts[id] = &AIOContext{
- done: make(chan struct{}, 1),
+ requestReady: make(chan struct{}, 1),
maxOutstanding: events,
}
return true
}
-// destroyAIOContext destroys an asynchronous I/O context.
+// destroyAIOContext destroys an asynchronous I/O context. It doesn't wait for
+// for pending requests to complete. Returns the destroyed AIOContext so it can
+// be drained.
//
-// False is returned if the context does not exist.
-func (a *aioManager) destroyAIOContext(id uint64) bool {
+// Nil is returned if the context does not exist.
+func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
a.mu.Lock()
defer a.mu.Unlock()
ctx, ok := a.contexts[id]
if !ok {
- return false
+ return nil
}
delete(a.contexts, id)
ctx.destroy()
- return true
+ return ctx
}
// lookupAIOContext looks up the given context.
@@ -103,8 +103,8 @@ type ioResult struct {
//
// +stateify savable
type AIOContext struct {
- // done is the notification channel used for all requests.
- done chan struct{} `state:"nosave"`
+ // requestReady is the notification channel used for all requests.
+ requestReady chan struct{} `state:"nosave"`
// mu protects below.
mu sync.Mutex `state:"nosave"`
@@ -130,8 +130,14 @@ func (ctx *AIOContext) destroy() {
ctx.mu.Lock()
defer ctx.mu.Unlock()
ctx.dead = true
- if ctx.outstanding == 0 {
- close(ctx.done)
+ ctx.checkForDone()
+}
+
+// Preconditions: ctx.mu must be held by caller.
+func (ctx *AIOContext) checkForDone() {
+ if ctx.dead && ctx.outstanding == 0 {
+ close(ctx.requestReady)
+ ctx.requestReady = nil
}
}
@@ -155,11 +161,12 @@ func (ctx *AIOContext) PopRequest() (interface{}, bool) {
// Is there anything ready?
if e := ctx.results.Front(); e != nil {
- ctx.results.Remove(e)
- ctx.outstanding--
- if ctx.outstanding == 0 && ctx.dead {
- close(ctx.done)
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
}
+ ctx.outstanding--
+ ctx.results.Remove(e)
+ ctx.checkForDone()
return e.data, true
}
return nil, false
@@ -173,26 +180,58 @@ func (ctx *AIOContext) FinishRequest(data interface{}) {
// Push to the list and notify opportunistically. The channel notify
// here is guaranteed to be safe because outstanding must be non-zero.
- // The done channel is only closed when outstanding reaches zero.
+ // The requestReady channel is only closed when outstanding reaches zero.
ctx.results.PushBack(&ioResult{data: data})
select {
- case ctx.done <- struct{}{}:
+ case ctx.requestReady <- struct{}{}:
default:
}
}
// WaitChannel returns a channel that is notified when an AIO request is
-// completed.
-//
-// The boolean return value indicates whether or not the context is active.
-func (ctx *AIOContext) WaitChannel() (chan struct{}, bool) {
+// completed. Returns nil if the context is destroyed and there are no more
+// outstanding requests.
+func (ctx *AIOContext) WaitChannel() chan struct{} {
ctx.mu.Lock()
defer ctx.mu.Unlock()
- if ctx.outstanding == 0 && ctx.dead {
- return nil, false
+ return ctx.requestReady
+}
+
+// Dead returns true if the context has been destroyed.
+func (ctx *AIOContext) Dead() bool {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+ return ctx.dead
+}
+
+// CancelPendingRequest forgets about a request that hasn't yet completed.
+func (ctx *AIOContext) CancelPendingRequest() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ panic("AIOContext outstanding is going negative")
}
- return ctx.done, true
+ ctx.outstanding--
+ ctx.checkForDone()
+}
+
+// Drain drops all completed requests. Pending requests remain untouched.
+func (ctx *AIOContext) Drain() {
+ ctx.mu.Lock()
+ defer ctx.mu.Unlock()
+
+ if ctx.outstanding == 0 {
+ return
+ }
+ size := uint32(ctx.results.Len())
+ if ctx.outstanding < size {
+ panic("AIOContext outstanding is going negative")
+ }
+ ctx.outstanding -= size
+ ctx.results.Reset()
+ ctx.checkForDone()
}
// aioMappable implements memmap.MappingIdentity and memmap.Mappable for AIO
@@ -203,7 +242,7 @@ type aioMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
}
var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp())
@@ -219,8 +258,8 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) {
}
// DecRef implements refs.RefCounter.DecRef.
-func (m *aioMappable) DecRef() {
- m.AtomicRefCount.DecRefWithDestructor(func() {
+func (m *aioMappable) DecRef(ctx context.Context) {
+ m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) {
m.mfp.MemoryFile().DecRef(m.fr)
})
}
@@ -328,14 +367,14 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
if err != nil {
return 0, err
}
- defer m.DecRef()
+ defer m.DecRef(ctx)
addr, err := mm.MMap(ctx, memmap.MMapOpts{
Length: aioRingBufferSize,
MappingIdentity: m,
Mappable: m,
- // TODO(fvoznika): Linux does "do_mmap_pgoff(..., PROT_READ |
- // PROT_WRITE, ...)" in fs/aio.c:aio_setup_ring(); why do we make this
- // mapping read-only?
+ // Linux uses "do_mmap_pgoff(..., PROT_READ | PROT_WRITE, ...)" in
+ // fs/aio.c:aio_setup_ring(). Since we don't implement AIO_RING_MAGIC,
+ // user mode should not write to this page.
Perms: usermem.Read,
MaxPerms: usermem.Read,
})
@@ -350,11 +389,11 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
return id, nil
}
-// DestroyAIOContext destroys an asynchronous I/O context. It returns false if
-// the context does not exist.
-func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) bool {
+// DestroyAIOContext destroys an asynchronous I/O context. It returns the
+// destroyed context. nil if the context does not exist.
+func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
if _, ok := mm.LookupAIOContext(ctx, id); !ok {
- return false
+ return nil
}
// Only unmaps after it assured that the address is a valid aio context to
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
index c37fc9f7b..3dabac1af 100644
--- a/pkg/sentry/mm/aio_context_state.go
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -16,5 +16,5 @@ package mm
// afterLoad is invoked by stateify.
func (a *AIOContext) afterLoad() {
- a.done = make(chan struct{}, 1)
+ a.requestReady = make(chan struct{}, 1)
}
diff --git a/pkg/sentry/mm/debug.go b/pkg/sentry/mm/debug.go
index df9adf708..c273c982e 100644
--- a/pkg/sentry/mm/debug.go
+++ b/pkg/sentry/mm/debug.go
@@ -18,7 +18,7 @@ import (
"bytes"
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
const (
diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go
index b03e7d020..fa776f9c6 100644
--- a/pkg/sentry/mm/io.go
+++ b/pkg/sentry/mm/io.go
@@ -15,11 +15,11 @@
package mm
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// There are two supported ways to copy data to/from application virtual
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 4e9ca1de6..09dbc06a4 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -18,27 +18,27 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
-func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager {
+func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager {
return &MemoryManager{
- p: p,
- mfp: mfp,
- haveASIO: p.SupportsAddressSpaceIO(),
- privateRefs: &privateRefs{},
- users: 1,
- auxv: arch.Auxv{},
- dumpability: UserDumpable,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ p: p,
+ mfp: mfp,
+ haveASIO: p.SupportsAddressSpaceIO(),
+ privateRefs: &privateRefs{},
+ users: 1,
+ auxv: arch.Auxv{},
+ dumpability: UserDumpable,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: sleepForActivation,
}
}
@@ -57,6 +57,8 @@ func (mm *MemoryManager) SetMmapLayout(ac arch.Context, r *limits.LimitSet) (arc
// Fork creates a copy of mm with 1 user, as for Linux syscalls fork() or
// clone() (without CLONE_VM).
func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
+ mm.AddressSpace().PreFork()
+ defer mm.AddressSpace().PostFork()
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
mm.mappingMu.RLock()
@@ -80,9 +82,11 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
envv: mm.envv,
auxv: append(arch.Auxv(nil), mm.auxv...),
// IncRef'd below, once we know that there isn't an error.
- executable: mm.executable,
- dumpability: mm.dumpability,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ executable: mm.executable,
+ dumpability: mm.dumpability,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: mm.sleepForActivation,
+ vdsoSigReturnAddr: mm.vdsoSigReturnAddr,
}
// Copy vmas.
@@ -229,7 +233,15 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
// IncUsers increments mm's user count and returns true. If the user count is
// already 0, IncUsers does nothing and returns false.
func (mm *MemoryManager) IncUsers() bool {
- return atomicbitops.IncUnlessZeroInt32(&mm.users)
+ for {
+ users := atomic.LoadInt32(&mm.users)
+ if users == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt32(&mm.users, users, users+1) {
+ return true
+ }
+ }
}
// DecUsers decrements mm's user count. If the user count reaches 0, all
@@ -248,7 +260,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) {
mm.executable = nil
mm.metadataMu.Unlock()
if exe != nil {
- exe.DecRef()
+ exe.DecRef(ctx)
}
mm.activeMu.Lock()
diff --git a/pkg/sentry/mm/metadata.go b/pkg/sentry/mm/metadata.go
index d2a01d48a..0cfd60f6c 100644
--- a/pkg/sentry/mm/metadata.go
+++ b/pkg/sentry/mm/metadata.go
@@ -15,9 +15,10 @@
package mm
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Dumpability describes if and how core dumps should be created.
@@ -132,7 +133,7 @@ func (mm *MemoryManager) SetAuxv(auxv arch.Auxv) {
//
// An additional reference will be taken in the case of a non-nil executable,
// which must be released by the caller.
-func (mm *MemoryManager) Executable() *fs.Dirent {
+func (mm *MemoryManager) Executable() fsbridge.File {
mm.metadataMu.Lock()
defer mm.metadataMu.Unlock()
@@ -147,15 +148,15 @@ func (mm *MemoryManager) Executable() *fs.Dirent {
// SetExecutable sets the executable.
//
// This takes a reference on d.
-func (mm *MemoryManager) SetExecutable(d *fs.Dirent) {
+func (mm *MemoryManager) SetExecutable(ctx context.Context, file fsbridge.File) {
mm.metadataMu.Lock()
// Grab a new reference.
- d.IncRef()
+ file.IncRef()
// Set the executable.
orig := mm.executable
- mm.executable = d
+ mm.executable = file
mm.metadataMu.Unlock()
@@ -164,6 +165,20 @@ func (mm *MemoryManager) SetExecutable(d *fs.Dirent) {
// Do this without holding the lock, since it may wind up doing some
// I/O to sync the dirent, etc.
if orig != nil {
- orig.DecRef()
+ orig.DecRef(ctx)
}
}
+
+// VDSOSigReturn returns the address of vdso_sigreturn.
+func (mm *MemoryManager) VDSOSigReturn() uint64 {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ return mm.vdsoSigReturnAddr
+}
+
+// SetVDSOSigReturn sets the address of vdso_sigreturn.
+func (mm *MemoryManager) SetVDSOSigReturn(addr uint64) {
+ mm.metadataMu.Lock()
+ defer mm.metadataMu.Unlock()
+ mm.vdsoSigReturnAddr = addr
+}
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index f350e0109..3e85964e4 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -25,7 +25,7 @@
// Locks taken by memmap.Mappable.Translate
// mm.privateRefs.mu
// platform.AddressSpace locks
-// platform.File locks
+// memmap.File locks
// mm.aioManager.mu
// mm.AIOContext.mu
//
@@ -35,16 +35,15 @@
package mm
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// MemoryManager implements a virtual address space.
@@ -82,7 +81,7 @@ type MemoryManager struct {
users int32
// mappingMu is analogous to Linux's struct mm_struct::mmap_sem.
- mappingMu gvsync.DowngradableRWMutex `state:"nosave"`
+ mappingMu sync.RWMutex `state:"nosave"`
// vmas stores virtual memory areas. Since vmas are stored by value,
// clients should usually use vmaIterator.ValuePtr() instead of
@@ -125,7 +124,7 @@ type MemoryManager struct {
// activeMu is loosely analogous to Linux's struct
// mm_struct::page_table_lock.
- activeMu gvsync.DowngradableRWMutex `state:"nosave"`
+ activeMu sync.RWMutex `state:"nosave"`
// pmas stores platform mapping areas used to implement vmas. Since pmas
// are stored by value, clients should usually use pmaIterator.ValuePtr()
@@ -217,7 +216,7 @@ type MemoryManager struct {
// is not nil, it holds a reference on the Dirent.
//
// executable is protected by metadataMu.
- executable *fs.Dirent
+ executable fsbridge.File
// dumpability describes if and how this MemoryManager may be dumped to
// userspace.
@@ -228,6 +227,14 @@ type MemoryManager struct {
// aioManager keeps track of AIOContexts used for async IOs. AIOManager
// must be cloned when CLONE_VM is used.
aioManager aioManager
+
+ // sleepForActivation indicates whether the task should report to be sleeping
+ // before trying to activate the address space. When set to true, delays in
+ // activation are not reported as stuck tasks by the watchdog.
+ sleepForActivation bool
+
+ // vdsoSigReturnAddr is the address of 'vdso_sigreturn'.
+ vdsoSigReturnAddr uint64
}
// vma represents a virtual memory area.
@@ -280,7 +287,7 @@ type vma struct {
mlockMode memmap.MLockMode
// numaPolicy is the NUMA policy for this vma set by mbind().
- numaPolicy int32
+ numaPolicy linux.NumaPolicy
// numaNodemask is the NUMA nodemask for this vma set by mbind().
numaNodemask uint64
@@ -389,7 +396,7 @@ type pma struct {
// file is the file mapped by this pma. Only pmas for which file ==
// MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to
// the corresponding file range while they exist.
- file platform.File `state:"nosave"`
+ file memmap.File `state:"nosave"`
// off is the offset into file at which this pma begins.
//
@@ -429,7 +436,7 @@ type pma struct {
private bool
// If internalMappings is not empty, it is the cached return value of
- // file.MapInternal for the platform.FileRange mapped by this pma.
+ // file.MapInternal for the memmap.FileRange mapped by this pma.
internalMappings safemem.BlockSeq `state:"nosave"`
}
@@ -462,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 {
func (fileRefcountSetFunctions) ClearValue(_ *int32) {
}
-func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) {
+func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) {
return rc1, rc1 == rc2
}
-func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) {
+func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) {
return rc, rc
}
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index 4d2bfaaed..fdc308542 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -17,21 +17,21 @@ package mm
import (
"testing"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func testMemoryManager(ctx context.Context) *MemoryManager {
p := platform.FromContext(ctx)
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
- mm := NewMemoryManager(p, mfp)
+ mm := NewMemoryManager(p, mfp, false)
mm.layout = arch.MmapLayout{
MinAddr: p.MinUserAddress(),
MaxAddr: p.MaxUserAddress(),
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index c976c6f45..930ec895f 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -17,14 +17,13 @@ package mm
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// existingPMAsLocked checks that pmas exist for all addresses in ar, and
@@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
}
}
-// Pin returns the platform.File ranges currently mapped by addresses in ar in
+// Pin returns the memmap.File ranges currently mapped by addresses in ar in
// mm, acquiring a reference on the returned ranges which the caller must
// release by calling Unpin. If not all addresses are mapped, Pin returns a
// non-nil error. Note that Pin may return both a non-empty slice of
@@ -674,15 +673,15 @@ type PinnedRange struct {
Source usermem.AddrRange
// File is the mapped file.
- File platform.File
+ File memmap.File
// Offset is the offset into File at which this PinnedRange begins.
Offset uint64
}
-// FileRange returns the platform.File offsets mapped by pr.
-func (pr PinnedRange) FileRange() platform.FileRange {
- return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
+// FileRange returns the memmap.File offsets mapped by pr.
+func (pr PinnedRange) FileRange() memmap.FileRange {
+ return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
}
// Unpin releases the reference held by prs.
@@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf
}
// incPrivateRef acquires a reference on private pages in fr.
-func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
+func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) {
mm.privateRefs.mu.Lock()
defer mm.privateRefs.mu.Unlock()
refSet := &mm.privateRefs.refs
@@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
}
// decPrivateRef releases a reference on private pages in fr.
-func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) {
- var freed []platform.FileRange
+func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) {
+ var freed []memmap.FileRange
mm.privateRefs.mu.Lock()
refSet := &mm.privateRefs.refs
@@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa
// Discard internal mappings instead of trying to merge them, since merging
// them requires an allocation and getting them again from the
- // platform.File might not.
+ // memmap.File might not.
pma1.internalMappings = safemem.BlockSeq{}
return pma1, true
}
@@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error {
return nil
}
-func (pseg pmaIterator) fileRange() platform.FileRange {
+func (pseg pmaIterator) fileRange() memmap.FileRange {
return pseg.fileRangeOf(pseg.Range())
}
// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0.
-func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
+func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
if checkInvariants {
if !pseg.Ok() {
panic("terminal pma iterator")
@@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
pma := pseg.ValuePtr()
pstart := pseg.Start()
- return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
+ return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
}
diff --git a/pkg/sentry/mm/procfs.go b/pkg/sentry/mm/procfs.go
index 8c2246bb4..6efe5102b 100644
--- a/pkg/sentry/mm/procfs.go
+++ b/pkg/sentry/mm/procfs.go
@@ -19,10 +19,10 @@ import (
"fmt"
"strings"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs/proc/seqfile"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -66,8 +66,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.appendVMAMapsEntryLocked(ctx, vseg, buf)
}
@@ -81,7 +79,6 @@ func (mm *MemoryManager) ReadMapsDataInto(ctx context.Context, buf *bytes.Buffer
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallMapsEntry)
}
}
@@ -97,8 +94,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaMapsEntryLocked(ctx, vseg),
@@ -116,7 +111,6 @@ func (mm *MemoryManager) ReadMapsSeqFileData(ctx context.Context, handle seqfile
//
// Artifically adjust the seqfile handle so we only output vsyscall entry once.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallMapsEntry),
@@ -154,7 +148,7 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI
// Do not include the guard page: fs/proc/task_mmu.c:show_map_vma() =>
// stack_guard_page_start().
- fmt.Fprintf(b, "%08x-%08x %s%s %08x %02x:%02x %d ",
+ lineLen, _ := fmt.Fprintf(b, "%08x-%08x %s%s %08x %02x:%02x %d ",
vseg.Start(), vseg.End(), vma.realPerms, private, vma.off, devMajor, devMinor, ino)
// Figure out our filename or hint.
@@ -171,7 +165,7 @@ func (mm *MemoryManager) appendVMAMapsEntryLocked(ctx context.Context, vseg vmaI
}
if s != "" {
// Per linux, we pad until the 74th character.
- if pad := 73 - b.Len(); pad > 0 {
+ if pad := 73 - lineLen; pad > 0 {
b.WriteString(strings.Repeat(" ", pad))
}
b.WriteString(s)
@@ -187,15 +181,12 @@ func (mm *MemoryManager) ReadSmapsDataInto(ctx context.Context, buf *bytes.Buffe
var start usermem.Addr
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
mm.vmaSmapsEntryIntoLocked(ctx, vseg, buf)
}
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
buf.WriteString(vsyscallSmapsEntry)
}
}
@@ -211,8 +202,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
start = *handle.(*usermem.Addr)
}
for vseg := mm.vmas.LowerBoundSegment(start); vseg.Ok(); vseg = vseg.NextSegment() {
- // FIXME(b/30793614): If we use a usermem.Addr for the handle, we get
- // "panic: autosave error: type usermem.Addr is not registered".
vmaAddr := vseg.End()
data = append(data, seqfile.SeqData{
Buf: mm.vmaSmapsEntryLocked(ctx, vseg),
@@ -223,7 +212,6 @@ func (mm *MemoryManager) ReadSmapsSeqFileData(ctx context.Context, handle seqfil
// We always emulate vsyscall, so advertise it here. See
// ReadMapsSeqFileData for additional commentary.
if start != vsyscallEnd {
- // FIXME(b/30793614): Can't get a pointer to constant vsyscallEnd.
vmaAddr := vsyscallEnd
data = append(data, seqfile.SeqData{
Buf: []byte(vsyscallSmapsEntry),
diff --git a/pkg/sentry/mm/save_restore.go b/pkg/sentry/mm/save_restore.go
index 93259c5a3..f56215d9a 100644
--- a/pkg/sentry/mm/save_restore.go
+++ b/pkg/sentry/mm/save_restore.go
@@ -17,7 +17,7 @@ package mm
import (
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// InvalidateUnsavable invokes memmap.Mappable.InvalidateUnsavable on all
diff --git a/pkg/sentry/mm/shm.go b/pkg/sentry/mm/shm.go
index b9f2d23e5..6432731d4 100644
--- a/pkg/sentry/mm/shm.go
+++ b/pkg/sentry/mm/shm.go
@@ -15,10 +15,10 @@
package mm
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/shm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// DetachShm unmaps a sysv shared memory segment.
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index ea2d7af74..4cdb52eb6 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -15,14 +15,13 @@
package mm
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// SpecialMappable implements memmap.MappingIdentity and memmap.Mappable with
@@ -35,7 +34,7 @@ type SpecialMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
name string
}
@@ -44,15 +43,15 @@ type SpecialMappable struct {
// SpecialMappable will use the given name in /proc/[pid]/maps.
//
// Preconditions: fr.Length() != 0.
-func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable {
+func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable {
m := SpecialMappable{mfp: mfp, fr: fr, name: name}
m.EnableLeakCheck("mm.SpecialMappable")
return &m
}
// DecRef implements refs.RefCounter.DecRef.
-func (m *SpecialMappable) DecRef() {
- m.AtomicRefCount.DecRefWithDestructor(func() {
+func (m *SpecialMappable) DecRef(ctx context.Context) {
+ m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) {
m.mfp.MemoryFile().DecRef(m.fr)
})
}
@@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider {
// FileRange returns the offsets into MemoryFileProvider().MemoryFile() that
// store the SpecialMappable's contents.
-func (m *SpecialMappable) FileRange() platform.FileRange {
+func (m *SpecialMappable) FileRange() memmap.FileRange {
return m.fr
}
diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go
index c2466c988..e74d4e1c1 100644
--- a/pkg/sentry/mm/syscalls.go
+++ b/pkg/sentry/mm/syscalls.go
@@ -19,14 +19,14 @@ import (
mrand "math/rand"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// HandleUserFault handles an application page fault. sp is the faulting
@@ -101,7 +101,7 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme
if err != nil {
return 0, err
}
- defer m.DecRef()
+ defer m.DecRef(ctx)
opts.MappingIdentity = m
opts.Mappable = m
}
@@ -974,7 +974,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error
}
// NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR).
-func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (int32, uint64, error) {
+func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (linux.NumaPolicy, uint64, error) {
mm.mappingMu.RLock()
defer mm.mappingMu.RUnlock()
vseg := mm.vmas.FindSegment(addr)
@@ -986,7 +986,7 @@ func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (int32, uint64, error) {
}
// SetNumaPolicy implements the semantics of Linux's mbind().
-func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy int32, nodemask uint64) error {
+func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy linux.NumaPolicy, nodemask uint64) error {
if !addr.IsPageAligned() {
return syserror.EINVAL
}
@@ -1191,7 +1191,7 @@ func (mm *MemoryManager) MSync(ctx context.Context, addr usermem.Addr, length ui
mr := vseg.mappableRangeOf(vseg.Range().Intersect(ar))
mm.mappingMu.RUnlock()
err := id.Msync(ctx, mr)
- id.DecRef()
+ id.DecRef(ctx)
if err != nil {
return err
}
diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go
index f2fd70799..c4e1989ed 100644
--- a/pkg/sentry/mm/vma.go
+++ b/pkg/sentry/mm/vma.go
@@ -18,13 +18,13 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Preconditions: mm.mappingMu must be locked for writing. opts must be valid
@@ -42,7 +42,12 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp
Map32Bit: opts.Map32Bit,
})
if err != nil {
- return vmaIterator{}, usermem.AddrRange{}, err
+ // Can't force without opts.Unmap and opts.Fixed.
+ if opts.Force && opts.Unmap && opts.Fixed {
+ addr = opts.Addr
+ } else {
+ return vmaIterator{}, usermem.AddrRange{}, err
+ }
}
ar, _ := addr.ToRange(opts.Length)
@@ -195,7 +200,7 @@ func (mm *MemoryManager) applicationAddrRange() usermem.AddrRange {
// Preconditions: mm.mappingMu must be locked.
func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
- for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextGap() {
+ for gap := mm.vmas.LowerBoundGap(bounds.Start); gap.Ok() && gap.Start() < bounds.End; gap = gap.NextLargeEnoughGap(usermem.Addr(length)) {
if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
// Can we shift up to match the alignment?
if offset := uint64(gr.Start) % alignment; offset != 0 {
@@ -214,7 +219,7 @@ func (mm *MemoryManager) findLowestAvailableLocked(length, alignment uint64, bou
// Preconditions: mm.mappingMu must be locked.
func (mm *MemoryManager) findHighestAvailableLocked(length, alignment uint64, bounds usermem.AddrRange) (usermem.Addr, error) {
- for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevGap() {
+ for gap := mm.vmas.UpperBoundGap(bounds.End); gap.Ok() && gap.End() > bounds.Start; gap = gap.PrevLargeEnoughGap(usermem.Addr(length)) {
if gr := gap.availableRange().Intersect(bounds); uint64(gr.Length()) >= length {
// Can we shift down to match the alignment?
start := gr.End - usermem.Addr(length)
@@ -377,7 +382,7 @@ func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRa
vma.mappable.RemoveMapping(ctx, mm, vmaAR, vma.off, vma.canWriteMappableLocked())
}
if vma.id != nil {
- vma.id.DecRef()
+ vma.id.DecRef(ctx)
}
mm.usageAS -= uint64(vmaAR.Length())
if vma.isPrivateDataLocked() {
@@ -446,7 +451,7 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa
}
if vma2.id != nil {
- vma2.id.DecRef()
+ vma2.id.DecRef(context.Background())
}
return vma1, true
}
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index f404107af..7a3311a70 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -34,21 +33,42 @@ go_template_instance(
out = "usage_set.go",
consts = {
"minDegree": "10",
+ "trackGaps": "1",
},
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "pgalloc",
prefix = "usage",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "usageInfo",
"Functions": "usageSetFunctions",
},
)
+go_template_instance(
+ name = "reclaim_set",
+ out = "reclaim_set.go",
+ consts = {
+ "minDegree": "10",
+ },
+ imports = {
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
+ },
+ package = "pgalloc",
+ prefix = "reclaim",
+ template = "//pkg/segment:generic_set",
+ types = {
+ "Key": "uint64",
+ "Range": "memmap.FileRange",
+ "Value": "reclaimSetValue",
+ "Functions": "reclaimSetFunctions",
+ },
+)
+
go_library(
name = "pgalloc",
srcs = [
@@ -57,23 +77,25 @@ go_library(
"evictable_range_set.go",
"pgalloc.go",
"pgalloc_unsafe.go",
+ "reclaim_set.go",
"save_restore.go",
"usage_set.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/pgalloc",
visibility = ["//pkg/sentry:internal"],
deps = [
+ "//pkg/context",
"//pkg/log",
"//pkg/memutil",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/hostmm",
- "//pkg/sentry/platform",
- "//pkg/sentry/safemem",
+ "//pkg/sentry/memmap",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
"//pkg/state",
+ "//pkg/state/wire",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
@@ -81,6 +103,6 @@ go_test(
name = "pgalloc_test",
size = "small",
srcs = ["pgalloc_test.go"],
- embed = [":pgalloc"],
- deps = ["//pkg/sentry/usermem"],
+ library = ":pgalloc",
+ deps = ["//pkg/usermem"],
)
diff --git a/pkg/sentry/pgalloc/context.go b/pkg/sentry/pgalloc/context.go
index 11ccf897b..d25215418 100644
--- a/pkg/sentry/pgalloc/context.go
+++ b/pkg/sentry/pgalloc/context.go
@@ -15,7 +15,7 @@
package pgalloc
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is this package's type for context.Context.Value keys.
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index f7f7298c4..46d3be58c 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -25,22 +25,22 @@ import (
"fmt"
"math"
"os"
- "sync"
"sync/atomic"
"syscall"
"time"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostmm"
- "gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// MemoryFile is a platform.File whose pages may be allocated to arbitrary
+// MemoryFile is a memmap.File whose pages may be allocated to arbitrary
// users.
type MemoryFile struct {
// opts holds options passed to NewMemoryFile. opts is immutable.
@@ -108,12 +108,6 @@ type MemoryFile struct {
usageSwapped uint64
usageLast time.Time
- // minUnallocatedPage is the minimum page that may be unallocated.
- // i.e., there are no unallocated pages below minUnallocatedPage.
- //
- // minUnallocatedPage is protected by mu.
- minUnallocatedPage uint64
-
// fileSize is the size of the backing memory file in bytes. fileSize is
// always a power-of-two multiple of chunkSize.
//
@@ -146,11 +140,9 @@ type MemoryFile struct {
// is protected by mu.
reclaimable bool
- // minReclaimablePage is the minimum page that may be reclaimable.
- // i.e., all reclaimable pages are >= minReclaimablePage.
- //
- // minReclaimablePage is protected by mu.
- minReclaimablePage uint64
+ // relcaim is the collection of regions for reclaim. relcaim is protected
+ // by mu.
+ reclaim reclaimSet
// reclaimCond is signaled (with mu locked) when reclaimable or destroyed
// transitions from false to true.
@@ -180,6 +172,11 @@ type MemoryFileOpts struct {
// notifications to determine when eviction is necessary. This option has
// no effect unless DelayedEviction is DelayedEvictionEnabled.
UseHostMemcgPressure bool
+
+ // If ManualZeroing is true, MemoryFile must not assume that new pages
+ // obtained from the host are zero-filled, such that MemoryFile must manually
+ // zero newly-allocated pages.
+ ManualZeroing bool
}
// DelayedEvictionType is the type of MemoryFileOpts.DelayedEviction.
@@ -268,12 +265,10 @@ type evictableMemoryUserInfo struct {
}
const (
- chunkShift = 24
- chunkSize = 1 << chunkShift // 16 MB
+ chunkShift = 30
+ chunkSize = 1 << chunkShift // 1 GB
chunkMask = chunkSize - 1
- initialSize = chunkSize
-
// maxPage is the highest 64-bit page.
maxPage = math.MaxUint64 &^ (usermem.PageSize - 1)
)
@@ -297,19 +292,12 @@ func NewMemoryFile(file *os.File, opts MemoryFileOpts) (*MemoryFile, error) {
if err := file.Truncate(0); err != nil {
return nil, err
}
- if err := file.Truncate(initialSize); err != nil {
- return nil, err
- }
f := &MemoryFile{
- opts: opts,
- fileSize: initialSize,
- file: file,
- // No pages are reclaimable. DecRef will always be able to
- // decrease minReclaimablePage from this point.
- minReclaimablePage: maxPage,
- evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
+ opts: opts,
+ file: file,
+ evictable: make(map[EvictableMemoryUser]*evictableMemoryUserInfo),
}
- f.mappings.Store(make([]uintptr, initialSize/chunkSize))
+ f.mappings.Store(make([]uintptr, 0))
f.reclaimCond.L = &f.mu
if f.opts.DelayedEviction == DelayedEvictionEnabled && f.opts.UseHostMemcgPressure {
@@ -384,7 +372,7 @@ func (f *MemoryFile) Destroy() {
// to Allocate.
//
// Preconditions: length must be page-aligned and non-zero.
-func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) {
+func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) {
if length == 0 || length%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid allocation length: %#x", length))
}
@@ -399,39 +387,38 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
alignment = usermem.HugePageSize
}
- start, minUnallocatedPage := findUnallocatedRange(&f.usage, f.minUnallocatedPage, length, alignment)
- end := start + length
- // File offsets are int64s. Since length must be strictly positive, end
- // cannot legitimately be 0.
- if end < start || int64(end) <= 0 {
- return platform.FileRange{}, syserror.ENOMEM
+ // Find a range in the underlying file.
+ fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
+ if !ok {
+ return memmap.FileRange{}, syserror.ENOMEM
}
- // Expand the file if needed. Double the file size on each expansion;
- // uncommitted pages have effectively no cost.
- fileSize := f.fileSize
- for int64(end) > fileSize {
- if fileSize >= 2*fileSize {
- // fileSize overflow.
- return platform.FileRange{}, syserror.ENOMEM
- }
- fileSize *= 2
- }
- if fileSize > f.fileSize {
- if err := f.file.Truncate(fileSize); err != nil {
- return platform.FileRange{}, err
+ // Expand the file if needed.
+ if int64(fr.End) > f.fileSize {
+ // Round the new file size up to be chunk-aligned.
+ newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask
+ if err := f.file.Truncate(newFileSize); err != nil {
+ return memmap.FileRange{}, err
}
- f.fileSize = fileSize
+ f.fileSize = newFileSize
f.mappingsMu.Lock()
oldMappings := f.mappings.Load().([]uintptr)
- newMappings := make([]uintptr, fileSize>>chunkShift)
+ newMappings := make([]uintptr, newFileSize>>chunkShift)
copy(newMappings, oldMappings)
f.mappings.Store(newMappings)
f.mappingsMu.Unlock()
}
+ if f.opts.ManualZeroing {
+ if err := f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ }); err != nil {
+ return memmap.FileRange{}, err
+ }
+ }
// Mark selected pages as in use.
- fr := platform.FileRange{start, end}
if !f.usage.Add(fr, usageInfo{
kind: kind,
refs: 1,
@@ -439,49 +426,79 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
panic(fmt.Sprintf("allocating %v: failed to insert into usage set:\n%v", fr, &f.usage))
}
- if minUnallocatedPage < start {
- f.minUnallocatedPage = minUnallocatedPage
- } else {
- // start was the first unallocated page. The next must be
- // somewhere beyond end.
- f.minUnallocatedPage = end
- }
-
return fr, nil
}
-// findUnallocatedRange returns the first unallocated page in usage of the
-// specified length and alignment beginning at page start and the first single
-// unallocated page.
-func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uint64, uint64) {
- // Only searched until the first page is found.
- firstPage := start
- foundFirstPage := false
- alignMask := alignment - 1
- for seg := usage.LowerBoundSegment(start); seg.Ok(); seg = seg.NextSegment() {
- r := seg.Range()
-
- if !foundFirstPage && r.Start > firstPage {
- foundFirstPage = true
+// findAvailableRange returns an available range in the usageSet.
+//
+// Note that scanning for available slots takes place from end first backwards,
+// then forwards. This heuristic has important consequence for how sequential
+// mappings can be merged in the host VMAs, given that addresses for both
+// application and sentry mappings are allocated top-down (from higher to
+// lower addresses). The file is also grown expoentially in order to create
+// space for mappings to be allocated downwards.
+//
+// Precondition: alignment must be a power of 2.
+func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) {
+ alignmentMask := alignment - 1
+
+ // Search for space in existing gaps, starting at the current end of the
+ // file and working backward.
+ lastGap := usage.LastGap()
+ gap := lastGap
+ for {
+ end := gap.End()
+ if end > uint64(fileSize) {
+ end = uint64(fileSize)
}
- if start >= r.End {
- // start was rounded up to an alignment boundary from the end
- // of a previous segment and is now beyond r.End.
- continue
+ // Try to allocate from the end of this gap, with the start of the
+ // allocated range aligned down to alignment.
+ unalignedStart := end - length
+ if unalignedStart > end {
+ // Negative overflow: this and all preceding gaps are too small to
+ // accommodate length.
+ break
+ }
+ if start := unalignedStart &^ alignmentMask; start >= gap.Start() {
+ return memmap.FileRange{start, start + length}, true
}
- // This segment represents allocated or reclaimable pages; only the
- // range from start to the segment's beginning is allocatable, and the
- // next allocatable range begins after the segment.
- if r.Start > start && r.Start-start >= length {
+
+ gap = gap.PrevLargeEnoughGap(length)
+ if !gap.Ok() {
break
}
- start = (r.End + alignMask) &^ alignMask
- if !foundFirstPage {
- firstPage = r.End
+ }
+
+ // Check that it's possible to fit this allocation at the end of a file of any size.
+ min := lastGap.Start()
+ min = (min + alignmentMask) &^ alignmentMask
+ if min+length < min {
+ // Overflow: allocation would exceed the range of uint64.
+ return memmap.FileRange{}, false
+ }
+
+ // Determine the minimum file size required to fit this allocation at its end.
+ for {
+ newFileSize := 2 * fileSize
+ if newFileSize <= fileSize {
+ if fileSize != 0 {
+ // Overflow: allocation would exceed the range of int64.
+ return memmap.FileRange{}, false
+ }
+ newFileSize = chunkSize
+ }
+ fileSize = newFileSize
+
+ unalignedStart := uint64(fileSize) - length
+ if unalignedStart > uint64(fileSize) {
+ // Negative overflow: fileSize is still inadequate.
+ continue
+ }
+ if start := unalignedStart &^ alignmentMask; start >= min {
+ return memmap.FileRange{start, start + length}, true
}
}
- return start, firstPage
}
// AllocateAndFill allocates memory of the given kind and fills it by calling
@@ -491,22 +508,22 @@ func findUnallocatedRange(usage *usageSet, start, length, alignment uint64) (uin
// by r.ReadToBlocks(), it returns that error.
//
// Preconditions: length > 0. length must be page-aligned.
-func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) {
+func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) {
fr, err := f.Allocate(length, kind)
if err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
dsts, err := f.MapInternal(fr, usermem.Write)
if err != nil {
f.DecRef(fr)
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
n, err := safemem.ReadFullToBlocks(r, dsts)
un := uint64(usermem.Addr(n).RoundDown())
if un < length {
// Free unused memory and update fr to contain only the memory that is
// still allocated.
- f.DecRef(platform.FileRange{fr.Start + un, fr.End})
+ f.DecRef(memmap.FileRange{fr.Start + un, fr.End})
fr.End = fr.Start + un
}
return fr, err
@@ -523,7 +540,7 @@ const (
// will read zeroes.
//
// Preconditions: fr.Length() > 0.
-func (f *MemoryFile) Decommit(fr platform.FileRange) error {
+func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -543,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error {
return nil
}
-func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
+func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
// Since we're changing the knownCommitted attribute, we need to merge
@@ -564,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
f.usage.MergeRange(fr)
}
-// IncRef implements platform.File.IncRef.
-func (f *MemoryFile) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (f *MemoryFile) IncRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -583,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
}
-// DecRef implements platform.File.DecRef.
-func (f *MemoryFile) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (f *MemoryFile) DecRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -602,6 +619,7 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
}
val.refs--
if val.refs == 0 {
+ f.reclaim.Add(seg.Range(), reclaimSetValue{})
freed = true
// Reclassify memory as System, until it's freed by the reclaim
// goroutine.
@@ -614,17 +632,13 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
if freed {
- if fr.Start < f.minReclaimablePage {
- // We've freed at least one lower page.
- f.minReclaimablePage = fr.Start
- }
f.reclaimable = true
f.reclaimCond.Signal()
}
}
-// MapInternal implements platform.File.MapInternal.
-func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
if !fr.WellFormed() || fr.Length() == 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -650,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (
// forEachMappingSlice invokes fn on a sequence of byte slices that
// collectively map all bytes in fr.
-func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error {
+func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error {
mappings := f.mappings.Load().([]uintptr)
for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
chunk := int(chunkStart >> chunkShift)
@@ -930,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
continue
case !populated && populatedRun:
// Finish the run by changing this segment.
- runRange := platform.FileRange{
+ runRange := memmap.FileRange{
Start: r.Start + uint64(populatedRunStart*usermem.PageSize),
End: r.Start + uint64(i*usermem.PageSize),
}
@@ -995,7 +1009,7 @@ func (f *MemoryFile) File() *os.File {
return f.file
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (f *MemoryFile) FD() int {
return int(f.file.Fd())
}
@@ -1016,6 +1030,7 @@ func (f *MemoryFile) String() string {
// for allocation.
func (f *MemoryFile) runReclaim() {
for {
+ // N.B. We must call f.markReclaimed on the returned FrameRange.
fr, ok := f.findReclaimable()
if !ok {
break
@@ -1071,13 +1086,17 @@ func (f *MemoryFile) runReclaim() {
}
}
-func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
+// findReclaimable finds memory that has been marked for reclaim.
+//
+// Note that there returned range will be removed from tracking. It
+// must be reclaimed (removed from f.usage) at this point.
+func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) {
f.mu.Lock()
defer f.mu.Unlock()
for {
for {
if f.destroyed {
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
if f.reclaimable {
break
@@ -1089,27 +1108,24 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
}
f.reclaimCond.Wait()
}
- // Allocate returns the first usable range in offset order and is
- // currently a linear scan, so reclaiming from the beginning of the
- // file minimizes the expected latency of Allocate.
- for seg := f.usage.LowerBoundSegment(f.minReclaimablePage); seg.Ok(); seg = seg.NextSegment() {
- if seg.ValuePtr().refs == 0 {
- f.minReclaimablePage = seg.End()
- return seg.Range(), true
- }
+ // Allocate works from the back of the file inwards, so reclaim
+ // preserves this order to minimize the cost of the search.
+ if seg := f.reclaim.LastSegment(); seg.Ok() {
+ fr := seg.Range()
+ f.reclaim.Remove(seg)
+ return fr, true
}
- // No pages are reclaimable.
+ // Nothing is reclaimable.
f.reclaimable = false
- f.minReclaimablePage = maxPage
}
}
-func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
+func (f *MemoryFile) markReclaimed(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
seg := f.usage.FindSegment(fr.Start)
- // All of fr should be mapped to a single uncommitted reclaimable segment
- // accounted to System.
+ // All of fr should be mapped to a single uncommitted reclaimable
+ // segment accounted to System.
if !seg.Ok() {
panic(fmt.Sprintf("reclaimed pages %v include unreferenced pages:\n%v", fr, &f.usage))
}
@@ -1123,14 +1139,10 @@ func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
}); got != want {
panic(fmt.Sprintf("reclaimed pages %v in segment %v has incorrect state %v, wanted %v:\n%v", fr, seg.Range(), got, want, &f.usage))
}
- // Deallocate reclaimed pages. Even though all of seg is reclaimable, the
- // caller of markReclaimed may not have decommitted it, so we can only mark
- // fr as reclaimed.
+ // Deallocate reclaimed pages. Even though all of seg is reclaimable,
+ // the caller of markReclaimed may not have decommitted it, so we can
+ // only mark fr as reclaimed.
f.usage.Remove(f.usage.Isolate(seg, fr))
- if fr.Start < f.minUnallocatedPage {
- // We've deallocated at least one lower page.
- f.minUnallocatedPage = fr.Start
- }
}
// StartEvictions requests that f evict all evictable allocations. It does not
@@ -1210,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 {
func (usageSetFunctions) ClearValue(val *usageInfo) {
}
-func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) {
+func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) {
return val1, val1 == val2
}
-func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
+func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
return val, val
}
@@ -1241,3 +1253,27 @@ func (evictableRangeSetFunctions) Merge(_ EvictableRange, _ evictableRangeSetVal
func (evictableRangeSetFunctions) Split(_ EvictableRange, _ evictableRangeSetValue, _ uint64) (evictableRangeSetValue, evictableRangeSetValue) {
return evictableRangeSetValue{}, evictableRangeSetValue{}
}
+
+// reclaimSetValue is the value type of reclaimSet.
+type reclaimSetValue struct{}
+
+type reclaimSetFunctions struct{}
+
+func (reclaimSetFunctions) MinKey() uint64 {
+ return 0
+}
+
+func (reclaimSetFunctions) MaxKey() uint64 {
+ return math.MaxUint64
+}
+
+func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) {
+}
+
+func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
+ return reclaimSetValue{}, true
+}
+
+func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
+ return reclaimSetValue{}, reclaimSetValue{}
+}
diff --git a/pkg/sentry/pgalloc/pgalloc_test.go b/pkg/sentry/pgalloc/pgalloc_test.go
index 428e6a859..405db141f 100644
--- a/pkg/sentry/pgalloc/pgalloc_test.go
+++ b/pkg/sentry/pgalloc/pgalloc_test.go
@@ -17,45 +17,55 @@ package pgalloc
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
page = usermem.PageSize
hugepage = usermem.HugePageSize
+ topPage = (1 << 63) - page
)
func TestFindUnallocatedRange(t *testing.T) {
for _, test := range []struct {
- desc string
- usage *usageSegmentDataSlices
- start uint64
- length uint64
- alignment uint64
- unallocated uint64
- minUnallocated uint64
+ desc string
+ usage *usageSegmentDataSlices
+ fileSize int64
+ length uint64
+ alignment uint64
+ start uint64
+ expectFail bool
}{
{
- desc: "Initial allocation succeeds",
- usage: &usageSegmentDataSlices{},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ desc: "Initial allocation succeeds",
+ usage: &usageSegmentDataSlices{},
+ length: page,
+ alignment: page,
+ start: chunkSize - page, // Grows by chunkSize, allocate down.
},
{
- desc: "Allocation begins at start of file",
+ desc: "Allocation finds empty space at start of file",
usage: &usageSegmentDataSlices{
Start: []uint64{page},
End: []uint64{2 * page},
Values: []usageInfo{{refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 0,
- minUnallocated: 0,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocation finds empty space at end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0},
+ End: []uint64{page},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "In-use frames are not allocatable",
@@ -64,11 +74,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 2 * page,
+ length: page,
+ alignment: page,
+ start: 3 * page, // Double fileSize, allocate top-down.
},
{
desc: "Reclaimable frames are not allocatable",
@@ -77,11 +86,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 2 * page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 0}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: 3 * page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: 5 * page, // Double fileSize, grow down.
},
{
desc: "Gaps between in-use frames are allocatable",
@@ -90,11 +98,10 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: page,
- alignment: page,
- unallocated: page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: page,
+ alignment: page,
+ start: page,
},
{
desc: "Inadequately-sized gaps are rejected",
@@ -103,14 +110,13 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, 3 * page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: 2 * page,
- alignment: page,
- unallocated: 3 * page,
- minUnallocated: page,
+ fileSize: 3 * page,
+ length: 2 * page,
+ alignment: page,
+ start: 4 * page, // Double fileSize, grow down.
},
{
- desc: "Hugepage alignment is honored",
+ desc: "Alignment is honored at end of file",
usage: &usageSegmentDataSlices{
Start: []uint64{0, hugepage + page},
// Hugepage-sized gap here that shouldn't be allocated from
@@ -118,37 +124,103 @@ func TestFindUnallocatedRange(t *testing.T) {
End: []uint64{page, hugepage + 2*page},
Values: []usageInfo{{refs: 1}, {refs: 1}},
},
- start: 0,
- length: hugepage,
- alignment: hugepage,
- unallocated: 2 * hugepage,
- minUnallocated: page,
+ fileSize: hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: 3 * hugepage, // Double fileSize until alignment is satisfied, grow down.
},
{
- desc: "Pages before start ignored",
+ desc: "Alignment is honored before end of file",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{0, 2*hugepage + page},
+ // Page will need to be shifted down from top.
+ End: []uint64{page, 2*hugepage + 2*page},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: 2*hugepage + 2*page,
+ length: hugepage,
+ alignment: hugepage,
+ start: hugepage,
+ },
+ {
+ desc: "Allocation doubles file size more than once if necessary",
+ usage: &usageSegmentDataSlices{},
+ fileSize: page,
+ length: 4 * page,
+ alignment: page,
+ start: 0,
+ },
+ {
+ desc: "Allocations are compact if possible",
usage: &usageSegmentDataSlices{
Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 4 * page,
+ length: page,
+ alignment: page,
+ start: 2 * page,
+ },
+ {
+ desc: "Top-down allocation within one gap",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 4 * page, 7 * page},
+ End: []uint64{2 * page, 5 * page, 8 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 6 * page,
+ },
+ {
+ desc: "Top-down allocation between multiple gaps",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, 3 * page, 5 * page},
+ End: []uint64{2 * page, 4 * page, 6 * page},
+ Values: []usageInfo{{refs: 1}, {refs: 2}, {refs: 1}},
+ },
+ fileSize: 6 * page,
+ length: page,
+ alignment: page,
+ start: 4 * page,
},
{
- desc: "start may be in the middle of segment",
+ desc: "Top-down allocation with large top gap",
usage: &usageSegmentDataSlices{
- Start: []uint64{0, 3 * page},
+ Start: []uint64{page, 3 * page},
End: []uint64{2 * page, 4 * page},
Values: []usageInfo{{refs: 1}, {refs: 2}},
},
- start: page,
- length: page,
- alignment: page,
- unallocated: 2 * page,
- minUnallocated: 2 * page,
+ fileSize: 8 * page,
+ length: page,
+ alignment: page,
+ start: 7 * page,
+ },
+ {
+ desc: "Gaps found with possible overflow",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page, topPage - page},
+ End: []uint64{2 * page, topPage},
+ Values: []usageInfo{{refs: 1}, {refs: 1}},
+ },
+ fileSize: topPage,
+ length: page,
+ alignment: page,
+ start: topPage - 2*page,
+ },
+ {
+ desc: "Overflow detected",
+ usage: &usageSegmentDataSlices{
+ Start: []uint64{page},
+ End: []uint64{topPage},
+ Values: []usageInfo{{refs: 1}},
+ },
+ fileSize: topPage,
+ length: 2 * page,
+ alignment: page,
+ expectFail: true,
},
} {
t.Run(test.desc, func(t *testing.T) {
@@ -156,12 +228,18 @@ func TestFindUnallocatedRange(t *testing.T) {
if err := usage.ImportSortedSlices(test.usage); err != nil {
t.Fatalf("Failed to initialize usage from %v: %v", test.usage, err)
}
- unallocated, minUnallocated := findUnallocatedRange(&usage, test.start, test.length, test.alignment)
- if unallocated != test.unallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got unallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, unallocated, test.unallocated)
+ fr, ok := findAvailableRange(&usage, test.fileSize, test.length, test.alignment)
+ if !test.expectFail && !ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, false wanted %x, true", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if test.expectFail && ok {
+ t.Fatalf("findAvailableRange(%v, %x, %x, %x): got %x, true wanted %x, false", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
+ }
+ if ok && fr.Start != test.start {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got start=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.Start, test.start)
}
- if minUnallocated != test.minUnallocated {
- t.Errorf("findUnallocatedRange(%v, %x, %x, %x): got minUnallocated %x, wanted %x", test.usage, test.start, test.length, test.alignment, minUnallocated, test.minUnallocated)
+ if ok && fr.End != test.start+test.length {
+ t.Errorf("findAvailableRange(%v, %x, %x, %x): got end=%x, wanted %x", test.usage, test.fileSize, test.length, test.alignment, fr.End, test.start+test.length)
}
})
}
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
index 1effc7735..78317fa35 100644
--- a/pkg/sentry/pgalloc/save_restore.go
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -16,6 +16,7 @@ package pgalloc
import (
"bytes"
+ "context"
"fmt"
"io"
"runtime"
@@ -24,12 +25,13 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/usage"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// SaveTo writes f's state to the given stream.
-func (f *MemoryFile) SaveTo(w io.Writer) error {
+func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error {
// Wait for reclaim.
f.mu.Lock()
defer f.mu.Unlock()
@@ -78,10 +80,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// Save metadata.
- if err := state.Save(w, &f.fileSize, nil); err != nil {
+ if _, err := state.Save(ctx, w, &f.fileSize); err != nil {
return err
}
- if err := state.Save(w, &f.usage, nil); err != nil {
+ if _, err := state.Save(ctx, w, &f.usage); err != nil {
return err
}
@@ -114,9 +116,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// LoadFrom loads MemoryFile state from the given stream.
-func (f *MemoryFile) LoadFrom(r io.Reader) error {
+func (f *MemoryFile) LoadFrom(ctx context.Context, r wire.Reader) error {
// Load metadata.
- if err := state.Load(r, &f.fileSize, nil); err != nil {
+ if _, err := state.Load(ctx, r, &f.fileSize); err != nil {
return err
}
if err := f.file.Truncate(f.fileSize); err != nil {
@@ -124,7 +126,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error {
}
newMappings := make([]uintptr, f.fileSize>>chunkShift)
f.mappings.Store(newMappings)
- if err := state.Load(r, &f.usage, nil); err != nil {
+ if _, err := state.Load(ctx, r, &f.usage); err != nil {
return err
}
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 157bffa81..209b28053 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -1,40 +1,21 @@
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
-go_template_instance(
- name = "file_range",
- out = "file_range.go",
- package = "platform",
- prefix = "File",
- template = "//pkg/segment:generic_range",
- types = {
- "T": "uint64",
- },
-)
-
go_library(
name = "platform",
srcs = [
"context.go",
- "file_range.go",
"mmap_min_addr.go",
"platform.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/atomicbitops",
- "//pkg/log",
+ "//pkg/context",
"//pkg/seccomp",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/platform/safecopy",
- "//pkg/sentry/safemem",
- "//pkg/sentry/usage",
- "//pkg/sentry/usermem",
- "//pkg/syserror",
+ "//pkg/sentry/memmap",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/platform/context.go b/pkg/sentry/platform/context.go
index e29bc4485..6759cda65 100644
--- a/pkg/sentry/platform/context.go
+++ b/pkg/sentry/platform/context.go
@@ -15,7 +15,7 @@
package platform
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is the auth package's type for context.Context.Value keys.
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
index b6d008dbe..83b385f14 100644
--- a/pkg/sentry/platform/interrupt/BUILD
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,13 +7,13 @@ go_library(
srcs = [
"interrupt.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform/interrupt",
visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
)
go_test(
name = "interrupt_test",
size = "small",
srcs = ["interrupt_test.go"],
- embed = [":interrupt"],
+ library = ":interrupt",
)
diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go
index a4651f500..57be41647 100644
--- a/pkg/sentry/platform/interrupt/interrupt.go
+++ b/pkg/sentry/platform/interrupt/interrupt.go
@@ -17,7 +17,8 @@ package interrupt
import (
"fmt"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// Receiver receives interrupt notifications from a Forwarder.
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 31fa48ec5..3970dd81d 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -7,53 +6,67 @@ go_library(
name = "kvm",
srcs = [
"address_space.go",
- "allocator.go",
"bluepill.go",
+ "bluepill_allocator.go",
"bluepill_amd64.go",
"bluepill_amd64.s",
"bluepill_amd64_unsafe.go",
+ "bluepill_arm64.go",
+ "bluepill_arm64.s",
+ "bluepill_arm64_unsafe.go",
"bluepill_fault.go",
"bluepill_unsafe.go",
"context.go",
- "filters.go",
+ "filters_amd64.go",
+ "filters_arm64.go",
"kvm.go",
"kvm_amd64.go",
"kvm_amd64_unsafe.go",
+ "kvm_arm64.go",
+ "kvm_arm64_unsafe.go",
"kvm_const.go",
+ "kvm_const_arm64.go",
"machine.go",
"machine_amd64.go",
"machine_amd64_unsafe.go",
+ "machine_arm64.go",
+ "machine_arm64_unsafe.go",
"machine_unsafe.go",
"physical_map.go",
+ "physical_map_amd64.go",
+ "physical_map_arm64.go",
"virtual_map.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/atomicbitops",
+ "//pkg/context",
"//pkg/cpuid",
"//pkg/log",
"//pkg/procid",
+ "//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/platform/ring0",
"//pkg/sentry/platform/ring0/pagetables",
- "//pkg/sentry/platform/safecopy",
"//pkg/sentry/time",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
+ "//pkg/usermem",
],
)
go_test(
name = "kvm_test",
srcs = [
+ "kvm_amd64_test.go",
"kvm_test.go",
"virtual_map_test.go",
],
- embed = [":kvm"],
+ library = ":kvm",
tags = [
"manual",
"nogotsan",
@@ -65,6 +78,6 @@ go_test(
"//pkg/sentry/platform/kvm/testutil",
"//pkg/sentry/platform/ring0",
"//pkg/sentry/platform/ring0/pagetables",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index acd41f73d..af5c5e191 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -15,27 +15,27 @@
package kvm
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// dirtySet tracks vCPUs for invalidation.
type dirtySet struct {
- vCPUs []uint64
+ vCPUMasks []uint64
}
// forEach iterates over all CPUs in the dirty set.
+//
+//go:nosplit
func (ds *dirtySet) forEach(m *machine, fn func(c *vCPU)) {
- m.mu.RLock()
- defer m.mu.RUnlock()
-
- for index := range ds.vCPUs {
- mask := atomic.SwapUint64(&ds.vCPUs[index], 0)
+ for index := range ds.vCPUMasks {
+ mask := atomic.SwapUint64(&ds.vCPUMasks[index], 0)
if mask != 0 {
for bit := 0; bit < 64; bit++ {
if mask&(1<<uint64(bit)) == 0 {
@@ -54,7 +54,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
index := uint64(c.id) / 64
bit := uint64(1) << uint(c.id%64)
- oldValue := atomic.LoadUint64(&ds.vCPUs[index])
+ oldValue := atomic.LoadUint64(&ds.vCPUMasks[index])
if oldValue&bit != 0 {
return false // Not clean.
}
@@ -62,7 +62,7 @@ func (ds *dirtySet) mark(c *vCPU) bool {
// Set the bit unilaterally, and ensure that a flush takes place. Note
// that it's possible for races to occur here, but since the flush is
// taking place long after these lines there's no race in practice.
- atomicbitops.OrUint64(&ds.vCPUs[index], bit)
+ atomicbitops.OrUint64(&ds.vCPUMasks[index], bit)
return true // Previously clean.
}
@@ -113,7 +113,12 @@ type hostMapEntry struct {
length uintptr
}
-func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
+// mapLocked maps the given host entry.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.AccessType) (inv bool) {
for m.length > 0 {
physical, length, ok := translateToPhysical(m.addr)
if !ok {
@@ -127,24 +132,16 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
// not have physical mappings, the KVM module may inject
// spurious exceptions when emulation fails (i.e. it tries to
// emulate because the RIP is pointed at those pages).
- as.machine.mapPhysical(physical, length)
+ as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE)
// Install the page table mappings. Note that the ordering is
// important; if the pagetable mappings were installed before
// ensuring the physical pages were available, then some other
// thread could theoretically access them.
- //
- // Due to the way KVM's shadow paging implementation works,
- // modifications to the page tables while in host mode may not
- // be trapped, leading to the shadow pages being out of sync.
- // Therefore, we need to ensure that we are in guest mode for
- // page table modifications. See the call to bluepill, below.
- as.machine.retryInGuest(func() {
- inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
- AccessType: at,
- User: true,
- }, physical) || inv
- })
+ inv = as.pageTables.Map(addr, length, pagetables.MapOpts{
+ AccessType: at,
+ User: true,
+ }, physical) || inv
m.addr += length
m.length -= length
addr += usermem.Addr(length)
@@ -154,7 +151,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
}
// MapFile implements platform.AddressSpace.MapFile.
-func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
as.mu.Lock()
defer as.mu.Unlock()
@@ -176,6 +173,10 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return err
}
+ // See block in mapLocked.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+
// Map the mappings in the sentry's address space (guest physical memory)
// into the application's address space (guest virtual memory).
inv := false
@@ -190,7 +191,12 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
_ = s[i] // Touch to commit.
}
}
- prev := as.mapHost(addr, hostMapEntry{
+
+ // See bluepill_allocator.go.
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ // Perform the mapping.
+ prev := as.mapLocked(addr, hostMapEntry{
addr: b.Addr(),
length: uintptr(b.Len()),
}, at)
@@ -204,17 +210,27 @@ func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.
return nil
}
+// unmapLocked is an escape-checked wrapped around Unmap.
+//
+// +checkescape:hard,stack
+//
+//go:nosplit
+func (as *addressSpace) unmapLocked(addr usermem.Addr, length uint64) bool {
+ return as.pageTables.Unmap(addr, uintptr(length))
+}
+
// Unmap unmaps the given range by calling pagetables.PageTables.Unmap.
func (as *addressSpace) Unmap(addr usermem.Addr, length uint64) {
as.mu.Lock()
defer as.mu.Unlock()
- // See above re: retryInGuest.
- var prev bool
- as.machine.retryInGuest(func() {
- prev = as.pageTables.Unmap(addr, uintptr(length)) || prev
- })
- if prev {
+ // See above & bluepill_allocator.go.
+ as.pageTables.Allocator.(*allocator).cpu = as.machine.Get()
+ defer as.machine.Put(as.pageTables.Allocator.(*allocator).cpu)
+ bluepill(as.pageTables.Allocator.(*allocator).cpu)
+
+ if prev := as.unmapLocked(addr, length); prev {
+ // Invalidate all active vCPUs.
as.invalidate()
// Recycle any freed intermediate pages.
@@ -227,8 +243,14 @@ func (as *addressSpace) Release() {
as.Unmap(0, ^uint64(0))
// Free all pages from the allocator.
- as.pageTables.Allocator.(allocator).base.Drain()
+ as.pageTables.Allocator.(*allocator).base.Drain()
// Drop all cached machine references.
as.machine.dropPageTables(as.pageTables)
}
+
+// PreFork implements platform.AddressSpace.PreFork.
+func (as *addressSpace) PreFork() {}
+
+// PostFork implements platform.AddressSpace.PostFork.
+func (as *addressSpace) PostFork() {}
diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go
index 043de51b3..4b23f7803 100644
--- a/pkg/sentry/platform/kvm/bluepill.go
+++ b/pkg/sentry/platform/kvm/bluepill.go
@@ -19,8 +19,9 @@ import (
"reflect"
"syscall"
+ "gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
// bluepill enters guest mode.
@@ -36,6 +37,18 @@ func sighandler()
func dieTrampoline()
var (
+ // bounceSignal is the signal used for bouncing KVM.
+ //
+ // We use SIGCHLD because it is not masked by the runtime, and
+ // it will be ignored properly by other parts of the kernel.
+ bounceSignal = syscall.SIGCHLD
+
+ // bounceSignalMask has only bounceSignal set.
+ bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1))
+
+ // bounce is the interrupt vector used to return to the kernel.
+ bounce = uint32(ring0.VirtualizationException)
+
// savedHandler is a pointer to the previous handler.
//
// This is called by bluepillHandler.
@@ -45,6 +58,13 @@ var (
dieTrampolineAddr uintptr
)
+// redpill invokes a syscall with -1.
+//
+//go:nosplit
+func redpill() {
+ syscall.RawSyscall(^uintptr(0), 0, 0, 0)
+}
+
// dieHandler is called by dieTrampoline.
//
//go:nosplit
@@ -61,20 +81,14 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) {
// Save the death message, which will be thrown.
c.dieState.message = msg
- // Reload all registers to have an accurate stack trace when we return
- // to host mode. This means that the stack should be unwound correctly.
- if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
- throw(msg)
- }
-
// Setup the trampoline.
dieArchSetup(c, context, &c.dieState.guestRegs)
}
func init() {
// Install the handler.
- if err := safecopy.ReplaceSignalHandler(syscall.SIGSEGV, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
- panic(fmt.Sprintf("Unable to set handler for signal %d: %v", syscall.SIGSEGV, err))
+ if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil {
+ panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err))
}
// Extract the address for the trampoline.
diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/bluepill_allocator.go
index 80942e9c9..9485e1301 100644
--- a/pkg/sentry/platform/kvm/allocator.go
+++ b/pkg/sentry/platform/kvm/bluepill_allocator.go
@@ -21,56 +21,80 @@ import (
)
type allocator struct {
- base *pagetables.RuntimeAllocator
+ base pagetables.RuntimeAllocator
+
+ // cpu must be set prior to any pagetable operation.
+ //
+ // Due to the way KVM's shadow paging implementation works,
+ // modifications to the page tables while in host mode may not be
+ // trapped, leading to the shadow pages being out of sync. Therefore,
+ // we need to ensure that we are in guest mode for page table
+ // modifications. See the call to bluepill, below.
+ cpu *vCPU
}
// newAllocator is used to define the allocator.
-func newAllocator() allocator {
- return allocator{
- base: pagetables.NewRuntimeAllocator(),
- }
+func newAllocator() *allocator {
+ a := new(allocator)
+ a.base.Init()
+ return a
}
// NewPTEs implements pagetables.Allocator.NewPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) NewPTEs() *pagetables.PTEs {
- return a.base.NewPTEs()
+func (a *allocator) NewPTEs() *pagetables.PTEs {
+ ptes := a.base.NewPTEs() // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
+ return ptes
}
// PhysicalFor returns the physical address for a set of PTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
+func (a *allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
virtual := a.base.PhysicalFor(ptes)
physical, _, ok := translateToPhysical(virtual)
if !ok {
- panic(fmt.Sprintf("PhysicalFor failed for %p", ptes))
+ panic(fmt.Sprintf("PhysicalFor failed for %p", ptes)) // escapes: panic.
}
return physical
}
// LookupPTEs implements pagetables.Allocator.LookupPTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
- virtualStart, physicalStart, _, ok := calculateBluepillFault(physical)
+func (a *allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
+ virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions)
if !ok {
- panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical))
+ panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) // escapes: panic.
}
return a.base.LookupPTEs(virtualStart + (physical - physicalStart))
}
// FreePTEs implements pagetables.Allocator.FreePTEs.
//
+// +checkescape:all
+//
//go:nosplit
-func (a allocator) FreePTEs(ptes *pagetables.PTEs) {
- a.base.FreePTEs(ptes)
+func (a *allocator) FreePTEs(ptes *pagetables.PTEs) {
+ a.base.FreePTEs(ptes) // escapes: bluepill below.
+ if a.cpu != nil {
+ bluepill(a.cpu)
+ }
}
// Recycle implements pagetables.Allocator.Recycle.
//
//go:nosplit
-func (a allocator) Recycle() {
+func (a *allocator) Recycle() {
a.base.Recycle()
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.go b/pkg/sentry/platform/kvm/bluepill_amd64.go
index 421c88220..ddc1554d5 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64.go
@@ -24,26 +24,10 @@ import (
)
var (
- // bounceSignal is the signal used for bouncing KVM.
- //
- // We use SIGCHLD because it is not masked by the runtime, and
- // it will be ignored properly by other parts of the kernel.
- bounceSignal = syscall.SIGCHLD
-
- // bounceSignalMask has only bounceSignal set.
- bounceSignalMask = uint64(1 << (uint64(bounceSignal) - 1))
-
- // bounce is the interrupt vector used to return to the kernel.
- bounce = uint32(ring0.VirtualizationException)
+ // The action for bluepillSignal is changed by sigaction().
+ bluepillSignal = syscall.SIGSEGV
)
-// redpill on amd64 invokes a syscall with -1.
-//
-//go:nosplit
-func redpill() {
- syscall.RawSyscall(^uintptr(0), 0, 0, 0)
-}
-
// bluepillArchEnter is called during bluepillEnter.
//
//go:nosplit
@@ -79,6 +63,8 @@ func bluepillArchEnter(context *arch.SignalContext64) *vCPU {
// KernelSyscall handles kernel syscalls.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelSyscall() {
regs := c.Registers()
@@ -88,13 +74,15 @@ func (c *vCPU) KernelSyscall() {
// We only trigger a bluepill entry in the bluepill function, and can
// therefore be guaranteed that there is no floating point state to be
// loaded on resuming from halt. We only worry about saving on exit.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no, reload host segment.
}
// KernelException handles kernel exceptions.
//
+// +checkescape:all
+//
//go:nosplit
func (c *vCPU) KernelException(vector ring0.Vector) {
regs := c.Registers()
@@ -105,9 +93,9 @@ func (c *vCPU) KernelException(vector ring0.Vector) {
regs.Rip = 0
}
// See above.
- ring0.SaveFloatingPoint((*byte)(c.floatingPointState))
+ ring0.SaveFloatingPoint((*byte)(c.floatingPointState)) // escapes: no.
ring0.Halt()
- ring0.WriteFS(uintptr(regs.Fs_base)) // Reload host segment.
+ ring0.WriteFS(uintptr(regs.Fs_base)) // escapes: no; reload host segment.
}
// bluepillArchExit is called during bluepillEnter.
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index 9d8af143e..03a98512e 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -17,19 +17,13 @@
package kvm
import (
+ "syscall"
"unsafe"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
-// bluepillArchContext returns the arch-specific context.
-//
-//go:nosplit
-func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
- return &((*arch.UContext64)(context).MContext)
-}
-
// dieArchSetup initializes the state for dieTrampoline.
//
// The amd64 dieTrampoline requires the vCPU to be set in BX, and the last RIP
@@ -38,6 +32,12 @@ func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
//
//go:nosplit
func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // Reload all registers to have an accurate stack trace when we return
+ // to host mode. This means that the stack should be unwound correctly.
+ if errno := c.getUserRegisters(&c.dieState.guestRegs); errno != 0 {
+ throw(c.dieState.message)
+ }
+
// If the vCPU is in user mode, we set the stack to the stored stack
// value in the vCPU itself. We don't want to unwind the user stack.
if guestRegs.RFLAGS&ring0.UserFlagsSet == ring0.UserFlagsSet {
@@ -54,3 +54,34 @@ func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
context.Rbx = uint64(uintptr(unsafe.Pointer(c)))
context.Rip = uint64(dieTrampolineAddr)
}
+
+// getHypercallID returns hypercall ID.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ return _KVM_HYPERCALL_MAX
+}
+
+// bluepillStopGuest is reponsible for injecting interrupt.
+//
+//go:nosplit
+func bluepillStopGuest(c *vCPU) {
+ // Interrupt: we must have requested an interrupt
+ // window; set the interrupt line.
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_INTERRUPT,
+ uintptr(unsafe.Pointer(&bounce))); errno != 0 {
+ throw("interrupt injection failed")
+ }
+ // Clear previous injection request.
+ c.runData.requestInterruptWindow = 0
+}
+
+// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
+//
+//go:nosplit
+func bluepillReadyStopGuest(c *vCPU) bool {
+ return c.runData.readyForInterruptInjection != 0
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
new file mode 100644
index 000000000..ed5ae03d3
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -0,0 +1,124 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+var (
+ // The action for bluepillSignal is changed by sigaction().
+ bluepillSignal = syscall.SIGILL
+
+ // vcpuSErr is the event of system error.
+ vcpuSErr = kvmVcpuEvents{
+ exception: exception{
+ sErrPending: 1,
+ sErrHasEsr: 0,
+ pad: [6]uint8{0, 0, 0, 0, 0, 0},
+ sErrEsr: 1,
+ },
+ rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
+ }
+)
+
+// bluepillArchEnter is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) {
+ c = vCPUPtr(uintptr(context.Regs[8]))
+ regs := c.CPU.Registers()
+ regs.Regs = context.Regs
+ regs.Sp = context.Sp
+ regs.Pc = context.Pc
+ regs.Pstate = context.Pstate
+ regs.Pstate &^= uint64(ring0.PsrFlagsClear)
+ regs.Pstate |= ring0.KernelFlagsSet
+ return
+}
+
+// bluepillArchExit is called during bluepillEnter.
+//
+//go:nosplit
+func bluepillArchExit(c *vCPU, context *arch.SignalContext64) {
+ regs := c.CPU.Registers()
+ context.Regs = regs.Regs
+ context.Sp = regs.Sp
+ context.Pc = regs.Pc
+ context.Pstate = regs.Pstate
+ context.Pstate &^= uint64(ring0.PsrFlagsClear)
+ context.Pstate |= ring0.UserFlagsSet
+
+ lazyVfp := c.GetLazyVFP()
+ if lazyVfp != 0 {
+ fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+ context.Fpsimd64.Fpsr = fpsimd.Fpsr
+ context.Fpsimd64.Fpcr = fpsimd.Fpcr
+ context.Fpsimd64.Vregs = fpsimd.Vregs
+ }
+}
+
+// KernelSyscall handles kernel syscalls.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (c *vCPU) KernelSyscall() {
+ regs := c.Registers()
+ if regs.Regs[8] != ^uint64(0) {
+ regs.Pc -= 4 // Rewind.
+ }
+
+ vfpEnable := ring0.CPACREL1()
+ if vfpEnable != 0 {
+ fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+ fpcr := ring0.GetFPCR()
+ fpsr := ring0.GetFPSR()
+ fpsimd.Fpcr = uint32(fpcr)
+ fpsimd.Fpsr = uint32(fpsr)
+ ring0.SaveVRegs((*byte)(c.floatingPointState))
+ }
+
+ ring0.Halt()
+}
+
+// KernelException handles kernel exceptions.
+//
+// +checkescape:all
+//
+//go:nosplit
+func (c *vCPU) KernelException(vector ring0.Vector) {
+ regs := c.Registers()
+ if vector == ring0.Vector(bounce) {
+ regs.Pc = 0
+ }
+
+ vfpEnable := ring0.CPACREL1()
+ if vfpEnable != 0 {
+ fpsimd := fpsimdPtr((*byte)(c.floatingPointState))
+ fpcr := ring0.GetFPCR()
+ fpsr := ring0.GetFPSR()
+ fpsimd.Fpcr = uint32(fpcr)
+ fpsimd.Fpsr = uint32(fpsr)
+ ring0.SaveVRegs((*byte)(c.floatingPointState))
+ }
+
+ ring0.Halt()
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s
new file mode 100644
index 000000000..04efa0147
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.s
@@ -0,0 +1,89 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// VCPU_CPU is the location of the CPU in the vCPU struct.
+//
+// This is guaranteed to be zero.
+#define VCPU_CPU 0x0
+
+// CPU_SELF is the self reference in ring0's percpu.
+//
+// This is guaranteed to be zero.
+#define CPU_SELF 0x0
+
+// Context offsets.
+//
+// Only limited use of the context is done in the assembly stub below, most is
+// done in the Go handlers.
+#define SIGINFO_SIGNO 0x0
+#define CONTEXT_PC 0x1B8
+#define CONTEXT_R0 0xB8
+
+// See bluepill.go.
+TEXT ·bluepill(SB),NOSPLIT,$0
+begin:
+ MOVD vcpu+0(FP), R8
+ MOVD $VCPU_CPU(R8), R9
+ ORR $0xffff000000000000, R9, R9
+ // Trigger sigill.
+ // In ring0.Start(), the value of R8 will be stored into tpidr_el1.
+ // When the context was loaded into vcpu successfully,
+ // we will check if the value of R10 and R9 are the same.
+ WORD $0xd538d08a // MRS TPIDR_EL1, R10
+check_vcpu:
+ CMP R10, R9
+ BEQ right_vCPU
+wrong_vcpu:
+ CALL ·redpill(SB)
+ B begin
+right_vCPU:
+ RET
+
+// sighandler: see bluepill.go for documentation.
+//
+// The arguments are the following:
+//
+// R0 - The signal number.
+// R1 - Pointer to siginfo_t structure.
+// R2 - Pointer to ucontext structure.
+//
+TEXT ·sighandler(SB),NOSPLIT,$0
+ // si_signo should be sigill.
+ MOVD SIGINFO_SIGNO(R1), R7
+ CMPW $4, R7
+ BNE fallback
+
+ MOVD CONTEXT_PC(R2), R7
+ CMPW $0, R7
+ BEQ fallback
+
+ MOVD R2, 8(RSP)
+ BL ·bluepillHandler(SB) // Call the handler.
+
+ RET
+
+fallback:
+ // Jump to the previous signal handler.
+ MOVD ·savedHandler(SB), R7
+ B (R7)
+
+// dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation.
+TEXT ·dieTrampoline(SB),NOSPLIT,$0
+ // R0: Fake the old PC as caller
+ // R1: First argument (vCPU)
+ MOVD.P R1, 8(RSP) // R1: First argument (vCPU)
+ MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller
+ B ·dieHandler(SB)
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
new file mode 100644
index 000000000..b35c930e2
--- /dev/null
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -0,0 +1,97 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+// fpsimdPtr returns a fpsimd64 for the given address.
+//
+//go:nosplit
+func fpsimdPtr(addr *byte) *arch.FpsimdContext {
+ return (*arch.FpsimdContext)(unsafe.Pointer(addr))
+}
+
+// dieArchSetup initialies the state for dieTrampoline.
+//
+// The arm64 dieTrampoline requires the vCPU to be set in R1, and the last PC
+// to be in R0. The trampoline then simulates a call to dieHandler from the
+// provided PC.
+//
+//go:nosplit
+func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) {
+ // If the vCPU is in user mode, we set the stack to the stored stack
+ // value in the vCPU itself. We don't want to unwind the user stack.
+ if guestRegs.Regs.Pstate&ring0.PsrModeMask == ring0.UserFlagsSet {
+ regs := c.CPU.Registers()
+ context.Regs[0] = regs.Regs[0]
+ context.Sp = regs.Sp
+ context.Regs[29] = regs.Regs[29] // stack base address
+ } else {
+ context.Regs[0] = guestRegs.Regs.Pc
+ context.Sp = guestRegs.Regs.Sp
+ context.Regs[29] = guestRegs.Regs.Regs[29]
+ context.Pstate = guestRegs.Regs.Pstate
+ }
+ context.Regs[1] = uint64(uintptr(unsafe.Pointer(c)))
+ context.Pc = uint64(dieTrampolineAddr)
+}
+
+// bluepillArchFpContext returns the arch-specific fpsimd context.
+//
+//go:nosplit
+func bluepillArchFpContext(context unsafe.Pointer) *arch.FpsimdContext {
+ return &((*arch.SignalContext64)(context).Fpsimd64)
+}
+
+// getHypercallID returns hypercall ID.
+//
+// On Arm64, the MMIO address should be 64-bit aligned.
+//
+//go:nosplit
+func getHypercallID(addr uintptr) int {
+ if addr < arm64HypercallMMIOBase || addr >= (arm64HypercallMMIOBase+_AARCH64_HYPERCALL_MMIO_SIZE) {
+ return _KVM_HYPERCALL_MAX
+ } else {
+ return int(((addr) - arm64HypercallMMIOBase) >> 3)
+ }
+}
+
+// bluepillStopGuest is reponsible for injecting sError.
+//
+//go:nosplit
+func bluepillStopGuest(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 {
+ throw("sErr injection failed")
+ }
+}
+
+// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection.
+//
+//go:nosplit
+func bluepillReadyStopGuest(c *vCPU) bool {
+ return true
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index b97476053..e34f46aeb 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -18,7 +18,7 @@ import (
"sync/atomic"
"syscall"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -46,9 +46,9 @@ func yield() {
// calculateBluepillFault calculates the fault address range.
//
//go:nosplit
-func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) {
+func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) {
alignedPhysical := physical &^ uintptr(usermem.PageSize-1)
- for _, pr := range physicalRegions {
+ for _, pr := range phyRegions {
end := pr.physical + pr.length
if physical < pr.physical || physical >= end {
continue
@@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng
// The corresponding virtual address is returned. This may throw on error.
//
//go:nosplit
-func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) {
+func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) {
// Paging fault: we need to map the underlying physical pages for this
// fault. This all has to be done in this function because we're in a
// signal handler context. (We can't call any functions that might
// split the stack.)
- virtualStart, physicalStart, length, ok := calculateBluepillFault(physical)
+ virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
if !ok {
return 0, false
}
@@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) {
yield() // Race with another call.
slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0))
}
- errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart)
+ errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags)
if errno == 0 {
// Successfully added region; we can increment nextSlot and
// allow another set to proceed here.
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 7e8e9f42a..bf357de1a 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -23,6 +23,8 @@ import (
"sync/atomic"
"syscall"
"unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
//go:linkname throw runtime.throw
@@ -49,12 +51,39 @@ func uintptrValue(addr *byte) uintptr {
return (uintptr)(unsafe.Pointer(addr))
}
+// bluepillArchContext returns the UContext64.
+//
+//go:nosplit
+func bluepillArchContext(context unsafe.Pointer) *arch.SignalContext64 {
+ return &((*arch.UContext64)(context).MContext)
+}
+
+// bluepillHandleHlt is reponsible for handling VM-Exit.
+//
+//go:nosplit
+func bluepillGuestExit(c *vCPU, context unsafe.Pointer) {
+ // Copy out registers.
+ bluepillArchExit(c, bluepillArchContext(context))
+
+ // Return to the vCPUReady state; notify any waiters.
+ user := atomic.LoadUint32(&c.state) & vCPUUser
+ switch atomic.SwapUint32(&c.state, user) {
+ case user | vCPUGuest: // Expected case.
+ case user | vCPUGuest | vCPUWaiter:
+ c.notify()
+ default:
+ throw("invalid state")
+ }
+}
+
// bluepillHandler is called from the signal stub.
//
// The world may be stopped while this is executing, and it executes on the
// signal stack. It should only execute raw system calls and functions that are
// explicitly marked go:nosplit.
//
+// +checkescape:all
+//
//go:nosplit
func bluepillHandler(context unsafe.Pointer) {
// Sanitize the registers; interrupts must always be disabled.
@@ -73,20 +102,25 @@ func bluepillHandler(context unsafe.Pointer) {
}
for {
- switch _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0); errno {
+ _, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(c.fd), _KVM_RUN, 0) // escapes: no.
+ switch errno {
case 0: // Expected case.
case syscall.EINTR:
// First, we process whatever pending signal
// interrupted KVM. Since we're in a signal handler
// currently, all signals are masked and the signal
// must have been delivered directly to this thread.
- sig, _, errno := syscall.RawSyscall6(
+ timeout := syscall.Timespec{}
+ sig, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_RT_SIGTIMEDWAIT,
uintptr(unsafe.Pointer(&bounceSignalMask)),
- 0, // siginfo.
- 0, // timeout.
- 8, // sigset size.
+ 0, // siginfo.
+ uintptr(unsafe.Pointer(&timeout)), // timeout.
+ 8, // sigset size.
0, 0)
+ if errno == syscall.EAGAIN {
+ continue
+ }
if errno != 0 {
throw("error waiting for pending signal")
}
@@ -99,12 +133,12 @@ func bluepillHandler(context unsafe.Pointer) {
// PIC, we can't inject an interrupt while they are
// masked. We need to request a window if it's not
// ready.
- if c.runData.readyForInterruptInjection == 0 {
- c.runData.requestInterruptWindow = 1
- continue // Rerun vCPU.
- } else {
+ if bluepillReadyStopGuest(c) {
// Force injection below; the vCPU is ready.
c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN
+ } else {
+ c.runData.requestInterruptWindow = 1
+ continue // Rerun vCPU.
}
case syscall.EFAULT:
// If a fault is not serviceable due to the host
@@ -112,7 +146,7 @@ func bluepillHandler(context unsafe.Pointer) {
// MMIO exit we receive EFAULT from the run ioctl. We
// always inject an NMI here since we may be in kernel
// mode and have interrupts disabled.
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_NMI, 0); errno != 0 {
@@ -143,26 +177,21 @@ func bluepillHandler(context unsafe.Pointer) {
c.die(bluepillArchContext(context), "debug")
return
case _KVM_EXIT_HLT:
- // Copy out registers.
- bluepillArchExit(c, bluepillArchContext(context))
-
- // Return to the vCPUReady state; notify any waiters.
- user := atomic.LoadUint32(&c.state) & vCPUUser
- switch atomic.SwapUint32(&c.state, user) {
- case user | vCPUGuest: // Expected case.
- case user | vCPUGuest | vCPUWaiter:
- c.notify()
- default:
- throw("invalid state")
- }
+ bluepillGuestExit(c, context)
return
case _KVM_EXIT_MMIO:
+ physical := uintptr(c.runData.data[0])
+ if getHypercallID(physical) == _KVM_HYPERCALL_VMEXIT {
+ bluepillGuestExit(c, context)
+ return
+ }
+
// Increment the fault count.
atomic.AddUint32(&c.faults, 1)
// For MMIO, the physical address is the first data item.
- physical := uintptr(c.runData.data[0])
- virtual, ok := handleBluepillFault(c.machine, physical)
+ physical = uintptr(c.runData.data[0])
+ virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
if !ok {
c.die(bluepillArchContext(context), "invalid physical address")
return
@@ -188,17 +217,7 @@ func bluepillHandler(context unsafe.Pointer) {
}
}
case _KVM_EXIT_IRQ_WINDOW_OPEN:
- // Interrupt: we must have requested an interrupt
- // window; set the interrupt line.
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_INTERRUPT,
- uintptr(unsafe.Pointer(&bounce))); errno != 0 {
- throw("interrupt injection failed")
- }
- // Clear previous injection request.
- c.runData.requestInterruptWindow = 0
+ bluepillStopGuest(c)
case _KVM_EXIT_SHUTDOWN:
c.die(bluepillArchContext(context), "shutdown")
return
diff --git a/pkg/sentry/platform/kvm/context.go b/pkg/sentry/platform/kvm/context.go
index 99450d22d..6e6b76416 100644
--- a/pkg/sentry/platform/kvm/context.go
+++ b/pkg/sentry/platform/kvm/context.go
@@ -15,11 +15,12 @@
package kvm
import (
+ pkgcontext "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// context is an implementation of the platform context.
@@ -37,7 +38,8 @@ type context struct {
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(as platform.AddressSpace, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, _ int32) (*arch.SignalInfo, usermem.AccessType, error) {
+ as := mm.AddressSpace()
localAS := as.(*addressSpace)
// Grab a vCPU.
@@ -85,3 +87,12 @@ func (c *context) Switch(as platform.AddressSpace, ac arch.Context, _ int32) (*a
func (c *context) Interrupt() {
c.interrupt.NotifyInterrupt()
}
+
+// Release implements platform.Context.Release().
+func (c *context) Release() {}
+
+// FullStateChanged implements platform.Context.FullStateChanged.
+func (c *context) FullStateChanged() {}
+
+// PullFullState implements platform.Context.PullFullState.
+func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {}
diff --git a/pkg/sentry/platform/kvm/filters.go b/pkg/sentry/platform/kvm/filters_amd64.go
index 7d949f1dd..7d949f1dd 100644
--- a/pkg/sentry/platform/kvm/filters.go
+++ b/pkg/sentry/platform/kvm/filters_amd64.go
diff --git a/pkg/sentry/platform/kvm/filters_arm64.go b/pkg/sentry/platform/kvm/filters_arm64.go
new file mode 100644
index 000000000..9245d07c2
--- /dev/null
+++ b/pkg/sentry/platform/kvm/filters_arm64.go
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// SyscallFilters returns syscalls made exclusively by the KVM platform.
+func (*KVM) SyscallFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_IOCTL: {},
+ syscall.SYS_MMAP: {},
+ syscall.SYS_RT_SIGSUSPEND: {},
+ syscall.SYS_RT_SIGTIMEDWAIT: {},
+ 0xffffffffffffffff: {}, // KVM uses syscall -1 to transition to host.
+ }
+}
diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go
index ee4cd2f4d..ae813e24e 100644
--- a/pkg/sentry/platform/kvm/kvm.go
+++ b/pkg/sentry/platform/kvm/kvm.go
@@ -18,16 +18,47 @@ package kvm
import (
"fmt"
"os"
- "sync"
"syscall"
- "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// userMemoryRegion is a region of physical memory.
+//
+// This mirrors kvm_memory_region.
+type userMemoryRegion struct {
+ slot uint32
+ flags uint32
+ guestPhysAddr uint64
+ memorySize uint64
+ userspaceAddr uint64
+}
+
+// runData is the run structure. This may be mapped for synchronous register
+// access (although that doesn't appear to be supported by my kernel at least).
+//
+// This mirrors kvm_run.
+type runData struct {
+ requestInterruptWindow uint8
+ _ [7]uint8
+
+ exitReason uint32
+ readyForInterruptInjection uint8
+ ifFlag uint8
+ _ [2]uint8
+
+ cr8 uint64
+ apicBase uint64
+
+ // This is the union data for exits. Interpretation depends entirely on
+ // the exitReason above (see vCPU code for more information).
+ data [32]uint64
+}
+
// KVM represents a lightweight VM context.
type KVM struct {
platform.NoCPUPreemptionDetection
@@ -56,18 +87,26 @@ func New(deviceFile *os.File) (*KVM, error) {
// Ensure global initialization is done.
globalOnce.Do(func() {
- physicalInit()
- globalErr = updateSystemValues(int(fd))
- ring0.Init(cpuid.HostFeatureSet())
+ globalErr = updateGlobalOnce(int(fd))
})
if globalErr != nil {
return nil, globalErr
}
// Create a new VM fd.
- vm, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, fd, _KVM_CREATE_VM, 0)
- if errno != 0 {
- return nil, fmt.Errorf("creating VM: %v", errno)
+ var (
+ vm uintptr
+ errno syscall.Errno
+ )
+ for {
+ vm, _, errno = syscall.Syscall(syscall.SYS_IOCTL, fd, _KVM_CREATE_VM, 0)
+ if errno == syscall.EINTR {
+ continue
+ }
+ if errno != 0 {
+ return nil, fmt.Errorf("creating VM: %v", errno)
+ }
+ break
}
// We are done with the device file.
deviceFile.Close()
@@ -152,6 +191,11 @@ func (*constructor) OpenDevice() (*os.File, error) {
return OpenDevice()
}
+// Flags implements platform.Constructor.Flags().
+func (*constructor) Requirements() platform.Requirements {
+ return platform.Requirements{}
+}
+
func init() {
platform.Register("kvm", &constructor{})
}
diff --git a/pkg/sentry/platform/kvm/kvm_amd64.go b/pkg/sentry/platform/kvm/kvm_amd64.go
index 5d8ef4761..093497bc4 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64.go
@@ -17,20 +17,10 @@
package kvm
import (
+ "gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
)
-// userMemoryRegion is a region of physical memory.
-//
-// This mirrors kvm_memory_region.
-type userMemoryRegion struct {
- slot uint32
- flags uint32
- guestPhysAddr uint64
- memorySize uint64
- userspaceAddr uint64
-}
-
// userRegs represents KVM user registers.
//
// This mirrors kvm_regs.
@@ -168,27 +158,6 @@ type modelControlRegisters struct {
entries [16]modelControlRegister
}
-// runData is the run structure. This may be mapped for synchronous register
-// access (although that doesn't appear to be supported by my kernel at least).
-//
-// This mirrors kvm_run.
-type runData struct {
- requestInterruptWindow uint8
- _ [7]uint8
-
- exitReason uint32
- readyForInterruptInjection uint8
- ifFlag uint8
- _ [2]uint8
-
- cr8 uint64
- apicBase uint64
-
- // This is the union data for exits. Interpretation depends entirely on
- // the exitReason above (see vCPU code for more information).
- data [32]uint64
-}
-
// cpuidEntry is a single CPUID entry.
//
// This mirrors kvm_cpuid_entry2.
@@ -211,3 +180,11 @@ type cpuidEntries struct {
_ uint32
entries [_KVM_NR_CPUID_ENTRIES]cpuidEntry
}
+
+// updateGlobalOnce does global initialization. It has to be called only once.
+func updateGlobalOnce(fd int) error {
+ physicalInit()
+ err := updateSystemValues(int(fd))
+ ring0.Init(cpuid.HostFeatureSet())
+ return err
+}
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
new file mode 100644
index 000000000..c0b4fd374
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package kvm
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+)
+
+func TestSegments(t *testing.T) {
+ applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
+ testutil.SetTestSegments(regs)
+ for {
+ var si arch.SignalInfo
+ if _, err := c.SwitchToUser(ring0.SwitchOpts{
+ Registers: regs,
+ FloatingPointState: dummyFPState,
+ PageTables: pt,
+ FullRestore: true,
+ }, &si); err == platform.ErrContextInterrupt {
+ continue // Retry.
+ } else if err != nil {
+ t.Errorf("application segment check with full restore got unexpected error: %v", err)
+ }
+ if err := testutil.CheckTestSegments(regs); err != nil {
+ t.Errorf("application segment check with full restore failed: %v", err)
+ }
+ break // Done.
+ }
+ return false
+ })
+}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
new file mode 100644
index 000000000..0b06a923a
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+)
+
+type kvmOneReg struct {
+ id uint64
+ addr uint64
+}
+
+// arm64HypercallMMIOBase is MMIO base address used to dispatch hypercalls.
+var arm64HypercallMMIOBase uintptr
+
+const KVM_NR_SPSR = 5
+
+type userFpsimdState struct {
+ vregs [64]uint64
+ fpsr uint32
+ fpcr uint32
+ reserved [2]uint32
+}
+
+type userRegs struct {
+ Regs arch.Registers
+ sp_el1 uint64
+ elr_el1 uint64
+ spsr [KVM_NR_SPSR]uint64
+ fpRegs userFpsimdState
+}
+
+type exception struct {
+ sErrPending uint8
+ sErrHasEsr uint8
+ pad [6]uint8
+ sErrEsr uint64
+}
+
+type kvmVcpuEvents struct {
+ exception
+ rsvd [12]uint32
+}
+
+// updateGlobalOnce does global initialization. It has to be called only once.
+func updateGlobalOnce(fd int) error {
+ physicalInit()
+ err := updateSystemValues(int(fd))
+ ring0.Init()
+ return err
+}
diff --git a/pkg/sentry/usermem/usermem_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
index 876783e78..48ccf8474 100644
--- a/pkg/sentry/usermem/usermem_unsafe.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
@@ -12,16 +12,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+// +build arm64
+
+package kvm
import (
- "unsafe"
+ "fmt"
+ "syscall"
+)
+
+var (
+ runDataSize int
+ hasGuestPCID bool
)
-// stringFromImmutableBytes is equivalent to string(bs), except that it never
-// copies even if escape analysis can't prove that bs does not escape. This is
-// only valid if bs is never mutated after stringFromImmutableBytes returns.
-func stringFromImmutableBytes(bs []byte) string {
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
+func updateSystemValues(fd int) error {
+ // Extract the mmap size.
+ sz, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(fd), _KVM_GET_VCPU_MMAP_SIZE, 0)
+ if errno != 0 {
+ return fmt.Errorf("getting VCPU mmap size: %v", errno)
+ }
+ // Save the data.
+ runDataSize = int(sz)
+ hasGuestPCID = true
+
+ // Success.
+ return nil
}
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index d05f05c29..3bf918446 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -35,6 +35,8 @@ const (
_KVM_GET_SUPPORTED_CPUID = 0xc008ae05
_KVM_SET_CPUID2 = 0x4008ae90
_KVM_SET_SIGNAL_MASK = 0x4004ae8b
+ _KVM_GET_VCPU_EVENTS = 0x8040ae9f
+ _KVM_SET_VCPU_EVENTS = 0x4040aea0
)
// KVM exit reasons.
@@ -49,11 +51,15 @@ const (
_KVM_EXIT_SHUTDOWN = 0x8
_KVM_EXIT_FAIL_ENTRY = 0x9
_KVM_EXIT_INTERNAL_ERROR = 0x11
+ _KVM_EXIT_SYSTEM_EVENT = 0x18
)
// KVM capability options.
const (
- _KVM_CAP_MAX_VCPUS = 0x42
+ _KVM_CAP_MAX_VCPUS = 0x42
+ _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5
+ _KVM_CAP_VCPU_EVENTS = 0x29
+ _KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e
)
// KVM limits.
@@ -62,3 +68,20 @@ const (
_KVM_NR_INTERRUPTS = 0x100
_KVM_NR_CPUID_ENTRIES = 0x100
)
+
+// KVM kvm_memory_region::flags.
+const (
+ _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0
+ _KVM_MEM_READONLY = uint32(1) << 1
+ _KVM_MEM_FLAGS_NONE = 0
+)
+
+// KVM hypercall list.
+// Canonical list of hypercalls supported.
+const (
+ // On amd64, it uses 'HLT' to leave the guest.
+ // Unlike amd64, arm64 can only uses mmio_exit/psci to leave the guest.
+ // _KVM_HYPERCALL_VMEXIT is only used on Arm64 for now.
+ _KVM_HYPERCALL_VMEXIT int = iota
+ _KVM_HYPERCALL_MAX
+)
diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go
new file mode 100644
index 000000000..9a7be3655
--- /dev/null
+++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go
@@ -0,0 +1,152 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+// KVM ioctls for Arm64.
+const (
+ _KVM_GET_ONE_REG = 0x4010aeab
+ _KVM_SET_ONE_REG = 0x4010aeac
+
+ _KVM_ARM_TARGET_GENERIC_V8 = 5
+ _KVM_ARM_PREFERRED_TARGET = 0x8020aeaf
+ _KVM_ARM_VCPU_INIT = 0x4020aeae
+ _KVM_ARM64_REGS_PSTATE = 0x6030000000100042
+ _KVM_ARM64_REGS_SP_EL1 = 0x6030000000100044
+ _KVM_ARM64_REGS_R0 = 0x6030000000100000
+ _KVM_ARM64_REGS_R1 = 0x6030000000100002
+ _KVM_ARM64_REGS_R2 = 0x6030000000100004
+ _KVM_ARM64_REGS_R3 = 0x6030000000100006
+ _KVM_ARM64_REGS_R8 = 0x6030000000100010
+ _KVM_ARM64_REGS_R18 = 0x6030000000100024
+ _KVM_ARM64_REGS_PC = 0x6030000000100040
+ _KVM_ARM64_REGS_MAIR_EL1 = 0x603000000013c510
+ _KVM_ARM64_REGS_TCR_EL1 = 0x603000000013c102
+ _KVM_ARM64_REGS_TTBR0_EL1 = 0x603000000013c100
+ _KVM_ARM64_REGS_TTBR1_EL1 = 0x603000000013c101
+ _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080
+ _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082
+ _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600
+)
+
+// Arm64: Architectural Feature Access Control Register EL1.
+const (
+ _FPEN_NOTRAP = 3
+ _FPEN_SHIFT = 20
+)
+
+// Arm64: System Control Register EL1.
+const (
+ _SCTLR_M = 1 << 0
+ _SCTLR_C = 1 << 2
+ _SCTLR_I = 1 << 12
+)
+
+// Arm64: Translation Control Register EL1.
+const (
+ _TCR_IPS_40BITS = 2 << 32 // PA=40
+ _TCR_IPS_48BITS = 5 << 32 // PA=48
+
+ _TCR_T0SZ_OFFSET = 0
+ _TCR_T1SZ_OFFSET = 16
+ _TCR_IRGN0_SHIFT = 8
+ _TCR_IRGN1_SHIFT = 24
+ _TCR_ORGN0_SHIFT = 10
+ _TCR_ORGN1_SHIFT = 26
+ _TCR_SH0_SHIFT = 12
+ _TCR_SH1_SHIFT = 28
+ _TCR_TG0_SHIFT = 14
+ _TCR_TG1_SHIFT = 30
+
+ _TCR_T0SZ_VA48 = 64 - 48 // VA=48
+ _TCR_T1SZ_VA48 = 64 - 48 // VA=48
+
+ _TCR_A1 = 1 << 22
+ _TCR_ASID16 = 1 << 36
+ _TCR_TBI0 = 1 << 37
+
+ _TCR_TXSZ_VA48 = (_TCR_T0SZ_VA48 << _TCR_T0SZ_OFFSET) | (_TCR_T1SZ_VA48 << _TCR_T1SZ_OFFSET)
+
+ _TCR_TG0_4K = 0 << _TCR_TG0_SHIFT // 4K
+ _TCR_TG0_64K = 1 << _TCR_TG0_SHIFT // 64K
+
+ _TCR_TG1_4K = 2 << _TCR_TG1_SHIFT
+
+ _TCR_TG_FLAGS = _TCR_TG0_4K | _TCR_TG1_4K
+
+ _TCR_IRGN0_WBWA = 1 << _TCR_IRGN0_SHIFT
+ _TCR_IRGN1_WBWA = 1 << _TCR_IRGN1_SHIFT
+ _TCR_IRGN_WBWA = _TCR_IRGN0_WBWA | _TCR_IRGN1_WBWA
+
+ _TCR_ORGN0_WBWA = 1 << _TCR_ORGN0_SHIFT
+ _TCR_ORGN1_WBWA = 1 << _TCR_ORGN1_SHIFT
+
+ _TCR_ORGN_WBWA = _TCR_ORGN0_WBWA | _TCR_ORGN1_WBWA
+
+ _TCR_SHARED = (3 << _TCR_SH0_SHIFT) | (3 << _TCR_SH1_SHIFT)
+
+ _TCR_CACHE_FLAGS = _TCR_IRGN_WBWA | _TCR_ORGN_WBWA
+)
+
+// Arm64: Memory Attribute Indirection Register EL1.
+const (
+ _MT_DEVICE_nGnRnE = 0
+ _MT_DEVICE_nGnRE = 1
+ _MT_DEVICE_GRE = 2
+ _MT_NORMAL_NC = 3
+ _MT_NORMAL = 4
+ _MT_NORMAL_WT = 5
+ _MT_EL1_INIT = (0 << _MT_DEVICE_nGnRnE) | (0x4 << _MT_DEVICE_nGnRE * 8) | (0xc << _MT_DEVICE_GRE * 8) | (0x44 << _MT_NORMAL_NC * 8) | (0xff << _MT_NORMAL * 8) | (0xbb << _MT_NORMAL_WT * 8)
+)
+
+const (
+ _KVM_ARM_VCPU_POWER_OFF = 0 // CPU is started in OFF state
+ _KVM_ARM_VCPU_PSCI_0_2 = 2 // CPU uses PSCI v0.2
+)
+
+// Arm64: Exception Syndrome Register EL1.
+const (
+ _ESR_ELx_EC_SHIFT = 26
+ _ESR_ELx_EC_MASK = 0x3F << _ESR_ELx_EC_SHIFT
+
+ _ESR_ELx_EC_IMP_DEF = 0x1f
+ _ESR_ELx_EC_IABT_LOW = 0x20
+ _ESR_ELx_EC_IABT_CUR = 0x21
+ _ESR_ELx_EC_PC_ALIGN = 0x22
+
+ _ESR_ELx_CM = 1 << 8
+ _ESR_ELx_WNR = 1 << 6
+
+ _ESR_ELx_FSC = 0x3F
+
+ _ESR_SEGV_MAPERR_L0 = 0x4
+ _ESR_SEGV_MAPERR_L1 = 0x5
+ _ESR_SEGV_MAPERR_L2 = 0x6
+ _ESR_SEGV_MAPERR_L3 = 0x7
+
+ _ESR_SEGV_ACCERR_L1 = 0x9
+ _ESR_SEGV_ACCERR_L2 = 0xa
+ _ESR_SEGV_ACCERR_L3 = 0xb
+
+ _ESR_SEGV_PEMERR_L1 = 0xd
+ _ESR_SEGV_PEMERR_L2 = 0xe
+ _ESR_SEGV_PEMERR_L3 = 0xf
+)
+
+// Arm64: MMIO base address used to dispatch hypercalls.
+const (
+ // on Arm64, the MMIO address must be 64-bit aligned.
+ // Currently, we only need 1 hypercall: hypercall_vmexit.
+ _AARCH64_HYPERCALL_MMIO_SIZE = 1 << 3
+)
diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go
index 30df725d4..45b3180f1 100644
--- a/pkg/sentry/platform/kvm/kvm_test.go
+++ b/pkg/sentry/platform/kvm/kvm_test.go
@@ -27,7 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
var dummyFPState = (*byte)(arch.NewFloatingPointData())
@@ -117,10 +117,10 @@ func TestKernelFloatingPoint(t *testing.T) {
})
}
-func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *syscall.PtraceRegs, *pagetables.PageTables) bool) {
+func applicationTest(t testHarness, useHostMappings bool, target func(), fn func(*vCPU, *arch.Registers, *pagetables.PageTables) bool) {
// Initialize registers & page tables.
var (
- regs syscall.PtraceRegs
+ regs arch.Registers
pt *pagetables.PageTables
)
testutil.SetTestTarget(&regs, target)
@@ -154,7 +154,7 @@ func applicationTest(t testHarness, useHostMappings bool, target func(), fn func
}
func TestApplicationSyscall(t *testing.T) {
- applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -168,7 +168,7 @@ func TestApplicationSyscall(t *testing.T) {
}
return false
})
- applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -184,7 +184,7 @@ func TestApplicationSyscall(t *testing.T) {
}
func TestApplicationFault(t *testing.T) {
- applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
@@ -199,7 +199,7 @@ func TestApplicationFault(t *testing.T) {
}
return false
})
- applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, nil) // Cause fault.
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
@@ -216,7 +216,7 @@ func TestApplicationFault(t *testing.T) {
}
func TestRegistersSyscall(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.TwiddleRegsSyscall, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
var si arch.SignalInfo
@@ -239,7 +239,7 @@ func TestRegistersSyscall(t *testing.T) {
}
func TestRegistersFault(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.TwiddleRegsFault, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTestRegs(regs) // Fill values for all registers.
for {
var si arch.SignalInfo
@@ -262,32 +262,8 @@ func TestRegistersFault(t *testing.T) {
})
}
-func TestSegments(t *testing.T) {
- applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
- testutil.SetTestSegments(regs)
- for {
- var si arch.SignalInfo
- if _, err := c.SwitchToUser(ring0.SwitchOpts{
- Registers: regs,
- FloatingPointState: dummyFPState,
- PageTables: pt,
- FullRestore: true,
- }, &si); err == platform.ErrContextInterrupt {
- continue // Retry.
- } else if err != nil {
- t.Errorf("application segment check with full restore got unexpected error: %v", err)
- }
- if err := testutil.CheckTestSegments(regs); err != nil {
- t.Errorf("application segment check with full restore failed: %v", err)
- }
- break // Done.
- }
- return false
- })
-}
-
func TestBounce(t *testing.T) {
- applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
go func() {
time.Sleep(time.Millisecond)
c.BounceToKernel()
@@ -302,7 +278,7 @@ func TestBounce(t *testing.T) {
}
return false
})
- applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
go func() {
time.Sleep(time.Millisecond)
c.BounceToKernel()
@@ -321,7 +297,7 @@ func TestBounce(t *testing.T) {
}
func TestBounceStress(t *testing.T) {
- applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
randomSleep := func() {
// O(hundreds of microseconds) is appropriate to ensure
// different overlaps and different schedules.
@@ -357,7 +333,7 @@ func TestBounceStress(t *testing.T) {
func TestInvalidate(t *testing.T) {
var data uintptr // Used below.
- applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, true, testutil.Touch, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
testutil.SetTouchTarget(regs, &data) // Read legitimate value.
for {
var si arch.SignalInfo
@@ -398,7 +374,7 @@ func IsFault(err error, si *arch.SignalInfo) bool {
}
func TestEmptyAddressSpace(t *testing.T) {
- applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -412,7 +388,7 @@ func TestEmptyAddressSpace(t *testing.T) {
}
return false
})
- applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(t, false, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -471,7 +447,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
i int // Iteration includes machine.Get() / machine.Put().
a int // Count for ErrContextInterrupt.
)
- applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
@@ -493,7 +469,7 @@ func BenchmarkApplicationSyscall(b *testing.B) {
func BenchmarkKernelSyscall(b *testing.B) {
// Note that the target passed here is irrelevant, we never execute SwitchToUser.
- applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(b, true, testutil.Getpid, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
// iteration does not include machine.Get() / machine.Put().
for i := 0; i < b.N; i++ {
testutil.Getpid()
@@ -508,7 +484,7 @@ func BenchmarkWorldSwitchToUserRoundtrip(b *testing.B) {
i int
a int
)
- applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *syscall.PtraceRegs, pt *pagetables.PageTables) bool {
+ applicationTest(b, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool {
var si arch.SignalInfo
if _, err := c.SwitchToUser(ring0.SwitchOpts{
Registers: regs,
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index cc6c138b2..6c54712d1 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -17,7 +17,6 @@ package kvm
import (
"fmt"
"runtime"
- "sync"
"sync/atomic"
"syscall"
@@ -26,7 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// machine contains state associated with the VM as a whole.
@@ -52,16 +52,19 @@ type machine struct {
// available is notified when vCPUs are available.
available sync.Cond
- // vCPUs are the machine vCPUs.
+ // vCPUsByTID are the machine vCPUs.
//
// These are populated dynamically.
- vCPUs map[uint64]*vCPU
+ vCPUsByTID map[uint64]*vCPU
// vCPUsByID are the machine vCPUs, can be indexed by the vCPU's ID.
- vCPUsByID map[int]*vCPU
+ vCPUsByID []*vCPU
// maxVCPUs is the maximum number of vCPUs supported by the machine.
maxVCPUs int
+
+ // nextID is the next vCPU ID.
+ nextID uint32
}
const (
@@ -137,9 +140,8 @@ type dieState struct {
//
// Precondition: mu must be held.
func (m *machine) newVCPU() *vCPU {
- id := len(m.vCPUs)
-
// Create the vCPU.
+ id := int(atomic.AddUint32(&m.nextID, 1) - 1)
fd, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CREATE_VCPU, uintptr(id))
if errno != 0 {
panic(fmt.Sprintf("error creating new vCPU: %v", errno))
@@ -176,11 +178,7 @@ func (m *machine) newVCPU() *vCPU {
// newMachine returns a new VM context.
func newMachine(vm int) (*machine, error) {
// Create the machine.
- m := &machine{
- fd: vm,
- vCPUs: make(map[uint64]*vCPU),
- vCPUsByID: make(map[int]*vCPU),
- }
+ m := &machine{fd: vm}
m.available.L = &m.mu
m.kernel.Init(ring0.KernelOpts{
PageTables: pagetables.New(newAllocator()),
@@ -194,6 +192,10 @@ func newMachine(vm int) (*machine, error) {
}
log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs)
+ // Create the vCPUs map/slices.
+ m.vCPUsByTID = make(map[uint64]*vCPU)
+ m.vCPUsByID = make([]*vCPU, m.maxVCPUs)
+
// Apply the physical mappings. Note that these mappings may point to
// guest physical addresses that are not actually available. These
// physical pages are mapped on demand, see kernel_unsafe.go.
@@ -215,6 +217,17 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
+ var physicalRegionsReadOnly []physicalRegion
+ var physicalRegionsAvailable []physicalRegion
+
+ physicalRegionsReadOnly = rdonlyRegionsForSetMem()
+ physicalRegionsAvailable = availableRegionsForSetMem()
+
+ // Map all read-only regions.
+ for _, r := range physicalRegionsReadOnly {
+ m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY)
+ }
+
// Ensure that the currently mapped virtual regions are actually
// available in the VM. Note that this doesn't guarantee no future
// faults, however it should guarantee that everything is available to
@@ -223,6 +236,13 @@ func newMachine(vm int) (*machine, error) {
if excludeVirtualRegion(vr) {
return // skip region.
}
+
+ for _, r := range physicalRegionsReadOnly {
+ if vr.virtual == r.virtual {
+ return
+ }
+ }
+
for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
physical, length, ok := translateToPhysical(virtual)
if !ok {
@@ -236,7 +256,7 @@ func newMachine(vm int) (*machine, error) {
}
// Ensure the physical range is mapped.
- m.mapPhysical(physical, length)
+ m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE)
virtual += length
}
})
@@ -256,9 +276,11 @@ func newMachine(vm int) (*machine, error) {
// not available. This attempts to be efficient for calls in the hot path.
//
// This panics on error.
-func (m *machine) mapPhysical(physical, length uintptr) {
+//
+//go:nosplit
+func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) {
for end := physical + length; physical < end; {
- _, physicalStart, length, ok := calculateBluepillFault(physical)
+ _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
if !ok {
// Should never happen.
panic("mapPhysical on unknown physical address")
@@ -266,7 +288,7 @@ func (m *machine) mapPhysical(physical, length uintptr) {
if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok {
// Not present in the cache; requires setting the slot.
- if _, ok := handleBluepillFault(m, physical); !ok {
+ if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok {
panic("handleBluepillFault failed")
}
}
@@ -286,7 +308,11 @@ func (m *machine) Destroy() {
runtime.SetFinalizer(m, nil)
// Destroy vCPUs.
- for _, c := range m.vCPUs {
+ for _, c := range m.vCPUsByID {
+ if c == nil {
+ continue
+ }
+
// Ensure the vCPU is not still running in guest mode. This is
// possible iff teardown has been done by other threads, and
// somehow a single thread has not executed any system calls.
@@ -311,13 +337,15 @@ func (m *machine) Destroy() {
}
// Get gets an available vCPU.
+//
+// This will return with the OS thread locked.
func (m *machine) Get() *vCPU {
+ m.mu.RLock()
runtime.LockOSThread()
tid := procid.Current()
- m.mu.RLock()
// Check for an exact match.
- if c := m.vCPUs[tid]; c != nil {
+ if c := m.vCPUsByTID[tid]; c != nil {
c.lock()
m.mu.RUnlock()
return c
@@ -325,15 +353,29 @@ func (m *machine) Get() *vCPU {
// The happy path failed. We now proceed to acquire an exclusive lock
// (because the vCPU map may change), and scan all available vCPUs.
+ // In this case, we first unlock the OS thread. Otherwise, if mu is
+ // not available, the current system thread will be parked and a new
+ // system thread spawned. We avoid this situation by simply refreshing
+ // tid after relocking the system thread.
m.mu.RUnlock()
+ runtime.UnlockOSThread()
m.mu.Lock()
+ runtime.LockOSThread()
+ tid = procid.Current()
+
+ // Recheck for an exact match.
+ if c := m.vCPUsByTID[tid]; c != nil {
+ c.lock()
+ m.mu.Unlock()
+ return c
+ }
for {
// Scan for an available vCPU.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if atomic.CompareAndSwapUint32(&c.state, vCPUReady, vCPUUser) {
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -341,17 +383,17 @@ func (m *machine) Get() *vCPU {
}
// Create a new vCPU (maybe).
- if len(m.vCPUs) < m.maxVCPUs {
+ if int(m.nextID) < m.maxVCPUs {
c := m.newVCPU()
c.lock()
- m.vCPUs[tid] = c
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
}
// Scan for something not in user mode.
- for origTID, c := range m.vCPUs {
+ for origTID, c := range m.vCPUsByTID {
if !atomic.CompareAndSwapUint32(&c.state, vCPUGuest, vCPUGuest|vCPUWaiter) {
continue
}
@@ -369,8 +411,8 @@ func (m *machine) Get() *vCPU {
}
// Steal the vCPU.
- delete(m.vCPUs, origTID)
- m.vCPUs[tid] = c
+ delete(m.vCPUsByTID, origTID)
+ m.vCPUsByTID[tid] = c
m.mu.Unlock()
c.loadSegments(tid)
return c
@@ -397,7 +439,7 @@ func (m *machine) Put(c *vCPU) {
// newDirtySet returns a new dirty set.
func (m *machine) newDirtySet() *dirtySet {
return &dirtySet{
- vCPUs: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
+ vCPUMasks: make([]uint64, (m.maxVCPUs+63)/64, (m.maxVCPUs+63)/64),
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index c1cbe33be..acc823ba6 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -26,7 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// initArchState initializes architecture-specific state.
@@ -51,9 +51,10 @@ func (m *machine) initArchState() error {
recover()
debug.SetPanicOnFault(old)
}()
- m.retryInGuest(func() {
- ring0.SetCPUIDFaulting(true)
- })
+ c := m.Get()
+ defer m.Put(c)
+ bluepill(c)
+ ring0.SetCPUIDFaulting(true)
return nil
}
@@ -89,8 +90,10 @@ func (m *machine) dropPageTables(pt *pagetables.PageTables) {
defer m.mu.Unlock()
// Clear from all PCIDs.
- for _, c := range m.vCPUs {
- c.PCIDs.Drop(pt)
+ for _, c := range m.vCPUsByID {
+ if c != nil && c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
}
}
@@ -333,25 +336,12 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
}
}
-// retryInGuest runs the given function in guest mode.
-//
-// If the function does not complete in guest mode (due to execution of a
-// system call due to a GC stall, for example), then it will be retried. The
-// given function must be idempotent as a result of the retry mechanism.
-func (m *machine) retryInGuest(fn func()) {
- c := m.Get()
- defer m.Put(c)
- for {
- c.ClearErrorCode() // See below.
- bluepill(c) // Force guest mode.
- fn() // Execute the given function.
- _, user := c.ErrorCode()
- if user {
- // If user is set, then we haven't bailed back to host
- // mode via a kernel exception or system call. We
- // consider the full function to have executed in guest
- // mode and we can return.
- break
- }
- }
+// On x86 platform, the flags for "setMemoryRegion" can always be set as 0.
+// There is no need to return read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ return nil
+}
+
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ return physicalRegions
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 506ec9af1..290f035dd 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -26,30 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/time"
)
-// setMemoryRegion initializes a region.
-//
-// This may be called from bluepillHandler, and therefore returns an errno
-// directly (instead of wrapping in an error) to avoid allocations.
-//
-//go:nosplit
-func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
- userRegion := userMemoryRegion{
- slot: uint32(slot),
- flags: 0,
- guestPhysAddr: uint64(physical),
- memorySize: uint64(length),
- userspaceAddr: uint64(virtual),
- }
-
- // Set the region.
- _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(m.fd),
- _KVM_SET_USER_MEMORY_REGION,
- uintptr(unsafe.Pointer(&userRegion)))
- return errno
-}
-
// loadSegments copies the current segments.
//
// This may be called from within the signal context and throws on error.
@@ -159,3 +135,43 @@ func (c *vCPU) setSignalMask() error {
}
return nil
}
+
+// setUserRegisters sets user registers in the vCPU.
+func (c *vCPU) setUserRegisters(uregs *userRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return fmt.Errorf("error setting user registers: %v", errno)
+ }
+ return nil
+}
+
+// getUserRegisters reloads user registers in the vCPU.
+//
+// This is safe to call from a nosplit context.
+//
+//go:nosplit
+func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_REGS,
+ uintptr(unsafe.Pointer(uregs))); errno != 0 {
+ return errno
+ }
+ return 0
+}
+
+// setSystemRegisters sets system registers.
+func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SREGS,
+ uintptr(unsafe.Pointer(sregs))); errno != 0 {
+ return fmt.Errorf("error setting system registers: %v", errno)
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
new file mode 100644
index 000000000..9db171af9
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -0,0 +1,183 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type vCPUArchState struct {
+ // PCIDs is the set of PCIDs for this vCPU.
+ //
+ // This starts above fixedKernelPCID.
+ PCIDs *pagetables.PCIDs
+
+ // floatingPointState is the floating point state buffer used in guest
+ // to host transitions. See usage in bluepill_arm64.go.
+ floatingPointState *arch.FloatingPointData
+}
+
+const (
+ // fixedKernelPCID is a fixed kernel PCID used for the kernel page
+ // tables. We must start allocating user PCIDs above this in order to
+ // avoid any conflict (see below).
+ fixedKernelPCID = 1
+
+ // poolPCIDs is the number of PCIDs to record in the database. As this
+ // grows, assignment can take longer, since it is a simple linear scan.
+ // Beyond a relatively small number, there are likely few perform
+ // benefits, since the TLB has likely long since lost any translations
+ // from more than a few PCIDs past.
+ poolPCIDs = 8
+)
+
+// Get all read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ var rdonlyRegions []region
+
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return
+ }
+
+ if !vr.accessType.Write && vr.accessType.Read {
+ rdonlyRegions = append(rdonlyRegions, vr.region)
+ }
+
+ // TODO(gvisor.dev/issue/2686): PROT_NONE should be specially treated.
+ // Workaround: treated as rdonly temporarily.
+ if !vr.accessType.Write && !vr.accessType.Read && !vr.accessType.Execute {
+ rdonlyRegions = append(rdonlyRegions, vr.region)
+ }
+ })
+
+ for _, r := range rdonlyRegions {
+ physical, _, ok := translateToPhysical(r.virtual)
+ if !ok {
+ continue
+ }
+
+ phyRegions = append(phyRegions, physicalRegion{
+ region: region{
+ virtual: r.virtual,
+ length: r.length,
+ },
+ physical: physical,
+ })
+ }
+
+ return phyRegions
+}
+
+// Get all available physicalRegions.
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ var excludeRegions []region
+ applyVirtualRegions(func(vr virtualRegion) {
+ if !vr.accessType.Write {
+ excludeRegions = append(excludeRegions, vr.region)
+ }
+ })
+
+ phyRegions = computePhysicalRegions(excludeRegions)
+
+ return phyRegions
+}
+
+// dropPageTables drops cached page table entries.
+func (m *machine) dropPageTables(pt *pagetables.PageTables) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ // Clear from all PCIDs.
+ for _, c := range m.vCPUsByID {
+ if c.PCIDs != nil {
+ c.PCIDs.Drop(pt)
+ }
+ }
+}
+
+// nonCanonical generates a canonical address return.
+//
+//go:nosplit
+func nonCanonical(addr uint64, signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ *info = arch.SignalInfo{
+ Signo: signal,
+ Code: arch.SignalInfoKernel,
+ }
+ info.SetAddr(addr) // Include address.
+ return usermem.NoAccess, platform.ErrContextSignal
+}
+
+// isInstructionAbort returns true if it is an instruction abort.
+//
+//go:nosplit
+func isInstructionAbort(code uint64) bool {
+ value := (code & _ESR_ELx_EC_MASK) >> _ESR_ELx_EC_SHIFT
+ return value == _ESR_ELx_EC_IABT_LOW
+}
+
+// isWriteFault returns whether it is a write fault.
+//
+//go:nosplit
+func isWriteFault(code uint64) bool {
+ if isInstructionAbort(code) {
+ return false
+ }
+
+ return (code & _ESR_ELx_WNR) != 0
+}
+
+// fault generates an appropriate fault return.
+//
+//go:nosplit
+func (c *vCPU) fault(signal int32, info *arch.SignalInfo) (usermem.AccessType, error) {
+ bluepill(c) // Probably no-op, but may not be.
+ faultAddr := c.GetFaultAddr()
+ code, user := c.ErrorCode()
+
+ if !user {
+ // The last fault serviced by this CPU was not a user
+ // fault, so we can't reliably trust the faultAddr or
+ // the code provided here. We need to re-execute.
+ return usermem.NoAccess, platform.ErrContextInterrupt
+ }
+
+ // Reset the pointed SignalInfo.
+ *info = arch.SignalInfo{Signo: signal}
+ info.SetAddr(uint64(faultAddr))
+
+ ret := code & _ESR_ELx_FSC
+ switch ret {
+ case _ESR_SEGV_MAPERR_L0, _ESR_SEGV_MAPERR_L1, _ESR_SEGV_MAPERR_L2, _ESR_SEGV_MAPERR_L3:
+ info.Code = 1 //SEGV_MAPERR
+ case _ESR_SEGV_ACCERR_L1, _ESR_SEGV_ACCERR_L2, _ESR_SEGV_ACCERR_L3, _ESR_SEGV_PEMERR_L1, _ESR_SEGV_PEMERR_L2, _ESR_SEGV_PEMERR_L3:
+ info.Code = 2 // SEGV_ACCERR.
+ default:
+ info.Code = 2
+ }
+
+ accessType := usermem.AccessType{
+ Read: !isWriteFault(uint64(code)),
+ Write: isWriteFault(uint64(code)),
+ Execute: isInstructionAbort(uint64(code)),
+ }
+
+ return accessType, platform.ErrContextSignal
+}
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
new file mode 100644
index 000000000..905712076
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -0,0 +1,286 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package kvm
+
+import (
+ "fmt"
+ "reflect"
+ "sync/atomic"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+type kvmVcpuInit struct {
+ target uint32
+ features [7]uint32
+}
+
+var vcpuInit kvmVcpuInit
+
+// initArchState initializes architecture-specific state.
+func (m *machine) initArchState() error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_ARM_PREFERRED_TARGET,
+ uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 {
+ panic(fmt.Sprintf("error setting KVM_ARM_PREFERRED_TARGET failed: %v", errno))
+ }
+ return nil
+}
+
+// initArchState initializes architecture-specific state.
+func (c *vCPU) initArchState() error {
+ var (
+ reg kvmOneReg
+ data uint64
+ regGet kvmOneReg
+ dataGet uint64
+ )
+
+ reg.addr = uint64(reflect.ValueOf(&data).Pointer())
+ regGet.addr = uint64(reflect.ValueOf(&dataGet).Pointer())
+
+ vcpuInit.features[0] |= (1 << _KVM_ARM_VCPU_PSCI_0_2)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_ARM_VCPU_INIT,
+ uintptr(unsafe.Pointer(&vcpuInit))); errno != 0 {
+ panic(fmt.Sprintf("error setting KVM_ARM_VCPU_INIT failed: %v", errno))
+ }
+
+ // cpacr_el1
+ reg.id = _KVM_ARM64_REGS_CPACR_EL1
+ // It is off by default, and it is turned on only when in use.
+ data = 0 // Disable fpsimd.
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // tcr_el1
+ data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
+ reg.id = _KVM_ARM64_REGS_TCR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // mair_el1
+ data = _MT_EL1_INIT
+ reg.id = _KVM_ARM64_REGS_MAIR_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // ttbr0_el1
+ data = c.machine.kernel.PageTables.TTBR0_EL1(false, 0)
+
+ reg.id = _KVM_ARM64_REGS_TTBR0_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ c.SetTtbr0Kvm(uintptr(data))
+
+ // ttbr1_el1
+ data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0)
+
+ reg.id = _KVM_ARM64_REGS_TTBR1_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // sp_el1
+ data = c.CPU.StackTop()
+ reg.id = _KVM_ARM64_REGS_SP_EL1
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // pc
+ reg.id = _KVM_ARM64_REGS_PC
+ data = uint64(reflect.ValueOf(ring0.Start).Pointer())
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // r8
+ reg.id = _KVM_ARM64_REGS_R8
+ data = uint64(reflect.ValueOf(&c.CPU).Pointer())
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // vbar_el1
+ reg.id = _KVM_ARM64_REGS_VBAR_EL1
+
+ fromLocation := reflect.ValueOf(ring0.Vectors).Pointer()
+ offset := fromLocation & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ toLocation := fromLocation + offset
+ data = uint64(ring0.KernelStartAddress | toLocation)
+ if err := c.setOneRegister(&reg); err != nil {
+ return err
+ }
+
+ // Use the address of the exception vector table as
+ // the MMIO address base.
+ arm64HypercallMMIOBase = toLocation
+
+ // Initialize the PCID database.
+ if hasGuestPCID {
+ // Note that NewPCIDs may return a nil table here, in which
+ // case we simply don't use PCID support (see below). In
+ // practice, this should not happen, however.
+ c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
+ }
+
+ c.floatingPointState = arch.NewFloatingPointData()
+ return nil
+}
+
+//go:nosplit
+func (c *vCPU) loadSegments(tid uint64) {
+ // TODO(gvisor.dev/issue/1238): TLS is not supported.
+ // Get TLS from tpidr_el0.
+ atomic.StoreUint64(&c.tid, tid)
+}
+
+func (c *vCPU) setOneRegister(reg *kvmOneReg) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_ONE_REG,
+ uintptr(unsafe.Pointer(reg))); errno != 0 {
+ return fmt.Errorf("error setting one register: %v", errno)
+ }
+ return nil
+}
+
+func (c *vCPU) getOneRegister(reg *kvmOneReg) error {
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_GET_ONE_REG,
+ uintptr(unsafe.Pointer(reg))); errno != 0 {
+ return fmt.Errorf("error setting one register: %v", errno)
+ }
+ return nil
+}
+
+// setCPUID sets the CPUID to be used by the guest.
+func (c *vCPU) setCPUID() error {
+ return nil
+}
+
+// setSystemTime sets the TSC for the vCPU.
+func (c *vCPU) setSystemTime() error {
+ return nil
+}
+
+// setSignalMask sets the vCPU signal mask.
+//
+// This must be called prior to running the vCPU.
+func (c *vCPU) setSignalMask() error {
+ // The layout of this structure implies that it will not necessarily be
+ // the same layout chosen by the Go compiler. It gets fudged here.
+ var data struct {
+ length uint32
+ mask1 uint32
+ mask2 uint32
+ _ uint32
+ }
+ data.length = 8 // Fixed sigset size.
+ data.mask1 = ^uint32(bounceSignalMask & 0xffffffff)
+ data.mask2 = ^uint32(bounceSignalMask >> 32)
+ if _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_SIGNAL_MASK,
+ uintptr(unsafe.Pointer(&data))); errno != 0 {
+ return fmt.Errorf("error setting signal mask: %v", errno)
+ }
+
+ return nil
+}
+
+// SwitchToUser unpacks architectural-details.
+func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) (usermem.AccessType, error) {
+ // Check for canonical addresses.
+ if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) {
+ return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info)
+ } else if !ring0.IsCanonical(regs.Sp) {
+ return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
+ }
+
+ // Assign PCIDs.
+ if c.PCIDs != nil {
+ var requireFlushPCID bool // Force a flush?
+ switchOpts.UserASID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables)
+ switchOpts.Flush = switchOpts.Flush || requireFlushPCID
+ }
+
+ var vector ring0.Vector
+ ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0)
+ c.SetTtbr0App(uintptr(ttbr0App))
+
+ // TODO(gvisor.dev/issue/1238): full context-switch supporting for Arm64.
+ // The Arm64 user-mode execution state consists of:
+ // x0-x30
+ // PC, SP, PSTATE
+ // V0-V31: 32 128-bit registers for floating point, and simd
+ // FPSR
+ // TPIDR_EL0, used for TLS
+ appRegs := switchOpts.Registers
+ c.SetAppAddr(ring0.KernelStartAddress | uintptr(unsafe.Pointer(appRegs)))
+
+ entersyscall()
+ bluepill(c)
+ vector = c.CPU.SwitchToUser(switchOpts)
+ exitsyscall()
+
+ switch vector {
+ case ring0.Syscall:
+ // Fast path: system call executed.
+ return usermem.NoAccess, nil
+
+ case ring0.PageFault:
+ return c.fault(int32(syscall.SIGSEGV), info)
+ case ring0.Vector(bounce): // ring0.VirtualizationException
+ return usermem.NoAccess, platform.ErrContextInterrupt
+ case ring0.El0Sync_undef,
+ ring0.El1Sync_undef:
+ *info = arch.SignalInfo{
+ Signo: int32(syscall.SIGILL),
+ Code: 1, // ILL_ILLOPC (illegal opcode).
+ }
+ info.SetAddr(switchOpts.Registers.Pc) // Include address.
+ return usermem.AccessType{}, platform.ErrContextSignal
+ default:
+ panic(fmt.Sprintf("unexpected vector: 0x%x", vector))
+ }
+
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index 405e00292..9f86f6a7a 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -35,6 +35,30 @@ func entersyscall()
//go:linkname exitsyscall runtime.exitsyscall
func exitsyscall()
+// setMemoryRegion initializes a region.
+//
+// This may be called from bluepillHandler, and therefore returns an errno
+// directly (instead of wrapping in an error) to avoid allocations.
+//
+//go:nosplit
+func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno {
+ userRegion := userMemoryRegion{
+ slot: uint32(slot),
+ flags: uint32(flags),
+ guestPhysAddr: uint64(physical),
+ memorySize: uint64(length),
+ userspaceAddr: uint64(virtual),
+ }
+
+ // Set the region.
+ _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_USER_MEMORY_REGION,
+ uintptr(unsafe.Pointer(&userRegion)))
+ return errno
+}
+
// mapRunData maps the vCPU run data.
func mapRunData(fd int) (*runData, error) {
r, _, errno := syscall.RawSyscall6(
@@ -63,46 +87,6 @@ func unmapRunData(r *runData) error {
return nil
}
-// setUserRegisters sets user registers in the vCPU.
-func (c *vCPU) setUserRegisters(uregs *userRegs) error {
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_REGS,
- uintptr(unsafe.Pointer(uregs))); errno != 0 {
- return fmt.Errorf("error setting user registers: %v", errno)
- }
- return nil
-}
-
-// getUserRegisters reloads user registers in the vCPU.
-//
-// This is safe to call from a nosplit context.
-//
-//go:nosplit
-func (c *vCPU) getUserRegisters(uregs *userRegs) syscall.Errno {
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_GET_REGS,
- uintptr(unsafe.Pointer(uregs))); errno != 0 {
- return errno
- }
- return 0
-}
-
-// setSystemRegisters sets system registers.
-func (c *vCPU) setSystemRegisters(sregs *systemRegs) error {
- if _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(c.fd),
- _KVM_SET_SREGS,
- uintptr(unsafe.Pointer(sregs))); errno != 0 {
- return fmt.Errorf("error setting system registers: %v", errno)
- }
- return nil
-}
-
// atomicAddressSpace is an atomic address space pointer.
type atomicAddressSpace struct {
pointer unsafe.Pointer
@@ -131,7 +115,7 @@ func (a *atomicAddressSpace) get() *addressSpace {
//
//go:nosplit
func (c *vCPU) notify() {
- _, _, errno := syscall.RawSyscall6(
+ _, _, errno := syscall.RawSyscall6( // escapes: no.
syscall.SYS_FUTEX,
uintptr(unsafe.Pointer(&c.state)),
linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG,
diff --git a/pkg/sentry/platform/kvm/physical_map.go b/pkg/sentry/platform/kvm/physical_map.go
index 586e91bb2..f7fa2f98d 100644
--- a/pkg/sentry/platform/kvm/physical_map.go
+++ b/pkg/sentry/platform/kvm/physical_map.go
@@ -21,16 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
-)
-
-const (
- // reservedMemory is a chunk of physical memory reserved starting at
- // physical address zero. There are some special pages in this region,
- // so we just call the whole thing off.
- //
- // Other architectures may define this to be zero.
- reservedMemory = 0x100000000
+ "gvisor.dev/gvisor/pkg/usermem"
)
type region struct {
@@ -59,8 +50,7 @@ func fillAddressSpace() (excludedRegions []region) {
// We can cut vSize in half, because the kernel will be using the top
// half and we ignore it while constructing mappings. It's as if we've
// already excluded half the possible addresses.
- vSize := uintptr(1) << ring0.VirtualAddressBits()
- vSize = vSize >> 1
+ vSize := ring0.UserspaceSize
// We exclude reservedMemory below from our physical memory size, so it
// needs to be dropped here as well. Otherwise, we could end up with
diff --git a/pkg/sentry/platform/kvm/physical_map_amd64.go b/pkg/sentry/platform/kvm/physical_map_amd64.go
new file mode 100644
index 000000000..c5adfb577
--- /dev/null
+++ b/pkg/sentry/platform/kvm/physical_map_amd64.go
@@ -0,0 +1,22 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+const (
+ // reservedMemory is a chunk of physical memory reserved starting at
+ // physical address zero. There are some special pages in this region,
+ // so we just call the whole thing off.
+ reservedMemory = 0x100000000
+)
diff --git a/pkg/sentry/fsimpl/proc/proc.go b/pkg/sentry/platform/kvm/physical_map_arm64.go
index 31dec36de..4d8561453 100644
--- a/pkg/sentry/fsimpl/proc/proc.go
+++ b/pkg/sentry/platform/kvm/physical_map_arm64.go
@@ -12,5 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package proc implements a partial in-memory file system for procfs.
-package proc
+package kvm
+
+const (
+ reservedMemory = 0
+)
diff --git a/pkg/sentry/platform/kvm/testutil/BUILD b/pkg/sentry/platform/kvm/testutil/BUILD
index b0e45f159..f7feb8683 100644
--- a/pkg/sentry/platform/kvm/testutil/BUILD
+++ b/pkg/sentry/platform/kvm/testutil/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -12,6 +12,6 @@ go_library(
"testutil_arm64.go",
"testutil_arm64.s",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil",
visibility = ["//pkg/sentry/platform/kvm:__pkg__"],
+ deps = ["//pkg/sentry/arch"],
)
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
index 4c108abbf..8048eedec 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_amd64.go
@@ -18,19 +18,20 @@ package testutil
import (
"reflect"
- "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// TwiddleSegments reads segments into known registers.
func TwiddleSegments()
// SetTestTarget sets the rip appropriately.
-func SetTestTarget(regs *syscall.PtraceRegs, fn func()) {
+func SetTestTarget(regs *arch.Registers, fn func()) {
regs.Rip = uint64(reflect.ValueOf(fn).Pointer())
}
// SetTouchTarget sets rax appropriately.
-func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
+func SetTouchTarget(regs *arch.Registers, target *uintptr) {
if target != nil {
regs.Rax = uint64(reflect.ValueOf(target).Pointer())
} else {
@@ -39,12 +40,12 @@ func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
}
// RewindSyscall rewinds a syscall RIP.
-func RewindSyscall(regs *syscall.PtraceRegs) {
+func RewindSyscall(regs *arch.Registers) {
regs.Rip -= 2
}
// SetTestRegs initializes registers to known values.
-func SetTestRegs(regs *syscall.PtraceRegs) {
+func SetTestRegs(regs *arch.Registers) {
regs.R15 = 0x15
regs.R14 = 0x14
regs.R13 = 0x13
@@ -64,7 +65,7 @@ func SetTestRegs(regs *syscall.PtraceRegs) {
}
// CheckTestRegs checks that registers were twiddled per TwiddleRegs.
-func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) {
+func CheckTestRegs(regs *arch.Registers, full bool) (err error) {
if need := ^uint64(0x15); regs.R15 != need {
err = addRegisterMismatch(err, "R15", regs.R15, need)
}
@@ -121,13 +122,13 @@ var fsData uint64 = 0x55
var gsData uint64 = 0x85
// SetTestSegments initializes segments to known values.
-func SetTestSegments(regs *syscall.PtraceRegs) {
+func SetTestSegments(regs *arch.Registers) {
regs.Fs_base = uint64(reflect.ValueOf(&fsData).Pointer())
regs.Gs_base = uint64(reflect.ValueOf(&gsData).Pointer())
}
// CheckTestSegments checks that registers were twiddled per TwiddleSegments.
-func CheckTestSegments(regs *syscall.PtraceRegs) (err error) {
+func CheckTestSegments(regs *arch.Registers) (err error) {
if regs.Rax != fsData {
err = addRegisterMismatch(err, "Rax", regs.Rax, fsData)
}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
index 40b2e4acc..4dad877ba 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.go
@@ -19,16 +19,17 @@ package testutil
import (
"fmt"
"reflect"
- "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// SetTestTarget sets the rip appropriately.
-func SetTestTarget(regs *syscall.PtraceRegs, fn func()) {
+func SetTestTarget(regs *arch.Registers, fn func()) {
regs.Pc = uint64(reflect.ValueOf(fn).Pointer())
}
// SetTouchTarget sets rax appropriately.
-func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
+func SetTouchTarget(regs *arch.Registers, target *uintptr) {
if target != nil {
regs.Regs[8] = uint64(reflect.ValueOf(target).Pointer())
} else {
@@ -37,23 +38,27 @@ func SetTouchTarget(regs *syscall.PtraceRegs, target *uintptr) {
}
// RewindSyscall rewinds a syscall RIP.
-func RewindSyscall(regs *syscall.PtraceRegs) {
+func RewindSyscall(regs *arch.Registers) {
regs.Pc -= 4
}
// SetTestRegs initializes registers to known values.
-func SetTestRegs(regs *syscall.PtraceRegs) {
+func SetTestRegs(regs *arch.Registers) {
for i := 0; i <= 30; i++ {
regs.Regs[i] = uint64(i) + 1
}
}
// CheckTestRegs checks that registers were twiddled per TwiddleRegs.
-func CheckTestRegs(regs *syscall.PtraceRegs, full bool) (err error) {
+func CheckTestRegs(regs *arch.Registers, full bool) (err error) {
for i := 0; i <= 30; i++ {
if need := ^uint64(i + 1); regs.Regs[i] != need {
err = addRegisterMismatch(err, fmt.Sprintf("R%d", i), regs.Regs[i], need)
}
}
+ // Check tls.
+ if need := ^uint64(11); regs.TPIDR_EL0 != need {
+ err = addRegisterMismatch(err, "tpdir_el0", regs.TPIDR_EL0, need)
+ }
return
}
diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
index 2cd28b2d2..6caf7282d 100644
--- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
+++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s
@@ -50,6 +50,23 @@ TEXT ·SpinLoop(SB),NOSPLIT,$0
start:
B start
+TEXT ·FloatingPointWorks(SB),NOSPLIT,$0-8
+ NO_LOCAL_POINTERS
+ // gc will touch fpsimd, so we should test it.
+ // such as in <runtime.deductSweepCredit>.
+ FMOVD $(9.9), F0
+ MOVD $SYS_GETPID, R8 // getpid
+ SVC
+ FMOVD $(9.9), F1
+ FCMPD F0, F1
+ BNE isNaN
+ MOVD $1, R0
+ MOVD R0, ret+0(FP)
+ RET
+isNaN:
+ MOVD $0, ret+0(FP)
+ RET
+
// MVN: bitwise logical NOT
// This case simulates an application that modified R0-R30.
#define TWIDDLE_REGS() \
@@ -87,5 +104,15 @@ start:
TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0
TWIDDLE_REGS()
+ MSR R10, TPIDR_EL0
+ // Trapped in el0_svc.
SVC
RET // never reached
+
+TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0
+ TWIDDLE_REGS()
+ MSR R10, TPIDR_EL0
+ // Trapped in el0_ia.
+ // Branch to Register branches unconditionally to an address in <Rn>.
+ JMP (R6) // <=> br x6, must fault
+ RET // never reached
diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go
index 2d68855ef..c8897d34f 100644
--- a/pkg/sentry/platform/kvm/virtual_map.go
+++ b/pkg/sentry/platform/kvm/virtual_map.go
@@ -22,7 +22,7 @@ import (
"regexp"
"strconv"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type virtualRegion struct {
diff --git a/pkg/sentry/platform/kvm/virtual_map_test.go b/pkg/sentry/platform/kvm/virtual_map_test.go
index 6a2f145be..327e2be4f 100644
--- a/pkg/sentry/platform/kvm/virtual_map_test.go
+++ b/pkg/sentry/platform/kvm/virtual_map_test.go
@@ -18,7 +18,7 @@ import (
"syscall"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type checker struct {
diff --git a/pkg/sentry/platform/mmap_min_addr.go b/pkg/sentry/platform/mmap_min_addr.go
index 999787462..091c2e365 100644
--- a/pkg/sentry/platform/mmap_min_addr.go
+++ b/pkg/sentry/platform/mmap_min_addr.go
@@ -20,7 +20,7 @@ import (
"strconv"
"strings"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// systemMMapMinAddrSource is the source file.
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index ec22dbf87..ba031516a 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -22,10 +22,11 @@ import (
"os"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Platform provides abstractions for execution contexts (Context,
@@ -114,6 +115,17 @@ func (NoCPUPreemptionDetection) PreemptAllCPUs() error {
panic("This platform does not support CPU preemption detection")
}
+// MemoryManager represents an abstraction above the platform address space
+// which manages memory mappings and their contents.
+type MemoryManager interface {
+ //usermem.IO provides access to the contents of a virtual memory space.
+ usermem.IO
+ // MMap establishes a memory mapping.
+ MMap(ctx context.Context, opts memmap.MMapOpts) (usermem.Addr, error)
+ // AddressSpace returns the AddressSpace bound to mm.
+ AddressSpace() AddressSpace
+}
+
// Context represents the execution context for a single thread.
type Context interface {
// Switch resumes execution of the thread specified by the arch.Context
@@ -143,11 +155,43 @@ type Context interface {
// concurrent call to Switch().
//
// - ErrContextCPUPreempted: See the definition of that error for details.
- Switch(as AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error)
+ Switch(ctx context.Context, mm MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error)
+
+ // PullFullState() pulls a full state of the application thread.
+ //
+ // A platform can support lazy loading/restoring of a thread state
+ // which includes registers and a floating point state.
+ //
+ // For example, when the Sentry handles a system call, it may have only
+ // syscall arguments without other registers and a floating point
+ // state. And in this case, if the Sentry will need to construct a
+ // signal frame to call a signal handler, it will need to call
+ // PullFullState() to load all registers and FPU state.
+ //
+ // Preconditions: The caller must be running on the task goroutine.
+ PullFullState(as AddressSpace, ac arch.Context)
+
+ // FullStateChanged() indicates that a thread state has been changed by
+ // the Sentry. This happens in case of the rt_sigreturn, execve, etc.
+ //
+ // First, it indicates that the Sentry has the full state of the thread
+ // and PullFullState() has to do nothing if it is called after
+ // FullStateChanged().
+ //
+ // Second, it forces restoring the full state of the application
+ // thread. A platform can support lazy loading/restoring of a thread
+ // state. This means that if the Sentry has not changed a thread state,
+ // the platform may not restore it.
+ //
+ // Preconditions: The caller must be running on the task goroutine.
+ FullStateChanged()
// Interrupt interrupts a concurrent call to Switch(), causing it to return
// ErrContextInterrupt.
Interrupt()
+
+ // Release() releases any resources associated with this context.
+ Release()
}
var (
@@ -204,7 +248,7 @@ type AddressSpace interface {
// Preconditions: addr and fr must be page-aligned. fr.Length() > 0.
// at.Any() == true. At least one reference must be held on all pages in
// fr, and must continue to be held as long as pages are mapped.
- MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error
+ MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error
// Unmap unmaps the given range.
//
@@ -215,6 +259,13 @@ type AddressSpace interface {
// must be acquired via platform.NewAddressSpace().
Release()
+ // PreFork() is called before creating a copy of AddressSpace. This
+ // guarantees that this address space will be in a consistent state.
+ PreFork()
+
+ // PostFork() is called after creating a copy of AddressSpace.
+ PostFork()
+
// AddressSpaceIO methods are supported iff the associated platform's
// Platform.SupportsAddressSpaceIO() == true. AddressSpaces for which this
// does not hold may panic if AddressSpaceIO methods are invoked.
@@ -307,56 +358,28 @@ func (f SegmentationFault) Error() string {
return fmt.Sprintf("segmentation fault at %#x", f.Addr)
}
-// File represents a host file that may be mapped into an AddressSpace.
-type File interface {
- // All pages in a File are reference-counted.
-
- // IncRef increments the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr. (The File
- // interface does not provide a way to acquire an initial reference;
- // implementors may define mechanisms for doing so.)
- IncRef(fr FileRange)
-
- // DecRef decrements the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr.
- DecRef(fr FileRange)
-
- // MapInternal returns a mapping of the given file offsets in the invoking
- // process' address space for reading and writing.
- //
- // Note that fr.Start and fr.End need not be page-aligned.
- //
- // Preconditions: fr.Length() > 0. At least one reference must be held on
- // all pages in fr.
- //
- // Postconditions: The returned mapping is valid as long as at least one
- // reference is held on the mapped pages.
- MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
-
- // FD returns the file descriptor represented by the File.
- //
- // The only permitted operation on the returned file descriptor is to map
- // pages from it consistent with the requirements of AddressSpace.MapFile.
- FD() int
-}
-
-// FileRange represents a range of uint64 offsets into a File.
-//
-// type FileRange <generated using go_generics>
-
-// String implements fmt.Stringer.String.
-func (fr FileRange) String() string {
- return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
+// Requirements is used to specify platform specific requirements.
+type Requirements struct {
+ // RequiresCurrentPIDNS indicates that the sandbox has to be started in the
+ // current pid namespace.
+ RequiresCurrentPIDNS bool
+ // RequiresCapSysPtrace indicates that the sandbox has to be started with
+ // the CAP_SYS_PTRACE capability.
+ RequiresCapSysPtrace bool
}
// Constructor represents a platform type.
type Constructor interface {
+ // New returns a new platform instance.
+ //
+ // Arguments:
+ //
+ // * deviceFile - the device file (e.g. /dev/kvm for the KVM platform).
New(deviceFile *os.File) (Platform, error)
OpenDevice() (*os.File, error)
+
+ // Requirements returns platform specific requirements.
+ Requirements() Requirements
}
// platforms contains all available platform types.
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index ebcc8c098..e04165fbf 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,6 +9,7 @@ go_library(
"ptrace.go",
"ptrace_amd64.go",
"ptrace_arm64.go",
+ "ptrace_arm64_unsafe.go",
"ptrace_unsafe.go",
"stub_amd64.s",
"stub_arm64.s",
@@ -20,18 +21,21 @@ go_library(
"subprocess_linux_unsafe.go",
"subprocess_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ptrace",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/log",
"//pkg/procid",
+ "//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/hostcpu",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
- "//pkg/sentry/platform/safecopy",
- "//pkg/sentry/usermem",
+ "//pkg/sync",
+ "//pkg/usermem",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index 7b120a15d..b52d0fbd8 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -46,13 +46,14 @@ package ptrace
import (
"os"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
+ pkgcontext "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/interrupt"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
var (
@@ -95,7 +96,8 @@ type context struct {
}
// Switch runs the provided context in the given address space.
-func (c *context) Switch(as platform.AddressSpace, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) {
+func (c *context) Switch(ctx pkgcontext.Context, mm platform.MemoryManager, ac arch.Context, cpu int32) (*arch.SignalInfo, usermem.AccessType, error) {
+ as := mm.AddressSpace()
s := as.(*subprocess)
isSyscall := s.switchToApp(c, ac)
@@ -177,6 +179,15 @@ func (c *context) Interrupt() {
c.interrupt.NotifyInterrupt()
}
+// Release implements platform.Context.Release().
+func (c *context) Release() {}
+
+// FullStateChanged implements platform.Context.FullStateChanged.
+func (c *context) FullStateChanged() {}
+
+// PullFullState implements platform.Context.PullFullState.
+func (c *context) PullFullState(as platform.AddressSpace, ac arch.Context) {}
+
// PTrace represents a collection of ptrace subprocesses.
type PTrace struct {
platform.MMapMinAddr
@@ -248,6 +259,16 @@ func (*constructor) OpenDevice() (*os.File, error) {
return nil, nil
}
+// Flags implements platform.Constructor.Flags().
+func (*constructor) Requirements() platform.Requirements {
+ // TODO(b/75837838): Also set a new PID namespace so that we limit
+ // access to other host processes.
+ return platform.Requirements{
+ RequiresCapSysPtrace: true,
+ RequiresCurrentPIDNS: true,
+ }
+}
+
func init() {
platform.Register("ptrace", &constructor{})
}
diff --git a/pkg/sentry/platform/ptrace/ptrace_amd64.go b/pkg/sentry/platform/ptrace/ptrace_amd64.go
index db0212538..3b9a870a5 100644
--- a/pkg/sentry/platform/ptrace/ptrace_amd64.go
+++ b/pkg/sentry/platform/ptrace/ptrace_amd64.go
@@ -15,9 +15,8 @@
package ptrace
import (
- "syscall"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
@@ -28,6 +27,20 @@ func fpRegSet(useXsave bool) uintptr {
return linux.NT_PRFPREG
}
-func stackPointer(r *syscall.PtraceRegs) uintptr {
+func stackPointer(r *arch.Registers) uintptr {
return uintptr(r.Rsp)
}
+
+// x86 use the fs_base register to store the TLS pointer which can be
+// get/set in "func (t *thread) get/setRegs(regs *arch.Registers)".
+// So both of the get/setTLS() operations are noop here.
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64.go b/pkg/sentry/platform/ptrace/ptrace_arm64.go
index 4db28c534..5c869926a 100644
--- a/pkg/sentry/platform/ptrace/ptrace_arm64.go
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64.go
@@ -15,9 +15,8 @@
package ptrace
import (
- "syscall"
-
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// fpRegSet returns the GETREGSET/SETREGSET register set type to be used.
@@ -25,6 +24,6 @@ func fpRegSet(_ bool) uintptr {
return linux.NT_PRFPREG
}
-func stackPointer(r *syscall.PtraceRegs) uintptr {
+func stackPointer(r *arch.Registers) uintptr {
return uintptr(r.Sp)
}
diff --git a/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
new file mode 100644
index 000000000..32b8a6be9
--- /dev/null
+++ b/pkg/sentry/platform/ptrace/ptrace_arm64_unsafe.go
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ptrace
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+// getTLS gets the thread local storage register.
+func (t *thread) getTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
+
+// setTLS sets the thread local storage register.
+func (t *thread) setTLS(tls *uint64) error {
+ iovec := syscall.Iovec{
+ Base: (*byte)(unsafe.Pointer(tls)),
+ Len: uint64(unsafe.Sizeof(*tls)),
+ }
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_SETREGSET,
+ uintptr(t.tid),
+ linux.NT_ARM_TLS,
+ uintptr(unsafe.Pointer(&iovec)),
+ 0, 0)
+ if errno != 0 {
+ return errno
+ }
+ return nil
+}
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 72c7ec564..8b72d24e8 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -20,11 +20,11 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// getRegs gets the general purpose register set.
-func (t *thread) getRegs(regs *syscall.PtraceRegs) error {
+func (t *thread) getRegs(regs *arch.Registers) error {
iovec := syscall.Iovec{
Base: (*byte)(unsafe.Pointer(regs)),
Len: uint64(unsafe.Sizeof(*regs)),
@@ -43,7 +43,7 @@ func (t *thread) getRegs(regs *syscall.PtraceRegs) error {
}
// setRegs sets the general purpose register set.
-func (t *thread) setRegs(regs *syscall.PtraceRegs) error {
+func (t *thread) setRegs(regs *arch.Registers) error {
iovec := syscall.Iovec{
Base: (*byte)(unsafe.Pointer(regs)),
Len: uint64(unsafe.Sizeof(*regs)),
diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s
index 64c718d21..16f9c523e 100644
--- a/pkg/sentry/platform/ptrace/stub_amd64.s
+++ b/pkg/sentry/platform/ptrace/stub_amd64.s
@@ -64,6 +64,8 @@ begin:
CMPQ AX, $0
JL error
+ MOVQ $0, BX
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -73,23 +75,26 @@ begin:
MOVQ $SIGSTOP, SI
SYSCALL
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+ // The sentry sets BX to 1 when creating stub process.
+ CMPQ BX, $1
+ JE clone
+
+ // Notify the Sentry that syscall exited.
+done:
+ INT $3
+ // Be paranoid.
+ JMP done
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R15 has been updated with the expected PPID.
- JMP begin
+ CMPQ AX, $0
+ JE begin
+ // The clone syscall returns a non-zero value.
+ JMP done
error:
// Exit with -errno.
MOVQ AX, DI
diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s
index 2c5e4d5cb..6162df02a 100644
--- a/pkg/sentry/platform/ptrace/stub_arm64.s
+++ b/pkg/sentry/platform/ptrace/stub_arm64.s
@@ -59,6 +59,8 @@ begin:
CMP $0x0, R0
BLT error
+ MOVD $0, R9
+
// SIGSTOP to wait for attach.
//
// The SYSCALL instruction will be used for future syscall injection by
@@ -66,22 +68,26 @@ begin:
MOVD $SYS_KILL, R8
MOVD $SIGSTOP, R1
SVC
- // The tracer may "detach" and/or allow code execution here in three cases:
- //
- // 1. New (traced) stub threads are explicitly detached by the
- // goroutine in newSubprocess. However, they are detached while in
- // group-stop, so they do not execute code here.
- //
- // 2. If a tracer thread exits, it implicitly detaches from the stub,
- // potentially allowing code execution here. However, the Go runtime
- // never exits individual threads, so this case never occurs.
- //
- // 3. subprocess.createStub clones a new stub process that is untraced,
+
+ // The sentry sets R9 to 1 when creating stub process.
+ CMP $1, R9
+ BEQ clone
+
+done:
+ // Notify the Sentry that syscall exited.
+ BRK $3
+ B done // Be paranoid.
+clone:
+ // subprocess.createStub clones a new stub process that is untraced,
// thus executing this code. We setup the PDEATHSIG before SIGSTOPing
// ourselves for attach by the tracer.
//
// R7 has been updated with the expected PPID.
- B begin
+ CMP $0, R0
+ BEQ begin
+
+ // The clone system call returned a non-zero value.
+ B done
error:
// Exit with -errno.
diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go
index aa1b87237..341dde143 100644
--- a/pkg/sentry/platform/ptrace/stub_unsafe.go
+++ b/pkg/sentry/platform/ptrace/stub_unsafe.go
@@ -19,8 +19,8 @@ import (
"syscall"
"unsafe"
- "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// stub is defined in arch-specific assembly.
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index ddb1f41e3..e1d54d8a2 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -18,14 +18,16 @@ import (
"fmt"
"os"
"runtime"
- "sync"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Linux kernel errnos which "should never be seen by user programs", but will
@@ -62,7 +64,7 @@ type thread struct {
// initRegs are the initial registers for the first thread.
//
// These are used for the register set for system calls.
- initRegs syscall.PtraceRegs
+ initRegs arch.Registers
}
// threadPool is a collection of threads.
@@ -316,7 +318,7 @@ const (
)
func (t *thread) dumpAndPanic(message string) {
- var regs syscall.PtraceRegs
+ var regs arch.Registers
message += "\n"
if err := t.getRegs(&regs); err == nil {
message += dumpRegs(&regs)
@@ -331,7 +333,7 @@ func (t *thread) unexpectedStubExit() {
msg, err := t.getEventMessage()
status := syscall.WaitStatus(msg)
if status.Signaled() && status.Signal() == syscall.SIGKILL {
- // SIGKILL can be only sent by an user or OOM-killer. In both
+ // SIGKILL can be only sent by a user or OOM-killer. In both
// these cases, we don't need to panic. There is no reasons to
// think that something wrong in gVisor.
log.Warningf("The ptrace stub process %v has been killed by SIGKILL.", t.tgid)
@@ -422,20 +424,22 @@ func (t *thread) init() {
// This is _not_ for use by application system calls, rather it is for use when
// a system call must be injected into the remote context (e.g. mmap, munmap).
// Note that clones are handled separately.
-func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
+func (t *thread) syscall(regs *arch.Registers) (uintptr, error) {
// Set registers.
if err := t.setRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace set regs failed: %v", err))
}
for {
- // Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ // Execute the syscall instruction. The task has to stop on the
+ // trap instruction which is right after the syscall
+ // instruction.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_CONT, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
+ if sig == syscall.SIGTRAP {
// Reached syscall-enter-stop.
break
} else {
@@ -447,18 +451,6 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
}
- // Complete the actual system call.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- // Wait for syscall-exit-stop. "[Signal-delivery-stop] never happens
- // between syscall-enter-stop and syscall-exit-stop; it happens *after*
- // syscall-exit-stop.)" - ptrace(2), "Syscall-stops"
- if sig := t.wait(stopped); sig != (syscallEvent | syscall.SIGTRAP) {
- t.dumpAndPanic(fmt.Sprintf("wait failed: expected SIGTRAP, got %v [%d]", sig, sig))
- }
-
// Grab registers.
if err := t.getRegs(regs); err != nil {
panic(fmt.Sprintf("ptrace get regs failed: %v", err))
@@ -470,7 +462,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
// syscallIgnoreInterrupt ignores interrupts on the system call thread and
// restarts the syscall if the kernel indicates that should happen.
func (t *thread) syscallIgnoreInterrupt(
- initRegs *syscall.PtraceRegs,
+ initRegs *arch.Registers,
sysno uintptr,
args ...arch.SyscallArgument) (uintptr, error) {
for {
@@ -515,6 +507,9 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
regs := &ac.StateData().Regs
t.resetSysemuRegs(regs)
+ // Extract TLS register
+ tls := uint64(ac.TLS())
+
// Check for interrupts, and ensure that future interrupts will signal t.
if !c.interrupt.Enable(t) {
// Pending interrupt; simulate.
@@ -535,20 +530,23 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if err := t.setFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
panic(fmt.Sprintf("ptrace set fpregs (%+v) failed: %v", fpState, err))
}
+ if err := t.setTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace set tls (%+v) failed: %v", tls, err))
+ }
for {
// Start running until the next system call.
if isSingleStepping(regs) {
if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_SYSEMU_SINGLESTEP,
+ unix.PTRACE_SYSEMU_SINGLESTEP,
uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
- syscall.PTRACE_SYSEMU,
+ unix.PTRACE_SYSEMU,
uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
@@ -564,6 +562,12 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
if err := t.getFPRegs(fpState, uint64(fpLen), useXsave); err != nil {
panic(fmt.Sprintf("ptrace get fpregs failed: %v", err))
}
+ if err := t.getTLS(&tls); err != nil {
+ panic(fmt.Sprintf("ptrace get tls failed: %v", err))
+ }
+ if !ac.SetTLS(uintptr(tls)) {
+ panic(fmt.Sprintf("tls value %v is invalid", tls))
+ }
// Is it a system call?
if sig == (syscallEvent | syscall.SIGTRAP) {
@@ -613,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp
}
// MapFile implements platform.AddressSpace.MapFile.
-func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
var flags int
if precommit {
flags |= syscall.MAP_POPULATE
@@ -658,3 +662,9 @@ func (s *subprocess) Unmap(addr usermem.Addr, length uint64) {
panic(fmt.Sprintf("munmap(%x, %x)) failed: %v", addr, length, err))
}
}
+
+// PreFork implements platform.AddressSpace.PreFork.
+func (s *subprocess) PreFork() {}
+
+// PostFork implements platform.AddressSpace.PostFork.
+func (s *subprocess) PostFork() {}
diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go
index 4649a94a7..84b699f0d 100644
--- a/pkg/sentry/platform/ptrace/subprocess_amd64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go
@@ -21,6 +21,9 @@ import (
"strings"
"syscall"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
)
@@ -38,7 +41,7 @@ const (
// resetSysemuRegs sets up emulation registers.
//
// This should be called prior to calling sysemu.
-func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) {
+func (t *thread) resetSysemuRegs(regs *arch.Registers) {
regs.Cs = t.initRegs.Cs
regs.Ss = t.initRegs.Ss
regs.Ds = t.initRegs.Ds
@@ -50,7 +53,7 @@ func (t *thread) resetSysemuRegs(regs *syscall.PtraceRegs) {
// createSyscallRegs sets up syscall registers.
//
// This should be called to generate registers for a system call.
-func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs {
+func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers {
// Copy initial registers.
regs := *initRegs
@@ -79,18 +82,18 @@ func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch
}
// isSingleStepping determines if the registers indicate single-stepping.
-func isSingleStepping(regs *syscall.PtraceRegs) bool {
+func isSingleStepping(regs *arch.Registers) bool {
return (regs.Eflags & arch.X86TrapFlag) != 0
}
// updateSyscallRegs updates registers after finishing sysemu.
-func updateSyscallRegs(regs *syscall.PtraceRegs) {
+func updateSyscallRegs(regs *arch.Registers) {
// Ptrace puts -ENOSYS in rax on syscall-enter-stop.
regs.Rax = regs.Orig_rax
}
// syscallReturnValue extracts a sensible return from registers.
-func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
+func syscallReturnValue(regs *arch.Registers) (uintptr, error) {
rval := int64(regs.Rax)
if rval < 0 {
return 0, syscall.Errno(-rval)
@@ -98,7 +101,7 @@ func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
return uintptr(rval), nil
}
-func dumpRegs(regs *syscall.PtraceRegs) string {
+func dumpRegs(regs *arch.Registers) string {
var m strings.Builder
fmt.Fprintf(&m, "Registers:\n")
@@ -139,7 +142,118 @@ func (t *thread) adjustInitRegsRip() {
t.initRegs.Rip -= initRegsRipAdjustment
}
-// Pass the expected PPID to the child via R15 when creating stub process
-func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
+// Pass the expected PPID to the child via R15 when creating stub process.
+func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
initregs.R15 = uint64(ppid)
+ // Rbx has to be set to 1 when creating stub process.
+ initregs.Rbx = 1
+}
+
+// patchSignalInfo patches the signal info to account for hitting the seccomp
+// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
+// synchronous trap, but patch the structure to appear like a SIGSEGV with the
+// Rip as the faulting address.
+//
+// Note that this should only be called after verifying that the signalInfo has
+// been generated by the kernel.
+func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+ if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
+ signalInfo.Signo = int32(linux.SIGSEGV)
+
+ // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
+ // with the si_call_addr field pointing to the current RIP. This field
+ // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
+ // anything there. We do need to unwind emulation however, so we set the
+ // instruction pointer to the faulting value, and "unpop" the stack.
+ regs.Rip = signalInfo.Addr()
+ regs.Rsp -= 8
+ }
+}
+
+// enableCpuidFault enables cpuid-faulting.
+//
+// This may fail on older kernels or hardware, so we just disregard the result.
+// Host CPUID will be enabled.
+//
+// This is safe to call in an afterFork context.
+//
+//go:nosplit
+func enableCpuidFault() {
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
+}
+
+// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
+// Ref attachedThread() for more detail.
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
+ rules = append(rules,
+ // Rules for trapping vsyscall access.
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_GETTIMEOFDAY: {},
+ syscall.SYS_TIME: {},
+ unix.SYS_GETCPU: {}, // SYS_GETCPU was not defined in package syscall on amd64.
+ },
+ Action: linux.SECCOMP_RET_TRAP,
+ Vsyscall: true,
+ })
+ if defaultAction != linux.SECCOMP_RET_ALLOW {
+ rules = append(rules,
+ seccomp.RuleSet{
+ Rules: seccomp.SyscallRules{
+ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
+ },
+ },
+ Action: linux.SECCOMP_RET_ALLOW,
+ })
+ }
+ return rules
+}
+
+// probeSeccomp returns true iff seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8. This check is dynamic
+// because kernels have be backported behavior.
+//
+// See createStub for more information.
+//
+// Precondition: the runtime OS thread must be locked.
+func probeSeccomp() bool {
+ // Create a completely new, destroyable process.
+ t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
+ if err != nil {
+ panic(fmt.Sprintf("seccomp probe failed: %v", err))
+ }
+ defer t.destroy()
+
+ // Set registers to the yield system call. This call is not allowed
+ // by the filters specified in the attachThread function.
+ regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
+ if err := t.setRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace set regs failed: %v", err))
+ }
+
+ for {
+ // Attempt an emulation.
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, unix.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
+ panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
+ }
+
+ sig := t.wait(stopped)
+ if sig == (syscallEvent | syscall.SIGTRAP) {
+ // Did the seccomp errno hook already run? This would
+ // indicate that seccomp is first in line and we're
+ // less than 4.8.
+ if err := t.getRegs(&regs); err != nil {
+ panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
+ }
+ if _, err := syscallReturnValue(&regs); err == nil {
+ // The seccomp errno mode ran first, and reset
+ // the error in the registers.
+ return false
+ }
+ // The seccomp hook did not run yet, and therefore it
+ // is safe to use RET_KILL mode for dispatched calls.
+ return true
+ }
+ }
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_arm64.go b/pkg/sentry/platform/ptrace/subprocess_arm64.go
index bec884ba5..bd618fae8 100644
--- a/pkg/sentry/platform/ptrace/subprocess_arm64.go
+++ b/pkg/sentry/platform/ptrace/subprocess_arm64.go
@@ -17,8 +17,12 @@
package ptrace
import (
+ "fmt"
+ "strings"
"syscall"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
)
@@ -37,13 +41,13 @@ const (
// resetSysemuRegs sets up emulation registers.
//
// This should be called prior to calling sysemu.
-func (s *subprocess) resetSysemuRegs(regs *syscall.PtraceRegs) {
+func (t *thread) resetSysemuRegs(regs *arch.Registers) {
}
// createSyscallRegs sets up syscall registers.
//
// This should be called to generate registers for a system call.
-func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch.SyscallArgument) syscall.PtraceRegs {
+func createSyscallRegs(initRegs *arch.Registers, sysno uintptr, args ...arch.SyscallArgument) arch.Registers {
// Copy initial registers (Pc, Sp, etc.).
regs := *initRegs
@@ -74,7 +78,7 @@ func createSyscallRegs(initRegs *syscall.PtraceRegs, sysno uintptr, args ...arch
}
// isSingleStepping determines if the registers indicate single-stepping.
-func isSingleStepping(regs *syscall.PtraceRegs) bool {
+func isSingleStepping(regs *arch.Registers) bool {
// Refer to the ARM SDM D2.12.3: software step state machine
// return (regs.Pstate.SS == 1) && (MDSCR_EL1.SS == 1).
//
@@ -85,13 +89,13 @@ func isSingleStepping(regs *syscall.PtraceRegs) bool {
}
// updateSyscallRegs updates registers after finishing sysemu.
-func updateSyscallRegs(regs *syscall.PtraceRegs) {
+func updateSyscallRegs(regs *arch.Registers) {
// No special work is necessary.
return
}
// syscallReturnValue extracts a sensible return from registers.
-func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
+func syscallReturnValue(regs *arch.Registers) (uintptr, error) {
rval := int64(regs.Regs[0])
if rval < 0 {
return 0, syscall.Errno(-rval)
@@ -99,7 +103,7 @@ func syscallReturnValue(regs *syscall.PtraceRegs) (uintptr, error) {
return uintptr(rval), nil
}
-func dumpRegs(regs *syscall.PtraceRegs) string {
+func dumpRegs(regs *arch.Registers) string {
var m strings.Builder
fmt.Fprintf(&m, "Registers:\n")
@@ -121,6 +125,50 @@ func (t *thread) adjustInitRegsRip() {
}
// Pass the expected PPID to the child via X7 when creating stub process
-func initChildProcessPPID(initregs *syscall.PtraceRegs, ppid int32) {
+func initChildProcessPPID(initregs *arch.Registers, ppid int32) {
initregs.Regs[7] = uint64(ppid)
+ // R9 has to be set to 1 when creating stub process.
+ initregs.Regs[9] = 1
+}
+
+// patchSignalInfo patches the signal info to account for hitting the seccomp
+// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
+// synchronous trap, but patch the structure to appear like a SIGSEGV with the
+// Rip as the faulting address.
+//
+// Note that this should only be called after verifying that the signalInfo has
+// been generated by the kernel.
+func patchSignalInfo(regs *arch.Registers, signalInfo *arch.SignalInfo) {
+ if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
+ signalInfo.Signo = int32(linux.SIGSEGV)
+
+ // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
+ // with the si_call_addr field pointing to the current RIP. This field
+ // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
+ // anything there. We do need to unwind emulation however, so we set the
+ // instruction pointer to the faulting value, and "unpop" the stack.
+ regs.Pc = signalInfo.Addr()
+ regs.Sp -= 8
+ }
+}
+
+// Noop on arm64.
+//
+//go:nosplit
+func enableCpuidFault() {
+}
+
+// appendArchSeccompRules append architecture specific seccomp rules when creating BPF program.
+// Ref attachedThread() for more detail.
+func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFAction) []seccomp.RuleSet {
+ return rules
+}
+
+// probeSeccomp returns true if seccomp is run after ptrace notifications,
+// which is generally the case for kernel version >= 4.8.
+//
+// On arm64, the support of PTRACE_SYSEMU was added in the 5.3 kernel, so
+// probeSeccomp can always return true.
+func probeSeccomp() bool {
+ return true
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index 3782d4332..2ce528601 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -29,75 +29,6 @@ import (
const syscallEvent syscall.Signal = 0x80
-// probeSeccomp returns true iff seccomp is run after ptrace notifications,
-// which is generally the case for kernel version >= 4.8. This check is dynamic
-// because kernels have be backported behavior.
-//
-// See createStub for more information.
-//
-// Precondition: the runtime OS thread must be locked.
-func probeSeccomp() bool {
- // Create a completely new, destroyable process.
- t, err := attachedThread(0, linux.SECCOMP_RET_ERRNO)
- if err != nil {
- panic(fmt.Sprintf("seccomp probe failed: %v", err))
- }
- defer t.destroy()
-
- // Set registers to the yield system call. This call is not allowed
- // by the filters specified in the attachThread function.
- regs := createSyscallRegs(&t.initRegs, syscall.SYS_SCHED_YIELD)
- if err := t.setRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace set regs failed: %v", err))
- }
-
- for {
- // Attempt an emulation.
- if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
- panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
- }
-
- sig := t.wait(stopped)
- if sig == (syscallEvent | syscall.SIGTRAP) {
- // Did the seccomp errno hook already run? This would
- // indicate that seccomp is first in line and we're
- // less than 4.8.
- if err := t.getRegs(&regs); err != nil {
- panic(fmt.Sprintf("ptrace get-regs failed: %v", err))
- }
- if _, err := syscallReturnValue(&regs); err == nil {
- // The seccomp errno mode ran first, and reset
- // the error in the registers.
- return false
- }
- // The seccomp hook did not run yet, and therefore it
- // is safe to use RET_KILL mode for dispatched calls.
- return true
- }
- }
-}
-
-// patchSignalInfo patches the signal info to account for hitting the seccomp
-// filters from vsyscall emulation, specified below. We allow for SIGSYS as a
-// synchronous trap, but patch the structure to appear like a SIGSEGV with the
-// Rip as the faulting address.
-//
-// Note that this should only be called after verifying that the signalInfo has
-// been generated by the kernel.
-func patchSignalInfo(regs *syscall.PtraceRegs, signalInfo *arch.SignalInfo) {
- if linux.Signal(signalInfo.Signo) == linux.SIGSYS {
- signalInfo.Signo = int32(linux.SIGSEGV)
-
- // Unwind the kernel emulation, if any has occurred. A SIGSYS is delivered
- // with the si_call_addr field pointing to the current RIP. This field
- // aligns with the si_addr field for a SIGSEGV, so we don't need to touch
- // anything there. We do need to unwind emulation however, so we set the
- // instruction pointer to the faulting value, and "unpop" the stack.
- regs.Rip = signalInfo.Addr()
- regs.Rsp -= 8
- }
-}
-
// createStub creates a fresh stub processes.
//
// Precondition: the runtime OS thread must be locked.
@@ -143,18 +74,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// stub and all its children. This is used to create child stubs
// (below), so we must include the ability to fork, but otherwise lock
// down available calls only to what is needed.
- rules := []seccomp.RuleSet{
- // Rules for trapping vsyscall access.
- {
- Rules: seccomp.SyscallRules{
- syscall.SYS_GETTIMEOFDAY: {},
- syscall.SYS_TIME: {},
- 309: {}, // SYS_GETCPU.
- },
- Action: linux.SECCOMP_RET_TRAP,
- Vsyscall: true,
- },
- }
+ rules := []seccomp.RuleSet{}
if defaultAction != linux.SECCOMP_RET_ALLOW {
rules = append(rules, seccomp.RuleSet{
Rules: seccomp.SyscallRules{
@@ -173,10 +93,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// For the initial process creation.
syscall.SYS_WAIT4: {},
- syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)},
- },
- syscall.SYS_EXIT: {},
+ syscall.SYS_EXIT: {},
// For the stub prctl dance (all).
syscall.SYS_PRCTL: []seccomp.Rule{
@@ -197,6 +114,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
Action: linux.SECCOMP_RET_ALLOW,
})
}
+ rules = appendArchSeccompRules(rules, defaultAction)
instrs, err := seccomp.BuildProgram(rules, defaultAction)
if err != nil {
return nil, err
@@ -267,9 +185,8 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
syscall.RawSyscall(syscall.SYS_EXIT, uintptr(errno), 0, 0)
}
- // Enable cpuid-faulting; this may fail on older kernels or hardware,
- // so we just disregard the result. Host CPUID will be enabled.
- syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
+ // Enable cpuid-faulting.
+ enableCpuidFault()
// Call the stub; should not return.
stubCall(stubStart, ppid)
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
index de6783fb0..245b20722 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go
@@ -18,13 +18,14 @@
package ptrace
import (
- "sync"
"sync/atomic"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/hostcpu"
+ "gvisor.dev/gvisor/pkg/sync"
)
// maskPool contains reusable CPU masks for setting affinity. Unfortunately,
@@ -49,20 +50,6 @@ func unmaskAllSignals() syscall.Errno {
return errno
}
-// getCPU gets the current CPU.
-//
-// Precondition: the current runtime thread should be locked.
-func getCPU() (uint32, error) {
- var cpu uintptr
- if _, _, errno := syscall.RawSyscall(
- unix.SYS_GETCPU,
- uintptr(unsafe.Pointer(&cpu)),
- 0, 0); errno != 0 {
- return 0, errno
- }
- return uint32(cpu), nil
-}
-
// setCPU sets the CPU affinity.
func (t *thread) setCPU(cpu uint32) error {
mask := maskPool.Get().([]uintptr)
@@ -93,10 +80,8 @@ func (t *thread) setCPU(cpu uint32) error {
//
// Precondition: the current runtime thread should be locked.
func (t *thread) bind() {
- currentCPU, err := getCPU()
- if err != nil {
- return
- }
+ currentCPU := hostcpu.GetCPU()
+
if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU {
// Set the affinity on the thread and save the CPU for next
// round; we don't expect CPUs to bounce around too frequently.
diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
index b80a3604d..0bee995e4 100644
--- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 48b0ceaec..679b287c3 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -1,10 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
go_template(
- name = "defs",
+ name = "defs_amd64",
srcs = [
"defs.go",
"defs_amd64.go",
@@ -14,11 +14,29 @@ go_template(
visibility = [":__subpackages__"],
)
+go_template(
+ name = "defs_arm64",
+ srcs = [
+ "aarch64.go",
+ "defs.go",
+ "defs_arm64.go",
+ "offsets_arm64.go",
+ ],
+ visibility = [":__subpackages__"],
+)
+
go_template_instance(
- name = "defs_impl",
- out = "defs_impl.go",
+ name = "defs_impl_amd64",
+ out = "defs_impl_amd64.go",
package = "ring0",
- template = ":defs",
+ template = ":defs_amd64",
+)
+
+go_template_instance(
+ name = "defs_impl_arm64",
+ out = "defs_impl_arm64.go",
+ package = "ring0",
+ template = ":defs_arm64",
)
genrule(
@@ -29,24 +47,40 @@ genrule(
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
+genrule(
+ name = "entry_impl_arm64",
+ srcs = ["entry_arm64.s"],
+ outs = ["entry_impl_arm64.s"],
+ cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
+)
+
go_library(
name = "ring0",
srcs = [
- "defs_impl.go",
+ "defs_impl_amd64.go",
+ "defs_impl_arm64.go",
"entry_amd64.go",
+ "entry_arm64.go",
"entry_impl_amd64.s",
+ "entry_impl_arm64.s",
"kernel.go",
"kernel_amd64.go",
+ "kernel_arm64.go",
"kernel_unsafe.go",
"lib_amd64.go",
"lib_amd64.s",
+ "lib_arm64.go",
+ "lib_arm64.s",
+ "lib_arm64_unsafe.go",
"ring0.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/cpuid",
+ "//pkg/safecopy",
+ "//pkg/sentry/arch",
"//pkg/sentry/platform/ring0/pagetables",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go
new file mode 100644
index 000000000..87a573cc4
--- /dev/null
+++ b/pkg/sentry/platform/ring0/aarch64.go
@@ -0,0 +1,111 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// Useful bits.
+const (
+ _PGD_PGT_BASE = 0x1000
+ _PGD_PGT_SIZE = 0x1000
+ _PUD_PGT_BASE = 0x2000
+ _PUD_PGT_SIZE = 0x1000
+ _PMD_PGT_BASE = 0x3000
+ _PMD_PGT_SIZE = 0x4000
+ _PTE_PGT_BASE = 0x7000
+ _PTE_PGT_SIZE = 0x1000
+)
+
+const (
+ // DAIF bits:debug, sError, IRQ, FIQ.
+ _PSR_D_BIT = 0x00000200
+ _PSR_A_BIT = 0x00000100
+ _PSR_I_BIT = 0x00000080
+ _PSR_F_BIT = 0x00000040
+ _PSR_DAIF_SHIFT = 6
+ _PSR_DAIF_MASK = 0xf << _PSR_DAIF_SHIFT
+
+ // PSR bits.
+ _PSR_MODE_EL0t = 0x00000000
+ _PSR_MODE_EL1t = 0x00000004
+ _PSR_MODE_EL1h = 0x00000005
+ _PSR_MODE_MASK = 0x0000000f
+
+ PsrFlagsClear = _PSR_MODE_MASK | _PSR_DAIF_MASK
+ PsrModeMask = _PSR_MODE_MASK
+
+ // KernelFlagsSet should always be set in the kernel.
+ KernelFlagsSet = _PSR_MODE_EL1h | _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT
+
+ // UserFlagsSet are always set in userspace.
+ UserFlagsSet = _PSR_MODE_EL0t
+)
+
+// Vector is an exception vector.
+type Vector uintptr
+
+// Exception vectors.
+const (
+ El1SyncInvalid = iota
+ El1IrqInvalid
+ El1FiqInvalid
+ El1ErrorInvalid
+ El1Sync
+ El1Irq
+ El1Fiq
+ El1Error
+ El0Sync
+ El0Irq
+ El0Fiq
+ El0Error
+ El0Sync_invalid
+ El0Irq_invalid
+ El0Fiq_invalid
+ El0Error_invalid
+ El1Sync_da
+ El1Sync_ia
+ El1Sync_sp_pc
+ El1Sync_undef
+ El1Sync_dbg
+ El1Sync_inv
+ El0Sync_svc
+ El0Sync_da
+ El0Sync_ia
+ El0Sync_fpsimd_acc
+ El0Sync_sve_acc
+ El0Sync_sys
+ El0Sync_sp_pc
+ El0Sync_undef
+ El0Sync_dbg
+ El0Sync_inv
+ _NR_INTERRUPTS
+)
+
+// System call vectors.
+const (
+ Syscall Vector = El0Sync_svc
+ PageFault Vector = El0Sync_da
+ VirtualizationException Vector = El0Error
+)
+
+// VirtualAddressBits returns the number bits available for virtual addresses.
+func VirtualAddressBits() uint32 {
+ return 48
+}
+
+// PhysicalAddressBits returns the number of bits available for physical addresses.
+func PhysicalAddressBits() uint32 {
+ return 40
+}
diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go
index 076063f85..e6daf24df 100644
--- a/pkg/sentry/platform/ring0/defs.go
+++ b/pkg/sentry/platform/ring0/defs.go
@@ -15,20 +15,8 @@
package ring0
import (
- "syscall"
-
- "gvisor.dev/gvisor/pkg/sentry/usermem"
-)
-
-var (
- // UserspaceSize is the total size of userspace.
- UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1)
-
- // MaximumUserAddress is the largest possible user address.
- MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
-
- // KernelStartAddress is the starting kernel address.
- KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
)
// Kernel is a global kernel object.
@@ -83,7 +71,7 @@ type CPU struct {
// registers is a set of registers; these may be used on kernel system
// calls and exceptions via the Registers function.
- registers syscall.PtraceRegs
+ registers arch.Registers
// hooks are kernel hooks.
hooks Hooks
@@ -94,14 +82,14 @@ type CPU struct {
// This is explicitly safe to call during KernelException and KernelSyscall.
//
//go:nosplit
-func (c *CPU) Registers() *syscall.PtraceRegs {
+func (c *CPU) Registers() *arch.Registers {
return &c.registers
}
// SwitchOpts are passed to the Switch function.
type SwitchOpts struct {
// Registers are the user register state.
- Registers *syscall.PtraceRegs
+ Registers *arch.Registers
// FloatingPointState is a byte pointer where floating point state is
// saved and restored.
diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go
index 7206322b1..9c6c2cf5c 100644
--- a/pkg/sentry/platform/ring0/defs_amd64.go
+++ b/pkg/sentry/platform/ring0/defs_amd64.go
@@ -18,6 +18,18 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ // UserspaceSize is the total size of userspace.
+ UserspaceSize = uintptr(1) << (VirtualAddressBits() - 1)
+
+ // MaximumUserAddress is the largest possible user address.
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+
+ // KernelStartAddress is the starting kernel address.
+ KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
)
// Segment indices and Selectors.
diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go
new file mode 100644
index 000000000..0e2ab716c
--- /dev/null
+++ b/pkg/sentry/platform/ring0/defs_arm64.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var (
+ // UserspaceSize is the total size of userspace.
+ UserspaceSize = uintptr(1) << (VirtualAddressBits())
+
+ // MaximumUserAddress is the largest possible user address.
+ MaximumUserAddress = (UserspaceSize - 1) & ^uintptr(usermem.PageSize-1)
+
+ // KernelStartAddress is the starting kernel address.
+ KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1)
+)
+
+// KernelOpts has initialization options for the kernel.
+type KernelOpts struct {
+ // PageTables are the kernel pagetables; this must be provided.
+ PageTables *pagetables.PageTables
+}
+
+// KernelArchState contains architecture-specific state.
+type KernelArchState struct {
+ KernelOpts
+}
+
+// CPUArchState contains CPU-specific arch state.
+type CPUArchState struct {
+ // stack is the stack used for interrupts on this CPU.
+ stack [512]byte
+
+ // errorCode is the error code from the last exception.
+ errorCode uintptr
+
+ // errorType indicates the type of error code here, it is always set
+ // along with the errorCode value above.
+ //
+ // It will either by 1, which indicates a user error, or 0 indicating a
+ // kernel error. If the error code below returns false (kernel error),
+ // then it cannot provide relevant information about the last
+ // exception.
+ errorType uintptr
+
+ // faultAddr is the value of far_el1.
+ faultAddr uintptr
+
+ // ttbr0Kvm is the value of ttbr0_el1 for sentry.
+ ttbr0Kvm uintptr
+
+ // ttbr0App is the value of ttbr0_el1 for applicaton.
+ ttbr0App uintptr
+
+ // exception vector.
+ vecCode Vector
+
+ // application context pointer.
+ appAddr uintptr
+
+ // lazyVFP is the value of cpacr_el1.
+ lazyVFP uintptr
+}
+
+// ErrorCode returns the last error code.
+//
+// The returned boolean indicates whether the error code corresponds to the
+// last user error or not. If it does not, then fault information must be
+// ignored. This is generally the result of a kernel fault while servicing a
+// user fault.
+//
+//go:nosplit
+func (c *CPU) ErrorCode() (value uintptr, user bool) {
+ return c.errorCode, c.errorType != 0
+}
+
+// ClearErrorCode resets the error code.
+//
+//go:nosplit
+func (c *CPU) ClearErrorCode() {
+ c.errorCode = 0 // No code.
+ c.errorType = 1 // User mode.
+}
+
+//go:nosplit
+func (c *CPU) GetFaultAddr() (value uintptr) {
+ return c.faultAddr
+}
+
+//go:nosplit
+func (c *CPU) SetTtbr0Kvm(value uintptr) {
+ c.ttbr0Kvm = value
+}
+
+//go:nosplit
+func (c *CPU) SetTtbr0App(value uintptr) {
+ c.ttbr0App = value
+}
+
+//go:nosplit
+func (c *CPU) GetVector() (value Vector) {
+ return c.vecCode
+}
+
+//go:nosplit
+func (c *CPU) SetAppAddr(value uintptr) {
+ c.appAddr = value
+}
+
+// GetLazyVFP returns the value of cpacr_el1.
+//go:nosplit
+func (c *CPU) GetLazyVFP() (value uintptr) {
+ return c.lazyVFP
+}
+
+// SwitchArchOpts are embedded in SwitchOpts.
+type SwitchArchOpts struct {
+ // UserASID indicates that the application ASID to be used on switch,
+ UserASID uint16
+
+ // KernelASID indicates that the kernel ASID to be used on return,
+ KernelASID uint16
+}
+
+func init() {
+}
diff --git a/pkg/sentry/platform/ring0/entry_amd64.go b/pkg/sentry/platform/ring0/entry_amd64.go
index a5ce67885..7fa43c2f5 100644
--- a/pkg/sentry/platform/ring0/entry_amd64.go
+++ b/pkg/sentry/platform/ring0/entry_amd64.go
@@ -17,7 +17,7 @@
package ring0
import (
- "syscall"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// This is an assembly function.
@@ -41,7 +41,7 @@ func swapgs()
// The return code is the vector that interrupted execution.
//
// See stubs.go for a note regarding the frame size of this function.
-func sysret(*CPU, *syscall.PtraceRegs) Vector
+func sysret(*CPU, *arch.Registers) Vector
// "iret is the cadillac of CPL switching."
//
@@ -50,7 +50,7 @@ func sysret(*CPU, *syscall.PtraceRegs) Vector
// iret is nearly identical to sysret, except an iret is used to fully restore
// all user state. This must be called in cases where all registers need to be
// restored.
-func iret(*CPU, *syscall.PtraceRegs) Vector
+func iret(*CPU, *arch.Registers) Vector
// exception is the generic exception entry.
//
diff --git a/pkg/sentry/platform/ring0/entry_arm64.go b/pkg/sentry/platform/ring0/entry_arm64.go
new file mode 100644
index 000000000..62a93f3d6
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_arm64.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// This is an assembly function.
+//
+// The sysenter function is invoked in two situations:
+//
+// (1) The guest kernel has executed a system call.
+// (2) The guest application has executed a system call.
+//
+// The interrupt flag is examined to determine whether the system call was
+// executed from kernel mode or not and the appropriate stub is called.
+
+func El1_sync_invalid()
+func El1_irq_invalid()
+func El1_fiq_invalid()
+func El1_error_invalid()
+
+func El1_sync()
+func El1_irq()
+func El1_fiq()
+func El1_error()
+
+func El0_sync()
+func El0_irq()
+func El0_fiq()
+func El0_error()
+
+func El0_sync_invalid()
+func El0_irq_invalid()
+func El0_fiq_invalid()
+func El0_error_invalid()
+
+func Vectors()
+
+// Start is the CPU entrypoint.
+//
+// The CPU state will be set to c.Registers().
+func Start()
+func kernelExitToEl1()
+
+func kernelExitToEl0()
+
+// Shutdown execution
+func Shutdown()
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
new file mode 100644
index 000000000..9d29b7168
--- /dev/null
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -0,0 +1,786 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+// NB: Offsets are programatically generated (see BUILD).
+//
+// This file is concatenated with the definitions.
+
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+
+// ERET returns using the ELR and SPSR for the current exception level.
+#define ERET() \
+ WORD $0xd69f03e0
+
+// RSV_REG is a register that holds el1 information temporarily.
+#define RSV_REG R18_PLATFORM
+
+// RSV_REG_APP is a register that holds el0 information temporarily.
+#define RSV_REG_APP R9
+
+#define FPEN_NOTRAP 0x3
+#define FPEN_SHIFT 20
+
+#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+
+// sctlr_el1: system control register el1.
+#define SCTLR_M 1 << 0
+#define SCTLR_C 1 << 2
+#define SCTLR_I 1 << 12
+#define SCTLR_UCT 1 << 15
+
+#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT)
+
+// cntkctl_el1: counter-timer kernel control register el1.
+#define CNTKCTL_EL0PCTEN 1 << 0
+#define CNTKCTL_EL0VCTEN 1 << 1
+
+#define CNTKCTL_EL1_DEFAULT (CNTKCTL_EL0PCTEN | CNTKCTL_EL0VCTEN)
+
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not saved: R9, R18.
+#define REGISTERS_SAVE(reg, offset) \
+ MOVD R0, offset+PTRACE_R0(reg); \
+ MOVD R1, offset+PTRACE_R1(reg); \
+ MOVD R2, offset+PTRACE_R2(reg); \
+ MOVD R3, offset+PTRACE_R3(reg); \
+ MOVD R4, offset+PTRACE_R4(reg); \
+ MOVD R5, offset+PTRACE_R5(reg); \
+ MOVD R6, offset+PTRACE_R6(reg); \
+ MOVD R7, offset+PTRACE_R7(reg); \
+ MOVD R8, offset+PTRACE_R8(reg); \
+ MOVD R10, offset+PTRACE_R10(reg); \
+ MOVD R11, offset+PTRACE_R11(reg); \
+ MOVD R12, offset+PTRACE_R12(reg); \
+ MOVD R13, offset+PTRACE_R13(reg); \
+ MOVD R14, offset+PTRACE_R14(reg); \
+ MOVD R15, offset+PTRACE_R15(reg); \
+ MOVD R16, offset+PTRACE_R16(reg); \
+ MOVD R17, offset+PTRACE_R17(reg); \
+ MOVD R19, offset+PTRACE_R19(reg); \
+ MOVD R20, offset+PTRACE_R20(reg); \
+ MOVD R21, offset+PTRACE_R21(reg); \
+ MOVD R22, offset+PTRACE_R22(reg); \
+ MOVD R23, offset+PTRACE_R23(reg); \
+ MOVD R24, offset+PTRACE_R24(reg); \
+ MOVD R25, offset+PTRACE_R25(reg); \
+ MOVD R26, offset+PTRACE_R26(reg); \
+ MOVD R27, offset+PTRACE_R27(reg); \
+ MOVD g, offset+PTRACE_R28(reg); \
+ MOVD R29, offset+PTRACE_R29(reg); \
+ MOVD R30, offset+PTRACE_R30(reg);
+
+// Loads a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not loaded: R9, R18.
+#define REGISTERS_LOAD(reg, offset) \
+ MOVD offset+PTRACE_R0(reg), R0; \
+ MOVD offset+PTRACE_R1(reg), R1; \
+ MOVD offset+PTRACE_R2(reg), R2; \
+ MOVD offset+PTRACE_R3(reg), R3; \
+ MOVD offset+PTRACE_R4(reg), R4; \
+ MOVD offset+PTRACE_R5(reg), R5; \
+ MOVD offset+PTRACE_R6(reg), R6; \
+ MOVD offset+PTRACE_R7(reg), R7; \
+ MOVD offset+PTRACE_R8(reg), R8; \
+ MOVD offset+PTRACE_R10(reg), R10; \
+ MOVD offset+PTRACE_R11(reg), R11; \
+ MOVD offset+PTRACE_R12(reg), R12; \
+ MOVD offset+PTRACE_R13(reg), R13; \
+ MOVD offset+PTRACE_R14(reg), R14; \
+ MOVD offset+PTRACE_R15(reg), R15; \
+ MOVD offset+PTRACE_R16(reg), R16; \
+ MOVD offset+PTRACE_R17(reg), R17; \
+ MOVD offset+PTRACE_R19(reg), R19; \
+ MOVD offset+PTRACE_R20(reg), R20; \
+ MOVD offset+PTRACE_R21(reg), R21; \
+ MOVD offset+PTRACE_R22(reg), R22; \
+ MOVD offset+PTRACE_R23(reg), R23; \
+ MOVD offset+PTRACE_R24(reg), R24; \
+ MOVD offset+PTRACE_R25(reg), R25; \
+ MOVD offset+PTRACE_R26(reg), R26; \
+ MOVD offset+PTRACE_R27(reg), R27; \
+ MOVD offset+PTRACE_R28(reg), g; \
+ MOVD offset+PTRACE_R29(reg), R29; \
+ MOVD offset+PTRACE_R30(reg), R30;
+
+// NOP-s
+#define nop31Instructions() \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f; \
+ WORD $0xd503201f;
+
+#define ESR_ELx_EC_UNKNOWN (0x00)
+#define ESR_ELx_EC_WFx (0x01)
+/* Unallocated EC: 0x02 */
+#define ESR_ELx_EC_CP15_32 (0x03)
+#define ESR_ELx_EC_CP15_64 (0x04)
+#define ESR_ELx_EC_CP14_MR (0x05)
+#define ESR_ELx_EC_CP14_LS (0x06)
+#define ESR_ELx_EC_FP_ASIMD (0x07)
+#define ESR_ELx_EC_CP10_ID (0x08) /* EL2 only */
+#define ESR_ELx_EC_PAC (0x09) /* EL2 and above */
+/* Unallocated EC: 0x0A - 0x0B */
+#define ESR_ELx_EC_CP14_64 (0x0C)
+/* Unallocated EC: 0x0d */
+#define ESR_ELx_EC_ILL (0x0E)
+/* Unallocated EC: 0x0F - 0x10 */
+#define ESR_ELx_EC_SVC32 (0x11)
+#define ESR_ELx_EC_HVC32 (0x12) /* EL2 only */
+#define ESR_ELx_EC_SMC32 (0x13) /* EL2 and above */
+/* Unallocated EC: 0x14 */
+#define ESR_ELx_EC_SVC64 (0x15)
+#define ESR_ELx_EC_HVC64 (0x16) /* EL2 and above */
+#define ESR_ELx_EC_SMC64 (0x17) /* EL2 and above */
+#define ESR_ELx_EC_SYS64 (0x18)
+#define ESR_ELx_EC_SVE (0x19)
+/* Unallocated EC: 0x1A - 0x1E */
+#define ESR_ELx_EC_IMP_DEF (0x1f) /* EL3 only */
+#define ESR_ELx_EC_IABT_LOW (0x20)
+#define ESR_ELx_EC_IABT_CUR (0x21)
+#define ESR_ELx_EC_PC_ALIGN (0x22)
+/* Unallocated EC: 0x23 */
+#define ESR_ELx_EC_DABT_LOW (0x24)
+#define ESR_ELx_EC_DABT_CUR (0x25)
+#define ESR_ELx_EC_SP_ALIGN (0x26)
+/* Unallocated EC: 0x27 */
+#define ESR_ELx_EC_FP_EXC32 (0x28)
+/* Unallocated EC: 0x29 - 0x2B */
+#define ESR_ELx_EC_FP_EXC64 (0x2C)
+/* Unallocated EC: 0x2D - 0x2E */
+#define ESR_ELx_EC_SERROR (0x2F)
+#define ESR_ELx_EC_BREAKPT_LOW (0x30)
+#define ESR_ELx_EC_BREAKPT_CUR (0x31)
+#define ESR_ELx_EC_SOFTSTP_LOW (0x32)
+#define ESR_ELx_EC_SOFTSTP_CUR (0x33)
+#define ESR_ELx_EC_WATCHPT_LOW (0x34)
+#define ESR_ELx_EC_WATCHPT_CUR (0x35)
+/* Unallocated EC: 0x36 - 0x37 */
+#define ESR_ELx_EC_BKPT32 (0x38)
+/* Unallocated EC: 0x39 */
+#define ESR_ELx_EC_VECTOR32 (0x3A) /* EL2 only */
+/* Unallocted EC: 0x3B */
+#define ESR_ELx_EC_BRK64 (0x3C)
+/* Unallocated EC: 0x3D - 0x3F */
+#define ESR_ELx_EC_MAX (0x3F)
+
+#define ESR_ELx_EC_SHIFT (26)
+#define ESR_ELx_EC_MASK (UL(0x3F) << ESR_ELx_EC_SHIFT)
+#define ESR_ELx_EC(esr) (((esr) & ESR_ELx_EC_MASK) >> ESR_ELx_EC_SHIFT)
+
+#define ESR_ELx_IL_SHIFT (25)
+#define ESR_ELx_IL (UL(1) << ESR_ELx_IL_SHIFT)
+#define ESR_ELx_ISS_MASK (ESR_ELx_IL - 1)
+
+/* ISS field definitions shared by different classes */
+#define ESR_ELx_WNR_SHIFT (6)
+#define ESR_ELx_WNR (UL(1) << ESR_ELx_WNR_SHIFT)
+
+/* Asynchronous Error Type */
+#define ESR_ELx_IDS_SHIFT (24)
+#define ESR_ELx_IDS (UL(1) << ESR_ELx_IDS_SHIFT)
+#define ESR_ELx_AET_SHIFT (10)
+#define ESR_ELx_AET (UL(0x7) << ESR_ELx_AET_SHIFT)
+
+#define ESR_ELx_AET_UC (UL(0) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UEU (UL(1) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UEO (UL(2) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_UER (UL(3) << ESR_ELx_AET_SHIFT)
+#define ESR_ELx_AET_CE (UL(6) << ESR_ELx_AET_SHIFT)
+
+/* Shared ISS field definitions for Data/Instruction aborts */
+#define ESR_ELx_SET_SHIFT (11)
+#define ESR_ELx_SET_MASK (UL(3) << ESR_ELx_SET_SHIFT)
+#define ESR_ELx_FnV_SHIFT (10)
+#define ESR_ELx_FnV (UL(1) << ESR_ELx_FnV_SHIFT)
+#define ESR_ELx_EA_SHIFT (9)
+#define ESR_ELx_EA (UL(1) << ESR_ELx_EA_SHIFT)
+#define ESR_ELx_S1PTW_SHIFT (7)
+#define ESR_ELx_S1PTW (UL(1) << ESR_ELx_S1PTW_SHIFT)
+
+/* Shared ISS fault status code(IFSC/DFSC) for Data/Instruction aborts */
+#define ESR_ELx_FSC (0x3F)
+#define ESR_ELx_FSC_TYPE (0x3C)
+#define ESR_ELx_FSC_EXTABT (0x10)
+#define ESR_ELx_FSC_SERROR (0x11)
+#define ESR_ELx_FSC_ACCESS (0x08)
+#define ESR_ELx_FSC_FAULT (0x04)
+#define ESR_ELx_FSC_PERM (0x0C)
+
+/* ISS field definitions for Data Aborts */
+#define ESR_ELx_ISV_SHIFT (24)
+#define ESR_ELx_ISV (UL(1) << ESR_ELx_ISV_SHIFT)
+#define ESR_ELx_SAS_SHIFT (22)
+#define ESR_ELx_SAS (UL(3) << ESR_ELx_SAS_SHIFT)
+#define ESR_ELx_SSE_SHIFT (21)
+#define ESR_ELx_SSE (UL(1) << ESR_ELx_SSE_SHIFT)
+#define ESR_ELx_SRT_SHIFT (16)
+#define ESR_ELx_SRT_MASK (UL(0x1F) << ESR_ELx_SRT_SHIFT)
+#define ESR_ELx_SF_SHIFT (15)
+#define ESR_ELx_SF (UL(1) << ESR_ELx_SF_SHIFT)
+#define ESR_ELx_AR_SHIFT (14)
+#define ESR_ELx_AR (UL(1) << ESR_ELx_AR_SHIFT)
+#define ESR_ELx_CM_SHIFT (8)
+#define ESR_ELx_CM (UL(1) << ESR_ELx_CM_SHIFT)
+
+/* ISS field definitions for exceptions taken in to Hyp */
+#define ESR_ELx_CV (UL(1) << 24)
+#define ESR_ELx_COND_SHIFT (20)
+#define ESR_ELx_COND_MASK (UL(0xF) << ESR_ELx_COND_SHIFT)
+#define ESR_ELx_WFx_ISS_TI (UL(1) << 0)
+#define ESR_ELx_WFx_ISS_WFI (UL(0) << 0)
+#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
+#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+
+// LOAD_KERNEL_ADDRESS loads a kernel address.
+#define LOAD_KERNEL_ADDRESS(from, to) \
+ MOVD from, to; \
+ ORR $0xffff000000000000, to, to;
+
+// LOAD_KERNEL_STACK loads the kernel temporary stack.
+#define LOAD_KERNEL_STACK(from) \
+ LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \
+ MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \
+ MOVD RSV_REG, RSP; \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ ISB $15; \
+ DSB $15;
+
+// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
+#define SWITCH_TO_APP_PAGETABLE(from) \
+ MOVD CPU_TTBR0_APP(from), RSV_REG; \
+ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ ISB $15; \
+ DSB $15;
+
+// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
+#define SWITCH_TO_KVM_PAGETABLE(from) \
+ MOVD CPU_TTBR0_KVM(from), RSV_REG; \
+ WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
+ ISB $15; \
+ DSB $15;
+
+#define VFP_ENABLE \
+ MOVD $FPEN_ENABLE, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
+#define VFP_DISABLE \
+ MOVD $0x0, R0; \
+ WORD $0xd5181040; \ //MSR R0, CPACR_EL1
+ ISB $15;
+
+// KERNEL_ENTRY_FROM_EL0 is the entry code of the vcpu from el0 to el1.
+#define KERNEL_ENTRY_FROM_EL0 \
+ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable.
+ SWITCH_TO_KVM_PAGETABLE(RSV_REG); \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer.
+ REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context.
+ MOVD RSV_REG_APP, R20; \
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \
+ ADD $16, RSP, RSP; \
+ MOVD RSV_REG, PTRACE_R18(R20); \
+ MOVD RSV_REG_APP, PTRACE_R9(R20); \
+ MOVD R20, RSV_REG_APP; \
+ WORD $0xd5384003; \ // MRS SPSR_EL1, R3
+ MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \
+ MRS ELR_EL1, R3; \
+ MOVD R3, PTRACE_PC(RSV_REG_APP); \
+ WORD $0xd5384103; \ // MRS SP_EL0, R3
+ MOVD R3, PTRACE_SP(RSV_REG_APP);
+
+// KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1.
+#define KERNEL_ENTRY_FROM_EL1 \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context.
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
+ WORD $0xd5384004; \ // MRS SPSR_EL1, R4
+ MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
+ MRS ELR_EL1, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \
+ MOVD RSP, R4; \
+ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
+ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+
+// Halt halts execution.
+TEXT ·Halt(SB),NOSPLIT,$0
+ // Clear bluepill.
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ CMP RSV_REG, R9
+ BNE mmio_exit
+ MOVD $0, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+
+ // Flush dcache.
+ WORD $0xd5087e52 // DC CISW
+mmio_exit:
+ // Disable fpsimd.
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, CPU_LAZY_VFP(RSV_REG)
+ VFP_DISABLE
+
+ // Trigger MMIO_EXIT/_KVM_HYPERCALL_VMEXIT.
+ //
+ // To keep it simple, I used the address of exception table as the
+ // MMIO base address, so that I can trigger a MMIO-EXIT by forcibly writing
+ // a read-only space.
+ // Also, the length is engough to match a sufficient number of hypercall ID.
+ // Then, in host user space, I can calculate this address to find out
+ // which hypercall.
+ MRS VBAR_EL1, R9
+ MOVD R0, 0x0(R9)
+
+ // Flush dcahce.
+ WORD $0xd5087e52 // DC CISW
+
+ RET
+
+// HaltAndResume halts execution and point the pointer to the resume function.
+TEXT ·HaltAndResume(SB),NOSPLIT,$0
+ BL ·Halt(SB)
+ B ·kernelExitToEl1(SB) // Resume.
+
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0
+ WORD $0xd538d092 // MRS TPIDR_EL1, R18
+ MOVD CPU_SELF(RSV_REG), R3 // Load vCPU.
+ MOVD R3, 8(RSP) // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ B ·kernelExitToEl1(SB) // Resume.
+
+// Shutdown stops the guest.
+TEXT ·Shutdown(SB),NOSPLIT,$0
+ // PSCI EVENT.
+ MOVD $0x84000009, R0
+ HVC $0
+
+// See kernel.go.
+TEXT ·Current(SB),NOSPLIT,$0-8
+ MOVD CPU_SELF(RSV_REG), R8
+ MOVD R8, ret+0(FP)
+ RET
+
+#define STACK_FRAME_SIZE 16
+
+// kernelExitToEl0 is the entrypoint for application in guest_el0.
+// Prepare the vcpu environment for container application.
+TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
+ // Step1, save sentry context into memory.
+ MRS TPIDR_EL1, RSV_REG
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS)
+ MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG)
+
+ WORD $0xd5384003 // MRS SPSR_EL1, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG)
+ MOVD R30, CPU_REGISTERS+PTRACE_PC(RSV_REG)
+ MOVD RSP, R3
+ MOVD R3, CPU_REGISTERS+PTRACE_SP(RSV_REG)
+
+ MOVD CPU_REGISTERS+PTRACE_R3(RSV_REG), R3
+
+ // Step2, switch to temporary stack.
+ LOAD_KERNEL_STACK(RSV_REG)
+
+ // Step3, load app context pointer.
+ MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP
+
+ // Step4, prepare the environment for container application.
+ // set sp_el0.
+ MOVD PTRACE_SP(RSV_REG_APP), R1
+ WORD $0xd5184101 //MSR R1, SP_EL0
+ // set pc.
+ MOVD PTRACE_PC(RSV_REG_APP), R1
+ MSR R1, ELR_EL1
+ // set pstate.
+ MOVD PTRACE_PSTATE(RSV_REG_APP), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ // RSV_REG & RSV_REG_APP will be loaded at the end.
+ REGISTERS_LOAD(RSV_REG_APP, 0)
+
+ // switch to user pagetable.
+ MOVD PTRACE_R18(RSV_REG_APP), RSV_REG
+ MOVD PTRACE_R9(RSV_REG_APP), RSV_REG_APP
+
+ SUB $STACK_FRAME_SIZE, RSP, RSP
+ STP (RSV_REG, RSV_REG_APP), 16*0(RSP)
+
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ SWITCH_TO_APP_PAGETABLE(RSV_REG)
+
+ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
+ ADD $STACK_FRAME_SIZE, RSP, RSP
+
+ ISB $15
+ ERET()
+
+// kernelExitToEl1 is the entrypoint for sentry in guest_el1.
+// Prepare the vcpu environment for sentry.
+TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ MOVD CPU_REGISTERS+PTRACE_PSTATE(RSV_REG), R1
+ WORD $0xd5184001 //MSR R1, SPSR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1
+ MSR R1, ELR_EL1
+
+ MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
+ MOVD R1, RSP
+
+ REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
+ MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
+
+ ERET()
+
+// Start is the CPU entrypoint.
+TEXT ·Start(SB),NOSPLIT,$0
+ // Flush dcache.
+ WORD $0xd5087e52 // DC CISW
+ // Init.
+ MOVD $SCTLR_EL1_DEFAULT, R1
+ MSR R1, SCTLR_EL1
+
+ MOVD $CNTKCTL_EL1_DEFAULT, R1
+ MSR R1, CNTKCTL_EL1
+
+ MOVD R8, RSV_REG
+ ORR $0xffff000000000000, RSV_REG, RSV_REG
+ WORD $0xd518d092 //MSR R18, TPIDR_EL1
+
+ B ·kernelExitToEl1(SB)
+
+// El1_sync_invalid is the handler for an invalid EL1_sync.
+TEXT ·El1_sync_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_irq_invalid is the handler for an invalid El1_irq.
+TEXT ·El1_irq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_fiq_invalid is the handler for an invalid El1_fiq.
+TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_error_invalid is the handler for an invalid El1_error.
+TEXT ·El1_error_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_sync is the handler for El1_sync.
+TEXT ·El1_sync(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL1
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ LSR $ESR_ELx_EC_SHIFT, R25, R24
+ CMP $ESR_ELx_EC_DABT_CUR, R24
+ BEQ el1_da
+ CMP $ESR_ELx_EC_IABT_CUR, R24
+ BEQ el1_ia
+ CMP $ESR_ELx_EC_SYS64, R24
+ BEQ el1_undef
+ CMP $ESR_ELx_EC_SP_ALIGN, R24
+ BEQ el1_sp_pc
+ CMP $ESR_ELx_EC_PC_ALIGN, R24
+ BEQ el1_sp_pc
+ CMP $ESR_ELx_EC_UNKNOWN, R24
+ BEQ el1_undef
+ CMP $ESR_ELx_EC_SVC64, R24
+ BEQ el1_svc
+ CMP $ESR_ELx_EC_BREAKPT_CUR, R24
+ BGE el1_dbg
+ CMP $ESR_ELx_EC_FP_ASIMD, R24
+ BEQ el1_fpsimd_acc
+ B el1_invalid
+
+el1_da:
+el1_ia:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+
+ MOVD $PageFault, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·HaltAndResume(SB)
+
+el1_sp_pc:
+ B ·Shutdown(SB)
+
+el1_undef:
+ B ·Shutdown(SB)
+
+el1_svc:
+ MOVD $0, CPU_ERROR_CODE(RSV_REG)
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+ B ·HaltEl1SvcAndResume(SB)
+
+el1_dbg:
+ B ·Shutdown(SB)
+
+el1_fpsimd_acc:
+ VFP_ENABLE
+ B ·kernelExitToEl1(SB) // Resume.
+
+el1_invalid:
+ B ·Shutdown(SB)
+
+// El1_irq is the handler for El1_irq.
+TEXT ·El1_irq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_fiq is the handler for El1_fiq.
+TEXT ·El1_fiq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El1_error is the handler for El1_error.
+TEXT ·El1_error(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// El0_sync is the handler for El0_sync.
+TEXT ·El0_sync(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL0
+ WORD $0xd5385219 // MRS ESR_EL1, R25
+ LSR $ESR_ELx_EC_SHIFT, R25, R24
+ CMP $ESR_ELx_EC_SVC64, R24
+ BEQ el0_svc
+ CMP $ESR_ELx_EC_DABT_LOW, R24
+ BEQ el0_da
+ CMP $ESR_ELx_EC_IABT_LOW, R24
+ BEQ el0_ia
+ CMP $ESR_ELx_EC_FP_ASIMD, R24
+ BEQ el0_fpsimd_acc
+ CMP $ESR_ELx_EC_SVE, R24
+ BEQ el0_sve_acc
+ CMP $ESR_ELx_EC_FP_EXC64, R24
+ BEQ el0_fpsimd_exc
+ CMP $ESR_ELx_EC_SP_ALIGN, R24
+ BEQ el0_sp_pc
+ CMP $ESR_ELx_EC_PC_ALIGN, R24
+ BEQ el0_sp_pc
+ CMP $ESR_ELx_EC_UNKNOWN, R24
+ BEQ el0_undef
+ CMP $ESR_ELx_EC_BREAKPT_LOW, R24
+ BGE el0_dbg
+ B el0_invalid
+
+el0_svc:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+
+ MOVD $0, CPU_ERROR_CODE(RSV_REG) // Clear error code.
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $Syscall, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·kernelExitToEl1(SB)
+
+el0_da:
+el0_ia:
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $PageFault, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ MRS ESR_EL1, R3
+ MOVD R3, CPU_ERROR_CODE(RSV_REG)
+
+ B ·kernelExitToEl1(SB)
+
+el0_fpsimd_acc:
+ B ·Shutdown(SB)
+
+el0_sve_acc:
+ B ·Shutdown(SB)
+
+el0_fpsimd_exc:
+ B ·Shutdown(SB)
+
+el0_sp_pc:
+ B ·Shutdown(SB)
+
+el0_undef:
+ MOVD $El0Sync_undef, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·kernelExitToEl1(SB)
+
+el0_dbg:
+ B ·Shutdown(SB)
+
+el0_invalid:
+ B ·Shutdown(SB)
+
+TEXT ·El0_irq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_fiq(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_error(SB),NOSPLIT,$0
+ KERNEL_ENTRY_FROM_EL0
+ WORD $0xd538d092 //MRS TPIDR_EL1, R18
+ WORD $0xd538601a //MRS FAR_EL1, R26
+
+ MOVD R26, CPU_FAULT_ADDR(RSV_REG)
+
+ MOVD $1, R3
+ MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user.
+
+ MOVD $VirtualizationException, R3
+ MOVD R3, CPU_VECTOR_CODE(RSV_REG)
+
+ B ·HaltAndResume(SB)
+
+TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_irq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+TEXT ·El0_error_invalid(SB),NOSPLIT,$0
+ B ·Shutdown(SB)
+
+// Vectors implements exception vector table.
+TEXT ·Vectors(SB),NOSPLIT,$0
+ B ·El1_sync_invalid(SB)
+ nop31Instructions()
+ B ·El1_irq_invalid(SB)
+ nop31Instructions()
+ B ·El1_fiq_invalid(SB)
+ nop31Instructions()
+ B ·El1_error_invalid(SB)
+ nop31Instructions()
+
+ B ·El1_sync(SB)
+ nop31Instructions()
+ B ·El1_irq(SB)
+ nop31Instructions()
+ B ·El1_fiq(SB)
+ nop31Instructions()
+ B ·El1_error(SB)
+ nop31Instructions()
+
+ B ·El0_sync(SB)
+ nop31Instructions()
+ B ·El0_irq(SB)
+ nop31Instructions()
+ B ·El0_fiq(SB)
+ nop31Instructions()
+ B ·El0_error(SB)
+ nop31Instructions()
+
+ B ·El0_sync_invalid(SB)
+ nop31Instructions()
+ B ·El0_irq_invalid(SB)
+ nop31Instructions()
+ B ·El0_fiq_invalid(SB)
+ nop31Instructions()
+ B ·El0_error_invalid(SB)
+ nop31Instructions()
+
+ // The exception-vector-table is required to be 11-bits aligned.
+ // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
+ // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
+ WORD $0xd503201f //nop
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
+ WORD $0xd503201f
+ nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 780bf9a66..549f3d228 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -1,25 +1,34 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
- name = "defs_impl",
- out = "defs_impl.go",
+ name = "defs_impl_arm64",
+ out = "defs_impl_arm64.go",
package = "main",
- template = "//pkg/sentry/platform/ring0:defs",
+ template = "//pkg/sentry/platform/ring0:defs_arm64",
+)
+
+go_template_instance(
+ name = "defs_impl_amd64",
+ out = "defs_impl_amd64.go",
+ package = "main",
+ template = "//pkg/sentry/platform/ring0:defs_amd64",
)
go_binary(
name = "gen_offsets",
srcs = [
- "defs_impl.go",
+ "defs_impl_amd64.go",
+ "defs_impl_arm64.go",
"main.go",
],
visibility = ["//pkg/sentry/platform/ring0:__pkg__"],
deps = [
"//pkg/cpuid",
+ "//pkg/sentry/arch",
"//pkg/sentry/platform/ring0/pagetables",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go
index 900c0bba7..021693791 100644
--- a/pkg/sentry/platform/ring0/kernel.go
+++ b/pkg/sentry/platform/ring0/kernel.go
@@ -31,23 +31,39 @@ type defaultHooks struct{}
// KernelSyscall implements Hooks.KernelSyscall.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelSyscall() { Halt() }
+func (defaultHooks) KernelSyscall() {
+ Halt()
+}
// KernelException implements Hooks.KernelException.
//
+// +checkescape:all
+//
//go:nosplit
-func (defaultHooks) KernelException(Vector) { Halt() }
+func (defaultHooks) KernelException(Vector) {
+ Halt()
+}
// kernelSyscall is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelSyscall(c *CPU) { c.hooks.KernelSyscall() }
+func kernelSyscall(c *CPU) {
+ c.hooks.KernelSyscall()
+}
// kernelException is a trampoline.
//
+// +checkescape:hard,stack
+//
//go:nosplit
-func kernelException(c *CPU, vector Vector) { c.hooks.KernelException(vector) }
+func kernelException(c *CPU, vector Vector) {
+ c.hooks.KernelException(vector)
+}
// Init initializes a new CPU.
//
diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go
index 0feff8778..d37981dbf 100644
--- a/pkg/sentry/platform/ring0/kernel_amd64.go
+++ b/pkg/sentry/platform/ring0/kernel_amd64.go
@@ -178,6 +178,8 @@ func IsCanonical(addr uint64) bool {
//
// Precondition: the Rip, Rsp, Fs and Gs registers must be canonical.
//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
userCR3 := switchOpts.PageTables.CR3(!switchOpts.Flush, switchOpts.UserPCID)
@@ -192,9 +194,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
// Perform the switch.
swapgs() // GS will be swapped on return.
- WriteFS(uintptr(regs.Fs_base)) // Set application FS.
- WriteGS(uintptr(regs.Gs_base)) // Set application GS.
- LoadFloatingPoint(switchOpts.FloatingPointState) // Copy in floating point.
+ WriteFS(uintptr(regs.Fs_base)) // escapes: no. Set application FS.
+ WriteGS(uintptr(regs.Gs_base)) // escapes: no. Set application GS.
+ LoadFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy in floating point.
jumpToKernel() // Switch to upper half.
writeCR3(uintptr(userCR3)) // Change to user address space.
if switchOpts.FullRestore {
@@ -204,8 +206,8 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
}
writeCR3(uintptr(kernelCR3)) // Return to kernel address space.
jumpToUser() // Return to lower half.
- SaveFloatingPoint(switchOpts.FloatingPointState) // Copy out floating point.
- WriteFS(uintptr(c.registers.Fs_base)) // Restore kernel FS.
+ SaveFloatingPoint(switchOpts.FloatingPointState) // escapes: no. Copy out floating point.
+ WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
return
}
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
new file mode 100644
index 000000000..d0afa1aaa
--- /dev/null
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -0,0 +1,72 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// HaltAndResume halts execution and point the pointer to the resume function.
+//go:nosplit
+func HaltAndResume()
+
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+//go:nosplit
+func HaltEl1SvcAndResume()
+
+// init initializes architecture-specific state.
+func (k *Kernel) init(opts KernelOpts) {
+ // Save the root page tables.
+ k.PageTables = opts.PageTables
+}
+
+// init initializes architecture-specific state.
+func (c *CPU) init() {
+ // Set the kernel stack pointer(virtual address).
+ c.registers.Sp = uint64(c.StackTop())
+
+}
+
+// StackTop returns the kernel's stack address.
+//
+//go:nosplit
+func (c *CPU) StackTop() uint64 {
+ return uint64(kernelAddr(&c.stack[0])) + uint64(len(c.stack))
+}
+
+// IsCanonical indicates whether addr is canonical per the arm64 spec.
+//
+//go:nosplit
+func IsCanonical(addr uint64) bool {
+ return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000
+}
+
+//go:nosplit
+func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
+ regs := switchOpts.Registers
+
+ regs.Pstate &= ^uint64(PsrFlagsClear)
+ regs.Pstate |= UserFlagsSet
+
+ LoadFloatingPoint(switchOpts.FloatingPointState)
+ SetTLS(regs.TPIDR_EL0)
+
+ kernelExitToEl0()
+
+ regs.TPIDR_EL0 = GetTLS()
+ SaveFloatingPoint(switchOpts.FloatingPointState)
+
+ vector = c.vecCode
+
+ return
+}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
new file mode 100644
index 000000000..00e52c8af
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -0,0 +1,58 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+// CPACREL1 returns the value of the CPACR_EL1 register.
+func CPACREL1() (value uintptr)
+
+// FPCR returns the value of FPCR register.
+func GetFPCR() (value uintptr)
+
+// SetFPCR writes the FPCR value.
+func SetFPCR(value uintptr)
+
+// FPSR returns the value of FPSR register.
+func GetFPSR() (value uintptr)
+
+// SetFPSR writes the FPSR value.
+func SetFPSR(value uintptr)
+
+// SaveVRegs saves V0-V31 registers.
+// V0-V31: 32 128-bit registers for floating point and simd.
+func SaveVRegs(*byte)
+
+// LoadVRegs loads V0-V31 registers.
+func LoadVRegs(*byte)
+
+// LoadFloatingPoint loads floating point state.
+func LoadFloatingPoint(*byte)
+
+// SaveFloatingPoint saves floating point state.
+func SaveFloatingPoint(*byte)
+
+// GetTLS returns the value of TPIDR_EL0 register.
+func GetTLS() (value uint64)
+
+// SetTLS writes the TPIDR_EL0 value.
+func SetTLS(value uint64)
+
+// Init sets function pointers based on architectural features.
+//
+// This must be called prior to using ring0.
+func Init() {
+ rewriteVectors()
+}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
new file mode 100644
index 000000000..86bfbe46f
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -0,0 +1,217 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "funcdata.h"
+#include "textflag.h"
+
+TEXT ·GetTLS(SB),NOSPLIT,$0-8
+ MRS TPIDR_EL0, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·SetTLS(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ MSR R1, TPIDR_EL0
+ RET
+
+TEXT ·CPACREL1(SB),NOSPLIT,$0-8
+ WORD $0xd5381041 // MRS CPACR_EL1, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·GetFPCR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4201 // MRS NZCV, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·GetFPSR(SB),NOSPLIT,$0-8
+ WORD $0xd53b4421 // MRS FPSR, R1
+ MOVD R1, ret+0(FP)
+ RET
+
+TEXT ·SetFPCR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4201 // MSR R1, NZCV
+ RET
+
+TEXT ·SetFPSR(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ WORD $0xd51b4421 // MSR R1, FPSR
+ RET
+
+TEXT ·SaveVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD F0, 16*1(R0)
+ FMOVD F1, 16*2(R0)
+ FMOVD F2, 16*3(R0)
+ FMOVD F3, 16*4(R0)
+ FMOVD F4, 16*5(R0)
+ FMOVD F5, 16*6(R0)
+ FMOVD F6, 16*7(R0)
+ FMOVD F7, 16*8(R0)
+ FMOVD F8, 16*9(R0)
+ FMOVD F9, 16*10(R0)
+ FMOVD F10, 16*11(R0)
+ FMOVD F11, 16*12(R0)
+ FMOVD F12, 16*13(R0)
+ FMOVD F13, 16*14(R0)
+ FMOVD F14, 16*15(R0)
+ FMOVD F15, 16*16(R0)
+ FMOVD F16, 16*17(R0)
+ FMOVD F17, 16*18(R0)
+ FMOVD F18, 16*19(R0)
+ FMOVD F19, 16*20(R0)
+ FMOVD F20, 16*21(R0)
+ FMOVD F21, 16*22(R0)
+ FMOVD F22, 16*23(R0)
+ FMOVD F23, 16*24(R0)
+ FMOVD F24, 16*25(R0)
+ FMOVD F25, 16*26(R0)
+ FMOVD F26, 16*27(R0)
+ FMOVD F27, 16*28(R0)
+ FMOVD F28, 16*29(R0)
+ FMOVD F29, 16*30(R0)
+ FMOVD F30, 16*31(R0)
+ FMOVD F31, 16*32(R0)
+ ISB $15
+
+ RET
+
+TEXT ·LoadVRegs(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ // Skip aarch64_ctx, fpsr, fpcr.
+ FMOVD 16*1(R0), F0
+ FMOVD 16*2(R0), F1
+ FMOVD 16*3(R0), F2
+ FMOVD 16*4(R0), F3
+ FMOVD 16*5(R0), F4
+ FMOVD 16*6(R0), F5
+ FMOVD 16*7(R0), F6
+ FMOVD 16*8(R0), F7
+ FMOVD 16*9(R0), F8
+ FMOVD 16*10(R0), F9
+ FMOVD 16*11(R0), F10
+ FMOVD 16*12(R0), F11
+ FMOVD 16*13(R0), F12
+ FMOVD 16*14(R0), F13
+ FMOVD 16*15(R0), F14
+ FMOVD 16*16(R0), F15
+ FMOVD 16*17(R0), F16
+ FMOVD 16*18(R0), F17
+ FMOVD 16*19(R0), F18
+ FMOVD 16*20(R0), F19
+ FMOVD 16*21(R0), F20
+ FMOVD 16*22(R0), F21
+ FMOVD 16*23(R0), F22
+ FMOVD 16*24(R0), F23
+ FMOVD 16*25(R0), F24
+ FMOVD 16*26(R0), F25
+ FMOVD 16*27(R0), F26
+ FMOVD 16*28(R0), F27
+ FMOVD 16*29(R0), F28
+ FMOVD 16*30(R0), F29
+ FMOVD 16*31(R0), F30
+ FMOVD 16*32(R0), F31
+ ISB $15
+
+ RET
+
+TEXT ·LoadFloatingPoint(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ MOVD 0(R0), R1
+ MOVD R1, FPSR
+ MOVD 8(R0), R1
+ MOVD R1, NZCV
+
+ FMOVD 16*1(R0), F0
+ FMOVD 16*2(R0), F1
+ FMOVD 16*3(R0), F2
+ FMOVD 16*4(R0), F3
+ FMOVD 16*5(R0), F4
+ FMOVD 16*6(R0), F5
+ FMOVD 16*7(R0), F6
+ FMOVD 16*8(R0), F7
+ FMOVD 16*9(R0), F8
+ FMOVD 16*10(R0), F9
+ FMOVD 16*11(R0), F10
+ FMOVD 16*12(R0), F11
+ FMOVD 16*13(R0), F12
+ FMOVD 16*14(R0), F13
+ FMOVD 16*15(R0), F14
+ FMOVD 16*16(R0), F15
+ FMOVD 16*17(R0), F16
+ FMOVD 16*18(R0), F17
+ FMOVD 16*19(R0), F18
+ FMOVD 16*20(R0), F19
+ FMOVD 16*21(R0), F20
+ FMOVD 16*22(R0), F21
+ FMOVD 16*23(R0), F22
+ FMOVD 16*24(R0), F23
+ FMOVD 16*25(R0), F24
+ FMOVD 16*26(R0), F25
+ FMOVD 16*27(R0), F26
+ FMOVD 16*28(R0), F27
+ FMOVD 16*29(R0), F28
+ FMOVD 16*30(R0), F29
+ FMOVD 16*31(R0), F30
+ FMOVD 16*32(R0), F31
+
+ RET
+
+TEXT ·SaveFloatingPoint(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R0
+
+ MOVD FPSR, R1
+ MOVD R1, 0(R0)
+ MOVD NZCV, R1
+ MOVD R1, 8(R0)
+
+ FMOVD F0, 16*1(R0)
+ FMOVD F1, 16*2(R0)
+ FMOVD F2, 16*3(R0)
+ FMOVD F3, 16*4(R0)
+ FMOVD F4, 16*5(R0)
+ FMOVD F5, 16*6(R0)
+ FMOVD F6, 16*7(R0)
+ FMOVD F7, 16*8(R0)
+ FMOVD F8, 16*9(R0)
+ FMOVD F9, 16*10(R0)
+ FMOVD F10, 16*11(R0)
+ FMOVD F11, 16*12(R0)
+ FMOVD F12, 16*13(R0)
+ FMOVD F13, 16*14(R0)
+ FMOVD F14, 16*15(R0)
+ FMOVD F15, 16*16(R0)
+ FMOVD F16, 16*17(R0)
+ FMOVD F17, 16*18(R0)
+ FMOVD F18, 16*19(R0)
+ FMOVD F19, 16*20(R0)
+ FMOVD F20, 16*21(R0)
+ FMOVD F21, 16*22(R0)
+ FMOVD F22, 16*23(R0)
+ FMOVD F23, 16*24(R0)
+ FMOVD F24, 16*25(R0)
+ FMOVD F25, 16*26(R0)
+ FMOVD F26, 16*27(R0)
+ FMOVD F27, 16*28(R0)
+ FMOVD F28, 16*29(R0)
+ FMOVD F29, 16*30(R0)
+ FMOVD F30, 16*31(R0)
+ FMOVD F31, 16*32(R0)
+
+ RET
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
new file mode 100644
index 000000000..c05166fea
--- /dev/null
+++ b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
@@ -0,0 +1,108 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "reflect"
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/safecopy"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const (
+ nopInstruction = 0xd503201f
+ instSize = unsafe.Sizeof(uint32(0))
+ vectorsRawLen = 0x800
+)
+
+func unsafeSlice(addr uintptr, length int) (slice []uint32) {
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
+ hdr.Data = addr
+ hdr.Len = length / int(instSize)
+ hdr.Cap = length / int(instSize)
+ return slice
+}
+
+// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
+//
+// According to the design documentation of Arm64,
+// the start address of exception vector table should be 11-bits aligned.
+// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
+// But, we can't align a function's start address to a specific address by using golang.
+// We have raised this question in golang community:
+// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
+// This function will be removed when golang supports this feature.
+//
+// There are 2 jobs were implemented in this function:
+// 1, move the start address of exception vector table into the specific address.
+// 2, modify the offset of each instruction.
+func rewriteVectors() {
+ vectorsBegin := reflect.ValueOf(Vectors).Pointer()
+
+ // The exception-vector-table is required to be 11-bits aligned.
+ // And the size is 0x800.
+ // Please see the documentation as reference:
+ // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
+ //
+ // But, golang does not allow to set a function's address to a specific value.
+ // So, for gvisor, I defined the size of exception-vector-table as 4K,
+ // filled the 2nd 2K part with NOP-s.
+ // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
+ //
+ // So, the prerequisite for this function to work correctly is:
+ // vectorsSafeLen >= 0x1000
+ // vectorsRawLen = 0x800
+ vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
+ if vectorsSafeLen < 2*vectorsRawLen {
+ panic("Can't update vectors")
+ }
+
+ vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
+ vectorsRawLen32 := vectorsRawLen / int(instSize)
+
+ offset := vectorsBegin & (1<<11 - 1)
+ if offset != 0 {
+ offset = 1<<11 - offset
+ }
+
+ pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
+
+ _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+
+ offset = offset / instSize // By index, not bytes.
+ // Move exception-vector-table into the specific address, should uses memmove here.
+ for i := 1; i <= vectorsRawLen32; i++ {
+ vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
+ }
+
+ // Adjust branch since instruction was moved forward.
+ for i := 0; i < vectorsRawLen32; i++ {
+ if vectorsSafeTable[int(offset)+i] != nopInstruction {
+ vectorsSafeTable[int(offset)+i] -= uint32(offset)
+ }
+ }
+
+ _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
+ if errno != 0 {
+ panic(errno.Error())
+ }
+}
diff --git a/pkg/sentry/platform/ring0/offsets_amd64.go b/pkg/sentry/platform/ring0/offsets_amd64.go
index 85cc3fdad..b8ab120a0 100644
--- a/pkg/sentry/platform/ring0/offsets_amd64.go
+++ b/pkg/sentry/platform/ring0/offsets_amd64.go
@@ -20,7 +20,8 @@ import (
"fmt"
"io"
"reflect"
- "syscall"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
)
// Emit prints architecture-specific offsets.
@@ -64,7 +65,7 @@ func Emit(w io.Writer) {
fmt.Fprintf(w, "#define SyscallInt80 0x%02x\n", SyscallInt80)
fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
- p := &syscall.PtraceRegs{}
+ p := &arch.Registers{}
fmt.Fprintf(w, "\n// Ptrace registers.\n")
fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.R15).Pointer()-reflect.ValueOf(p).Pointer())
fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.R14).Pointer()-reflect.ValueOf(p).Pointer())
diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go
new file mode 100644
index 000000000..f3de962f0
--- /dev/null
+++ b/pkg/sentry/platform/ring0/offsets_arm64.go
@@ -0,0 +1,127 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package ring0
+
+import (
+ "fmt"
+ "io"
+ "reflect"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// Emit prints architecture-specific offsets.
+func Emit(w io.Writer) {
+ fmt.Fprintf(w, "// Automatically generated, do not edit.\n")
+
+ c := &CPU{}
+ fmt.Fprintf(w, "\n// CPU offsets.\n")
+ fmt.Fprintf(w, "#define CPU_SELF 0x%02x\n", reflect.ValueOf(&c.self).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_REGISTERS 0x%02x\n", reflect.ValueOf(&c.registers).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_STACK_TOP 0x%02x\n", reflect.ValueOf(&c.stack[0]).Pointer()-reflect.ValueOf(c).Pointer()+uintptr(len(c.stack)))
+ fmt.Fprintf(w, "#define CPU_ERROR_CODE 0x%02x\n", reflect.ValueOf(&c.errorCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_ERROR_TYPE 0x%02x\n", reflect.ValueOf(&c.errorType).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_FAULT_ADDR 0x%02x\n", reflect.ValueOf(&c.faultAddr).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_TTBR0_KVM 0x%02x\n", reflect.ValueOf(&c.ttbr0Kvm).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_TTBR0_APP 0x%02x\n", reflect.ValueOf(&c.ttbr0App).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer())
+ fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer())
+
+ fmt.Fprintf(w, "\n// Bits.\n")
+ fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet)
+
+ fmt.Fprintf(w, "\n// Vectors.\n")
+ fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid)
+ fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid)
+ fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid)
+ fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid)
+
+ fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync)
+ fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq)
+ fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq)
+ fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error)
+
+ fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync)
+ fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq)
+ fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq)
+ fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error)
+
+ fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid)
+ fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid)
+ fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid)
+ fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid)
+
+ fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da)
+ fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia)
+ fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc)
+ fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef)
+ fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg)
+ fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv)
+
+ fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc)
+ fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da)
+ fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia)
+ fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc)
+ fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc)
+ fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys)
+ fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc)
+ fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef)
+ fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg)
+ fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv)
+
+ fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault)
+ fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall)
+ fmt.Fprintf(w, "#define VirtualizationException 0x%02x\n", VirtualizationException)
+
+ p := &arch.Registers{}
+ fmt.Fprintf(w, "\n// Ptrace registers.\n")
+ fmt.Fprintf(w, "#define PTRACE_R0 0x%02x\n", reflect.ValueOf(&p.Regs[0]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R1 0x%02x\n", reflect.ValueOf(&p.Regs[1]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R2 0x%02x\n", reflect.ValueOf(&p.Regs[2]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R3 0x%02x\n", reflect.ValueOf(&p.Regs[3]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R4 0x%02x\n", reflect.ValueOf(&p.Regs[4]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R5 0x%02x\n", reflect.ValueOf(&p.Regs[5]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R6 0x%02x\n", reflect.ValueOf(&p.Regs[6]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R7 0x%02x\n", reflect.ValueOf(&p.Regs[7]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R8 0x%02x\n", reflect.ValueOf(&p.Regs[8]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R9 0x%02x\n", reflect.ValueOf(&p.Regs[9]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R10 0x%02x\n", reflect.ValueOf(&p.Regs[10]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R11 0x%02x\n", reflect.ValueOf(&p.Regs[11]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R12 0x%02x\n", reflect.ValueOf(&p.Regs[12]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R13 0x%02x\n", reflect.ValueOf(&p.Regs[13]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R14 0x%02x\n", reflect.ValueOf(&p.Regs[14]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R15 0x%02x\n", reflect.ValueOf(&p.Regs[15]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R16 0x%02x\n", reflect.ValueOf(&p.Regs[16]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R17 0x%02x\n", reflect.ValueOf(&p.Regs[17]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R18 0x%02x\n", reflect.ValueOf(&p.Regs[18]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R19 0x%02x\n", reflect.ValueOf(&p.Regs[19]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R20 0x%02x\n", reflect.ValueOf(&p.Regs[20]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R21 0x%02x\n", reflect.ValueOf(&p.Regs[21]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R22 0x%02x\n", reflect.ValueOf(&p.Regs[22]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R23 0x%02x\n", reflect.ValueOf(&p.Regs[23]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R24 0x%02x\n", reflect.ValueOf(&p.Regs[24]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R25 0x%02x\n", reflect.ValueOf(&p.Regs[25]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R26 0x%02x\n", reflect.ValueOf(&p.Regs[26]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R27 0x%02x\n", reflect.ValueOf(&p.Regs[27]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R28 0x%02x\n", reflect.ValueOf(&p.Regs[28]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R29 0x%02x\n", reflect.ValueOf(&p.Regs[29]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_R30 0x%02x\n", reflect.ValueOf(&p.Regs[30]).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_SP 0x%02x\n", reflect.ValueOf(&p.Sp).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_PC 0x%02x\n", reflect.ValueOf(&p.Pc).Pointer()-reflect.ValueOf(p).Pointer())
+ fmt.Fprintf(w, "#define PTRACE_PSTATE 0x%02x\n", reflect.ValueOf(&p.Pstate).Pointer()-reflect.ValueOf(p).Pointer())
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 934a90378..16d5f478b 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,14 +1,14 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test", "select_arch")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
go_template(
name = "generic_walker",
- srcs = [
- "walker_amd64.go",
- ],
+ srcs = select_arch(
+ amd64 = ["walker_amd64.go"],
+ arm64 = ["walker_arm64.go"],
+ ),
opt_types = [
"Visitor",
],
@@ -76,20 +76,29 @@ go_library(
"allocator.go",
"allocator_unsafe.go",
"pagetables.go",
+ "pagetables_aarch64.go",
"pagetables_amd64.go",
+ "pagetables_arm64.go",
"pagetables_x86.go",
+ "pcids.go",
+ "pcids_aarch64.go",
+ "pcids_aarch64.s",
"pcids_x86.go",
+ "walker_amd64.go",
+ "walker_arm64.go",
"walker_empty.go",
"walker_lookup.go",
"walker_map.go",
"walker_unmap.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables",
visibility = [
"//pkg/sentry/platform/kvm:__subpackages__",
"//pkg/sentry/platform/ring0:__subpackages__",
],
- deps = ["//pkg/sentry/usermem"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/usermem",
+ ],
)
go_test(
@@ -97,9 +106,10 @@ go_test(
size = "small",
srcs = [
"pagetables_amd64_test.go",
+ "pagetables_arm64_test.go",
"pagetables_test.go",
"walker_check.go",
],
- embed = [":pagetables"],
- deps = ["//pkg/sentry/usermem"],
+ library = ":pagetables",
+ deps = ["//pkg/usermem"],
)
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator.go b/pkg/sentry/platform/ring0/pagetables/allocator.go
index 23fd5c352..8d75b7599 100644
--- a/pkg/sentry/platform/ring0/pagetables/allocator.go
+++ b/pkg/sentry/platform/ring0/pagetables/allocator.go
@@ -53,9 +53,14 @@ type RuntimeAllocator struct {
// NewRuntimeAllocator returns an allocator that uses runtime allocation.
func NewRuntimeAllocator() *RuntimeAllocator {
- return &RuntimeAllocator{
- used: make(map[*PTEs]struct{}),
- }
+ r := new(RuntimeAllocator)
+ r.Init()
+ return r
+}
+
+// Init initializes a RuntimeAllocator.
+func (r *RuntimeAllocator) Init() {
+ r.used = make(map[*PTEs]struct{})
}
// Recycle returns freed pages to the pool.
diff --git a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go
index a90394a33..d08bfdeb3 100644
--- a/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go
+++ b/pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go
@@ -17,7 +17,7 @@ package pagetables
import (
"unsafe"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// newAlignedPTEs returns a set of aligned PTEs.
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go
index 904f1a6de..7f18ac296 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go
@@ -21,7 +21,7 @@
package pagetables
import (
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// PageTables is a set of page tables.
@@ -48,15 +48,6 @@ func New(a Allocator) *PageTables {
return p
}
-// Init initializes a set of PageTables.
-//
-//go:nosplit
-func (p *PageTables) Init(allocator Allocator) {
- p.Allocator = allocator
- p.root = p.Allocator.NewPTEs()
- p.rootPhysical = p.Allocator.PhysicalFor(p.root)
-}
-
// mapVisitor is used for map.
type mapVisitor struct {
target uintptr // Input.
@@ -95,6 +86,8 @@ func (*mapVisitor) requiresSplit() bool { return true }
//
// Precondition: addr & length must be page-aligned, their sum must not overflow.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool {
if !opts.AccessType.Any() {
@@ -137,6 +130,8 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool {
w := unmapWalker{
@@ -171,6 +166,8 @@ func (v *emptyVisitor) visit(start uintptr, pte *PTE, align uintptr) {
//
// Precondition: addr & length must be page-aligned.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) IsEmpty(addr usermem.Addr, length uintptr) bool {
w := emptyWalker{
@@ -206,6 +203,8 @@ func (*lookupVisitor) requiresSplit() bool { return false }
// Lookup returns the physical address for the given virtual address.
//
+// +checkescape:hard,stack
+//
//go:nosplit
func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) {
mask := uintptr(usermem.PageSize - 1)
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
new file mode 100644
index 000000000..6409d1d91
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go
@@ -0,0 +1,215 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// archPageTables is architecture-specific data.
+type archPageTables struct {
+ // root is the pagetable root for kernel space.
+ root *PTEs
+
+ // rootPhysical is the cached physical address of the root.
+ //
+ // This is saved only to prevent constant translation.
+ rootPhysical uintptr
+
+ asid uint16
+}
+
+// TTBR0_EL1 returns the translation table base register 0.
+//
+//go:nosplit
+func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 {
+ return uint64(p.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+}
+
+// TTBR1_EL1 returns the translation table base register 1.
+//
+//go:nosplit
+func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 {
+ return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset
+}
+
+// Bits in page table entries.
+const (
+ typeTable = 0x3 << 0
+ typeSect = 0x1 << 0
+ typePage = 0x3 << 0
+ pteValid = 0x1 << 0
+ pteTableBit = 0x1 << 1
+ pteTypeMask = 0x3 << 0
+ present = pteValid | pteTableBit
+ user = 0x1 << 6 /* AP[1] */
+ readOnly = 0x1 << 7 /* AP[2] */
+ accessed = 0x1 << 10
+ dbm = 0x1 << 51
+ writable = dbm
+ cont = 0x1 << 52
+ pxn = 0x1 << 53
+ xn = 0x1 << 54
+ dirty = 0x1 << 55
+ nG = 0x1 << 11
+ shared = 0x3 << 8
+)
+
+const (
+ mtDevicenGnRE = 0x1 << 2
+ mtNormal = 0x4 << 2
+)
+
+const (
+ executeDisable = xn
+ optionMask = 0xfff | 0xfff<<48
+ protDefault = accessed | shared
+)
+
+// MapOpts are x86 options.
+type MapOpts struct {
+ // AccessType defines permissions.
+ AccessType usermem.AccessType
+
+ // Global indicates the page is globally accessible.
+ Global bool
+
+ // User indicates the page is a user page.
+ User bool
+}
+
+// PTE is a page table entry.
+type PTE uintptr
+
+// Clear clears this PTE, including sect page information.
+//
+//go:nosplit
+func (p *PTE) Clear() {
+ atomic.StoreUintptr((*uintptr)(p), 0)
+}
+
+// Valid returns true iff this entry is valid.
+//
+//go:nosplit
+func (p *PTE) Valid() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&present != 0
+}
+
+// Opts returns the PTE options.
+//
+// These are all options except Valid and Sect.
+//
+//go:nosplit
+func (p *PTE) Opts() MapOpts {
+ v := atomic.LoadUintptr((*uintptr)(p))
+
+ return MapOpts{
+ AccessType: usermem.AccessType{
+ Read: true,
+ Write: v&readOnly == 0,
+ Execute: v&xn == 0,
+ },
+ Global: v&nG == 0,
+ User: v&user != 0,
+ }
+}
+
+// SetSect sets this page as a sect page.
+//
+// The page must not be valid or a panic will result.
+//
+//go:nosplit
+func (p *PTE) SetSect() {
+ if p.Valid() {
+ // This is not allowed.
+ panic("SetSect called on valid page!")
+ }
+ atomic.StoreUintptr((*uintptr)(p), typeSect)
+}
+
+// IsSect returns true iff this page is a sect page.
+//
+//go:nosplit
+func (p *PTE) IsSect() bool {
+ return atomic.LoadUintptr((*uintptr)(p))&pteTypeMask == typeSect
+}
+
+// Set sets this PTE value.
+//
+// This does not change the sect page property.
+//
+//go:nosplit
+func (p *PTE) Set(addr uintptr, opts MapOpts) {
+ if !opts.AccessType.Any() {
+ p.Clear()
+ return
+ }
+ v := (addr &^ optionMask) | protDefault | nG | readOnly
+
+ if p.IsSect() {
+ // Note that this is inherited from the previous instance. Set
+ // does not change the value of Sect. See above.
+ v |= typeSect
+ } else {
+ v |= typePage
+ }
+
+ if opts.Global {
+ v = v &^ nG
+ }
+
+ if opts.AccessType.Execute {
+ v = v &^ executeDisable
+ } else {
+ v |= executeDisable
+ }
+ if opts.AccessType.Write {
+ v = v &^ readOnly
+ }
+
+ if opts.User {
+ v |= user
+ v |= mtNormal
+ } else {
+ v = v &^ user
+ v |= mtDevicenGnRE // Strong order for the addresses with ring0.KernelStartAddress.
+ }
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// setPageTable sets this PTE value and forces the write bit and sect bit to
+// be cleared. This is used explicitly for breaking sect pages.
+//
+//go:nosplit
+func (p *PTE) setPageTable(pt *PageTables, ptes *PTEs) {
+ addr := pt.Allocator.PhysicalFor(ptes)
+ if addr&^optionMask != addr {
+ // This should never happen.
+ panic("unaligned physical address!")
+ }
+ v := addr | typeTable | protDefault | mtNormal
+ atomic.StoreUintptr((*uintptr)(p), v)
+}
+
+// Address extracts the address. This should only be used if Valid returns true.
+//
+//go:nosplit
+func (p *PTE) Address() uintptr {
+ return atomic.LoadUintptr((*uintptr)(p)) &^ optionMask
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
index 7aa6c524e..0c153cf8c 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go
@@ -41,5 +41,14 @@ const (
entriesPerPage = 512
)
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+}
+
// PTEs is a collection of entries.
type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go
index 35e917526..54e8e554f 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64_test.go
@@ -19,7 +19,7 @@ package pagetables
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func Test2MAnd4K(t *testing.T) {
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
new file mode 100644
index 000000000..1a49f12a2
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go
@@ -0,0 +1,57 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+// Address constraints.
+//
+// The lowerTop and upperBottom currently apply to four-level pagetables;
+// additional refactoring would be necessary to support five-level pagetables.
+const (
+ lowerTop = 0x0000ffffffffffff
+ upperBottom = 0xffff000000000000
+ pteShift = 12
+ pmdShift = 21
+ pudShift = 30
+ pgdShift = 39
+
+ pteMask = 0x1ff << pteShift
+ pmdMask = 0x1ff << pmdShift
+ pudMask = 0x1ff << pudShift
+ pgdMask = 0x1ff << pgdShift
+
+ pteSize = 1 << pteShift
+ pmdSize = 1 << pmdShift
+ pudSize = 1 << pudShift
+ pgdSize = 1 << pgdShift
+
+ ttbrASIDOffset = 55
+ ttbrASIDMask = 0xff
+
+ entriesPerPage = 512
+)
+
+// Init initializes a set of PageTables.
+//
+//go:nosplit
+func (p *PageTables) Init(allocator Allocator) {
+ p.Allocator = allocator
+ p.root = p.Allocator.NewPTEs()
+ p.rootPhysical = p.Allocator.PhysicalFor(p.root)
+ p.archPageTables.root = p.Allocator.NewPTEs()
+ p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root)
+}
+
+// PTEs is a collection of entries.
+type PTEs [entriesPerPage]PTE
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go
new file mode 100644
index 000000000..2f73d424f
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64_test.go
@@ -0,0 +1,80 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func Test2MAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a huge page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*47)
+
+ pt.Map(0xffff000000400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: false}, pteSize*42)
+ pt.Map(0xffffff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: false}, pmdSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
+ {0x0000ff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ {0xffff000000400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: false}},
+ {0xffffff0000000000, pmdSize, pmdSize * 47, MapOpts{AccessType: usermem.Read, User: false}},
+ })
+}
+
+func Test1GAnd4K(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a small page and a super page.
+ pt.Map(0x400000, pteSize, MapOpts{AccessType: usermem.ReadWrite, User: true}, pteSize*42)
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*47)
+
+ checkMappings(t, pt, []mapping{
+ {0x400000, pteSize, pteSize * 42, MapOpts{AccessType: usermem.ReadWrite, User: true}},
+ {0x0000ff0000000000, pudSize, pudSize * 47, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
+
+func TestSplit1GPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a super page and knock out the middle.
+ pt.Map(0x0000ff0000000000, pudSize, MapOpts{AccessType: usermem.Read, User: true}, pudSize*42)
+ pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pudSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x0000ff0000000000, pteSize, pudSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000 + pudSize - pteSize, pteSize, pudSize*42 + pudSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
+
+func TestSplit2MPage(t *testing.T) {
+ pt := New(NewRuntimeAllocator())
+
+ // Map a huge page and knock out the middle.
+ pt.Map(0x0000ff0000000000, pmdSize, MapOpts{AccessType: usermem.Read, User: true}, pmdSize*42)
+ pt.Unmap(usermem.Addr(0x0000ff0000000000+pteSize), pmdSize-(2*pteSize))
+
+ checkMappings(t, pt, []mapping{
+ {0x0000ff0000000000, pteSize, pmdSize * 42, MapOpts{AccessType: usermem.Read, User: true}},
+ {0x0000ff0000000000 + pmdSize - pteSize, pteSize, pmdSize*42 + pmdSize - pteSize, MapOpts{AccessType: usermem.Read, User: true}},
+ })
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go
index 6e95ad2b9..5c88d087d 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_test.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_test.go
@@ -17,7 +17,7 @@ package pagetables
import (
"testing"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
type mapping struct {
diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
index 3e2383c5e..157438d9b 100644
--- a/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pagetables_x86.go
@@ -12,14 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package pagetables
import (
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// archPageTables is architecture-specific data.
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids.go b/pkg/sentry/platform/ring0/pagetables/pcids.go
new file mode 100644
index 000000000..964496aac
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids.go
@@ -0,0 +1,104 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package pagetables
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// PCIDs is a simple PCID database.
+//
+// This is not protected by locks and is thus suitable for use only with a
+// single CPU at a time.
+type PCIDs struct {
+ // mu protects below.
+ mu sync.Mutex
+
+ // cache are the assigned page tables.
+ cache map[*PageTables]uint16
+
+ // avail are available PCIDs.
+ avail []uint16
+}
+
+// NewPCIDs returns a new PCID database.
+//
+// start is the first index to assign. Typically this will be one, as the zero
+// pcid will always be flushed on transition (see pagetables_x86.go). This may
+// be more than one if specific PCIDs are reserved.
+//
+// Nil is returned iff the start and size are out of range.
+func NewPCIDs(start, size uint16) *PCIDs {
+ if start+uint16(size) > limitPCID {
+ return nil // See comment.
+ }
+ p := &PCIDs{
+ cache: make(map[*PageTables]uint16),
+ }
+ for pcid := start; pcid < start+size; pcid++ {
+ p.avail = append(p.avail, pcid)
+ }
+ return p
+}
+
+// Assign assigns a PCID to the given PageTables.
+//
+// This may overwrite any previous assignment provided. If this in the case,
+// true is returned to indicate that the PCID should be flushed.
+func (p *PCIDs) Assign(pt *PageTables) (uint16, bool) {
+ p.mu.Lock()
+ if pcid, ok := p.cache[pt]; ok {
+ p.mu.Unlock()
+ return pcid, false // No flush.
+ }
+
+ // Is there something available?
+ if len(p.avail) > 0 {
+ pcid := p.avail[len(p.avail)-1]
+ p.avail = p.avail[:len(p.avail)-1]
+ p.cache[pt] = pcid
+
+ // We need to flush because while this is in the available
+ // pool, it may have been used previously.
+ p.mu.Unlock()
+ return pcid, true
+ }
+
+ // Evict an existing table.
+ for old, pcid := range p.cache {
+ delete(p.cache, old)
+ p.cache[pt] = pcid
+
+ // A flush is definitely required in this case, these page
+ // tables may still be active. (They will just be assigned some
+ // other PCID if and when they hit the given CPU again.)
+ p.mu.Unlock()
+ return pcid, true
+ }
+
+ // No PCID.
+ p.mu.Unlock()
+ return 0, false
+}
+
+// Drop drops references to a set of page tables.
+func (p *PCIDs) Drop(pt *PageTables) {
+ p.mu.Lock()
+ if pcid, ok := p.cache[pt]; ok {
+ delete(p.cache, pt)
+ p.avail = append(p.avail, pcid)
+ }
+ p.mu.Unlock()
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go
new file mode 100644
index 000000000..fbfd41d83
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+// limitPCID is the maximum value of PCIDs.
+//
+// In VMSAv8-64, the PCID(ASID) size is an IMPLEMENTATION DEFINED choice
+// of 8 bits or 16 bits, and ID_AA64MMFR0_EL1.ASIDBits identifies the
+// supported size. When an implementation supports a 16-bit ASID, TCR_ELx.AS
+// selects whether the top 8 bits of the ASID are used.
+var limitPCID uint16
+
+// GetASIDBits return the system ASID bits, 8 or 16 bits.
+func GetASIDBits() uint8
+
+func init() {
+ limitPCID = uint16(1)<<GetASIDBits() - 1
+}
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s
new file mode 100644
index 000000000..e9d62d768
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_aarch64.s
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+#include "funcdata.h"
+#include "textflag.h"
+
+#define ID_AA64MMFR0_ASIDBITS_SHIFT 4
+#define ID_AA64MMFR0_ASIDBITS_16 2
+#define TCR_EL1_AS_BIT 36
+
+// GetASIDBits return the system ASID bits, 8 or 16 bits.
+//
+// func GetASIDBits() uint8
+TEXT ·GetASIDBits(SB),NOSPLIT,$0-1
+ // First, check whether 16bits ASID is supported.
+ // ID_AA64MMFR0_EL1.ASIDBITS[7:4] == 0010.
+ WORD $0xd5380700 // MRS ID_AA64MMFR0_EL1, R0
+ UBFX $ID_AA64MMFR0_ASIDBITS_SHIFT, R0, $4, R0
+ CMPW $ID_AA64MMFR0_ASIDBITS_16, R0
+ BNE bits_8
+
+ // Second, check whether 16bits ASID is enabled.
+ // TCR_EL1.AS[36] == 1.
+ WORD $0xd5382040 // MRS TCR_EL1, R0
+ TBZ $TCR_EL1_AS_BIT, R0, bits_8
+ MOVD $16, R0
+ B done
+bits_8:
+ MOVD $8, R0
+done:
+ MOVB R0, ret+0(FP)
+ RET
diff --git a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
index 0f029f25d..91fc5e8dd 100644
--- a/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
+++ b/pkg/sentry/platform/ring0/pagetables/pcids_x86.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,94 +16,5 @@
package pagetables
-import (
- "sync"
-)
-
-// limitPCID is the number of valid PCIDs.
-const limitPCID = 4096
-
-// PCIDs is a simple PCID database.
-//
-// This is not protected by locks and is thus suitable for use only with a
-// single CPU at a time.
-type PCIDs struct {
- // mu protects below.
- mu sync.Mutex
-
- // cache are the assigned page tables.
- cache map[*PageTables]uint16
-
- // avail are available PCIDs.
- avail []uint16
-}
-
-// NewPCIDs returns a new PCID database.
-//
-// start is the first index to assign. Typically this will be one, as the zero
-// pcid will always be flushed on transition (see pagetables_x86.go). This may
-// be more than one if specific PCIDs are reserved.
-//
-// Nil is returned iff the start and size are out of range.
-func NewPCIDs(start, size uint16) *PCIDs {
- if start+uint16(size) >= limitPCID {
- return nil // See comment.
- }
- p := &PCIDs{
- cache: make(map[*PageTables]uint16),
- }
- for pcid := start; pcid < start+size; pcid++ {
- p.avail = append(p.avail, pcid)
- }
- return p
-}
-
-// Assign assigns a PCID to the given PageTables.
-//
-// This may overwrite any previous assignment provided. If this in the case,
-// true is returned to indicate that the PCID should be flushed.
-func (p *PCIDs) Assign(pt *PageTables) (uint16, bool) {
- p.mu.Lock()
- if pcid, ok := p.cache[pt]; ok {
- p.mu.Unlock()
- return pcid, false // No flush.
- }
-
- // Is there something available?
- if len(p.avail) > 0 {
- pcid := p.avail[len(p.avail)-1]
- p.avail = p.avail[:len(p.avail)-1]
- p.cache[pt] = pcid
-
- // We need to flush because while this is in the available
- // pool, it may have been used previously.
- p.mu.Unlock()
- return pcid, true
- }
-
- // Evict an existing table.
- for old, pcid := range p.cache {
- delete(p.cache, old)
- p.cache[pt] = pcid
-
- // A flush is definitely required in this case, these page
- // tables may still be active. (They will just be assigned some
- // other PCID if and when they hit the given CPU again.)
- p.mu.Unlock()
- return pcid, true
- }
-
- // No PCID.
- p.mu.Unlock()
- return 0, false
-}
-
-// Drop drops references to a set of page tables.
-func (p *PCIDs) Drop(pt *PageTables) {
- p.mu.Lock()
- if pcid, ok := p.cache[pt]; ok {
- delete(p.cache, pt)
- p.avail = append(p.avail, pcid)
- }
- p.mu.Unlock()
-}
+// limitPCID is the maximum value of valid PCIDs.
+const limitPCID = 4095
diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
new file mode 100644
index 000000000..c261d393a
--- /dev/null
+++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go
@@ -0,0 +1,314 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package pagetables
+
+// Visitor is a generic type.
+type Visitor interface {
+ // visit is called on each PTE.
+ visit(start uintptr, pte *PTE, align uintptr)
+
+ // requiresAlloc indicates that new entries should be allocated within
+ // the walked range.
+ requiresAlloc() bool
+
+ // requiresSplit indicates that entries in the given range should be
+ // split if they are huge or jumbo pages.
+ requiresSplit() bool
+}
+
+// Walker walks page tables.
+type Walker struct {
+ // pageTables are the tables to walk.
+ pageTables *PageTables
+
+ // Visitor is the set of arguments.
+ visitor Visitor
+}
+
+// iterateRange iterates over all appropriate levels of page tables for the given range.
+//
+// If requiresAlloc is true, then Set _must_ be called on all given PTEs. The
+// exception is sect pages. If a valid sect page (huge or jumbo) cannot be
+// installed, then the walk will continue to individual entries.
+//
+// This algorithm will attempt to maximize the use of sect pages whenever
+// possible. Whether a sect page is provided will be clear through the range
+// provided in the callback.
+//
+// Note that if requiresAlloc is true, then no gaps will be present. However,
+// if alloc is not set, then the iteration will likely be full of gaps.
+//
+// Note that this function should generally be avoided in favor of Map, Unmap,
+// etc. when not necessary.
+//
+// Precondition: start must be page-aligned.
+//
+// Precondition: start must be less than end.
+//
+// Precondition: If requiresAlloc is true, then start and end should not span
+// non-canonical ranges. If they do, a panic will result.
+//
+//go:nosplit
+func (w *Walker) iterateRange(start, end uintptr) {
+ if start%pteSize != 0 {
+ panic("unaligned start")
+ }
+ if end < start {
+ panic("start > end")
+ }
+ if start < lowerTop {
+ if end <= lowerTop {
+ w.iterateRangeCanonical(start, end)
+ } else if end > lowerTop && end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(start, lowerTop)
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else if start < upperBottom {
+ if end <= upperBottom {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ } else {
+ if w.visitor.requiresAlloc() {
+ panic("alloc spans non-canonical range")
+ }
+ w.iterateRangeCanonical(upperBottom, end)
+ }
+ } else {
+ w.iterateRangeCanonical(start, end)
+ }
+}
+
+// next returns the next address quantized by the given size.
+//
+//go:nosplit
+func next(start uintptr, size uintptr) uintptr {
+ start &= ^(size - 1)
+ start += size
+ return start
+}
+
+// iterateRangeCanonical walks a canonical range.
+//
+//go:nosplit
+func (w *Walker) iterateRangeCanonical(start, end uintptr) {
+ pgdEntryIndex := w.pageTables.root
+ if start >= upperBottom {
+ pgdEntryIndex = w.pageTables.archPageTables.root
+ }
+
+ for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ {
+ var (
+ pgdEntry = &pgdEntryIndex[pgdIndex]
+ pudEntries *PTEs
+ )
+ if !pgdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ start = next(start, pgdSize)
+ continue
+ }
+
+ // Allocate a new pgd.
+ pudEntries = w.pageTables.Allocator.NewPTEs()
+ pgdEntry.setPageTable(w.pageTables, pudEntries)
+ } else {
+ pudEntries = w.pageTables.Allocator.LookupPTEs(pgdEntry.Address())
+ }
+
+ // Map the next level.
+ clearPUDEntries := uint16(0)
+
+ for pudIndex := uint16((start & pudMask) >> pudShift); start < end && pudIndex < entriesPerPage; pudIndex++ {
+ var (
+ pudEntry = &pudEntries[pudIndex]
+ pmdEntries *PTEs
+ )
+ if !pudEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPUDEntries++
+ start = next(start, pudSize)
+ continue
+ }
+
+ // This level has 1-GB sect pages. Is this
+ // entire region at least as large as a single
+ // PUD entry? If so, we can skip allocating a
+ // new page for the pmd.
+ if start&(pudSize-1) == 0 && end-start >= pudSize {
+ pudEntry.SetSect()
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+ if pudEntry.Valid() {
+ start = next(start, pudSize)
+ continue
+ }
+ }
+
+ // Allocate a new pud.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+
+ } else if pudEntry.IsSect() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pudSize-1) != 0 || end < next(start, pudSize)) {
+ // Install the relevant entries.
+ pmdEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pmdEntries[index].SetSect()
+ pmdEntries[index].Set(
+ pudEntry.Address()+(pmdSize*uintptr(index)),
+ pudEntry.Opts())
+ }
+ pudEntry.setPageTable(w.pageTables, pmdEntries)
+ } else {
+ // A sect page to be checked directly.
+ w.visitor.visit(uintptr(start), pudEntry, pudSize-1)
+
+ // Might have been cleared.
+ if !pudEntry.Valid() {
+ clearPUDEntries++
+ }
+
+ // Note that the sect page was changed.
+ start = next(start, pudSize)
+ continue
+ }
+
+ } else {
+ pmdEntries = w.pageTables.Allocator.LookupPTEs(pudEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPMDEntries := uint16(0)
+
+ for pmdIndex := uint16((start & pmdMask) >> pmdShift); start < end && pmdIndex < entriesPerPage; pmdIndex++ {
+ var (
+ pmdEntry = &pmdEntries[pmdIndex]
+ pteEntries *PTEs
+ )
+ if !pmdEntry.Valid() {
+ if !w.visitor.requiresAlloc() {
+ // Skip over this entry.
+ clearPMDEntries++
+ start = next(start, pmdSize)
+ continue
+ }
+
+ // This level has 2-MB huge pages. If this
+ // region is contined in a single PMD entry?
+ // As above, we can skip allocating a new page.
+ if start&(pmdSize-1) == 0 && end-start >= pmdSize {
+ pmdEntry.SetSect()
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+ if pmdEntry.Valid() {
+ start = next(start, pmdSize)
+ continue
+ }
+ }
+
+ // Allocate a new pmd.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+
+ } else if pmdEntry.IsSect() {
+ // Does this page need to be split?
+ if w.visitor.requiresSplit() && (start&(pmdSize-1) != 0 || end < next(start, pmdSize)) {
+ // Install the relevant entries.
+ pteEntries = w.pageTables.Allocator.NewPTEs()
+ for index := uint16(0); index < entriesPerPage; index++ {
+ pteEntries[index].Set(
+ pmdEntry.Address()+(pteSize*uintptr(index)),
+ pmdEntry.Opts())
+ }
+ pmdEntry.setPageTable(w.pageTables, pteEntries)
+ } else {
+ // A huge page to be checked directly.
+ w.visitor.visit(uintptr(start), pmdEntry, pmdSize-1)
+
+ // Might have been cleared.
+ if !pmdEntry.Valid() {
+ clearPMDEntries++
+ }
+
+ // Note that the huge page was changed.
+ start = next(start, pmdSize)
+ continue
+ }
+
+ } else {
+ pteEntries = w.pageTables.Allocator.LookupPTEs(pmdEntry.Address())
+ }
+
+ // Map the next level, since this is valid.
+ clearPTEEntries := uint16(0)
+
+ for pteIndex := uint16((start & pteMask) >> pteShift); start < end && pteIndex < entriesPerPage; pteIndex++ {
+ var (
+ pteEntry = &pteEntries[pteIndex]
+ )
+ if !pteEntry.Valid() && !w.visitor.requiresAlloc() {
+ clearPTEEntries++
+ start += pteSize
+ continue
+ }
+
+ // At this point, we are guaranteed that start%pteSize == 0.
+ w.visitor.visit(uintptr(start), pteEntry, pteSize-1)
+ if !pteEntry.Valid() {
+ if w.visitor.requiresAlloc() {
+ panic("PTE not set after iteration with requiresAlloc!")
+ }
+ clearPTEEntries++
+ }
+
+ // Note that the pte was changed.
+ start += pteSize
+ continue
+ }
+
+ // Check if we no longer need this page.
+ if clearPTEEntries == entriesPerPage {
+ pmdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pteEntries)
+ clearPMDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPMDEntries == entriesPerPage {
+ pudEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pmdEntries)
+ clearPUDEntries++
+ }
+ }
+
+ // Check if we no longer need this page.
+ if clearPUDEntries == entriesPerPage {
+ pgdEntry.Clear()
+ w.pageTables.Allocator.FreePTEs(pudEntries)
+ }
+ }
+}
diff --git a/pkg/sentry/platform/ring0/x86.go b/pkg/sentry/platform/ring0/x86.go
index 5f80d64e8..9da0ea685 100644
--- a/pkg/sentry/platform/ring0/x86.go
+++ b/pkg/sentry/platform/ring0/x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build i386 amd64
+// +build 386 amd64
package ring0
diff --git a/pkg/sentry/sighandling/BUILD b/pkg/sentry/sighandling/BUILD
index f561670c7..6c38a3f44 100644
--- a/pkg/sentry/sighandling/BUILD
+++ b/pkg/sentry/sighandling/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,7 +8,6 @@ go_library(
"sighandling.go",
"sighandling_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/sighandling",
visibility = ["//pkg/sentry:internal"],
deps = ["//pkg/abi/linux"],
)
diff --git a/pkg/sentry/sighandling/sighandling.go b/pkg/sentry/sighandling/sighandling.go
index 2f65db70b..83195d5a1 100644
--- a/pkg/sentry/sighandling/sighandling.go
+++ b/pkg/sentry/sighandling/sighandling.go
@@ -16,7 +16,6 @@
package sighandling
import (
- "fmt"
"os"
"os/signal"
"reflect"
@@ -31,37 +30,25 @@ const numSignals = 32
// handleSignals listens for incoming signals and calls the given handler
// function.
//
-// It starts when the start channel is closed, stops when the stop channel
-// is closed, and closes done once it will no longer deliver signals to k.
-func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start, stop, done chan struct{}) {
+// It stops when the stop channel is closed. The done channel is closed once it
+// will no longer deliver signals to k.
+func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), stop, done chan struct{}) {
// Build a select case.
- sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(start)}}
+ sc := []reflect.SelectCase{{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}}
for _, sigchan := range sigchans {
sc = append(sc, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sigchan)})
}
- started := false
for {
// Wait for a notification.
index, _, ok := reflect.Select(sc)
- // Was it the start / stop channel?
+ // Was it the stop channel?
if index == 0 {
if !ok {
- if !started {
- // start channel; start forwarding and
- // swap this case for the stop channel
- // to select stop requests.
- started = true
- sc[0] = reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(stop)}
- } else {
- // stop channel; stop forwarding and
- // clear this case so it is never
- // selected again.
- started = false
- close(done)
- sc[0].Chan = reflect.Value{}
- }
+ // Stop forwarding and notify that it's done.
+ close(done)
+ return
}
continue
}
@@ -73,44 +60,17 @@ func handleSignals(sigchans []chan os.Signal, handler func(linux.Signal), start,
// Otherwise, it was a signal on channel N. Index 0 represents the stop
// channel, so index N represents the channel for signal N.
- signal := linux.Signal(index)
-
- if !started {
- // Kernel cannot receive signals, either because it is
- // not ready yet or is shutting down.
- //
- // Kill ourselves if this signal would have killed the
- // process before PrepareForwarding was called. i.e., all
- // _SigKill signals; see Go
- // src/runtime/sigtab_linux_generic.go.
- //
- // Otherwise ignore the signal.
- //
- // TODO(b/114489875): Drop in Go 1.12, which uses tgkill
- // in runtime.raise.
- switch signal {
- case linux.SIGHUP, linux.SIGINT, linux.SIGTERM:
- dieFromSignal(signal)
- panic(fmt.Sprintf("Failed to die from signal %d", signal))
- default:
- continue
- }
- }
-
- // Pass the signal to the handler.
- handler(signal)
+ handler(linux.Signal(index))
}
}
-// PrepareHandler ensures that synchronous signals are passed to the given
-// handler function and returns a callback that starts signal delivery, which
-// itself returns a callback that stops signal handling.
+// StartSignalForwarding ensures that synchronous signals are passed to the
+// given handler function and returns a callback that stops signal delivery.
//
// Note that this function permanently takes over signal handling. After the
// stop callback, signals revert to the default Go runtime behavior, which
// cannot be overridden with external calls to signal.Notify.
-func PrepareHandler(handler func(linux.Signal)) func() func() {
- start := make(chan struct{})
+func StartSignalForwarding(handler func(linux.Signal)) func() {
stop := make(chan struct{})
done := make(chan struct{})
@@ -125,16 +85,18 @@ func PrepareHandler(handler func(linux.Signal)) func() func() {
for sig := 1; sig <= numSignals+1; sig++ {
sigchan := make(chan os.Signal, 1)
sigchans = append(sigchans, sigchan)
+
+ // SIGURG is used by Go's runtime scheduler.
+ if sig == int(linux.SIGURG) {
+ continue
+ }
signal.Notify(sigchan, syscall.Signal(sig))
}
// Start up our listener.
- go handleSignals(sigchans, handler, start, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
+ go handleSignals(sigchans, handler, stop, done) // S/R-SAFE: synchronized by Kernel.extMu.
- return func() func() {
- close(start)
- return func() {
- close(stop)
- <-done
- }
+ return func() {
+ close(stop)
+ <-done
}
}
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index c303435d5..1ebe22d34 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -15,8 +15,6 @@
package sighandling
import (
- "fmt"
- "runtime"
"syscall"
"unsafe"
@@ -48,27 +46,3 @@ func IgnoreChildStop() error {
return nil
}
-
-// dieFromSignal kills the current process with sig.
-//
-// Preconditions: The default action of sig is termination.
-func dieFromSignal(sig linux.Signal) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- sa := sigaction{handler: linux.SIG_DFL}
- if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGACTION, uintptr(sig), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 {
- panic(fmt.Sprintf("rt_sigaction failed: %v", e))
- }
-
- set := linux.MakeSignalSet(sig)
- if _, _, e := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_UNBLOCK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0); e != 0 {
- panic(fmt.Sprintf("rt_sigprocmask failed: %v", e))
- }
-
- if err := syscall.Tgkill(syscall.Getpid(), syscall.Gettid(), syscall.Signal(sig)); err != nil {
- panic(fmt.Sprintf("tgkill failed: %v", err))
- }
-
- panic("failed to die")
-}
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 26176b10d..c0fd3425b 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -1,24 +1,25 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "socket",
srcs = ["socket.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/binary",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket/unix/transport",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 4a6e83a8b..ca16d0381 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -1,11 +1,13 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "control",
- srcs = ["control.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/control",
+ srcs = [
+ "control.go",
+ "control_vfs2.go",
+ ],
imports = [
"gvisor.dev/gvisor/pkg/sentry/fs",
],
@@ -13,12 +15,15 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/binary",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4e95101b7..70ccf77a7 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -19,13 +19,15 @@ package control
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const maxInt = int(^uint(0) >> 1)
@@ -39,6 +41,8 @@ type SCMCredentials interface {
Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
}
+// LINT.IfChange
+
// SCMRights represents a SCM_RIGHTS socket control message.
type SCMRights interface {
transport.RightsControlMessage
@@ -64,7 +68,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
for _, fd := range fds {
file := t.GetFile(fd)
if file == nil {
- files.Release()
+ files.Release(t)
return nil, syserror.EBADF
}
files = append(files, file)
@@ -96,9 +100,9 @@ func (fs *RightsFiles) Clone() transport.RightsControlMessage {
}
// Release implements transport.RightsControlMessage.Release.
-func (fs *RightsFiles) Release() {
+func (fs *RightsFiles) Release(ctx context.Context) {
for _, f := range *fs {
- f.DecRef()
+ f.DecRef(ctx)
}
*fs = nil
}
@@ -111,7 +115,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32
fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{
CloseOnExec: cloexec,
})
- files[0].DecRef()
+ files[0].DecRef(t)
files = files[1:]
if err != nil {
t.Warningf("Error inserting FD: %v", err)
@@ -140,6 +144,8 @@ func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flag
return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
}
+// LINT.ThenChange(./control_vfs2.go)
+
// scmCredentials represents an SCM_CREDENTIALS socket control message.
//
// +stateify savable
@@ -188,21 +194,21 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := AlignDown(cap(buf)-len(buf), 4)
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
//
// align must be >= 4 and each data int32 is 4 bytes. The length of the
- // header is already aligned, so if we align to the with of the data there
+ // header is already aligned, so if we align to the width of the data there
// are two cases:
// 1. The aligned length is less than the length of the header. The
// unaligned length was also less than the length of the header, so we
// can't write anything.
// 2. The aligned length is greater than or equal to the length of the
- // header. We can write the header plus zero or more datas. We can't write
- // a partial int32, so the length of the message will be
- // min(aligned length, header + datas).
+ // header. We can write the header plus zero or more bytes of data. We can't
+ // write a partial int32, so the length of the message will be
+ // min(aligned length, header + data).
if space < linux.SizeOfControlMessageHeader {
flags |= linux.MSG_CTRUNC
return buf, flags
@@ -239,12 +245,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
buf = binary.Marshal(buf, usermem.ByteOrder, data)
- // Check if we went over.
+ // If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
return hdrBuf
}
- // Fix up length.
+ // Update control message length to include data.
putUint64(ob, uint64(len(buf)-len(ob)))
return alignSlice(buf, align)
@@ -281,19 +287,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
}
-// AlignUp rounds a length up to an alignment. align must be a power of 2.
-func AlignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
-
-// AlignDown rounds a down to an alignment. align must be a power of 2.
-func AlignDown(length int, align uint) int {
- return length & ^(int(align) - 1)
-}
-
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := AlignUp(len(buf), align)
+ aligned := binary.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -320,35 +316,139 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
buf,
linux.SOL_TCP,
linux.TCP_INQ,
- 4,
+ t.Arch().Width(),
inq,
)
}
+// PackTOS packs an IP_TOS socket control message.
+func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_TOS,
+ t.Arch().Width(),
+ tos,
+ )
+}
+
+// PackTClass packs an IPV6_TCLASS socket control message.
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_TCLASS,
+ t.Arch().Width(),
+ tClass,
+ )
+}
+
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
+// PackControlMessages packs control messages into the given buffer.
+//
+// We skip control messages specific to Unix domain sockets.
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+ if cmsgs.IP.HasTimestamp {
+ buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
+ }
+
+ if cmsgs.IP.HasInq {
+ // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
+ buf = PackInq(t, cmsgs.IP.Inq, buf)
+ }
+
+ if cmsgs.IP.HasTOS {
+ buf = PackTOS(t, cmsgs.IP.TOS, buf)
+ }
+
+ if cmsgs.IP.HasTClass {
+ buf = PackTClass(t, cmsgs.IP.TClass, buf)
+ }
+
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
+ return buf
+}
+
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
// Parse parses a raw socket control message into portable objects.
-func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) {
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
- fds linux.ControlMessageRights
- haveCreds bool
- creds linux.ControlMessageCredentials
+ cmsgs socket.ControlMessages
+ fds linux.ControlMessageRights
)
for i := 0; i < len(buf); {
if i+linux.SizeOfControlMessageHeader > len(buf) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return cmsgs, syserror.EINVAL
}
var h linux.ControlMessageHeader
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
if h.Length > uint64(len(buf)-i) {
- return transport.ControlMessages{}, syserror.EINVAL
- }
- if h.Level != linux.SOL_SOCKET {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
i += linux.SizeOfControlMessageHeader
@@ -358,59 +458,105 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.
// sizeof(long) in CMSG_ALIGN.
width := t.Arch().Width()
- switch h.Type {
- case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
- numRights := rightsSize / linux.SizeOfControlMessageRight
-
- if len(fds)+numRights > linux.SCM_MAX_FD {
- return transport.ControlMessages{}, syserror.EINVAL
+ switch h.Level {
+ case linux.SOL_SOCKET:
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += binary.AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ scmCreds, err := NewSCMCredentials(t, creds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Credentials = scmCreds
+ i += binary.AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
- fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ case linux.SOL_IP:
+ switch h.Type {
+ case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTOS = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- i += AlignUp(length, width)
-
- case linux.SCM_CREDENTIALS:
- if length < linux.SizeOfControlMessageCredentials {
- return transport.ControlMessages{}, syserror.EINVAL
+ case linux.SOL_IPV6:
+ switch h.Type {
+ case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTClass = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
- haveCreds = true
- i += AlignUp(length, width)
-
default:
- // Unknown message type.
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
}
- var credentials SCMCredentials
- if haveCreds {
- var err error
- if credentials, err = NewSCMCredentials(t, creds); err != nil {
- return transport.ControlMessages{}, err
- }
- } else {
- credentials = makeCreds(t, socketOrEndpoint)
+ if cmsgs.Unix.Credentials == nil {
+ cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint)
}
- var rights SCMRights
if len(fds) > 0 {
- var err error
- if rights, err = NewSCMRights(t, fds); err != nil {
- return transport.ControlMessages{}, err
+ if kernel.VFS2Enabled {
+ rights, err := NewSCMRightsVFS2(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
+ } else {
+ rights, err := NewSCMRights(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
}
}
- if credentials == nil && rights == nil {
- return transport.ControlMessages{}, nil
- }
-
- return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil
+ return cmsgs, nil
}
func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
@@ -432,6 +578,8 @@ func MakeCreds(t *kernel.Task) SCMCredentials {
return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
}
+// LINT.IfChange
+
// New creates default control messages if needed.
func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages {
return transport.ControlMessages{
@@ -439,3 +587,5 @@ func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transpo
Rights: rights,
}
}
+
+// LINT.ThenChange(./control_vfs2.go)
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
new file mode 100644
index 000000000..d9621968c
--- /dev/null
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SCMRightsVFS2 represents a SCM_RIGHTS socket control message.
+type SCMRightsVFS2 interface {
+ transport.RightsControlMessage
+
+ // Files returns up to max RightsFiles.
+ //
+ // Returned files are consumed and ownership is transferred to the caller.
+ // Subsequent calls to Files will return the next files.
+ Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool)
+}
+
+// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
+// maintained for each vfs.FileDescription and is release either when an FD is created or
+// when the Release method is called.
+type RightsFilesVFS2 []*vfs.FileDescription
+
+// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
+// representation using local sentry FDs.
+func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) {
+ files := make(RightsFilesVFS2, 0, len(fds))
+ for _, fd := range fds {
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ files.Release(t)
+ return nil, syserror.EBADF
+ }
+ files = append(files, file)
+ }
+ return &files, nil
+}
+
+// Files implements SCMRights.Files.
+func (fs *RightsFilesVFS2) Files(ctx context.Context, max int) (RightsFilesVFS2, bool) {
+ n := max
+ var trunc bool
+ if l := len(*fs); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+ rf := (*fs)[:n]
+ *fs = (*fs)[n:]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage {
+ nfs := append(RightsFilesVFS2(nil), *fs...)
+ for _, nf := range nfs {
+ nf.IncRef()
+ }
+ return &nfs
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (fs *RightsFilesVFS2) Release(ctx context.Context) {
+ for _, f := range *fs {
+ f.DecRef(ctx)
+ }
+ *fs = nil
+}
+
+// rightsFDsVFS2 gets up to the specified maximum number of FDs.
+func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int) ([]int32, bool) {
+ files, trunc := rights.Files(t, max)
+ fds := make([]int32, 0, len(files))
+ for i := 0; i < max && len(files) > 0; i++ {
+ fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{
+ CloseOnExec: cloexec,
+ })
+ files[0].DecRef(t)
+ files = files[1:]
+ if err != nil {
+ t.Warningf("Error inserting FD: %v", err)
+ // This is what Linux does.
+ break
+ }
+
+ fds = append(fds, int32(fd))
+ }
+ return fds, trunc
+}
+
+// PackRightsVFS2 packs as many FDs as will fit into the unused capacity of buf.
+func PackRightsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, buf []byte, flags int) ([]byte, int) {
+ maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
+ // Linux does not return any FDs if none fit.
+ if maxFDs <= 0 {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+ fds, trunc := rightsFDsVFS2(t, rights, cloexec, maxFDs)
+ if trunc {
+ flags |= linux.MSG_CTRUNC
+ }
+ align := t.Arch().Width()
+ return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
+}
+
+// NewVFS2 creates default control messages if needed.
+func NewVFS2(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRightsVFS2) transport.ControlMessages {
+ return transport.ControlMessages{
+ Credentials: makeCreds(t, socketOrEndpoint),
+ Rights: rights,
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index c1b20eaf8..8448ea401 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -10,31 +10,41 @@ go_library(
"save_restore.go",
"socket.go",
"socket_unsafe.go",
+ "socket_vfs2.go",
+ "sockopt_impl.go",
"stack.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/hostinet",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/context",
"//pkg/fdnotifier",
"//pkg/log",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/hostfd",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 92beb1bcf..242e6bf76 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -18,21 +18,26 @@ import (
"fmt"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
@@ -41,8 +46,14 @@ const (
// sizeofSockaddr is the size in bytes of the largest sockaddr type
// supported by this package.
sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
+
+ // maxControlLen is the maximum size of a control message buffer used in a
+ // recvmsg or sendmsg syscall.
+ maxControlLen = 1024
)
+// LINT.IfChange
+
// socketOperations implements fs.FileOperations and socket.Socket for a socket
// implemented using a host socket.
type socketOperations struct {
@@ -53,55 +64,74 @@ type socketOperations struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
socket.SendReceiveTimeout
family int // Read-only.
stype linux.SockType // Read-only.
protocol int // Read-only.
- fd int // must be O_NONBLOCK
queue waiter.Queue
+
+ // fd is the host socket fd. It must have O_NONBLOCK, so that operations
+ // will return EWOULDBLOCK instead of blocking on the host. This allows us to
+ // handle blocking behavior independently in the sentry.
+ fd int
}
var _ = socket.Socket(&socketOperations{})
func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
s := &socketOperations{
- family: family,
- stype: stype,
- protocol: protocol,
- fd: fd,
+ socketOpsCommon: socketOpsCommon{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ },
}
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
return nil, syserr.FromError(err)
}
dirent := socket.NewDirent(ctx, socketDevice)
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil
}
// Release implements fs.FileOperations.Release.
-func (s *socketOperations) Release() {
+func (s *socketOpsCommon) Release(context.Context) {
fdnotifier.RemoveFD(int32(s.fd))
syscall.Close(s.fd)
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.queue.EventRegister(e, mask)
fdnotifier.UpdateFD(int32(s.fd))
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *socketOperations) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.queue.EventUnregister(e)
fdnotifier.UpdateFD(int32(s.fd))
}
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return ioctl(ctx, s.fd, io, args)
+}
+
// Read implements fs.FileOperations.Read.
func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
@@ -120,7 +150,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
}
return uint64(n), nil
}
- return readv(s.fd, iovecsFromBlockSeq(dsts))
+ return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
}))
return int64(n), err
}
@@ -143,13 +173,13 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
return uint64(n), nil
}
- return writev(s.fd, iovecsFromBlockSeq(srcs))
+ return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
}))
return int64(n), err
}
// Connect implements socket.Socket.Connect.
-func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
sockaddr = sockaddr[:sizeofSockaddr]
}
@@ -189,7 +219,7 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
var peerAddr linux.SockAddr
var peerAddrBuf []byte
var peerAddrlen uint32
@@ -203,7 +233,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
// Conservatively ignore all flags specified by the application and add
- // SOCK_NONBLOCK since socketOperations requires it.
+ // SOCK_NONBLOCK since socketOpsCommon requires it.
fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
if blocking {
var ch chan struct{}
@@ -229,23 +259,41 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
}
- f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
- if err != nil {
- syscall.Close(fd)
- return 0, nil, 0, err
- }
- defer f.DecRef()
+ var (
+ kfd int32
+ kerr error
+ )
+ if kernel.VFS2Enabled {
+ f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&syscall.SOCK_NONBLOCK))
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef(t)
- kfd, kerr := t.NewFDFrom(0, f, kernel.FDFlags{
- CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
- })
- t.Kernel().RecordSocket(f)
+ kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocketVFS2(f)
+ } else {
+ f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef(t)
+
+ kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocket(f)
+ }
return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
}
// Bind implements socket.Socket.Bind.
-func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) > sizeofSockaddr {
sockaddr = sockaddr[:sizeofSockaddr]
}
@@ -258,12 +306,12 @@ func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Listen implements socket.Socket.Listen.
-func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return syserr.FromError(syscall.Listen(s.fd, backlog))
}
// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
switch how {
case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR:
return syserr.FromError(syscall.Shutdown(s.fd, how))
@@ -273,34 +321,40 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
if outLen < 0 {
return nil, syserr.ErrInvalidArgument
}
- // Whitelist options and constrain option length.
- var optlen int
+ // Only allow known and safe options.
+ optlen := getSockOptLen(t, level, name)
switch level {
- case syscall.SOL_IPV6:
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_IPV6:
switch name {
- case syscall.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
- case syscall.SOL_SOCKET:
+ case linux.SOL_SOCKET:
switch name {
- case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
optlen = sizeofInt32
- case syscall.SO_LINGER:
+ case linux.SO_LINGER:
optlen = syscall.SizeofLinger
}
- case syscall.SOL_TCP:
+ case linux.SOL_TCP:
switch name {
- case syscall.TCP_NODELAY:
+ case linux.TCP_NODELAY:
optlen = sizeofInt32
- case syscall.TCP_INFO:
+ case linux.TCP_INFO:
optlen = int(linux.SizeOfTCPInfo)
}
}
+
if optlen == 0 {
return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
}
@@ -312,30 +366,39 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
if err != nil {
return nil, syserr.FromError(err)
}
- return opt, nil
+ optP := primitive.ByteSlice(opt)
+ return &optP, nil
}
// SetSockOpt implements socket.Socket.SetSockOpt.
-func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
- // Whitelist options and constrain option length.
- var optlen int
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ // Only allow known and safe options.
+ optlen := setSockOptLen(t, level, name)
switch level {
- case syscall.SOL_IPV6:
+ case linux.SOL_IP:
switch name {
- case syscall.IPV6_V6ONLY:
+ case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
}
- case syscall.SOL_SOCKET:
+ case linux.SOL_IPV6:
switch name {
- case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
- case syscall.SOL_TCP:
+ case linux.SOL_SOCKET:
switch name {
- case syscall.TCP_NODELAY:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_TCP:
+ switch name {
+ case linux.TCP_NODELAY:
optlen = sizeofInt32
}
}
+
if optlen == 0 {
// Pretend to accept socket options we don't understand. This seems
// dangerous, but it's what netstack does...
@@ -354,11 +417,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- // Whitelist flags.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ // Only allow known and safe flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
- // messages that netstack/tcpip/transport/unix doesn't understand. Kill the
+ // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
// Socket interface's dependence on netstack.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
@@ -370,6 +433,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
senderAddrBuf = make([]byte, sizeofSockaddr)
}
+ var controlBuf []byte
var msgFlags int
recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
@@ -384,12 +448,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// We always do a non-blocking recv*().
sysflags := flags | syscall.MSG_DONTWAIT
- if dsts.NumBlocks() == 1 {
- // Skip allocating []syscall.Iovec.
- return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf)
- }
-
- iovs := iovecsFromBlockSeq(dsts)
+ iovs := safemem.IovecsFromBlockSeq(dsts)
msg := syscall.Msghdr{
Iov: &iovs[0],
Iovlen: uint64(len(iovs)),
@@ -398,12 +457,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msg.Name = &senderAddrBuf[0]
msg.Namelen = uint32(len(senderAddrBuf))
}
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
return 0, err
}
senderAddrBuf = senderAddrBuf[:msg.Namelen]
msgFlags = int(msg.Flags)
+ controlLen = uint64(msg.Controllen)
return n, nil
})
@@ -429,36 +497,75 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
}
}
-
- // We don't allow control messages.
- msgFlags &^= linux.MSG_CTRUNC
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err)
+
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ controlMessages := socket.ControlMessages{}
+ for _, unixCmsg := range unixControlMessages {
+ switch unixCmsg.Header.Level {
+ case syscall.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case syscall.IP_TOS:
+ controlMessages.IP.HasTOS = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ }
+
+ case syscall.SOL_IPV6:
+ switch unixCmsg.Header.Type {
+ case syscall.IPV6_TCLASS:
+ controlMessages.IP.HasTClass = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+ }
+ }
+ }
+
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
- // Whitelist flags.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Only allow known and safe flags.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
return 0, syserr.ErrInvalidArgument
}
+ space := uint64(control.CmsgsSpace(t, controlMessages))
+ if space > maxControlLen {
+ space = maxControlLen
+ }
+ controlBuf := make([]byte, 0, space)
+ // PackControlMessages will append up to space bytes to controlBuf.
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
if uint64(src.NumBytes()) != srcs.NumBytes() {
return 0, nil
}
- if srcs.IsEmpty() {
+ if srcs.IsEmpty() && len(controlBuf) == 0 {
return 0, nil
}
// We always do a non-blocking send*().
sysflags := flags | syscall.MSG_DONTWAIT
- if srcs.NumBlocks() == 1 {
+ if srcs.NumBlocks() == 1 && len(controlBuf) == 0 {
// Skip allocating []syscall.Iovec.
src := srcs.Head()
n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
@@ -468,7 +575,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return uint64(n), nil
}
- iovs := iovecsFromBlockSeq(srcs)
+ iovs := safemem.IovecsFromBlockSeq(srcs)
msg := syscall.Msghdr{
Iov: &iovs[0],
Iovlen: uint64(len(iovs)),
@@ -477,6 +584,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
msg.Name = &to[0]
msg.Namelen = uint32(len(to))
}
+ if len(controlBuf) != 0 {
+ msg.Control = &controlBuf[0]
+ msg.Controllen = uint64(len(controlBuf))
+ }
return sendmsg(s.fd, &msg, sysflags)
})
@@ -509,21 +620,6 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return int(n), syserr.FromError(err)
}
-func iovecsFromBlockSeq(bs safemem.BlockSeq) []syscall.Iovec {
- iovs := make([]syscall.Iovec, 0, bs.NumBlocks())
- for ; !bs.IsEmpty(); bs = bs.Tail() {
- b := bs.Head()
- iovs = append(iovs, syscall.Iovec{
- Base: &b.ToSlice()[0],
- Len: uint64(b.Len()),
- })
- // We don't need to care about b.NeedSafecopy(), because the host
- // kernel will handle such address ranges just fine (by returning
- // EFAULT).
- }
- return iovs
-}
-
func translateIOSyscallError(err error) error {
if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
return syserror.ErrWouldBlock
@@ -532,7 +628,7 @@ func translateIOSyscallError(err error) error {
}
// State implements socket.Socket.State.
-func (s *socketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
info := linux.TCPInfo{}
buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo)
if err != nil {
@@ -554,7 +650,7 @@ func (s *socketOperations) State() uint32 {
}
// Type implements socket.Socket.Type.
-func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
return s.family, s.stype, s.protocol
}
@@ -610,8 +706,11 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int
return nil, nil, nil
}
+// LINT.ThenChange(./socket_vfs2.go)
+
func init() {
for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
socket.RegisterProvider(family, &socketProvider{family})
+ socket.RegisterProviderVFS2(family, &socketProviderVFS2{family})
}
}
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
index e69ec38c2..3f420c2ec 100644
--- a/pkg/sentry/socket/hostinet/socket_unsafe.go
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -19,14 +19,13 @@ import (
"unsafe"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
func firstBytePtr(bs []byte) unsafe.Pointer {
@@ -54,12 +53,11 @@ func writev(fd int, srcs []syscall.Iovec) (uint64, error) {
return uint64(n), nil
}
-// Ioctl implements fs.FileOperations.Ioctl.
-func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
switch cmd := uintptr(args[1].Int()); cmd {
case syscall.TIOCINQ, syscall.TIOCOUTQ:
var val int32
- if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(s.fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 {
+ if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 {
return 0, translateIOSyscallError(errno)
}
var buf [4]byte
@@ -93,7 +91,7 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
@@ -104,7 +102,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32,
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr := make([]byte, sizeofSockaddr)
addrlen := uint32(len(addr))
_, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
new file mode 100644
index 000000000..8a1d52ebf
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -0,0 +1,203 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type socketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ // We store metadata for hostinet sockets internally. Technically, we should
+ // access metadata (e.g. through stat, chmod) on the host for correctness,
+ // but this is not very useful for inet socket fds, which do not belong to a
+ // concrete file anyway.
+ vfs.DentryMetadataFileDescriptionImpl
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&socketVFS2{})
+
+func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ s := &socketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ },
+ }
+ s.LockFD.Init(&vfs.FileLocks{})
+ if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ vfsfd := &s.vfsfd
+ if err := vfsfd.Init(s, linux.O_RDWR|(flags&linux.O_NONBLOCK), mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ fdnotifier.RemoveFD(int32(s.fd))
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return ioctl(ctx, s.fd, uio, args)
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ENODEV
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ reader := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
+ n, err := dst.CopyOutFrom(ctx, reader)
+ hostfd.PutReadWriterAt(reader)
+ return int64(n), err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) PWrite(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ writer := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
+ n, err := src.CopyInTo(ctx, writer)
+ hostfd.PutReadWriterAt(writer)
+ return int64(n), err
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
+type socketProviderVFS2 struct {
+ family int
+}
+
+// Socket implements socket.ProviderVFS2.Socket.
+func (p *socketProviderVFS2) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := stypeflags & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newVFS2Socket(t, p.family, stype, protocol, fd, uint32(stypeflags&syscall.SOCK_NONBLOCK))
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProviderVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go
new file mode 100644
index 000000000..8a783712e
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/sockopt_impl.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+func getSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
+
+func setSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 4b460d30e..3d3fabb30 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -25,15 +25,16 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
)
var defaultRecvBufSize = inet.TCPBufferSize{
@@ -55,6 +56,7 @@ type Stack struct {
interfaceAddrs map[int32][]inet.InterfaceAddr
routes []inet.Route
supportsIPv6 bool
+ tcpRecovery inet.TCPLossRecovery
tcpRecvBufSize inet.TCPBufferSize
tcpSendBufSize inet.TCPBufferSize
tcpSACKEnabled bool
@@ -128,6 +130,13 @@ func (s *Stack) Configure() error {
log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false")
}
+ s.ipv4Forwarding = false
+ if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil {
+ s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0"
+ } else {
+ log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false")
+ }
+
return nil
}
@@ -321,6 +330,11 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return addrs
}
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
// SupportsIPv6 implements inet.Stack.SupportsIPv6.
func (s *Stack) SupportsIPv6() bool {
return s.supportsIPv6
@@ -356,6 +370,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
return syserror.EACCES
}
+// TCPRecovery implements inet.Stack.TCPRecovery.
+func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
+ return s.tcpRecovery, nil
+}
+
+// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
+func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+ return syserror.EACCES
+}
+
// getLine reads one line from proc file, with specified prefix.
// The last argument, withHeader, specifies if it contains line header.
func getLine(f *os.File, prefix string, withHeader bool) string {
@@ -455,6 +479,15 @@ func (s *Stack) RouteTable() []inet.Route {
// Resume implements inet.Stack.Resume.
func (s *Stack) Resume() {}
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
+
// Forwarding implements inet.Stack.Forwarding.
func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
switch protocol {
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 5eb06bbf4..721094bbf 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -1,24 +1,29 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "netfilter",
srcs = [
+ "extensions.go",
"netfilter.go",
+ "owner_matcher.go",
+ "targets.go",
+ "tcp_matcher.go",
+ "udp_matcher.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter",
# This target depends on netstack and should only be used by epsocket,
# which is allowed to depend on netstack.
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/log",
"//pkg/sentry/kernel",
- "//pkg/sentry/usermem",
"//pkg/syserr",
"//pkg/tcpip",
- "//pkg/tcpip/iptables",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
new file mode 100644
index 000000000..0336a32d8
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TODO(gvisor.dev/issue/170): The following per-matcher params should be
+// supported:
+// - Table name
+// - Match size
+// - User size
+// - Hooks
+// - Proto
+// - Family
+
+// matchMaker knows how to (un)marshal the matcher named name().
+type matchMaker interface {
+ // name is the matcher name as stored in the xt_entry_match struct.
+ name() string
+
+ // marshal converts from an stack.Matcher to an ABI struct.
+ marshal(matcher stack.Matcher) []byte
+
+ // unmarshal converts from the ABI matcher struct to an
+ // stack.Matcher.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
+}
+
+// matchMakers maps the name of supported matchers to the matchMaker that
+// marshals and unmarshals it. It is immutable after package initialization.
+var matchMakers = map[string]matchMaker{}
+
+// registermatchMaker should be called by match extensions to register them
+// with the netfilter package.
+func registerMatchMaker(mm matchMaker) {
+ if _, ok := matchMakers[mm.name()]; ok {
+ panic(fmt.Sprintf("Multiple matches registered with name %q.", mm.name()))
+ }
+ matchMakers[mm.name()] = mm
+}
+
+func marshalMatcher(matcher stack.Matcher) []byte {
+ matchMaker, ok := matchMakers[matcher.Name()]
+ if !ok {
+ panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
+ }
+ return matchMaker.marshal(matcher)
+}
+
+// marshalEntryMatch creates a marshalled XTEntryMatch with the given name and
+// data appended at the end.
+func marshalEntryMatch(name string, data []byte) []byte {
+ nflog("marshaling matcher %q", name)
+
+ // We have to pad this struct size to a multiple of 8 bytes.
+ size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ matcher := linux.KernelXTEntryMatch{
+ XTEntryMatch: linux.XTEntryMatch{
+ MatchSize: uint16(size),
+ },
+ Data: data,
+ }
+ copy(matcher.Name[:], name)
+
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, matcher)
+ return append(buf, make([]byte, size-len(buf))...)
+}
+
+func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
+ matchMaker, ok := matchMakers[match.Name.String()]
+ if !ok {
+ return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
+ }
+ return matchMaker.unmarshal(buf, filter)
+}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 9f87c32f1..e91b0624c 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,270 +17,506 @@
package netfilter
import (
+ "bytes"
+ "errors"
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// errorTargetName is used to mark targets as error targets. Error targets
-// shouldn't be reached - an error has occurred if we fall through to one.
-const errorTargetName = "ERROR"
-
-// metadata is opaque to netstack. It holds data that we need to translate
-// between Linux's and netstack's iptables representations.
-type metadata struct {
- HookEntry [linux.NF_INET_NUMHOOKS]uint32
- Underflow [linux.NF_INET_NUMHOOKS]uint32
- NumEntries uint32
- Size uint32
+// enableLogging controls whether to log the (de)serialization of netfilter
+// structs between userspace and netstack. These logs are useful when
+// developing iptables, but can pollute sentry logs otherwise.
+const enableLogging = false
+
+// emptyFilter is for comparison with a rule's filters to determine whether it
+// is also empty. It is immutable.
+var emptyFilter = stack.IPHeaderFilter{
+ Dst: "\x00\x00\x00\x00",
+ DstMask: "\x00\x00\x00\x00",
+ Src: "\x00\x00\x00\x00",
+ SrcMask: "\x00\x00\x00\x00",
+}
+
+// nflog logs messages related to the writing and reading of iptables.
+func nflog(format string, args ...interface{}) {
+ if enableLogging && log.IsLogging(log.Debug) {
+ log.Debugf("netfilter: "+format, args...)
+ }
}
// GetInfo returns information about iptables.
-func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
+func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
var info linux.IPTGetinfo
- if _, err := t.CopyIn(outPtr, &info); err != nil {
+ if _, err := info.CopyIn(t, outPtr); err != nil {
return linux.IPTGetinfo{}, syserr.FromError(err)
}
- // Find the appropriate table.
- table, err := findTable(ep, info.TableName())
+ _, info, err := convertNetstackToBinary(stack, info.Name)
if err != nil {
- return linux.IPTGetinfo{}, err
+ nflog("couldn't convert iptables: %v", err)
+ return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
}
- // Get the hooks that apply to this table.
- info.ValidHooks = table.ValidHooks()
-
- // Grab the metadata struct, which is used to store information (e.g.
- // the number of entries) that applies to the user's encoding of
- // iptables, but not netstack's.
- metadata := table.Metadata().(metadata)
-
- // Set values from metadata.
- info.HookEntry = metadata.HookEntry
- info.Underflow = metadata.Underflow
- info.NumEntries = metadata.NumEntries
- info.Size = metadata.Size
-
+ nflog("returning info: %+v", info)
return info, nil
}
// GetEntries returns netstack's iptables rules encoded for the iptables tool.
-func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
+func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
- if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ if _, err := userEntries.CopyIn(t, outPtr); err != nil {
+ nflog("couldn't copy in entries %q", userEntries.Name)
return linux.KernelIPTGetEntries{}, syserr.FromError(err)
}
- // Find the appropriate table.
- table, err := findTable(ep, userEntries.TableName())
- if err != nil {
- return linux.KernelIPTGetEntries{}, err
- }
-
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
- entries, _, err := convertNetstackToBinary(userEntries.TableName(), table)
+ entries, _, err := convertNetstackToBinary(stack, userEntries.Name)
if err != nil {
- return linux.KernelIPTGetEntries{}, err
+ nflog("couldn't read entries: %v", err)
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
if binary.Size(entries) > uintptr(outLen) {
+ nflog("insufficient GetEntries output size: %d", uintptr(outLen))
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
}
return entries, nil
}
-func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) {
- ipt, err := ep.IPTables()
- if err != nil {
- return iptables.Table{}, syserr.FromError(err)
- }
- table, ok := ipt.Tables[tableName]
+// convertNetstackToBinary converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a bit, reading some offsets,
+// jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
+ table, ok := stack.IPTables().GetTable(tablename.String())
if !ok {
- return iptables.Table{}, syserr.ErrInvalidArgument
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
+ var entries linux.KernelIPTGetEntries
+ var info linux.IPTGetinfo
+ info.ValidHooks = table.ValidHooks()
+
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
}
- return table, nil
+ copy(info.Name[:], tablename[:])
+ copy(entries.Name[:], tablename[:])
+
+ for ruleIdx, rule := range table.Rules {
+ nflog("convert to binary: current offset: %d", entries.Size)
+
+ // Is this a chain entry point?
+ for hook, hookRuleIdx := range table.BuiltinChains {
+ if hookRuleIdx == ruleIdx {
+ nflog("convert to binary: found hook %d at offset %d", hook, entries.Size)
+ info.HookEntry[hook] = entries.Size
+ }
+ }
+ // Is this a chain underflow point?
+ for underflow, underflowRuleIdx := range table.Underflows {
+ if underflowRuleIdx == ruleIdx {
+ nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size)
+ info.Underflow[underflow] = entries.Size
+ }
+ }
+
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ Entry: linux.IPTEntry{
+ IP: linux.IPTIP{
+ Protocol: uint16(rule.Filter.Protocol),
+ },
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+ copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IP.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ if rule.Filter.DstInvert {
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ }
+ if rule.Filter.SrcInvert {
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ }
+ if rule.Filter.OutputInterfaceInvert {
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ nflog("convert to binary: matcher serialized as: %v", serialized)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+
+ nflog("convert to binary: adding entry: %+v", entry)
+
+ entries.Size += uint32(entry.Entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ info.NumEntries++
+ }
+
+ nflog("convert to binary: finished with an marshalled size of %d", info.Size)
+ info.Size = entries.Size
+ return entries, info, nil
}
-// FillDefaultIPTables sets stack's IPTables to the default tables and
-// populates them with metadata.
-func FillDefaultIPTables(stack *stack.Stack) {
- ipt := iptables.DefaultTables()
+// SetEntries sets iptables rules for a single table. See
+// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
+func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
+ // Get the basic rules data (struct ipt_replace).
+ if len(optVal) < linux.SizeOfIPTReplace {
+ nflog("optVal has insufficient size for replace %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var replace linux.IPTReplace
+ replaceBuf := optVal[:linux.SizeOfIPTReplace]
+ optVal = optVal[linux.SizeOfIPTReplace:]
+ binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+
+ // TODO(gvisor.dev/issue/170): Support other tables.
+ var table stack.Table
+ switch replace.Name.String() {
+ case stack.FilterTable:
+ table = stack.EmptyFilterTable()
+ case stack.NATTable:
+ table = stack.EmptyNATTable()
+ default:
+ nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ return syserr.ErrInvalidArgument
+ }
- // In order to fill in the metadata, we have to translate ipt from its
- // netstack format to Linux's giant-binary-blob format.
- for name, table := range ipt.Tables {
- _, metadata, err := convertNetstackToBinary(name, table)
+ nflog("set entries: setting entries in table %q", replace.Name.String())
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ nflog("set entries: processing entry at offset %d", offset)
+
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIPTEntry {
+ nflog("optVal has insufficient size for entry %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var entry linux.IPTEntry
+ buf := optVal[:linux.SizeOfIPTEntry]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ initialOptValLen := len(optVal)
+ optVal = optVal[linux.SizeOfIPTEntry:]
+
+ if entry.TargetOffset < linux.SizeOfIPTEntry {
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support more IPTIP
+ // filtering fields.
+ filter, err := filterFromIPTIP(entry.IP)
if err != nil {
- panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+ nflog("bad iptip: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Matchers and targets can specify
+ // that they only work for certain protocols, hooks, tables.
+ // Get matchers.
+ matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
+ if len(optVal) < int(matchersSize) {
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ if err != nil {
+ nflog("failed to parse matchers: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+ optVal = optVal[matchersSize:]
+
+ // Get the target of the rule.
+ targetSize := entry.NextOffset - entry.TargetOffset
+ if len(optVal) < int(targetSize) {
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ target, err := parseTarget(filter, optVal[:targetSize])
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, stack.Rule{
+ Filter: filter,
+ Target: target,
+ Matchers: matchers,
+ })
+ offsets[offset] = int(entryIdx)
+ offset += uint32(entry.NextOffset)
+
+ if initialOptValLen-len(optVal) != int(entry.NextOffset) {
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
}
- table.SetMetadata(metadata)
- ipt.Tables[name] = table
}
- stack.SetIPTables(ipt)
-}
+ // Go through the list of supported hooks for this table and, for each
+ // one, set the rule it corresponds to.
+ for hook, _ := range replace.HookEntry {
+ if table.ValidHooks()&(1<<hook) != 0 {
+ hk := hookFromLinux(hook)
+ table.BuiltinChains[hk] = stack.HookUnset
+ table.Underflows[hk] = stack.HookUnset
+ for offset, ruleIdx := range offsets {
+ if offset == replace.HookEntry[hook] {
+ table.BuiltinChains[hk] = ruleIdx
+ }
+ if offset == replace.Underflow[hook] {
+ if !validUnderflow(table.Rules[ruleIdx]) {
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
+ return syserr.ErrInvalidArgument
+ }
+ table.Underflows[hk] = ruleIdx
+ }
+ }
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset {
+ nflog("hook %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset {
+ nflog("underflow %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
-// convertNetstackToBinary converts the iptables as stored in netstack to the
-// format expected by the iptables tool. Linux stores each table as a binary
-// blob that can only be traversed by parsing a bit, reading some offsets,
-// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
- // Return values.
- var entries linux.KernelIPTGetEntries
- var meta metadata
+ // Add the user chains.
+ for ruleIdx, rule := range table.Rules {
+ if _, ok := rule.Target.(stack.UserChainTarget); !ok {
+ continue
+ }
- // The table name has to fit in the struct.
- if linux.XT_TABLE_MAXNAMELEN < len(name) {
- return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ // We found a user chain. Before inserting it into the table,
+ // check that:
+ // - There's some other rule after it.
+ // - There are no matchers.
+ if ruleIdx == len(table.Rules)-1 {
+ nflog("user chain must have a rule or default policy")
+ return syserr.ErrInvalidArgument
+ }
+ if len(table.Rules[ruleIdx].Matchers) != 0 {
+ nflog("user chain's first node must have no matchers")
+ return syserr.ErrInvalidArgument
+ }
}
- copy(entries.Name[:], name)
- // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of
- // these chains ends with an unconditional policy entry.
- for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ {
- chain, ok := table.BuiltinChains[hook]
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
if !ok {
- // This table doesn't support this hook.
continue
}
- // Sanity check.
- if len(chain.Rules) < 1 {
- return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
}
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
- for ruleIdx, rule := range chain.Rules {
- // If this is the first rule of a builtin chain, set
- // the metadata hook entry point.
- if ruleIdx == 0 {
- meta.HookEntry[hook] = entries.Size
+ // TODO(gvisor.dev/issue/170): Support other chains.
+ // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
+ // make sure all other chains point to ACCEPT rules.
+ for hook, ruleIdx := range table.BuiltinChains {
+ if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
+ if ruleIdx == stack.HookUnset {
+ continue
}
-
- // Each rule corresponds to an entry.
- entry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
- NextOffset: linux.SizeOfIPTEntry,
- TargetOffset: linux.SizeOfIPTEntry,
- },
+ if !isUnconditionalAccept(table.Rules[ruleIdx]) {
+ nflog("hook %d is unsupported.", hook)
+ return syserr.ErrInvalidArgument
}
+ }
+ }
- for _, matcher := range rule.Matchers {
- // Serialize the matcher and add it to the
- // entry.
- serialized := marshalMatcher(matcher)
- entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
- entry.TargetOffset += uint16(len(serialized))
- }
+ // TODO(gvisor.dev/issue/170): Check the following conditions:
+ // - There are no loops.
+ // - There are no chains without an unconditional final rule.
+ // - There are no chains without an unconditional underflow rule.
- // Serialize and append the target.
- serialized := marshalTarget(rule.Target)
- entry.Elems = append(entry.Elems, serialized...)
- entry.NextOffset += uint16(len(serialized))
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table))
+}
- // The underflow rule is the last rule in the chain,
- // and is an unconditional rule (i.e. it matches any
- // packet). This is enforced when saving iptables.
- if ruleIdx == len(chain.Rules)-1 {
- meta.Underflow[hook] = entries.Size
- }
+// parseMatchers parses 0 or more matchers from optVal. optVal should contain
+// only the matchers.
+func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
+ nflog("set entries: parsing matchers of size %d", len(optVal))
+ var matchers []stack.Matcher
+ for len(optVal) > 0 {
+ nflog("set entries: optVal has len %d", len(optVal))
+
+ // Get the XTEntryMatch.
+ if len(optVal) < linux.SizeOfXTEntryMatch {
+ return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
+ }
+ var match linux.XTEntryMatch
+ buf := optVal[:linux.SizeOfXTEntryMatch]
+ binary.Unmarshal(buf, usermem.ByteOrder, &match)
+ nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
+
+ // Check some invariants.
+ if match.MatchSize < linux.SizeOfXTEntryMatch {
- entries.Size += uint32(entry.NextOffset)
- entries.Entrytable = append(entries.Entrytable, entry)
- meta.NumEntries++
+ return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch)
+ }
+ if len(optVal) < int(match.MatchSize) {
+ return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal))
}
+ // Parse the specific matcher.
+ matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
+ if err != nil {
+ return nil, fmt.Errorf("failed to create matcher: %v", err)
+ }
+ matchers = append(matchers, matcher)
+
+ // TODO(gvisor.dev/issue/170): Check the revision field.
+ optVal = optVal[match.MatchSize:]
}
- // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of
- // these starts with an error node holding the chain's name and ends
- // with an unconditional return.
-
- // Lastly, each table ends with an unconditional error target rule as
- // its final entry.
- errorEntry := linux.KernelIPTEntry{
- IPTEntry: linux.IPTEntry{
- NextOffset: linux.SizeOfIPTEntry,
- TargetOffset: linux.SizeOfIPTEntry,
- },
+ if len(optVal) != 0 {
+ return nil, errors.New("optVal should be exhausted after parsing matchers")
}
- var errorTarget linux.XTErrorTarget
- errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget
- copy(errorTarget.ErrorName[:], errorTargetName)
- copy(errorTarget.Target.Name[:], errorTargetName)
-
- // Serialize and add it to the list of entries.
- errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget)
- serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget)
- errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...)
- errorEntry.NextOffset += uint16(len(serializedErrorTarget))
-
- entries.Size += uint32(errorEntry.NextOffset)
- entries.Entrytable = append(entries.Entrytable, errorEntry)
- meta.NumEntries++
- meta.Size = entries.Size
-
- return entries, meta, nil
+
+ return matchers, nil
}
-func marshalMatcher(matcher iptables.Matcher) []byte {
- switch matcher.(type) {
- default:
- // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so
- // any call to marshalMatcher will panic.
- panic(fmt.Errorf("unknown matcher of type %T", matcher))
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
+ if containsUnsupportedFields(iptip) {
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
+ }
+ if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
+ }
+
+ n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterface)
+ }
+ ifname := string(iptip.OutputInterface[:n])
+
+ n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterfaceMask)
}
+ ifnameMask := string(iptip.OutputInterfaceMask[:n])
+
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
+ Src: tcpip.Address(iptip.Src[:]),
+ SrcMask: tcpip.Address(iptip.SrcMask[:]),
+ SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
+ OutputInterface: ifname,
+ OutputInterfaceMask: ifnameMask,
+ OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
+ }, nil
}
-func marshalTarget(target iptables.Target) []byte {
- switch target.(type) {
- case iptables.UnconditionalAcceptTarget:
- return marshalUnconditionalAcceptTarget()
+func containsUnsupportedFields(iptip linux.IPTIP) bool {
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - Src and SrcMask
+ // - The inverse destination IP check flag
+ // - OutputInterface, OutputInterfaceMask and its inverse.
+ var emptyInterface = [linux.IFNAMSIZ]byte{}
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
+ return iptip.InputInterface != emptyInterface ||
+ iptip.InputInterfaceMask != emptyInterface ||
+ iptip.Flags != 0 ||
+ iptip.InverseFlags&^inverseMask != 0
+}
+
+func validUnderflow(rule stack.Rule) bool {
+ if len(rule.Matchers) != 0 {
+ return false
+ }
+ if rule.Filter != emptyFilter {
+ return false
+ }
+ switch rule.Target.(type) {
+ case stack.AcceptTarget, stack.DropTarget:
+ return true
default:
- panic(fmt.Errorf("unknown target of type %T", target))
+ return false
}
}
-func marshalUnconditionalAcceptTarget() []byte {
- // The target's name will be the empty string.
- target := linux.XTStandardTarget{
- Target: linux.XTEntryTarget{
- TargetSize: linux.SizeOfXTStandardTarget,
- },
- Verdict: translateStandardVerdict(iptables.Accept),
+func isUnconditionalAccept(rule stack.Rule) bool {
+ if !validUnderflow(rule) {
+ return false
}
-
- ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
- return binary.Marshal(ret, usermem.ByteOrder, target)
+ _, ok := rule.Target.(stack.AcceptTarget)
+ return ok
}
-// translateStandardVerdict translates verdicts the same way as the iptables
-// tool.
-func translateStandardVerdict(verdict iptables.Verdict) int32 {
- switch verdict {
- case iptables.Accept:
- return -linux.NF_ACCEPT - 1
- case iptables.Drop:
- return -linux.NF_DROP - 1
- case iptables.Queue:
- return -linux.NF_QUEUE - 1
- case iptables.Return:
- return linux.NF_RETURN
- case iptables.Jump:
- // TODO(gvisor.dev/issue/170): Support Jump.
- panic("Jump isn't supported yet")
- default:
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+func hookFromLinux(hook int) stack.Hook {
+ switch hook {
+ case linux.NF_INET_PRE_ROUTING:
+ return stack.Prerouting
+ case linux.NF_INET_LOCAL_IN:
+ return stack.Input
+ case linux.NF_INET_FORWARD:
+ return stack.Forward
+ case linux.NF_INET_LOCAL_OUT:
+ return stack.Output
+ case linux.NF_INET_POST_ROUTING:
+ return stack.Postrouting
}
+ panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
new file mode 100644
index 000000000..1b4e0ad79
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameOwner = "owner"
+
+func init() {
+ registerMatchMaker(ownerMarshaler{})
+}
+
+// ownerMarshaler implements matchMaker for owner matching.
+type ownerMarshaler struct{}
+
+// name implements matchMaker.name.
+func (ownerMarshaler) name() string {
+ return matcherNameOwner
+}
+
+// marshal implements matchMaker.marshal.
+func (ownerMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*OwnerMatcher)
+ iptOwnerInfo := linux.IPTOwnerInfo{
+ UID: matcher.uid,
+ GID: matcher.gid,
+ }
+
+ // Support for UID and GID match.
+ if matcher.matchUID {
+ iptOwnerInfo.Match = linux.XT_OWNER_UID
+ if matcher.invertUID {
+ iptOwnerInfo.Invert = linux.XT_OWNER_UID
+ }
+ }
+ if matcher.matchGID {
+ iptOwnerInfo.Match |= linux.XT_OWNER_GID
+ if matcher.invertGID {
+ iptOwnerInfo.Invert |= linux.XT_OWNER_GID
+ }
+ }
+
+ buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfIPTOwnerInfo {
+ return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.IPTOwnerInfo
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+
+ var owner OwnerMatcher
+ owner.uid = matchData.UID
+ owner.gid = matchData.GID
+
+ // Check flags.
+ if matchData.Match&linux.XT_OWNER_UID != 0 {
+ owner.matchUID = true
+ if matchData.Invert&linux.XT_OWNER_UID != 0 {
+ owner.invertUID = true
+ }
+ }
+ if matchData.Match&linux.XT_OWNER_GID != 0 {
+ owner.matchGID = true
+ if matchData.Invert&linux.XT_OWNER_GID != 0 {
+ owner.invertGID = true
+ }
+ }
+
+ return &owner, nil
+}
+
+type OwnerMatcher struct {
+ uid uint32
+ gid uint32
+ matchUID bool
+ matchGID bool
+ invertUID bool
+ invertGID bool
+}
+
+// Name implements Matcher.Name.
+func (*OwnerMatcher) Name() string {
+ return matcherNameOwner
+}
+
+// Match implements Matcher.Match.
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ // Support only for OUTPUT chain.
+ // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
+ if hook != stack.Output {
+ return false, true
+ }
+
+ // If the packet owner is not set, drop the packet.
+ if pkt.Owner == nil {
+ return false, true
+ }
+
+ var matches bool
+ // Check for UID match.
+ if om.matchUID {
+ if pkt.Owner.UID() == om.uid {
+ matches = true
+ }
+ if matches == om.invertUID {
+ return false, false
+ }
+ }
+
+ // Check for GID match.
+ if om.matchGID {
+ matches = false
+ if pkt.Owner.GID() == om.gid {
+ matches = true
+ }
+ if matches == om.invertGID {
+ return false, false
+ }
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..8ebdaff18
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,282 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "errors"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// errorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const errorTargetName = "ERROR"
+
+// redirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const redirectTargetName = "REDIRECT"
+
+func marshalTarget(target stack.Target) []byte {
+ switch tg := target.(type) {
+ case stack.AcceptTarget:
+ return marshalStandardTarget(stack.RuleAccept)
+ case stack.DropTarget:
+ return marshalStandardTarget(stack.RuleDrop)
+ case stack.ErrorTarget:
+ return marshalErrorTarget(errorTargetName)
+ case stack.UserChainTarget:
+ return marshalErrorTarget(tg.Name)
+ case stack.ReturnTarget:
+ return marshalStandardTarget(stack.RuleReturn)
+ case stack.RedirectTarget:
+ return marshalRedirectTarget(tg)
+ case JumpTarget:
+ return marshalJumpTarget(tg)
+ default:
+ panic(fmt.Errorf("unknown target of type %T", target))
+ }
+}
+
+func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
+ nflog("convert to binary: marshalling standard target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ Verdict: translateFromStandardVerdict(verdict),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalErrorTarget(errorName string) []byte {
+ // This is an error target named error
+ target := linux.XTErrorTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTErrorTarget,
+ },
+ }
+ copy(target.Name[:], errorName)
+ copy(target.Target.Name[:], errorTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalRedirectTarget(rt stack.RedirectTarget) []byte {
+ // This is a redirect target named redirect
+ target := linux.XTRedirectTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTRedirectTarget,
+ },
+ }
+ copy(target.Target.Name[:], redirectTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ target.NfRange.RangeSize = 1
+ if rt.RangeProtoSpecified {
+ target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ }
+ // Convert port from little endian to big endian.
+ port := make([]byte, 2)
+ binary.LittleEndian.PutUint16(port, rt.MinPort)
+ target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port)
+ binary.LittleEndian.PutUint16(port, rt.MaxPort)
+ target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateFromStandardVerdict translates verdicts the same way as the iptables
+// tool.
+func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
+ switch verdict {
+ case stack.RuleAccept:
+ return -linux.NF_ACCEPT - 1
+ case stack.RuleDrop:
+ return -linux.NF_DROP - 1
+ case stack.RuleReturn:
+ return linux.NF_RETURN
+ default:
+ // TODO(gvisor.dev/issue/170): Support Jump.
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ }
+}
+
+// translateToStandardTarget translates from the value in a
+// linux.XTStandardTarget to an stack.Verdict.
+func translateToStandardTarget(val int32) (stack.Target, error) {
+ // TODO(gvisor.dev/issue/170): Support other verdicts.
+ switch val {
+ case -linux.NF_ACCEPT - 1:
+ return stack.AcceptTarget{}, nil
+ case -linux.NF_DROP - 1:
+ return stack.DropTarget{}, nil
+ case -linux.NF_QUEUE - 1:
+ return nil, errors.New("unsupported iptables verdict QUEUE")
+ case linux.NF_RETURN:
+ return stack.ReturnTarget{}, nil
+ default:
+ return nil, fmt.Errorf("unknown iptables verdict %d", val)
+ }
+}
+
+// parseTarget parses a target from optVal. optVal should contain only the
+// target.
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
+ nflog("set entries: parsing target of size %d", len(optVal))
+ if len(optVal) < linux.SizeOfXTEntryTarget {
+ return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
+ }
+ var target linux.XTEntryTarget
+ buf := optVal[:linux.SizeOfXTEntryTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &target)
+ switch target.Name.String() {
+ case "":
+ // Standard target.
+ if len(optVal) != linux.SizeOfXTStandardTarget {
+ return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal))
+ }
+ var standardTarget linux.XTStandardTarget
+ buf = optVal[:linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
+ }
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
+
+ case errorTargetName:
+ // Error target.
+ if len(optVal) != linux.SizeOfXTErrorTarget {
+ return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal))
+ }
+ var errorTarget linux.XTErrorTarget
+ buf = optVal[:linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+
+ // Error targets are used in 2 cases:
+ // * An actual error case. These rules have an error
+ // named errorTargetName. The last entry of the table
+ // is usually an error case to catch any packets that
+ // somehow fall through every rule.
+ // * To mark the start of a user defined chain. These
+ // rules have an error with the name of the chain.
+ switch name := errorTarget.Name.String(); name {
+ case errorTargetName:
+ nflog("set entries: error target")
+ return stack.ErrorTarget{}, nil
+ default:
+ // User defined chain.
+ nflog("set entries: user-defined target %q", name)
+ return stack.UserChainTarget{Name: name}, nil
+ }
+
+ case redirectTargetName:
+ // Redirect target.
+ if len(optVal) < linux.SizeOfXTRedirectTarget {
+ return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = optVal[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ var target stack.RedirectTarget
+ nfRange := redirectTarget.NfRange
+
+ // RangeSize should be 1.
+ if nfRange.RangeSize != 1 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+ target.RangeProtoSpecified = true
+
+ target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // Convert port from big endian to little endian.
+ port := make([]byte, 2)
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
+ target.MinPort = binary.LittleEndian.Uint16(port)
+
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
+ target.MaxPort = binary.LittleEndian.Uint16(port)
+ return target, nil
+ }
+
+ // Unknown target.
+ return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String())
+}
+
+// JumpTarget implements stack.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements stack.Target.Action.
+func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+ return stack.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
new file mode 100644
index 000000000..0bfd6c1f4
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -0,0 +1,130 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameTCP = "tcp"
+
+func init() {
+ registerMatchMaker(tcpMarshaler{})
+}
+
+// tcpMarshaler implements matchMaker for TCP matching.
+type tcpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (tcpMarshaler) name() string {
+ return matcherNameTCP
+}
+
+// marshal implements matchMaker.marshal.
+func (tcpMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*TCPMatcher)
+ xttcp := linux.XTTCP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTTCP)
+ return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfXTTCP {
+ return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.XTTCP
+ binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTTCP: %+v", matchData)
+
+ if matchData.Option != 0 ||
+ matchData.FlagMask != 0 ||
+ matchData.FlagCompare != 0 ||
+ matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported TCP matcher flags set")
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber {
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ }
+
+ return &TCPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// TCPMatcher matches TCP packets and their headers. It implements Matcher.
+type TCPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*TCPMatcher) Name() string {
+ return matcherNameTCP
+}
+
+// Match implements Matcher.Match.
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
+ // There's no valid TCP header here, so we drop the packet immediately.
+ return false, true
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := tcpHeader.SourcePort(); sourcePort < tm.sourcePortStart || tm.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := tcpHeader.DestinationPort(); destinationPort < tm.destinationPortStart || tm.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
new file mode 100644
index 000000000..7ed05461d
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameUDP = "udp"
+
+func init() {
+ registerMatchMaker(udpMarshaler{})
+}
+
+// udpMarshaler implements matchMaker for UDP matching.
+type udpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (udpMarshaler) name() string {
+ return matcherNameUDP
+}
+
+// marshal implements matchMaker.marshal.
+func (udpMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*UDPMatcher)
+ xtudp := linux.XTUDP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTUDP)
+ return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfXTUDP {
+ return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may exceed what's
+ // strictly necessary to hold matchData.
+ var matchData linux.XTUDP
+ binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTUDP: %+v", matchData)
+
+ if matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported UDP matcher inverse flags set")
+ }
+
+ if filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ }
+
+ return &UDPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// UDPMatcher matches UDP packets and their headers. It implements Matcher.
+type UDPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*UDPMatcher) Name() string {
+ return matcherNameUDP
+}
+
+// Match implements Matcher.Match.
+func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+
+ // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
+ // into the stack.Check codepath as matchers are added.
+ if netHeader.TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ udpHeader := header.UDP(pkt.TransportHeader().View())
+ if len(udpHeader) < header.UDPMinimumSize {
+ // There's no valid UDP header here, so we drop the packet immediately.
+ return false, true
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := udpHeader.SourcePort(); sourcePort < um.sourcePortStart || um.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := udpHeader.DestinationPort(); destinationPort < um.destinationPortStart || um.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index f95803f91..0546801bf 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -7,29 +7,48 @@ go_library(
srcs = [
"message.go",
"provider.go",
+ "provider_vfs2.go",
"socket.go",
+ "socket_vfs2.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/binary",
+ "//pkg/context",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
+ "//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
+ ],
+)
+
+go_test(
+ name = "netlink_test",
+ size = "small",
+ srcs = [
+ "message_test.go",
+ ],
+ deps = [
+ ":netlink",
+ "//pkg/abi/linux",
],
)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index ce0a1afd0..0899c61d1 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -20,18 +20,19 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// alignUp rounds a length up to an alignment.
+// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) &^ (int(align) - 1)
+func alignPad(length int, align uint) int {
+ return binary.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
type Message struct {
+ hdr linux.NetlinkMessageHeader
buf []byte
}
@@ -40,10 +41,86 @@ type Message struct {
// The header length will be updated by Finalize.
func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
return &Message{
+ hdr: hdr,
buf: binary.Marshal(nil, usermem.ByteOrder, hdr),
}
}
+// ParseMessage parses the first message seen at buf, returning the rest of the
+// buffer. If message is malformed, ok of false is returned. For last message,
+// padding check is loose, if there isn't enought padding, whole buf is consumed
+// and ok is set to true.
+func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
+ b := BytesView(buf)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+ var hdr linux.NetlinkMessageHeader
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ // Msg portion.
+ totalMsgLen := int(hdr.Length)
+ _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+
+ // Padding.
+ numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return
+ }
+
+ return &Message{
+ hdr: hdr,
+ buf: buf[:totalMsgLen],
+ }, []byte(b), true
+}
+
+// Header returns the header of this message.
+func (m *Message) Header() linux.NetlinkMessageHeader {
+ return m.hdr
+}
+
+// GetData unmarshals the payload message header from this netlink message, and
+// returns the attributes portion.
+func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
+ b := BytesView(m.buf)
+
+ _, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return nil, false
+ }
+
+ size := int(binary.Size(msg))
+ msgBytes, ok := b.Extract(size)
+ if !ok {
+ return nil, false
+ }
+ binary.Unmarshal(msgBytes, usermem.ByteOrder, msg)
+
+ numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return nil, false
+ }
+
+ return AttrsView(b), true
+}
+
// Finalize returns the []byte containing the entire message, with the total
// length set in the message header. The Message must not be modified after
// calling Finalize.
@@ -54,7 +131,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -89,7 +166,7 @@ func (m *Message) PutAttr(atype uint16, v interface{}) {
m.Put(v)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -106,7 +183,7 @@ func (m *Message) PutAttrString(atype uint16, s string) {
m.putZeros(1)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -157,3 +234,48 @@ func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
ms.Messages = append(ms.Messages, m)
return m
}
+
+// AttrsView is a view into the attributes portion of a netlink message.
+type AttrsView []byte
+
+// Empty returns whether there is no attribute left in v.
+func (v AttrsView) Empty() bool {
+ return len(v) == 0
+}
+
+// ParseFirst parses first netlink attribute at the beginning of v.
+func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) {
+ b := BytesView(v)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+
+ _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO))
+ if !ok {
+ return
+ }
+
+ return hdr, value, AttrsView(b), ok
+}
+
+// BytesView supports extracting data from a byte slice with bounds checking.
+type BytesView []byte
+
+// Extract removes the first n bytes from v and returns it. If n is out of
+// bounds, it returns false.
+func (v *BytesView) Extract(n int) ([]byte, bool) {
+ if n < 0 || n > len(*v) {
+ return nil, false
+ }
+ extracted := (*v)[:n]
+ *v = (*v)[n:]
+ return extracted, true
+}
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
new file mode 100644
index 000000000..ef13d9386
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -0,0 +1,312 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package message_test
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+)
+
+type dummyNetlinkMsg struct {
+ Foo uint16
+}
+
+func TestParseMessage(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ header linux.NetlinkMessageHeader
+ dataMsg *dummyNetlinkMsg
+ restLen int
+ ok bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid with next message",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ 0xFF, // Next message (rest)
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 1,
+ ok: true,
+ },
+ {
+ desc: "valid for last message without padding",
+ input: []byte{
+ 0x12, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 18,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid for last message not to be aligned",
+ input: []byte{
+ 0x13, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ 0x00, // Excessive 1 byte permitted at end
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 19,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "header.Length too short",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header.Length too long",
+ input: []byte{
+ 0xFF, 0xFF, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header incomplete",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ },
+ ok: false,
+ },
+ {
+ desc: "empty message",
+ input: []byte{},
+ ok: false,
+ },
+ }
+ for _, test := range tests {
+ msg, rest, ok := netlink.ParseMessage(test.input)
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ continue
+ }
+ if !test.ok {
+ continue
+ }
+ if !reflect.DeepEqual(msg.Header(), test.header) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
+ }
+
+ dataMsg := &dummyNetlinkMsg{}
+ _, dataOk := msg.GetData(dataMsg)
+ if !dataOk {
+ t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
+ } else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
+ t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
+ }
+
+ if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
+
+func TestAttrView(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ // Outputs for ParseFirst.
+ hdr linux.NetlinkAttrHeader
+ value []byte
+ restLen int
+ ok bool
+
+ // Outputs for Empty.
+ isEmpty bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x06, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 6,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment with rest data",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ 0xFF, 0xFE, // Rest data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 2,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too long",
+ input: []byte{
+ 0xFF, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too short",
+ input: []byte{
+ 0x01, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "empty",
+ input: []byte{},
+ ok: false,
+ isEmpty: true,
+ },
+ }
+ for _, test := range tests {
+ attrs := netlink.AttrsView(test.input)
+
+ // Test ParseFirst().
+ hdr, value, rest, ok := attrs.ParseFirst()
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ } else if test.ok {
+ if !reflect.DeepEqual(hdr, test.hdr) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr)
+ }
+ if !bytes.Equal(value, test.value) {
+ t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value)
+ }
+ if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest)
+ }
+ }
+
+ // Test Empty().
+ if got, want := attrs.Empty(), test.isEmpty; got != want {
+ t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 463544c1a..3a22923d8 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,17 +1,16 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "port",
srcs = ["port.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port",
visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
)
go_test(
name = "port_test",
srcs = ["port_test.go"],
- embed = [":port"],
+ library = ":port",
)
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
index e9d3275b1..2cd3afc22 100644
--- a/pkg/sentry/socket/netlink/port/port.go
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -24,7 +24,8 @@ import (
"fmt"
"math"
"math/rand"
- "sync"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
// maxPorts is a sanity limit on the maximum number of ports to allocate per
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 689cad997..31e374833 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -18,7 +18,7 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -30,12 +30,19 @@ type Protocol interface {
// Protocol returns the Linux netlink protocol value.
Protocol() int
+ // CanSend returns true if this protocol may ever send messages.
+ //
+ // TODO(gvisor.dev/issue/1119): This is a workaround to allow
+ // advertising support for otherwise unimplemented features on sockets
+ // that will never send messages, thus making those features no-ops.
+ CanSend() bool
+
// ProcessMessage processes a single message from userspace.
//
// If err == nil, any messages added to ms will be sent back to the
// other end of the socket. Setting ms.Multi will cause an NLMSG_DONE
// message to be sent even if ms contains no messages.
- ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error
+ ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error
}
// Provider is a function that creates a new Protocol for a specific netlink
@@ -60,6 +67,8 @@ func RegisterProvider(protocol int, provider Provider) {
protocols[protocol] = provider
}
+// LINT.IfChange
+
// socketProvider implements socket.Provider.
type socketProvider struct {
}
@@ -88,7 +97,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int
}
d := socket.NewDirent(t, netlinkSocketDevice)
- defer d.DecRef()
+ defer d.DecRef(t)
return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil
}
@@ -98,7 +107,10 @@ func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.Fi
return nil, nil, syserr.ErrNotSupported
}
+// LINT.ThenChange(./provider_vfs2.go)
+
// init registers the socket provider.
func init() {
socket.RegisterProvider(linux.AF_NETLINK, &socketProvider{})
+ socket.RegisterProviderVFS2(linux.AF_NETLINK, &socketProviderVFS2{})
}
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
new file mode 100644
index 000000000..bb205be0d
--- /dev/null
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// socketProviderVFS2 implements socket.Provider.
+type socketProviderVFS2 struct {
+}
+
+// Socket implements socket.Provider.Socket.
+func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Netlink sockets must be specified as datagram or raw, but they
+ // behave the same regardless of type.
+ if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW {
+ return nil, syserr.ErrSocketNotSupported
+ }
+
+ provider, ok := protocols[protocol]
+ if !ok {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ p, err := provider(t)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewVFS2(t, stype, p)
+ if err != nil {
+ return nil, err
+ }
+
+ vfsfd := &s.vfsfd
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+ if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Pair implements socket.Provider.Pair by returning an error.
+func (*socketProviderVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Netlink sockets never supports creating socket pairs.
+ return nil, nil, syserr.ErrNotSupported
+}
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 1d4912753..93127398d 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -1,15 +1,16 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "route",
- srcs = ["protocol.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route",
+ srcs = [
+ "protocol.go",
+ ],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index cc70ac237..c84d8bd7c 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -17,9 +17,10 @@ package route
import (
"bytes"
+ "syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -61,8 +62,13 @@ func (p *Protocol) Protocol() int {
return linux.NETLINK_ROUTE
}
-// dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests.
-func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return true
+}
+
+// dumpLinks handles RTM_GETLINK dump requests.
+func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
// NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
// ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some
// userspace applications (including glibc) still include rtgenmsg.
@@ -86,38 +92,105 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader
return nil
}
- for id, i := range stack.Interfaces() {
- m := ms.AddMessage(linux.NetlinkMessageHeader{
- Type: linux.RTM_NEWLINK,
- })
+ for idx, i := range stack.Interfaces() {
+ addNewLinkMessage(ms, idx, i)
+ }
- m.Put(linux.InterfaceInfoMessage{
- Family: linux.AF_UNSPEC,
- Type: i.DeviceType,
- Index: id,
- Flags: i.Flags,
- })
+ return nil
+}
- m.PutAttrString(linux.IFLA_IFNAME, i.Name)
- m.PutAttr(linux.IFLA_MTU, i.MTU)
+// getLinks handles RTM_GETLINK requests.
+func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
- mac := make([]byte, 6)
- brd := mac
- if len(i.Addr) > 0 {
- mac = i.Addr
- brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
+ // Parse message.
+ var ifi linux.InterfaceInfoMessage
+ attrs, ok := msg.GetData(&ifi)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ // Parse attributes.
+ var byName []byte
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
}
- m.PutAttr(linux.IFLA_ADDRESS, mac)
- m.PutAttr(linux.IFLA_BROADCAST, brd)
+ attrs = rest
- // TODO(gvisor.dev/issue/578): There are many more attributes.
+ switch ahdr.Type {
+ case linux.IFLA_IFNAME:
+ if len(value) < 1 {
+ return syserr.ErrInvalidArgument
+ }
+ byName = value[:len(value)-1]
+
+ // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK.
+ }
}
+ found := false
+ for idx, i := range stack.Interfaces() {
+ switch {
+ case ifi.Index > 0:
+ if idx != ifi.Index {
+ continue
+ }
+ case byName != nil:
+ if string(byName) != i.Name {
+ continue
+ }
+ default:
+ // Criteria not specified.
+ return syserr.ErrInvalidArgument
+ }
+
+ addNewLinkMessage(ms, idx, i)
+ found = true
+ break
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
return nil
}
-// dumpAddrs handles RTM_GETADDR + NLM_F_DUMP requests.
-func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+// addNewLinkMessage appends RTM_NEWLINK message for the given interface into
+// the message set.
+func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWLINK,
+ })
+
+ m.Put(linux.InterfaceInfoMessage{
+ Family: linux.AF_UNSPEC,
+ Type: i.DeviceType,
+ Index: idx,
+ Flags: i.Flags,
+ })
+
+ m.PutAttrString(linux.IFLA_IFNAME, i.Name)
+ m.PutAttr(linux.IFLA_MTU, i.MTU)
+
+ mac := make([]byte, 6)
+ brd := mac
+ if len(i.Addr) > 0 {
+ mac = i.Addr
+ brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
+ }
+ m.PutAttr(linux.IFLA_ADDRESS, mac)
+ m.PutAttr(linux.IFLA_BROADCAST, brd)
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+}
+
+// dumpAddrs handles RTM_GETADDR dump requests.
+func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
// RTM_GETADDR dump requests need not contain anything more than the
// netlink header and 1 byte protocol family common to all
// NETLINK_ROUTE requests.
@@ -149,6 +222,7 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader
Index: uint32(id),
})
+ m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr))
m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
// TODO(gvisor.dev/issue/578): There are many more attributes.
@@ -158,22 +232,136 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader
return nil
}
-// dumpRoutes handles RTM_GETROUTE + NLM_F_DUMP requests.
-func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+// commonPrefixLen reports the length of the longest IP address prefix.
+// This is a simplied version from Golang's src/net/addrselect.go.
+func commonPrefixLen(a, b []byte) (cpl int) {
+ for len(a) > 0 {
+ if a[0] == b[0] {
+ cpl += 8
+ a = a[1:]
+ b = b[1:]
+ continue
+ }
+ bits := 8
+ ab, bb := a[0], b[0]
+ for {
+ ab >>= 1
+ bb >>= 1
+ bits--
+ if ab == bb {
+ cpl += bits
+ return
+ }
+ }
+ }
+ return
+}
+
+// fillRoute returns the Route using LPM algorithm. Refer to Linux's
+// net/ipv4/route.c:rt_fill_info().
+func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) {
+ family := uint8(linux.AF_INET)
+ if len(addr) != 4 {
+ family = linux.AF_INET6
+ }
+
+ idx := -1 // Index of the Route rule to be returned.
+ idxDef := -1 // Index of the default route rule.
+ prefix := 0 // Current longest prefix.
+ for i, route := range routes {
+ if route.Family != family {
+ continue
+ }
+
+ if len(route.GatewayAddr) > 0 && route.DstLen == 0 {
+ idxDef = i
+ continue
+ }
+
+ cpl := commonPrefixLen(addr, route.DstAddr)
+ if cpl < int(route.DstLen) {
+ continue
+ }
+ cpl = int(route.DstLen)
+ if cpl > prefix {
+ idx = i
+ prefix = cpl
+ }
+ }
+ if idx == -1 {
+ idx = idxDef
+ }
+ if idx == -1 {
+ return inet.Route{}, syserr.ErrNoRoute
+ }
+
+ route := routes[idx]
+ if family == linux.AF_INET {
+ route.DstLen = 32
+ } else {
+ route.DstLen = 128
+ }
+ route.DstAddr = addr
+ route.Flags |= linux.RTM_F_CLONED // This route is cloned.
+ return route, nil
+}
+
+// parseForDestination parses a message as format of RouteMessage-RtAttr-dst.
+func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) {
+ var rtMsg linux.RouteMessage
+ attrs, ok := msg.GetData(&rtMsg)
+ if !ok {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See
+ // commit bc234301af12. Note we don't check this flag for backward
+ // compatibility.
+ if rtMsg.Flags != 0 && rtMsg.Flags != linux.RTM_F_LOOKUP_TABLE {
+ return nil, syserr.ErrNotSupported
+ }
+
+ // Expect first attribute is RTA_DST.
+ if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST {
+ return value, nil
+ }
+ return nil, syserr.ErrInvalidArgument
+}
+
+// dumpRoutes handles RTM_GETROUTE requests.
+func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
// RTM_GETROUTE dump requests need not contain anything more than the
// netlink header and 1 byte protocol family common to all
// NETLINK_ROUTE requests.
- // We always send back an NLMSG_DONE.
- ms.Multi = true
-
stack := inet.StackFromContext(ctx)
if stack == nil {
// No network routes.
return nil
}
- for _, rt := range stack.RouteTable() {
+ hdr := msg.Header()
+ routeTables := stack.RouteTable()
+
+ if hdr.Flags == linux.NLM_F_REQUEST {
+ dst, err := parseForDestination(msg)
+ if err != nil {
+ return err
+ }
+ route, err := fillRoute(routeTables, dst)
+ if err != nil {
+ // TODO(gvisor.dev/issue/1237): return NLMSG_ERROR with ENETUNREACH.
+ return syserr.ErrNotSupported
+ }
+ routeTables = append([]inet.Route{}, route)
+ } else if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP {
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+ } else {
+ // TODO(b/68878065): Only above cases are supported.
+ return syserr.ErrNotSupported
+ }
+
+ for _, rt := range routeTables {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.RTM_NEWROUTE,
})
@@ -214,10 +402,55 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade
return nil
}
+// newAddr handles RTM_NEWADDR requests.
+func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err == syscall.EEXIST {
+ flags := msg.Header().Flags
+ if flags&linux.NLM_F_EXCL != 0 {
+ return syserr.ErrExists
+ }
+ } else if err != nil {
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+ return nil
+}
+
// ProcessMessage implements netlink.Protocol.ProcessMessage.
-func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ hdr := msg.Header()
+
// All messages start with a 1 byte protocol family.
- if len(data) < 1 {
+ var family uint8
+ if _, ok := msg.GetData(&family); !ok {
// Linux ignores messages missing the protocol family. See
// net/core/rtnetlink.c:rtnetlink_rcv_msg.
return nil
@@ -231,22 +464,32 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH
}
}
- // TODO(b/68878065): Only the dump variant of the types below are
- // supported.
- if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP {
- return syserr.ErrNotSupported
- }
-
- switch hdr.Type {
- case linux.RTM_GETLINK:
- return p.dumpLinks(ctx, hdr, data, ms)
- case linux.RTM_GETADDR:
- return p.dumpAddrs(ctx, hdr, data, ms)
- case linux.RTM_GETROUTE:
- return p.dumpRoutes(ctx, hdr, data, ms)
- default:
- return syserr.ErrNotSupported
+ if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP {
+ // TODO(b/68878065): Only the dump variant of the types below are
+ // supported.
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.dumpLinks(ctx, msg, ms)
+ case linux.RTM_GETADDR:
+ return p.dumpAddrs(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+ } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST {
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.getLink(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ case linux.RTM_NEWADDR:
+ return p.newAddr(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
}
+ return syserr.ErrNotSupported
}
// init registers the NETLINK_ROUTE provider.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index b2732ca29..68a9b9a96 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -17,27 +17,29 @@ package netlink
import (
"math"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
@@ -53,15 +55,19 @@ const (
maxSendBufferSize = 4 << 20 // 4MB
)
+var errNoFilter = syserr.New("no filter attached", linux.ENOENT)
+
// netlinkSocketDevice is the netlink socket virtual device.
var netlinkSocketDevice = device.NewAnonDevice()
+// LINT.IfChange
+
// Socket is the base socket type for netlink sockets.
//
// This implementation only supports userspace sending and receiving messages
// to/from the kernel.
//
-// Socket implements socket.Socket.
+// Socket implements socket.Socket and transport.Credentialer.
//
// +stateify savable
type Socket struct {
@@ -72,6 +78,14 @@ type Socket struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
socket.SendReceiveTimeout
// ports provides netlink port allocation.
@@ -104,9 +118,19 @@ type Socket struct {
// sendBufferSize is the send buffer "size". We don't actually have a
// fixed buffer but only consume this many bytes.
sendBufferSize uint32
+
+ // passcred indicates if this socket wants SCM credentials.
+ passcred bool
+
+ // filter indicates that this socket has a BPF filter "installed".
+ //
+ // TODO(gvisor.dev/issue/1119): We don't actually support filtering,
+ // this is just bookkeeping for tracking add/remove.
+ filter bool
}
var _ socket.Socket = (*Socket)(nil)
+var _ transport.Credentialer = (*Socket)(nil)
// NewSocket creates a new Socket.
func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
@@ -116,31 +140,33 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke
// Bind the endpoint for good measure so we can connect to it. The
// bound address will never be exposed.
if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
- ep.Close()
+ ep.Close(t)
return nil, err
}
// Create a connection from which the kernel can write messages.
connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
if err != nil {
- ep.Close()
+ ep.Close(t)
return nil, err
}
return &Socket{
- ports: t.Kernel().NetlinkPorts(),
- protocol: protocol,
- skType: skType,
- ep: ep,
- connection: connection,
- sendBufferSize: defaultSendBufferSize,
+ socketOpsCommon: socketOpsCommon{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ skType: skType,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ },
}, nil
}
// Release implements fs.FileOperations.Release.
-func (s *Socket) Release() {
- s.connection.Release()
- s.ep.Close()
+func (s *socketOpsCommon) Release(ctx context.Context) {
+ s.connection.Release(ctx)
+ s.ep.Close(ctx)
if s.bound {
s.ports.Release(s.protocol.Protocol(), s.portID)
@@ -148,7 +174,7 @@ func (s *Socket) Release() {
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
// ep holds messages to be read and thus handles EventIn readiness.
ready := s.ep.Readiness(mask)
@@ -162,16 +188,32 @@ func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *Socket) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
// Writable readiness never changes, so no registration is needed.
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *Socket) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
+// Passcred implements transport.Credentialer.Passcred.
+func (s *socketOpsCommon) Passcred() bool {
+ s.mu.Lock()
+ passcred := s.passcred
+ s.mu.Unlock()
+ return passcred
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *socketOpsCommon) ConnectedPasscred() bool {
+ // This socket is connected to the kernel, which doesn't need creds.
+ //
+ // This is arbitrary, as ConnectedPasscred on this type has no callers.
+ return false
+}
+
// Ioctl implements fs.FileOperations.Ioctl.
func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
// TODO(b/68878065): no ioctls supported.
@@ -199,7 +241,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
// port of 0 defaults to the ThreadGroup ID.
//
// Preconditions: mu is held.
-func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
+func (s *socketOpsCommon) bindPort(t *kernel.Task, port int32) *syserr.Error {
if s.bound {
// Re-binding is only allowed if the port doesn't change.
if port != s.portID {
@@ -223,7 +265,7 @@ func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
}
// Bind implements socket.Socket.Bind.
-func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
a, err := ExtractSockAddr(sockaddr)
if err != nil {
return err
@@ -241,7 +283,7 @@ func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Connect implements socket.Socket.Connect.
-func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
a, err := ExtractSockAddr(sockaddr)
if err != nil {
return err
@@ -272,25 +314,25 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr
}
// Accept implements socket.Socket.Accept.
-func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Netlink sockets never support accept.
return 0, nil, 0, syserr.ErrNotSupported
}
// Listen implements socket.Socket.Listen.
-func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// Netlink sockets never support listen.
return syserr.ErrNotSupported
}
// Shutdown implements socket.Socket.Shutdown.
-func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// Netlink sockets never support shutdown.
return syserr.ErrNotSupported
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -300,18 +342,31 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.
}
s.mu.Lock()
defer s.mu.Unlock()
- return int32(s.sendBufferSize), nil
+ sendBufferSizeP := primitive.Int32(s.sendBufferSize)
+ return &sendBufferSizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- return int32(math.MaxInt32), nil
+ recvBufferSizeP := primitive.Int32(math.MaxInt32)
+ return &recvBufferSizeP, nil
+
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ var passcred primitive.Int32
+ if s.Passcred() {
+ passcred = 1
+ }
+ return &passcred, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
+
case linux.SOL_NETLINK:
switch name {
case linux.NETLINK_BROADCAST_ERROR,
@@ -330,7 +385,7 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.
}
// SetSockOpt implements socket.Socket.SetSockOpt.
-func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -348,6 +403,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
s.sendBufferSize = size
s.mu.Unlock()
return nil
+
case linux.SO_RCVBUF:
if len(opt) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -355,6 +411,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
// We don't have limit on receiving size. So just accept anything as
// valid for compatibility.
return nil
+
+ case linux.SO_PASSCRED:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ passcred := usermem.ByteOrder.Uint32(opt)
+
+ s.mu.Lock()
+ s.passcred = passcred != 0
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_ATTACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): We don't actually
+ // support filtering. If this socket can't ever send
+ // messages, then there is nothing to filter and we can
+ // advertise support. Otherwise, be conservative and
+ // return an error.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ s.filter = true
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_DETACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): See above.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ filter := s.filter
+ s.filter = false
+ s.mu.Unlock()
+
+ if !filter {
+ return errNoFilter
+ }
+
+ return nil
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
@@ -380,7 +482,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -392,7 +494,7 @@ func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
sa := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
// TODO(b/68878065): Support non-kernel peers. For now the peer
@@ -403,7 +505,7 @@ func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
from := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: 0,
@@ -413,29 +515,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
trunc := flags&linux.MSG_TRUNC != 0
r := unix.EndpointReader{
+ Ctx: t,
Endpoint: s.ep,
Peek: flags&linux.MSG_PEEK != 0,
}
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
// If MSG_TRUNC is set with a zero byte destination then we still need
// to read the message and discard it, or in the case where MSG_PEEK is
// set, leave it be. In both cases the full message length must be
- // returned. However, the memory manager for the destination will not read
- // the endpoint if the destination is zero length.
- //
- // In order for the endpoint to be read when the destination size is zero,
- // we must cause a read of the endpoint by using a separate fake zero
- // length block sequence and calling the EndpointReader directly.
+ // returned.
if trunc && dst.Addrs.NumBytes() == 0 {
- // Perform a read to a zero byte block sequence. We can ignore the
- // original destination since it was zero bytes. The length returned by
- // ReadToBlocks is ignored and we return the full message length to comply
- // with MSG_TRUNC.
- _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0))))
- return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
}
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
@@ -453,7 +555,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var mflags int
if n < int64(r.MsgSize) {
mflags |= linux.MSG_TRUNC
@@ -483,18 +585,43 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _
})
}
+// kernelSCM implements control.SCMCredentials with credentials that represent
+// the kernel itself rather than a Task.
+//
+// +stateify savable
+type kernelSCM struct{}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
+ _, ok := oc.(kernelSCM)
+ return ok
+}
+
+// Credentials implements control.SCMCredentials.Credentials.
+func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ return 0, auth.RootUID, auth.RootGID
+}
+
+// kernelCreds is the concrete version of kernelSCM used in all creds.
+var kernelCreds = &kernelSCM{}
+
// sendResponse sends the response messages in ms back to userspace.
-func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
+func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
// Linux combines multiple netlink messages into a single datagram.
bufs := make([][]byte, 0, len(ms.Messages))
for _, m := range ms.Messages {
bufs = append(bufs, m.Finalize())
}
+ // All messages are from the kernel.
+ cms := transport.ControlMessages{
+ Credentials: kernelCreds,
+ }
+
if len(bufs) > 0 {
// RecvMsg never receives the address, so we don't need to send
// one.
- _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send(ctx, bufs, cms, tcpip.FullAddress{})
// If the buffer is full, we simply drop messages, just like
// Linux.
if err != nil && err != syserr.ErrWouldBlock {
@@ -521,7 +648,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
// Add the dump_done_errno payload.
m.Put(int64(0))
- _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
return err
}
@@ -533,47 +660,38 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
return nil
}
-func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error {
+func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
-
m.Put(linux.NetlinkErrorMessage{
Error: int32(-err.ToLinux().Number()),
Header: hdr,
})
- return nil
+}
+func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+ m.Put(linux.NetlinkErrorMessage{
+ Error: 0,
+ Header: hdr,
+ })
}
// processMessages handles each message in buf, passing it to the protocol
// handler for final handling.
-func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
+func (s *socketOpsCommon) processMessages(ctx context.Context, buf []byte) *syserr.Error {
for len(buf) > 0 {
- if len(buf) < linux.NetlinkMessageHeaderSize {
+ msg, rest, ok := ParseMessage(buf)
+ if !ok {
// Linux ignores messages that are too short. See
// net/netlink/af_netlink.c:netlink_rcv_skb.
break
}
-
- var hdr linux.NetlinkMessageHeader
- binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr)
-
- if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) {
- // Linux ignores malformed messages. See
- // net/netlink/af_netlink.c:netlink_rcv_skb.
- break
- }
-
- // Data from this message.
- data := buf[linux.NetlinkMessageHeaderSize:hdr.Length]
-
- // Advance to the next message.
- next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO)
- if next >= len(buf)-1 {
- next = len(buf) - 1
- }
- buf = buf[next:]
+ buf = rest
+ hdr := msg.Header()
// Ignore control messages.
if hdr.Type < linux.NLMSG_MIN_TYPE {
@@ -581,19 +699,10 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error
}
ms := NewMessageSet(s.portID, hdr.Seq)
- var err *syserr.Error
- // TODO(b/68877377): ACKs not supported yet.
- if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
- err = syserr.ErrNotSupported
- } else {
-
- err = s.protocol.ProcessMessage(ctx, hdr, data, ms)
- }
- if err != nil {
- ms = NewMessageSet(s.portID, hdr.Seq)
- if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil {
- return err
- }
+ if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil {
+ dumpErrorMesage(hdr, ms, err)
+ } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
+ dumpAckMesage(hdr, ms)
}
if err := s.sendResponse(ctx, ms); err != nil {
@@ -605,7 +714,7 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error
}
// sendMsg is the core of message send, used for SendMsg and Write.
-func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
dstPort := int32(0)
if len(to) != 0 {
@@ -652,7 +761,7 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte,
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
return s.sendMsg(t, src, to, flags, controlMessages)
}
@@ -663,11 +772,13 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence,
}
// State implements socket.Socket.State.
-func (s *Socket) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
return s.ep.State()
}
// Type implements socket.Socket.Type.
-func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
}
+
+// LINT.ThenChange(./socket_vfs2.go)
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
new file mode 100644
index 000000000..a38d25da9
--- /dev/null
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -0,0 +1,152 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 is the base VFS2 socket type for netlink sockets.
+//
+// This implementation only supports userspace sending and receiving messages
+// to/from the kernel.
+//
+// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ socket.SocketVFS2 = (*SocketVFS2)(nil)
+var _ transport.Credentialer = (*SocketVFS2)(nil)
+
+// NewVFS2 creates a new SocketVFS2.
+func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketVFS2, *syserr.Error) {
+ // Datagram endpoint used to buffer kernel -> user messages.
+ ep := transport.NewConnectionless(t)
+
+ // Bind the endpoint for good measure so we can connect to it. The
+ // bound address will never be exposed.
+ if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
+ ep.Close(t)
+ return nil, err
+ }
+
+ // Create a connection from which the kernel can write messages.
+ connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
+ if err != nil {
+ ep.Close(t)
+ return nil, err
+ }
+
+ fd := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ skType: skType,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ },
+ }
+ fd.LockFD.Init(&vfs.FileLocks{})
+ return fd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (*SocketVFS2) Ioctl(context.Context, usermem.IO, arch.SyscallArguments) (uintptr, error) {
+ // TODO(b/68878065): no ioctls supported.
+ return 0, syserror.ENOTTY
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &unix.EndpointReader{
+ Endpoint: s.ep,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
+ return int64(n), err.ToError()
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD
new file mode 100644
index 000000000..b6434923c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "uevent",
+ srcs = ["protocol.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go
new file mode 100644
index 000000000..029ba21b5
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/protocol.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol.
+//
+// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does
+// not support any device events, so these sockets never send any messages.
+package uevent
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_KOBJECT_UEVENT
+}
+
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return false
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // Silently ignore all messages.
+ return nil
+}
+
+// init registers the NETLINK_KOBJECT_UEVENT provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index e414d8055..1fb777a6c 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -7,44 +7,52 @@ go_library(
srcs = [
"device.go",
"netstack.go",
+ "netstack_vfs2.go",
"provider.go",
+ "provider_vfs2.go",
"save_restore.go",
"stack.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack",
visibility = [
"//pkg/sentry:internal",
],
deps = [
"//pkg/abi/linux",
+ "//pkg/amutex",
"//pkg/binary",
+ "//pkg/context",
"//pkg/log",
"//pkg/metric",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netfilter",
"//pkg/sentry/unimpl",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
+ "//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 27c6692c4..e4846bc0b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,29 +26,32 @@ package netstack
import (
"bytes"
+ "fmt"
"io"
"math"
"reflect"
- "sync"
+ "sync/atomic"
"syscall"
"time"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/unimpl"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -57,12 +60,21 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
var cm tcpip.StatCounter
- metric.MustRegisterCustomUint64Metric(name, false /* sync */, description, cm.Value)
+ metric.MustRegisterCustomUint64Metric(name, true /* cumulative */, false /* sync */, description, cm.Value)
+ return &cm
+}
+
+func mustCreateGauge(name, description string) *tcpip.StatCounter {
+ var cm tcpip.StatCounter
+ metric.MustRegisterCustomUint64Metric(name, false /* cumulative */, false /* sync */, description, cm.Value)
return &cm
}
@@ -138,19 +150,23 @@ var Metrics = tcpip.Stats{
},
},
IP: tcpip.IPStats{
- PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
- InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
- PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
- PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
- OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
- MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
- MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
+ PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
+ InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
+ InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."),
+ PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
+ PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
+ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
+ MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
},
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
- CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."),
+ CurrentEstablished: mustCreateGauge("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."),
+ CurrentConnected: mustCreateGauge("/netstack/tcp/current_open", "Number of connections that are in connected state."),
EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"),
+ EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "Number of times established TCP connections made a transition to CLOSED state."),
+ EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."),
ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."),
@@ -178,6 +194,8 @@ var Metrics = tcpip.Stats{
MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
+ ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
+ InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."),
},
}
@@ -220,19 +238,29 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
+ // transport.Endpoint.SetSockOptBool.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+ // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
}
+// LINT.IfChange
+
// SocketOperations encapsulates all the state needed to represent a network stack
// endpoint in the kernel context.
//
@@ -244,6 +272,14 @@ type SocketOperations struct {
fsutil.FileNoFsync `state:"nosave"`
fsutil.FileNoMMap `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
socket.SendReceiveTimeout
*waiter.Queue
@@ -252,14 +288,21 @@ type SocketOperations struct {
skType linux.SockType
protocol int
+ // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
+ // Must be accessed using atomic operations. It must only be written
+ // with readMu held but can be read without holding readMu. The latter
+ // is required to avoid deadlocks in epoll Readiness checks.
+ readViewHasData uint32
+
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
// readView contains the remaining payload from the last packet.
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
- sender tcpip.FullAddress
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+ linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When
@@ -281,19 +324,21 @@ type SocketOperations struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}
dirent := socket.NewDirent(t, netstackDevice)
- defer dirent.DecRef()
+ defer dirent.DecRef(t)
return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{
- Queue: queue,
- family: family,
- Endpoint: endpoint,
- skType: skType,
- protocol: protocol,
+ socketOpsCommon: socketOpsCommon{
+ Queue: queue,
+ family: family,
+ Endpoint: endpoint,
+ skType: skType,
+ protocol: protocol,
+ },
}), nil
}
@@ -314,22 +359,15 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
// AF_INET6, and AF_PACKET addresses.
//
-// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
-//
// AddressAndFamily returns an address and its family.
-func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
- family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
- }
-
// Get the rest of the fields based on the address family.
- switch family {
+ switch family := usermem.ByteOrder.Uint16(addr); family {
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
@@ -385,7 +423,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- // TODO(b/129292371): Return protocol too.
+ // TODO(gvisor.dev/issue/173): Return protocol too.
return tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
@@ -399,33 +437,49 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
}
}
-func (s *SocketOperations) isPacketBased() bool {
+func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
// fetchReadView updates the readView field of the socket if it's currently
// empty. It assumes that the socket is locked.
-func (s *SocketOperations) fetchReadView() *syserr.Error {
+//
+// Precondition: s.readMu must be held.
+func (s *socketOpsCommon) fetchReadView() *syserr.Error {
if len(s.readView) > 0 {
return nil
}
-
s.readView = nil
s.sender = tcpip.FullAddress{}
+ s.linkPacketInfo = tcpip.LinkPacketInfo{}
- v, cms, err := s.Endpoint.Read(&s.sender)
+ var v buffer.View
+ var cms tcpip.ControlMessages
+ var err *tcpip.Error
+
+ switch e := s.Endpoint.(type) {
+ // The ordering of these interfaces matters. The most specific
+ // interfaces must be specified before the more generic Endpoint
+ // interface.
+ case tcpip.PacketEndpoint:
+ v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
+ case tcpip.Endpoint:
+ v, cms, err = e.Read(&s.sender)
+ }
if err != nil {
+ atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
}
s.readView = v
s.readCM = cms
+ atomic.StoreUint32(&s.readViewHasData, 1)
return nil
}
// Release implements fs.FileOperations.Release.
-func (s *SocketOperations) Release() {
+func (s *socketOpsCommon) Release(context.Context) {
s.Endpoint.Close()
}
@@ -520,11 +574,9 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
if resCh != nil {
- t := ctx.(*kernel.Task)
- if err := t.Block(resCh); err != nil {
- return 0, syserr.FromError(err).ToError()
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
}
-
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
}
@@ -593,11 +645,9 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
if resCh != nil {
- t := ctx.(*kernel.Task)
- if err := t.Block(resCh); err != nil {
- return 0, syserr.FromError(err).ToError()
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
}
-
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
Atomic: true, // See above.
})
@@ -612,26 +662,54 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
// Readiness returns a mask of ready events for socket s.
-func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
r := s.Endpoint.Readiness(mask)
// Check our cached value iff the caller asked for readability and the
// endpoint itself is currently not readable.
if (mask & ^r & waiter.EventIn) != 0 {
- s.readMu.Lock()
- if len(s.readView) > 0 {
+ if atomic.LoadUint32(&s.readViewHasData) == 1 {
r |= waiter.EventIn
}
- s.readMu.Unlock()
}
return r
}
+func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
+ if family == uint16(s.family) {
+ return nil
+ }
+ if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
+ v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ if !v {
+ return nil
+ }
+ }
+ return syserr.ErrInvalidArgument
+}
+
+// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
+// receiver's family is AF_INET6.
+//
+// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
+// represented by the empty string.
+//
+// TODO(gvisor.dev/issue/1556): remove this function.
+func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
+ if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
+ addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ }
+ return addr
+}
+
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
-func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ addr, family, err := AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -643,6 +721,12 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
return syserr.TranslateNetstackError(err)
}
+
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return err
+ }
+ addr = s.mapFamily(addr, family)
+
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -655,6 +739,14 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
defer s.EventUnregister(&e)
if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
+ // TCP unlike UDP returns EADDRNOTAVAIL when it can't
+ // find an available local ephemeral port.
+ if err == tcpip.ErrNoPortAvailable {
+ return syserr.ErrAddressNotAvailable
+ }
+ }
+
return syserr.TranslateNetstackError(err)
}
@@ -670,10 +762,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
- if err != nil {
- return err
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
+ }
+
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
}
// Issue the bind request to the endpoint.
@@ -682,13 +808,13 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// Listen implements the linux syscall listen(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
}
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
@@ -728,7 +854,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
if err != nil {
return 0, nil, 0, err
}
- defer ns.DecRef()
+ defer ns.DecRef(t)
if flags&linux.SOCK_NONBLOCK != 0 {
flags := ns.Flags()
@@ -774,7 +900,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
f, err := ConvertShutdown(how)
if err != nil {
return err
@@ -786,7 +912,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -796,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -824,22 +950,30 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
return nil, syserr.ErrInvalidArgument
}
- info, err := netfilter.GetInfo(t, s.Endpoint, outPtr)
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
return nil, syserr.ErrInvalidArgument
}
- entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen)
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -849,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -874,8 +1008,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
return nil, syserr.ErrProtocolNotAvailable
}
+func boolToInt32(v bool) int32 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -886,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
// Get the last error and convert it.
err := ep.GetSockOpt(tcpip.ErrorOption{})
if err == nil {
- return int32(0), nil
+ optP := primitive.Int32(0)
+ return &optP, nil
}
- return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number())
+ return &optP, nil
case linux.SO_PEERCRED:
if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
@@ -896,23 +1040,25 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
}
tcred := t.Credentials()
- return syscall.Ucred{
- Pid: int32(t.ThreadGroup().ID()),
- Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
- Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
- }, nil
+ creds := linux.ControlMessageCredentials{
+ PID: int32(t.ThreadGroup().ID()),
+ UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }
+ return &creds, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.PasscredOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.PasscredOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -928,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
@@ -944,74 +1091,93 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_REUSEADDR:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReuseAddressOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReusePortOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if len(v) == 0 {
- return []byte{}, nil
+ if v == 0 {
+ var b primitive.ByteSlice
+ return &b, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
}
- return append([]byte(v), 0), nil
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+
+ name := primitive.ByteSlice(append([]byte(nic.Name), 0))
+ return &name, nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.BroadcastOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.KeepaliveEnabledOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- return linux.Linger{}, nil
+
+ linger := linux.Linger{}
+ return &linger, nil
case linux.SO_SNDTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1019,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.SendTimeout()), nil
+ sendTimeout := linux.NsecToTimeval(s.SendTimeout())
+ return &sendTimeout, nil
case linux.SO_RCVTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1027,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.RecvTimeout()), nil
+ recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
+ return &recvTimeout, nil
case linux.SO_OOBINLINE:
if outLen < sizeOfInt32 {
@@ -1039,7 +1207,20 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
+
+ case linux.SO_NO_CHECK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1048,58 +1229,58 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptInt(tcpip.DelayOption)
+ v, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if v == 0 {
- return int32(1), nil
- }
- return int32(0), nil
+ vP := primitive.Int32(boolToInt32(!v))
+ return &vP, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.CorkOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.CorkOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.QuickAckOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MaxSegOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MaxSegOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
@@ -1110,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveIdle, nil
case linux.TCP_KEEPINTVL:
if outLen < sizeOfInt32 {
@@ -1122,8 +1303,32 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
+ keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveInterval, nil
+
+ case linux.TCP_KEEPCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(v)
+ return &vP, nil
- return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_USER_TIMEOUT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPUserTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond)
+ return &tcpUserTimeout, nil
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
@@ -1136,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
info := linux.TCPInfo{}
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &info)
- if len(ib) > outLen {
- ib = ib[:outLen]
+ buf := t.CopyScratchBuffer(info.SizeBytes())
+ info.MarshalUnsafe(buf)
+ if len(buf) > outLen {
+ buf = buf[:outLen]
}
-
- return ib, nil
+ bufP := primitive.ByteSlice(buf)
+ return &bufP, nil
case linux.TCP_CC_INFO,
linux.TCP_NOTSENT_LOWAT,
@@ -1171,8 +1377,59 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
b := make([]byte, toCopy)
copy(b, v)
- return b, nil
+ bP := primitive.ByteSlice(b)
+ return &bP, nil
+
+ case linux.TCP_LINGER2:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPLingerTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ return &lingerTimeout, nil
+
+ case linux.TCP_DEFER_ACCEPT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPDeferAcceptOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second)
+ return &tcpDeferAccept, nil
+
+ case linux.TCP_SYNCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TCPSynCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(v)
+ return &vP, nil
+
+ case linux.TCP_WINDOW_CLAMP:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TCPWindowClampOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ vP := primitive.Int32(v)
+ return &vP, nil
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1180,19 +1437,20 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.V6OnlyOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1200,21 +1458,41 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_TCLASS:
// Length handling for parity with Linux.
if outLen == 0 {
- return make([]byte, 0), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
- var v tcpip.IPv6TrafficClassOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- uintv := uint32(v)
+ uintv := primitive.Uint32(v)
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ ib := t.CopyScratchBuffer(uintv.SizeBytes())
+ uintv.MarshalUnsafe(ib)
// Handle cases where outLen is lesser than sizeOfInt32.
if len(ib) > outLen {
ib = ib[:outLen]
}
- return ib, nil
+ ibP := primitive.ByteSlice(ib)
+ return &ibP, nil
+
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
+ case linux.SO_ORIGINAL_DST:
+ // TODO(gvisor.dev/issue/170): ip6tables.
+ return nil, syserr.ErrInvalidArgument
default:
emitUnimplementedEventIPv6(t, name)
@@ -1223,36 +1501,38 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.TTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.TTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
// Fill in the default value, if needed.
- if v == 0 {
- v = DefaultTTL
+ vP := primitive.Int32(v)
+ if vP == 0 {
+ vP = DefaultTTL
}
- return int32(v), nil
+ return &vP, nil
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastTTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_MULTICAST_IF:
if outLen < len(linux.InetAddr{}) {
@@ -1266,36 +1546,76 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(*linux.SockAddrInet).Addr, nil
+ return &a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastLoopOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if v {
- return int32(1), nil
- }
- return int32(0), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
- return []byte(nil), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
- var v tcpip.IPv4TOSOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
- return uint8(v), nil
+ vP := primitive.Uint8(v)
+ return &vP, nil
+ }
+ vP := primitive.Int32(v)
+ return &vP, nil
+
+ case linux.IP_RECVTOS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- return int32(v), nil
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
+
+ case linux.SO_ORIGINAL_DST:
+ if outLen < int(binary.Size(linux.SockAddrInet{})) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OriginalDestinationOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ return a.(*linux.SockAddrInet), nil
default:
emitUnimplementedEventIP(t, name)
@@ -1330,12 +1650,32 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
return nil
}
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
// sockets backed by a commonEndpoint.
-func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
switch level {
case linux.SOL_SOCKET:
return setSockOptSocket(t, s, ep, name, optVal)
@@ -1362,7 +1702,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n
}
// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
-func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.SO_SNDBUF:
if len(optVal) < sizeOfInt32 {
@@ -1386,7 +1726,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1394,14 +1734,27 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
if n == -1 {
n = len(optVal)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
@@ -1409,7 +1762,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
@@ -1417,7 +1770,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1425,7 +1778,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1466,6 +1819,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v)))
+ case linux.SO_NO_CHECK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+
case linux.SO_LINGER:
if len(optVal) < linux.SizeOfLinger {
return syserr.ErrInvalidArgument
@@ -1480,6 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
return nil
+ case linux.SO_DETACH_FILTER:
+ // optval is ignored.
+ var v tcpip.SocketDetachFilterOption
+ return syserr.TranslateNetstackError(ep.SetSockOpt(v))
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
@@ -1497,11 +1863,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- var o int
- if v == 0 {
- o = 1
- }
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -1509,7 +1871,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -1517,7 +1879,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -1525,7 +1887,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
case linux.TCP_KEEPIDLE:
if len(optVal) < sizeOfInt32 {
@@ -1549,6 +1911,28 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_KEEPCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPCNT {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v)))
+
+ case linux.TCP_USER_TIMEOUT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))))
+
case linux.TCP_CONGESTION:
v := tcpip.CongestionControlOption(optVal)
if err := ep.SetSockOpt(v); err != nil {
@@ -1556,6 +1940,40 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return nil
+ case linux.TCP_LINGER2:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_DEFER_ACCEPT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_SYNCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v)))
+
+ case linux.TCP_WINDOW_CLAMP:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v)))
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1576,13 +1994,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
case linux.IPV6_ADD_MEMBERSHIP,
linux.IPV6_DROP_MEMBERSHIP,
linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
linux.IPV6_PKTINFO,
linux.IPV6_ROUTER_ALERT,
linux.IPV6_XFRM_POLICY,
@@ -1606,7 +2025,15 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
if v == -1 {
v = 0
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v)))
+
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
default:
emitUnimplementedEventIPv6(t, name)
@@ -1683,7 +2110,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if v < 0 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v)))
case linux.IP_ADD_MEMBERSHIP:
req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
@@ -1730,9 +2157,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(
- tcpip.MulticastLoopOption(v != 0),
- ))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -1751,7 +2176,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
} else if v < 1 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v)))
case linux.IP_TOS:
if len(optVal) == 0 {
@@ -1761,7 +2186,34 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v)))
+
+ case linux.IP_RECVTOS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
+ case linux.IP_HDRINCL:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
@@ -1769,7 +2221,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_CHECKSUM,
linux.IP_DROP_SOURCE_MEMBERSHIP,
linux.IP_FREEBIND,
- linux.IP_HDRINCL,
linux.IP_IPSEC_POLICY,
linux.IP_MINTTL,
linux.IP_MSFILTER,
@@ -1778,12 +2229,10 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_PKTINFO,
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
linux.IP_RECVORIGDSTADDR,
- linux.IP_RECVTOS,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -1811,30 +2260,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) {
switch name {
case linux.TCP_CONGESTION,
linux.TCP_CORK,
- linux.TCP_DEFER_ACCEPT,
linux.TCP_FASTOPEN,
linux.TCP_FASTOPEN_CONNECT,
linux.TCP_FASTOPEN_KEY,
linux.TCP_FASTOPEN_NO_COOKIE,
- linux.TCP_KEEPCNT,
- linux.TCP_KEEPIDLE,
- linux.TCP_KEEPINTVL,
- linux.TCP_LINGER2,
- linux.TCP_MAXSEG,
linux.TCP_QUEUE_SEQ,
- linux.TCP_QUICKACK,
linux.TCP_REPAIR,
linux.TCP_REPAIR_QUEUE,
linux.TCP_REPAIR_WINDOW,
linux.TCP_SAVED_SYN,
linux.TCP_SAVE_SYN,
- linux.TCP_SYNCNT,
linux.TCP_THIN_DUPACK,
linux.TCP_THIN_LINEAR_TIMEOUTS,
linux.TCP_TIMESTAMP,
- linux.TCP_ULP,
- linux.TCP_USER_TIMEOUT,
- linux.TCP_WINDOW_CLAMP:
+ linux.TCP_ULP:
t.Kernel().EmitUnimplementedEvent(t)
}
@@ -1876,7 +2315,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
- linux.IPV6_RECVTCLASS,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
linux.IPV6_TCLASS,
@@ -1981,8 +2419,8 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
case linux.AF_INET6:
var out linux.SockAddrInet6
- if len(addr.Addr) == 4 {
- // Copy address is v4-mapped format.
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
copy(out.Addr[12:], addr.Addr)
out.Addr[10] = 0xff
out.Addr[11] = 0xff
@@ -1997,7 +2435,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(sockAddrInet6Size)
case linux.AF_PACKET:
- // TODO(b/129292371): Return protocol too.
+ // TODO(gvisor.dev/issue/173): Return protocol too.
var out linux.SockAddrLink
out.Family = linux.AF_PACKET
out.InterfaceIndex = int32(addr.NIC)
@@ -2012,7 +2450,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2024,7 +2462,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32,
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.Endpoint.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -2039,16 +2477,21 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
// caller.
//
// Precondition: s.readMu must be locked.
-func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
+func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
var err *syserr.Error
var copied int
// Copy as many views as possible into the user-provided buffer.
- for dst.NumBytes() != 0 {
+ for {
+ // Always do at least one fetchReadView, even if the number of bytes to
+ // read is 0.
err = s.fetchReadView()
if err != nil {
break
}
+ if dst.NumBytes() == 0 {
+ break
+ }
var n int
var e error
@@ -2066,6 +2509,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
}
copied += n
s.readView.TrimFront(n)
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
@@ -2082,7 +2529,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
return 0, err
}
-func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) {
+func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
if !s.sockOptInq {
return
}
@@ -2094,10 +2541,27 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) {
cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
}
+func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
+ switch pktType {
+ case tcpip.PacketHost:
+ return linux.PACKET_HOST
+ case tcpip.PacketOtherHost:
+ return linux.PACKET_OTHERHOST
+ case tcpip.PacketOutgoing:
+ return linux.PACKET_OUTGOING
+ case tcpip.PacketBroadcast:
+ return linux.PACKET_BROADCAST
+ case tcpip.PacketMulticast:
+ return linux.PACKET_MULTICAST
+ default:
+ panic(fmt.Sprintf("unknown packet type: %d", pktType))
+ }
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
-func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
@@ -2112,9 +2576,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// caller-supplied buffer.
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
- s.readMu.Unlock()
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
+ s.readMu.Unlock()
return n, 0, nil, 0, cmsg, err
}
@@ -2149,6 +2613,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = ConvertAddress(s.family, s.sender)
+ switch v := addr.(type) {
+ case *linux.SockAddrLink:
+ v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
+ }
}
if peek {
@@ -2188,6 +2657,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readView.TrimFront(int(n))
}
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
var flags int
if msgLen > int(n) {
flags |= linux.MSG_TRUNC
@@ -2202,15 +2675,26 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
}
-func (s *SocketOperations) controlMessages() socket.ControlMessages {
- return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}}
+func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
+ return socket.ControlMessages{
+ IP: tcpip.ControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ },
+ }
}
// updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after
// successfully writing packet data out to userspace.
//
// Precondition: s.readMu must be locked.
-func (s *SocketOperations) updateTimestamp() {
+func (s *socketOpsCommon) updateTimestamp() {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp {
s.timestampValid = true
@@ -2220,7 +2704,7 @@ func (s *SocketOperations) updateTimestamp() {
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -2288,7 +2772,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
// Reject Unix control messages.
if !controlMessages.Unix.Empty() {
return 0, syserr.ErrInvalidArgument
@@ -2296,10 +2780,14 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
+ addrBuf, family, err := AddressAndFamily(to)
if err != nil {
return 0, err
}
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return 0, err
+ }
+ addrBuf = s.mapFamily(addrBuf, family)
addr = &addrBuf
}
@@ -2360,11 +2848,20 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
// Ioctl implements fs.FileOperations.Ioctl.
func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return s.socketOpsCommon.ioctl(ctx, io, args)
+}
+
+func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
// SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
switch args[1].Int() {
- case syscall.SIOCGSTAMP:
+ case linux.SIOCGSTAMP:
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
@@ -2372,9 +2869,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
tv := linux.NsecToTimeval(s.timestampNS)
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := tv.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2384,16 +2879,17 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
// Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
v += len(s.readView)
+ s.readMu.Unlock()
if v > math.MaxInt32 {
v = math.MaxInt32
}
- // Copy result to user-space.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ // Copy result to userspace.
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
}
@@ -2402,52 +2898,49 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// Ioctl performs a socket ioctl.
func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("ioctl(2) may only be called from a task goroutine")
+ }
+
switch arg := int(args[1].Int()); arg {
- case syscall.SIOCGIFFLAGS,
- syscall.SIOCGIFADDR,
- syscall.SIOCGIFBRDADDR,
- syscall.SIOCGIFDSTADDR,
- syscall.SIOCGIFHWADDR,
- syscall.SIOCGIFINDEX,
- syscall.SIOCGIFMAP,
- syscall.SIOCGIFMETRIC,
- syscall.SIOCGIFMTU,
- syscall.SIOCGIFNAME,
- syscall.SIOCGIFNETMASK,
- syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFFLAGS,
+ linux.SIOCGIFADDR,
+ linux.SIOCGIFBRDADDR,
+ linux.SIOCGIFDSTADDR,
+ linux.SIOCGIFHWADDR,
+ linux.SIOCGIFINDEX,
+ linux.SIOCGIFMAP,
+ linux.SIOCGIFMETRIC,
+ linux.SIOCGIFMTU,
+ linux.SIOCGIFNAME,
+ linux.SIOCGIFNETMASK,
+ linux.SIOCGIFTXQLEN,
+ linux.SIOCETHTOOL:
var ifr linux.IFReq
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil {
return 0, err.ToError()
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ _, err := ifr.CopyOut(t, args[2].Pointer())
return 0, err
- case syscall.SIOCGIFCONF:
+ case linux.SIOCGIFCONF:
// Return a list of interface addresses or the buffer size
// necessary to hold the list.
var ifc linux.IFConf
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil {
return 0, err
}
- if err := ifconfIoctl(ctx, io, &ifc); err != nil {
+ if err := ifconfIoctl(ctx, t, io, &ifc); err != nil {
return 0, err
}
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
+ _, err := ifc.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCINQ:
@@ -2459,10 +2952,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
if v > math.MaxInt32 {
v = math.MaxInt32
}
- // Copy result to user-space.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ // Copy result to userspace.
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.TIOCOUTQ:
@@ -2475,10 +2967,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
v = math.MaxInt32
}
- // Copy result to user-space.
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
- AddressSpaceActive: true,
- })
+ // Copy result to userspace.
+ vP := primitive.Int32(v)
+ _, err := vP.CopyOut(t, args[2].Pointer())
return 0, err
case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
@@ -2504,7 +2995,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to
// identify a device.
- if arg == syscall.SIOCGIFNAME {
+ if arg == linux.SIOCGIFNAME {
// Gets the name of the interface given the interface index
// stored in ifr_ifindex.
index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
@@ -2527,21 +3018,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
switch arg {
- case syscall.SIOCGIFINDEX:
+ case linux.SIOCGIFINDEX:
// Copy out the index to the data.
usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
- case syscall.SIOCGIFHWADDR:
+ case linux.SIOCGIFHWADDR:
// Copy the hardware address out.
- ifr.Data[0] = 6 // IEEE802.2 arp type.
- ifr.Data[1] = 0
+ //
+ // Refer: https://linux.die.net/man/7/netdevice
+ // SIOCGIFHWADDR, SIOCSIFHWADDR
+ //
+ // Get or set the hardware address of a device using
+ // ifr_hwaddr. The hardware address is specified in a struct
+ // sockaddr. sa_family contains the ARPHRD_* device type,
+ // sa_data the L2 hardware address starting from byte 0. Setting
+ // the hardware address is a privileged operation.
+ usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType)
n := copy(ifr.Data[2:], iface.Addr)
for i := 2 + n; i < len(ifr.Data); i++ {
ifr.Data[i] = 0 // Clear padding.
}
- usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
- case syscall.SIOCGIFFLAGS:
+ case linux.SIOCGIFFLAGS:
f, err := interfaceStatusFlags(stack, iface.Name)
if err != nil {
return err
@@ -2550,7 +3048,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
// matches Linux behavior.
usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
- case syscall.SIOCGIFADDR:
+ case linux.SIOCGIFADDR:
// Copy the IPv4 address out.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2561,32 +3059,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
- case syscall.SIOCGIFMETRIC:
+ case linux.SIOCGIFMETRIC:
// Gets the metric of the device. As per netdevice(7), this
// always just sets ifr_metric to 0.
usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
- case syscall.SIOCGIFMTU:
+ case linux.SIOCGIFMTU:
// Gets the MTU of the device.
usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
- case syscall.SIOCGIFMAP:
+ case linux.SIOCGIFMAP:
// Gets the hardware parameters of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFTXQLEN:
+ case linux.SIOCGIFTXQLEN:
// Gets the transmit queue length of the device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFDSTADDR:
+ case linux.SIOCGIFDSTADDR:
// Gets the destination address of a point-to-point device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFBRDADDR:
+ case linux.SIOCGIFBRDADDR:
// Gets the broadcast address of a device.
// TODO(gvisor.dev/issue/505): Implement.
- case syscall.SIOCGIFNETMASK:
+ case linux.SIOCGIFNETMASK:
// Gets the network mask of a device.
for _, addr := range stack.InterfaceAddrs()[index] {
// This ioctl is only compatible with AF_INET addresses.
@@ -2603,6 +3101,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
break
}
+ case linux.SIOCETHTOOL:
+ // Stubbed out for now, Ideally we should implement the required
+ // sub-commands for ETHTOOL
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c
+ return syserr.ErrEndpointOperation
+
default:
// Not a valid call.
return syserr.ErrInvalidArgument
@@ -2612,7 +3118,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe
}
// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error {
// If Ptr is NULL, return the necessary buffer size via Len.
// Otherwise, write up to Len bytes starting at Ptr containing ifreq
// structs.
@@ -2649,9 +3155,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
// Copy the ifr to userspace.
dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
+ if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil {
return err
}
}
@@ -2697,7 +3201,7 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
// State implements socket.Socket.State. State translates the internal state
// returned by netstack to values defined by Linux.
-func (s *SocketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
if s.family != linux.AF_INET && s.family != linux.AF_INET6 {
// States not implemented for this socket's family.
return 0
@@ -2757,6 +3261,8 @@ func (s *SocketOperations) State() uint32 {
}
// Type implements socket.Socket.Type.
-func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
return s.family, s.skType, s.protocol
}
+
+// LINT.ThenChange(./netstack_vfs2.go)
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
new file mode 100644
index 000000000..3335e7430
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -0,0 +1,332 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
+// SocketVFS2 encapsulates all the state needed to represent a network stack
+// endpoint in the kernel context.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&SocketVFS2{})
+
+// NewVFS2 creates a new endpoint socket.
+func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
+ if skType == linux.SOCK_STREAM {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ }
+
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ s := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ Queue: queue,
+ family: family,
+ Endpoint: endpoint,
+ skType: skType,
+ protocol: protocol,
+ },
+ }
+ s.LockFD.Init(&vfs.FileLocks{})
+ vfsfd := &s.vfsfd
+ if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ if err == syserr.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, err.ToError()
+ }
+ return int64(n), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ f := &ioSequencePayload{ctx: ctx, src: src}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
+ }
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ }
+
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, wq, terr := s.Endpoint.Accept()
+ if terr != nil {
+ if terr != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(terr)
+ }
+
+ var err *syserr.Error
+ ep, wq, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := NewVFS2(t, s.family, s.skType, s.protocol, wq, ep)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef(t)
+
+ if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil {
+ return 0, nil, 0, syserr.FromError(err)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer and write it to peer slice.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+
+ t.Kernel().RecordSocketVFS2(ns)
+
+ return fd, addr, addrLen, syserr.FromError(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return s.socketOpsCommon.ioctl(ctx, uio, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketVFS2 rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := primitive.Int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptTimestamp {
+ val = 1
+ }
+ return &val, nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := primitive.Int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return &val, nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
+ if err != nil {
+ return nil, err
+ }
+ return &info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return &entries, nil
+
+ }
+ }
+
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketVFS2 rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
+ return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 2d2c1ba2a..ead3b2b79 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -18,7 +18,7 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -33,6 +33,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// provider is an inet socket provider.
type provider struct {
family int
@@ -62,10 +64,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
- // TODO(b/142504697): "In order to create a raw socket, a
- // process must have the CAP_NET_RAW capability in the user
- // namespace that governs its network namespace." - raw(7)
-
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -75,6 +73,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
switch protocol {
case syscall.IPPROTO_ICMP:
return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
case syscall.IPPROTO_UDP:
return header.UDPProtocolNumber, true, nil
case syscall.IPPROTO_TCP:
@@ -124,6 +124,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
} else {
ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
}
if e != nil {
return nil, syserr.TranslateNetstackError(e)
@@ -133,10 +139,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
}
func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // TODO(b/142504697): "In order to create a packet socket, a process
- // must have the CAP_NET_RAW capability in the user namespace that
- // governs its network namespace." - packet(7)
-
// Packet sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(t)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -167,6 +169,8 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol
return New(t, linux.AF_PACKET, stype, protocol, wq, ep)
}
+// LINT.ThenChange(./provider_vfs2.go)
+
// Pair just returns nil sockets (not supported).
func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
return nil, nil, nil
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
new file mode 100644
index 000000000..2a01143f6
--- /dev/null
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// providerVFS2 is an inet socket provider.
+type providerVFS2 struct {
+ family int
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
+func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Fail right away if we don't have a stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ // Don't propagate an error here. Instead, allow the socket
+ // code to continue searching for another provider.
+ return nil, nil
+ }
+ eps, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocketVFS2(t, eps, stype, protocol)
+ }
+
+ // Figure out the transport protocol.
+ transProto, associated, err := getTransportProtocol(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
+ wq := &waiter.Queue{}
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
+ }
+ if e != nil {
+ return nil, syserr.TranslateNetstackError(e)
+ }
+
+ return NewVFS2(t, p.family, stype, int(transProto), wq, ep)
+}
+
+func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return NewVFS2(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
+// Pair just returns nil sockets (not supported).
+func (*providerVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ return nil, nil, nil
+}
+
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
+func init() {
+ // Providers backed by netstack.
+ p := []providerVFS2{
+ {
+ family: linux.AF_INET,
+ netProto: ipv4.ProtocolNumber,
+ },
+ {
+ family: linux.AF_INET6,
+ netProto: ipv6.ProtocolNumber,
+ },
+ {
+ family: linux.AF_PACKET,
+ },
+ }
+
+ for i := range p {
+ socket.RegisterProviderVFS2(p[i].family, &p[i])
+ }
+}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index d0102cfa3..f9097d6b2 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -15,14 +15,15 @@
package netstack
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -41,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool {
return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
}
+// Converts Netstack's ARPHardwareType to equivalent linux constants.
+func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 {
+ switch t {
+ case header.ARPHardwareNone:
+ return linux.ARPHRD_NONE
+ case header.ARPHardwareLoopback:
+ return linux.ARPHRD_LOOPBACK
+ case header.ARPHardwareEther:
+ return linux.ARPHRD_ETHER
+ default:
+ panic(fmt.Sprintf("unknown ARPHRD type: %d", t))
+ }
+}
+
// Interfaces implements inet.Stack.Interfaces.
func (s *Stack) Interfaces() map[int32]inet.Interface {
is := make(map[int32]inet.Interface)
for id, ni := range s.Stack.NICInfo() {
- var devType uint16
- if ni.Flags.Loopback {
- devType = linux.ARPHRD_LOOPBACK
- }
is[int32(id)] = inet.Interface{
Name: ni.Name,
Addr: []byte(ni.LinkAddress),
Flags: uint32(nicStateFlagsToLinux(ni.Flags)),
- DeviceType: devType,
+ DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType),
MTU: ni.MTU,
}
}
@@ -89,6 +100,59 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
return nicAddrs
}
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ var (
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ )
+ switch addr.Family {
+ case linux.AF_INET:
+ if len(addr.Addr) < header.IPv4AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv4AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv4.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
+
+ case linux.AF_INET6:
+ if len(addr.Addr) < header.IPv6AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv6AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv6.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
+
+ default:
+ return syserror.ENOTSUP
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: int(addr.PrefixLen),
+ },
+ }
+
+ // Attach address to interface.
+ if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network.
+ s.Stack.AddRoute(tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: tcpip.NICID(idx),
+ })
+ return nil
+}
+
// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
var rs tcp.ReceiveBufferSizeOption
@@ -143,39 +207,83 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
}
+// TCPRecovery implements inet.Stack.TCPRecovery.
+func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
+ var recovery tcp.Recovery
+ if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+ return inet.TCPLossRecovery(recovery), nil
+}
+
+// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
+func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError()
+}
+
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
switch stats := stat.(type) {
+ case *inet.StatDev:
+ for _, ni := range s.Stack.NICInfo() {
+ if ni.Name != arg {
+ continue
+ }
+ // TODO(gvisor.dev/issue/2103) Support stubbed stats.
+ *stats = inet.StatDev{
+ // Receive section.
+ ni.Stats.Rx.Bytes.Value(), // bytes.
+ ni.Stats.Rx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // frame.
+ 0, // compressed.
+ 0, // multicast.
+ // Transmit section.
+ ni.Stats.Tx.Bytes.Value(), // bytes.
+ ni.Stats.Tx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // colls.
+ 0, // carrier.
+ 0, // compressed.
+ }
+ break
+ }
case *inet.StatSNMPIP:
ip := Metrics.IP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPIP{
- 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL.
- ip.PacketsReceived.Value(), // InReceives.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors.
- ip.InvalidAddressesReceived.Value(), // InAddrErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards.
- ip.PacketsDelivered.Value(), // InDelivers.
- ip.PacketsSent.Value(), // OutRequests.
- ip.OutgoingPacketErrors.Value(), // OutDiscards.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates.
+ 0, // Ip/Forwarding.
+ 0, // Ip/DefaultTTL.
+ ip.PacketsReceived.Value(), // InReceives.
+ 0, // Ip/InHdrErrors.
+ ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors.
+ 0, // Ip/ForwDatagrams.
+ 0, // Ip/InUnknownProtos.
+ 0, // Ip/InDiscards.
+ ip.PacketsDelivered.Value(), // InDelivers.
+ ip.PacketsSent.Value(), // OutRequests.
+ ip.OutgoingPacketErrors.Value(), // OutDiscards.
+ 0, // Ip/OutNoRoutes.
+ 0, // Support Ip/ReasmTimeout.
+ 0, // Support Ip/ReasmReqds.
+ 0, // Support Ip/ReasmOKs.
+ 0, // Support Ip/ReasmFails.
+ 0, // Support Ip/FragOKs.
+ 0, // Support Ip/FragFails.
+ 0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs.
+ 0, // Icmp/InMsgs.
Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors.
+ 0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
in.ParamProblem.Value(), // InParmProbs.
@@ -187,7 +295,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.TimestampReply.Value(), // InTimestampReps.
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs.
+ 0, // Icmp/OutMsgs.
Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
out.DstUnreachable.Value(), // OutDestUnreachs.
out.TimeExceeded.Value(), // OutTimeExcds.
@@ -223,15 +331,16 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
case *inet.StatSNMPUDP:
udp := Metrics.UDP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPUDP{
udp.PacketsReceived.Value(), // InDatagrams.
udp.UnknownPortErrors.Value(), // NoPorts.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors.
+ 0, // Udp/InErrors.
udp.PacketsSent.Value(), // OutDatagrams.
udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti.
+ 0, // Udp/SndbufErrors.
+ udp.ChecksumErrors.Value(), // Udp/InCsumErrors.
+ 0, // Udp/IgnoredMulti.
}
default:
return syserr.ErrEndpointOperation.ToError()
@@ -278,21 +387,30 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (iptables.IPTables, error) {
+func (s *Stack) IPTables() (*stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
-// FillDefaultIPTables sets the stack's iptables to the default tables, which
-// allow and do not modify all traffic.
-func (s *Stack) FillDefaultIPTables() {
- netfilter.FillDefaultIPTables(s.Stack)
-}
-
// Resume implements inet.Stack.Resume.
func (s *Stack) Resume() {
s.Stack.Resume()
}
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return s.Stack.RegisteredEndpoints()
+}
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint {
+ return s.Stack.CleanupEndpoints()
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) {
+ s.Stack.RestoreCleanupEndpoints(es)
+}
+
// Forwarding implements inet.Stack.Forwarding.
func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
switch protocol {
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
deleted file mode 100644
index 3a6baa308..000000000
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ /dev/null
@@ -1,68 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "rpcinet",
- srcs = [
- "device.go",
- "rpcinet.go",
- "socket.go",
- "stack.go",
- "stack_unsafe.go",
- ],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- ":syscall_rpc_go_proto",
- "//pkg/abi/linux",
- "//pkg/binary",
- "//pkg/sentry/arch",
- "//pkg/sentry/context",
- "//pkg/sentry/device",
- "//pkg/sentry/fs",
- "//pkg/sentry/fs/fsutil",
- "//pkg/sentry/inet",
- "//pkg/sentry/kernel",
- "//pkg/sentry/kernel/time",
- "//pkg/sentry/socket",
- "//pkg/sentry/socket/hostinet",
- "//pkg/sentry/socket/rpcinet/conn",
- "//pkg/sentry/socket/rpcinet/notifier",
- "//pkg/sentry/unimpl",
- "//pkg/sentry/usermem",
- "//pkg/syserr",
- "//pkg/syserror",
- "//pkg/tcpip",
- "//pkg/tcpip/buffer",
- "//pkg/unet",
- "//pkg/waiter",
- ],
-)
-
-proto_library(
- name = "syscall_rpc_proto",
- srcs = ["syscall_rpc.proto"],
- visibility = [
- "//visibility:public",
- ],
-)
-
-cc_proto_library(
- name = "syscall_rpc_cc_proto",
- visibility = [
- "//visibility:public",
- ],
- deps = [":syscall_rpc_proto"],
-)
-
-go_proto_library(
- name = "syscall_rpc_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
- proto = ":syscall_rpc_proto",
- visibility = [
- "//visibility:public",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD
deleted file mode 100644
index 23eadcb1b..000000000
--- a/pkg/sentry/socket/rpcinet/conn/BUILD
+++ /dev/null
@@ -1,17 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "conn",
- srcs = ["conn.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/binary",
- "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
- "//pkg/syserr",
- "//pkg/unet",
- "@com_github_golang_protobuf//proto:go_default_library",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go
deleted file mode 100644
index 356adad99..000000000
--- a/pkg/sentry/socket/rpcinet/conn/conn.go
+++ /dev/null
@@ -1,187 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package conn is an RPC connection to a syscall RPC server.
-package conn
-
-import (
- "fmt"
- "sync"
- "sync/atomic"
- "syscall"
-
- "github.com/golang/protobuf/proto"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/unet"
-
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
-)
-
-type request struct {
- response []byte
- ready chan struct{}
- ignoreResult bool
-}
-
-// RPCConnection represents a single RPC connection to a syscall gofer.
-type RPCConnection struct {
- // reqID is the ID of the last request and must be accessed atomically.
- reqID uint64
-
- sendMu sync.Mutex
- socket *unet.Socket
-
- reqMu sync.Mutex
- requests map[uint64]request
-}
-
-// NewRPCConnection initializes a RPC connection to a socket gofer.
-func NewRPCConnection(s *unet.Socket) *RPCConnection {
- conn := &RPCConnection{socket: s, requests: map[uint64]request{}}
- go func() { // S/R-FIXME(b/77962828)
- var nums [16]byte
- for {
- for n := 0; n < len(nums); {
- nn, err := conn.socket.Read(nums[n:])
- if err != nil {
- panic(fmt.Sprint("error reading length from socket rpc gofer: ", err))
- }
- n += nn
- }
-
- b := make([]byte, binary.LittleEndian.Uint64(nums[:8]))
- id := binary.LittleEndian.Uint64(nums[8:])
-
- for n := 0; n < len(b); {
- nn, err := conn.socket.Read(b[n:])
- if err != nil {
- panic(fmt.Sprint("error reading request from socket rpc gofer: ", err))
- }
- n += nn
- }
-
- conn.reqMu.Lock()
- r := conn.requests[id]
- if r.ignoreResult {
- delete(conn.requests, id)
- } else {
- r.response = b
- conn.requests[id] = r
- }
- conn.reqMu.Unlock()
- close(r.ready)
- }
- }()
- return conn
-}
-
-// NewRequest makes a request to the RPC gofer and returns the request ID and a
-// channel which will be closed once the request completes.
-func (c *RPCConnection) NewRequest(req pb.SyscallRequest, ignoreResult bool) (uint64, chan struct{}) {
- b, err := proto.Marshal(&req)
- if err != nil {
- panic(fmt.Sprint("invalid proto: ", err))
- }
-
- id := atomic.AddUint64(&c.reqID, 1)
- ch := make(chan struct{})
-
- c.reqMu.Lock()
- c.requests[id] = request{ready: ch, ignoreResult: ignoreResult}
- c.reqMu.Unlock()
-
- c.sendMu.Lock()
- defer c.sendMu.Unlock()
-
- var nums [16]byte
- binary.LittleEndian.PutUint64(nums[:8], uint64(len(b)))
- binary.LittleEndian.PutUint64(nums[8:], id)
- for n := 0; n < len(nums); {
- nn, err := c.socket.Write(nums[n:])
- if err != nil {
- panic(fmt.Sprint("error writing length and ID to socket gofer: ", err))
- }
- n += nn
- }
-
- for n := 0; n < len(b); {
- nn, err := c.socket.Write(b[n:])
- if err != nil {
- panic(fmt.Sprint("error writing request to socket gofer: ", err))
- }
- n += nn
- }
-
- return id, ch
-}
-
-// RPCReadFile will execute the ReadFile helper RPC method which avoids the
-// common pattern of open(2), read(2), close(2) by doing all three operations
-// as a single RPC. It will read the entire file or return EFBIG if the file
-// was too large.
-func (c *RPCConnection) RPCReadFile(path string) ([]byte, *syserr.Error) {
- req := &pb.SyscallRequest_ReadFile{&pb.ReadFileRequest{
- Path: path,
- }}
-
- id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-ch
-
- res := c.Request(id).Result.(*pb.SyscallResponse_ReadFile).ReadFile.Result
- if e, ok := res.(*pb.ReadFileResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.ReadFileResponse_Data).Data, nil
-}
-
-// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
-// common pattern of open(2), write(2), write(2), close(2) by doing all
-// operations as a single RPC.
-func (c *RPCConnection) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
- req := &pb.SyscallRequest_WriteFile{&pb.WriteFileRequest{
- Path: path,
- Content: data,
- }}
-
- id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-ch
-
- res := c.Request(id).Result.(*pb.SyscallResponse_WriteFile).WriteFile
- if e := res.ErrorNumber; e != 0 {
- return int64(res.Written), syserr.FromHost(syscall.Errno(e))
- }
-
- return int64(res.Written), nil
-}
-
-// Request retrieves the request corresponding to the given request ID.
-//
-// The channel returned by NewRequest must have been closed before Request can
-// be called. This will happen automatically, do not manually close the
-// channel.
-func (c *RPCConnection) Request(id uint64) pb.SyscallResponse {
- c.reqMu.Lock()
- r := c.requests[id]
- delete(c.requests, id)
- c.reqMu.Unlock()
-
- var resp pb.SyscallResponse
- if err := proto.Unmarshal(r.response, &resp); err != nil {
- panic(fmt.Sprint("invalid proto: ", err))
- }
-
- return resp
-}
diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD
deleted file mode 100644
index a3585e10d..000000000
--- a/pkg/sentry/socket/rpcinet/notifier/BUILD
+++ /dev/null
@@ -1,16 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "notifier",
- srcs = ["notifier.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier",
- visibility = ["//:sandbox"],
- deps = [
- "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
- "//pkg/sentry/socket/rpcinet/conn",
- "//pkg/waiter",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
deleted file mode 100644
index 7efe4301f..000000000
--- a/pkg/sentry/socket/rpcinet/notifier/notifier.go
+++ /dev/null
@@ -1,231 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package notifier implements an FD notifier implementation over RPC.
-package notifier
-
-import (
- "fmt"
- "sync"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-type fdInfo struct {
- queue *waiter.Queue
- waiting bool
-}
-
-// Notifier holds all the state necessary to issue notifications when IO events
-// occur in the observed FDs.
-type Notifier struct {
- // rpcConn is the connection that is used for sending RPCs.
- rpcConn *conn.RPCConnection
-
- // epFD is the epoll file descriptor used to register for io
- // notifications.
- epFD uint32
-
- // mu protects fdMap.
- mu sync.Mutex
-
- // fdMap maps file descriptors to their notification queues and waiting
- // status.
- fdMap map[uint32]*fdInfo
-}
-
-// NewRPCNotifier creates a new notifier object.
-func NewRPCNotifier(cn *conn.RPCConnection) (*Notifier, error) {
- id, c := cn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCreate1{&pb.EpollCreate1Request{}}}, false /* ignoreResult */)
- <-c
-
- res := cn.Request(id).Result.(*pb.SyscallResponse_EpollCreate1).EpollCreate1.Result
- if e, ok := res.(*pb.EpollCreate1Response_ErrorNumber); ok {
- return nil, syscall.Errno(e.ErrorNumber)
- }
-
- w := &Notifier{
- rpcConn: cn,
- epFD: res.(*pb.EpollCreate1Response_Fd).Fd,
- fdMap: make(map[uint32]*fdInfo),
- }
-
- go w.waitAndNotify() // S/R-FIXME(b/77962828)
-
- return w, nil
-}
-
-// waitFD waits on mask for fd. The fdMap mutex must be hold.
-func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error {
- if !fi.waiting && mask == 0 {
- return nil
- }
-
- e := pb.EpollEvent{
- Events: mask.ToLinux() | unix.EPOLLET,
- Fd: fd,
- }
-
- switch {
- case !fi.waiting && mask != 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_ADD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
- <-c
-
- e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
- if e != 0 {
- return syscall.Errno(e)
- }
-
- fi.waiting = true
- case fi.waiting && mask == 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_DEL, Fd: fd}}}, false /* ignoreResult */)
- <-c
- n.rpcConn.Request(id)
-
- fi.waiting = false
- case fi.waiting && mask != 0:
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_MOD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
- <-c
-
- e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
- if e != 0 {
- return syscall.Errno(e)
- }
- }
-
- return nil
-}
-
-// addFD adds an FD to the list of FDs observed by n.
-func (n *Notifier) addFD(fd uint32, queue *waiter.Queue) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- // Panic if we're already notifying on this FD.
- if _, ok := n.fdMap[fd]; ok {
- panic(fmt.Sprintf("File descriptor %d added twice", fd))
- }
-
- // We have nothing to wait for at the moment. Just add it to the map.
- n.fdMap[fd] = &fdInfo{queue: queue}
-}
-
-// updateFD updates the set of events the FD needs to be notified on.
-func (n *Notifier) updateFD(fd uint32) error {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- if fi, ok := n.fdMap[fd]; ok {
- return n.waitFD(fd, fi, fi.queue.Events())
- }
-
- return nil
-}
-
-// RemoveFD removes an FD from the list of FDs observed by n.
-func (n *Notifier) removeFD(fd uint32) {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- // Remove from map, then from epoll object.
- n.waitFD(fd, n.fdMap[fd], 0)
- delete(n.fdMap, fd)
-}
-
-// hasFD returns true if the FD is in the list of observed FDs.
-func (n *Notifier) hasFD(fd uint32) bool {
- n.mu.Lock()
- defer n.mu.Unlock()
-
- _, ok := n.fdMap[fd]
- return ok
-}
-
-// waitAndNotify loops waiting for io event notifications from the epoll
-// object. Once notifications arrive, they are dispatched to the
-// registered queue.
-func (n *Notifier) waitAndNotify() error {
- for {
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollWait{&pb.EpollWaitRequest{Fd: n.epFD, NumEvents: 100, Msec: -1}}}, false /* ignoreResult */)
- <-c
-
- res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollWait).EpollWait.Result
- if e, ok := res.(*pb.EpollWaitResponse_ErrorNumber); ok {
- err := syscall.Errno(e.ErrorNumber)
- // NOTE(magi): I don't think epoll_wait can return EAGAIN but I'm being
- // conseratively careful here since exiting the notification thread
- // would be really bad.
- if err == syscall.EINTR || err == syscall.EAGAIN {
- continue
- }
- return err
- }
-
- n.mu.Lock()
- for _, e := range res.(*pb.EpollWaitResponse_Events).Events.Events {
- if fi, ok := n.fdMap[e.Fd]; ok {
- fi.queue.Notify(waiter.EventMaskFromLinux(e.Events))
- }
- }
- n.mu.Unlock()
- }
-}
-
-// AddFD adds an FD to the list of observed FDs.
-func (n *Notifier) AddFD(fd uint32, queue *waiter.Queue) error {
- n.addFD(fd, queue)
- return nil
-}
-
-// UpdateFD updates the set of events the FD needs to be notified on.
-func (n *Notifier) UpdateFD(fd uint32) error {
- return n.updateFD(fd)
-}
-
-// RemoveFD removes an FD from the list of observed FDs.
-func (n *Notifier) RemoveFD(fd uint32) {
- n.removeFD(fd)
-}
-
-// HasFD returns true if the FD is in the list of observed FDs.
-//
-// This should only be used by tests to assert that FDs are correctly
-// registered.
-func (n *Notifier) HasFD(fd uint32) bool {
- return n.hasFD(fd)
-}
-
-// NonBlockingPoll polls the given fd in non-blocking fashion. It is used just
-// to query the FD's current state; this method will block on the RPC response
-// although the syscall is non-blocking.
-func (n *Notifier) NonBlockingPoll(fd uint32, mask waiter.EventMask) waiter.EventMask {
- for {
- id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Poll{&pb.PollRequest{Fd: fd, Events: mask.ToLinux()}}}, false /* ignoreResult */)
- <-c
-
- res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_Poll).Poll.Result
- if e, ok := res.(*pb.PollResponse_ErrorNumber); ok {
- if syscall.Errno(e.ErrorNumber) == syscall.EINTR {
- continue
- }
- return mask
- }
-
- return waiter.EventMaskFromLinux(res.(*pb.PollResponse_Events).Events)
- }
-}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
deleted file mode 100644
index ddb76d9d4..000000000
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ /dev/null
@@ -1,909 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "sync/atomic"
- "syscall"
- "time"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/unimpl"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-// socketOperations implements fs.FileOperations and socket.Socket for a socket
-// implemented using a host socket.
-type socketOperations struct {
- fsutil.FilePipeSeek `state:"nosave"`
- fsutil.FileNotDirReaddir `state:"nosave"`
- fsutil.FileNoFsync `state:"nosave"`
- fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
- fsutil.FileNoopFlush `state:"nosave"`
- fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- socket.SendReceiveTimeout
-
- family int // Read-only.
- stype linux.SockType // Read-only.
- protocol int // Read-only.
-
- fd uint32 // must be O_NONBLOCK
- wq *waiter.Queue
- rpcConn *conn.RPCConnection
- notifier *notifier.Notifier
-
- // shState is the state of the connection with respect to shutdown. Because
- // we're mixing non-blocking semantics on the other side we have to adapt for
- // some strange differences between blocking and non-blocking sockets.
- shState int32
-}
-
-// Verify that we actually implement socket.Socket.
-var _ = socket.Socket(&socketOperations{})
-
-// New creates a new RPC socket.
-func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
- if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- fd := res.(*pb.SocketResponse_Fd).Fd
-
- var wq waiter.Queue
- stack.notifier.AddFD(fd, &wq)
-
- dirent := socket.NewDirent(ctx, socketDevice)
- defer dirent.DecRef()
- return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
- family: family,
- stype: skType,
- protocol: protocol,
- wq: &wq,
- fd: fd,
- rpcConn: stack.rpcConn,
- notifier: stack.notifier,
- }), nil
-}
-
-func isBlockingErrno(err error) bool {
- return err == syscall.EAGAIN || err == syscall.EWOULDBLOCK
-}
-
-func translateIOSyscallError(err error) error {
- if isBlockingErrno(err) {
- return syserror.ErrWouldBlock
- }
- return err
-}
-
-// setShutdownFlags will set the shutdown flag so we can handle blocking reads
-// after a read shutdown.
-func (s *socketOperations) setShutdownFlags(how int) {
- var f tcpip.ShutdownFlags
- switch how {
- case linux.SHUT_RD:
- f = tcpip.ShutdownRead
- case linux.SHUT_WR:
- f = tcpip.ShutdownWrite
- case linux.SHUT_RDWR:
- f = tcpip.ShutdownWrite | tcpip.ShutdownRead
- }
-
- // Atomically update the flags.
- for {
- old := atomic.LoadInt32(&s.shState)
- if atomic.CompareAndSwapInt32(&s.shState, old, old|int32(f)) {
- break
- }
- }
-}
-
-func (s *socketOperations) resetShutdownFlags() {
- atomic.StoreInt32(&s.shState, 0)
-}
-
-func (s *socketOperations) isShutRdSet() bool {
- return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownRead) != 0
-}
-
-func (s *socketOperations) isShutWrSet() bool {
- return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownWrite) != 0
-}
-
-// Release implements fs.FileOperations.Release.
-func (s *socketOperations) Release() {
- s.notifier.RemoveFD(s.fd)
-
- // We always need to close the FD.
- _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: s.fd}}}, true /* ignoreResult */)
-}
-
-// Readiness implements waiter.Waitable.Readiness.
-func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
- return s.notifier.NonBlockingPoll(s.fd, mask)
-}
-
-// EventRegister implements waiter.Waitable.EventRegister.
-func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
- s.wq.EventRegister(e, mask)
- s.notifier.UpdateFD(s.fd)
-}
-
-// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *socketOperations) EventUnregister(e *waiter.Entry) {
- s.wq.EventUnregister(e)
- s.notifier.UpdateFD(s.fd)
-}
-
-func rpcRead(t *kernel.Task, req *pb.SyscallRequest_Read) (*pb.ReadResponse_Data, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Read).Read.Result
- if e, ok := res.(*pb.ReadResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.ReadResponse_Data), nil
-}
-
-// Read implements fs.FileOperations.Read.
-func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- req := &pb.SyscallRequest_Read{&pb.ReadRequest{
- Fd: s.fd,
- Length: uint32(dst.NumBytes()),
- }}
-
- res, se := rpcRead(ctx.(*kernel.Task), req)
- if se == nil {
- n, e := dst.CopyOut(ctx, res.Data)
- return int64(n), e
- }
-
- return 0, se.ToError()
-}
-
-func rpcWrite(t *kernel.Task, req *pb.SyscallRequest_Write) (uint32, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Write).Write.Result
- if e, ok := res.(*pb.WriteResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.WriteResponse_Length).Length, nil
-}
-
-// Write implements fs.FileOperations.Write.
-func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- t := ctx.(*kernel.Task)
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, err
- }
-
- n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}})
- if n > 0 && n < uint32(src.NumBytes()) {
- // The FileOperations.Write interface expects us to return ErrWouldBlock in
- // the event of a partial write.
- return int64(n), syserror.ErrWouldBlock
- }
- return int64(n), err.ToError()
-}
-
-func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Connect{&pb.ConnectRequest{Fd: uint32(fd), Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Connect).Connect.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Connect implements socket.Socket.Connect.
-func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- if !blocking {
- e := rpcConnect(t, s.fd, sockaddr)
- if e == nil {
- // Reset the shutdown state on new connects.
- s.resetShutdownFlags()
- }
- return e
- }
-
- // Register for notification when the endpoint becomes writable, then
- // initiate the connection.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut|waiter.EventIn|waiter.EventHUp)
- defer s.EventUnregister(&e)
- for {
- if err := rpcConnect(t, s.fd, sockaddr); err == nil || err != syserr.ErrInProgress && err != syserr.ErrAlreadyInProgress {
- if err == nil {
- // Reset the shutdown state on new connects.
- s.resetShutdownFlags()
- }
- return err
- }
-
- // It's pending, so we have to wait for a notification, and fetch the
- // result once the wait completes.
- if err := t.Block(ch); err != nil {
- return syserr.FromError(err)
- }
- }
-}
-
-func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultPayload, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Accept{&pb.AcceptRequest{Fd: fd, Peer: peer, Flags: syscall.SOCK_NONBLOCK}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Accept).Accept.Result
- if e, ok := res.(*pb.AcceptResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- return res.(*pb.AcceptResponse_Payload).Payload, nil
-}
-
-// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- payload, se := rpcAccept(t, s.fd, peerRequested)
-
- // Check if we need to block.
- if blocking && se == syserr.ErrTryAgain {
- // Register for notifications.
- e, ch := waiter.NewChannelEntry(nil)
- // FIXME(b/119878986): This waiter.EventHUp is a partial
- // measure, need to figure out how to translate linux events to
- // internal events.
- s.EventRegister(&e, waiter.EventIn|waiter.EventHUp)
- defer s.EventUnregister(&e)
-
- // Try to accept the connection again; if it fails, then wait until we
- // get a notification.
- for {
- if payload, se = rpcAccept(t, s.fd, peerRequested); se != syserr.ErrTryAgain {
- break
- }
-
- if err := t.Block(ch); err != nil {
- return 0, nil, 0, syserr.FromError(err)
- }
- }
- }
-
- // Handle any error from accept.
- if se != nil {
- return 0, nil, 0, se
- }
-
- var wq waiter.Queue
- s.notifier.AddFD(payload.Fd, &wq)
-
- dirent := socket.NewDirent(t, socketDevice)
- defer dirent.DecRef()
- fileFlags := fs.FileFlags{
- Read: true,
- Write: true,
- NonSeekable: true,
- NonBlocking: flags&linux.SOCK_NONBLOCK != 0,
- }
- file := fs.NewFile(t, dirent, fileFlags, &socketOperations{
- family: s.family,
- stype: s.stype,
- protocol: s.protocol,
- wq: &wq,
- fd: payload.Fd,
- rpcConn: s.rpcConn,
- notifier: s.notifier,
- })
- defer file.DecRef()
-
- fd, err := t.NewFDFrom(0, file, kernel.FDFlags{
- CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- })
- if err != nil {
- return 0, nil, 0, syserr.FromError(err)
- }
- t.Kernel().RecordSocket(file)
-
- if peerRequested {
- return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil
- }
-
- return fd, nil, 0, nil
-}
-
-// Bind implements socket.Socket.Bind.
-func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: s.fd, Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Listen implements socket.Socket.Listen.
-func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Listen{&pb.ListenRequest{Fd: s.fd, Backlog: int64(backlog)}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Listen).Listen.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// Shutdown implements socket.Socket.Shutdown.
-func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
- // We save the shutdown state because of strange differences on linux
- // related to recvs on blocking vs. non-blocking sockets after a SHUT_RD.
- // We need to emulate that behavior on the blocking side.
- // TODO(b/120096741): There is a possible race that can exist with loopback,
- // where data could possibly be lost.
- s.setShutdownFlags(how)
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Shutdown{&pb.ShutdownRequest{Fd: s.fd, How: int64(how)}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
-
- return nil
-}
-
-// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
- // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed
- // within the sentry.
- if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
- if outLen < linux.SizeOfTimeval {
- return nil, syserr.ErrInvalidArgument
- }
-
- return linux.NsecToTimeval(s.RecvTimeout()), nil
- }
- if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
- if outLen < linux.SizeOfTimeval {
- return nil, syserr.ErrInvalidArgument
- }
-
- return linux.NsecToTimeval(s.SendTimeout()), nil
- }
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockOpt).GetSockOpt.Result
- if e, ok := res.(*pb.GetSockOptResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.GetSockOptResponse_Opt).Opt, nil
-}
-
-// SetSockOpt implements socket.Socket.SetSockOpt.
-func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
- // Because blocking actually happens within the sentry we need to inspect
- // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO,
- // and if so, we will save it and use it as the deadline for recv(2)
- // or send(2) related syscalls.
- if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO {
- if len(opt) < linux.SizeOfTimeval {
- return syserr.ErrInvalidArgument
- }
-
- var v linux.Timeval
- binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
- return syserr.ErrDomain
- }
- s.SetRecvTimeout(v.ToNsecCapped())
- return nil
- }
- if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO {
- if len(opt) < linux.SizeOfTimeval {
- return syserr.ErrInvalidArgument
- }
-
- var v linux.Timeval
- binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
- return syserr.ErrDomain
- }
- s.SetSendTimeout(v.ToNsecCapped())
- return nil
- }
-
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
- <-c
-
- if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_SetSockOpt).SetSockOpt.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-// GetPeerName implements socket.Socket.GetPeerName.
-func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetPeerName).GetPeerName.Result
- if e, ok := res.(*pb.GetPeerNameResponse_ErrorNumber); ok {
- return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- addr := res.(*pb.GetPeerNameResponse_Address).Address
- return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
-}
-
-// GetSockName implements socket.Socket.GetSockName.
-func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
- stack := t.NetworkContext().(*Stack)
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockName).GetSockName.Result
- if e, ok := res.(*pb.GetSockNameResponse_ErrorNumber); ok {
- return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- addr := res.(*pb.GetSockNameResponse_Address).Address
- return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil
-}
-
-func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) {
- stack := t.NetworkContext().(*Stack)
-
- id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Ioctl{&pb.IOCtlRequest{Fd: fd, Cmd: cmd, Arg: arg}}}, false /* ignoreResult */)
- <-c
-
- res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Ioctl).Ioctl.Result
- if e, ok := res.(*pb.IOCtlResponse_ErrorNumber); ok {
- return nil, syscall.Errno(e.ErrorNumber)
- }
-
- return res.(*pb.IOCtlResponse_Value).Value, nil
-}
-
-// ifconfIoctlFromStack populates a struct ifconf for the SIOCGIFCONF ioctl.
-func ifconfIoctlFromStack(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
- // If Ptr is NULL, return the necessary buffer size via Len.
- // Otherwise, write up to Len bytes starting at Ptr containing ifreq
- // structs.
- t := ctx.(*kernel.Task)
- s := t.NetworkContext().(*Stack)
- if s == nil {
- return syserr.ErrNoDevice.ToError()
- }
-
- if ifc.Ptr == 0 {
- ifc.Len = int32(len(s.Interfaces())) * int32(linux.SizeOfIFReq)
- return nil
- }
-
- max := ifc.Len
- ifc.Len = 0
- for key, ifaceAddrs := range s.InterfaceAddrs() {
- iface := s.Interfaces()[key]
- for _, ifaceAddr := range ifaceAddrs {
- // Don't write past the end of the buffer.
- if ifc.Len+int32(linux.SizeOfIFReq) > max {
- break
- }
- if ifaceAddr.Family != linux.AF_INET {
- continue
- }
-
- // Populate ifr.ifr_addr.
- ifr := linux.IFReq{}
- ifr.SetName(iface.Name)
- usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
- usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
- copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
-
- // Copy the ifr to userspace.
- dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
- ifc.Len += int32(linux.SizeOfIFReq)
- if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-// Ioctl implements fs.FileOperations.Ioctl.
-func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- t := ctx.(*kernel.Task)
-
- cmd := uint32(args[1].Int())
- arg := args[2].Pointer()
-
- var buf []byte
- switch cmd {
- // The following ioctls take 4 byte argument parameters.
- case syscall.TIOCINQ,
- syscall.TIOCOUTQ:
- buf = make([]byte, 4)
- // The following ioctls have args which are sizeof(struct ifreq).
- case syscall.SIOCGIFADDR,
- syscall.SIOCGIFBRDADDR,
- syscall.SIOCGIFDSTADDR,
- syscall.SIOCGIFFLAGS,
- syscall.SIOCGIFHWADDR,
- syscall.SIOCGIFINDEX,
- syscall.SIOCGIFMAP,
- syscall.SIOCGIFMETRIC,
- syscall.SIOCGIFMTU,
- syscall.SIOCGIFNAME,
- syscall.SIOCGIFNETMASK,
- syscall.SIOCGIFTXQLEN:
- buf = make([]byte, linux.SizeOfIFReq)
- case syscall.SIOCGIFCONF:
- // SIOCGIFCONF has slightly different behavior than the others, in that it
- // will need to populate the array of ifreqs.
- var ifc linux.IFConf
- if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- }); err != nil {
- return 0, err
- }
-
- if err := ifconfIoctlFromStack(ctx, io, &ifc); err != nil {
- return 0, err
- }
- _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
- return 0, err
-
- case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
- unimpl.EmitUnimplementedEvent(ctx)
-
- default:
- return 0, syserror.ENOTTY
- }
-
- _, err := io.CopyIn(ctx, arg, buf, usermem.IOOpts{
- AddressSpaceActive: true,
- })
-
- if err != nil {
- return 0, err
- }
-
- v, err := rpcIoctl(t, s.fd, cmd, buf)
- if err != nil {
- return 0, err
- }
-
- if len(v) != len(buf) {
- return 0, syserror.EINVAL
- }
-
- _, err = io.CopyOut(ctx, arg, v, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- return 0, err
-}
-
-func rpcRecvMsg(t *kernel.Task, req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
- if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.RecvmsgResponse_Payload).Payload, nil
-}
-
-// Because we only support SO_TIMESTAMP we will search control messages for
-// that value and set it if so, all other control messages will be ignored.
-func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_ResultPayload) socket.ControlMessages {
- c := socket.ControlMessages{}
- if len(payload.GetCmsgData()) > 0 {
- // Parse the control messages looking for SO_TIMESTAMP.
- msgs, e := syscall.ParseSocketControlMessage(payload.GetCmsgData())
- if e != nil {
- return socket.ControlMessages{}
- }
- for _, m := range msgs {
- if m.Header.Level != linux.SOL_SOCKET || m.Header.Type != linux.SO_TIMESTAMP {
- continue
- }
-
- // Let's parse the time stamp and set it.
- if len(m.Data) < linux.SizeOfTimeval {
- // Give up on locating the SO_TIMESTAMP option.
- return socket.ControlMessages{}
- }
-
- var v linux.Timeval
- binary.Unmarshal(m.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
- c.IP.HasTimestamp = true
- c.IP.Timestamp = v.ToNsecCapped()
- break
- }
- }
- return c
-}
-
-// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
- Fd: s.fd,
- Length: uint32(dst.NumBytes()),
- Sender: senderRequested,
- Trunc: flags&linux.MSG_TRUNC != 0,
- Peek: flags&linux.MSG_PEEK != 0,
- CmsgLength: uint32(controlDataLen),
- }}
-
- res, err := rpcRecvMsg(t, req)
- if err == nil {
- var e error
- var n int
- if len(res.Data) > 0 {
- n, e = dst.CopyOut(t, res.Data)
- if e == nil && n != len(res.Data) {
- panic("CopyOut failed to copy full buffer")
- }
- }
- c := s.extractControlMessages(res)
- return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
- }
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
- return 0, 0, nil, 0, socket.ControlMessages{}, err
- }
-
- // We'll have to block. Register for notifications and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventIn)
- defer s.EventUnregister(&e)
-
- for {
- res, err := rpcRecvMsg(t, req)
- if err == nil {
- var e error
- var n int
- if len(res.Data) > 0 {
- n, e = dst.CopyOut(t, res.Data)
- if e == nil && n != len(res.Data) {
- panic("CopyOut failed to copy full buffer")
- }
- }
- c := s.extractControlMessages(res)
- return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e)
- }
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- return 0, 0, nil, 0, socket.ControlMessages{}, err
- }
-
- if s.isShutRdSet() {
- // Blocking would have caused us to block indefinitely so we return 0,
- // this is the same behavior as Linux.
- return 0, 0, nil, 0, socket.ControlMessages{}, nil
- }
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
- }
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
- }
- }
-}
-
-func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
- s := t.NetworkContext().(*Stack)
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
- if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.SendmsgResponse_Length).Length, nil
-}
-
-// SendMsg implements socket.Socket.SendMsg.
-func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
- // Whitelist flags.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
- return 0, syserr.ErrInvalidArgument
- }
-
- // Reject Unix control messages.
- if !controlMessages.Unix.Empty() {
- return 0, syserr.ErrInvalidArgument
- }
-
- v := buffer.NewView(int(src.NumBytes()))
-
- // Copy all the data into the buffer.
- if _, err := src.CopyIn(t, v); err != nil {
- return 0, syserr.FromError(err)
- }
-
- // TODO(bgeffon): this needs to change to map directly to a SendMsg syscall
- // in the RPC.
- totalWritten := 0
- n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: uint32(s.fd),
- Data: v,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }})
-
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
- return int(n), err
- }
-
- if n > 0 {
- totalWritten += int(n)
- v.TrimFront(int(n))
- }
-
- // We'll have to block. Register for notification and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut)
- defer s.EventUnregister(&e)
-
- for {
- n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: uint32(s.fd),
- Data: v,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }})
-
- if n > 0 {
- totalWritten += int(n)
- v.TrimFront(int(n))
-
- if err == nil && totalWritten < int(src.NumBytes()) {
- continue
- }
- }
-
- if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- // We eat the error in this situation.
- return int(totalWritten), nil
- }
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return int(totalWritten), syserr.ErrTryAgain
- }
- return int(totalWritten), syserr.FromError(err)
- }
- }
-}
-
-// State implements socket.Socket.State.
-func (s *socketOperations) State() uint32 {
- // TODO(b/127845868): Define a new rpc to query the socket state.
- return 0
-}
-
-// Type implements socket.Socket.Type.
-func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
- return s.family, s.stype, s.protocol
-}
-
-type socketProvider struct {
- family int
-}
-
-// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // Check that we are using the RPC network stack.
- stack := t.NetworkContext()
- if stack == nil {
- return nil, nil
- }
-
- s, ok := stack.(*Stack)
- if !ok {
- return nil, nil
- }
-
- // Only accept TCP and UDP.
- //
- // Try to restrict the flags we will accept to minimize backwards
- // incompatibility with netstack.
- stype := stypeflags & linux.SOCK_TYPE_MASK
- switch stype {
- case syscall.SOCK_STREAM:
- switch protocol {
- case 0, syscall.IPPROTO_TCP:
- // ok
- default:
- return nil, nil
- }
- case syscall.SOCK_DGRAM:
- switch protocol {
- case 0, syscall.IPPROTO_UDP:
- // ok
- default:
- return nil, nil
- }
- default:
- return nil, nil
- }
-
- return newSocketFile(t, s, p.family, stype, protocol)
-}
-
-// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
- // Not supported by AF_INET/AF_INET6.
- return nil, nil, nil
-}
-
-func init() {
- for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
- socket.RegisterProvider(family, &socketProvider{family})
- }
-}
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
deleted file mode 100644
index f5441b826..000000000
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ /dev/null
@@ -1,178 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "fmt"
- "syscall"
-
- "gvisor.dev/gvisor/pkg/sentry/inet"
- "gvisor.dev/gvisor/pkg/sentry/socket/hostinet"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
- "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
- "gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/unet"
-)
-
-// Stack implements inet.Stack for RPC backed sockets.
-type Stack struct {
- interfaces map[int32]inet.Interface
- interfaceAddrs map[int32][]inet.InterfaceAddr
- routes []inet.Route
- rpcConn *conn.RPCConnection
- notifier *notifier.Notifier
-}
-
-// NewStack returns a Stack containing the current state of the host network
-// stack.
-func NewStack(fd int32) (*Stack, error) {
- sock, err := unet.NewSocket(int(fd))
- if err != nil {
- return nil, err
- }
-
- stack := &Stack{
- interfaces: make(map[int32]inet.Interface),
- interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
- rpcConn: conn.NewRPCConnection(sock),
- }
-
- var e error
- stack.notifier, e = notifier.NewRPCNotifier(stack.rpcConn)
- if e != nil {
- return nil, e
- }
-
- links, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETLINK)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETLINK failed: %v", err)
- }
-
- addrs, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETADDR)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETADDR failed: %v", err)
- }
-
- e = hostinet.ExtractHostInterfaces(links, addrs, stack.interfaces, stack.interfaceAddrs)
- if e != nil {
- return nil, e
- }
-
- routes, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETROUTE)
- if err != nil {
- return nil, fmt.Errorf("RTM_GETROUTE failed: %v", err)
- }
-
- stack.routes, e = hostinet.ExtractHostRoutes(routes)
- if e != nil {
- return nil, e
- }
-
- return stack, nil
-}
-
-// RPCReadFile will execute the ReadFile helper RPC method which avoids the
-// common pattern of open(2), read(2), close(2) by doing all three operations
-// as a single RPC. It will read the entire file or return EFBIG if the file
-// was too large.
-func (s *Stack) RPCReadFile(path string) ([]byte, *syserr.Error) {
- return s.rpcConn.RPCReadFile(path)
-}
-
-// RPCWriteFile will execute the WriteFile helper RPC method which avoids the
-// common pattern of open(2), write(2), write(2), close(2) by doing all
-// operations as a single RPC.
-func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) {
- return s.rpcConn.RPCWriteFile(path, data)
-}
-
-// Interfaces implements inet.Stack.Interfaces.
-func (s *Stack) Interfaces() map[int32]inet.Interface {
- interfaces := make(map[int32]inet.Interface)
- for k, v := range s.interfaces {
- interfaces[k] = v
- }
- return interfaces
-}
-
-// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
-func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
- addrs := make(map[int32][]inet.InterfaceAddr)
- for k, v := range s.interfaceAddrs {
- addrs[k] = append([]inet.InterfaceAddr(nil), v...)
- }
- return addrs
-}
-
-// SupportsIPv6 implements inet.Stack.SupportsIPv6.
-func (s *Stack) SupportsIPv6() bool {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
-func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
-func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
- panic("rpcinet handles procfs directly this method should not be called")
-
-}
-
-// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
-func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-
-}
-
-// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
-func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
-func (s *Stack) TCPSACKEnabled() (bool, error) {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// Statistics implements inet.Stack.Statistics.
-func (s *Stack) Statistics(stat interface{}, arg string) error {
- return syserr.ErrEndpointOperation.ToError()
-}
-
-// RouteTable implements inet.Stack.RouteTable.
-func (s *Stack) RouteTable() []inet.Route {
- return append([]inet.Route(nil), s.routes...)
-}
-
-// Resume implements inet.Stack.Resume.
-func (s *Stack) Resume() {}
-
-// Forwarding implements inet.Stack.Forwarding.
-func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
- panic("rpcinet handles procfs directly this method should not be called")
-}
-
-// SetForwarding implements inet.Stack.SetForwarding.
-func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
- panic("rpcinet handles procfs directly this method should not be called")
-}
diff --git a/pkg/sentry/socket/rpcinet/stack_unsafe.go b/pkg/sentry/socket/rpcinet/stack_unsafe.go
deleted file mode 100644
index a94bdad83..000000000
--- a/pkg/sentry/socket/rpcinet/stack_unsafe.go
+++ /dev/null
@@ -1,193 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package rpcinet
-
-import (
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
- pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserr"
-)
-
-// NewNetlinkRouteRequest builds a netlink message for getting the RIB,
-// the routing information base.
-func newNetlinkRouteRequest(proto, seq, family int) []byte {
- rr := &syscall.NetlinkRouteRequest{}
- rr.Header.Len = uint32(syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg)
- rr.Header.Type = uint16(proto)
- rr.Header.Flags = syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST
- rr.Header.Seq = uint32(seq)
- rr.Data.Family = uint8(family)
- return netlinkRRtoWireFormat(rr)
-}
-
-func netlinkRRtoWireFormat(rr *syscall.NetlinkRouteRequest) []byte {
- b := make([]byte, rr.Header.Len)
- *(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len
- *(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type
- *(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags
- *(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq
- *(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid
- b[16] = byte(rr.Data.Family)
- return b
-}
-
-func (s *Stack) getNetlinkFd() (uint32, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(syscall.AF_NETLINK), Type: int64(syscall.SOCK_RAW | syscall.SOCK_NONBLOCK), Protocol: int64(syscall.NETLINK_ROUTE)}}}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
- if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
- return res.(*pb.SocketResponse_Fd).Fd, nil
-}
-
-func (s *Stack) bindNetlinkFd(fd uint32, sockaddr []byte) *syserr.Error {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: fd, Address: sockaddr}}}, false /* ignoreResult */)
- <-c
-
- if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
- return syserr.FromHost(syscall.Errno(e))
- }
- return nil
-}
-
-func (s *Stack) closeNetlinkFd(fd uint32) {
- _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: fd}}}, true /* ignoreResult */)
-}
-
-func (s *Stack) rpcSendMsg(req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
- if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
- return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.SendmsgResponse_Length).Length, nil
-}
-
-func (s *Stack) sendMsg(fd uint32, buf []byte, to []byte, flags int) (int, *syserr.Error) {
- // Whitelist flags.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
- return 0, syserr.ErrInvalidArgument
- }
-
- req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
- Fd: fd,
- Data: buf,
- Address: to,
- More: flags&linux.MSG_MORE != 0,
- EndOfRecord: flags&linux.MSG_EOR != 0,
- }}
-
- n, err := s.rpcSendMsg(req)
- return int(n), err
-}
-
-func (s *Stack) rpcRecvMsg(req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
- id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
- <-c
-
- res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
- if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
- return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
- }
-
- return res.(*pb.RecvmsgResponse_Payload).Payload, nil
-}
-
-func (s *Stack) recvMsg(fd, l, flags uint32) ([]byte, *syserr.Error) {
- req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
- Fd: fd,
- Length: l,
- Sender: false,
- Trunc: flags&linux.MSG_TRUNC != 0,
- Peek: flags&linux.MSG_PEEK != 0,
- }}
-
- res, err := s.rpcRecvMsg(req)
- if err != nil {
- return nil, err
- }
- return res.Data, nil
-}
-
-func (s *Stack) netlinkRequest(proto, family int) ([]byte, error) {
- fd, err := s.getNetlinkFd()
- if err != nil {
- return nil, err.ToError()
- }
- defer s.closeNetlinkFd(fd)
-
- lsa := syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
- b := binary.Marshal(nil, usermem.ByteOrder, &lsa)
- if err := s.bindNetlinkFd(fd, b); err != nil {
- return nil, err.ToError()
- }
-
- wb := newNetlinkRouteRequest(proto, 1, family)
- _, err = s.sendMsg(fd, wb, b, 0)
- if err != nil {
- return nil, err.ToError()
- }
-
- var tab []byte
-done:
- for {
- rb, err := s.recvMsg(fd, uint32(syscall.Getpagesize()), 0)
- nr := len(rb)
- if err != nil {
- return nil, err.ToError()
- }
-
- if nr < syscall.NLMSG_HDRLEN {
- return nil, syserr.ErrInvalidArgument.ToError()
- }
-
- tab = append(tab, rb...)
- msgs, e := syscall.ParseNetlinkMessage(rb)
- if e != nil {
- return nil, e
- }
-
- for _, m := range msgs {
- if m.Header.Type == syscall.NLMSG_DONE {
- break done
- }
- if m.Header.Type == syscall.NLMSG_ERROR {
- return nil, syserr.ErrInvalidArgument.ToError()
- }
- }
- }
-
- return tab, nil
-}
-
-// DoNetlinkRouteRequest returns routing information base, also known as RIB,
-// which consists of network facility information, states and parameters.
-func (s *Stack) DoNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
- data, err := s.netlinkRequest(req, syscall.AF_UNSPEC)
- if err != nil {
- return nil, err
- }
- return syscall.ParseNetlinkMessage(data)
-}
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
deleted file mode 100644
index 9586f5923..000000000
--- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto
+++ /dev/null
@@ -1,353 +0,0 @@
-syntax = "proto3";
-
-// package syscall_rpc is a set of networking related system calls that can be
-// forwarded to a socket gofer.
-//
-// TODO(b/77963526): Document individual RPCs.
-package syscall_rpc;
-
-message SendmsgRequest {
- uint32 fd = 1;
- bytes data = 2 [ctype = CORD];
- bytes address = 3;
- bool more = 4;
- bool end_of_record = 5;
-}
-
-message SendmsgResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 length = 2;
- }
-}
-
-message IOCtlRequest {
- uint32 fd = 1;
- uint32 cmd = 2;
- bytes arg = 3;
-}
-
-message IOCtlResponse {
- oneof result {
- uint32 error_number = 1;
- bytes value = 2;
- }
-}
-
-message RecvmsgRequest {
- uint32 fd = 1;
- uint32 length = 2;
- bool sender = 3;
- bool peek = 4;
- bool trunc = 5;
- uint32 cmsg_length = 6;
-}
-
-message OpenRequest {
- bytes path = 1;
- uint32 flags = 2;
- uint32 mode = 3;
-}
-
-message OpenResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message ReadRequest {
- uint32 fd = 1;
- uint32 length = 2;
-}
-
-message ReadResponse {
- oneof result {
- uint32 error_number = 1;
- bytes data = 2 [ctype = CORD];
- }
-}
-
-message ReadFileRequest {
- string path = 1;
-}
-
-message ReadFileResponse {
- oneof result {
- uint32 error_number = 1;
- bytes data = 2 [ctype = CORD];
- }
-}
-
-message WriteRequest {
- uint32 fd = 1;
- bytes data = 2 [ctype = CORD];
-}
-
-message WriteResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 length = 2;
- }
-}
-
-message WriteFileRequest {
- string path = 1;
- bytes content = 2;
-}
-
-message WriteFileResponse {
- uint32 error_number = 1;
- uint32 written = 2;
-}
-
-message AddressResponse {
- bytes address = 1;
- uint32 length = 2;
-}
-
-message RecvmsgResponse {
- message ResultPayload {
- bytes data = 1 [ctype = CORD];
- AddressResponse address = 2;
- uint32 length = 3;
- bytes cmsg_data = 4;
- }
- oneof result {
- uint32 error_number = 1;
- ResultPayload payload = 2;
- }
-}
-
-message BindRequest {
- uint32 fd = 1;
- bytes address = 2;
-}
-
-message BindResponse {
- uint32 error_number = 1;
-}
-
-message AcceptRequest {
- uint32 fd = 1;
- bool peer = 2;
- int64 flags = 3;
-}
-
-message AcceptResponse {
- message ResultPayload {
- uint32 fd = 1;
- AddressResponse address = 2;
- }
- oneof result {
- uint32 error_number = 1;
- ResultPayload payload = 2;
- }
-}
-
-message ConnectRequest {
- uint32 fd = 1;
- bytes address = 2;
-}
-
-message ConnectResponse {
- uint32 error_number = 1;
-}
-
-message ListenRequest {
- uint32 fd = 1;
- int64 backlog = 2;
-}
-
-message ListenResponse {
- uint32 error_number = 1;
-}
-
-message ShutdownRequest {
- uint32 fd = 1;
- int64 how = 2;
-}
-
-message ShutdownResponse {
- uint32 error_number = 1;
-}
-
-message CloseRequest {
- uint32 fd = 1;
-}
-
-message CloseResponse {
- uint32 error_number = 1;
-}
-
-message GetSockOptRequest {
- uint32 fd = 1;
- int64 level = 2;
- int64 name = 3;
- uint32 length = 4;
-}
-
-message GetSockOptResponse {
- oneof result {
- uint32 error_number = 1;
- bytes opt = 2;
- }
-}
-
-message SetSockOptRequest {
- uint32 fd = 1;
- int64 level = 2;
- int64 name = 3;
- bytes opt = 4;
-}
-
-message SetSockOptResponse {
- uint32 error_number = 1;
-}
-
-message GetSockNameRequest {
- uint32 fd = 1;
-}
-
-message GetSockNameResponse {
- oneof result {
- uint32 error_number = 1;
- AddressResponse address = 2;
- }
-}
-
-message GetPeerNameRequest {
- uint32 fd = 1;
-}
-
-message GetPeerNameResponse {
- oneof result {
- uint32 error_number = 1;
- AddressResponse address = 2;
- }
-}
-
-message SocketRequest {
- int64 family = 1;
- int64 type = 2;
- int64 protocol = 3;
-}
-
-message SocketResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message EpollWaitRequest {
- uint32 fd = 1;
- uint32 num_events = 2;
- sint64 msec = 3;
-}
-
-message EpollEvent {
- uint32 fd = 1;
- uint32 events = 2;
-}
-
-message EpollEvents {
- repeated EpollEvent events = 1;
-}
-
-message EpollWaitResponse {
- oneof result {
- uint32 error_number = 1;
- EpollEvents events = 2;
- }
-}
-
-message EpollCtlRequest {
- uint32 epfd = 1;
- int64 op = 2;
- uint32 fd = 3;
- EpollEvent event = 4;
-}
-
-message EpollCtlResponse {
- uint32 error_number = 1;
-}
-
-message EpollCreate1Request {
- int64 flag = 1;
-}
-
-message EpollCreate1Response {
- oneof result {
- uint32 error_number = 1;
- uint32 fd = 2;
- }
-}
-
-message PollRequest {
- uint32 fd = 1;
- uint32 events = 2;
-}
-
-message PollResponse {
- oneof result {
- uint32 error_number = 1;
- uint32 events = 2;
- }
-}
-
-message SyscallRequest {
- oneof args {
- SocketRequest socket = 1;
- SendmsgRequest sendmsg = 2;
- RecvmsgRequest recvmsg = 3;
- BindRequest bind = 4;
- AcceptRequest accept = 5;
- ConnectRequest connect = 6;
- ListenRequest listen = 7;
- ShutdownRequest shutdown = 8;
- CloseRequest close = 9;
- GetSockOptRequest get_sock_opt = 10;
- SetSockOptRequest set_sock_opt = 11;
- GetSockNameRequest get_sock_name = 12;
- GetPeerNameRequest get_peer_name = 13;
- EpollWaitRequest epoll_wait = 14;
- EpollCtlRequest epoll_ctl = 15;
- EpollCreate1Request epoll_create1 = 16;
- PollRequest poll = 17;
- ReadRequest read = 18;
- WriteRequest write = 19;
- OpenRequest open = 20;
- IOCtlRequest ioctl = 21;
- WriteFileRequest write_file = 22;
- ReadFileRequest read_file = 23;
- }
-}
-
-message SyscallResponse {
- oneof result {
- SocketResponse socket = 1;
- SendmsgResponse sendmsg = 2;
- RecvmsgResponse recvmsg = 3;
- BindResponse bind = 4;
- AcceptResponse accept = 5;
- ConnectResponse connect = 6;
- ListenResponse listen = 7;
- ShutdownResponse shutdown = 8;
- CloseResponse close = 9;
- GetSockOptResponse get_sock_opt = 10;
- SetSockOptResponse set_sock_opt = 11;
- GetSockNameResponse get_sock_name = 12;
- GetPeerNameResponse get_peer_name = 13;
- EpollWaitResponse epoll_wait = 14;
- EpollCtlResponse epoll_ctl = 15;
- EpollCreate1Response epoll_create1 = 16;
- PollResponse poll = 17;
- ReadResponse read = 18;
- WriteResponse write = 19;
- OpenResponse open = 20;
- IOCtlResponse ioctl = 21;
- WriteFileResponse write_file = 22;
- ReadFileResponse read_file = 23;
- }
-}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 8c250c325..04b259d27 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -24,16 +24,18 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
@@ -43,11 +45,30 @@ type ControlMessages struct {
IP tcpip.ControlMessages
}
-// Socket is the interface containing socket syscalls used by the syscall layer
-// to redirect them to the appropriate implementation.
+// Release releases Unix domain socket credentials and rights.
+func (c *ControlMessages) Release(ctx context.Context) {
+ c.Unix.Release(ctx)
+}
+
+// Socket is an interface combining fs.FileOperations and SocketOps,
+// representing a VFS1 socket file.
type Socket interface {
fs.FileOperations
+ SocketOps
+}
+// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps,
+// representing a VFS2 socket file.
+type SocketVFS2 interface {
+ vfs.FileDescriptionImpl
+ SocketOps
+}
+
+// SocketOps is the interface containing socket syscalls used by the syscall
+// layer to redirect them to the appropriate implementation.
+//
+// It is implemented by both Socket and SocketVFS2.
+type SocketOps interface {
// Connect implements the connect(2) linux syscall.
Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
@@ -66,7 +87,7 @@ type Socket interface {
Shutdown(t *kernel.Task, how int) *syserr.Error
// GetSockOpt implements the getsockopt(2) linux syscall.
- GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error)
// SetSockOpt implements the setsockopt(2) linux syscall.
SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
@@ -148,6 +169,8 @@ var families = make(map[int][]Provider)
// RegisterProvider registers the provider of a given address family so that
// sockets of that type can be created via socket() and/or socketpair()
// syscalls.
+//
+// This should only be called during the initialization of the address family.
func RegisterProvider(family int, provider Provider) {
families[family] = append(families[family], provider)
}
@@ -211,6 +234,74 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino))
}
+// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for
+// specific address families (e.g., AF_INET).
+type ProviderVFS2 interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error)
+}
+
+// familiesVFS2 holds a map of all known address families and their providers.
+var familiesVFS2 = make(map[int][]ProviderVFS2)
+
+// RegisterProviderVFS2 registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProviderVFS2(family int, provider ProviderVFS2) {
+ familiesVFS2[family] = append(familiesVFS2[family], provider)
+}
+
+// NewVFS2 creates a new socket with the given family, type and protocol.
+func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ for _, p := range familiesVFS2[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocketVFS2(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// PairVFS2 creates a new connected socket pair with the given family, type and
+// protocol.
+func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ providers, ok := familiesVFS2[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocketVFS2(s1)
+ k.RecordSocketVFS2(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
// SendReceiveTimeout stores timeouts for send and receive calls.
//
// It is meant to be embedded into Socket implementations to help satisfy the
@@ -317,7 +408,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) {
linux.SO_MARK,
linux.SO_MAX_PACING_RATE,
linux.SO_NOFCS,
- linux.SO_NO_CHECK,
linux.SO_OOBINLINE,
linux.SO_PASSCRED,
linux.SO_PASSSEC,
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 5b6a154f6..cb953e4dc 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -1,35 +1,54 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "socket_refs",
+ out = "socket_refs.go",
+ package = "unix",
+ prefix = "socketOpsCommon",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "socketOpsCommon",
+ },
+)
+
go_library(
name = "unix",
srcs = [
"device.go",
"io.go",
+ "socket_refs.go",
"unix.go",
+ "unix_vfs2.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/log",
"//pkg/refs",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
+ "//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 2ec1a662d..129949990 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -15,8 +15,8 @@
package unix
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -83,6 +83,19 @@ type EndpointReader struct {
ControlTrunc bool
}
+// Truncate calls RecvMsg on the endpoint without writing to a destination.
+func (r *EndpointReader) Truncate() error {
+ // Ignore bytes read since it will always be zero.
+ _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return err.ToError()
+ }
+ return nil
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 788ad70d2..c708b6030 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -25,13 +25,14 @@ go_library(
"transport_message_list.go",
"unix.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/ilist",
+ "//pkg/log",
"//pkg/refs",
- "//pkg/sentry/context",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index dea11e253..c67b602f0 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -15,10 +15,9 @@
package transport
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/waiter"
@@ -212,7 +211,7 @@ func (e *connectionedEndpoint) Listening() bool {
// The socket will be a fresh state after a call to close and may be reused.
// That is, close may be used to "unbind" or "disconnect" the socket in error
// paths.
-func (e *connectionedEndpoint) Close() {
+func (e *connectionedEndpoint) Close(ctx context.Context) {
e.Lock()
var c ConnectedEndpoint
var r Receiver
@@ -234,7 +233,7 @@ func (e *connectionedEndpoint) Close() {
case e.Listening():
close(e.acceptedChan)
for n := range e.acceptedChan {
- n.Close()
+ n.Close(ctx)
}
e.acceptedChan = nil
e.path = ""
@@ -242,18 +241,18 @@ func (e *connectionedEndpoint) Close() {
e.Unlock()
if c != nil {
c.CloseNotify()
- c.Release()
+ c.Release(ctx)
}
if r != nil {
r.CloseNotify()
- r.Release()
+ r.Release(ctx)
}
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
if ce.Type() != e.stype {
- return syserr.ErrConnectionRefused
+ return syserr.ErrWrongProtocolForSocket
}
// Check if ce is e to avoid a deadlock.
@@ -341,7 +340,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
return nil
default:
// Busy; return ECONNREFUSED per spec.
- ne.Close()
+ ne.Close(ctx)
e.Unlock()
ce.Unlock()
return syserr.ErrConnectionRefused
@@ -477,6 +476,9 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
// State implements socket.Socket.State.
func (e *connectionedEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
if e.Connected() {
return linux.SS_CONNECTED
}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 0322dec0b..70ee8f9b8 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -16,7 +16,7 @@ package transport
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/waiter"
@@ -54,10 +54,10 @@ func (e *connectionlessEndpoint) isBound() bool {
// Close puts the endpoint in a closed state and frees all resources associated
// with it.
-func (e *connectionlessEndpoint) Close() {
+func (e *connectionlessEndpoint) Close(ctx context.Context) {
e.Lock()
if e.connected != nil {
- e.connected.Release()
+ e.connected.Release(ctx)
e.connected = nil
}
@@ -71,7 +71,7 @@ func (e *connectionlessEndpoint) Close() {
e.Unlock()
r.CloseNotify()
- r.Release()
+ r.Release(ctx)
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
@@ -108,10 +108,10 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C
if err != nil {
return 0, syserr.ErrInvalidEndpointState
}
- defer connected.Release()
+ defer connected.Release(ctx)
e.Lock()
- n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ n, notify, err := connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
e.Unlock()
if notify {
@@ -135,7 +135,7 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi
e.Lock()
if e.connected != nil {
- e.connected.Release()
+ e.connected.Release(ctx)
}
e.connected = connected
e.Unlock()
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index e27b1c714..ef6043e19 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -15,10 +15,12 @@
package transport
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -56,10 +58,10 @@ func (q *queue) Close() {
// Both the read and write queues must be notified after resetting:
// q.ReaderQueue.Notify(waiter.EventIn)
// q.WriterQueue.Notify(waiter.EventOut)
-func (q *queue) Reset() {
+func (q *queue) Reset(ctx context.Context) {
q.mu.Lock()
for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
- cur.Release()
+ cur.Release(ctx)
}
q.dataList.Reset()
q.used = 0
@@ -67,8 +69,8 @@ func (q *queue) Reset() {
}
// DecRef implements RefCounter.DecRef with destructor q.Reset.
-func (q *queue) DecRef() {
- q.DecRefWithDestructor(q.Reset)
+func (q *queue) DecRef(ctx context.Context) {
+ q.DecRefWithDestructor(ctx, q.Reset)
// We don't need to notify after resetting because no one cares about
// this queue after all references have been dropped.
}
@@ -101,12 +103,16 @@ func (q *queue) IsWritable() bool {
// Enqueue adds an entry to the data queue if room is available.
//
+// If discardEmpty is true and there are zero bytes of data, the packet is
+// dropped.
+//
// If truncate is true, Enqueue may truncate the message before enqueuing it.
-// Otherwise, the entire message must fit. If n < e.Length(), err indicates why.
+// Otherwise, the entire message must fit. If l is less than the size of data,
+// err indicates why.
//
// If notify is true, ReaderQueue.Notify must be called:
// q.ReaderQueue.Notify(waiter.EventIn)
-func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *syserr.Error) {
+func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) {
q.mu.Lock()
if q.closed {
@@ -114,9 +120,16 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s
return 0, false, syserr.ErrClosedForSend
}
- free := q.limit - q.used
+ for _, d := range data {
+ l += int64(len(d))
+ }
+ if discardEmpty && l == 0 {
+ q.mu.Unlock()
+ c.Release(ctx)
+ return 0, false, nil
+ }
- l = e.Length()
+ free := q.limit - q.used
if l > free && truncate {
if free == 0 {
@@ -125,8 +138,7 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s
return 0, false, syserr.ErrWouldBlock
}
- e.Truncate(free)
- l = e.Length()
+ l = free
err = syserr.ErrWouldBlock
}
@@ -137,14 +149,26 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s
}
if l > free {
- // Message can't fit right now.
+ // Message can't fit right now, and could not be truncated.
q.mu.Unlock()
return 0, false, syserr.ErrWouldBlock
}
+ // Aggregate l bytes of data. This will truncate the data if l is less than
+ // the total bytes held in data.
+ v := make([]byte, l)
+ for i, b := 0, v; i < len(data) && len(b) > 0; i++ {
+ n := copy(b, data[i])
+ b = b[n:]
+ }
+
notify = q.dataList.Front() == nil
q.used += l
- q.dataList.PushBack(e)
+ q.dataList.PushBack(&message{
+ Data: buffer.View(v),
+ Control: c,
+ Address: from,
+ })
q.mu.Unlock()
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 529a7a7a9..475d7177e 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,11 +16,12 @@
package transport
import (
- "sync"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -36,7 +37,7 @@ type RightsControlMessage interface {
Clone() RightsControlMessage
// Release releases any resources owned by the RightsControlMessage.
- Release()
+ Release(ctx context.Context)
}
// A CredentialsControlMessage is a control message containing Unix credentials.
@@ -73,9 +74,9 @@ func (c *ControlMessages) Clone() ControlMessages {
}
// Release releases both the credentials and the rights.
-func (c *ControlMessages) Release() {
+func (c *ControlMessages) Release(ctx context.Context) {
if c.Rights != nil {
- c.Rights.Release()
+ c.Rights.Release(ctx)
}
*c = ControlMessages{}
}
@@ -89,7 +90,7 @@ type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
- Close()
+ Close(ctx context.Context)
// RecvMsg reads data and a control message from the endpoint. This method
// does not block if there is no data pending.
@@ -175,17 +176,25 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptBool sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
- SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+ // GetSockOptBool gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
- GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
// State returns the current state of the socket, as represented by Linux in
// procfs.
@@ -243,7 +252,7 @@ type BoundEndpoint interface {
// Release releases any resources held by the BoundEndpoint. It must be
// called before dropping all references to a BoundEndpoint returned by a
// function.
- Release()
+ Release(ctx context.Context)
}
// message represents a message passed over a Unix domain socket.
@@ -272,8 +281,8 @@ func (m *message) Length() int64 {
}
// Release releases any resources held by the message.
-func (m *message) Release() {
- m.Control.Release()
+func (m *message) Release(ctx context.Context) {
+ m.Control.Release(ctx)
}
// Peek returns a copy of the message.
@@ -295,7 +304,7 @@ type Receiver interface {
// See Endpoint.RecvMsg for documentation on shared arguments.
//
// notify indicates if RecvNotify should be called.
- Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+ Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
// RecvNotify notifies the Receiver of a successful Recv. This must not be
// called while holding any endpoint locks.
@@ -324,7 +333,7 @@ type Receiver interface {
// Release releases any resources owned by the Receiver. It should be
// called before droping all references to a Receiver.
- Release()
+ Release(ctx context.Context)
}
// queueReceiver implements Receiver for datagram sockets.
@@ -335,7 +344,7 @@ type queueReceiver struct {
}
// Recv implements Receiver.Recv.
-func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *queueReceiver) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
var m *message
var notify bool
var err *syserr.Error
@@ -389,8 +398,8 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 {
}
// Release implements Receiver.Release.
-func (q *queueReceiver) Release() {
- q.readQueue.DecRef()
+func (q *queueReceiver) Release(ctx context.Context) {
+ q.readQueue.DecRef(ctx)
}
// streamQueueReceiver implements Receiver for stream sockets.
@@ -447,7 +456,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
}
// Recv implements Receiver.Recv.
-func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
q.mu.Lock()
defer q.mu.Unlock()
@@ -493,7 +502,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int,
var cmTruncated bool
if c.Rights != nil && numRights == 0 {
- c.Rights.Release()
+ c.Rights.Release(ctx)
c.Rights = nil
cmTruncated = true
}
@@ -548,7 +557,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int,
// Consume rights.
if numRights == 0 {
cmTruncated = true
- q.control.Rights.Release()
+ q.control.Rights.Release(ctx)
} else {
c.Rights = q.control.Rights
haveRights = true
@@ -573,7 +582,7 @@ type ConnectedEndpoint interface {
//
// syserr.ErrWouldBlock can be returned along with a partial write if
// the caller should block to send the rest of the data.
- Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
+ Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
// SendNotify notifies the ConnectedEndpoint of a successful Send. This
// must not be called while holding any endpoint locks.
@@ -607,7 +616,7 @@ type ConnectedEndpoint interface {
// Release releases any resources owned by the ConnectedEndpoint. It should
// be called before droping all references to a ConnectedEndpoint.
- Release()
+ Release(ctx context.Context)
// CloseUnread sets the fact that this end is closed with unread data to
// the peer socket.
@@ -645,35 +654,22 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
}
// Send implements ConnectedEndpoint.Send.
-func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
- var l int64
- for _, d := range data {
- l += int64(len(d))
- }
-
+func (e *connectedEndpoint) Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+ discardEmpty := false
truncate := false
if e.endpoint.Type() == linux.SOCK_STREAM {
- // Since stream sockets don't preserve message boundaries, we
- // can write only as much of the message as fits in the queue.
- truncate = true
-
// Discard empty stream packets. Since stream sockets don't
// preserve message boundaries, sending zero bytes is a no-op.
// In Linux, the receiver actually uses a zero-length receive
// as an indication that the stream was closed.
- if l == 0 {
- controlMessages.Release()
- return 0, false, nil
- }
- }
+ discardEmpty = true
- v := make([]byte, 0, l)
- for _, d := range data {
- v = append(v, d...)
+ // Since stream sockets don't preserve message boundaries, we
+ // can write only as much of the message as fits in the queue.
+ truncate = true
}
- l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate)
- return int64(l), notify, err
+ return e.writeQueue.Enqueue(ctx, data, c, from, discardEmpty, truncate)
}
// SendNotify implements ConnectedEndpoint.SendNotify.
@@ -711,8 +707,8 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 {
}
// Release implements ConnectedEndpoint.Release.
-func (e *connectedEndpoint) Release() {
- e.writeQueue.DecRef()
+func (e *connectedEndpoint) Release(ctx context.Context) {
+ e.writeQueue.DecRef(ctx)
}
// CloseUnread implements ConnectedEndpoint.CloseUnread.
@@ -802,7 +798,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n
return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
}
- recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(ctx, data, creds, numRights, peek)
e.Unlock()
if err != nil {
return 0, 0, ControlMessages{}, false, err
@@ -831,7 +827,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
return 0, syserr.ErrAlreadyConnected
}
- n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ n, notify, err := e.connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
e.Unlock()
if notify {
@@ -843,19 +839,46 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option. Currently not supported.
func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
+ return nil
+}
+
+func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
case tcpip.PasscredOption:
- e.setPasscred(v != 0)
- return nil
+ e.setPasscred(v)
+ case tcpip.ReuseAddressOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
}
return nil
}
-func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
-func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.PasscredOption:
+ return e.Passcred(), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -911,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return int(v), nil
default:
+ log.Warningf("Unsupported socket option: %d", opt)
return -1, tcpip.ErrUnknownProtocolOption
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
- }
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
+ log.Warningf("Unsupported socket option: %T", opt)
return tcpip.ErrUnknownProtocolOption
}
}
@@ -988,6 +1001,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// Release implements BoundEndpoint.Release.
-func (*baseEndpoint) Release() {
+func (*baseEndpoint) Release(context.Context) {
// Binding a baseEndpoint doesn't take a reference.
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 1aaae8487..b7e8e4325 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -22,9 +22,9 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -33,11 +33,13 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -52,17 +54,14 @@ type SocketOperations struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- refs.AtomicRefCount
- socket.SendReceiveTimeout
- ep transport.Endpoint
- stype linux.SockType
+ socketOpsCommon
}
// New creates a new unix socket.
func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true})
}
@@ -75,29 +74,51 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
}
s := SocketOperations{
- ep: ep,
- stype: stype,
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
}
- s.EnableLeakCheck("unix.SocketOperations")
+ s.EnableLeakCheck()
return fs.NewFile(ctx, d, flags, &s)
}
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socketOpsCommonRefs
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ stype linux.SockType
+
+ // abstractName and abstractNamespace indicate the name and namespace of the
+ // socket if it is bound to an abstract socket namespace. Once the socket is
+ // bound, they cannot be modified.
+ abstractName string
+ abstractNamespace *kernel.AbstractSocketNamespace
+}
+
// DecRef implements RefCounter.DecRef.
-func (s *SocketOperations) DecRef() {
- s.DecRefWithDestructor(func() {
- s.ep.Close()
+func (s *socketOpsCommon) DecRef(ctx context.Context) {
+ s.socketOpsCommonRefs.DecRef(func() {
+ s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
})
}
// Release implemements fs.FileOperations.Release.
-func (s *SocketOperations) Release() {
+func (s *socketOpsCommon) Release(ctx context.Context) {
// Release only decrements a reference on s because s may be referenced in
// the abstract socket namespace.
- s.DecRef()
+ s.DecRef(ctx)
}
-func (s *SocketOperations) isPacket() bool {
+func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
return true
@@ -110,16 +131,22 @@ func (s *SocketOperations) isPacket() bool {
}
// Endpoint extracts the transport.Endpoint.
-func (s *SocketOperations) Endpoint() transport.Endpoint {
+func (s *socketOpsCommon) Endpoint() transport.Endpoint {
return s.ep
}
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, family, err := netstack.AddressAndFamily(sockaddr)
if err != nil {
+ if err == syserr.ErrAddressFamilyNotSupported {
+ err = syserr.ErrInvalidArgument
+ }
return "", err
}
+ if family != linux.AF_UNIX {
+ return "", syserr.ErrInvalidArgument
+ }
// The address is trimmed by GetAddress.
p := string(addr.Addr)
@@ -137,7 +164,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -149,7 +176,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -166,13 +193,13 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
// Listen implements the linux syscall listen(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return s.ep.Listen(backlog)
}
@@ -215,7 +242,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
ns := New(t, ep, s.stype)
- defer ns.DecRef()
+ defer ns.DecRef(t)
if flags&linux.SOCK_NONBLOCK != 0 {
flags := ns.Flags()
@@ -265,17 +292,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if t.IsNetworkNamespaced() {
return syserr.ErrInvalidEndpointState
}
- if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ asn := t.AbstractSockets()
+ name := p[1:]
+ if err := asn.Bind(t, name, bep, s); err != nil {
// syserr.ErrPortInUse corresponds to EADDRINUSE.
return syserr.ErrPortInUse
}
+ s.abstractName = name
+ s.abstractNamespace = asn
} else {
// The parent and name.
var d *fs.Dirent
var name string
cwd := t.FSContext().WorkingDirectory()
- defer cwd.DecRef()
+ defer cwd.DecRef(t)
// Is there no slash at all?
if !strings.Contains(p, "/") {
@@ -283,7 +314,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
name = p
} else {
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
// Find the last path component, we know that something follows
// that final slash, otherwise extractPath() would have failed.
lastSlash := strings.LastIndex(p, "/")
@@ -299,16 +330,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
// No path available.
return syserr.ErrNoSuchFile
}
- defer d.DecRef()
+ defer d.DecRef(t)
name = p[lastSlash+1:]
}
// Create the socket.
+ //
+ // Note that the file permissions here are not set correctly (see
+ // gvisor.dev/issue/2324). There is no convenient way to get permissions
+ // on the socket referred to by s, so we will leave this discrepancy
+ // unresolved until VFS2 replaces this code.
childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
if err != nil {
return syserr.ErrPortInUse
}
- childDir.DecRef()
+ childDir.DecRef(t)
}
return nil
@@ -339,41 +375,76 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
return ep, nil
}
+ if kernel.VFS2Enabled {
+ p := fspath.Parse(path)
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ relPath := !p.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: p,
+ FollowFinalSymlink: true,
+ }
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path})
+ root.DecRef(t)
+ if relPath {
+ start.DecRef(t)
+ }
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+ return ep, nil
+ }
+
// Find the node in the filesystem.
root := t.FSContext().RootDirectory()
cwd := t.FSContext().WorkingDirectory()
remainingTraversals := uint(fs.DefaultTraversalLimit)
d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals)
- cwd.DecRef()
- root.DecRef()
+ cwd.DecRef(t)
+ root.DecRef(t)
if e != nil {
return nil, syserr.FromError(e)
}
// Extract the endpoint if one is there.
ep := d.Inode.BoundEndpoint(path)
- d.DecRef()
+ d.DecRef(t)
if ep == nil {
// No socket!
return nil, syserr.ErrConnectionRefused
}
-
return ep, nil
}
// Connect implements the linux syscall connect(2) for unix sockets.
-func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
ep, err := extractEndpoint(t, sockaddr)
if err != nil {
return err
}
- defer ep.Release()
+ defer ep.Release(t)
// Connect the server endpoint.
- return s.ep.Connect(t, ep)
+ err = s.ep.Connect(t, ep)
+
+ if err == syserr.ErrWrongProtocolForSocket {
+ // Linux for abstract sockets returns ErrConnectionRefused
+ // instead of ErrWrongProtocolForSocket.
+ path, _ := extractPath(sockaddr)
+ if len(path) > 0 && path[0] == 0 {
+ err = syserr.ErrConnectionRefused
+ }
+ }
+
+ return err
}
-// Writev implements fs.FileOperations.Write.
+// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
t := kernel.TaskFromContext(ctx)
ctrl := control.New(t, s.ep, nil)
@@ -393,7 +464,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
Ctx: t,
Endpoint: s.ep,
@@ -401,15 +472,25 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
To: nil,
}
if len(to) > 0 {
- ep, err := extractEndpoint(t, to)
- if err != nil {
- return 0, err
- }
- defer ep.Release()
- w.To = ep
+ switch s.stype {
+ case linux.SOCK_SEQPACKET:
+ to = nil
+ case linux.SOCK_STREAM:
+ if s.State() == linux.SS_CONNECTED {
+ return 0, syserr.ErrAlreadyConnected
+ }
+ return 0, syserr.ErrNotSupported
+ default:
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release(t)
+ w.To = ep
- if ep.Passcred() && w.Control.Credentials == nil {
- w.Control.Credentials = control.MakeCreds(t)
+ if ep.Passcred() && w.Control.Credentials == nil {
+ w.Control.Credentials = control.MakeCreds(t)
+ }
}
}
@@ -447,27 +528,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
// Passcred implements transport.Credentialer.Passcred.
-func (s *SocketOperations) Passcred() bool {
+func (s *socketOpsCommon) Passcred() bool {
return s.ep.Passcred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
-func (s *SocketOperations) ConnectedPasscred() bool {
+func (s *socketOpsCommon) ConnectedPasscred() bool {
return s.ep.ConnectedPasscred()
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.ep.Readiness(mask)
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
@@ -479,7 +560,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
f, err := netstack.ConvertShutdown(how)
if err != nil {
return err
@@ -505,7 +586,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -541,8 +622,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if senderRequested {
r.From = &tcpip.FullAddress{}
}
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+
+ }
+
var total int64
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
+ if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
@@ -577,7 +677,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
@@ -623,12 +723,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// State implements socket.Socket.State.
-func (s *SocketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
return s.ep.State()
}
// Type implements socket.Socket.Type.
-func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
// Unix domain sockets always have a protocol of 0.
return linux.AF_UNIX, s.stype, 0
}
@@ -681,4 +781,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
func init() {
socket.RegisterProvider(linux.AF_UNIX, &provider{})
+ socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{})
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
new file mode 100644
index 000000000..d066ef8ab
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -0,0 +1,376 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// SocketVFS2 implements socket.SocketVFS2 (and by extension,
+// vfs.FileDescriptionImpl) for Unix sockets.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&SocketVFS2{})
+
+// NewSockfsFile creates a new socket file in the global sockfs mount and
+// returns a corresponding file description.
+func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return fd, nil
+}
+
+// NewFileDescription creates and returns a socket file description
+// corresponding to the given mount and dentry.
+func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ sock := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+ sock.LockFD.Init(locks)
+ vfsfd := &sock.vfsfd
+ if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
+ defer s.socketOpsCommon.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := NewSockfsFile(t, ep, s.stype)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef(t)
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocketVFS2(ns)
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ asn := t.AbstractSockets()
+ name := p[1:]
+ if err := asn.Bind(t, name, bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ s.abstractName = name
+ s.abstractNamespace = asn
+ } else {
+ path := fspath.Parse(p)
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef(t)
+ start := root
+ relPath := !path.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef(t)
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ }
+ stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE})
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // File permissions correspond to net/unix/af_unix.c:unix_bind.
+ Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()),
+ Endpoint: bep,
+ })
+ if err == syserror.EEXIST {
+ return syserr.ErrAddressInUse
+ }
+ return syserr.FromError(err)
+ }
+
+ return nil
+ })
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, uio, args)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
+// providerVFS2 is a unix domain socket provider for VFS2.
+type providerVFS2 struct{}
+
+func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ f, err := NewSockfsFile(t, ep, stype)
+ if err != nil {
+ ep.Close(t)
+ return nil, err
+ }
+ return f, nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1, err := NewSockfsFile(t, ep1, stype)
+ if err != nil {
+ ep1.Close(t)
+ ep2.Close(t)
+ return nil, nil, err
+ }
+ s2, err := NewSockfsFile(t, ep2, stype)
+ if err != nil {
+ s1.DecRef(t)
+ ep2.Close(t)
+ return nil, nil, err
+ }
+
+ return s1, s2, nil
+}
diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD
index 88765f4d6..0ea4aab8b 100644
--- a/pkg/sentry/state/BUILD
+++ b/pkg/sentry/state/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,7 +9,6 @@ go_library(
"state_metadata.go",
"state_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/state",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go
index 9eb626b76..a06c9b8ab 100644
--- a/pkg/sentry/state/state.go
+++ b/pkg/sentry/state/state.go
@@ -60,6 +60,7 @@ type SaveOpts struct {
func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error {
log.Infof("Sandbox save started, pausing all tasks.")
k.Pause()
+ k.ReceiveTaskStates()
defer k.Unpause()
defer log.Infof("Tasks resumed after save.")
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 72ebf766d..88d5db9fc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -1,6 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
+load("//tools:defs.bzl", "go_library", "proto_library")
package(licenses = ["notice"])
@@ -9,17 +7,19 @@ go_library(
srcs = [
"capability.go",
"clone.go",
+ "epoll.go",
"futex.go",
- "linux64.go",
+ "linux64_amd64.go",
+ "linux64_arm64.go",
"open.go",
"poll.go",
"ptrace.go",
+ "select.go",
"signal.go",
"socket.go",
"strace.go",
"syscalls.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/strace",
visibility = ["//:sandbox"],
deps = [
":strace_go_proto",
@@ -31,29 +31,15 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
- "//pkg/sentry/socket/control",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
],
)
proto_library(
- name = "strace_proto",
+ name = "strace",
srcs = ["strace.proto"],
visibility = ["//visibility:public"],
)
-
-cc_proto_library(
- name = "strace_cc_proto",
- visibility = ["//visibility:public"],
- deps = [":strace_proto"],
-)
-
-go_proto_library(
- name = "strace_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto",
- proto = ":strace_proto",
- visibility = ["//visibility:public"],
-)
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
new file mode 100644
index 000000000..5d51a7792
--- /dev/null
+++ b/pkg/sentry/strace/epoll.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x ", eventAddr)
+ writeEpollEvent(&sb, e)
+ return sb.String()
+}
+
+func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x {", eventsAddr)
+ addr := eventsAddr
+ for i := uint64(0); i < numEvents; i++ {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(addr, &e); err != nil {
+ fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
+ continue
+ }
+ writeEpollEvent(&sb, e)
+ if uint64(sb.Len()) >= maxBytes {
+ sb.WriteString("...")
+ break
+ }
+ // Allowing addr to overflow is consistent with Linux, and harmless; if
+ // this isn't the last iteration of the loop, the next call to CopyIn
+ // will just fail with EFAULT.
+ addr, _ = addr.AddLength(uint64(linux.SizeOfEpollEvent))
+ }
+ sb.WriteString("}")
+ return sb.String()
+}
+
+func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) {
+ events := epollEventEvents.Parse(uint64(e.Events))
+ fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1])
+}
+
+var epollCtlOps = abi.ValueSet{
+ linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD",
+ linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL",
+ linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD",
+}
+
+var epollEventEvents = abi.FlagSet{
+ {Flag: linux.EPOLLIN, Name: "EPOLLIN"},
+ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
+ {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
+ {Flag: linux.EPOLLERR, Name: "EPOLLERR"},
+ {Flag: linux.EPOLLHUP, Name: "EPOLLHUP"},
+ {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
+ {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
+ {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
+ {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"},
+ {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"},
+ {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"},
+ {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"},
+ {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"},
+ {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"},
+ {Flag: linux.EPOLLET, Name: "EPOLLET"},
+}
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64_amd64.go
index 5d57b75af..71b92eaee 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,8 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package strace
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
// linuxAMD64 provides a mapping of the Linux amd64 syscalls and their argument
// types for display / formatting.
var linuxAMD64 = SyscallMap{
@@ -30,7 +37,7 @@ var linuxAMD64 = SyscallMap{
10: makeSyscallInfo("mprotect", Hex, Hex, Hex),
11: makeSyscallInfo("munmap", Hex, Hex),
12: makeSyscallInfo("brk", Hex),
- 13: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction),
+ 13: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction, Hex),
14: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
15: makeSyscallInfo("rt_sigreturn"),
16: makeSyscallInfo("ioctl", FD, Hex, Hex),
@@ -40,7 +47,7 @@ var linuxAMD64 = SyscallMap{
20: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
21: makeSyscallInfo("access", Path, Oct),
22: makeSyscallInfo("pipe", PipeFDs),
- 23: makeSyscallInfo("select", Hex, Hex, Hex, Hex, Timeval),
+ 23: makeSyscallInfo("select", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timeval),
24: makeSyscallInfo("sched_yield"),
25: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
26: makeSyscallInfo("msync", Hex, Hex, Hex),
@@ -71,8 +78,8 @@ var linuxAMD64 = SyscallMap{
51: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
52: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
53: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
- 54: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
- 55: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 54: makeSyscallInfo("setsockopt", FD, SockOptLevel, SockOptName, SetSockOptVal, Hex /* length by value, not a pointer */),
+ 55: makeSyscallInfo("getsockopt", FD, SockOptLevel, SockOptName, GetSockOptVal, SockLen),
56: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
57: makeSyscallInfo("fork"),
58: makeSyscallInfo("vfork"),
@@ -249,8 +256,8 @@ var linuxAMD64 = SyscallMap{
229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
231: makeSyscallInfo("exit_group", Hex),
- 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex),
- 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
235: makeSyscallInfo("utimes", Path, Timeval),
// 236: vserver (not implemented in the Linux kernel)
@@ -287,7 +294,7 @@ var linuxAMD64 = SyscallMap{
267: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
268: makeSyscallInfo("fchmodat", FD, Path, Mode),
269: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
- 270: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 270: makeSyscallInfo("pselect6", Hex, SelectFDSet, SelectFDSet, SelectFDSet, Timespec, SigSet),
271: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
272: makeSyscallInfo("unshare", CloneFlags),
273: makeSyscallInfo("set_robust_list", Hex, Hex),
@@ -298,7 +305,7 @@ var linuxAMD64 = SyscallMap{
278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
- 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
283: makeSyscallInfo("timerfd_create", Hex, Hex),
284: makeSyscallInfo("eventfd", Hex),
@@ -335,5 +342,43 @@ var linuxAMD64 = SyscallMap{
315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 318: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 319: makeSyscallInfo("memfd_create", Path, Hex), // Not quite a path, but close.
+ 320: makeSyscallInfo("kexec_file_load", FD, FD, Hex, Hex, Hex),
+ 321: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 322: makeSyscallInfo("execveat", FD, Path, ExecveStringVector, ExecveStringVector, Hex),
+ 323: makeSyscallInfo("userfaultfd", Hex),
+ 324: makeSyscallInfo("membarrier", Hex, Hex),
+ 325: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 326: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 327: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 328: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 329: makeSyscallInfo("pkey_mprotect", Hex, Hex, Hex, Hex),
+ 330: makeSyscallInfo("pkey_alloc", Hex, Hex),
+ 331: makeSyscallInfo("pkey_free", Hex),
332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 333: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 334: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
+}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.AMD64,
+ syscalls: linuxAMD64,
+ },
+ )
}
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
new file mode 100644
index 000000000..bd7361a52
--- /dev/null
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -0,0 +1,323 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package strace
+
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+)
+
+// linuxARM64 provides a mapping of the Linux arm64 syscalls and their argument
+// types for display / formatting.
+var linuxARM64 = SyscallMap{
+ 0: makeSyscallInfo("io_setup", Hex, Hex),
+ 1: makeSyscallInfo("io_destroy", Hex),
+ 2: makeSyscallInfo("io_submit", Hex, Hex, Hex),
+ 3: makeSyscallInfo("io_cancel", Hex, Hex, Hex),
+ 4: makeSyscallInfo("io_getevents", Hex, Hex, Hex, Hex, Timespec),
+ 5: makeSyscallInfo("setxattr", Path, Path, Hex, Hex, Hex),
+ 6: makeSyscallInfo("lsetxattr", Path, Path, Hex, Hex, Hex),
+ 7: makeSyscallInfo("fsetxattr", FD, Path, Hex, Hex, Hex),
+ 8: makeSyscallInfo("getxattr", Path, Path, Hex, Hex),
+ 9: makeSyscallInfo("lgetxattr", Path, Path, Hex, Hex),
+ 10: makeSyscallInfo("fgetxattr", FD, Path, Hex, Hex),
+ 11: makeSyscallInfo("listxattr", Path, Path, Hex),
+ 12: makeSyscallInfo("llistxattr", Path, Path, Hex),
+ 13: makeSyscallInfo("flistxattr", FD, Path, Hex),
+ 14: makeSyscallInfo("removexattr", Path, Path),
+ 15: makeSyscallInfo("lremovexattr", Path, Path),
+ 16: makeSyscallInfo("fremovexattr", FD, Path),
+ 17: makeSyscallInfo("getcwd", PostPath, Hex),
+ 18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
+ 19: makeSyscallInfo("eventfd2", Hex, Hex),
+ 20: makeSyscallInfo("epoll_create1", Hex),
+ 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
+ 23: makeSyscallInfo("dup", FD),
+ 24: makeSyscallInfo("dup3", FD, FD, Hex),
+ 25: makeSyscallInfo("fcntl", FD, Hex, Hex),
+ 26: makeSyscallInfo("inotify_init1", Hex),
+ 27: makeSyscallInfo("inotify_add_watch", Hex, Path, Hex),
+ 28: makeSyscallInfo("inotify_rm_watch", Hex, Hex),
+ 29: makeSyscallInfo("ioctl", FD, Hex, Hex),
+ 30: makeSyscallInfo("ioprio_set", Hex, Hex, Hex),
+ 31: makeSyscallInfo("ioprio_get", Hex, Hex),
+ 32: makeSyscallInfo("flock", FD, Hex),
+ 33: makeSyscallInfo("mknodat", FD, Path, Mode, Hex),
+ 34: makeSyscallInfo("mkdirat", FD, Path, Hex),
+ 35: makeSyscallInfo("unlinkat", FD, Path, Hex),
+ 36: makeSyscallInfo("symlinkat", Path, Hex, Path),
+ 37: makeSyscallInfo("linkat", FD, Path, Hex, Path, Hex),
+ 38: makeSyscallInfo("renameat", FD, Path, Hex, Path),
+ 39: makeSyscallInfo("umount2", Path, Hex),
+ 40: makeSyscallInfo("mount", Path, Path, Path, Hex, Path),
+ 41: makeSyscallInfo("pivot_root", Path, Path),
+ 42: makeSyscallInfo("nfsservctl", Hex, Hex, Hex),
+ 43: makeSyscallInfo("statfs", Path, Hex),
+ 44: makeSyscallInfo("fstatfs", FD, Hex),
+ 45: makeSyscallInfo("truncate", Path, Hex),
+ 46: makeSyscallInfo("ftruncate", FD, Hex),
+ 47: makeSyscallInfo("fallocate", FD, Hex, Hex, Hex),
+ 48: makeSyscallInfo("faccessat", FD, Path, Oct, Hex),
+ 49: makeSyscallInfo("chdir", Path),
+ 50: makeSyscallInfo("fchdir", FD),
+ 51: makeSyscallInfo("chroot", Path),
+ 52: makeSyscallInfo("fchmod", FD, Mode),
+ 53: makeSyscallInfo("fchmodat", FD, Path, Mode),
+ 54: makeSyscallInfo("fchownat", FD, Path, Hex, Hex, Hex),
+ 55: makeSyscallInfo("fchown", FD, Hex, Hex),
+ 56: makeSyscallInfo("openat", FD, Path, OpenFlags, Mode),
+ 57: makeSyscallInfo("close", FD),
+ 58: makeSyscallInfo("vhangup"),
+ 59: makeSyscallInfo("pipe2", PipeFDs, Hex),
+ 60: makeSyscallInfo("quotactl", Hex, Hex, Hex, Hex),
+ 61: makeSyscallInfo("getdents64", FD, Hex, Hex),
+ 62: makeSyscallInfo("lseek", Hex, Hex, Hex),
+ 63: makeSyscallInfo("read", FD, ReadBuffer, Hex),
+ 64: makeSyscallInfo("write", FD, WriteBuffer, Hex),
+ 65: makeSyscallInfo("readv", FD, ReadIOVec, Hex),
+ 66: makeSyscallInfo("writev", FD, WriteIOVec, Hex),
+ 67: makeSyscallInfo("pread64", FD, ReadBuffer, Hex, Hex),
+ 68: makeSyscallInfo("pwrite64", FD, WriteBuffer, Hex, Hex),
+ 69: makeSyscallInfo("preadv", FD, ReadIOVec, Hex, Hex),
+ 70: makeSyscallInfo("pwritev", FD, WriteIOVec, Hex, Hex),
+ 71: makeSyscallInfo("sendfile", FD, FD, Hex, Hex),
+ 72: makeSyscallInfo("pselect6", Hex, Hex, Hex, Hex, Hex, Hex),
+ 73: makeSyscallInfo("ppoll", PollFDs, Hex, Timespec, SigSet, Hex),
+ 74: makeSyscallInfo("signalfd4", Hex, Hex, Hex, Hex),
+ 75: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
+ 76: makeSyscallInfo("splice", FD, Hex, FD, Hex, Hex, Hex),
+ 77: makeSyscallInfo("tee", FD, FD, Hex, Hex),
+ 78: makeSyscallInfo("readlinkat", FD, Path, ReadBuffer, Hex),
+ 79: makeSyscallInfo("fstatat", FD, Path, Stat, Hex),
+ 80: makeSyscallInfo("fstat", FD, Stat),
+ 81: makeSyscallInfo("sync"),
+ 82: makeSyscallInfo("fsync", FD),
+ 83: makeSyscallInfo("fdatasync", FD),
+ 84: makeSyscallInfo("sync_file_range", FD, Hex, Hex, Hex),
+ 85: makeSyscallInfo("timerfd_create", Hex, Hex),
+ 86: makeSyscallInfo("timerfd_settime", FD, Hex, ItimerSpec, PostItimerSpec),
+ 87: makeSyscallInfo("timerfd_gettime", FD, PostItimerSpec),
+ 88: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
+ 89: makeSyscallInfo("acct", Hex),
+ 90: makeSyscallInfo("capget", CapHeader, PostCapData),
+ 91: makeSyscallInfo("capset", CapHeader, CapData),
+ 92: makeSyscallInfo("personality", Hex),
+ 93: makeSyscallInfo("exit", Hex),
+ 94: makeSyscallInfo("exit_group", Hex),
+ 95: makeSyscallInfo("waitid", Hex, Hex, Hex, Hex, Rusage),
+ 96: makeSyscallInfo("set_tid_address", Hex),
+ 97: makeSyscallInfo("unshare", CloneFlags),
+ 98: makeSyscallInfo("futex", Hex, FutexOp, Hex, Timespec, Hex, Hex),
+ 99: makeSyscallInfo("set_robust_list", Hex, Hex),
+ 100: makeSyscallInfo("get_robust_list", Hex, Hex, Hex),
+ 101: makeSyscallInfo("nanosleep", Timespec, PostTimespec),
+ 102: makeSyscallInfo("getitimer", ItimerType, PostItimerVal),
+ 103: makeSyscallInfo("setitimer", ItimerType, ItimerVal, PostItimerVal),
+ 104: makeSyscallInfo("kexec_load", Hex, Hex, Hex, Hex),
+ 105: makeSyscallInfo("init_module", Hex, Hex, Hex),
+ 106: makeSyscallInfo("delete_module", Hex, Hex),
+ 107: makeSyscallInfo("timer_create", Hex, Hex, Hex),
+ 108: makeSyscallInfo("timer_gettime", Hex, PostItimerSpec),
+ 109: makeSyscallInfo("timer_getoverrun", Hex),
+ 110: makeSyscallInfo("timer_settime", Hex, Hex, ItimerSpec, PostItimerSpec),
+ 111: makeSyscallInfo("timer_delete", Hex),
+ 112: makeSyscallInfo("clock_settime", Hex, Timespec),
+ 113: makeSyscallInfo("clock_gettime", Hex, PostTimespec),
+ 114: makeSyscallInfo("clock_getres", Hex, PostTimespec),
+ 115: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
+ 116: makeSyscallInfo("syslog", Hex, Hex, Hex),
+ 117: makeSyscallInfo("ptrace", PtraceRequest, Hex, Hex, Hex),
+ 118: makeSyscallInfo("sched_setparam", Hex, Hex),
+ 119: makeSyscallInfo("sched_setscheduler", Hex, Hex, Hex),
+ 120: makeSyscallInfo("sched_getscheduler", Hex),
+ 121: makeSyscallInfo("sched_getparam", Hex, Hex),
+ 122: makeSyscallInfo("sched_setaffinity", Hex, Hex, Hex),
+ 123: makeSyscallInfo("sched_getaffinity", Hex, Hex, Hex),
+ 124: makeSyscallInfo("sched_yield"),
+ 125: makeSyscallInfo("sched_get_priority_max", Hex),
+ 126: makeSyscallInfo("sched_get_priority_min", Hex),
+ 127: makeSyscallInfo("sched_rr_get_interval", Hex, Hex),
+ 128: makeSyscallInfo("restart_syscall"),
+ 129: makeSyscallInfo("kill", Hex, Signal),
+ 130: makeSyscallInfo("tkill", Hex, Signal),
+ 131: makeSyscallInfo("tgkill", Hex, Hex, Signal),
+ 132: makeSyscallInfo("sigaltstack", Hex, Hex),
+ 133: makeSyscallInfo("rt_sigsuspend", Hex),
+ 134: makeSyscallInfo("rt_sigaction", Signal, SigAction, PostSigAction, Hex),
+ 135: makeSyscallInfo("rt_sigprocmask", SignalMaskAction, SigSet, PostSigSet, Hex),
+ 136: makeSyscallInfo("rt_sigpending", Hex),
+ 137: makeSyscallInfo("rt_sigtimedwait", SigSet, Hex, Timespec, Hex),
+ 138: makeSyscallInfo("rt_sigqueueinfo", Hex, Signal, Hex),
+ 139: makeSyscallInfo("rt_sigreturn"),
+ 140: makeSyscallInfo("setpriority", Hex, Hex, Hex),
+ 141: makeSyscallInfo("getpriority", Hex, Hex),
+ 142: makeSyscallInfo("reboot", Hex, Hex, Hex, Hex),
+ 143: makeSyscallInfo("setregid", Hex, Hex),
+ 144: makeSyscallInfo("setgid", Hex),
+ 145: makeSyscallInfo("setreuid", Hex, Hex),
+ 146: makeSyscallInfo("setuid", Hex),
+ 147: makeSyscallInfo("setresuid", Hex, Hex, Hex),
+ 148: makeSyscallInfo("getresuid", Hex, Hex, Hex),
+ 149: makeSyscallInfo("setresgid", Hex, Hex, Hex),
+ 150: makeSyscallInfo("getresgid", Hex, Hex, Hex),
+ 151: makeSyscallInfo("setfsuid", Hex),
+ 152: makeSyscallInfo("setfsgid", Hex),
+ 153: makeSyscallInfo("times", Hex),
+ 154: makeSyscallInfo("setpgid", Hex, Hex),
+ 155: makeSyscallInfo("getpgid", Hex),
+ 156: makeSyscallInfo("getsid", Hex),
+ 157: makeSyscallInfo("setsid"),
+ 158: makeSyscallInfo("getgroups", Hex, Hex),
+ 159: makeSyscallInfo("setgroups", Hex, Hex),
+ 160: makeSyscallInfo("uname", Uname),
+ 161: makeSyscallInfo("sethostname", Hex, Hex),
+ 162: makeSyscallInfo("setdomainname", Hex, Hex),
+ 163: makeSyscallInfo("getrlimit", Hex, Hex),
+ 164: makeSyscallInfo("setrlimit", Hex, Hex),
+ 165: makeSyscallInfo("getrusage", Hex, Rusage),
+ 166: makeSyscallInfo("umask", Hex),
+ 167: makeSyscallInfo("prctl", Hex, Hex, Hex, Hex, Hex),
+ 168: makeSyscallInfo("getcpu", Hex, Hex, Hex),
+ 169: makeSyscallInfo("gettimeofday", Timeval, Hex),
+ 170: makeSyscallInfo("settimeofday", Timeval, Hex),
+ 171: makeSyscallInfo("adjtimex", Hex),
+ 172: makeSyscallInfo("getpid"),
+ 173: makeSyscallInfo("getppid"),
+ 174: makeSyscallInfo("getuid"),
+ 175: makeSyscallInfo("geteuid"),
+ 176: makeSyscallInfo("getgid"),
+ 177: makeSyscallInfo("getegid"),
+ 178: makeSyscallInfo("gettid"),
+ 179: makeSyscallInfo("sysinfo", Hex),
+ 180: makeSyscallInfo("mq_open", Hex, Hex, Hex, Hex),
+ 181: makeSyscallInfo("mq_unlink", Hex),
+ 182: makeSyscallInfo("mq_timedsend", Hex, Hex, Hex, Hex, Hex),
+ 183: makeSyscallInfo("mq_timedreceive", Hex, Hex, Hex, Hex, Hex),
+ 184: makeSyscallInfo("mq_notify", Hex, Hex),
+ 185: makeSyscallInfo("mq_getsetattr", Hex, Hex, Hex),
+ 186: makeSyscallInfo("msgget", Hex, Hex),
+ 187: makeSyscallInfo("msgctl", Hex, Hex, Hex),
+ 188: makeSyscallInfo("msgrcv", Hex, Hex, Hex, Hex, Hex),
+ 189: makeSyscallInfo("msgsnd", Hex, Hex, Hex, Hex),
+ 190: makeSyscallInfo("semget", Hex, Hex, Hex),
+ 191: makeSyscallInfo("semctl", Hex, Hex, Hex, Hex),
+ 192: makeSyscallInfo("semtimedop", Hex, Hex, Hex, Hex),
+ 193: makeSyscallInfo("semop", Hex, Hex, Hex),
+ 194: makeSyscallInfo("shmget", Hex, Hex, Hex),
+ 195: makeSyscallInfo("shmctl", Hex, Hex, Hex),
+ 196: makeSyscallInfo("shmat", Hex, Hex, Hex),
+ 197: makeSyscallInfo("shmdt", Hex),
+ 198: makeSyscallInfo("socket", SockFamily, SockType, SockProtocol),
+ 199: makeSyscallInfo("socketpair", SockFamily, SockType, SockProtocol, Hex),
+ 200: makeSyscallInfo("bind", FD, SockAddr, Hex),
+ 201: makeSyscallInfo("listen", FD, Hex),
+ 202: makeSyscallInfo("accept", FD, PostSockAddr, SockLen),
+ 203: makeSyscallInfo("connect", FD, SockAddr, Hex),
+ 204: makeSyscallInfo("getsockname", FD, PostSockAddr, SockLen),
+ 205: makeSyscallInfo("getpeername", FD, PostSockAddr, SockLen),
+ 206: makeSyscallInfo("sendto", FD, Hex, Hex, Hex, SockAddr, Hex),
+ 207: makeSyscallInfo("recvfrom", FD, Hex, Hex, Hex, PostSockAddr, SockLen),
+ 208: makeSyscallInfo("setsockopt", FD, Hex, Hex, Hex, Hex),
+ 209: makeSyscallInfo("getsockopt", FD, Hex, Hex, Hex, Hex),
+ 210: makeSyscallInfo("shutdown", FD, Hex),
+ 211: makeSyscallInfo("sendmsg", FD, SendMsgHdr, Hex),
+ 212: makeSyscallInfo("recvmsg", FD, RecvMsgHdr, Hex),
+ 213: makeSyscallInfo("readahead", Hex, Hex, Hex),
+ 214: makeSyscallInfo("brk", Hex),
+ 215: makeSyscallInfo("munmap", Hex, Hex),
+ 216: makeSyscallInfo("mremap", Hex, Hex, Hex, Hex, Hex),
+ 217: makeSyscallInfo("add_key", Hex, Hex, Hex, Hex, Hex),
+ 218: makeSyscallInfo("request_key", Hex, Hex, Hex, Hex),
+ 219: makeSyscallInfo("keyctl", Hex, Hex, Hex, Hex, Hex),
+ 220: makeSyscallInfo("clone", CloneFlags, Hex, Hex, Hex, Hex),
+ 221: makeSyscallInfo("execve", Path, ExecveStringVector, ExecveStringVector),
+ 222: makeSyscallInfo("mmap", Hex, Hex, Hex, Hex, FD, Hex),
+ 223: makeSyscallInfo("fadvise64", FD, Hex, Hex, Hex),
+ 224: makeSyscallInfo("swapon", Hex, Hex),
+ 225: makeSyscallInfo("swapoff", Hex),
+ 226: makeSyscallInfo("mprotect", Hex, Hex, Hex),
+ 227: makeSyscallInfo("msync", Hex, Hex, Hex),
+ 228: makeSyscallInfo("mlock", Hex, Hex),
+ 229: makeSyscallInfo("munlock", Hex, Hex),
+ 230: makeSyscallInfo("mlockall", Hex),
+ 231: makeSyscallInfo("munlockall"),
+ 232: makeSyscallInfo("mincore", Hex, Hex, Hex),
+ 233: makeSyscallInfo("madvise", Hex, Hex, Hex),
+ 234: makeSyscallInfo("remap_file_pages", Hex, Hex, Hex, Hex, Hex),
+ 235: makeSyscallInfo("mbind", Hex, Hex, Hex, Hex, Hex, Hex),
+ 236: makeSyscallInfo("get_mempolicy", Hex, Hex, Hex, Hex, Hex),
+ 237: makeSyscallInfo("set_mempolicy", Hex, Hex, Hex),
+ 238: makeSyscallInfo("migrate_pages", Hex, Hex, Hex, Hex),
+ 239: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
+ 240: makeSyscallInfo("rt_tgsigqueueinfo", Hex, Hex, Signal, Hex),
+ 241: makeSyscallInfo("perf_event_open", Hex, Hex, Hex, Hex, Hex),
+ 242: makeSyscallInfo("accept4", FD, PostSockAddr, SockLen, SockFlags),
+ 243: makeSyscallInfo("recvmmsg", FD, Hex, Hex, Hex, Hex),
+
+ 260: makeSyscallInfo("wait4", Hex, Hex, Hex, Rusage),
+ 261: makeSyscallInfo("prlimit64", Hex, Hex, Hex, Hex),
+ 262: makeSyscallInfo("fanotify_init", Hex, Hex),
+ 263: makeSyscallInfo("fanotify_mark", Hex, Hex, Hex, Hex, Hex),
+ 264: makeSyscallInfo("name_to_handle_at", FD, Hex, Hex, Hex, Hex),
+ 265: makeSyscallInfo("open_by_handle_at", FD, Hex, Hex),
+ 266: makeSyscallInfo("clock_adjtime", Hex, Hex),
+ 267: makeSyscallInfo("syncfs", FD),
+ 268: makeSyscallInfo("setns", FD, Hex),
+ 269: makeSyscallInfo("sendmmsg", FD, Hex, Hex, Hex),
+ 270: makeSyscallInfo("process_vm_readv", Hex, ReadIOVec, Hex, IOVec, Hex, Hex),
+ 271: makeSyscallInfo("process_vm_writev", Hex, IOVec, Hex, WriteIOVec, Hex, Hex),
+ 272: makeSyscallInfo("kcmp", Hex, Hex, Hex, Hex, Hex),
+ 273: makeSyscallInfo("finit_module", Hex, Hex, Hex),
+ 274: makeSyscallInfo("sched_setattr", Hex, Hex, Hex),
+ 275: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
+ 276: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
+ 277: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 278: makeSyscallInfo("getrandom", Hex, Hex, Hex),
+ 279: makeSyscallInfo("memfd_create", Path, Hex),
+ 280: makeSyscallInfo("bpf", Hex, Hex, Hex),
+ 281: makeSyscallInfo("execveat", FD, Path, Hex, Hex, Hex),
+ 282: makeSyscallInfo("userfaultfd", Hex),
+ 283: makeSyscallInfo("membarrier", Hex),
+ 284: makeSyscallInfo("mlock2", Hex, Hex, Hex),
+ 285: makeSyscallInfo("copy_file_range", FD, Hex, FD, Hex, Hex, Hex),
+ 286: makeSyscallInfo("preadv2", FD, ReadIOVec, Hex, Hex, Hex),
+ 287: makeSyscallInfo("pwritev2", FD, WriteIOVec, Hex, Hex, Hex),
+ 291: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
+ 292: makeSyscallInfo("io_pgetevents", Hex, Hex, Hex, Hex, Timespec, SigSet),
+ 293: makeSyscallInfo("rseq", Hex, Hex, Hex, Hex),
+ 424: makeSyscallInfo("pidfd_send_signal", FD, Signal, Hex, Hex),
+ 425: makeSyscallInfo("io_uring_setup", Hex, Hex),
+ 426: makeSyscallInfo("io_uring_enter", FD, Hex, Hex, Hex, SigSet, Hex),
+ 427: makeSyscallInfo("io_uring_register", FD, Hex, Hex, Hex),
+ 428: makeSyscallInfo("open_tree", FD, Path, Hex),
+ 429: makeSyscallInfo("move_mount", FD, Path, FD, Path, Hex),
+ 430: makeSyscallInfo("fsopen", Path, Hex), // Not quite a path, but close.
+ 431: makeSyscallInfo("fsconfig", FD, Hex, Hex, Hex, Hex),
+ 432: makeSyscallInfo("fsmount", FD, Hex, Hex),
+ 433: makeSyscallInfo("fspick", FD, Path, Hex),
+ 434: makeSyscallInfo("pidfd_open", Hex, Hex),
+ 435: makeSyscallInfo("clone3", Hex, Hex),
+}
+
+func init() {
+ syscallTables = append(syscallTables,
+ syscallTable{
+ os: abi.Linux,
+ arch: arch.ARM64,
+ syscalls: linuxARM64})
+}
diff --git a/pkg/sentry/strace/poll.go b/pkg/sentry/strace/poll.go
index 5187594a7..074e80f9b 100644
--- a/pkg/sentry/strace/poll.go
+++ b/pkg/sentry/strace/poll.go
@@ -22,7 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// PollEventSet is the set of poll(2) event flags.
diff --git a/pkg/sentry/strace/select.go b/pkg/sentry/strace/select.go
new file mode 100644
index 000000000..3a4c32aa0
--- /dev/null
+++ b/pkg/sentry/strace/select.go
@@ -0,0 +1,56 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func fdsFromSet(t *kernel.Task, set []byte) []int {
+ var fds []int
+ // Append n if the n-th bit is 1.
+ for i, v := range set {
+ for j := 0; j < 8; j++ {
+ if (v>>j)&1 == 1 {
+ fds = append(fds, i*8+j)
+ }
+ }
+ }
+ return fds
+}
+
+func fdSet(t *kernel.Task, nfds int, addr usermem.Addr) string {
+ if nfds < 0 {
+ return fmt.Sprintf("%#x (negative nfds)", addr)
+ }
+ if addr == 0 {
+ return "null"
+ }
+
+ // Calculate the size of the fd set (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ set, err := linux.CopyInFDSet(t, addr, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return fmt.Sprintf("%#x (error decoding fdset: %s)", addr, err)
+ }
+
+ return fmt.Sprintf("%#x %v", addr, fdsFromSet(t, set))
+}
diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go
index 5656d53eb..c41f36e3f 100644
--- a/pkg/sentry/strace/signal.go
+++ b/pkg/sentry/strace/signal.go
@@ -21,7 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// signalNames contains the names of all named signals.
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 94334f6d2..b51c4c941 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -22,11 +22,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// SocketFamily are the possible socket(2) families.
@@ -208,16 +207,25 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
i += linux.SizeOfControlMessageHeader
width := t.Arch().Width()
length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length < 0 {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
continue
}
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := control.AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
fds := make(linux.ControlMessageRights, numRights)
@@ -286,7 +294,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
default:
panic("unreachable")
}
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
@@ -332,7 +340,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := netstack.AddressAndFamily(int(family), b, true /* strict */)
+ fa, _, err := netstack.AddressAndFamily(b)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
@@ -410,3 +418,228 @@ func sockFlags(flags int32) string {
}
return SocketFlagSet.Parse(uint64(flags))
}
+
+func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen usermem.Addr, maximumBlobSize uint, rval uintptr) string {
+ if int(rval) < 0 {
+ return hexNum(uint64(optVal))
+ }
+ if optVal == 0 {
+ return "null"
+ }
+ l, err := copySockLen(t, optLen)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading length: %v}", optLen, err)
+ }
+ return sockOptVal(t, level, optname, optVal, uint64(l), maximumBlobSize)
+}
+
+func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string {
+ switch optLen {
+ case 1:
+ var v uint8
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ case 2:
+ var v uint16
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ case 4:
+ var v uint32
+ _, err := t.CopyIn(optVal, &v)
+ if err != nil {
+ return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err)
+ }
+ return fmt.Sprintf("%#x {value=%v}", optVal, v)
+ default:
+ return dump(t, optVal, uint(optLen), maximumBlobSize)
+ }
+}
+
+var sockOptLevels = abi.ValueSet{
+ linux.SOL_IP: "SOL_IP",
+ linux.SOL_SOCKET: "SOL_SOCKET",
+ linux.SOL_TCP: "SOL_TCP",
+ linux.SOL_UDP: "SOL_UDP",
+ linux.SOL_IPV6: "SOL_IPV6",
+ linux.SOL_ICMPV6: "SOL_ICMPV6",
+ linux.SOL_RAW: "SOL_RAW",
+ linux.SOL_PACKET: "SOL_PACKET",
+ linux.SOL_NETLINK: "SOL_NETLINK",
+}
+
+var sockOptNames = map[uint64]abi.ValueSet{
+ linux.SOL_IP: {
+ linux.IP_TTL: "IP_TTL",
+ linux.IP_MULTICAST_TTL: "IP_MULTICAST_TTL",
+ linux.IP_MULTICAST_IF: "IP_MULTICAST_IF",
+ linux.IP_MULTICAST_LOOP: "IP_MULTICAST_LOOP",
+ linux.IP_TOS: "IP_TOS",
+ linux.IP_RECVTOS: "IP_RECVTOS",
+ linux.IPT_SO_GET_INFO: "IPT_SO_GET_INFO",
+ linux.IPT_SO_GET_ENTRIES: "IPT_SO_GET_ENTRIES",
+ linux.IP_ADD_MEMBERSHIP: "IP_ADD_MEMBERSHIP",
+ linux.IP_DROP_MEMBERSHIP: "IP_DROP_MEMBERSHIP",
+ linux.MCAST_JOIN_GROUP: "MCAST_JOIN_GROUP",
+ linux.IP_ADD_SOURCE_MEMBERSHIP: "IP_ADD_SOURCE_MEMBERSHIP",
+ linux.IP_BIND_ADDRESS_NO_PORT: "IP_BIND_ADDRESS_NO_PORT",
+ linux.IP_BLOCK_SOURCE: "IP_BLOCK_SOURCE",
+ linux.IP_CHECKSUM: "IP_CHECKSUM",
+ linux.IP_DROP_SOURCE_MEMBERSHIP: "IP_DROP_SOURCE_MEMBERSHIP",
+ linux.IP_FREEBIND: "IP_FREEBIND",
+ linux.IP_HDRINCL: "IP_HDRINCL",
+ linux.IP_IPSEC_POLICY: "IP_IPSEC_POLICY",
+ linux.IP_MINTTL: "IP_MINTTL",
+ linux.IP_MSFILTER: "IP_MSFILTER",
+ linux.IP_MTU_DISCOVER: "IP_MTU_DISCOVER",
+ linux.IP_MULTICAST_ALL: "IP_MULTICAST_ALL",
+ linux.IP_NODEFRAG: "IP_NODEFRAG",
+ linux.IP_OPTIONS: "IP_OPTIONS",
+ linux.IP_PASSSEC: "IP_PASSSEC",
+ linux.IP_PKTINFO: "IP_PKTINFO",
+ linux.IP_RECVERR: "IP_RECVERR",
+ linux.IP_RECVFRAGSIZE: "IP_RECVFRAGSIZE",
+ linux.IP_RECVOPTS: "IP_RECVOPTS",
+ linux.IP_RECVORIGDSTADDR: "IP_RECVORIGDSTADDR",
+ linux.IP_RECVTTL: "IP_RECVTTL",
+ linux.IP_RETOPTS: "IP_RETOPTS",
+ linux.IP_TRANSPARENT: "IP_TRANSPARENT",
+ linux.IP_UNBLOCK_SOURCE: "IP_UNBLOCK_SOURCE",
+ linux.IP_UNICAST_IF: "IP_UNICAST_IF",
+ linux.IP_XFRM_POLICY: "IP_XFRM_POLICY",
+ linux.MCAST_BLOCK_SOURCE: "MCAST_BLOCK_SOURCE",
+ linux.MCAST_JOIN_SOURCE_GROUP: "MCAST_JOIN_SOURCE_GROUP",
+ linux.MCAST_LEAVE_GROUP: "MCAST_LEAVE_GROUP",
+ linux.MCAST_LEAVE_SOURCE_GROUP: "MCAST_LEAVE_SOURCE_GROUP",
+ linux.MCAST_MSFILTER: "MCAST_MSFILTER",
+ linux.MCAST_UNBLOCK_SOURCE: "MCAST_UNBLOCK_SOURCE",
+ linux.IP_ROUTER_ALERT: "IP_ROUTER_ALERT",
+ linux.IP_PKTOPTIONS: "IP_PKTOPTIONS",
+ linux.IP_MTU: "IP_MTU",
+ linux.SO_ORIGINAL_DST: "SO_ORIGINAL_DST",
+ },
+ linux.SOL_SOCKET: {
+ linux.SO_ERROR: "SO_ERROR",
+ linux.SO_PEERCRED: "SO_PEERCRED",
+ linux.SO_PASSCRED: "SO_PASSCRED",
+ linux.SO_SNDBUF: "SO_SNDBUF",
+ linux.SO_RCVBUF: "SO_RCVBUF",
+ linux.SO_REUSEADDR: "SO_REUSEADDR",
+ linux.SO_REUSEPORT: "SO_REUSEPORT",
+ linux.SO_BINDTODEVICE: "SO_BINDTODEVICE",
+ linux.SO_BROADCAST: "SO_BROADCAST",
+ linux.SO_KEEPALIVE: "SO_KEEPALIVE",
+ linux.SO_LINGER: "SO_LINGER",
+ linux.SO_SNDTIMEO: "SO_SNDTIMEO",
+ linux.SO_RCVTIMEO: "SO_RCVTIMEO",
+ linux.SO_OOBINLINE: "SO_OOBINLINE",
+ linux.SO_TIMESTAMP: "SO_TIMESTAMP",
+ },
+ linux.SOL_TCP: {
+ linux.TCP_NODELAY: "TCP_NODELAY",
+ linux.TCP_CORK: "TCP_CORK",
+ linux.TCP_QUICKACK: "TCP_QUICKACK",
+ linux.TCP_MAXSEG: "TCP_MAXSEG",
+ linux.TCP_KEEPIDLE: "TCP_KEEPIDLE",
+ linux.TCP_KEEPINTVL: "TCP_KEEPINTVL",
+ linux.TCP_USER_TIMEOUT: "TCP_USER_TIMEOUT",
+ linux.TCP_INFO: "TCP_INFO",
+ linux.TCP_CC_INFO: "TCP_CC_INFO",
+ linux.TCP_NOTSENT_LOWAT: "TCP_NOTSENT_LOWAT",
+ linux.TCP_ZEROCOPY_RECEIVE: "TCP_ZEROCOPY_RECEIVE",
+ linux.TCP_CONGESTION: "TCP_CONGESTION",
+ linux.TCP_LINGER2: "TCP_LINGER2",
+ linux.TCP_DEFER_ACCEPT: "TCP_DEFER_ACCEPT",
+ linux.TCP_REPAIR_OPTIONS: "TCP_REPAIR_OPTIONS",
+ linux.TCP_INQ: "TCP_INQ",
+ linux.TCP_FASTOPEN: "TCP_FASTOPEN",
+ linux.TCP_FASTOPEN_CONNECT: "TCP_FASTOPEN_CONNECT",
+ linux.TCP_FASTOPEN_KEY: "TCP_FASTOPEN_KEY",
+ linux.TCP_FASTOPEN_NO_COOKIE: "TCP_FASTOPEN_NO_COOKIE",
+ linux.TCP_KEEPCNT: "TCP_KEEPCNT",
+ linux.TCP_QUEUE_SEQ: "TCP_QUEUE_SEQ",
+ linux.TCP_REPAIR: "TCP_REPAIR",
+ linux.TCP_REPAIR_QUEUE: "TCP_REPAIR_QUEUE",
+ linux.TCP_REPAIR_WINDOW: "TCP_REPAIR_WINDOW",
+ linux.TCP_SAVED_SYN: "TCP_SAVED_SYN",
+ linux.TCP_SAVE_SYN: "TCP_SAVE_SYN",
+ linux.TCP_SYNCNT: "TCP_SYNCNT",
+ linux.TCP_THIN_DUPACK: "TCP_THIN_DUPACK",
+ linux.TCP_THIN_LINEAR_TIMEOUTS: "TCP_THIN_LINEAR_TIMEOUTS",
+ linux.TCP_TIMESTAMP: "TCP_TIMESTAMP",
+ linux.TCP_ULP: "TCP_ULP",
+ linux.TCP_WINDOW_CLAMP: "TCP_WINDOW_CLAMP",
+ },
+ linux.SOL_IPV6: {
+ linux.IPV6_V6ONLY: "IPV6_V6ONLY",
+ linux.IPV6_PATHMTU: "IPV6_PATHMTU",
+ linux.IPV6_TCLASS: "IPV6_TCLASS",
+ linux.IPV6_ADD_MEMBERSHIP: "IPV6_ADD_MEMBERSHIP",
+ linux.IPV6_DROP_MEMBERSHIP: "IPV6_DROP_MEMBERSHIP",
+ linux.IPV6_IPSEC_POLICY: "IPV6_IPSEC_POLICY",
+ linux.IPV6_JOIN_ANYCAST: "IPV6_JOIN_ANYCAST",
+ linux.IPV6_LEAVE_ANYCAST: "IPV6_LEAVE_ANYCAST",
+ linux.IPV6_PKTINFO: "IPV6_PKTINFO",
+ linux.IPV6_ROUTER_ALERT: "IPV6_ROUTER_ALERT",
+ linux.IPV6_XFRM_POLICY: "IPV6_XFRM_POLICY",
+ linux.MCAST_BLOCK_SOURCE: "MCAST_BLOCK_SOURCE",
+ linux.MCAST_JOIN_GROUP: "MCAST_JOIN_GROUP",
+ linux.MCAST_JOIN_SOURCE_GROUP: "MCAST_JOIN_SOURCE_GROUP",
+ linux.MCAST_LEAVE_GROUP: "MCAST_LEAVE_GROUP",
+ linux.MCAST_LEAVE_SOURCE_GROUP: "MCAST_LEAVE_SOURCE_GROUP",
+ linux.MCAST_UNBLOCK_SOURCE: "MCAST_UNBLOCK_SOURCE",
+ linux.IPV6_2292DSTOPTS: "IPV6_2292DSTOPTS",
+ linux.IPV6_2292HOPLIMIT: "IPV6_2292HOPLIMIT",
+ linux.IPV6_2292HOPOPTS: "IPV6_2292HOPOPTS",
+ linux.IPV6_2292PKTINFO: "IPV6_2292PKTINFO",
+ linux.IPV6_2292PKTOPTIONS: "IPV6_2292PKTOPTIONS",
+ linux.IPV6_2292RTHDR: "IPV6_2292RTHDR",
+ linux.IPV6_ADDR_PREFERENCES: "IPV6_ADDR_PREFERENCES",
+ linux.IPV6_AUTOFLOWLABEL: "IPV6_AUTOFLOWLABEL",
+ linux.IPV6_DONTFRAG: "IPV6_DONTFRAG",
+ linux.IPV6_DSTOPTS: "IPV6_DSTOPTS",
+ linux.IPV6_FLOWINFO: "IPV6_FLOWINFO",
+ linux.IPV6_FLOWINFO_SEND: "IPV6_FLOWINFO_SEND",
+ linux.IPV6_FLOWLABEL_MGR: "IPV6_FLOWLABEL_MGR",
+ linux.IPV6_FREEBIND: "IPV6_FREEBIND",
+ linux.IPV6_HOPOPTS: "IPV6_HOPOPTS",
+ linux.IPV6_MINHOPCOUNT: "IPV6_MINHOPCOUNT",
+ linux.IPV6_MTU: "IPV6_MTU",
+ linux.IPV6_MTU_DISCOVER: "IPV6_MTU_DISCOVER",
+ linux.IPV6_MULTICAST_ALL: "IPV6_MULTICAST_ALL",
+ linux.IPV6_MULTICAST_HOPS: "IPV6_MULTICAST_HOPS",
+ linux.IPV6_MULTICAST_IF: "IPV6_MULTICAST_IF",
+ linux.IPV6_MULTICAST_LOOP: "IPV6_MULTICAST_LOOP",
+ linux.IPV6_RECVDSTOPTS: "IPV6_RECVDSTOPTS",
+ linux.IPV6_RECVERR: "IPV6_RECVERR",
+ linux.IPV6_RECVFRAGSIZE: "IPV6_RECVFRAGSIZE",
+ linux.IPV6_RECVHOPLIMIT: "IPV6_RECVHOPLIMIT",
+ linux.IPV6_RECVHOPOPTS: "IPV6_RECVHOPOPTS",
+ linux.IPV6_RECVORIGDSTADDR: "IPV6_RECVORIGDSTADDR",
+ linux.IPV6_RECVPATHMTU: "IPV6_RECVPATHMTU",
+ linux.IPV6_RECVPKTINFO: "IPV6_RECVPKTINFO",
+ linux.IPV6_RECVRTHDR: "IPV6_RECVRTHDR",
+ linux.IPV6_RECVTCLASS: "IPV6_RECVTCLASS",
+ linux.IPV6_RTHDR: "IPV6_RTHDR",
+ linux.IPV6_RTHDRDSTOPTS: "IPV6_RTHDRDSTOPTS",
+ linux.IPV6_TRANSPARENT: "IPV6_TRANSPARENT",
+ linux.IPV6_UNICAST_HOPS: "IPV6_UNICAST_HOPS",
+ linux.IPV6_UNICAST_IF: "IPV6_UNICAST_IF",
+ linux.MCAST_MSFILTER: "MCAST_MSFILTER",
+ linux.IPV6_ADDRFORM: "IPV6_ADDRFORM",
+ },
+ linux.SOL_NETLINK: {
+ linux.NETLINK_BROADCAST_ERROR: "NETLINK_BROADCAST_ERROR",
+ linux.NETLINK_CAP_ACK: "NETLINK_CAP_ACK",
+ linux.NETLINK_DUMP_STRICT_CHK: "NETLINK_DUMP_STRICT_CHK",
+ linux.NETLINK_EXT_ACK: "NETLINK_EXT_ACK",
+ linux.NETLINK_LIST_MEMBERSHIPS: "NETLINK_LIST_MEMBERSHIPS",
+ linux.NETLINK_NO_ENOBUFS: "NETLINK_NO_ENOBUFS",
+ linux.NETLINK_PKTINFO: "NETLINK_PKTINFO",
+ },
+}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 311389547..87b239730 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -33,7 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
pb "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// DefaultLogMaximumSize is the default LogMaximumSize.
@@ -55,6 +55,14 @@ var ItimerTypes = abi.ValueSet{
linux.ITIMER_PROF: "ITIMER_PROF",
}
+func hexNum(num uint64) string {
+ return "0x" + strconv.FormatUint(num, 16)
+}
+
+func hexArg(arg arch.SyscallArgument) string {
+ return hexNum(arg.Uint64())
+}
+
func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, maxBytes uint64) string {
if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV {
return fmt.Sprintf("%#x (error decoding iovecs: invalid iovcnt)", addr)
@@ -133,16 +141,20 @@ func path(t *kernel.Task, addr usermem.Addr) string {
}
func fd(t *kernel.Task, fd int32) string {
+ if kernel.VFS2Enabled {
+ return fdVFS2(t, fd)
+ }
+
root := t.FSContext().RootDirectory()
if root != nil {
- defer root.DecRef()
+ defer root.DecRef(t)
}
if fd == linux.AT_FDCWD {
wd := t.FSContext().WorkingDirectory()
var name string
if wd != nil {
- defer wd.DecRef()
+ defer wd.DecRef(t)
name, _ = wd.FullName(root)
} else {
name = "(unknown cwd)"
@@ -155,12 +167,36 @@ func fd(t *kernel.Task, fd int32) string {
// Cast FD to uint64 to avoid printing negative hex.
return fmt.Sprintf("%#x (bad FD)", uint64(fd))
}
- defer file.DecRef()
+ defer file.DecRef(t)
name, _ := file.Dirent.FullName(root)
return fmt.Sprintf("%#x %s", fd, name)
}
+func fdVFS2(t *kernel.Task, fd int32) string {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef(t)
+
+ vfsObj := root.Mount().Filesystem().VirtualFilesystem()
+ if fd == linux.AT_FDCWD {
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef(t)
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, wd)
+ return fmt.Sprintf("AT_FDCWD %s", name)
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // Cast FD to uint64 to avoid printing negative hex.
+ return fmt.Sprintf("%#x (bad FD)", uint64(fd))
+ }
+ defer file.DecRef(t)
+
+ name, _ := vfsObj.PathnameWithDeleted(t, root, file.VirtualDentry())
+ return fmt.Sprintf("%#x %s", fd, name)
+}
+
func fdpair(t *kernel.Task, addr usermem.Addr) string {
var fds [2]int32
_, err := t.CopyIn(addr, &fds)
@@ -389,6 +425,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, path(t, args[arg].Pointer()))
case ExecveStringVector:
output = append(output, stringVector(t, args[arg].Pointer()))
+ case SetSockOptVal:
+ output = append(output, sockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Uint64() /* optLen */, maximumBlobSize))
+ case SockOptLevel:
+ output = append(output, sockOptLevels.Parse(args[arg].Uint64()))
+ case SockOptName:
+ output = append(output, sockOptNames[args[arg-1].Uint64() /* level */].Parse(args[arg].Uint64()))
case SockAddr:
output = append(output, sockAddr(t, args[arg].Pointer(), uint32(args[arg+1].Uint64())))
case SockLen:
@@ -439,12 +481,20 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case EpollCtlOp:
+ output = append(output, epollCtlOps.Parse(uint64(args[arg].Int())))
+ case EpollEvent:
+ output = append(output, epollEvent(t, args[arg].Pointer()))
+ case EpollEvents:
+ output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
+ case SelectFDSet:
+ output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
output = append(output, "0o"+strconv.FormatUint(args[arg].Uint64(), 8))
case Hex:
fallthrough
default:
- output = append(output, "0x"+strconv.FormatUint(args[arg].Uint64(), 16))
+ output = append(output, hexArg(args[arg]))
}
}
@@ -505,6 +555,14 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint
output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
case PollFDs:
output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ case EpollEvents:
+ output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize))
+ case GetSockOptVal:
+ output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval)
+ case SetSockOptVal:
+ // No need to print the value again. While it usually
+ // isn't, the string version of this arg can be long.
+ output[arg] = hexArg(args[arg])
}
}
}
@@ -661,7 +719,7 @@ func (s SyscallMap) SyscallEnter(t *kernel.Task, sysno uintptr, args arch.Syscal
// SyscallExit implements kernel.Stracer.SyscallExit. It logs the syscall
// exit trace.
func (s SyscallMap) SyscallExit(context interface{}, t *kernel.Task, sysno, rval uintptr, err error) {
- errno := t.ExtractErrno(err, int(sysno))
+ errno := kernel.ExtractErrno(err, int(sysno))
c := context.(*syscallContext)
elapsed := time.Since(c.start)
@@ -720,9 +778,6 @@ func (s SyscallMap) Name(sysno uintptr) string {
//
// N.B. This is not in an init function because we can't be sure all syscall
// tables are registered with the kernel when init runs.
-//
-// TODO(gvisor.dev/issue/155): remove kernel package dependencies from this
-// package and have the kernel package self-initialize all syscall tables.
func Initialize() {
for _, table := range kernel.SyscallTables() {
// Is this known?
diff --git a/pkg/sentry/strace/strace.proto b/pkg/sentry/strace/strace.proto
index 4b2f73a5f..906c52c51 100644
--- a/pkg/sentry/strace/strace.proto
+++ b/pkg/sentry/strace/strace.proto
@@ -32,8 +32,7 @@ message Strace {
}
}
-message StraceEnter {
-}
+message StraceEnter {}
message StraceExit {
// Return value formatted as string.
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 3c389d375..7e69b9279 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -206,6 +206,38 @@ const (
// PollFDs is an array of struct pollfd. The number of entries in the
// array is in the next argument.
PollFDs
+
+ // SelectFDSet is an fd_set argument in select(2)/pselect(2). The
+ // number of FDs represented must be the first argument.
+ SelectFDSet
+
+ // GetSockOptVal is the optval argument in getsockopt(2).
+ //
+ // Formatted after syscall execution.
+ GetSockOptVal
+
+ // SetSockOptVal is the optval argument in setsockopt(2).
+ //
+ // Contents omitted after syscall execution.
+ SetSockOptVal
+
+ // SockOptLevel is the level argument in getsockopt(2) and
+ // setsockopt(2).
+ SockOptLevel
+
+ // SockOptLevel is the optname argument in getsockopt(2) and
+ // setsockopt(2).
+ SockOptName
+
+ // EpollCtlOp is the op argument to epoll_ctl(2).
+ EpollCtlOp
+
+ // EpollEvent is the event argument in epoll_ctl(2).
+ EpollEvent
+
+ // EpollEvents is an array of struct epoll_event. It is the events
+ // argument in epoll_wait(2)/epoll_pwait(2).
+ EpollEvents
)
// defaultFormat is the syscall argument format to use if the actual format is
@@ -246,14 +278,7 @@ type syscallTable struct {
syscalls SyscallMap
}
-// syscallTables contains all syscall tables.
-var syscallTables = []syscallTable{
- {
- os: abi.Linux,
- arch: arch.AMD64,
- syscalls: linuxAMD64,
- },
-}
+var syscallTables []syscallTable
// Lookup returns the SyscallMap for the OS/Arch combination. The returned map
// must not be changed.
diff --git a/pkg/sentry/syscalls/BUILD b/pkg/sentry/syscalls/BUILD
index 79d972202..b8d1bd415 100644
--- a/pkg/sentry/syscalls/BUILD
+++ b/pkg/sentry/syscalls/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,7 +8,6 @@ go_library(
"epoll.go",
"syscalls.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go
index 87dcad18b..d23a0068a 100644
--- a/pkg/sentry/syscalls/epoll.go
+++ b/pkg/sentry/syscalls/epoll.go
@@ -17,6 +17,7 @@ package syscalls
import (
"time"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
@@ -27,7 +28,7 @@ import (
// CreateEpoll implements the epoll_create(2) linux syscall.
func CreateEpoll(t *kernel.Task, closeOnExec bool) (int32, error) {
file := epoll.NewEventPoll(t)
- defer file.DecRef()
+ defer file.DecRef(t)
fd, err := t.NewFDFrom(0, file, kernel.FDFlags{
CloseOnExec: closeOnExec,
@@ -46,14 +47,14 @@ func AddEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, mask
if epollfile == nil {
return syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Get the target file id.
file := t.GetFile(fd)
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
@@ -72,14 +73,14 @@ func UpdateEpoll(t *kernel.Task, epfd int32, fd int32, flags epoll.EntryFlags, m
if epollfile == nil {
return syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Get the target file id.
file := t.GetFile(fd)
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
@@ -98,14 +99,14 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
if epollfile == nil {
return syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Get the target file id.
file := t.GetFile(fd)
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
@@ -114,17 +115,17 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error {
}
// Try to remove the entry.
- return e.RemoveEntry(epoll.FileIdentifier{file, fd})
+ return e.RemoveEntry(t, epoll.FileIdentifier{file, fd})
}
// WaitEpoll implements the epoll_wait(2) linux syscall.
-func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]epoll.Event, error) {
+func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) {
// Get epoll from the file descriptor.
epollfile := t.GetFile(fd)
if epollfile == nil {
return nil, syserror.EBADF
}
- defer epollfile.DecRef()
+ defer epollfile.DecRef(t)
// Extract the epollPoll operations.
e, ok := epollfile.FileOperations.(*epoll.EventPoll)
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 4c0bf96e4..4a9b04fd0 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,11 +8,11 @@ go_library(
"error.go",
"flags.go",
"linux64.go",
- "linux64_amd64.go",
- "linux64_arm64.go",
"sigset.go",
"sys_aio.go",
"sys_capability.go",
+ "sys_clone_amd64.go",
+ "sys_clone_arm64.go",
"sys_epoll.go",
"sys_eventfd.go",
"sys_file.go",
@@ -30,6 +30,7 @@ go_library(
"sys_random.go",
"sys_read.go",
"sys_rlimit.go",
+ "sys_rseq.go",
"sys_rusage.go",
"sys_sched.go",
"sys_seccomp.go",
@@ -39,6 +40,8 @@ go_library(
"sys_socket.go",
"sys_splice.go",
"sys_stat.go",
+ "sys_stat_amd64.go",
+ "sys_stat_arm64.go",
"sys_sync.go",
"sys_sysinfo.go",
"sys_syslog.go",
@@ -46,28 +49,31 @@ go_library(
"sys_time.go",
"sys_timer.go",
"sys_timerfd.go",
- "sys_tls.go",
+ "sys_tls_amd64.go",
+ "sys_tls_arm64.go",
"sys_utsname.go",
"sys_write.go",
+ "sys_xattr.go",
"timespec.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/syscalls/linux",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi",
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/bpf",
+ "//pkg/context",
"//pkg/log",
"//pkg/metric",
"//pkg/rand",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
"//pkg/sentry/fs",
"//pkg/sentry/fs/anon",
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/epoll",
@@ -82,15 +88,18 @@ go_library(
"//pkg/sentry/loader",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
- "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/syscalls",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserr",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go
index 1d9018c96..46060f6f5 100644
--- a/pkg/sentry/syscalls/linux/error.go
+++ b/pkg/sentry/syscalls/linux/error.go
@@ -16,13 +16,14 @@ package linux
import (
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -31,20 +32,58 @@ var (
partialResultOnce sync.Once
)
+// HandleIOErrorVFS2 handles special error cases for partial results. For some
+// errors, we may consume the error and return only the partial read/write.
+//
+// op and f are used only for panics.
+func HandleIOErrorVFS2(t *kernel.Task, partialResult bool, ioerr, intr error, op string, f *vfs.FileDescription) error {
+ known, err := handleIOErrorImpl(t, partialResult, ioerr, intr, op)
+ if err != nil {
+ return err
+ }
+ if !known {
+ // An unknown error is encountered with a partial read/write.
+ fs := f.Mount().Filesystem().VirtualFilesystem()
+ root := vfs.RootFromContext(t)
+ name, _ := fs.PathnameWithDeleted(t, root, f.VirtualDentry())
+ log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, ioerr, ioerr, op, name)
+ partialResultOnce.Do(partialResultMetric.Increment)
+ }
+ return nil
+}
+
// handleIOError handles special error cases for partial results. For some
// errors, we may consume the error and return only the partial read/write.
//
// op and f are used only for panics.
-func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op string, f *fs.File) error {
+func handleIOError(t *kernel.Task, partialResult bool, ioerr, intr error, op string, f *fs.File) error {
+ known, err := handleIOErrorImpl(t, partialResult, ioerr, intr, op)
+ if err != nil {
+ return err
+ }
+ if !known {
+ // An unknown error is encountered with a partial read/write.
+ name, _ := f.Dirent.FullName(nil /* ignore chroot */)
+ log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, ioerr, ioerr, op, name, f.FileOperations)
+ partialResultOnce.Do(partialResultMetric.Increment)
+ }
+ return nil
+}
+
+// handleIOError handles special error cases for partial results. For some
+// errors, we may consume the error and return only the partial read/write.
+//
+// Returns false if error is unknown.
+func handleIOErrorImpl(t *kernel.Task, partialResult bool, err, intr error, op string) (bool, error) {
switch err {
case nil:
// Typical successful syscall.
- return nil
+ return true, nil
case io.EOF:
// EOF is always consumed. If this is a partial read/write
// (result != 0), the application will see that, otherwise
// they will see 0.
- return nil
+ return true, nil
case syserror.ErrExceedsFileSizeLimit:
// Ignore partialResult because this error only applies to
// normal files, and for those files we cannot accumulate
@@ -53,20 +92,20 @@ func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op strin
// Do not consume the error and return it as EFBIG.
// Simultaneously send a SIGXFSZ per setrlimit(2).
t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
- return syserror.EFBIG
+ return true, syserror.EFBIG
case syserror.ErrInterrupted:
// The syscall was interrupted. Return nil if it completed
// partially, otherwise return the error code that the syscall
// needs (to indicate to the kernel what it should do).
if partialResult {
- return nil
+ return true, nil
}
- return intr
+ return true, intr
}
if !partialResult {
// Typical syscall error.
- return err
+ return true, err
}
switch err {
@@ -75,14 +114,14 @@ func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op strin
// read/write. Like ErrWouldBlock, since we have a
// partial read/write, we consume the error and return
// the partial result.
- return nil
+ return true, nil
case syserror.EFAULT:
// EFAULT is only shown the user if nothing was
// read/written. If we read something (this case), they see
// a partial read/write. They will then presumably try again
// with an incremented buffer, which will EFAULT with
// result == 0.
- return nil
+ return true, nil
case syserror.EPIPE:
// Writes to a pipe or socket will return EPIPE if the other
// side is gone. The partial write is returned. EPIPE will be
@@ -90,32 +129,29 @@ func handleIOError(t *kernel.Task, partialResult bool, err, intr error, op strin
//
// TODO(gvisor.dev/issue/161): In some cases SIGPIPE should
// also be sent to the application.
- return nil
+ return true, nil
case syserror.ENOSPC:
// Similar to EPIPE. Return what we wrote this time, and let
// ENOSPC be returned on the next call.
- return nil
+ return true, nil
case syserror.ECONNRESET:
// For TCP sendfile connections, we may have a reset. But we
// should just return n as the result.
- return nil
+ return true, nil
case syserror.ErrWouldBlock:
// Syscall would block, but completed a partial read/write.
// This case should only be returned by IssueIO for nonblocking
// files. Since we have a partial read/write, we consume
// ErrWouldBlock, returning the partial result.
- return nil
+ return true, nil
}
switch err.(type) {
case kernel.SyscallRestartErrno:
// Identical to the EINTR case.
- return nil
+ return true, nil
}
- // An unknown error is encountered with a partial read/write.
- name, _ := f.Dirent.FullName(nil /* ignore chroot */)
- log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, err, err, op, name, f.FileOperations)
- partialResultOnce.Do(partialResultMetric.Increment)
- return nil
+ // Error is unknown and cannot be properly handled.
+ return false, nil
}
diff --git a/pkg/sentry/syscalls/linux/flags.go b/pkg/sentry/syscalls/linux/flags.go
index 444f2b004..07961dad9 100644
--- a/pkg/sentry/syscalls/linux/flags.go
+++ b/pkg/sentry/syscalls/linux/flags.go
@@ -50,5 +50,6 @@ func linuxToFlags(mask uint) fs.FileFlags {
Directory: mask&linux.O_DIRECTORY != 0,
Async: mask&linux.O_ASYNC != 0,
LargeFile: mask&linux.O_LARGEFILE != 0,
+ Truncate: mask&linux.O_TRUNC != 0,
}
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 68589a377..80c65164a 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -15,6 +15,16 @@
// Package linux provides syscall tables for amd64 Linux.
package linux
+import (
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
const (
// LinuxSysname is the OS name advertised by gVisor.
LinuxSysname = "Linux"
@@ -25,3 +35,702 @@ const (
// LinuxVersion is the version info advertised by gVisor.
LinuxVersion = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
)
+
+// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
+// numbers from Linux 4.4.
+var AMD64 = &kernel.SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.AMD64,
+ Version: kernel.Version{
+ // Version 4.4 is chosen as a stable, longterm version of Linux, which
+ // guides the interface provided by this syscall table. The build
+ // version is that for a clean build with default kernel config, at 5
+ // minutes after v4.4 was tagged.
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
+ },
+ AuditNumber: linux.AUDIT_ARCH_X86_64,
+ Table: map[uintptr]kernel.Syscall{
+ 0: syscalls.Supported("read", Read),
+ 1: syscalls.Supported("write", Write),
+ 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil),
+ 3: syscalls.Supported("close", Close),
+ 4: syscalls.Supported("stat", Stat),
+ 5: syscalls.Supported("fstat", Fstat),
+ 6: syscalls.Supported("lstat", Lstat),
+ 7: syscalls.Supported("poll", Poll),
+ 8: syscalls.Supported("lseek", Lseek),
+ 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil),
+ 10: syscalls.Supported("mprotect", Mprotect),
+ 11: syscalls.Supported("munmap", Munmap),
+ 12: syscalls.Supported("brk", Brk),
+ 13: syscalls.Supported("rt_sigaction", RtSigaction),
+ 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
+ 15: syscalls.Supported("rt_sigreturn", RtSigreturn),
+ 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
+ 17: syscalls.Supported("pread64", Pread64),
+ 18: syscalls.Supported("pwrite64", Pwrite64),
+ 19: syscalls.Supported("readv", Readv),
+ 20: syscalls.Supported("writev", Writev),
+ 21: syscalls.Supported("access", Access),
+ 22: syscalls.Supported("pipe", Pipe),
+ 23: syscalls.Supported("select", Select),
+ 24: syscalls.Supported("sched_yield", SchedYield),
+ 25: syscalls.Supported("mremap", Mremap),
+ 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
+ 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
+ 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
+ 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
+ 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
+ 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
+ 32: syscalls.Supported("dup", Dup),
+ 33: syscalls.Supported("dup2", Dup2),
+ 34: syscalls.Supported("pause", Pause),
+ 35: syscalls.Supported("nanosleep", Nanosleep),
+ 36: syscalls.Supported("getitimer", Getitimer),
+ 37: syscalls.Supported("alarm", Alarm),
+ 38: syscalls.Supported("setitimer", Setitimer),
+ 39: syscalls.Supported("getpid", Getpid),
+ 40: syscalls.Supported("sendfile", Sendfile),
+ 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
+ 42: syscalls.Supported("connect", Connect),
+ 43: syscalls.Supported("accept", Accept),
+ 44: syscalls.Supported("sendto", SendTo),
+ 45: syscalls.Supported("recvfrom", RecvFrom),
+ 46: syscalls.Supported("sendmsg", SendMsg),
+ 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
+ 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
+ 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
+ 50: syscalls.Supported("listen", Listen),
+ 51: syscalls.Supported("getsockname", GetSockName),
+ 52: syscalls.Supported("getpeername", GetPeerName),
+ 53: syscalls.Supported("socketpair", SocketPair),
+ 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
+ 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
+ 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
+ 57: syscalls.Supported("fork", Fork),
+ 58: syscalls.Supported("vfork", Vfork),
+ 59: syscalls.Supported("execve", Execve),
+ 60: syscalls.Supported("exit", Exit),
+ 61: syscalls.Supported("wait4", Wait4),
+ 62: syscalls.Supported("kill", Kill),
+ 63: syscalls.Supported("uname", Uname),
+ 64: syscalls.Supported("semget", Semget),
+ 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 67: syscalls.Supported("shmdt", Shmdt),
+ 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
+ 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
+ 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
+ 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
+ 76: syscalls.Supported("truncate", Truncate),
+ 77: syscalls.Supported("ftruncate", Ftruncate),
+ 78: syscalls.Supported("getdents", Getdents),
+ 79: syscalls.Supported("getcwd", Getcwd),
+ 80: syscalls.Supported("chdir", Chdir),
+ 81: syscalls.Supported("fchdir", Fchdir),
+ 82: syscalls.Supported("rename", Rename),
+ 83: syscalls.Supported("mkdir", Mkdir),
+ 84: syscalls.Supported("rmdir", Rmdir),
+ 85: syscalls.Supported("creat", Creat),
+ 86: syscalls.Supported("link", Link),
+ 87: syscalls.Supported("unlink", Unlink),
+ 88: syscalls.Supported("symlink", Symlink),
+ 89: syscalls.Supported("readlink", Readlink),
+ 90: syscalls.Supported("chmod", Chmod),
+ 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
+ 92: syscalls.Supported("chown", Chown),
+ 93: syscalls.Supported("fchown", Fchown),
+ 94: syscalls.Supported("lchown", Lchown),
+ 95: syscalls.Supported("umask", Umask),
+ 96: syscalls.Supported("gettimeofday", Gettimeofday),
+ 97: syscalls.Supported("getrlimit", Getrlimit),
+ 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
+ 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
+ 100: syscalls.Supported("times", Times),
+ 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
+ 102: syscalls.Supported("getuid", Getuid),
+ 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
+ 104: syscalls.Supported("getgid", Getgid),
+ 105: syscalls.Supported("setuid", Setuid),
+ 106: syscalls.Supported("setgid", Setgid),
+ 107: syscalls.Supported("geteuid", Geteuid),
+ 108: syscalls.Supported("getegid", Getegid),
+ 109: syscalls.Supported("setpgid", Setpgid),
+ 110: syscalls.Supported("getppid", Getppid),
+ 111: syscalls.Supported("getpgrp", Getpgrp),
+ 112: syscalls.Supported("setsid", Setsid),
+ 113: syscalls.Supported("setreuid", Setreuid),
+ 114: syscalls.Supported("setregid", Setregid),
+ 115: syscalls.Supported("getgroups", Getgroups),
+ 116: syscalls.Supported("setgroups", Setgroups),
+ 117: syscalls.Supported("setresuid", Setresuid),
+ 118: syscalls.Supported("getresuid", Getresuid),
+ 119: syscalls.Supported("setresgid", Setresgid),
+ 120: syscalls.Supported("getresgid", Getresgid),
+ 121: syscalls.Supported("getpgid", Getpgid),
+ 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 124: syscalls.Supported("getsid", Getsid),
+ 125: syscalls.Supported("capget", Capget),
+ 126: syscalls.Supported("capset", Capset),
+ 127: syscalls.Supported("rt_sigpending", RtSigpending),
+ 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
+ 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
+ 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
+ 131: syscalls.Supported("sigaltstack", Sigaltstack),
+ 132: syscalls.Supported("utime", Utime),
+ 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil),
+ 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil),
+ 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil),
+ 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
+ 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
+ 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
+ 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
+ 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
+ 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
+ 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
+ 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
+ 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
+ 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
+ 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
+ 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
+ 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
+ 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil),
+ 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
+ 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil),
+ 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
+ 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil),
+ 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
+ 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
+ 161: syscalls.Supported("chroot", Chroot),
+ 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
+ 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
+ 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
+ 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
+ 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
+ 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
+ 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
+ 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
+ 170: syscalls.Supported("sethostname", Sethostname),
+ 171: syscalls.Supported("setdomainname", Setdomainname),
+ 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil),
+ 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil),
+ 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil),
+ 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
+ 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
+ 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
+ 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
+ 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
+ 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
+ 186: syscalls.Supported("gettid", Gettid),
+ 187: syscalls.Supported("readahead", Readahead),
+ 188: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
+ 189: syscalls.PartiallySupported("lsetxattr", LSetXattr, "Only supported for tmpfs.", nil),
+ 190: syscalls.PartiallySupported("fsetxattr", FSetXattr, "Only supported for tmpfs.", nil),
+ 191: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
+ 192: syscalls.PartiallySupported("lgetxattr", LGetXattr, "Only supported for tmpfs.", nil),
+ 193: syscalls.PartiallySupported("fgetxattr", FGetXattr, "Only supported for tmpfs.", nil),
+ 194: syscalls.PartiallySupported("listxattr", ListXattr, "Only supported for tmpfs", nil),
+ 195: syscalls.PartiallySupported("llistxattr", LListXattr, "Only supported for tmpfs", nil),
+ 196: syscalls.PartiallySupported("flistxattr", FListXattr, "Only supported for tmpfs", nil),
+ 197: syscalls.PartiallySupported("removexattr", RemoveXattr, "Only supported for tmpfs", nil),
+ 198: syscalls.PartiallySupported("lremovexattr", LRemoveXattr, "Only supported for tmpfs", nil),
+ 199: syscalls.PartiallySupported("fremovexattr", FRemoveXattr, "Only supported for tmpfs", nil),
+ 200: syscalls.Supported("tkill", Tkill),
+ 201: syscalls.Supported("time", Time),
+ 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
+ 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
+ 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
+ 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
+ 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
+ 213: syscalls.Supported("epoll_create", EpollCreate),
+ 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil),
+ 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil),
+ 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 217: syscalls.Supported("getdents64", Getdents64),
+ 218: syscalls.Supported("set_tid_address", SetTidAddress),
+ 219: syscalls.Supported("restart_syscall", RestartSyscall),
+ 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
+ 222: syscalls.Supported("timer_create", TimerCreate),
+ 223: syscalls.Supported("timer_settime", TimerSettime),
+ 224: syscalls.Supported("timer_gettime", TimerGettime),
+ 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
+ 226: syscalls.Supported("timer_delete", TimerDelete),
+ 227: syscalls.Supported("clock_settime", ClockSettime),
+ 228: syscalls.Supported("clock_gettime", ClockGettime),
+ 229: syscalls.Supported("clock_getres", ClockGetres),
+ 230: syscalls.Supported("clock_nanosleep", ClockNanosleep),
+ 231: syscalls.Supported("exit_group", ExitGroup),
+ 232: syscalls.Supported("epoll_wait", EpollWait),
+ 233: syscalls.Supported("epoll_ctl", EpollCtl),
+ 234: syscalls.Supported("tgkill", Tgkill),
+ 235: syscalls.Supported("utimes", Utimes),
+ 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil),
+ 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
+ 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
+ 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
+ 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
+ 247: syscalls.Supported("waitid", Waitid),
+ 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
+ 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
+ 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
+ 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil),
+ 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
+ 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
+ 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
+ 257: syscalls.Supported("openat", Openat),
+ 258: syscalls.Supported("mkdirat", Mkdirat),
+ 259: syscalls.Supported("mknodat", Mknodat),
+ 260: syscalls.Supported("fchownat", Fchownat),
+ 261: syscalls.Supported("futimesat", Futimesat),
+ 262: syscalls.Supported("fstatat", Fstatat),
+ 263: syscalls.Supported("unlinkat", Unlinkat),
+ 264: syscalls.Supported("renameat", Renameat),
+ 265: syscalls.Supported("linkat", Linkat),
+ 266: syscalls.Supported("symlinkat", Symlinkat),
+ 267: syscalls.Supported("readlinkat", Readlinkat),
+ 268: syscalls.Supported("fchmodat", Fchmodat),
+ 269: syscalls.Supported("faccessat", Faccessat),
+ 270: syscalls.Supported("pselect", Pselect),
+ 271: syscalls.Supported("ppoll", Ppoll),
+ 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
+ 273: syscalls.Supported("set_robust_list", SetRobustList),
+ 274: syscalls.Supported("get_robust_list", GetRobustList),
+ 275: syscalls.Supported("splice", Splice),
+ 276: syscalls.Supported("tee", Tee),
+ 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
+ 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
+ 280: syscalls.Supported("utimensat", Utimensat),
+ 281: syscalls.Supported("epoll_pwait", EpollPwait),
+ 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 283: syscalls.Supported("timerfd_create", TimerfdCreate),
+ 284: syscalls.Supported("eventfd", Eventfd),
+ 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
+ 286: syscalls.Supported("timerfd_settime", TimerfdSettime),
+ 287: syscalls.Supported("timerfd_gettime", TimerfdGettime),
+ 288: syscalls.Supported("accept4", Accept4),
+ 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 290: syscalls.Supported("eventfd2", Eventfd2),
+ 291: syscalls.Supported("epoll_create1", EpollCreate1),
+ 292: syscalls.Supported("dup3", Dup3),
+ 293: syscalls.Supported("pipe2", Pipe2),
+ 294: syscalls.Supported("inotify_init1", InotifyInit1),
+ 295: syscalls.Supported("preadv", Preadv),
+ 296: syscalls.Supported("pwritev", Pwritev),
+ 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
+ 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
+ 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
+ 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 302: syscalls.Supported("prlimit64", Prlimit64),
+ 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
+ 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
+ 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
+ 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
+ 309: syscalls.Supported("getcpu", Getcpu),
+ 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
+ 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
+ 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 317: syscalls.Supported("seccomp", Seccomp),
+ 318: syscalls.Supported("getrandom", GetRandom),
+ 319: syscalls.Supported("memfd_create", MemfdCreate),
+ 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
+ 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
+ 322: syscalls.Supported("execveat", Execveat),
+ 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+
+ // Syscalls implemented after 325 are "backports" from versions
+ // of Linux after 4.4.
+ 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 327: syscalls.Supported("preadv2", Preadv2),
+ 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 329: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 330: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 331: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
+ 332: syscalls.Supported("statx", Statx),
+ 333: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 334: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ },
+ Emulate: map[usermem.Addr]uintptr{
+ 0xffffffffff600000: 96, // vsyscall gettimeofday(2)
+ 0xffffffffff600400: 201, // vsyscall time(2)
+ 0xffffffffff600800: 309, // vsyscall getcpu(2)
+ },
+ Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, syserror.ENOSYS
+ },
+}
+
+// ARM64 is a table of Linux arm64 syscall API with the corresponding syscall
+// numbers from Linux 4.4.
+var ARM64 = &kernel.SyscallTable{
+ OS: abi.Linux,
+ Arch: arch.ARM64,
+ Version: kernel.Version{
+ Sysname: LinuxSysname,
+ Release: LinuxRelease,
+ Version: LinuxVersion,
+ },
+ AuditNumber: linux.AUDIT_ARCH_AARCH64,
+ Table: map[uintptr]kernel.Syscall{
+ 0: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 1: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
+ 5: syscalls.PartiallySupported("setxattr", SetXattr, "Only supported for tmpfs.", nil),
+ 6: syscalls.PartiallySupported("lsetxattr", LSetXattr, "Only supported for tmpfs.", nil),
+ 7: syscalls.PartiallySupported("fsetxattr", FSetXattr, "Only supported for tmpfs.", nil),
+ 8: syscalls.PartiallySupported("getxattr", GetXattr, "Only supported for tmpfs.", nil),
+ 9: syscalls.PartiallySupported("lgetxattr", LGetXattr, "Only supported for tmpfs.", nil),
+ 10: syscalls.PartiallySupported("fgetxattr", FGetXattr, "Only supported for tmpfs.", nil),
+ 11: syscalls.PartiallySupported("listxattr", ListXattr, "Only supported for tmpfs", nil),
+ 12: syscalls.PartiallySupported("llistxattr", LListXattr, "Only supported for tmpfs", nil),
+ 13: syscalls.PartiallySupported("flistxattr", FListXattr, "Only supported for tmpfs", nil),
+ 14: syscalls.PartiallySupported("removexattr", RemoveXattr, "Only supported for tmpfs", nil),
+ 15: syscalls.PartiallySupported("lremovexattr", LRemoveXattr, "Only supported for tmpfs", nil),
+ 16: syscalls.PartiallySupported("fremovexattr", FRemoveXattr, "Only supported for tmpfs", nil),
+ 17: syscalls.Supported("getcwd", Getcwd),
+ 18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
+ 19: syscalls.Supported("eventfd2", Eventfd2),
+ 20: syscalls.Supported("epoll_create1", EpollCreate1),
+ 21: syscalls.Supported("epoll_ctl", EpollCtl),
+ 22: syscalls.Supported("epoll_pwait", EpollPwait),
+ 23: syscalls.Supported("dup", Dup),
+ 24: syscalls.Supported("dup3", Dup3),
+ 25: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
+ 26: syscalls.Supported("inotify_init1", InotifyInit1),
+ 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
+ 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
+ 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
+ 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
+ 32: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
+ 33: syscalls.Supported("mknodat", Mknodat),
+ 34: syscalls.Supported("mkdirat", Mkdirat),
+ 35: syscalls.Supported("unlinkat", Unlinkat),
+ 36: syscalls.Supported("symlinkat", Symlinkat),
+ 37: syscalls.Supported("linkat", Linkat),
+ 38: syscalls.Supported("renameat", Renameat),
+ 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
+ 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
+ 41: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
+ 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
+ 43: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
+ 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
+ 45: syscalls.Supported("truncate", Truncate),
+ 46: syscalls.Supported("ftruncate", Ftruncate),
+ 47: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
+ 48: syscalls.Supported("faccessat", Faccessat),
+ 49: syscalls.Supported("chdir", Chdir),
+ 50: syscalls.Supported("fchdir", Fchdir),
+ 51: syscalls.Supported("chroot", Chroot),
+ 52: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
+ 53: syscalls.Supported("fchmodat", Fchmodat),
+ 54: syscalls.Supported("fchownat", Fchownat),
+ 55: syscalls.Supported("fchown", Fchown),
+ 56: syscalls.Supported("openat", Openat),
+ 57: syscalls.Supported("close", Close),
+ 58: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
+ 59: syscalls.Supported("pipe2", Pipe2),
+ 60: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
+ 61: syscalls.Supported("getdents64", Getdents64),
+ 62: syscalls.Supported("lseek", Lseek),
+ 63: syscalls.Supported("read", Read),
+ 64: syscalls.Supported("write", Write),
+ 65: syscalls.Supported("readv", Readv),
+ 66: syscalls.Supported("writev", Writev),
+ 67: syscalls.Supported("pread64", Pread64),
+ 68: syscalls.Supported("pwrite64", Pwrite64),
+ 69: syscalls.Supported("preadv", Preadv),
+ 70: syscalls.Supported("pwritev", Pwritev),
+ 71: syscalls.Supported("sendfile", Sendfile),
+ 72: syscalls.Supported("pselect", Pselect),
+ 73: syscalls.Supported("ppoll", Ppoll),
+ 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
+ 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 77: syscalls.Supported("tee", Tee),
+ 78: syscalls.Supported("readlinkat", Readlinkat),
+ 79: syscalls.Supported("fstatat", Fstatat),
+ 80: syscalls.Supported("fstat", Fstat),
+ 81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
+ 82: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
+ 83: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
+ 84: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
+ 85: syscalls.Supported("timerfd_create", TimerfdCreate),
+ 86: syscalls.Supported("timerfd_settime", TimerfdSettime),
+ 87: syscalls.Supported("timerfd_gettime", TimerfdGettime),
+ 88: syscalls.Supported("utimensat", Utimensat),
+ 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
+ 90: syscalls.Supported("capget", Capget),
+ 91: syscalls.Supported("capset", Capset),
+ 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
+ 93: syscalls.Supported("exit", Exit),
+ 94: syscalls.Supported("exit_group", ExitGroup),
+ 95: syscalls.Supported("waitid", Waitid),
+ 96: syscalls.Supported("set_tid_address", SetTidAddress),
+ 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
+ 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
+ 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
+ 101: syscalls.Supported("nanosleep", Nanosleep),
+ 102: syscalls.Supported("getitimer", Getitimer),
+ 103: syscalls.Supported("setitimer", Setitimer),
+ 104: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
+ 105: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
+ 106: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
+ 107: syscalls.Supported("timer_create", TimerCreate),
+ 108: syscalls.Supported("timer_gettime", TimerGettime),
+ 109: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
+ 110: syscalls.Supported("timer_settime", TimerSettime),
+ 111: syscalls.Supported("timer_delete", TimerDelete),
+ 112: syscalls.Supported("clock_settime", ClockSettime),
+ 113: syscalls.Supported("clock_gettime", ClockGettime),
+ 114: syscalls.Supported("clock_getres", ClockGetres),
+ 115: syscalls.Supported("clock_nanosleep", ClockNanosleep),
+ 116: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
+ 117: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
+ 118: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
+ 119: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
+ 120: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
+ 121: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
+ 122: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
+ 123: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
+ 124: syscalls.Supported("sched_yield", SchedYield),
+ 125: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
+ 126: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
+ 127: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
+ 128: syscalls.Supported("restart_syscall", RestartSyscall),
+ 129: syscalls.Supported("kill", Kill),
+ 130: syscalls.Supported("tkill", Tkill),
+ 131: syscalls.Supported("tgkill", Tgkill),
+ 132: syscalls.Supported("sigaltstack", Sigaltstack),
+ 133: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
+ 134: syscalls.Supported("rt_sigaction", RtSigaction),
+ 135: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
+ 136: syscalls.Supported("rt_sigpending", RtSigpending),
+ 137: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
+ 138: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
+ 139: syscalls.Supported("rt_sigreturn", RtSigreturn),
+ 140: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
+ 141: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
+ 142: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
+ 143: syscalls.Supported("setregid", Setregid),
+ 144: syscalls.Supported("setgid", Setgid),
+ 145: syscalls.Supported("setreuid", Setreuid),
+ 146: syscalls.Supported("setuid", Setuid),
+ 147: syscalls.Supported("setresuid", Setresuid),
+ 148: syscalls.Supported("getresuid", Getresuid),
+ 149: syscalls.Supported("setresgid", Setresgid),
+ 150: syscalls.Supported("getresgid", Getresgid),
+ 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
+ 153: syscalls.Supported("times", Times),
+ 154: syscalls.Supported("setpgid", Setpgid),
+ 155: syscalls.Supported("getpgid", Getpgid),
+ 156: syscalls.Supported("getsid", Getsid),
+ 157: syscalls.Supported("setsid", Setsid),
+ 158: syscalls.Supported("getgroups", Getgroups),
+ 159: syscalls.Supported("setgroups", Setgroups),
+ 160: syscalls.Supported("uname", Uname),
+ 161: syscalls.Supported("sethostname", Sethostname),
+ 162: syscalls.Supported("setdomainname", Setdomainname),
+ 163: syscalls.Supported("getrlimit", Getrlimit),
+ 164: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
+ 165: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
+ 166: syscalls.Supported("umask", Umask),
+ 167: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
+ 168: syscalls.Supported("getcpu", Getcpu),
+ 169: syscalls.Supported("gettimeofday", Gettimeofday),
+ 170: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
+ 171: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
+ 172: syscalls.Supported("getpid", Getpid),
+ 173: syscalls.Supported("getppid", Getppid),
+ 174: syscalls.Supported("getuid", Getuid),
+ 175: syscalls.Supported("geteuid", Geteuid),
+ 176: syscalls.Supported("getgid", Getgid),
+ 177: syscalls.Supported("getegid", Getegid),
+ 178: syscalls.Supported("gettid", Gettid),
+ 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
+ 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
+ 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 190: syscalls.Supported("semget", Semget),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
+ 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
+ 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
+ 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
+ 196: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
+ 197: syscalls.Supported("shmdt", Shmdt),
+ 198: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
+ 199: syscalls.Supported("socketpair", SocketPair),
+ 200: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
+ 201: syscalls.Supported("listen", Listen),
+ 202: syscalls.Supported("accept", Accept),
+ 203: syscalls.Supported("connect", Connect),
+ 204: syscalls.Supported("getsockname", GetSockName),
+ 205: syscalls.Supported("getpeername", GetPeerName),
+ 206: syscalls.Supported("sendto", SendTo),
+ 207: syscalls.Supported("recvfrom", RecvFrom),
+ 208: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
+ 209: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
+ 210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
+ 211: syscalls.Supported("sendmsg", SendMsg),
+ 212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
+ 213: syscalls.Supported("readahead", Readahead),
+ 214: syscalls.Supported("brk", Brk),
+ 215: syscalls.Supported("munmap", Munmap),
+ 216: syscalls.Supported("mremap", Mremap),
+ 217: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
+ 218: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
+ 219: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
+ 220: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
+ 221: syscalls.Supported("execve", Execve),
+ 222: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil),
+ 223: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
+ 224: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
+ 225: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
+ 226: syscalls.Supported("mprotect", Mprotect),
+ 227: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
+ 228: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 229: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 230: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+ 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
+ 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
+ 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
+ 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
+ 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
+ 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
+ 238: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
+ 239: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
+ 240: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
+ 241: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
+ 242: syscalls.Supported("accept4", Accept4),
+ 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
+ 260: syscalls.Supported("wait4", Wait4),
+ 261: syscalls.Supported("prlimit64", Prlimit64),
+ 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
+ 264: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 265: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
+ 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
+ 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
+ 268: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
+ 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
+ 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
+ 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
+ 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
+ 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
+ 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
+ 277: syscalls.Supported("seccomp", Seccomp),
+ 278: syscalls.Supported("getrandom", GetRandom),
+ 279: syscalls.Supported("memfd_create", MemfdCreate),
+ 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
+ 281: syscalls.Supported("execveat", Execveat),
+ 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
+ 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(gvisor.dev/issue/267)
+ 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
+
+ // Syscalls after 284 are "backports" from versions of Linux after 4.4.
+ 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
+ 286: syscalls.Supported("preadv2", Preadv2),
+ 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
+ 288: syscalls.ErrorWithEvent("pkey_mprotect", syserror.ENOSYS, "", nil),
+ 289: syscalls.ErrorWithEvent("pkey_alloc", syserror.ENOSYS, "", nil),
+ 290: syscalls.ErrorWithEvent("pkey_free", syserror.ENOSYS, "", nil),
+ 291: syscalls.Supported("statx", Statx),
+ 292: syscalls.ErrorWithEvent("io_pgetevents", syserror.ENOSYS, "", nil),
+ 293: syscalls.PartiallySupported("rseq", RSeq, "Not supported on all platforms.", nil),
+
+ // Linux skips ahead to syscall 424 to sync numbers between arches.
+ 424: syscalls.ErrorWithEvent("pidfd_send_signal", syserror.ENOSYS, "", nil),
+ 425: syscalls.ErrorWithEvent("io_uring_setup", syserror.ENOSYS, "", nil),
+ 426: syscalls.ErrorWithEvent("io_uring_enter", syserror.ENOSYS, "", nil),
+ 427: syscalls.ErrorWithEvent("io_uring_register", syserror.ENOSYS, "", nil),
+ 428: syscalls.ErrorWithEvent("open_tree", syserror.ENOSYS, "", nil),
+ 429: syscalls.ErrorWithEvent("move_mount", syserror.ENOSYS, "", nil),
+ 430: syscalls.ErrorWithEvent("fsopen", syserror.ENOSYS, "", nil),
+ 431: syscalls.ErrorWithEvent("fsconfig", syserror.ENOSYS, "", nil),
+ 432: syscalls.ErrorWithEvent("fsmount", syserror.ENOSYS, "", nil),
+ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil),
+ 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil),
+ 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil),
+ },
+ Emulate: map[usermem.Addr]uintptr{},
+ Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, syserror.ENOSYS
+ },
+}
+
+func init() {
+ kernel.RegisterSyscallTable(AMD64)
+ kernel.RegisterSyscallTable(ARM64)
+}
diff --git a/pkg/sentry/syscalls/linux/linux64_amd64.go b/pkg/sentry/syscalls/linux/linux64_amd64.go
deleted file mode 100644
index 3021440ed..000000000
--- a/pkg/sentry/syscalls/linux/linux64_amd64.go
+++ /dev/null
@@ -1,386 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package linux
-
-import (
- "gvisor.dev/gvisor/pkg/abi"
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// AMD64 is a table of Linux amd64 syscall API with the corresponding syscall
-// numbers from Linux 4.4.
-var AMD64 = &kernel.SyscallTable{
- OS: abi.Linux,
- Arch: arch.AMD64,
- Version: kernel.Version{
- // Version 4.4 is chosen as a stable, longterm version of Linux, which
- // guides the interface provided by this syscall table. The build
- // version is that for a clean build with default kernel config, at 5
- // minutes after v4.4 was tagged.
- Sysname: LinuxSysname,
- Release: LinuxRelease,
- Version: LinuxVersion,
- },
- AuditNumber: linux.AUDIT_ARCH_X86_64,
- Table: map[uintptr]kernel.Syscall{
- 0: syscalls.Supported("read", Read),
- 1: syscalls.Supported("write", Write),
- 2: syscalls.PartiallySupported("open", Open, "Options O_DIRECT, O_NOATIME, O_PATH, O_TMPFILE, O_SYNC are not supported.", nil),
- 3: syscalls.Supported("close", Close),
- 4: syscalls.Supported("stat", Stat),
- 5: syscalls.Supported("fstat", Fstat),
- 6: syscalls.Supported("lstat", Lstat),
- 7: syscalls.Supported("poll", Poll),
- 8: syscalls.Supported("lseek", Lseek),
- 9: syscalls.PartiallySupported("mmap", Mmap, "Generally supported with exceptions. Options MAP_FIXED_NOREPLACE, MAP_SHARED_VALIDATE, MAP_SYNC MAP_GROWSDOWN, MAP_HUGETLB are not supported.", nil),
- 10: syscalls.Supported("mprotect", Mprotect),
- 11: syscalls.Supported("munmap", Munmap),
- 12: syscalls.Supported("brk", Brk),
- 13: syscalls.Supported("rt_sigaction", RtSigaction),
- 14: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
- 15: syscalls.Supported("rt_sigreturn", RtSigreturn),
- 16: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
- 17: syscalls.Supported("pread64", Pread64),
- 18: syscalls.Supported("pwrite64", Pwrite64),
- 19: syscalls.Supported("readv", Readv),
- 20: syscalls.Supported("writev", Writev),
- 21: syscalls.Supported("access", Access),
- 22: syscalls.Supported("pipe", Pipe),
- 23: syscalls.Supported("select", Select),
- 24: syscalls.Supported("sched_yield", SchedYield),
- 25: syscalls.Supported("mremap", Mremap),
- 26: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
- 27: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
- 28: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
- 29: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
- 30: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
- 31: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
- 32: syscalls.Supported("dup", Dup),
- 33: syscalls.Supported("dup2", Dup2),
- 34: syscalls.Supported("pause", Pause),
- 35: syscalls.Supported("nanosleep", Nanosleep),
- 36: syscalls.Supported("getitimer", Getitimer),
- 37: syscalls.Supported("alarm", Alarm),
- 38: syscalls.Supported("setitimer", Setitimer),
- 39: syscalls.Supported("getpid", Getpid),
- 40: syscalls.Supported("sendfile", Sendfile),
- 41: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
- 42: syscalls.Supported("connect", Connect),
- 43: syscalls.Supported("accept", Accept),
- 44: syscalls.Supported("sendto", SendTo),
- 45: syscalls.Supported("recvfrom", RecvFrom),
- 46: syscalls.Supported("sendmsg", SendMsg),
- 47: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
- 48: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
- 49: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
- 50: syscalls.Supported("listen", Listen),
- 51: syscalls.Supported("getsockname", GetSockName),
- 52: syscalls.Supported("getpeername", GetPeerName),
- 53: syscalls.Supported("socketpair", SocketPair),
- 54: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
- 55: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
- 56: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
- 57: syscalls.Supported("fork", Fork),
- 58: syscalls.Supported("vfork", Vfork),
- 59: syscalls.Supported("execve", Execve),
- 60: syscalls.Supported("exit", Exit),
- 61: syscalls.Supported("wait4", Wait4),
- 62: syscalls.Supported("kill", Kill),
- 63: syscalls.Supported("uname", Uname),
- 64: syscalls.Supported("semget", Semget),
- 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
- 67: syscalls.Supported("shmdt", Shmdt),
- 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 70: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 71: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
- 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
- 74: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
- 75: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
- 76: syscalls.Supported("truncate", Truncate),
- 77: syscalls.Supported("ftruncate", Ftruncate),
- 78: syscalls.Supported("getdents", Getdents),
- 79: syscalls.Supported("getcwd", Getcwd),
- 80: syscalls.Supported("chdir", Chdir),
- 81: syscalls.Supported("fchdir", Fchdir),
- 82: syscalls.Supported("rename", Rename),
- 83: syscalls.Supported("mkdir", Mkdir),
- 84: syscalls.Supported("rmdir", Rmdir),
- 85: syscalls.Supported("creat", Creat),
- 86: syscalls.Supported("link", Link),
- 87: syscalls.Supported("unlink", Unlink),
- 88: syscalls.Supported("symlink", Symlink),
- 89: syscalls.Supported("readlink", Readlink),
- 90: syscalls.Supported("chmod", Chmod),
- 91: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
- 92: syscalls.Supported("chown", Chown),
- 93: syscalls.Supported("fchown", Fchown),
- 94: syscalls.Supported("lchown", Lchown),
- 95: syscalls.Supported("umask", Umask),
- 96: syscalls.Supported("gettimeofday", Gettimeofday),
- 97: syscalls.Supported("getrlimit", Getrlimit),
- 98: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
- 99: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
- 100: syscalls.Supported("times", Times),
- 101: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
- 102: syscalls.Supported("getuid", Getuid),
- 103: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
- 104: syscalls.Supported("getgid", Getgid),
- 105: syscalls.Supported("setuid", Setuid),
- 106: syscalls.Supported("setgid", Setgid),
- 107: syscalls.Supported("geteuid", Geteuid),
- 108: syscalls.Supported("getegid", Getegid),
- 109: syscalls.Supported("setpgid", Setpgid),
- 110: syscalls.Supported("getppid", Getppid),
- 111: syscalls.Supported("getpgrp", Getpgrp),
- 112: syscalls.Supported("setsid", Setsid),
- 113: syscalls.Supported("setreuid", Setreuid),
- 114: syscalls.Supported("setregid", Setregid),
- 115: syscalls.Supported("getgroups", Getgroups),
- 116: syscalls.Supported("setgroups", Setgroups),
- 117: syscalls.Supported("setresuid", Setresuid),
- 118: syscalls.Supported("getresuid", Getresuid),
- 119: syscalls.Supported("setresgid", Setresgid),
- 120: syscalls.Supported("getresgid", Getresgid),
- 121: syscalls.Supported("getpgid", Getpgid),
- 122: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 123: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 124: syscalls.Supported("getsid", Getsid),
- 125: syscalls.Supported("capget", Capget),
- 126: syscalls.Supported("capset", Capset),
- 127: syscalls.Supported("rt_sigpending", RtSigpending),
- 128: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
- 129: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
- 130: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
- 131: syscalls.Supported("sigaltstack", Sigaltstack),
- 132: syscalls.Supported("utime", Utime),
- 133: syscalls.PartiallySupported("mknod", Mknod, "Device creation is not generally supported. Only regular file and FIFO creation are supported.", nil),
- 134: syscalls.Error("uselib", syserror.ENOSYS, "Obsolete", nil),
- 135: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
- 136: syscalls.ErrorWithEvent("ustat", syserror.ENOSYS, "Needs filesystem support.", nil),
- 137: syscalls.PartiallySupported("statfs", Statfs, "Depends on the backing file system implementation.", nil),
- 138: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
- 139: syscalls.ErrorWithEvent("sysfs", syserror.ENOSYS, "", []string{"gvisor.dev/issue/165"}),
- 140: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
- 141: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
- 142: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
- 143: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
- 144: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
- 145: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
- 146: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
- 147: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
- 148: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
- 149: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 150: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 151: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 152: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 153: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
- 154: syscalls.Error("modify_ldt", syserror.EPERM, "", nil),
- 155: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
- 156: syscalls.Error("sysctl", syserror.EPERM, "Deprecated. Use /proc/sys instead.", nil),
- 157: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
- 158: syscalls.PartiallySupported("arch_prctl", ArchPrctl, "Options ARCH_GET_GS, ARCH_SET_GS not supported.", nil),
- 159: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
- 160: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
- 161: syscalls.Supported("chroot", Chroot),
- 162: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
- 163: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
- 164: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
- 165: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
- 166: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
- 167: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
- 168: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
- 169: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
- 170: syscalls.Supported("sethostname", Sethostname),
- 171: syscalls.Supported("setdomainname", Setdomainname),
- 172: syscalls.CapError("iopl", linux.CAP_SYS_RAWIO, "", nil),
- 173: syscalls.CapError("ioperm", linux.CAP_SYS_RAWIO, "", nil),
- 174: syscalls.CapError("create_module", linux.CAP_SYS_MODULE, "", nil),
- 175: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
- 176: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
- 177: syscalls.Error("get_kernel_syms", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
- 178: syscalls.Error("query_module", syserror.ENOSYS, "Not supported in Linux > 2.6.", nil),
- 179: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
- 180: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
- 181: syscalls.Error("getpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 182: syscalls.Error("putpmsg", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 183: syscalls.Error("afs_syscall", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
- 186: syscalls.Supported("gettid", Gettid),
- 187: syscalls.Supported("readahead", Readahead),
- 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 194: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 195: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 196: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 197: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 198: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 199: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 200: syscalls.Supported("tkill", Tkill),
- 201: syscalls.Supported("time", Time),
- 202: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
- 203: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
- 204: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
- 205: syscalls.Error("set_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
- 206: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 207: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 208: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 209: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 210: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 211: syscalls.Error("get_thread_area", syserror.ENOSYS, "Expected to return ENOSYS on 64-bit", nil),
- 212: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
- 213: syscalls.Supported("epoll_create", EpollCreate),
- 214: syscalls.ErrorWithEvent("epoll_ctl_old", syserror.ENOSYS, "Deprecated.", nil),
- 215: syscalls.ErrorWithEvent("epoll_wait_old", syserror.ENOSYS, "Deprecated.", nil),
- 216: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
- 217: syscalls.Supported("getdents64", Getdents64),
- 218: syscalls.Supported("set_tid_address", SetTidAddress),
- 219: syscalls.Supported("restart_syscall", RestartSyscall),
- 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
- 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
- 222: syscalls.Supported("timer_create", TimerCreate),
- 223: syscalls.Supported("timer_settime", TimerSettime),
- 224: syscalls.Supported("timer_gettime", TimerGettime),
- 225: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
- 226: syscalls.Supported("timer_delete", TimerDelete),
- 227: syscalls.Supported("clock_settime", ClockSettime),
- 228: syscalls.Supported("clock_gettime", ClockGettime),
- 229: syscalls.Supported("clock_getres", ClockGetres),
- 230: syscalls.Supported("clock_nanosleep", ClockNanosleep),
- 231: syscalls.Supported("exit_group", ExitGroup),
- 232: syscalls.Supported("epoll_wait", EpollWait),
- 233: syscalls.Supported("epoll_ctl", EpollCtl),
- 234: syscalls.Supported("tgkill", Tgkill),
- 235: syscalls.Supported("utimes", Utimes),
- 236: syscalls.Error("vserver", syserror.ENOSYS, "Not implemented by Linux", nil),
- 237: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
- 238: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
- 239: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
- 240: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 241: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 242: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 243: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 244: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 245: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 246: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
- 247: syscalls.Supported("waitid", Waitid),
- 248: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
- 249: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
- 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
- 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
- 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
- 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil),
- 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
- 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
- 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
- 257: syscalls.Supported("openat", Openat),
- 258: syscalls.Supported("mkdirat", Mkdirat),
- 259: syscalls.Supported("mknodat", Mknodat),
- 260: syscalls.Supported("fchownat", Fchownat),
- 261: syscalls.Supported("futimesat", Futimesat),
- 262: syscalls.Supported("fstatat", Fstatat),
- 263: syscalls.Supported("unlinkat", Unlinkat),
- 264: syscalls.Supported("renameat", Renameat),
- 265: syscalls.Supported("linkat", Linkat),
- 266: syscalls.Supported("symlinkat", Symlinkat),
- 267: syscalls.Supported("readlinkat", Readlinkat),
- 268: syscalls.Supported("fchmodat", Fchmodat),
- 269: syscalls.Supported("faccessat", Faccessat),
- 270: syscalls.Supported("pselect", Pselect),
- 271: syscalls.Supported("ppoll", Ppoll),
- 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
- 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 275: syscalls.Supported("splice", Splice),
- 276: syscalls.Supported("tee", Tee),
- 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
- 278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
- 280: syscalls.Supported("utimensat", Utimensat),
- 281: syscalls.Supported("epoll_pwait", EpollPwait),
- 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
- 283: syscalls.Supported("timerfd_create", TimerfdCreate),
- 284: syscalls.Supported("eventfd", Eventfd),
- 285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
- 286: syscalls.Supported("timerfd_settime", TimerfdSettime),
- 287: syscalls.Supported("timerfd_gettime", TimerfdGettime),
- 288: syscalls.Supported("accept4", Accept4),
- 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
- 290: syscalls.Supported("eventfd2", Eventfd2),
- 291: syscalls.Supported("epoll_create1", EpollCreate1),
- 292: syscalls.Supported("dup3", Dup3),
- 293: syscalls.Supported("pipe2", Pipe2),
- 294: syscalls.Supported("inotify_init1", InotifyInit1),
- 295: syscalls.Supported("preadv", Preadv),
- 296: syscalls.Supported("pwritev", Pwritev),
- 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
- 298: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
- 299: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
- 300: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 301: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 302: syscalls.Supported("prlimit64", Prlimit64),
- 303: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
- 304: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
- 305: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
- 306: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
- 307: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
- 308: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
- 309: syscalls.Supported("getcpu", Getcpu),
- 310: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 311: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 312: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
- 313: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
- 314: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 315: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 316: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
- 317: syscalls.Supported("seccomp", Seccomp),
- 318: syscalls.Supported("getrandom", GetRandom),
- 319: syscalls.Supported("memfd_create", MemfdCreate),
- 320: syscalls.CapError("kexec_file_load", linux.CAP_SYS_BOOT, "", nil),
- 321: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
- 322: syscalls.Supported("execveat", Execveat),
- 323: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 324: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
- 325: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
-
- // Syscalls after 325 are "backports" from versions of Linux after 4.4.
- 326: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
- 327: syscalls.Supported("preadv2", Preadv2),
- 328: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
- 332: syscalls.Supported("statx", Statx),
- },
-
- Emulate: map[usermem.Addr]uintptr{
- 0xffffffffff600000: 96, // vsyscall gettimeofday(2)
- 0xffffffffff600400: 201, // vsyscall time(2)
- 0xffffffffff600800: 309, // vsyscall getcpu(2)
- },
- Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
- t.Kernel().EmitUnimplementedEvent(t)
- return 0, syserror.ENOSYS
- },
-}
diff --git a/pkg/sentry/syscalls/linux/linux64_arm64.go b/pkg/sentry/syscalls/linux/linux64_arm64.go
deleted file mode 100644
index 4cf7f836a..000000000
--- a/pkg/sentry/syscalls/linux/linux64_arm64.go
+++ /dev/null
@@ -1,313 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package linux
-
-import (
- "gvisor.dev/gvisor/pkg/abi"
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// ARM64 is a table of Linux arm64 syscall API with the corresponding syscall
-// numbers from Linux 4.4.
-var ARM64 = &kernel.SyscallTable{
- OS: abi.Linux,
- Arch: arch.ARM64,
- Version: kernel.Version{
- Sysname: LinuxSysname,
- Release: LinuxRelease,
- Version: LinuxVersion,
- },
- AuditNumber: linux.AUDIT_ARCH_AARCH64,
- Table: map[uintptr]kernel.Syscall{
- 0: syscalls.PartiallySupported("io_setup", IoSetup, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 1: syscalls.PartiallySupported("io_destroy", IoDestroy, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 2: syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 3: syscalls.PartiallySupported("io_cancel", IoCancel, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 4: syscalls.PartiallySupported("io_getevents", IoGetevents, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}),
- 5: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 6: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 7: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 8: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 9: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 10: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 11: syscalls.ErrorWithEvent("listxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 12: syscalls.ErrorWithEvent("llistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 13: syscalls.ErrorWithEvent("flistxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 14: syscalls.ErrorWithEvent("removexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 15: syscalls.ErrorWithEvent("lremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 16: syscalls.ErrorWithEvent("fremovexattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
- 17: syscalls.Supported("getcwd", Getcwd),
- 18: syscalls.CapError("lookup_dcookie", linux.CAP_SYS_ADMIN, "", nil),
- 19: syscalls.Supported("eventfd2", Eventfd2),
- 20: syscalls.Supported("epoll_create1", EpollCreate1),
- 21: syscalls.Supported("epoll_ctl", EpollCtl),
- 22: syscalls.Supported("epoll_pwait", EpollPwait),
- 23: syscalls.Supported("dup", Dup),
- 24: syscalls.Supported("dup3", Dup3),
- 26: syscalls.Supported("inotify_init1", InotifyInit1),
- 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil),
- 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil),
- 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil),
- 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
- 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending)
- 32: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
- 33: syscalls.Supported("mknodat", Mknodat),
- 34: syscalls.Supported("mkdirat", Mkdirat),
- 35: syscalls.Supported("unlinkat", Unlinkat),
- 36: syscalls.Supported("symlinkat", Symlinkat),
- 37: syscalls.Supported("linkat", Linkat),
- 38: syscalls.Supported("renameat", Renameat),
- 39: syscalls.PartiallySupported("umount2", Umount2, "Not all options or file systems are supported.", nil),
- 40: syscalls.PartiallySupported("mount", Mount, "Not all options or file systems are supported.", nil),
- 41: syscalls.Error("pivot_root", syserror.EPERM, "", nil),
- 42: syscalls.Error("nfsservctl", syserror.ENOSYS, "Removed after Linux 3.1.", nil),
- 44: syscalls.PartiallySupported("fstatfs", Fstatfs, "Depends on the backing file system implementation.", nil),
- 46: syscalls.Supported("ftruncate", Ftruncate),
- 47: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
- 48: syscalls.Supported("faccessat", Faccessat),
- 49: syscalls.Supported("chdir", Chdir),
- 50: syscalls.Supported("fchdir", Fchdir),
- 51: syscalls.Supported("chroot", Chroot),
- 52: syscalls.PartiallySupported("fchmod", Fchmod, "Options S_ISUID and S_ISGID not supported.", nil),
- 53: syscalls.Supported("fchmodat", Fchmodat),
- 54: syscalls.Supported("fchownat", Fchownat),
- 55: syscalls.Supported("fchown", Fchown),
- 56: syscalls.Supported("openat", Openat),
- 57: syscalls.Supported("close", Close),
- 58: syscalls.CapError("vhangup", linux.CAP_SYS_TTY_CONFIG, "", nil),
- 59: syscalls.Supported("pipe2", Pipe2),
- 60: syscalls.CapError("quotactl", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_admin for most operations
- 61: syscalls.Supported("getdents64", Getdents64),
- 62: syscalls.Supported("lseek", Lseek),
- 63: syscalls.Supported("read", Read),
- 64: syscalls.Supported("write", Write),
- 65: syscalls.Supported("readv", Readv),
- 66: syscalls.Supported("writev", Writev),
- 67: syscalls.Supported("pread64", Pread64),
- 68: syscalls.Supported("pwrite64", Pwrite64),
- 69: syscalls.Supported("preadv", Preadv),
- 70: syscalls.Supported("pwritev", Pwritev),
- 71: syscalls.Supported("sendfile", Sendfile),
- 72: syscalls.Supported("pselect", Pselect),
- 73: syscalls.Supported("ppoll", Ppoll),
- 74: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
- 75: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 76: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 77: syscalls.Supported("tee", Tee),
- 78: syscalls.Supported("readlinkat", Readlinkat),
- 80: syscalls.Supported("fstat", Fstat),
- 81: syscalls.PartiallySupported("sync", Sync, "Full data flush is not guaranteed at this time.", nil),
- 82: syscalls.PartiallySupported("fsync", Fsync, "Full data flush is not guaranteed at this time.", nil),
- 83: syscalls.PartiallySupported("fdatasync", Fdatasync, "Full data flush is not guaranteed at this time.", nil),
- 84: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
- 85: syscalls.Supported("timerfd_create", TimerfdCreate),
- 86: syscalls.Supported("timerfd_settime", TimerfdSettime),
- 87: syscalls.Supported("timerfd_gettime", TimerfdGettime),
- 88: syscalls.Supported("utimensat", Utimensat),
- 89: syscalls.CapError("acct", linux.CAP_SYS_PACCT, "", nil),
- 90: syscalls.Supported("capget", Capget),
- 91: syscalls.Supported("capset", Capset),
- 92: syscalls.ErrorWithEvent("personality", syserror.EINVAL, "Unable to change personality.", nil),
- 93: syscalls.Supported("exit", Exit),
- 94: syscalls.Supported("exit_group", ExitGroup),
- 95: syscalls.Supported("waitid", Waitid),
- 96: syscalls.Supported("set_tid_address", SetTidAddress),
- 97: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
- 98: syscalls.PartiallySupported("futex", Futex, "Robust futexes not supported.", nil),
- 99: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 100: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 101: syscalls.Supported("nanosleep", Nanosleep),
- 102: syscalls.Supported("getitimer", Getitimer),
- 103: syscalls.Supported("setitimer", Setitimer),
- 104: syscalls.CapError("kexec_load", linux.CAP_SYS_BOOT, "", nil),
- 105: syscalls.CapError("init_module", linux.CAP_SYS_MODULE, "", nil),
- 106: syscalls.CapError("delete_module", linux.CAP_SYS_MODULE, "", nil),
- 107: syscalls.Supported("timer_create", TimerCreate),
- 108: syscalls.Supported("timer_gettime", TimerGettime),
- 109: syscalls.Supported("timer_getoverrun", TimerGetoverrun),
- 110: syscalls.Supported("timer_settime", TimerSettime),
- 111: syscalls.Supported("timer_delete", TimerDelete),
- 112: syscalls.Supported("clock_settime", ClockSettime),
- 113: syscalls.Supported("clock_gettime", ClockGettime),
- 114: syscalls.Supported("clock_getres", ClockGetres),
- 115: syscalls.Supported("clock_nanosleep", ClockNanosleep),
- 116: syscalls.PartiallySupported("syslog", Syslog, "Outputs a dummy message for security reasons.", nil),
- 117: syscalls.PartiallySupported("ptrace", Ptrace, "Options PTRACE_PEEKSIGINFO, PTRACE_SECCOMP_GET_FILTER not supported.", nil),
- 118: syscalls.CapError("sched_setparam", linux.CAP_SYS_NICE, "", nil),
- 119: syscalls.PartiallySupported("sched_setscheduler", SchedSetscheduler, "Stub implementation.", nil),
- 120: syscalls.PartiallySupported("sched_getscheduler", SchedGetscheduler, "Stub implementation.", nil),
- 121: syscalls.PartiallySupported("sched_getparam", SchedGetparam, "Stub implementation.", nil),
- 122: syscalls.PartiallySupported("sched_setaffinity", SchedSetaffinity, "Stub implementation.", nil),
- 123: syscalls.PartiallySupported("sched_getaffinity", SchedGetaffinity, "Stub implementation.", nil),
- 124: syscalls.Supported("sched_yield", SchedYield),
- 125: syscalls.PartiallySupported("sched_get_priority_max", SchedGetPriorityMax, "Stub implementation.", nil),
- 126: syscalls.PartiallySupported("sched_get_priority_min", SchedGetPriorityMin, "Stub implementation.", nil),
- 127: syscalls.ErrorWithEvent("sched_rr_get_interval", syserror.EPERM, "", nil),
- 128: syscalls.Supported("restart_syscall", RestartSyscall),
- 129: syscalls.Supported("kill", Kill),
- 130: syscalls.Supported("tkill", Tkill),
- 131: syscalls.Supported("tgkill", Tgkill),
- 132: syscalls.Supported("sigaltstack", Sigaltstack),
- 133: syscalls.Supported("rt_sigsuspend", RtSigsuspend),
- 134: syscalls.Supported("rt_sigaction", RtSigaction),
- 135: syscalls.Supported("rt_sigprocmask", RtSigprocmask),
- 136: syscalls.Supported("rt_sigpending", RtSigpending),
- 137: syscalls.Supported("rt_sigtimedwait", RtSigtimedwait),
- 138: syscalls.Supported("rt_sigqueueinfo", RtSigqueueinfo),
- 139: syscalls.Supported("rt_sigreturn", RtSigreturn),
- 140: syscalls.PartiallySupported("setpriority", Setpriority, "Stub implementation.", nil),
- 141: syscalls.PartiallySupported("getpriority", Getpriority, "Stub implementation.", nil),
- 142: syscalls.CapError("reboot", linux.CAP_SYS_BOOT, "", nil),
- 143: syscalls.Supported("setregid", Setregid),
- 144: syscalls.Supported("setgid", Setgid),
- 145: syscalls.Supported("setreuid", Setreuid),
- 146: syscalls.Supported("setuid", Setuid),
- 147: syscalls.Supported("setresuid", Setresuid),
- 148: syscalls.Supported("getresuid", Getresuid),
- 149: syscalls.Supported("setresgid", Setresgid),
- 150: syscalls.Supported("getresgid", Getresgid),
- 151: syscalls.ErrorWithEvent("setfsuid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 152: syscalls.ErrorWithEvent("setfsgid", syserror.ENOSYS, "", []string{"gvisor.dev/issue/260"}), // TODO(b/112851702)
- 153: syscalls.Supported("times", Times),
- 154: syscalls.Supported("setpgid", Setpgid),
- 155: syscalls.Supported("getpgid", Getpgid),
- 156: syscalls.Supported("getsid", Getsid),
- 157: syscalls.Supported("setsid", Setsid),
- 158: syscalls.Supported("getgroups", Getgroups),
- 159: syscalls.Supported("setgroups", Setgroups),
- 160: syscalls.Supported("uname", Uname),
- 161: syscalls.Supported("sethostname", Sethostname),
- 162: syscalls.Supported("setdomainname", Setdomainname),
- 163: syscalls.Supported("getrlimit", Getrlimit),
- 164: syscalls.PartiallySupported("setrlimit", Setrlimit, "Not all rlimits are enforced.", nil),
- 165: syscalls.PartiallySupported("getrusage", Getrusage, "Fields ru_maxrss, ru_minflt, ru_majflt, ru_inblock, ru_oublock are not supported. Fields ru_utime and ru_stime have low precision.", nil),
- 166: syscalls.Supported("umask", Umask),
- 167: syscalls.PartiallySupported("prctl", Prctl, "Not all options are supported.", nil),
- 168: syscalls.Supported("getcpu", Getcpu),
- 169: syscalls.Supported("gettimeofday", Gettimeofday),
- 170: syscalls.CapError("settimeofday", linux.CAP_SYS_TIME, "", nil),
- 171: syscalls.CapError("adjtimex", linux.CAP_SYS_TIME, "", nil),
- 172: syscalls.Supported("getpid", Getpid),
- 173: syscalls.Supported("getppid", Getppid),
- 174: syscalls.Supported("getuid", Getuid),
- 175: syscalls.Supported("geteuid", Geteuid),
- 176: syscalls.Supported("getgid", Getgid),
- 177: syscalls.Supported("getegid", Getegid),
- 178: syscalls.Supported("gettid", Gettid),
- 179: syscalls.PartiallySupported("sysinfo", Sysinfo, "Fields loads, sharedram, bufferram, totalswap, freeswap, totalhigh, freehigh not supported.", nil),
- 180: syscalls.ErrorWithEvent("mq_open", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 181: syscalls.ErrorWithEvent("mq_unlink", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 182: syscalls.ErrorWithEvent("mq_timedsend", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 183: syscalls.ErrorWithEvent("mq_timedreceive", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 184: syscalls.ErrorWithEvent("mq_notify", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
- 186: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 187: syscalls.ErrorWithEvent("msgctl", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil),
- 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), // TODO(b/29354920)
- 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
- 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
- 196: syscalls.PartiallySupported("shmat", Shmat, "Option SHM_RND is not supported.", nil),
- 197: syscalls.Supported("shmdt", Shmdt),
- 198: syscalls.PartiallySupported("socket", Socket, "Limited support for AF_NETLINK, NETLINK_ROUTE sockets. Limited support for SOCK_RAW.", nil),
- 199: syscalls.Supported("socketpair", SocketPair),
- 200: syscalls.PartiallySupported("bind", Bind, "Autobind for abstract Unix sockets is not supported.", nil),
- 201: syscalls.Supported("listen", Listen),
- 202: syscalls.Supported("accept", Accept),
- 203: syscalls.Supported("connect", Connect),
- 204: syscalls.Supported("getsockname", GetSockName),
- 205: syscalls.Supported("getpeername", GetPeerName),
- 206: syscalls.Supported("sendto", SendTo),
- 207: syscalls.Supported("recvfrom", RecvFrom),
- 208: syscalls.PartiallySupported("setsockopt", SetSockOpt, "Not all socket options are supported.", nil),
- 209: syscalls.PartiallySupported("getsockopt", GetSockOpt, "Not all socket options are supported.", nil),
- 210: syscalls.PartiallySupported("shutdown", Shutdown, "Not all flags and control messages are supported.", nil),
- 211: syscalls.Supported("sendmsg", SendMsg),
- 212: syscalls.PartiallySupported("recvmsg", RecvMsg, "Not all flags and control messages are supported.", nil),
- 213: syscalls.Supported("readahead", Readahead),
- 214: syscalls.Supported("brk", Brk),
- 215: syscalls.Supported("munmap", Munmap),
- 216: syscalls.Supported("mremap", Mremap),
- 217: syscalls.Error("add_key", syserror.EACCES, "Not available to user.", nil),
- 218: syscalls.Error("request_key", syserror.EACCES, "Not available to user.", nil),
- 219: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil),
- 220: syscalls.PartiallySupported("clone", Clone, "Mount namespace (CLONE_NEWNS) not supported. Options CLONE_PARENT, CLONE_SYSVSEM not supported.", nil),
- 221: syscalls.Supported("execve", Execve),
- 224: syscalls.CapError("swapon", linux.CAP_SYS_ADMIN, "", nil),
- 225: syscalls.CapError("swapoff", linux.CAP_SYS_ADMIN, "", nil),
- 226: syscalls.Supported("mprotect", Mprotect),
- 227: syscalls.PartiallySupported("msync", Msync, "Full data flush is not guaranteed at this time.", nil),
- 228: syscalls.PartiallySupported("mlock", Mlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 229: syscalls.PartiallySupported("munlock", Munlock, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 230: syscalls.PartiallySupported("mlockall", Mlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 231: syscalls.PartiallySupported("munlockall", Munlockall, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 232: syscalls.PartiallySupported("mincore", Mincore, "Stub implementation. The sandbox does not have access to this information. Reports all mapped pages are resident.", nil),
- 233: syscalls.PartiallySupported("madvise", Madvise, "Options MADV_DONTNEED, MADV_DONTFORK are supported. Other advice is ignored.", nil),
- 234: syscalls.ErrorWithEvent("remap_file_pages", syserror.ENOSYS, "Deprecated since Linux 3.16.", nil),
- 235: syscalls.PartiallySupported("mbind", Mbind, "Stub implementation. Only a single NUMA node is advertised, and mempolicy is ignored accordingly, but mbind() will succeed and has effects reflected by get_mempolicy.", []string{"gvisor.dev/issue/262"}),
- 236: syscalls.PartiallySupported("get_mempolicy", GetMempolicy, "Stub implementation.", nil),
- 237: syscalls.PartiallySupported("set_mempolicy", SetMempolicy, "Stub implementation.", nil),
- 238: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil),
- 239: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
- 240: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo),
- 241: syscalls.ErrorWithEvent("perf_event_open", syserror.ENODEV, "No support for perf counters", nil),
- 242: syscalls.Supported("accept4", Accept4),
- 243: syscalls.PartiallySupported("recvmmsg", RecvMMsg, "Not all flags and control messages are supported.", nil),
- 260: syscalls.Supported("wait4", Wait4),
- 261: syscalls.Supported("prlimit64", Prlimit64),
- 262: syscalls.ErrorWithEvent("fanotify_init", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 263: syscalls.ErrorWithEvent("fanotify_mark", syserror.ENOSYS, "Needs CONFIG_FANOTIFY", nil),
- 264: syscalls.Error("name_to_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
- 265: syscalls.Error("open_by_handle_at", syserror.EOPNOTSUPP, "Not supported by gVisor filesystems", nil),
- 266: syscalls.CapError("clock_adjtime", linux.CAP_SYS_TIME, "", nil),
- 267: syscalls.PartiallySupported("syncfs", Syncfs, "Depends on backing file system.", nil),
- 268: syscalls.ErrorWithEvent("setns", syserror.EOPNOTSUPP, "Needs filesystem support", []string{"gvisor.dev/issue/140"}), // TODO(b/29354995)
- 269: syscalls.PartiallySupported("sendmmsg", SendMMsg, "Not all flags and control messages are supported.", nil),
- 270: syscalls.ErrorWithEvent("process_vm_readv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 271: syscalls.ErrorWithEvent("process_vm_writev", syserror.ENOSYS, "", []string{"gvisor.dev/issue/158"}),
- 272: syscalls.CapError("kcmp", linux.CAP_SYS_PTRACE, "", nil),
- 273: syscalls.CapError("finit_module", linux.CAP_SYS_MODULE, "", nil),
- 274: syscalls.ErrorWithEvent("sched_setattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 275: syscalls.ErrorWithEvent("sched_getattr", syserror.ENOSYS, "gVisor does not implement a scheduler.", []string{"gvisor.dev/issue/264"}), // TODO(b/118902272)
- 276: syscalls.ErrorWithEvent("renameat2", syserror.ENOSYS, "", []string{"gvisor.dev/issue/263"}), // TODO(b/118902772)
- 277: syscalls.Supported("seccomp", Seccomp),
- 278: syscalls.Supported("getrandom", GetRandom),
- 279: syscalls.Supported("memfd_create", MemfdCreate),
- 280: syscalls.CapError("bpf", linux.CAP_SYS_ADMIN, "", nil),
- 281: syscalls.ErrorWithEvent("execveat", syserror.ENOSYS, "", []string{"gvisor.dev/issue/265"}), // TODO(b/118901836)
- 282: syscalls.ErrorWithEvent("userfaultfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/266"}), // TODO(b/118906345)
- 283: syscalls.ErrorWithEvent("membarrier", syserror.ENOSYS, "", []string{"gvisor.dev/issue/267"}), // TODO(b/118904897)
- 284: syscalls.PartiallySupported("mlock2", Mlock2, "Stub implementation. The sandbox lacks appropriate permissions.", nil),
- 285: syscalls.ErrorWithEvent("copy_file_range", syserror.ENOSYS, "", nil),
- 286: syscalls.Supported("preadv2", Preadv2),
- 287: syscalls.PartiallySupported("pwritev2", Pwritev2, "Flag RWF_HIPRI is not supported.", nil),
- 291: syscalls.Supported("statx", Statx),
- },
- Emulate: map[usermem.Addr]uintptr{},
-
- Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
- t.Kernel().EmitUnimplementedEvent(t)
- return 0, syserror.ENOSYS
- },
-}
diff --git a/pkg/sentry/syscalls/linux/sigset.go b/pkg/sentry/syscalls/linux/sigset.go
index 333013d8c..434559b80 100644
--- a/pkg/sentry/syscalls/linux/sigset.go
+++ b/pkg/sentry/syscalls/linux/sigset.go
@@ -17,13 +17,17 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// copyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and
+// CopyInSigSet copies in a sigset_t, checks its size, and ensures that KILL and
// STOP are clear.
-func copyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.SignalSet, error) {
+//
+// TODO(gvisor.dev/issue/1624): This is only exported because
+// syscalls/vfs2/signal.go depends on it. Once vfs1 is deleted and the vfs2
+// syscalls are moved into this package, then they can be unexported.
+func CopyInSigSet(t *kernel.Task, sigSetAddr usermem.Addr, size uint) (linux.SignalSet, error) {
if size != linux.SignalSetSize {
return 0, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index f56411bfe..e9d64dec5 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -15,71 +15,18 @@
package linux
import (
- "encoding/binary"
-
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/eventfd"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
-// I/O commands.
-const (
- _IOCB_CMD_PREAD = 0
- _IOCB_CMD_PWRITE = 1
- _IOCB_CMD_FSYNC = 2
- _IOCB_CMD_FDSYNC = 3
- _IOCB_CMD_NOOP = 6
- _IOCB_CMD_PREADV = 7
- _IOCB_CMD_PWRITEV = 8
-)
-
-// I/O flags.
-const (
- _IOCB_FLAG_RESFD = 1
-)
-
-// ioCallback describes an I/O request.
-//
-// The priority field is currently ignored in the implementation below. Also
-// note that the IOCB_FLAG_RESFD feature is not supported.
-type ioCallback struct {
- Data uint64
- Key uint32
- Reserved1 uint32
-
- OpCode uint16
- ReqPrio int16
- FD int32
-
- Buf uint64
- Bytes uint64
- Offset int64
-
- Reserved2 uint64
- Flags uint32
-
- // eventfd to signal if IOCB_FLAG_RESFD is set in flags.
- ResFD int32
-}
-
-// ioEvent describes an I/O result.
-//
-// +stateify savable
-type ioEvent struct {
- Data uint64
- Obj uint64
- Result int64
- Result2 int64
-}
-
-// ioEventSize is the size of an ioEvent encoded.
-var ioEventSize = binary.Size(ioEvent{})
-
// IoSetup implements linux syscall io_setup(2).
func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
nrEvents := args[0].Int()
@@ -114,14 +61,28 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
func IoDestroy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := args[0].Uint64()
- // Destroy the given context.
- if !t.MemoryManager().DestroyAIOContext(t, id) {
+ ctx := t.MemoryManager().DestroyAIOContext(t, id)
+ if ctx == nil {
// Does not exist.
return 0, nil, syserror.EINVAL
}
- // FIXME(fvoznika): Linux blocks until all AIO to the destroyed context is
- // done.
- return 0, nil, nil
+
+ // Drain completed requests amd wait for pending requests until there are no
+ // more.
+ for {
+ ctx.Drain()
+
+ ch := ctx.WaitChannel()
+ if ch == nil {
+ // No more requests, we're done.
+ return 0, nil, nil
+ }
+ // The task cannot be interrupted during the wait. Equivalent to
+ // TASK_UNINTERRUPTIBLE in Linux.
+ t.UninterruptibleSleepStart(true /* deactivate */)
+ <-ch
+ t.UninterruptibleSleepFinish(true /* activate */)
+ }
}
// IoGetevents implements linux syscall io_getevents(2).
@@ -178,7 +139,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
}
}
- ev := v.(*ioEvent)
+ ev := v.(*linux.IOEvent)
// Copy out the result.
if _, err := t.CopyOut(eventsAddr, ev); err != nil {
@@ -190,7 +151,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
}
// Keep rolling.
- eventsAddr += usermem.Addr(ioEventSize)
+ eventsAddr += usermem.Addr(linux.IOEventSize)
}
// Everything finished.
@@ -200,13 +161,13 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadline ktime.Time) (interface{}, error) {
for {
if v, ok := ctx.PopRequest(); ok {
- // Request was readly available. Just return it.
+ // Request was readily available. Just return it.
return v, nil
}
// Need to wait for request completion.
- done, active := ctx.WaitChannel()
- if !active {
+ done := ctx.WaitChannel()
+ if done == nil {
// Context has been destroyed.
return nil, syserror.EINVAL
}
@@ -217,7 +178,7 @@ func waitForRequest(ctx *mm.AIOContext, t *kernel.Task, haveDeadline bool, deadl
}
// memoryFor returns appropriate memory for the given callback.
-func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
+func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) {
bytes := int(cb.Bytes)
if bytes < 0 {
// Linux also requires that this field fit in ssize_t.
@@ -228,17 +189,17 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
// we have no guarantee that t's AddressSpace will be active during the
// I/O.
switch cb.OpCode {
- case _IOCB_CMD_PREAD, _IOCB_CMD_PWRITE:
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
- case _IOCB_CMD_PREADV, _IOCB_CMD_PWRITEV:
+ case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
AddressSpaceActive: false,
})
- case _IOCB_CMD_FSYNC, _IOCB_CMD_FDSYNC, _IOCB_CMD_NOOP:
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP:
return usermem.IOSequence{}, nil
default:
@@ -247,66 +208,78 @@ func memoryFor(t *kernel.Task, cb *ioCallback) (usermem.IOSequence, error) {
}
}
-func performCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *ioCallback, ioseq usermem.IOSequence, ctx *mm.AIOContext, eventFile *fs.File) {
- ev := &ioEvent{
- Data: cb.Data,
- Obj: uint64(cbAddr),
- }
+// IoCancel implements linux syscall io_cancel(2).
+//
+// It is not presently supported (ENOSYS indicates no support on this
+// architecture).
+func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ENOSYS
+}
- // Construct a context.Context that will not be interrupted if t is
- // interrupted.
- c := t.AsyncContext()
+// LINT.IfChange
- var err error
- switch cb.OpCode {
- case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV:
- ev.Result, err = file.Preadv(c, ioseq, cb.Offset)
- case _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV:
- ev.Result, err = file.Pwritev(c, ioseq, cb.Offset)
- case _IOCB_CMD_FSYNC:
- err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncAll)
- case _IOCB_CMD_FDSYNC:
- err = file.Fsync(c, 0, fs.FileMaxOffset, fs.SyncData)
- }
+func getAIOCallback(t *kernel.Task, file *fs.File, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, actx *mm.AIOContext, eventFile *fs.File) kernel.AIOCallback {
+ return func(ctx context.Context) {
+ if actx.Dead() {
+ actx.CancelPendingRequest()
+ return
+ }
+ ev := &linux.IOEvent{
+ Data: cb.Data,
+ Obj: uint64(cbAddr),
+ }
- // Update the result.
- if err != nil {
- err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
- ev.Result = -int64(t.ExtractErrno(err, 0))
- }
+ var err error
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV:
+ ev.Result, err = file.Preadv(ctx, ioseq, cb.Offset)
+ case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ ev.Result, err = file.Pwritev(ctx, ioseq, cb.Offset)
+ case linux.IOCB_CMD_FSYNC:
+ err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncAll)
+ case linux.IOCB_CMD_FDSYNC:
+ err = file.Fsync(ctx, 0, fs.FileMaxOffset, fs.SyncData)
+ }
+
+ // Update the result.
+ if err != nil {
+ err = handleIOError(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", file)
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
+ }
- file.DecRef()
+ file.DecRef(ctx)
- // Queue the result for delivery.
- ctx.FinishRequest(ev)
+ // Queue the result for delivery.
+ actx.FinishRequest(ev)
- // Notify the event file if one was specified. This needs to happen
- // *after* queueing the result to avoid racing with the thread we may
- // wake up.
- if eventFile != nil {
- eventFile.FileOperations.(*eventfd.EventOperations).Signal(1)
- eventFile.DecRef()
+ // Notify the event file if one was specified. This needs to happen
+ // *after* queueing the result to avoid racing with the thread we may
+ // wake up.
+ if eventFile != nil {
+ eventFile.FileOperations.(*eventfd.EventOperations).Signal(1)
+ eventFile.DecRef(ctx)
+ }
}
}
// submitCallback processes a single callback.
-func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Addr) error {
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
file := t.GetFile(cb.FD)
if file == nil {
// File not found.
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Was there an eventFD? Extract it.
var eventFile *fs.File
- if cb.Flags&_IOCB_FLAG_RESFD != 0 {
+ if cb.Flags&linux.IOCB_FLAG_RESFD != 0 {
eventFile = t.GetFile(cb.ResFD)
if eventFile == nil {
// Bad FD.
return syserror.EBADF
}
- defer eventFile.DecRef()
+ defer eventFile.DecRef(t)
// Check that it is an eventfd.
if _, ok := eventFile.FileOperations.(*eventfd.EventOperations); !ok {
@@ -322,7 +295,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad
// Check offset for reads/writes.
switch cb.OpCode {
- case _IOCB_CMD_PREAD, _IOCB_CMD_PREADV, _IOCB_CMD_PWRITE, _IOCB_CMD_PWRITEV:
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
if cb.Offset < 0 {
return syserror.EINVAL
}
@@ -348,7 +321,7 @@ func submitCallback(t *kernel.Task, id uint64, cb *ioCallback, cbAddr usermem.Ad
// Perform the request asynchronously.
file.IncRef()
- fs.Async(func() { performCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile) })
+ t.QueueAIO(getAIOCallback(t, file, cbAddr, cb, ioseq, ctx, eventFile))
// All set.
return nil
@@ -377,7 +350,7 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Copy in this callback.
- var cb ioCallback
+ var cb linux.IOCallback
cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
if _, err := t.CopyIn(cbAddr, &cb); err != nil {
@@ -406,10 +379,4 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return uintptr(nrEvents), nil, nil
}
-// IoCancel implements linux syscall io_cancel(2).
-//
-// It is not presently supported (ENOSYS indicates no support on this
-// architecture).
-func IoCancel(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- return 0, nil, syserror.ENOSYS
-}
+// LINT.ThenChange(vfs2/aio.go)
diff --git a/pkg/sentry/syscalls/linux/sys_clone_amd64.go b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
new file mode 100644
index 000000000..dd43cf18d
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_amd64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors. We implement the default one in linux 3.11
+// x86_64:
+// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ childTID := args[3].Pointer()
+ tls := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_clone_arm64.go b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
new file mode 100644
index 000000000..cf68a8949
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_clone_arm64.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Clone implements linux syscall clone(2).
+// sys_clone has so many flavors, and we implement the default one in linux 3.11
+// arm64(kernel/fork.c with CONFIG_CLONE_BACKWARDS defined in the config file):
+// sys_clone(clone_flags, newsp, parent_tidptr, tls_val, child_tidptr)
+func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := int(args[0].Int())
+ stack := args[1].Pointer()
+ parentTID := args[2].Pointer()
+ tls := args[3].Pointer()
+ childTID := args[4].Pointer()
+ return clone(t, flags, stack, parentTID, childTID, tls)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index 65b4a227b..7f460d30b 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -20,11 +20,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/epoll"
"gvisor.dev/gvisor/pkg/sentry/syscalls"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// EpollCreate1 implements the epoll_create1(2) linux syscall.
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
@@ -70,7 +71,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
var data [2]int32
if op != linux.EPOLL_CTL_DEL {
var e linux.EpollEvent
- if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ if _, err := e.CopyIn(t, eventAddr); err != nil {
return 0, nil, err
}
@@ -83,8 +84,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
mask = waiter.EventMaskFromLinux(e.Events)
- data[0] = e.Fd
- data[1] = e.Data
+ data = e.Data
}
// Perform the requested operations.
@@ -104,28 +104,6 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
}
-// copyOutEvents copies epoll events from the kernel to user memory.
-func copyOutEvents(t *kernel.Task, addr usermem.Addr, e []epoll.Event) error {
- const itemLen = 12
- buffLen := len(e) * itemLen
- if _, ok := addr.AddLength(uint64(buffLen)); !ok {
- return syserror.EFAULT
- }
-
- b := t.CopyScratchBuffer(buffLen)
- for i := range e {
- usermem.ByteOrder.PutUint32(b[i*itemLen:], e[i].Events)
- usermem.ByteOrder.PutUint32(b[i*itemLen+4:], uint32(e[i].Data[0]))
- usermem.ByteOrder.PutUint32(b[i*itemLen+8:], uint32(e[i].Data[1]))
- }
-
- if _, err := t.CopyOutBytes(addr, b); err != nil {
- return err
- }
-
- return nil
-}
-
// EpollWait implements the epoll_wait(2) linux syscall.
func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
epfd := args[0].Int()
@@ -139,7 +117,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
if len(r) != 0 {
- if err := copyOutEvents(t, eventsAddr, r); err != nil {
+ if _, err := linux.CopyEpollEventSliceOut(t, eventsAddr, r); err != nil {
return 0, nil, err
}
}
@@ -153,7 +131,7 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
maskSize := uint(args[5].Uint())
if maskAddr != 0 {
- mask, err := copyInSigSet(t, maskAddr, maskSize)
+ mask, err := CopyInSigSet(t, maskAddr, maskSize)
if err != nil {
return 0, nil, err
}
@@ -165,3 +143,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_eventfd.go b/pkg/sentry/syscalls/linux/sys_eventfd.go
index 8a34c4e99..3b4f879e4 100644
--- a/pkg/sentry/syscalls/linux/sys_eventfd.go
+++ b/pkg/sentry/syscalls/linux/sys_eventfd.go
@@ -15,6 +15,7 @@
package linux
import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -22,32 +23,24 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
-const (
- // EFD_SEMAPHORE is a flag used in syscall eventfd(2) and eventfd2(2). Please
- // see its man page for more information.
- EFD_SEMAPHORE = 1
- EFD_NONBLOCK = 0x800
- EFD_CLOEXEC = 0x80000
-)
-
// Eventfd2 implements linux syscall eventfd2(2).
func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
initVal := args[0].Int()
flags := uint(args[1].Uint())
- allOps := uint(EFD_SEMAPHORE | EFD_NONBLOCK | EFD_CLOEXEC)
+ allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC)
if flags & ^allOps != 0 {
return 0, nil, syserror.EINVAL
}
- event := eventfd.New(t, uint64(initVal), flags&EFD_SEMAPHORE != 0)
+ event := eventfd.New(t, uint64(initVal), flags&linux.EFD_SEMAPHORE != 0)
event.SetFlags(fs.SettableFileFlags{
- NonBlocking: flags&EFD_NONBLOCK != 0,
+ NonBlocking: flags&linux.EFD_NONBLOCK != 0,
})
- defer event.DecRef()
+ defer event.DecRef(t)
fd, err := t.NewFDFrom(0, event, kernel.FDFlags{
- CloseOnExec: flags&EFD_CLOEXEC != 0,
+ CloseOnExec: flags&linux.EFD_CLOEXEC != 0,
})
if err != nil {
return 0, nil, err
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index b9a8e3e21..1bc9b184e 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -18,8 +18,8 @@ import (
"syscall"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
@@ -28,8 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/fasync"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// fileOpAt performs an operation on the second last component in the path.
@@ -40,7 +40,7 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent,
// Common case: we are accessing a file in the root.
root := t.FSContext().RootDirectory()
err := fn(root, root, name, linux.MaxSymlinkTraversals)
- root.DecRef()
+ root.DecRef(t)
return err
} else if dir == "." && dirFD == linux.AT_FDCWD {
// Common case: we are accessing a file relative to the current
@@ -48,8 +48,8 @@ func fileOpAt(t *kernel.Task, dirFD int32, path string, fn func(root *fs.Dirent,
wd := t.FSContext().WorkingDirectory()
root := t.FSContext().RootDirectory()
err := fn(root, wd, name, linux.MaxSymlinkTraversals)
- wd.DecRef()
- root.DecRef()
+ wd.DecRef(t)
+ root.DecRef(t)
return err
}
@@ -97,19 +97,19 @@ func fileOpOn(t *kernel.Task, dirFD int32, path string, resolve bool, fn func(ro
} else {
d, err = t.MountNamespace().FindLink(t, root, rel, path, &remainingTraversals)
}
- root.DecRef()
+ root.DecRef(t)
if wd != nil {
- wd.DecRef()
+ wd.DecRef(t)
}
if f != nil {
- f.DecRef()
+ f.DecRef(t)
}
if err != nil {
return err
}
err = fn(root, d, remainingTraversals)
- d.DecRef()
+ d.DecRef(t)
return err
}
@@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string
return path, dirPath, nil
}
+// LINT.IfChange
+
func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -169,10 +171,14 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
if dirPath {
return syserror.ENOTDIR
}
- if flags&linux.O_TRUNC != 0 {
- if err := d.Inode.Truncate(t, d, 0); err != nil {
- return err
- }
+ }
+
+ // Truncate is called when O_TRUNC is specified for any kind of
+ // existing Dirent. Behavior is delegated to the entry's Truncate
+ // implementation.
+ if flags&linux.O_TRUNC != 0 {
+ if err := d.Inode.Truncate(t, d, 0); err != nil {
+ return err
}
}
@@ -180,7 +186,7 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
if err != nil {
return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Success.
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
@@ -236,7 +242,7 @@ func mknodAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode
if err != nil {
return err
}
- file.DecRef()
+ file.DecRef(t)
return nil
case linux.ModeNamedPipe:
@@ -326,7 +332,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
if err != nil {
break
}
- defer found.DecRef()
+ defer found.DecRef(t)
// We found something (possibly a symlink). If the
// O_EXCL flag was passed, then we can immediately
@@ -351,7 +357,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
resolved, err = found.Inode.Getlink(t)
if err == nil {
// No more resolution necessary.
- defer resolved.DecRef()
+ defer resolved.DecRef(t)
break
}
if err != fs.ErrResolveViaReadlink {
@@ -378,7 +384,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
if err != nil {
break
}
- defer newParent.DecRef()
+ defer newParent.DecRef(t)
// Repeat the process with the parent and name of the
// symlink target.
@@ -396,7 +402,9 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
return err
}
- // Should we truncate the file?
+ // Truncate is called when O_TRUNC is specified for any kind of
+ // existing Dirent. Behavior is delegated to the entry's Truncate
+ // implementation.
if flags&linux.O_TRUNC != 0 {
if err := found.Inode.Truncate(t, found, 0); err != nil {
return err
@@ -408,7 +416,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
if err != nil {
return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
- defer newFile.DecRef()
+ defer newFile.DecRef(t)
case syserror.ENOENT:
// File does not exist. Proceed with creation.
@@ -424,7 +432,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l
// No luck, bail.
return err
}
- defer newFile.DecRef()
+ defer newFile.DecRef(t)
found = newFile.Dirent
default:
return err
@@ -506,7 +514,7 @@ func (ac accessContext) Value(key interface{}) interface{} {
}
}
-func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode uint) error {
+func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode uint) error {
const rOK = 4
const wOK = 2
const xOK = 1
@@ -521,7 +529,7 @@ func accessAt(t *kernel.Task, dirFD int32, addr usermem.Addr, resolve bool, mode
return syserror.EINVAL
}
- return fileOpOn(t, dirFD, path, resolve, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
+ return fileOpOn(t, dirFD, path, true /* resolve */, func(root *fs.Dirent, d *fs.Dirent, _ uint) error {
// access(2) and faccessat(2) check permissions using real
// UID/GID, not effective UID/GID.
//
@@ -556,19 +564,29 @@ func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
addr := args[0].Pointer()
mode := args[1].ModeT()
- return 0, nil, accessAt(t, linux.AT_FDCWD, addr, true, mode)
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
}
// Faccessat implements linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
dirFD := args[0].Int()
addr := args[1].Pointer()
mode := args[2].ModeT()
- flags := args[3].Int()
- return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
+ return 0, nil, accessAt(t, dirFD, addr, mode)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Ioctl implements linux syscall ioctl(2).
func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -578,7 +596,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Shared flags between file and socket.
switch request {
@@ -644,14 +662,18 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/ioctl.go)
+
+// LINT.IfChange
+
// Getcwd implements the linux syscall getcwd(2).
func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
size := args[1].SizeT()
cwd := t.FSContext().WorkingDirectory()
- defer cwd.DecRef()
+ defer cwd.DecRef(t)
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
// Get our fullname from the root and preprend unreachable if the root was
// unreachable from our current dirent this is the same behavior as on linux.
@@ -700,7 +722,7 @@ func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return err
}
- t.FSContext().SetRootDirectory(d)
+ t.FSContext().SetRootDirectory(t, d)
return nil
})
}
@@ -725,7 +747,7 @@ func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return err
}
- t.FSContext().SetWorkingDirectory(d)
+ t.FSContext().SetWorkingDirectory(t, d)
return nil
})
}
@@ -738,7 +760,7 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Is it a directory?
if !fs.IsDir(file.Dirent.Inode.StableAttr) {
@@ -750,10 +772,14 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
}
- t.FSContext().SetWorkingDirectory(file.Dirent)
+ t.FSContext().SetWorkingDirectory(t, file.Dirent)
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/fscontext.go)
+
+// LINT.IfChange
+
// Close implements linux syscall close(2).
func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -761,11 +787,11 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Note that Remove provides a reference on the file that we may use to
// flush. It is still active until we drop the final reference below
// (and other reference-holding operations complete).
- file := t.FDTable().Remove(fd)
+ file, _ := t.FDTable().Remove(fd)
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := file.Flush(t)
return 0, nil, handleIOError(t, false /* partial */, err, syserror.EINTR, "close", file)
@@ -779,7 +805,7 @@ func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{})
if err != nil {
@@ -800,7 +826,7 @@ func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if oldFile == nil {
return 0, nil, syserror.EBADF
}
- defer oldFile.DecRef()
+ defer oldFile.DecRef(t)
return uintptr(newfd), nil, nil
}
@@ -824,7 +850,7 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if oldFile == nil {
return 0, nil, syserror.EBADF
}
- defer oldFile.DecRef()
+ defer oldFile.DecRef(t)
err := t.NewFDAt(newfd, oldFile, kernel.FDFlags{CloseOnExec: flags&linux.O_CLOEXEC != 0})
if err != nil {
@@ -834,37 +860,60 @@ func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(newfd), nil, nil
}
-func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+func fGetOwnEx(t *kernel.Task, file *fs.File) linux.FOwnerEx {
ma := file.Async(nil)
if ma == nil {
- return 0
+ return linux.FOwnerEx{}
}
a := ma.(*fasync.FileAsync)
ot, otg, opg := a.Owner()
switch {
case ot != nil:
- return int32(t.PIDNamespace().IDOfTask(ot))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_TID,
+ PID: int32(t.PIDNamespace().IDOfTask(ot)),
+ }
case otg != nil:
- return int32(t.PIDNamespace().IDOfThreadGroup(otg))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PID,
+ PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)),
+ }
case opg != nil:
- return int32(-t.PIDNamespace().IDOfProcessGroup(opg))
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PGRP,
+ PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)),
+ }
default:
- return 0
+ return linux.FOwnerEx{}
+ }
+}
+
+func fGetOwn(t *kernel.Task, file *fs.File) int32 {
+ owner := fGetOwnEx(t, file)
+ if owner.Type == linux.F_OWNER_PGRP {
+ return -owner.PID
}
+ return owner.PID
}
// fSetOwn sets the file's owner with the semantics of F_SETOWN in Linux.
//
// If who is positive, it represents a PID. If negative, it represents a PGID.
// If the PID or PGID is invalid, the owner is silently unset.
-func fSetOwn(t *kernel.Task, file *fs.File, who int32) {
+func fSetOwn(t *kernel.Task, file *fs.File, who int32) error {
a := file.Async(fasync.New).(*fasync.FileAsync)
if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return syserror.EINVAL
+ }
pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who))
a.SetOwnerProcessGroup(t, pg)
+ } else {
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who))
+ a.SetOwnerThreadGroup(t, tg)
}
- tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who))
- a.SetOwnerThreadGroup(t, tg)
+ return nil
}
// Fcntl implements linux syscall fcntl(2).
@@ -876,7 +925,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
switch cmd {
case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
@@ -892,14 +941,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(flags.ToLinuxFDFlags()), nil, nil
case linux.F_SETFD:
flags := args[2].Uint()
- t.FDTable().SetFlags(fd, kernel.FDFlags{
+ err := t.FDTable().SetFlags(fd, kernel.FDFlags{
CloseOnExec: flags&linux.FD_CLOEXEC != 0,
})
+ return 0, nil, err
case linux.F_GETFL:
return uintptr(file.Flags().ToLinux()), nil, nil
case linux.F_SETFL:
flags := uint(args[2].Uint())
file.SetFlags(linuxToFlags(flags).Settable())
+ return 0, nil, nil
case linux.F_SETLK, linux.F_SETLKW:
// In Linux the file system can choose to provide lock operations for an inode.
// Normally pipe and socket types lack lock operations. We diverge and use a heavy
@@ -953,9 +1004,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
- // The lock uid is that of the Task's FDTable.
- lockUniqueID := lock.UniqueID(t.FDTable().ID())
-
// These locks don't block; execute the non-blocking operation using the inode's lock
// context directly.
switch flock.Type {
@@ -965,12 +1013,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
@@ -981,18 +1029,18 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if cmd == linux.F_SETLK {
// Non-blocking lock, provide a nil lock.Blocker.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, nil) {
return 0, nil, syserror.EAGAIN
}
} else {
// Blocking lock, pass in the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.Posix.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.Posix.LockRegion(t.FDTable(), lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
return 0, nil, nil
case linux.F_UNLCK:
- file.Dirent.Inode.LockCtx.Posix.UnlockRegion(lockUniqueID, rng)
+ file.Dirent.Inode.LockCtx.Posix.UnlockRegion(t.FDTable(), rng)
return 0, nil, nil
default:
return 0, nil, syserror.EINVAL
@@ -1000,8 +1048,45 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN:
return uintptr(fGetOwn(t, file)), nil, nil
case linux.F_SETOWN:
- fSetOwn(t, file, args[2].Int())
- return 0, nil, nil
+ return 0, nil, fSetOwn(t, file, args[2].Int())
+ case linux.F_GETOWN_EX:
+ addr := args[2].Pointer()
+ owner := fGetOwnEx(t, file)
+ _, err := t.CopyOut(addr, &owner)
+ return 0, nil, err
+ case linux.F_SETOWN_EX:
+ addr := args[2].Pointer()
+ var owner linux.FOwnerEx
+ _, err := t.CopyIn(addr, &owner)
+ if err != nil {
+ return 0, nil, err
+ }
+ a := file.Async(fasync.New).(*fasync.FileAsync)
+ switch owner.Type {
+ case linux.F_OWNER_TID:
+ task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
+ if task == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerTask(t, task)
+ return 0, nil, nil
+ case linux.F_OWNER_PID:
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID))
+ if tg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerThreadGroup(t, tg)
+ return 0, nil, nil
+ case linux.F_OWNER_PGRP:
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID))
+ if pg == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ a.SetOwnerProcessGroup(t, pg)
+ return 0, nil, nil
+ default:
+ return 0, nil, syserror.EINVAL
+ }
case linux.F_GET_SEALS:
val, err := tmpfs.GetSeals(file.Dirent.Inode)
return uintptr(val), nil, err
@@ -1029,18 +1114,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
}
- return 0, nil, nil
}
-const (
- _FADV_NORMAL = 0
- _FADV_RANDOM = 1
- _FADV_SEQUENTIAL = 2
- _FADV_WILLNEED = 3
- _FADV_DONTNEED = 4
- _FADV_NOREUSE = 5
-)
-
// Fadvise64 implements linux syscall fadvise64(2).
// This implementation currently ignores the provided advice.
func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
@@ -1057,7 +1132,7 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// If the FD refers to a pipe or FIFO, return error.
if fs.IsPipe(file.Dirent.Inode.StableAttr) {
@@ -1065,12 +1140,12 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
}
switch advice {
- case _FADV_NORMAL:
- case _FADV_RANDOM:
- case _FADV_SEQUENTIAL:
- case _FADV_WILLNEED:
- case _FADV_DONTNEED:
- case _FADV_NOREUSE:
+ case linux.POSIX_FADV_NORMAL:
+ case linux.POSIX_FADV_RANDOM:
+ case linux.POSIX_FADV_SEQUENTIAL:
+ case linux.POSIX_FADV_WILLNEED:
+ case linux.POSIX_FADV_DONTNEED:
+ case linux.POSIX_FADV_NOREUSE:
default:
return 0, nil, syserror.EINVAL
}
@@ -1096,7 +1171,7 @@ func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode
switch err {
case nil:
// The directory existed.
- defer f.DecRef()
+ defer f.DecRef(t)
return syserror.EEXIST
case syserror.EACCES:
// Permission denied while walking to the directory.
@@ -1156,7 +1231,7 @@ func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return syserror.ENOTEMPTY
}
- if err := fs.MayDelete(t, root, d, name); err != nil {
+ if err := d.MayDelete(t, root, name); err != nil {
return err
}
@@ -1274,7 +1349,7 @@ func linkAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32
if target == nil {
return syserror.EBADF
}
- defer target.DecRef()
+ defer target.DecRef(t)
if err := mayLinkAt(t, target.Dirent.Inode); err != nil {
return err
}
@@ -1359,6 +1434,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1418,6 +1497,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return n, nil, err
}
+// LINT.ThenChange(vfs2/stat.go)
+
+// LINT.IfChange
+
func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1429,7 +1512,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return syserror.ENOTDIR
}
- if err := fs.MayDelete(t, root, d, name); err != nil {
+ if err := d.MayDelete(t, root, name); err != nil {
return err
}
@@ -1454,6 +1537,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, unlinkAt(t, dirFD, addr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Truncate implements linux syscall truncate(2).
func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -1483,6 +1570,8 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if fs.IsDir(d.Inode.StableAttr) {
return syserror.EISDIR
}
+ // In contrast to open(O_TRUNC), truncate(2) is only valid for file
+ // types.
if !fs.IsFile(d.Inode.StableAttr) {
return syserror.EINVAL
}
@@ -1513,7 +1602,7 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Reject truncation if the file flags do not permit this operation.
// This is different from truncate(2) above.
@@ -1521,7 +1610,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EINVAL
}
- // Note that this is different from truncate(2) above, where a
+ // In contrast to open(O_TRUNC), truncate(2) is only valid for file
+ // types. Note that this is different from truncate(2) above, where a
// directory returns EISDIR.
if !fs.IsFile(file.Dirent.Inode.StableAttr) {
return 0, nil, syserror.EINVAL
@@ -1549,6 +1639,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/setstat.go)
+
// Umask implements linux syscall umask(2).
func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
mask := args[0].ModeT()
@@ -1556,6 +1648,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(mask), nil, nil
}
+// LINT.IfChange
+
// Change ownership of a file.
//
// uid and gid may be -1, in which case they will not be changed.
@@ -1636,7 +1730,7 @@ func chownAt(t *kernel.Task, fd int32, addr usermem.Addr, resolve, allowEmpty bo
if file == nil {
return syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return chown(t, file.Dirent, uid, gid)
}
@@ -1674,7 +1768,7 @@ func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, chown(t, file.Dirent, uid, gid)
}
@@ -1739,7 +1833,7 @@ func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, chmod(t, file.Dirent, mode)
}
@@ -1799,10 +1893,10 @@ func utimes(t *kernel.Task, dirFD int32, addr usermem.Addr, ts fs.TimeSpec, reso
if f == nil {
return syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
return setTimestamp(root, f.Dirent, linux.MaxSymlinkTraversals)
}
@@ -1922,6 +2016,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
}
+// LINT.ThenChange(vfs2/setstat.go)
+
+// LINT.IfChange
+
func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
@@ -1977,6 +2075,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
// Fallocate implements linux system call fallocate(2).
func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -1988,7 +2088,7 @@ func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
if offset < 0 || length <= 0 {
return 0, nil, syserror.EINVAL
@@ -2041,27 +2141,11 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
// flock(2): EBADF fd is not an open file descriptor.
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
nonblocking := operation&linux.LOCK_NB != 0
operation &^= linux.LOCK_NB
- // flock(2):
- // Locks created by flock() are associated with an open file table entry. This means that
- // duplicate file descriptors (created by, for example, fork(2) or dup(2)) refer to the
- // same lock, and this lock may be modified or released using any of these descriptors. Furthermore,
- // the lock is released either by an explicit LOCK_UN operation on any of these duplicate
- // descriptors, or when all such descriptors have been closed.
- //
- // If a process uses open(2) (or similar) to obtain more than one descriptor for the same file,
- // these descriptors are treated independently by flock(). An attempt to lock the file using
- // one of these file descriptors may be denied by a lock that the calling process has already placed via
- // another descriptor.
- //
- // We use the File UniqueID as the lock UniqueID because it needs to reference the same lock across dup(2)
- // and fork(2).
- lockUniqueID := lock.UniqueID(file.UniqueID)
-
// A BSD style lock spans the entire file.
rng := lock.LockRange{
Start: 0,
@@ -2072,29 +2156,29 @@ func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.LOCK_EX:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.WriteLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.WriteLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
case linux.LOCK_SH:
if nonblocking {
// Since we're nonblocking we pass a nil lock.Blocker implementation.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, nil) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, nil) {
return 0, nil, syserror.EWOULDBLOCK
}
} else {
// Because we're blocking we will pass the task to satisfy the lock.Blocker interface.
- if !file.Dirent.Inode.LockCtx.BSD.LockRegion(lockUniqueID, lock.ReadLock, rng, t) {
+ if !file.Dirent.Inode.LockCtx.BSD.LockRegion(file, lock.ReadLock, rng, t) {
return 0, nil, syserror.EINTR
}
}
case linux.LOCK_UN:
- file.Dirent.Inode.LockCtx.BSD.UnlockRegion(lockUniqueID, rng)
+ file.Dirent.Inode.LockCtx.BSD.UnlockRegion(file, rng)
default:
// flock(2): EINVAL operation is invalid.
return 0, nil, syserror.EINVAL
@@ -2140,8 +2224,8 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
- defer dirent.DecRef()
- defer file.DecRef()
+ defer dirent.DecRef(t)
+ defer file.DecRef(t)
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
CloseOnExec: cloExec,
diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go
index b9bd25464..9d1b2edb1 100644
--- a/pkg/sentry/syscalls/linux/sys_futex.go
+++ b/pkg/sentry/syscalls/linux/sys_futex.go
@@ -21,8 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// futexWaitRestartBlock encapsulates the state required to restart futex(2)
@@ -73,7 +73,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo
err = t.BlockWithDeadline(w.C, true, ktime.FromTimespec(ts))
}
- t.Futex().WaitComplete(w)
+ t.Futex().WaitComplete(w, t)
return 0, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
@@ -95,7 +95,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add
}
remaining, err := t.BlockWithTimeout(w.C, !forever, duration)
- t.Futex().WaitComplete(w)
+ t.Futex().WaitComplete(w, t)
if err == nil {
return 0, nil
}
@@ -148,7 +148,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.A
timer.Destroy()
}
- t.Futex().WaitComplete(w)
+ t.Futex().WaitComplete(w, t)
return syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
@@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
switch cmd {
case linux.FUTEX_WAIT:
// WAIT uses a relative timeout.
- mask = ^uint32(0)
+ mask = linux.FUTEX_BITSET_MATCH_ANY
var timeoutDur time.Duration
if !forever {
timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond
@@ -226,6 +226,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if mask == 0 {
return 0, nil, syserror.EINVAL
}
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
n, err := t.Futex().Wake(t, addr, private, mask, val)
return uintptr(n), nil, err
@@ -242,6 +247,11 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.FUTEX_WAKE_OP:
op := uint32(val3)
+ if val <= 0 {
+ // The Linux kernel wakes one waiter even if val is
+ // non-positive.
+ val = 1
+ }
n, err := t.Futex().WakeOp(t, addr, naddr, private, val, nreq, op)
return uintptr(n), nil, err
@@ -276,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, syserror.ENOSYS
}
}
+
+// SetRobustList implements linux syscall set_robust_list(2).
+func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ head := args[0].Pointer()
+ length := args[1].SizeT()
+
+ if length != uint(linux.SizeOfRobustListHead) {
+ return 0, nil, syserror.EINVAL
+ }
+ t.SetRobustList(head)
+ return 0, nil, nil
+}
+
+// GetRobustList implements linux syscall get_robust_list(2).
+func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // Despite the syscall using the name 'pid' for this variable, it is
+ // very much a tid.
+ tid := args[0].Int()
+ head := args[1].Pointer()
+ size := args[2].Pointer()
+
+ if tid < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ot := t
+ if tid != 0 {
+ if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil {
+ return 0, nil, syserror.ESRCH
+ }
+ }
+
+ // Copy out head pointer.
+ if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil {
+ return 0, nil, err
+ }
+
+ // Copy out size, which is a constant.
+ if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index 912cbe4ff..f5699e55d 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -23,10 +23,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Getdents implements linux syscall getdents(2) for 64bit systems.
func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -66,7 +68,7 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir
if dir == nil {
return 0, syserror.EBADF
}
- defer dir.DecRef()
+ defer dir.DecRef(t)
w := &usermem.IOReadWriter{
Ctx: t,
@@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
func (ds *direntSerializer) Written() int {
return ds.written
}
+
+// LINT.ThenChange(vfs2/getdents.go)
diff --git a/pkg/sentry/syscalls/linux/sys_inotify.go b/pkg/sentry/syscalls/linux/sys_inotify.go
index b2c7b3444..cf47bb9dd 100644
--- a/pkg/sentry/syscalls/linux/sys_inotify.go
+++ b/pkg/sentry/syscalls/linux/sys_inotify.go
@@ -40,7 +40,7 @@ func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
NonBlocking: flags&linux.IN_NONBLOCK != 0,
}
n := fs.NewFile(t, dirent, fileFlags, fs.NewInotify(t))
- defer n.DecRef()
+ defer n.DecRef(t)
fd, err := t.NewFDFrom(0, n, kernel.FDFlags{
CloseOnExec: flags&linux.IN_CLOEXEC != 0,
@@ -71,7 +71,7 @@ func fdToInotify(t *kernel.Task, fd int32) (*fs.Inotify, *fs.File, error) {
ino, ok := file.FileOperations.(*fs.Inotify)
if !ok {
// Not an inotify fd.
- file.DecRef()
+ file.DecRef(t)
return nil, nil, syserror.EINVAL
}
@@ -98,7 +98,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -128,6 +128,6 @@ func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
- return 0, nil, ino.RmWatch(wd)
+ defer file.DecRef(t)
+ return 0, nil, ino.RmWatch(t, wd)
}
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 297e920c4..1c38f8f4f 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Lseek implements linux syscall lseek(2).
func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -31,7 +33,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
var sw fs.SeekWhence
switch whence {
@@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return uintptr(offset), nil, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go
index f5a519d8a..9b4a5c3f1 100644
--- a/pkg/sentry/syscalls/linux/sys_mempolicy.go
+++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go
@@ -20,8 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// We unconditionally report a single NUMA node. This also means that our
@@ -162,10 +162,10 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if err != nil {
return 0, nil, err
}
- policy = 0 // maxNodes == 1
+ policy = linux.MPOL_DEFAULT // maxNodes == 1
}
if mode != 0 {
- if _, err := t.CopyOut(mode, policy); err != nil {
+ if _, err := policy.CopyOut(t, mode); err != nil {
return 0, nil, err
}
}
@@ -199,10 +199,10 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE {
return 0, nil, syserror.EINVAL
}
- policy = 0 // maxNodes == 1
+ policy = linux.MPOL_DEFAULT // maxNodes == 1
}
if mode != 0 {
- if _, err := t.CopyOut(mode, policy); err != nil {
+ if _, err := policy.CopyOut(t, mode); err != nil {
return 0, nil, err
}
}
@@ -216,7 +216,7 @@ func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// SetMempolicy implements the syscall set_mempolicy(2).
func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- modeWithFlags := args[0].Int()
+ modeWithFlags := linux.NumaPolicy(args[0].Int())
nodemask := args[1].Pointer()
maxnode := args[2].Uint()
@@ -233,7 +233,7 @@ func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
length := args[1].Uint64()
- mode := args[2].Int()
+ mode := linux.NumaPolicy(args[2].Int())
nodemask := args[3].Pointer()
maxnode := args[4].Uint()
flags := args[5].Uint()
@@ -258,9 +258,9 @@ func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
}
-func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags int32, nodemask usermem.Addr, maxnode uint32) (int32, uint64, error) {
- flags := modeWithFlags & linux.MPOL_MODE_FLAGS
- mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS
+func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask usermem.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) {
+ flags := linux.NumaPolicy(modeWithFlags & linux.MPOL_MODE_FLAGS)
+ mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS)
if flags == linux.MPOL_MODE_FLAGS {
// Can't specify both mode flags simultaneously.
return 0, 0, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 58a05b5bb..72786b032 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -22,8 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/mm"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Brk implements linux syscall brk(2).
@@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return uintptr(addr), nil, nil
}
+// LINT.IfChange
+
// Mmap implements linux syscall mmap(2).
func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
prot := args[2].Int()
@@ -73,7 +75,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
}
defer func() {
if opts.MappingIdentity != nil {
- opts.MappingIdentity.DecRef()
+ opts.MappingIdentity.DecRef(t)
}
}()
@@ -83,7 +85,7 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
flags := file.Flags()
// mmap unconditionally requires that the FD is readable.
@@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(rv), nil, err
}
+// LINT.ThenChange(vfs2/mmap.go)
+
// Munmap implements linux syscall munmap(2).
func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
diff --git a/pkg/sentry/syscalls/linux/sys_mount.go b/pkg/sentry/syscalls/linux/sys_mount.go
index 8c13e2d82..bd0633564 100644
--- a/pkg/sentry/syscalls/linux/sys_mount.go
+++ b/pkg/sentry/syscalls/linux/sys_mount.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// Mount implements Linux syscall mount(2).
@@ -115,7 +115,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}); err != nil {
// Something went wrong. Drop our ref on rootInode before
// returning the error.
- rootInode.DecRef()
+ rootInode.DecRef(t)
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go
index 418d7fa5f..3149e4aad 100644
--- a/pkg/sentry/syscalls/linux/sys_pipe.go
+++ b/pkg/sentry/syscalls/linux/sys_pipe.go
@@ -20,10 +20,12 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// pipe2 implements the actual system call with flags.
func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
@@ -32,10 +34,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize)
r.SetFlags(linuxToFlags(flags).Settable())
- defer r.DecRef()
+ defer r.DecRef(t)
w.SetFlags(linuxToFlags(flags).Settable())
- defer w.DecRef()
+ defer w.DecRef(t)
fds, err := t.NewFDs(0, []*fs.File{r, w}, kernel.FDFlags{
CloseOnExec: flags&linux.O_CLOEXEC != 0,
@@ -45,10 +47,12 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) {
}
if _, err := t.CopyOut(addr, fds); err != nil {
- // The files are not closed in this case, the exact semantics
- // of this error case are not well defined, but they could have
- // already been observed by user space.
- return 0, syserror.EFAULT
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef(t)
+ }
+ }
+ return 0, err
}
return 0, nil
}
@@ -69,3 +73,5 @@ func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
n, err := pipe2(t, addr, flags)
return n, nil, err
}
+
+// LINT.ThenChange(vfs2/pipe.go)
diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go
index 7a13beac2..3435bdf77 100644
--- a/pkg/sentry/syscalls/linux/sys_poll.go
+++ b/pkg/sentry/syscalls/linux/sys_poll.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -70,7 +70,7 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan
}
if ch == nil {
- defer file.DecRef()
+ defer file.DecRef(t)
} else {
state.file = file
state.waiter, _ = waiter.NewChannelEntry(ch)
@@ -82,11 +82,11 @@ func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan
}
// releaseState releases all the pollState in "state".
-func releaseState(state []pollState) {
+func releaseState(t *kernel.Task, state []pollState) {
for i := range state {
if state[i].file != nil {
state[i].file.EventUnregister(&state[i].waiter)
- state[i].file.DecRef()
+ state[i].file.DecRef(t)
}
}
}
@@ -107,7 +107,7 @@ func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.
// result, we stop registering for events but still go through all files
// to get their ready masks.
state := make([]pollState, len(pfd))
- defer releaseState(state)
+ defer releaseState(t, state)
n := uintptr(0)
for i := range pfd {
initReadiness(t, &pfd[i], &state[i], ch)
@@ -197,53 +197,51 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration)
return remainingTimeout, n, err
}
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
if nfds < 0 || nfds > fileCap {
return 0, syserror.EINVAL
}
- // Capture all the provided input vectors.
- //
- // N.B. This only works on little-endian architectures.
- byteCount := (nfds + 7) / 8
-
- bitsInLastPartialByte := uint(nfds % 8)
- r := make([]byte, byteCount)
- w := make([]byte, byteCount)
- e := make([]byte, byteCount)
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
- if readFDs != 0 {
- if _, err := t.CopyIn(readFDs, &r); err != nil {
- return 0, err
- }
- // Mask out bits above nfds.
- if bitsInLastPartialByte != 0 {
- r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
-
- if writeFDs != 0 {
- if _, err := t.CopyIn(writeFDs, &w); err != nil {
- return 0, err
- }
- if bitsInLastPartialByte != 0 {
- w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
-
- if exceptFDs != 0 {
- if _, err := t.CopyIn(exceptFDs, &e); err != nil {
- return 0, err
- }
- if bitsInLastPartialByte != 0 {
- e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte
- }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
}
// Count how many FDs are actually being requested so that we can build
// a PollFD array.
fdCount := 0
- for i := 0; i < byteCount; i++ {
+ for i := 0; i < nBytes; i++ {
v := r[i] | w[i] | e[i]
for v != 0 {
v &= (v - 1)
@@ -254,7 +252,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
// Build the PollFD array.
pfd := make([]linux.PollFD, 0, fdCount)
var fd int32
- for i := 0; i < byteCount; i++ {
+ for i := 0; i < nBytes; i++ {
rV, wV, eV := r[i], w[i], e[i]
v := rV | wV | eV
m := byte(1)
@@ -268,7 +266,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
if file == nil {
return 0, syserror.EBADF
}
- file.DecRef()
+ file.DecRef(t)
var mask int16
if (rV & m) != 0 {
@@ -295,8 +293,7 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add
}
// Do the syscall, then count the number of bits set.
- _, _, err := pollBlock(t, pfd, timeout)
- if err != nil {
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
return 0, syserror.ConvertIntr(err, syserror.EINTR)
}
@@ -446,7 +443,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
if maskAddr != 0 {
- mask, err := copyInSigSet(t, maskAddr, maskSize)
+ mask, err := CopyInSigSet(t, maskAddr, maskSize)
if err != nil {
return 0, nil, err
}
@@ -528,7 +525,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
if maskAddr != 0 {
- mask, err := copyInSigSet(t, maskAddr, size)
+ mask, err := CopyInSigSet(t, maskAddr, size)
if err != nil {
return 0, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index 98db32d77..64a725296 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/mm"
@@ -127,7 +128,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// They trying to set exe to a non-file?
if !fs.IsFile(file.Dirent.Inode.StableAttr) {
@@ -135,7 +136,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
// Set the underlying executable.
- t.MemoryManager().SetExecutable(file.Dirent)
+ t.MemoryManager().SetExecutable(t, fsbridge.NewFSFile(file))
case linux.PR_SET_MM_AUXV,
linux.PR_SET_MM_START_CODE,
@@ -160,8 +161,8 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if args[1].Int() != 1 || args[2].Int() != 0 || args[3].Int() != 0 || args[4].Int() != 0 {
return 0, nil, syserror.EINVAL
}
- // no_new_privs is assumed to always be set. See
- // kernel.Task.updateCredsForExec.
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
return 0, nil, nil
case linux.PR_GET_NO_NEW_PRIVS:
diff --git a/pkg/sentry/syscalls/linux/sys_random.go b/pkg/sentry/syscalls/linux/sys_random.go
index bc4c588bf..c0aa0fd60 100644
--- a/pkg/sentry/syscalls/linux/sys_random.go
+++ b/pkg/sentry/syscalls/linux/sys_random.go
@@ -19,11 +19,11 @@ import (
"math"
"gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index cd31e0649..3bbc3fa4b 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -23,13 +23,15 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
- // EventMaskRead contains events that can be triggerd on reads.
+ // EventMaskRead contains events that can be triggered on reads.
EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
)
@@ -46,7 +48,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is readable.
if !file.Flags().Read {
@@ -82,7 +84,7 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is readable.
if !file.Flags().Read {
@@ -94,8 +96,8 @@ func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, syserror.EINVAL
}
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -116,10 +118,10 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -162,7 +164,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is readable.
if !file.Flags().Read {
@@ -193,7 +195,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < 0 {
@@ -242,7 +244,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < -1 {
@@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go
index 51e3f836b..d5d5b6959 100644
--- a/pkg/sentry/syscalls/linux/sys_rlimit.go
+++ b/pkg/sentry/syscalls/linux/sys_rlimit.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/limits"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// rlimit describes an implementation of 'struct rlimit', which may vary from
@@ -197,7 +197,7 @@ func Prlimit64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
// saved set user IDs of the target process must match the real user ID of
// the caller and the real, effective, and saved set group IDs of the
// target process must match the real group ID of the caller."
- if !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) {
+ if ot != t && !t.HasCapabilityIn(linux.CAP_SYS_RESOURCE, t.PIDNamespace().UserNamespace()) {
cred, tcred := t.Credentials(), ot.Credentials()
if cred.RealKUID != tcred.RealKUID ||
cred.RealKUID != tcred.EffectiveKUID ||
diff --git a/pkg/sentry/syscalls/linux/sys_rseq.go b/pkg/sentry/syscalls/linux/sys_rseq.go
new file mode 100644
index 000000000..90db10ea6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_rseq.go
@@ -0,0 +1,48 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// RSeq implements syscall rseq(2).
+func RSeq(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Uint()
+ flags := args[2].Int()
+ signature := args[3].Uint()
+
+ if !t.RSeqAvailable() {
+ // Event for applications that want rseq on a configuration
+ // that doesn't support them.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ switch flags {
+ case 0:
+ // Register.
+ return 0, nil, t.SetRSeq(addr, length, signature)
+ case linux.RSEQ_FLAG_UNREGISTER:
+ return 0, nil, t.ClearRSeq(addr, length, signature)
+ default:
+ // Unknown flag.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go
index 18510ead8..5b7a66f4d 100644
--- a/pkg/sentry/syscalls/linux/sys_seccomp.go
+++ b/pkg/sentry/syscalls/linux/sys_seccomp.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// userSockFprog is equivalent to Linux's struct sock_fprog on amd64.
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index cde3b54e7..5f54f2456 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -22,8 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const opsMax = 500 // SEMOPM
diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go
index d57ffb3a1..f0ae8fa8e 100644
--- a/pkg/sentry/syscalls/linux/sys_shm.go
+++ b/pkg/sentry/syscalls/linux/sys_shm.go
@@ -39,10 +39,13 @@ func Shmget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, err
}
+ defer segment.DecRef(t)
return uintptr(segment.ID), nil, nil
}
// findSegment retrives a shm segment by the given id.
+//
+// findSegment returns a reference on Shm.
func findSegment(t *kernel.Task, id shm.ID) (*shm.Shm, error) {
r := t.IPCNamespace().ShmRegistry()
segment := r.FindByID(id)
@@ -63,6 +66,7 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef(t)
opts, err := segment.ConfigureAttach(t, addr, shm.AttachOpts{
Execute: flag&linux.SHM_EXEC == linux.SHM_EXEC,
@@ -72,7 +76,6 @@ func Shmat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- defer segment.DecRef()
addr, err = t.MemoryManager().MMap(t, opts)
return uintptr(addr), nil, err
}
@@ -105,6 +108,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef(t)
stat, err := segment.IPCStat(t)
if err == nil {
@@ -128,6 +132,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if err != nil {
return 0, nil, syserror.EINVAL
}
+ defer segment.DecRef(t)
switch cmd {
case linux.IPC_SET:
@@ -140,7 +145,7 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, err
case linux.IPC_RMID:
- segment.MarkDestroyed()
+ segment.MarkDestroyed(t)
return 0, nil, nil
case linux.SHM_LOCK, linux.SHM_UNLOCK:
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index fb6efd5d8..20cb1a5cb 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// "For a process to have permission to send a signal it must
@@ -245,6 +245,11 @@ func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
sig := linux.Signal(args[0].Int())
newactarg := args[1].Pointer()
oldactarg := args[2].Pointer()
+ sigsetsize := args[3].SizeT()
+
+ if sigsetsize != linux.SignalSetSize {
+ return 0, nil, syserror.EINVAL
+ }
var newactptr *arch.SignalAct
if newactarg != 0 {
@@ -290,7 +295,7 @@ func RtSigprocmask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
}
oldmask := t.SignalMask()
if setaddr != 0 {
- mask, err := copyInSigSet(t, setaddr, sigsetsize)
+ mask, err := CopyInSigSet(t, setaddr, sigsetsize)
if err != nil {
return 0, nil, err
}
@@ -350,7 +355,7 @@ func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
func RtSigpending(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
pending := t.PendingSignals()
- _, err := t.CopyOut(addr, pending)
+ _, err := pending.CopyOut(t, addr)
return 0, nil, err
}
@@ -361,7 +366,7 @@ func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
timespec := args[2].Pointer()
sigsetsize := args[3].SizeT()
- mask, err := copyInSigSet(t, sigset, sigsetsize)
+ mask, err := CopyInSigSet(t, sigset, sigsetsize)
if err != nil {
return 0, nil, err
}
@@ -387,7 +392,7 @@ func RtSigtimedwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if siginfo != 0 {
si.FixSignalCodeForUser()
- if _, err := t.CopyOut(siginfo, si); err != nil {
+ if _, err := si.CopyOut(t, siginfo); err != nil {
return 0, nil, err
}
}
@@ -406,7 +411,7 @@ func RtSigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// same way), and that the code is in the allowed set. This same logic
// appears below in RtSigtgqueueinfo and should be kept in sync.
var info arch.SignalInfo
- if _, err := t.CopyIn(infoAddr, &info); err != nil {
+ if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
info.Signo = int32(sig)
@@ -450,7 +455,7 @@ func RtTgsigqueueinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ker
// Copy in the info. See RtSigqueueinfo above.
var info arch.SignalInfo
- if _, err := t.CopyIn(infoAddr, &info); err != nil {
+ if _, err := info.CopyIn(t, infoAddr); err != nil {
return 0, nil, err
}
info.Signo = int32(sig)
@@ -480,7 +485,7 @@ func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.
// Copy in the signal mask.
var mask linux.SignalSet
- if _, err := t.CopyIn(sigset, &mask); err != nil {
+ if _, err := mask.CopyIn(t, sigset); err != nil {
return 0, nil, err
}
mask &^= kernel.UnblockableSignals
@@ -513,7 +518,7 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
// sharedSignalfd is shared between the two calls.
func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
// Copy in the signal mask.
- mask, err := copyInSigSet(t, sigset, sigsetsize)
+ mask, err := CopyInSigSet(t, sigset, sigsetsize)
if err != nil {
return 0, nil, err
}
@@ -531,7 +536,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Is this a signalfd?
if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok {
@@ -548,7 +553,7 @@ func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize ui
if err != nil {
return 0, nil, err
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Set appropriate flags.
file.SetFlags(fs.SettableFileFlags{
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index b5a72ce63..fec1c1974 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -26,11 +26,15 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
+// LINT.IfChange
+
// minListenBacklog is the minimum reasonable backlog for listening sockets.
const minListenBacklog = 8
@@ -41,7 +45,7 @@ const maxListenBacklog = 1024
const maxAddrLen = 200
// maxOptLen is the maximum sockopt parameter length we're willing to accept.
-const maxOptLen = 1024
+const maxOptLen = 1024 * 8
// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
// willing to accept. Note that this limit is smaller than Linux, which allows
@@ -196,7 +200,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
s.SetFlags(fs.SettableFileFlags{
NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
})
- defer s.DecRef()
+ defer s.DecRef(t)
fd, err := t.NewFDFrom(0, s, kernel.FDFlags{
CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
@@ -231,8 +235,8 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
}
s1.SetFlags(fileFlags)
s2.SetFlags(fileFlags)
- defer s1.DecRef()
- defer s2.DecRef()
+ defer s1.DecRef(t)
+ defer s2.DecRef(t)
// Create the FDs for the sockets.
fds, err := t.NewFDs(0, []*fs.File{s1, s2}, kernel.FDFlags{
@@ -244,7 +248,11 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// Copy the file descriptors out.
if _, err := t.CopyOut(socks, fds); err != nil {
- // Note that we don't close files here; see pipe(2) also.
+ for _, fd := range fds {
+ if file, _ := t.FDTable().Remove(fd); file != nil {
+ file.DecRef(t)
+ }
+ }
return 0, nil, err
}
@@ -262,7 +270,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -293,7 +301,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -352,7 +360,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -379,7 +387,7 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -408,7 +416,7 @@ func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -439,7 +447,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -447,16 +455,13 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- // Read the length if present. Reject negative values.
+ // Read the length. Reject negative values.
optLen := int32(0)
- if optLenAddr != 0 {
- if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
- return 0, nil, err
- }
-
- if optLen < 0 {
- return 0, nil, syserror.EINVAL
- }
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
}
// Call syscall implementation then copy both value and value len out.
@@ -465,15 +470,13 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, e.ToError()
}
- if optLenAddr != 0 {
- vLen := int32(binary.Size(v))
- if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
- return 0, nil, err
- }
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
}
if v != nil {
- if _, err := t.CopyOut(optValAddr, v); err != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
return 0, nil, err
}
}
@@ -483,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
// getSockOpt tries to handle common socket options, or dispatches to a specific
// socket implementation.
-func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) {
+func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
if level == linux.SOL_SOCKET {
switch name {
case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
@@ -495,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use
switch name {
case linux.SO_TYPE:
_, skType, _ := s.Type()
- return int32(skType), nil
+ v := primitive.Int32(skType)
+ return &v, nil
case linux.SO_DOMAIN:
family, _, _ := s.Type()
- return int32(family), nil
+ v := primitive.Int32(family)
+ return &v, nil
case linux.SO_PROTOCOL:
_, _, protocol := s.Type()
- return int32(protocol), nil
+ v := primitive.Int32(protocol)
+ return &v, nil
}
}
@@ -523,7 +529,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -538,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.EINVAL
}
buf := t.CopyScratchBuffer(int(optLen))
- if _, err := t.CopyIn(optValAddr, &buf); err != nil {
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
return 0, nil, err
}
@@ -561,7 +567,7 @@ func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -589,7 +595,7 @@ func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -622,7 +628,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -675,7 +681,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -769,7 +775,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
}
if !cms.Unix.Empty() {
mflags |= linux.MSG_CTRUNC
- cms.Unix.Release()
+ cms.Release(t)
}
if int(msg.Flags) != mflags {
@@ -789,24 +795,16 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
- defer cms.Unix.Release()
+ defer cms.Release(t)
controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
}
- if cms.IP.HasTimestamp {
- controlData = control.PackTimestamp(t, cms.IP.Timestamp, controlData)
- }
-
- if cms.IP.HasInq {
- // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
- controlData = control.PackInq(t, cms.IP.Inq, controlData)
- }
-
if cms.Unix.Rights != nil {
controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
}
@@ -853,7 +851,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -882,7 +880,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag
}
n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
- cm.Unix.Release()
+ cm.Release(t)
if e != nil {
return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
}
@@ -926,7 +924,7 @@ func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -964,7 +962,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -1065,10 +1063,10 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
}
// Call the syscall implementation.
- n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: controlMessages})
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
if err != nil {
- controlMessages.Release()
+ controlMessages.Release(t)
}
return uintptr(n), err
}
@@ -1086,7 +1084,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags
if file == nil {
return 0, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Extract the socket.
s, ok := file.FileOperations.(socket.Socket)
@@ -1141,3 +1139,5 @@ func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
return n, nil, err
}
+
+// LINT.ThenChange(./vfs2/socket.go)
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index dd3a5807f..b8846a10a 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -25,6 +25,14 @@ import (
// doSplice implements a blocking splice operation.
func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonBlocking bool) (int64, error) {
+ if opts.Length < 0 || opts.SrcStart < 0 || opts.DstStart < 0 || (opts.SrcStart+opts.Length < 0) {
+ return 0, syserror.EINVAL
+ }
+
+ if opts.Length > int64(kernel.MAX_RW_COUNT) {
+ opts.Length = int64(kernel.MAX_RW_COUNT)
+ }
+
var (
total int64
n int64
@@ -72,6 +80,12 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
}
}
+ if total > 0 {
+ // On Linux, inotify behavior is not very consistent with splice(2). We try
+ // our best to emulate Linux for very basic calls to splice, where for some
+ // reason, events are generated for output files, but not input files.
+ outFile.Dirent.InotifyEvent(linux.IN_MODIFY, 0)
+ }
return total, err
}
@@ -82,17 +96,12 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
offsetAddr := args[2].Pointer()
count := int64(args[3].SizeT())
- // Don't send a negative number of bytes.
- if count < 0 {
- return 0, nil, syserror.EINVAL
- }
-
// Get files.
inFile := t.GetFile(inFD)
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
if !inFile.Flags().Read {
return 0, nil, syserror.EBADF
@@ -102,7 +111,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
if !outFile.Flags().Write {
return 0, nil, syserror.EBADF
@@ -136,11 +145,6 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, err
}
- // The offset must be valid.
- if offset < 0 {
- return 0, nil, syserror.EINVAL
- }
-
// Do the splice.
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
@@ -188,13 +192,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
inFile := t.GetFile(inFD)
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
// The operation is non-blocking if anything is non-blocking.
//
@@ -211,8 +215,10 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
opts := fs.SpliceOpts{
Length: count,
}
+ inFileAttr := inFile.Dirent.Inode.StableAttr
+ outFileAttr := outFile.Dirent.Inode.StableAttr
switch {
- case fs.IsPipe(inFile.Dirent.Inode.StableAttr) && !fs.IsPipe(outFile.Dirent.Inode.StableAttr):
+ case fs.IsPipe(inFileAttr) && !fs.IsPipe(outFileAttr):
if inOffset != 0 {
return 0, nil, syserror.ESPIPE
}
@@ -225,11 +231,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if _, err := t.CopyIn(outOffset, &offset); err != nil {
return 0, nil, err
}
+
// Use the destination offset.
opts.DstOffset = true
opts.DstStart = offset
}
- case !fs.IsPipe(inFile.Dirent.Inode.StableAttr) && fs.IsPipe(outFile.Dirent.Inode.StableAttr):
+ case !fs.IsPipe(inFileAttr) && fs.IsPipe(outFileAttr):
if outOffset != 0 {
return 0, nil, syserror.ESPIPE
}
@@ -242,17 +249,18 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if _, err := t.CopyIn(inOffset, &offset); err != nil {
return 0, nil, err
}
+
// Use the source offset.
opts.SrcOffset = true
opts.SrcStart = offset
}
- case fs.IsPipe(inFile.Dirent.Inode.StableAttr) && fs.IsPipe(outFile.Dirent.Inode.StableAttr):
+ case fs.IsPipe(inFileAttr) && fs.IsPipe(outFileAttr):
if inOffset != 0 || outOffset != 0 {
return 0, nil, syserror.ESPIPE
}
// We may not refer to the same pipe; otherwise it's a continuous loop.
- if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ if inFileAttr.InodeID == outFileAttr.InodeID {
return 0, nil, syserror.EINVAL
}
default:
@@ -262,6 +270,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
// Splice data.
n, err := doSplice(t, outFile, inFile, opts, nonBlock)
+ // Special files can have additional requirements for granularity. For
+ // example, read from eventfd returns EINVAL if a size is less 8 bytes.
+ // Inotify is another example. read will return EINVAL is a buffer is
+ // too small to return the next event, but a size of an event isn't
+ // fixed, it is sizeof(struct inotify_event) + {NAME_LEN} + 1.
+ if n != 0 && err != nil && (fs.IsAnonymous(inFileAttr) || fs.IsAnonymous(outFileAttr)) {
+ err = nil
+ }
+
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile)
}
@@ -283,13 +300,13 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
if outFile == nil {
return 0, nil, syserror.EBADF
}
- defer outFile.DecRef()
+ defer outFile.DecRef(t)
inFile := t.GetFile(inFD)
if inFile == nil {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
+ defer inFile.DecRef(t)
// All files must be pipes.
if !fs.IsPipe(inFile.Dirent.Inode.StableAttr) || !fs.IsPipe(outFile.Dirent.Inode.StableAttr) {
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 5556bc276..a5826f2dd 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -16,14 +16,15 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Stat implements linux syscall stat(2).
func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -57,7 +58,7 @@ func Fstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, fstat(t, file, statAddr)
}
@@ -99,7 +100,7 @@ func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, fstat(t, file, statAddr)
}
@@ -113,7 +114,9 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
if err != nil {
return err
}
- return copyOutStat(t, statAddr, d.Inode.StableAttr, uattr)
+ s := statFromAttrs(t, d.Inode.StableAttr, uattr)
+ _, err = s.CopyOut(t, statAddr)
+ return err
}
// fstat implements fstat for the given *fs.File.
@@ -122,56 +125,8 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
if err != nil {
return err
}
- return copyOutStat(t, statAddr, f.Dirent.Inode.StableAttr, uattr)
-}
-
-// copyOutStat copies the attributes (sattr, uattr) to the struct stat at
-// address dst in t's address space. It encodes the stat struct to bytes
-// manually, as stat() is a very common syscall for many applications, and
-// t.CopyObjectOut has noticeable performance impact due to its many slice
-// allocations and use of reflection.
-func copyOutStat(t *kernel.Task, dst usermem.Addr, sattr fs.StableAttr, uattr fs.UnstableAttr) error {
- b := t.CopyScratchBuffer(int(linux.SizeOfStat))[:0]
-
- // Dev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.DeviceID))
- // Ino (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.InodeID))
- // Nlink (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uattr.Links)
- // Mode (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, sattr.Type.LinuxType()|uint32(uattr.Perms.LinuxMode()))
- // UID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
- // GID (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()))
- // Padding (uint32)
- b = binary.AppendUint32(b, usermem.ByteOrder, 0)
- // Rdev (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)))
- // Size (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Size))
- // Blksize (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(sattr.BlockSize))
- // Blocks (uint64)
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(uattr.Usage/512))
-
- // ATime
- atime := uattr.AccessTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(atime.Nsec))
-
- // MTime
- mtime := uattr.ModificationTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(mtime.Nsec))
-
- // CTime
- ctime := uattr.StatusChangeTime.Timespec()
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Sec))
- b = binary.AppendUint64(b, usermem.ByteOrder, uint64(ctime.Nsec))
-
- _, err := t.CopyOutBytes(dst, b)
+ s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
+ _, err = s.CopyOut(t, statAddr)
return err
}
@@ -183,7 +138,10 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
mask := args[3].Uint()
statxAddr := args[4].Pointer()
- if mask&linux.STATX__RESERVED > 0 {
+ if mask&linux.STATX__RESERVED != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.AT_SYMLINK_NOFOLLOW|linux.AT_EMPTY_PATH|linux.AT_STATX_SYNC_TYPE) != 0 {
return 0, nil, syserror.EINVAL
}
if flags&linux.AT_STATX_SYNC_TYPE == linux.AT_STATX_SYNC_TYPE {
@@ -200,7 +158,7 @@ func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
uattr, err := file.UnstableAttr(t)
if err != nil {
return 0, nil, err
@@ -291,7 +249,7 @@ func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
return 0, nil, statfsImpl(t, file.Dirent, statfsAddr)
}
@@ -328,3 +286,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
_, err = t.CopyOut(addr, &statfs)
return err
}
+
+// LINT.ThenChange(vfs2/stat.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_amd64.go b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
new file mode 100644
index 000000000..0a04a6113
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_amd64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uattr.Links,
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: sattr.BlockSize,
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_amd64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat_arm64.go b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
new file mode 100644
index 000000000..5a3b1bfad
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_stat_arm64.go
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// LINT.IfChange
+
+func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
+ return linux.Stat{
+ Dev: sattr.DeviceID,
+ Ino: sattr.InodeID,
+ Nlink: uint32(uattr.Links),
+ Mode: sattr.Type.LinuxType() | uint32(uattr.Perms.LinuxMode()),
+ UID: uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()),
+ GID: uint32(uattr.Owner.GID.In(t.UserNamespace()).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(sattr.DeviceFileMajor, sattr.DeviceFileMinor)),
+ Size: uattr.Size,
+ Blksize: int32(sattr.BlockSize),
+ Blocks: uattr.Usage / 512,
+ ATime: uattr.AccessTime.Timespec(),
+ MTime: uattr.ModificationTime.Timespec(),
+ CTime: uattr.StatusChangeTime.Timespec(),
+ }
+}
+
+// LINT.ThenChange(vfs2/stat_arm64.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 3e55235bd..f2c0e5069 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Sync implements linux system call sync(2).
func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
t.MountNamespace().SyncAll(t)
@@ -37,7 +39,7 @@ func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Use "sync-the-world" for now, it's guaranteed that fd is at least
// on the root filesystem.
@@ -52,7 +54,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll)
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
@@ -68,7 +70,7 @@ func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData)
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
@@ -101,7 +103,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// SYNC_FILE_RANGE_WAIT_BEFORE waits upon write-out of all pages in the
// specified range that have already been submitted to the device
@@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
+
+// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go
index a65b560c8..297de052a 100644
--- a/pkg/sentry/syscalls/linux/sys_sysinfo.go
+++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go
@@ -29,13 +29,18 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
mf.UpdateUsage()
_, totalUsage := usage.MemoryAccounting.Copy()
totalSize := usage.TotalMemory(mf.TotalSize(), totalUsage)
+ memFree := totalSize - totalUsage
+ if memFree > totalSize {
+ // Underflow.
+ memFree = 0
+ }
// Only a subset of the fields in sysinfo_t make sense to return.
si := linux.Sysinfo{
Procs: uint16(len(t.PIDNamespace().Tasks())),
Uptime: t.Kernel().MonotonicClock().Now().Seconds(),
TotalRAM: totalSize,
- FreeRAM: totalSize - totalUsage,
+ FreeRAM: memFree,
Unit: 1,
}
_, err := t.CopyOut(addr, si)
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 2476f8858..2d16e4933 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -21,11 +21,12 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
"gvisor.dev/gvisor/pkg/sentry/loader"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const (
@@ -116,10 +117,11 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
resolveFinal := flags&linux.AT_SYMLINK_NOFOLLOW == 0
root := t.FSContext().RootDirectory()
- defer root.DecRef()
+ defer root.DecRef(t)
var wd *fs.Dirent
- var executable *fs.File
+ var executable fsbridge.File
+ var closeOnExec bool
if dirFD == linux.AT_FDCWD || path.IsAbs(pathname) {
// Even if the pathname is absolute, we may still need the wd
// for interpreter scripts if the path of the interpreter is
@@ -127,14 +129,23 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
wd = t.FSContext().WorkingDirectory()
} else {
// Need to extract the given FD.
- f := t.GetFile(dirFD)
+ f, fdFlags := t.FDTable().Get(dirFD)
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
+ closeOnExec = fdFlags.CloseOnExec
if atEmptyPath && len(pathname) == 0 {
- executable = f
+ // TODO(gvisor.dev/issue/160): Linux requires only execute permission,
+ // not read. However, our backing filesystems may prevent us from reading
+ // the file without read permission. Additionally, a task with a
+ // non-readable executable has additional constraints on access via
+ // ptrace and procfs.
+ if err := f.Dirent.Inode.CheckPermission(t, fs.PermMask{Read: true, Execute: true}); err != nil {
+ return 0, nil, err
+ }
+ executable = fsbridge.NewFSFile(f)
} else {
wd = f.Dirent
wd.IncRef()
@@ -144,19 +155,18 @@ func execveat(t *kernel.Task, dirFD int32, pathnameAddr, argvAddr, envvAddr user
}
}
if wd != nil {
- defer wd.DecRef()
+ defer wd.DecRef(t)
}
// Load the new TaskContext.
remainingTraversals := uint(linux.MaxSymlinkTraversals)
loadArgs := loader.LoadArgs{
- Mounts: t.MountNamespace(),
- Root: root,
- WorkingDirectory: wd,
+ Opener: fsbridge.NewFSLookup(t.MountNamespace(), root, wd),
RemainingTraversals: &remainingTraversals,
ResolveFinal: resolveFinal,
Filename: pathname,
File: executable,
+ CloseOnExec: closeOnExec,
Argv: argv,
Envv: envv,
Features: t.Arch().FeatureSet(),
@@ -217,19 +227,6 @@ func clone(t *kernel.Task, flags int, stack usermem.Addr, parentTID usermem.Addr
return uintptr(ntid), ctrl, err
}
-// Clone implements linux syscall clone(2).
-// sys_clone has so many flavors. We implement the default one in linux 3.11
-// x86_64:
-// sys_clone(clone_flags, newsp, parent_tidptr, child_tidptr, tls_val)
-func Clone(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- flags := int(args[0].Int())
- stack := args[1].Pointer()
- parentTID := args[2].Pointer()
- childTID := args[3].Pointer()
- tls := args[4].Pointer()
- return clone(t, flags, stack, parentTID, childTID, tls)
-}
-
// Fork implements Linux syscall fork(2).
func Fork(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
// "A call to fork() is equivalent to a call to clone(2) specifying flags
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index b887fa9d7..2d2aa0819 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -22,8 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// The most significant 29 bits hold either a pid or a file descriptor.
diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go
index d4134207b..a4c400f87 100644
--- a/pkg/sentry/syscalls/linux/sys_timer.go
+++ b/pkg/sentry/syscalls/linux/sys_timer.go
@@ -20,8 +20,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
const nsecPerSec = int64(time.Second)
@@ -146,7 +146,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, err
}
- return uintptr(id), nil, nil
+ return 0, nil, nil
}
// TimerSettime implements linux syscall timer_settime(2).
diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go
index cf49b43db..34b03e4ee 100644
--- a/pkg/sentry/syscalls/linux/sys_timerfd.go
+++ b/pkg/sentry/syscalls/linux/sys_timerfd.go
@@ -43,7 +43,7 @@ func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.EINVAL
}
f := timerfd.NewFile(t, c)
- defer f.DecRef()
+ defer f.DecRef(t)
f.SetFlags(fs.SettableFileFlags{
NonBlocking: flags&linux.TFD_NONBLOCK != 0,
})
@@ -73,7 +73,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
tf, ok := f.FileOperations.(*timerfd.TimerOperations)
if !ok {
@@ -107,7 +107,7 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
if f == nil {
return 0, nil, syserror.EBADF
}
- defer f.DecRef()
+ defer f.DecRef(t)
tf, ok := f.FileOperations.(*timerfd.TimerOperations)
if !ok {
diff --git a/pkg/sentry/syscalls/linux/sys_tls.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
index b3eb96a1c..b3eb96a1c 100644
--- a/pkg/sentry/syscalls/linux/sys_tls.go
+++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go
diff --git a/pkg/sentry/syscalls/linux/sys_tls_arm64.go b/pkg/sentry/syscalls/linux/sys_tls_arm64.go
new file mode 100644
index 000000000..fb08a356e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_tls_arm64.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//+build arm64
+
+package linux
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// ArchPrctl is not defined for ARM64.
+func ArchPrctl(*kernel.Task, arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, syserror.ENOSYS
+}
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 748e8dd8d..e9d702e8e 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64
-
package linux
import (
@@ -35,7 +33,15 @@ func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
copy(u.Nodename[:], uts.HostName())
copy(u.Release[:], version.Release)
copy(u.Version[:], version.Version)
- copy(u.Machine[:], "x86_64") // build tag above.
+ // build tag above.
+ switch t.SyscallTable().Arch {
+ case arch.AMD64:
+ copy(u.Machine[:], "x86_64")
+ case arch.ARM64:
+ copy(u.Machine[:], "aarch64")
+ default:
+ copy(u.Machine[:], "unknown")
+ }
copy(u.Domainname[:], uts.DomainName())
// Copy out the result.
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index ad4b67806..485526e28 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -23,11 +23,13 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskWrite contains events that can be triggered on writes.
//
@@ -46,7 +48,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is writable.
if !file.Flags().Write {
@@ -83,10 +85,10 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
- // Check that the offset is legitimate.
- if offset < 0 {
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
return 0, nil, syserror.EINVAL
}
@@ -129,7 +131,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the file is writable.
if !file.Flags().Write {
@@ -160,7 +162,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < 0 {
@@ -213,7 +215,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
if file == nil {
return 0, nil, syserror.EBADF
}
- defer file.DecRef()
+ defer file.DecRef(t)
// Check that the offset is legitimate.
if offset < -1 {
@@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
new file mode 100644
index 000000000..97474fd3c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -0,0 +1,432 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+import (
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// LINT.IfChange
+
+// GetXattr implements linux syscall getxattr(2).
+func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getXattrFromPath(t, args, true)
+}
+
+// LGetXattr implements linux syscall lgetxattr(2).
+func LGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getXattrFromPath(t, args, false)
+}
+
+// FGetXattr implements linux syscall fgetxattr(2).
+func FGetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef(t)
+
+ n, err := getXattr(t, f.Dirent, nameAddr, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func getXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n := 0
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ n, err = getXattr(t, d, nameAddr, valueAddr, size)
+ return err
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+// getXattr implements getxattr(2) from the given *fs.Dirent.
+func getXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64) (int, error) {
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Read: true}); err != nil {
+ return 0, err
+ }
+
+ // TODO(b/148380782): Support xattrs in namespaces other than "user".
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ // If getxattr(2) is called with size 0, the size of the value will be
+ // returned successfully even if it is nonzero. In that case, we need to
+ // retrieve the entire attribute value so we can return the correct size.
+ requestedSize := size
+ if size == 0 || size > linux.XATTR_SIZE_MAX {
+ requestedSize = linux.XATTR_SIZE_MAX
+ }
+
+ value, err := d.Inode.GetXattr(t, name, requestedSize)
+ if err != nil {
+ return 0, err
+ }
+ n := len(value)
+ if uint64(n) > requestedSize {
+ return 0, syserror.ERANGE
+ }
+
+ // Don't copy out the attribute value if size is 0.
+ if size == 0 {
+ return n, nil
+ }
+
+ if _, err = t.CopyOutBytes(valueAddr, []byte(value)); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
+// SetXattr implements linux syscall setxattr(2).
+func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return setXattrFromPath(t, args, true)
+}
+
+// LSetXattr implements linux syscall lsetxattr(2).
+func LSetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return setXattrFromPath(t, args, false)
+}
+
+// FSetXattr implements linux syscall fsetxattr(2).
+func FSetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+ flags := args[4].Uint()
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef(t)
+
+ return 0, nil, setXattr(t, f.Dirent, nameAddr, valueAddr, uint64(size), flags)
+}
+
+func setXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := uint64(args[3].SizeT())
+ flags := args[4].Uint()
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ return setXattr(t, d, nameAddr, valueAddr, uint64(size), flags)
+ })
+}
+
+// setXattr implements setxattr(2) from the given *fs.Dirent.
+func setXattr(t *kernel.Task, d *fs.Dirent, nameAddr, valueAddr usermem.Addr, size uint64, flags uint32) error {
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ if size > linux.XATTR_SIZE_MAX {
+ return syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return err
+ }
+ value := string(buf)
+
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+
+ if err := d.Inode.SetXattr(t, d, name, value, flags); err != nil {
+ return err
+ }
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+// Restrict xattrs to regular files and directories.
+//
+// TODO(b/148380782): In Linux, this restriction technically only applies to
+// xattrs in the "user.*" namespace. Make file type checks specific to the
+// namespace once we allow other xattr prefixes.
+func xattrFileTypeOk(i *fs.Inode) bool {
+ return fs.IsRegular(i.StableAttr) || fs.IsDir(i.StableAttr)
+}
+
+func checkXattrPermissions(t *kernel.Task, i *fs.Inode, perms fs.PermMask) error {
+ // Restrict xattrs to regular files and directories.
+ if !xattrFileTypeOk(i) {
+ if perms.Write {
+ return syserror.EPERM
+ }
+ return syserror.ENODATA
+ }
+
+ return i.CheckPermission(t, perms)
+}
+
+// ListXattr implements linux syscall listxattr(2).
+func ListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listXattrFromPath(t, args, true)
+}
+
+// LListXattr implements linux syscall llistxattr(2).
+func LListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listXattrFromPath(t, args, false)
+}
+
+// FListXattr implements linux syscall flistxattr(2).
+func FListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := uint64(args[2].SizeT())
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef(t)
+
+ n, err := listXattr(t, f.Dirent, listAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func listXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := uint64(args[2].SizeT())
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n := 0
+ err = fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ n, err = listXattr(t, d, listAddr, size)
+ return err
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(n), nil, nil
+}
+
+func listXattr(t *kernel.Task, d *fs.Dirent, addr usermem.Addr, size uint64) (int, error) {
+ if !xattrFileTypeOk(d.Inode) {
+ return 0, nil
+ }
+
+ // If listxattr(2) is called with size 0, the buffer size needed to contain
+ // the xattr list will be returned successfully even if it is nonzero. In
+ // that case, we need to retrieve the entire list so we can compute and
+ // return the correct size.
+ requestedSize := size
+ if size == 0 || size > linux.XATTR_SIZE_MAX {
+ requestedSize = linux.XATTR_SIZE_MAX
+ }
+ xattrs, err := d.Inode.ListXattr(t, requestedSize)
+ if err != nil {
+ return 0, err
+ }
+
+ // TODO(b/148380782): support namespaces other than "user".
+ for x := range xattrs {
+ if !strings.HasPrefix(x, linux.XATTR_USER_PREFIX) {
+ delete(xattrs, x)
+ }
+ }
+
+ listSize := xattrListSize(xattrs)
+ if listSize > linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ if uint64(listSize) > requestedSize {
+ return 0, syserror.ERANGE
+ }
+
+ // Don't copy out the attributes if size is 0.
+ if size == 0 {
+ return listSize, nil
+ }
+
+ buf := make([]byte, 0, listSize)
+ for x := range xattrs {
+ buf = append(buf, []byte(x)...)
+ buf = append(buf, 0)
+ }
+ if _, err := t.CopyOutBytes(addr, buf); err != nil {
+ return 0, err
+ }
+
+ return len(buf), nil
+}
+
+func xattrListSize(xattrs map[string]struct{}) int {
+ size := 0
+ for x := range xattrs {
+ size += len(x) + 1
+ }
+ return size
+}
+
+// RemoveXattr implements linux syscall removexattr(2).
+func RemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return removeXattrFromPath(t, args, true)
+}
+
+// LRemoveXattr implements linux syscall lremovexattr(2).
+func LRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return removeXattrFromPath(t, args, false)
+}
+
+// FRemoveXattr implements linux syscall fremovexattr(2).
+func FRemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ // TODO(b/113957122): Return EBADF if the fd was opened with O_PATH.
+ f := t.GetFile(fd)
+ if f == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer f.DecRef(t)
+
+ return 0, nil, removeXattr(t, f.Dirent, nameAddr)
+}
+
+func removeXattrFromPath(t *kernel.Task, args arch.SyscallArguments, resolveSymlink bool) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, dirPath, err := copyInPath(t, pathAddr, false /* allowEmpty */)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, fileOpOn(t, linux.AT_FDCWD, path, resolveSymlink, func(_ *fs.Dirent, d *fs.Dirent, _ uint) error {
+ if dirPath && !fs.IsDir(d.Inode.StableAttr) {
+ return syserror.ENOTDIR
+ }
+
+ return removeXattr(t, d, nameAddr)
+ })
+}
+
+// removeXattr implements removexattr(2) from the given *fs.Dirent.
+func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ if err := checkXattrPermissions(t, d.Inode, fs.PermMask{Write: true}); err != nil {
+ return err
+ }
+
+ if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) {
+ return syserror.EOPNOTSUPP
+ }
+
+ if err := d.Inode.RemoveXattr(t, d, name); err != nil {
+ return err
+ }
+ d.InotifyEvent(linux.IN_ATTRIB, 0)
+ return nil
+}
+
+// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/timespec.go b/pkg/sentry/syscalls/linux/timespec.go
index 4ff8f9234..ddc3ee26e 100644
--- a/pkg/sentry/syscalls/linux/timespec.go
+++ b/pkg/sentry/syscalls/linux/timespec.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// copyTimespecIn copies a Timespec from the untrusted app range to the kernel.
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
new file mode 100644
index 000000000..64696b438
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -0,0 +1,78 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "vfs2",
+ srcs = [
+ "aio.go",
+ "epoll.go",
+ "eventfd.go",
+ "execve.go",
+ "fd.go",
+ "filesystem.go",
+ "fscontext.go",
+ "getdents.go",
+ "inotify.go",
+ "ioctl.go",
+ "lock.go",
+ "memfd.go",
+ "mmap.go",
+ "mount.go",
+ "path.go",
+ "pipe.go",
+ "poll.go",
+ "read_write.go",
+ "setstat.go",
+ "signal.go",
+ "socket.go",
+ "splice.go",
+ "stat.go",
+ "stat_amd64.go",
+ "stat_arm64.go",
+ "sync.go",
+ "timerfd.go",
+ "vfs2.go",
+ "xattr.go",
+ ],
+ marshal = True,
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/bits",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/gohacks",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/eventfd",
+ "//pkg/sentry/fsimpl/pipefs",
+ "//pkg/sentry/fsimpl/signalfd",
+ "//pkg/sentry/fsimpl/timerfd",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/fasync",
+ "//pkg/sentry/kernel/pipe",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/mm",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/syscalls",
+ "//pkg/sentry/syscalls/linux",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
+ ],
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
new file mode 100644
index 000000000..42559bf69
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -0,0 +1,219 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/mm"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// IoSubmit implements linux syscall io_submit(2).
+func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := args[0].Uint64()
+ nrEvents := args[1].Int()
+ addr := args[2].Pointer()
+
+ if nrEvents < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ for i := int32(0); i < nrEvents; i++ {
+ // Copy in the address.
+ cbAddrNative := t.Arch().Native(0)
+ if _, err := t.CopyIn(addr, cbAddrNative); err != nil {
+ if i > 0 {
+ // Some successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Copy in this callback.
+ var cb linux.IOCallback
+ cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative))
+ if _, err := t.CopyIn(cbAddr, &cb); err != nil {
+ if i > 0 {
+ // Some have been successful.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Process this callback.
+ if err := submitCallback(t, id, &cb, cbAddr); err != nil {
+ if i > 0 {
+ // Partial success.
+ return uintptr(i), nil, nil
+ }
+ // Nothing done.
+ return 0, nil, err
+ }
+
+ // Advance to the next one.
+ addr += usermem.Addr(t.Arch().Width())
+ }
+
+ return uintptr(nrEvents), nil, nil
+}
+
+// submitCallback processes a single callback.
+func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr usermem.Addr) error {
+ if cb.Reserved2 != 0 {
+ return syserror.EINVAL
+ }
+
+ fd := t.GetFileVFS2(cb.FD)
+ if fd == nil {
+ return syserror.EBADF
+ }
+ defer fd.DecRef(t)
+
+ // Was there an eventFD? Extract it.
+ var eventFD *vfs.FileDescription
+ if cb.Flags&linux.IOCB_FLAG_RESFD != 0 {
+ eventFD = t.GetFileVFS2(cb.ResFD)
+ if eventFD == nil {
+ return syserror.EBADF
+ }
+ defer eventFD.DecRef(t)
+
+ // Check that it is an eventfd.
+ if _, ok := eventFD.Impl().(*eventfd.EventFileDescription); !ok {
+ return syserror.EINVAL
+ }
+ }
+
+ ioseq, err := memoryFor(t, cb)
+ if err != nil {
+ return err
+ }
+
+ // Check offset for reads/writes.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ if cb.Offset < 0 {
+ return syserror.EINVAL
+ }
+ }
+
+ // Prepare the request.
+ aioCtx, ok := t.MemoryManager().LookupAIOContext(t, id)
+ if !ok {
+ return syserror.EINVAL
+ }
+ if ready := aioCtx.Prepare(); !ready {
+ // Context is busy.
+ return syserror.EAGAIN
+ }
+
+ if eventFD != nil {
+ // The request is set. Make sure there's a ref on the file.
+ //
+ // This is necessary when the callback executes on completion,
+ // which is also what will release this reference.
+ eventFD.IncRef()
+ }
+
+ // Perform the request asynchronously.
+ fd.IncRef()
+ t.QueueAIO(getAIOCallback(t, fd, eventFD, cbAddr, cb, ioseq, aioCtx))
+ return nil
+}
+
+func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback {
+ return func(ctx context.Context) {
+ // Release references after completing the callback.
+ defer fd.DecRef(ctx)
+ if eventFD != nil {
+ defer eventFD.DecRef(ctx)
+ }
+
+ if aioCtx.Dead() {
+ aioCtx.CancelPendingRequest()
+ return
+ }
+ ev := &linux.IOEvent{
+ Data: cb.Data,
+ Obj: uint64(cbAddr),
+ }
+
+ var err error
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PREADV:
+ ev.Result, err = fd.PRead(ctx, ioseq, cb.Offset, vfs.ReadOptions{})
+ case linux.IOCB_CMD_PWRITE, linux.IOCB_CMD_PWRITEV:
+ ev.Result, err = fd.PWrite(ctx, ioseq, cb.Offset, vfs.WriteOptions{})
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC:
+ err = fd.Sync(ctx)
+ }
+
+ // Update the result.
+ if err != nil {
+ err = slinux.HandleIOErrorVFS2(t, ev.Result != 0 /* partial */, err, nil /* never interrupted */, "aio", fd)
+ ev.Result = -int64(kernel.ExtractErrno(err, 0))
+ }
+
+ // Queue the result for delivery.
+ aioCtx.FinishRequest(ev)
+
+ // Notify the event file if one was specified. This needs to happen
+ // *after* queueing the result to avoid racing with the thread we may
+ // wake up.
+ if eventFD != nil {
+ eventFD.Impl().(*eventfd.EventFileDescription).Signal(1)
+ }
+ }
+}
+
+// memoryFor returns appropriate memory for the given callback.
+func memoryFor(t *kernel.Task, cb *linux.IOCallback) (usermem.IOSequence, error) {
+ bytes := int(cb.Bytes)
+ if bytes < 0 {
+ // Linux also requires that this field fit in ssize_t.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+
+ // Since this I/O will be asynchronous with respect to t's task goroutine,
+ // we have no guarantee that t's AddressSpace will be active during the
+ // I/O.
+ switch cb.OpCode {
+ case linux.IOCB_CMD_PREAD, linux.IOCB_CMD_PWRITE:
+ return t.SingleIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_PREADV, linux.IOCB_CMD_PWRITEV:
+ return t.IovecsIOSequence(usermem.Addr(cb.Buf), bytes, usermem.IOOpts{
+ AddressSpaceActive: false,
+ })
+
+ case linux.IOCB_CMD_FSYNC, linux.IOCB_CMD_FDSYNC, linux.IOCB_CMD_NOOP:
+ return usermem.IOSequence{}, nil
+
+ default:
+ // Not a supported command.
+ return usermem.IOSequence{}, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
new file mode 100644
index 000000000..c62f03509
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -0,0 +1,228 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+var sizeofEpollEvent = (*linux.EpollEvent)(nil).SizeBytes()
+
+// EpollCreate1 implements Linux syscall epoll_create1(2).
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements Linux syscall epoll_create(2).
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater
+ // than zero" - epoll_create(2)
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements Linux syscall epoll_ctl(2).
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef(t)
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+ if epfile == file {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var event linux.EpollEvent
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.AddInterest(file, fd, event)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, ep.DeleteInterest(file, fd)
+ case linux.EPOLL_CTL_MOD:
+ if _, err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.ModifyInterest(file, fd, event)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
+ return 0, nil, syserror.EINVAL
+ }
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef(t)
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
+ // maxEvents), so that the buffer can be allocated on the stack.
+ var (
+ events [16]linux.EpollEvent
+ total int
+ ch chan struct{}
+ haveDeadline bool
+ deadline ktime.Time
+ )
+ for {
+ batchEvents := len(events)
+ if batchEvents > maxEvents {
+ batchEvents = maxEvents
+ }
+ n := ep.ReadEvents(events[:batchEvents])
+ maxEvents -= n
+ if n != 0 {
+ // Copy what we read out.
+ copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n])
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
+ eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
+ total += copiedEvents
+ if err != nil {
+ if total != 0 {
+ return uintptr(total), nil, nil
+ }
+ return 0, nil, err
+ }
+ // If we've filled the application's event buffer, we're done.
+ if maxEvents == 0 {
+ return uintptr(total), nil, nil
+ }
+ // Loop if we read a full batch, under the expectation that there
+ // may be more events to read.
+ if n == batchEvents {
+ continue
+ }
+ }
+ // We get here if n != batchEvents. If we read any number of events
+ // (just now, or in a previous iteration of this loop), or if timeout
+ // is 0 (such that epoll_wait should be non-blocking), return the
+ // events we've read so far to the application.
+ if total != 0 || timeout == 0 {
+ return uintptr(total), nil, nil
+ }
+ // In the first iteration of this loop, register with the epoll
+ // instance for readability events, but then immediately continue the
+ // loop since we need to retry ReadEvents() before blocking. In all
+ // subsequent iterations, block until events are available, the timeout
+ // expires, or an interrupt arrives.
+ if ch == nil {
+ var w waiter.Entry
+ w, ch = waiter.NewChannelEntry(nil)
+ epfile.EventRegister(&w, waiter.EventIn)
+ defer epfile.EventUnregister(&w)
+ } else {
+ // Set up the timer if a timeout was specified.
+ if timeout > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ // total must be 0 since otherwise we would have returned
+ // above.
+ return 0, nil, err
+ }
+ }
+ }
+}
+
+// EpollPwait implements Linux syscall epoll_pwait(2).
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/eventfd.go b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
new file mode 100644
index 000000000..807f909da
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/eventfd.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Eventfd2 implements linux syscall eventfd2(2).
+func Eventfd2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ initVal := uint64(args[0].Uint())
+ flags := uint(args[1].Uint())
+ allOps := uint(linux.EFD_SEMAPHORE | linux.EFD_NONBLOCK | linux.EFD_CLOEXEC)
+
+ if flags & ^allOps != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ vfsObj := t.Kernel().VFS()
+ fileFlags := uint32(linux.O_RDWR)
+ if flags&linux.EFD_NONBLOCK != 0 {
+ fileFlags |= linux.O_NONBLOCK
+ }
+ semMode := flags&linux.EFD_SEMAPHORE != 0
+ eventfd, err := eventfd.New(t, vfsObj, initVal, semMode, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer eventfd.DecRef(t)
+
+ fd, err := t.NewFDFromVFS2(0, eventfd, kernel.FDFlags{
+ CloseOnExec: flags&linux.EFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// Eventfd implements linux syscall eventfd(2).
+func Eventfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[1].Value = 0
+ return Eventfd2(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
new file mode 100644
index 000000000..066ee0863
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathnameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+ return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef(t)
+ var executable fsbridge.File
+ closeOnExec := false
+ if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
+ // We must open the executable ourselves since dirfd is used as the
+ // starting point while resolving path, but the task working directory
+ // is used as the starting point while resolving interpreters (Linux:
+ // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() =>
+ // do_open_execat(fd=AT_FDCWD)), and the loader package is currently
+ // incapable of handling this correctly.
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ start := dirfile.VirtualDentry()
+ start.IncRef()
+ dirfile.DecRef(t)
+ closeOnExec = dirfileFlags.CloseOnExec
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ })
+ start.DecRef(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+ executable = fsbridge.NewVFSFile(file)
+ }
+
+ // Load the new TaskContext.
+ mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
+ defer mntns.DecRef(t)
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef(t)
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewVFSLookup(mntns, root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
new file mode 100644
index 000000000..4856554fe
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -0,0 +1,355 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/fasync"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Close implements Linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ _, file := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ err := file.OnClose(t)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements Linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements Linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ if oldfd == newfd {
+ // As long as oldfd is valid, dup2() does nothing and returns newfd.
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ file.DecRef(t)
+ return uintptr(newfd), nil, nil
+ }
+
+ return dup3(t, oldfd, newfd, 0)
+}
+
+// Dup3 implements Linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return dup3(t, oldfd, newfd, flags)
+}
+
+func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^linux.O_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().GetVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ minfd := args[2].Int()
+ fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, err
+ case linux.F_GETFL:
+ return uintptr(file.StatusFlags()), nil, nil
+ case linux.F_SETFL:
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ case linux.F_SETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ n, err := pipefile.SetPipeSize(int64(args[2].Int()))
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+ case linux.F_GETOWN:
+ owner, hasOwner := getAsyncOwner(t, file)
+ if !hasOwner {
+ return 0, nil, nil
+ }
+ if owner.Type == linux.F_OWNER_PGRP {
+ return uintptr(-owner.PID), nil, nil
+ }
+ return uintptr(owner.PID), nil, nil
+ case linux.F_SETOWN:
+ who := args[2].Int()
+ ownerType := int32(linux.F_OWNER_PID)
+ if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return 0, nil, syserror.EINVAL
+ }
+ ownerType = linux.F_OWNER_PGRP
+ who = -who
+ }
+ return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ case linux.F_GETOWN_EX:
+ owner, hasOwner := getAsyncOwner(t, file)
+ if !hasOwner {
+ return 0, nil, nil
+ }
+ _, err := t.CopyOut(args[2].Pointer(), &owner)
+ return 0, nil, err
+ case linux.F_SETOWN_EX:
+ var owner linux.FOwnerEx
+ _, err := t.CopyIn(args[2].Pointer(), &owner)
+ if err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID)
+ case linux.F_GETPIPE_SZ:
+ pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
+ if !ok {
+ return 0, nil, syserror.EBADF
+ }
+ return uintptr(pipefile.PipeSize()), nil, nil
+ case linux.F_GET_SEALS:
+ val, err := tmpfs.GetSeals(file)
+ return uintptr(val), nil, err
+ case linux.F_ADD_SEALS:
+ if !file.IsWritable() {
+ return 0, nil, syserror.EPERM
+ }
+ err := tmpfs.AddSeals(file, args[2].Uint())
+ return 0, nil, err
+ case linux.F_SETLK, linux.F_SETLKW:
+ return 0, nil, posixLock(t, args, file, cmd)
+ default:
+ // Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwnerEx, hasOwner bool) {
+ a := fd.AsyncHandler()
+ if a == nil {
+ return linux.FOwnerEx{}, false
+ }
+
+ ot, otg, opg := a.(*fasync.FileAsync).Owner()
+ switch {
+ case ot != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_TID,
+ PID: int32(t.PIDNamespace().IDOfTask(ot)),
+ }, true
+ case otg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PID,
+ PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)),
+ }, true
+ case opg != nil:
+ return linux.FOwnerEx{
+ Type: linux.F_OWNER_PGRP,
+ PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)),
+ }, true
+ default:
+ return linux.FOwnerEx{}, true
+ }
+}
+
+func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error {
+ switch ownerType {
+ case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
+ // Acceptable type.
+ default:
+ return syserror.EINVAL
+ }
+
+ a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync)
+ if pid == 0 {
+ a.ClearOwner()
+ return nil
+ }
+
+ switch ownerType {
+ case linux.F_OWNER_TID:
+ task := t.PIDNamespace().TaskWithID(kernel.ThreadID(pid))
+ if task == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerTask(t, task)
+ return nil
+ case linux.F_OWNER_PID:
+ tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(pid))
+ if tg == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerThreadGroup(t, tg)
+ return nil
+ case linux.F_OWNER_PGRP:
+ pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(pid))
+ if pg == nil {
+ return syserror.ESRCH
+ }
+ a.SetOwnerProcessGroup(t, pg)
+ return nil
+ default:
+ return syserror.EINVAL
+ }
+}
+
+func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error {
+ // Copy in the lock request.
+ flockAddr := args[2].Pointer()
+ var flock linux.Flock
+ if _, err := t.CopyIn(flockAddr, &flock); err != nil {
+ return err
+ }
+
+ var blocker lock.Blocker
+ if cmd == linux.F_SETLKW {
+ blocker = t
+ }
+
+ switch flock.Type {
+ case linux.F_RDLCK:
+ if !file.IsReadable() {
+ return syserror.EBADF
+ }
+ return file.LockPOSIX(t, t.FDTable(), lock.ReadLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+
+ case linux.F_WRLCK:
+ if !file.IsWritable() {
+ return syserror.EBADF
+ }
+ return file.LockPOSIX(t, t.FDTable(), lock.WriteLock, uint64(flock.Start), uint64(flock.Len), flock.Whence, blocker)
+
+ case linux.F_UNLCK:
+ return file.UnlockPOSIX(t, t.FDTable(), uint64(flock.Start), uint64(flock.Len), flock.Whence)
+
+ default:
+ return syserror.EINVAL
+ }
+}
+
+// Fadvise64 implements fadvise64(2).
+// This implementation currently ignores the provided advice.
+func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[2].Int64()
+ advice := args[3].Int()
+
+ // Note: offset is allowed to be negative.
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // If the FD refers to a pipe or FIFO, return error.
+ if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+
+ switch advice {
+ case linux.POSIX_FADV_NORMAL:
+ case linux.POSIX_FADV_RANDOM:
+ case linux.POSIX_FADV_SEQUENTIAL:
+ case linux.POSIX_FADV_WILLNEED:
+ case linux.POSIX_FADV_DONTNEED:
+ case linux.POSIX_FADV_NOREUSE:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Sure, whatever.
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
new file mode 100644
index 000000000..01e0f9010
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -0,0 +1,334 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Link implements Linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Linkat implements Linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+ if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return syserror.ENOENT
+ }
+
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0))
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release(t)
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release(t)
+
+ return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
+}
+
+// Mkdir implements Linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements Linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, mkdirat(t, dirfd, addr, mode)
+}
+
+func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+ return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+}
+
+// Mknod implements Linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ dev := args[2].Uint()
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev)
+}
+
+// Mknodat implements Linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ dev := args[3].Uint()
+ return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev)
+}
+
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ if mode.FileType() == 0 {
+ mode |= linux.ModeRegular
+ }
+ major, minor := linux.DecodeDeviceID(dev)
+ return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
+ Mode: mode &^ linux.FileMode(t.FSContext().Umask()),
+ DevMajor: uint32(major),
+ DevMinor: minor,
+ })
+}
+
+// Open implements Linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+ mode := args[2].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, flags, mode)
+}
+
+// Openat implements Linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ mode := args[3].ModeT()
+ return openat(t, dirfd, addr, flags, mode)
+}
+
+// Creat implements Linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
+}
+
+func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
+ Flags: flags | linux.O_LARGEFILE,
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ return uintptr(fd), nil, err
+}
+
+// Rename implements Linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Renameat implements Linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */)
+}
+
+// Renameat2 implements Linux syscall renameat2(2).
+func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Uint()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ // "If oldpath refers to a symbolic link, the link is renamed" - rename(2)
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release(t)
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release(t)
+
+ return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
+ Flags: flags,
+ })
+}
+
+// Rmdir implements Linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+ return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlink implements Linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+ return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlinkat implements Linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if flags&^linux.AT_REMOVEDIR != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirat(t, dirfd, pathAddr)
+ }
+ return 0, nil, unlinkat(t, dirfd, pathAddr)
+}
+
+// Symlink implements Linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ linkpathAddr := args[1].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr)
+}
+
+// Symlinkat implements Linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ newdirfd := args[1].Int()
+ linkpathAddr := args[2].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
+}
+
+func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+ target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ if len(target) == 0 {
+ return syserror.ENOENT
+ }
+ linkpath, err := copyInPath(t, linkpathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+ return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
new file mode 100644
index 000000000..a7d4d2a36
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getcwd implements Linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+
+ root := t.FSContext().RootDirectoryVFS2()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
+ root.DecRef(t)
+ wd.DecRef(t)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Construct a byte slice containing a NUL terminator.
+ buf := t.CopyScratchBuffer(len(s) + 1)
+ copy(buf, s)
+ buf[len(buf)-1] = 0
+
+ // Write the pathname slice.
+ n, err := t.CopyOutBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Chdir implements Linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(t, vd)
+ vd.DecRef(t)
+ return 0, nil, nil
+}
+
+// Fchdir implements Linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(t, vd)
+ vd.DecRef(t)
+ return 0, nil, nil
+}
+
+// Chroot implements Linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetRootDirectoryVFS2(t, vd)
+ vd.DecRef(t)
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
new file mode 100644
index 000000000..5517595b5
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -0,0 +1,161 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Getdents implements Linux syscall getdents(2).
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, false /* isGetdents64 */)
+}
+
+// Getdents64 implements Linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, true /* isGetdents64 */)
+}
+
+func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ cb := getGetdentsCallback(t, addr, size, isGetdents64)
+ err := file.IterDirents(t, cb)
+ n := size - cb.remaining
+ putGetdentsCallback(cb)
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+type getdentsCallback struct {
+ t *kernel.Task
+ addr usermem.Addr
+ remaining int
+ isGetdents64 bool
+}
+
+var getdentsCallbackPool = sync.Pool{
+ New: func() interface{} {
+ return &getdentsCallback{}
+ },
+}
+
+func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+ cb := getdentsCallbackPool.Get().(*getdentsCallback)
+ *cb = getdentsCallback{
+ t: t,
+ addr: addr,
+ remaining: size,
+ isGetdents64: isGetdents64,
+ }
+ return cb
+}
+
+func putGetdentsCallback(cb *getdentsCallback) {
+ cb.t = nil
+ getdentsCallbackPool.Put(cb)
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
+ var buf []byte
+ if cb.isGetdents64 {
+ // struct linux_dirent64 {
+ // ino64_t d_ino; /* 64-bit inode number */
+ // off64_t d_off; /* 64-bit offset to next structure */
+ // unsigned short d_reclen; /* Size of this dirent */
+ // unsigned char d_type; /* File type */
+ // char d_name[]; /* Filename (null-terminated) */
+ // };
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ size = (size + 7) &^ 7 // round up to multiple of 8
+ if size > cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ buf[18] = dirent.Type
+ copy(buf[19:], dirent.Name)
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name.
+ bufTail := buf[19+len(dirent.Name):]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
+ } else {
+ // struct linux_dirent {
+ // unsigned long d_ino; /* Inode number */
+ // unsigned long d_off; /* Offset to next linux_dirent */
+ // unsigned short d_reclen; /* Length of this linux_dirent */
+ // char d_name[]; /* Filename (null-terminated) */
+ // /* length is actually (d_reclen - 2 -
+ // offsetof(struct linux_dirent, d_name)) */
+ // /*
+ // char pad; // Zero padding byte
+ // char d_type; // File type (only since Linux
+ // // 2.6.4); offset is (d_reclen - 1)
+ // */
+ // };
+ if cb.t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
+ }
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ size = (size + 7) &^ 7 // round up to multiple of sizeof(long)
+ if size > cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ copy(buf[18:], dirent.Name)
+ // Zero out all remaining bytes in buf, including the NUL terminator
+ // after dirent.Name and the zero padding byte between the name and
+ // dirent type.
+ bufTail := buf[18+len(dirent.Name) : size-1]
+ for i := range bufTail {
+ bufTail[i] = 0
+ }
+ buf[size-1] = dirent.Type
+ }
+ n, err := cb.t.CopyOutBytes(cb.addr, buf)
+ if err != nil {
+ // Don't report partially-written dirents by advancing cb.addr or
+ // cb.remaining.
+ return err
+ }
+ cb.addr += usermem.Addr(n)
+ cb.remaining -= n
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go
new file mode 100644
index 000000000..11753d8e5
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const allFlags = linux.IN_NONBLOCK | linux.IN_CLOEXEC
+
+// InotifyInit1 implements the inotify_init1() syscalls.
+func InotifyInit1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^allFlags != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ ino, err := vfs.NewInotifyFD(t, t.Kernel().VFS(), uint32(flags))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer ino.DecRef(t)
+
+ fd, err := t.NewFDFromVFS2(0, ino, kernel.FDFlags{
+ CloseOnExec: flags&linux.IN_CLOEXEC != 0,
+ })
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// InotifyInit implements the inotify_init() syscalls.
+func InotifyInit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ args[0].Value = 0
+ return InotifyInit1(t, args)
+}
+
+// fdToInotify resolves an fd to an inotify object. If successful, the file will
+// have an extra ref and the caller is responsible for releasing the ref.
+func fdToInotify(t *kernel.Task, fd int32) (*vfs.Inotify, *vfs.FileDescription, error) {
+ f := t.GetFileVFS2(fd)
+ if f == nil {
+ // Invalid fd.
+ return nil, nil, syserror.EBADF
+ }
+
+ ino, ok := f.Impl().(*vfs.Inotify)
+ if !ok {
+ // Not an inotify fd.
+ f.DecRef(t)
+ return nil, nil, syserror.EINVAL
+ }
+
+ return ino, f, nil
+}
+
+// InotifyAddWatch implements the inotify_add_watch() syscall.
+func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ mask := args[2].Uint()
+
+ // "EINVAL: The given event mask contains no valid events."
+ // -- inotify_add_watch(2)
+ if mask&linux.ALL_INOTIFY_BITS == 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "IN_DONT_FOLLOW: Don't dereference pathname if it is a symbolic link."
+ // -- inotify(7)
+ follow := followFinalSymlink
+ if mask&linux.IN_DONT_FOLLOW == 0 {
+ follow = nofollowFinalSymlink
+ }
+
+ ino, f, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer f.DecRef(t)
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ if mask&linux.IN_ONLYDIR != 0 {
+ path.Dir = true
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, follow)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+ d, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{})
+ if err != nil {
+ return 0, nil, err
+ }
+ defer d.DecRef(t)
+
+ fd, err = ino.AddWatch(d.Dentry(), mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// InotifyRmWatch implements the inotify_rm_watch() syscall.
+func InotifyRmWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ wd := args[1].Int()
+
+ ino, f, err := fdToInotify(t, fd)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer f.DecRef(t)
+ return 0, nil, ino.RmWatch(t, wd)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
new file mode 100644
index 000000000..38778a388
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -0,0 +1,107 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Ioctl implements Linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Handle ioctls that apply to all FDs.
+ switch args[1].Int() {
+ case linux.FIONCLEX:
+ t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: false,
+ })
+ return 0, nil, nil
+
+ case linux.FIOCLEX:
+ t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{
+ CloseOnExec: true,
+ })
+ return 0, nil, nil
+
+ case linux.FIONBIO:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.StatusFlags()
+ if set != 0 {
+ flags |= linux.O_NONBLOCK
+ } else {
+ flags &^= linux.O_NONBLOCK
+ }
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), flags)
+
+ case linux.FIOASYNC:
+ var set int32
+ if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil {
+ return 0, nil, err
+ }
+ flags := file.StatusFlags()
+ if set != 0 {
+ flags |= linux.O_ASYNC
+ } else {
+ flags &^= linux.O_ASYNC
+ }
+ file.SetStatusFlags(t, t.Credentials(), flags)
+ return 0, nil, nil
+
+ case linux.FIOGETOWN, linux.SIOCGPGRP:
+ var who int32
+ owner, hasOwner := getAsyncOwner(t, file)
+ if hasOwner {
+ if owner.Type == linux.F_OWNER_PGRP {
+ who = -owner.PID
+ } else {
+ who = owner.PID
+ }
+ }
+ _, err := t.CopyOut(args[2].Pointer(), &who)
+ return 0, nil, err
+
+ case linux.FIOSETOWN, linux.SIOCSPGRP:
+ var who int32
+ if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil {
+ return 0, nil, err
+ }
+ ownerType := int32(linux.F_OWNER_PID)
+ if who < 0 {
+ // Check for overflow before flipping the sign.
+ if who-1 > who {
+ return 0, nil, syserror.EINVAL
+ }
+ ownerType = linux.F_OWNER_PGRP
+ who = -who
+ }
+ return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ }
+
+ ret, err := file.Ioctl(t, t.MemoryManager(), args)
+ return ret, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/lock.go b/pkg/sentry/syscalls/linux/vfs2/lock.go
new file mode 100644
index 000000000..b910b5a74
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/lock.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Flock implements linux syscall flock(2).
+func Flock(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ operation := args[1].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ // flock(2): EBADF fd is not an open file descriptor.
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ nonblocking := operation&linux.LOCK_NB != 0
+ operation &^= linux.LOCK_NB
+
+ var blocker lock.Blocker
+ if !nonblocking {
+ blocker = t
+ }
+
+ switch operation {
+ case linux.LOCK_EX:
+ if err := file.LockBSD(t, lock.WriteLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_SH:
+ if err := file.LockBSD(t, lock.ReadLock, blocker); err != nil {
+ return 0, nil, err
+ }
+ case linux.LOCK_UN:
+ if err := file.UnlockBSD(t); err != nil {
+ return 0, nil, err
+ }
+ default:
+ // flock(2): EINVAL operation is invalid.
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go
new file mode 100644
index 000000000..c4c0f9e0a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go
@@ -0,0 +1,64 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ memfdPrefix = "memfd:"
+ memfdMaxNameLen = linux.NAME_MAX - len(memfdPrefix)
+ memfdAllFlags = uint32(linux.MFD_CLOEXEC | linux.MFD_ALLOW_SEALING)
+)
+
+// MemfdCreate implements the linux syscall memfd_create(2).
+func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+
+ if flags&^memfdAllFlags != 0 {
+ // Unknown bits in flags.
+ return 0, nil, syserror.EINVAL
+ }
+
+ allowSeals := flags&linux.MFD_ALLOW_SEALING != 0
+ cloExec := flags&linux.MFD_CLOEXEC != 0
+
+ name, err := t.CopyInString(addr, memfdMaxNameLen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ shmMount := t.Kernel().ShmMount()
+ file, err := tmpfs.NewMemfd(t, t.Credentials(), shmMount, allowSeals, memfdPrefix+name)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: cloExec,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
new file mode 100644
index 000000000..dc05c2994
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mmap implements Linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef(t)
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // mmap unconditionally requires that the FD is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !file.IsWritable() {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go
new file mode 100644
index 000000000..4bd5c7ca2
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mount.go
@@ -0,0 +1,150 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mount implements Linux syscall mount(2).
+func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ sourceAddr := args[0].Pointer()
+ targetAddr := args[1].Pointer()
+ typeAddr := args[2].Pointer()
+ flags := args[3].Uint64()
+ dataAddr := args[4].Pointer()
+
+ // For null-terminated strings related to mount(2), Linux copies in at most
+ // a page worth of data. See fs/namespace.c:copy_mount_string().
+ fsType, err := t.CopyInString(typeAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ source, err := t.CopyInString(sourceAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ targetPath, err := copyInPath(t, targetAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ data := ""
+ if dataAddr != 0 {
+ // In Linux, a full page is always copied in regardless of null
+ // character placement, and the address is passed to each file system.
+ // Most file systems always treat this data as a string, though, and so
+ // do all of the ones we implement.
+ data, err = t.CopyInString(dataAddr, usermem.PageSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ // Ignore magic value that was required before Linux 2.4.
+ if flags&linux.MS_MGC_MSK == linux.MS_MGC_VAL {
+ flags = flags &^ linux.MS_MGC_MSK
+ }
+
+ // Must have CAP_SYS_ADMIN in the current mount namespace's associated user
+ // namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupportedOps = linux.MS_REMOUNT | linux.MS_BIND |
+ linux.MS_SHARED | linux.MS_PRIVATE | linux.MS_SLAVE |
+ linux.MS_UNBINDABLE | linux.MS_MOVE
+
+ // Silently allow MS_NOSUID, since we don't implement set-id bits
+ // anyway.
+ const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME
+
+ // Linux just allows passing any flags to mount(2) - it won't fail when
+ // unknown or unsupported flags are passed. Since we don't implement
+ // everything, we fail explicitly on flags that are unimplemented.
+ if flags&(unsupportedOps|unsupportedFlags) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var opts vfs.MountOptions
+ if flags&linux.MS_NOATIME == linux.MS_NOATIME {
+ opts.Flags.NoATime = true
+ }
+ if flags&linux.MS_NOEXEC == linux.MS_NOEXEC {
+ opts.Flags.NoExec = true
+ }
+ if flags&linux.MS_NODEV == linux.MS_NODEV {
+ opts.Flags.NoDev = true
+ }
+ if flags&linux.MS_NOSUID == linux.MS_NOSUID {
+ opts.Flags.NoSUID = true
+ }
+ if flags&linux.MS_RDONLY == linux.MS_RDONLY {
+ opts.ReadOnly = true
+ }
+ opts.GetFilesystemOptions.Data = data
+
+ target, err := getTaskPathOperation(t, linux.AT_FDCWD, targetPath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer target.Release(t)
+
+ return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts)
+}
+
+// Umount2 implements Linux syscall umount2(2).
+func Umount2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+
+ // Must have CAP_SYS_ADMIN in the mount namespace's associated user
+ // namespace.
+ //
+ // Currently, this is always the init task's user namespace.
+ creds := t.Credentials()
+ if !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.MountNamespaceVFS2().Owner) {
+ return 0, nil, syserror.EPERM
+ }
+
+ const unsupported = linux.MNT_FORCE | linux.MNT_EXPIRE
+ if flags&unsupported != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ opts := vfs.UmountOptions{
+ Flags: uint32(flags),
+ }
+
+ return 0, nil, t.Kernel().VFS().UmountAt(t, creds, &tpop.pop, &opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
new file mode 100644
index 000000000..90a511d9a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+ pathname, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fspath.Path{}, err
+ }
+ return fspath.Parse(pathname), nil
+}
+
+type taskPathOperation struct {
+ pop vfs.PathOperation
+ haveStartRef bool
+}
+
+func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) {
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ haveStartRef := false
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ root.DecRef(t)
+ return taskPathOperation{}, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ haveStartRef = true
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ root.DecRef(t)
+ return taskPathOperation{}, syserror.EBADF
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ haveStartRef = true
+ dirfile.DecRef(t)
+ }
+ }
+ return taskPathOperation{
+ pop: vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ },
+ haveStartRef: haveStartRef,
+ }, nil
+}
+
+func (tpop *taskPathOperation) Release(t *kernel.Task) {
+ tpop.pop.Root.DecRef(t)
+ if tpop.haveStartRef {
+ tpop.pop.Start.DecRef(t)
+ tpop.haveStartRef = false
+ }
+}
+
+type shouldAllowEmptyPath bool
+
+const (
+ disallowEmptyPath shouldAllowEmptyPath = false
+ allowEmptyPath shouldAllowEmptyPath = true
+)
+
+type shouldFollowFinalSymlink bool
+
+const (
+ nofollowFinalSymlink shouldFollowFinalSymlink = false
+ followFinalSymlink shouldFollowFinalSymlink = true
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
new file mode 100644
index 000000000..9b4848d9e
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Pipe implements Linux syscall pipe(2).
+func Pipe(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ return 0, nil, pipe2(t, addr, 0)
+}
+
+// Pipe2 implements Linux syscall pipe2(2).
+func Pipe2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Int()
+ return 0, nil, pipe2(t, addr, flags)
+}
+
+func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
+ if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
+ return syserror.EINVAL
+ }
+ r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ defer r.DecRef(t)
+ defer w.DecRef(t)
+
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{r, w}, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return err
+ }
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef(t)
+ }
+ }
+ return err
+ }
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
new file mode 100644
index 000000000..7b9d5e18a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -0,0 +1,586 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select. This has no
+// equivalent in Linux; it exists in gVisor since allocation failure in Go is
+// unrecoverable.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file description and waiter of a PollFD.
+type pollState struct {
+ file *vfs.FileDescription
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFileVFS2(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef(t)
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(t *kernel.Task, state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef(t)
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(t, state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ haveTimeout := timeout >= 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := copyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef(t)
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ _, err := tsRemaining.CopyOut(t, timespecAddr)
+ return err
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ _, err := tvRemaining.CopyOut(t, timevalAddr)
+ return err
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ var timeval linux.Timeval
+ if _, err := timeval.CopyIn(t, timevalAddr); err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ if t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
+ }
+ var maskStruct sigSetWithSize
+ if _, err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ return 0, nil, err
+ }
+ if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ var timespec linux.Timespec
+ if _, err := timespec.CopyIn(t, timespecAddr); err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
+
+func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+ if maskAddr == 0 {
+ return nil
+ }
+ if maskSize != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if _, err := mask.CopyIn(t, maskAddr); err != nil {
+ return err
+ }
+ mask &^= kernel.UnblockableSignals
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
new file mode 100644
index 000000000..a905dae0a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -0,0 +1,641 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+ eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements Linux syscall read(2).
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readv implements Linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.Read(t, dst, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.Read(t, dst, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+// Pread64 implements Linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Preadv implements Linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements Linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.ReadOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = read(t, file, dst, opts)
+ } else {
+ n, err = pread(t, file, dst, offset, opts)
+ }
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.PRead(t, dst, offset, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.PRead(t, dst, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+// Write implements Linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Writev implements Linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.Write(t, src, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.Write(t, src, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+// Pwrite64 implements Linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Pwritev implements Linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements Linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.WriteOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = write(t, file, src, opts)
+ } else {
+ n, err = pwrite(t, file, src, offset, opts)
+ }
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.PWrite(t, src, offset, opts)
+ if err != syserror.ErrWouldBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ allowBlock, deadline, hasDeadline := blockPolicy(t, file)
+ if !allowBlock {
+ if n > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = file.PWrite(t, src, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ if err = t.BlockWithDeadline(ch, hasDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ if total > 0 {
+ file.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ }
+ return total, err
+}
+
+func blockPolicy(t *kernel.Task, file *vfs.FileDescription) (allowBlock bool, deadline ktime.Time, hasDeadline bool) {
+ if file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return false, ktime.Time{}, false
+ }
+ // Sockets support read/write timeouts.
+ if s, ok := file.Impl().(socket.SocketVFS2); ok {
+ dl := s.RecvTimeout()
+ if dl < 0 {
+ return false, ktime.Time{}, false
+ }
+ if dl > 0 {
+ return true, t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond), true
+ }
+ }
+ return true, ktime.Time{}, false
+}
+
+// Lseek implements Linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ newoff, err := file.Seek(t, offset, whence)
+ return uintptr(newoff), nil, err
+}
+
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Check that the file is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate and does not overflow.
+ if offset < 0 || offset+int64(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
new file mode 100644
index 000000000..5e6eb13ba
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -0,0 +1,484 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+// Chmod implements Linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode)
+}
+
+// Fchmodat implements Linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
+}
+
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Fchmod implements Linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].ModeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Chown implements Linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */)
+}
+
+// Lchown implements Linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ owner := args[2].Int()
+ group := args[3].Int()
+ flags := args[4].Int()
+ return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
+}
+
+func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error {
+ userns := t.UserNamespace()
+ if owner != -1 {
+ kuid := userns.MapToKUID(auth.UID(owner))
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_UID
+ opts.Stat.UID = uint32(kuid)
+ }
+ if group != -1 {
+ kgid := userns.MapToKGID(auth.GID(group))
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_GID
+ opts.Stat.GID = uint32(kgid)
+ }
+ return nil
+}
+
+// Fchown implements Linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ owner := args[1].Int()
+ group := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, file.SetStat(t, opts)
+}
+
+// Truncate implements Linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ err = setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ NeedWritePerm: true,
+ })
+ return 0, nil, handleSetSizeError(t, err)
+}
+
+// Ftruncate implements Linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ if !file.IsWritable() {
+ return 0, nil, syserror.EINVAL
+ }
+
+ err := file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+ return 0, nil, handleSetSizeError(t, err)
+}
+
+// Fallocate implements linux system call fallocate(2).
+func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].Uint64()
+ offset := args[2].Int64()
+ length := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ if !file.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ if mode != 0 {
+ return 0, nil, syserror.ENOTSUP
+ }
+
+ if offset < 0 || length <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ size := offset + length
+
+ if size < 0 {
+ return 0, nil, syserror.EFBIG
+ }
+
+ limit := limits.FromContext(t).Get(limits.FileSize).Cur
+
+ if uint64(size) >= limit {
+ t.SendSignal(&arch.SignalInfo{
+ Signo: int32(linux.SIGXFSZ),
+ Code: arch.SignalInfoUser,
+ })
+ return 0, nil, syserror.EFBIG
+ }
+
+ if err := file.Allocate(t, mode, uint64(offset), uint64(length)); err != nil {
+ return 0, nil, err
+ }
+
+ file.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ return 0, nil, nil
+}
+
+// Utime implements Linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times linux.Utime
+ if _, err := times.CopyIn(t, timesAddr); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime.Sec = times.Actime
+ opts.Stat.Mtime.Sec = times.Modtime
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimes implements Linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimes(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Futimesat implements Linux syscall futimesat(2).
+func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+
+ // "If filename is NULL and dfd refers to an open file, then operate on the
+ // file. Otherwise look up filename, possibly using dfd as a starting
+ // point." - fs/utimes.c
+ var path fspath.Path
+ shouldAllowEmptyPath := allowEmptyPath
+ if dirfd == linux.AT_FDCWD || pathAddr != 0 {
+ var err error
+ path, err = copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ shouldAllowEmptyPath = disallowEmptyPath
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimes(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, followFinalSymlink, &opts)
+}
+
+func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Usec * 1000),
+ }
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Usec * 1000),
+ }
+ return nil
+}
+
+// Utimensat implements Linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ // Linux requires that the UTIME_OMIT check occur before checking path or
+ // flags.
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+ if opts.Stat.Mask == 0 {
+ return 0, nil, nil
+ }
+
+ if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // "If filename is NULL and dfd refers to an open file, then operate on the
+ // file. Otherwise look up filename, possibly using dfd as a starting
+ // point." - fs/utimes.c
+ var path fspath.Path
+ shouldAllowEmptyPath := allowEmptyPath
+ if dirfd == linux.AT_FDCWD || pathAddr != 0 {
+ var err error
+ path, err = copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ shouldAllowEmptyPath = disallowEmptyPath
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, shouldAllowEmptyPath, shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Nsec != linux.UTIME_OMIT {
+ if times[0].Nsec != linux.UTIME_NOW && (times[0].Nsec < 0 || times[0].Nsec > 999999999) {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_ATIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Nsec),
+ }
+ }
+ if times[1].Nsec != linux.UTIME_OMIT {
+ if times[1].Nsec != linux.UTIME_NOW && (times[1].Nsec < 0 || times[1].Nsec > 999999999) {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_MTIME
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Nsec),
+ }
+ }
+ return nil
+}
+
+func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef(t)
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef(t)
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.SetStat() instead of
+ // VirtualFilesystem.SetStatAt(), since the former may be able
+ // to use opened file state to expedite the SetStat.
+ err := dirfile.SetStat(t, *opts)
+ dirfile.DecRef(t)
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef(t)
+ dirfile.DecRef(t)
+ }
+ }
+ return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ }, opts)
+}
+
+func handleSetSizeError(t *kernel.Task, err error) error {
+ if err == syserror.ErrExceedsFileSizeLimit {
+ // Convert error to EFBIG and send a SIGXFSZ per setrlimit(2).
+ t.SendSignal(kernel.SignalInfoNoInfo(linux.SIGXFSZ, t, t))
+ return syserror.EFBIG
+ }
+ return err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/signal.go b/pkg/sentry/syscalls/linux/vfs2/signal.go
new file mode 100644
index 000000000..b89f34cdb
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/signal.go
@@ -0,0 +1,100 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/signalfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// sharedSignalfd is shared between the two calls.
+func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ // Copy in the signal mask.
+ mask, err := slinux.CopyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Always check for valid flags, even if not creating.
+ if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is this a change to an existing signalfd?
+ //
+ // The spec indicates that this should adjust the mask.
+ if fd != -1 {
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Is this a signalfd?
+ if sfd, ok := file.Impl().(*signalfd.SignalFileDescription); ok {
+ sfd.SetMask(mask)
+ return 0, nil, nil
+ }
+
+ // Not a signalfd.
+ return 0, nil, syserror.EINVAL
+ }
+
+ fileFlags := uint32(linux.O_RDWR)
+ if flags&linux.SFD_NONBLOCK != 0 {
+ fileFlags |= linux.O_NONBLOCK
+ }
+
+ // Create a new file.
+ vfsObj := t.Kernel().VFS()
+ file, err := signalfd.New(vfsObj, t, mask, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+
+ // Create a new descriptor.
+ fd, err = t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.SFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Done.
+ return uintptr(fd), nil, nil
+}
+
+// Signalfd implements the linux syscall signalfd(2).
+func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, 0)
+}
+
+// Signalfd4 implements the linux syscall signalfd4(2).
+func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ flags := args[3].Int()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, flags)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
new file mode 100644
index 000000000..4a68c64f3
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -0,0 +1,1144 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
+)
+
+// minListenBacklog is the minimum reasonable backlog for listening sockets.
+const minListenBacklog = 8
+
+// maxListenBacklog is the maximum allowed backlog for listening sockets.
+const maxListenBacklog = 1024
+
+// maxAddrLen is the maximum socket address length we're willing to accept.
+const maxAddrLen = 200
+
+// maxOptLen is the maximum sockopt parameter length we're willing to accept.
+const maxOptLen = 1024 * 8
+
+// maxControlLen is the maximum length of the msghdr.msg_control buffer we're
+// willing to accept. Note that this limit is smaller than Linux, which allows
+// buffers upto INT_MAX.
+const maxControlLen = 10 * 1024 * 1024
+
+// nameLenOffset is the offset from the start of the MessageHeader64 struct to
+// the NameLen field.
+const nameLenOffset = 8
+
+// controlLenOffset is the offset form the start of the MessageHeader64 struct
+// to the ControlLen field.
+const controlLenOffset = 40
+
+// flagsOffset is the offset form the start of the MessageHeader64 struct
+// to the Flags field.
+const flagsOffset = 48
+
+const sizeOfInt32 = 4
+
+// messageHeader64Len is the length of a MessageHeader64 struct.
+var messageHeader64Len = uint64(binary.Size(MessageHeader64{}))
+
+// multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
+var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{}))
+
+// baseRecvFlags are the flags that are accepted across recvmsg(2),
+// recvmmsg(2), and recvfrom(2).
+const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC
+
+// MessageHeader64 is the 64-bit representation of the msghdr struct used in
+// the recvmsg and sendmsg syscalls.
+type MessageHeader64 struct {
+ // Name is the optional pointer to a network address buffer.
+ Name uint64
+
+ // NameLen is the length of the buffer pointed to by Name.
+ NameLen uint32
+ _ uint32
+
+ // Iov is a pointer to an array of io vectors that describe the memory
+ // locations involved in the io operation.
+ Iov uint64
+
+ // IovLen is the length of the array pointed to by Iov.
+ IovLen uint64
+
+ // Control is the optional pointer to ancillary control data.
+ Control uint64
+
+ // ControlLen is the length of the data pointed to by Control.
+ ControlLen uint64
+
+ // Flags on the sent/received message.
+ Flags int32
+ _ int32
+}
+
+// multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
+// the recvmmsg and sendmmsg syscalls.
+type multipleMessageHeader64 struct {
+ msgHdr MessageHeader64
+ msgLen uint32
+ _ int32
+}
+
+// CopyInMessageHeader64 copies a message header from user to kernel memory.
+func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error {
+ b := t.CopyScratchBuffer(52)
+ if _, err := t.CopyInBytes(addr, b); err != nil {
+ return err
+ }
+
+ msg.Name = usermem.ByteOrder.Uint64(b[0:])
+ msg.NameLen = usermem.ByteOrder.Uint32(b[8:])
+ msg.Iov = usermem.ByteOrder.Uint64(b[16:])
+ msg.IovLen = usermem.ByteOrder.Uint64(b[24:])
+ msg.Control = usermem.ByteOrder.Uint64(b[32:])
+ msg.ControlLen = usermem.ByteOrder.Uint64(b[40:])
+ msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:]))
+
+ return nil
+}
+
+// CaptureAddress allocates memory for and copies a socket address structure
+// from the untrusted address space range.
+func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) {
+ if addrlen > maxAddrLen {
+ return nil, syserror.EINVAL
+ }
+
+ addrBuf := make([]byte, addrlen)
+ if _, err := t.CopyInBytes(addr, addrBuf); err != nil {
+ return nil, err
+ }
+
+ return addrBuf, nil
+}
+
+// writeAddress writes a sockaddr structure and its length to an output buffer
+// in the unstrusted address space range. If the address is bigger than the
+// buffer, it is truncated.
+func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error {
+ // Get the buffer length.
+ var bufLen uint32
+ if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil {
+ return err
+ }
+
+ if int32(bufLen) < 0 {
+ return syserror.EINVAL
+ }
+
+ // Write the length unconditionally.
+ if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil {
+ return err
+ }
+
+ if addr == nil {
+ return nil
+ }
+
+ if bufLen > addrLen {
+ bufLen = addrLen
+ }
+
+ // Copy as much of the address as will fit in the buffer.
+ encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr)
+ if bufLen > uint32(len(encodedAddr)) {
+ bufLen = uint32(len(encodedAddr))
+ }
+ _, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)])
+ return err
+}
+
+// Socket implements the linux syscall socket(2).
+func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the new socket.
+ s, e := socket.NewVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ defer s.DecRef(t)
+
+ if err := s.SetStatusFlags(t, t.Credentials(), uint32(stype&linux.SOCK_NONBLOCK)); err != nil {
+ return 0, nil, err
+ }
+
+ fd, err := t.NewFDFromVFS2(0, s, kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(fd), nil, nil
+}
+
+// SocketPair implements the linux syscall socketpair(2).
+func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ domain := int(args[0].Int())
+ stype := args[1].Int()
+ protocol := int(args[2].Int())
+ addr := args[3].Pointer()
+
+ // Check and initialize the flags.
+ if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create the socket pair.
+ s1, s2, e := socket.PairVFS2(t, domain, linux.SockType(stype&0xf), protocol)
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+ // Adding to the FD table will cause an extra reference to be acquired.
+ defer s1.DecRef(t)
+ defer s2.DecRef(t)
+
+ nonblocking := uint32(stype & linux.SOCK_NONBLOCK)
+ if err := s1.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+ if err := s2.SetStatusFlags(t, t.Credentials(), nonblocking); err != nil {
+ return 0, nil, err
+ }
+
+ // Create the FDs for the sockets.
+ flags := kernel.FDFlags{
+ CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
+ }
+ fds, err := t.NewFDsVFS2(0, []*vfs.FileDescription{s1, s2}, flags)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if _, err := t.CopyOut(addr, fds); err != nil {
+ for _, fd := range fds {
+ if _, file := t.FDTable().Remove(fd); file != nil {
+ file.DecRef(t)
+ }
+ }
+ return 0, nil, err
+ }
+
+ return 0, nil, nil
+}
+
+// Connect implements the linux syscall connect(2).
+func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+ return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS)
+}
+
+// accept is the implementation of the accept syscall. It is called by accept
+// and accept4 syscall handlers.
+func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, flags int) (uintptr, error) {
+ // Check that no unsupported flags are passed in.
+ if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ // Call the syscall implementation for this socket, then copy the
+ // output address if one is specified.
+ blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0
+
+ peerRequested := addrLen != 0
+ nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ if peerRequested {
+ // NOTE(magi): Linux does not give you an error if it can't
+ // write the data back out so neither do we.
+ if err := writeAddress(t, peer, peerLen, addr, addrLen); err == syserror.EINVAL {
+ return 0, err
+ }
+ }
+ return uintptr(nfd), nil
+}
+
+// Accept4 implements the linux syscall accept4(2).
+func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+ flags := int(args[3].Int())
+
+ n, err := accept(t, fd, addr, addrlen, flags)
+ return n, nil, err
+}
+
+// Accept implements the linux syscall accept(2).
+func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ n, err := accept(t, fd, addr, addrlen, 0)
+ return n, nil, err
+}
+
+// Bind implements the linux syscall bind(2).
+func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Uint()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Capture address and call syscall implementation.
+ a, err := CaptureAddress(t, addr, addrlen)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, s.Bind(t, a).ToError()
+}
+
+// Listen implements the linux syscall listen(2).
+func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ backlog := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Per Linux, the backlog is silently capped to reasonable values.
+ if backlog <= 0 {
+ backlog = minListenBacklog
+ }
+ if backlog > maxListenBacklog {
+ backlog = maxListenBacklog
+ }
+
+ return 0, nil, s.Listen(t, int(backlog)).ToError()
+}
+
+// Shutdown implements the linux syscall shutdown(2).
+func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ how := args[1].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Validate how, then call syscall implementation.
+ switch how {
+ case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+
+ return 0, nil, s.Shutdown(t, int(how)).ToError()
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2).
+func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLenAddr := args[4].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Read the length. Reject negative values.
+ optLen := int32(0)
+ if _, err := t.CopyIn(optLenAddr, &optLen); err != nil {
+ return 0, nil, err
+ }
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Call syscall implementation then copy both value and value len out.
+ v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen))
+ if e != nil {
+ return 0, nil, e.ToError()
+ }
+
+ vLen := int32(binary.Size(v))
+ if _, err := t.CopyOut(optLenAddr, vLen); err != nil {
+ return 0, nil, err
+ }
+
+ if v != nil {
+ if _, err := v.CopyOut(t, optValAddr); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ return 0, nil, nil
+}
+
+// getSockOpt tries to handle common socket options, or dispatches to a specific
+// socket implementation.
+func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) {
+ if level == linux.SOL_SOCKET {
+ switch name {
+ case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
+ if len < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+
+ switch name {
+ case linux.SO_TYPE:
+ _, skType, _ := s.Type()
+ v := primitive.Int32(skType)
+ return &v, nil
+ case linux.SO_DOMAIN:
+ family, _, _ := s.Type()
+ v := primitive.Int32(family)
+ return &v, nil
+ case linux.SO_PROTOCOL:
+ _, _, protocol := s.Type()
+ v := primitive.Int32(protocol)
+ return &v, nil
+ }
+ }
+
+ return s.GetSockOpt(t, level, name, optValAddr, len)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2).
+//
+// Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket.
+func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ level := args[1].Int()
+ name := args[2].Int()
+ optValAddr := args[3].Pointer()
+ optLen := args[4].Int()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if optLen < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if optLen > maxOptLen {
+ return 0, nil, syserror.EINVAL
+ }
+ buf := t.CopyScratchBuffer(int(optLen))
+ if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
+ return 0, nil, err
+ }
+
+ // Call syscall implementation.
+ if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2).
+func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket name and copy it to the caller.
+ v, vl, err := s.GetSockName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// GetPeerName implements the linux syscall getpeername(2).
+func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ addrlen := args[2].Pointer()
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Get the socket peer name and copy it to the caller.
+ v, vl, err := s.GetPeerName(t)
+ if err != nil {
+ return 0, nil, err.ToError()
+ }
+
+ return 0, nil, writeAddress(t, v, vl, addr, addrlen)
+}
+
+// RecvMsg implements the linux syscall recvmsg(2).
+func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
+ return n, nil, err
+}
+
+// RecvMMsg implements the linux syscall recvmmsg(2).
+func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+ toPtr := args[4].Pointer()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if toPtr != 0 {
+ var ts linux.Timespec
+ if _, err := ts.CopyIn(t, toPtr); err != nil {
+ return 0, nil, err
+ }
+ if !ts.Valid() {
+ return 0, nil, syserror.EINVAL
+ }
+ deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
+ haveDeadline = true
+ }
+
+ if !haveDeadline {
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
+ // Capture the message header and io vectors.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ dst, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ // FIXME(b/63594852): Pretend we have an empty error queue.
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return 0, syserror.EAGAIN
+ }
+
+ // Fast path when no control message nor name buffers are provided.
+ if msg.ControlLen == 0 && msg.NameLen == 0 {
+ n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
+ if err != nil {
+ return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS)
+ }
+ if !cms.Unix.Empty() {
+ mflags |= linux.MSG_CTRUNC
+ cms.Release(t)
+ }
+
+ if int(msg.Flags) != mflags {
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+ }
+
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+ defer cms.Release(t)
+
+ controlData := make([]byte, 0, msg.ControlLen)
+ controlData = control.PackControlMessages(t, cms, controlData)
+
+ if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
+ creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
+ controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
+ }
+
+ if cms.Unix.Rights != nil {
+ controlData, mflags = control.PackRightsVFS2(t, cms.Unix.Rights.(control.SCMRightsVFS2), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
+ }
+
+ // Copy the address to the caller.
+ if msg.NameLen != 0 {
+ if err := writeAddress(t, sender, senderLen, usermem.Addr(msg.Name), usermem.Addr(msgPtr+nameLenOffset)); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy the control data to the caller.
+ if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
+ return 0, err
+ }
+ if len(controlData) > 0 {
+ if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Copy out the flags to the caller.
+ if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil {
+ return 0, err
+ }
+
+ return uintptr(n), nil
+}
+
+// recvFrom is the implementation of the recvfrom syscall. It is called by
+// recvfrom and recv syscall handlers.
+func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLenPtr usermem.Addr) (uintptr, error) {
+ if int(bufLen) < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.RecvTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
+ cm.Release(t)
+ if e != nil {
+ return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS)
+ }
+
+ // Copy the address to the caller.
+ if nameLenPtr != 0 {
+ if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil {
+ return 0, err
+ }
+ }
+
+ return uintptr(n), nil
+}
+
+// RecvFrom implements the linux syscall recvfrom(2).
+func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLenPtr := args[5].Pointer()
+
+ n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr)
+ return n, nil, err
+}
+
+// SendMsg implements the linux syscall sendmsg(2).
+func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ n, err := sendSingleMsg(t, s, file, msgPtr, flags)
+ return n, nil, err
+}
+
+// SendMMsg implements the linux syscall sendmmsg(2).
+func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ msgPtr := args[1].Pointer()
+ vlen := args[2].Uint()
+ flags := args[3].Int()
+
+ if t.Arch().Width() != 8 {
+ // We only handle 64-bit for now.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, nil, syserror.ENOTSOCK
+ }
+
+ // Reject flags that we don't handle yet.
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ var count uint32
+ var err error
+ for i := uint64(0); i < uint64(vlen); i++ {
+ mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ var n uintptr
+ if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil {
+ break
+ }
+
+ // Copy the received length to the caller.
+ lp, ok := mp.AddLength(messageHeader64Len)
+ if !ok {
+ return 0, nil, syserror.EFAULT
+ }
+ if _, err = t.CopyOut(lp, uint32(n)); err != nil {
+ break
+ }
+ count++
+ }
+
+ if count == 0 {
+ return 0, nil, err
+ }
+ return uintptr(count), nil, nil
+}
+
+func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) {
+ // Capture the message header.
+ var msg MessageHeader64
+ if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil {
+ return 0, err
+ }
+
+ var controlData []byte
+ if msg.ControlLen > 0 {
+ // Put an upper bound to prevent large allocations.
+ if msg.ControlLen > maxControlLen {
+ return 0, syserror.ENOBUFS
+ }
+ controlData = make([]byte, msg.ControlLen)
+ if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil {
+ return 0, err
+ }
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ if msg.NameLen != 0 {
+ var err error
+ to, err = CaptureAddress(t, usermem.Addr(msg.Name), msg.NameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ // Read data then call the sendmsg implementation.
+ if msg.IovLen > linux.UIO_MAXIOV {
+ return 0, syserror.EMSGSIZE
+ }
+ src, err := t.IovecsIOSequence(usermem.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ controlMessages, err := control.Parse(t, s, controlData)
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
+ err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file)
+ if err != nil {
+ controlMessages.Release(t)
+ }
+ return uintptr(n), err
+}
+
+// sendTo is the implementation of the sendto syscall. It is called by sendto
+// and send syscall handlers.
+func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags int32, namePtr usermem.Addr, nameLen uint32) (uintptr, error) {
+ bl := int(bufLen)
+ if bl < 0 {
+ return 0, syserror.EINVAL
+ }
+
+ // Get socket from the file descriptor.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // Extract the socket.
+ s, ok := file.Impl().(socket.SocketVFS2)
+ if !ok {
+ return 0, syserror.ENOTSOCK
+ }
+
+ if (file.StatusFlags() & linux.SOCK_NONBLOCK) != 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Read the destination address if one is specified.
+ var to []byte
+ var err error
+ if namePtr != 0 {
+ to, err = CaptureAddress(t, namePtr, nameLen)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, err
+ }
+
+ var haveDeadline bool
+ var deadline ktime.Time
+ if dl := s.SendTimeout(); dl > 0 {
+ deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
+ haveDeadline = true
+ } else if dl < 0 {
+ flags |= linux.MSG_DONTWAIT
+ }
+
+ // Call the syscall implementation.
+ n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
+ return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file)
+}
+
+// SendTo implements the linux syscall sendto(2).
+func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufPtr := args[1].Pointer()
+ bufLen := args[2].Uint64()
+ flags := args[3].Int()
+ namePtr := args[4].Pointer()
+ nameLen := args[5].Uint()
+
+ n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
+ return n, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go
new file mode 100644
index 000000000..75bfa2c79
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/splice.go
@@ -0,0 +1,490 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/pipe"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// Splice implements Linux syscall splice(2).
+func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ inOffsetPtr := args[1].Pointer()
+ outFD := args[2].Int()
+ outOffsetPtr := args[3].Pointer()
+ count := int64(args[4].SizeT())
+ flags := args[5].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef(t)
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef(t)
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // At least one file description must represent a pipe.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe && !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy in offsets.
+ inOffset := int64(-1)
+ if inOffsetPtr != 0 {
+ if inIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ if inOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+ outOffset := int64(-1)
+ if outOffsetPtr != 0 {
+ if outIsPipe {
+ return 0, nil, syserror.ESPIPE
+ }
+ if outFile.Options().DenyPWrite {
+ return 0, nil, syserror.EINVAL
+ }
+ if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ if outOffset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Move data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ for {
+ // If both input and output are pipes, delegate to the pipe
+ // implementation. Otherwise, exactly one end is a pipe, which
+ // we ensure is consistently ordered after the non-pipe FD's
+ // locks by passing the pipe FD as usermem.IO to the non-pipe
+ // end.
+ switch {
+ case inIsPipe && outIsPipe:
+ n, err = pipe.Splice(t, outPipeFD, inPipeFD, count)
+ case inIsPipe:
+ if outOffset != -1 {
+ n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{})
+ outOffset += n
+ } else {
+ n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{})
+ }
+ case outIsPipe:
+ if inOffset != -1 {
+ n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{})
+ inOffset += n
+ } else {
+ n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
+ }
+ default:
+ panic("not possible")
+ }
+
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+ if err = dw.waitForBoth(t); err != nil {
+ break
+ }
+ }
+
+ // Copy updated offsets out.
+ if inOffsetPtr != 0 {
+ if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+ if outOffsetPtr != 0 {
+ if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ // On Linux, inotify behavior is not very consistent with splice(2). We try
+ // our best to emulate Linux for very basic calls to splice, where for some
+ // reason, events are generated for output files, but not input files.
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Tee implements Linux syscall tee(2).
+func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ inFD := args[0].Int()
+ outFD := args[1].Int()
+ count := int64(args[2].SizeT())
+ flags := args[3].Int()
+
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Check for invalid flags.
+ if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get file descriptions.
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef(t)
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef(t)
+
+ // Check that both files support the required directionality.
+ if !inFile.IsReadable() || !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := ((inFile.StatusFlags()|outFile.StatusFlags())&linux.O_NONBLOCK != 0) || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
+ // Both file descriptions must represent pipes.
+ inPipeFD, inIsPipe := inFile.Impl().(*pipe.VFSPipeFD)
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ if !inIsPipe || !outIsPipe {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ for {
+ n, err = pipe.Tee(t, outPipeFD, inPipeFD, count)
+ if n != 0 || err != syserror.ErrWouldBlock || nonBlock {
+ break
+ }
+ if err = dw.waitForBoth(t); err != nil {
+ break
+ }
+ }
+ if n == 0 {
+ return 0, nil, err
+ }
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// Sendfile implements linux system call sendfile(2).
+func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ outFD := args[0].Int()
+ inFD := args[1].Int()
+ offsetAddr := args[2].Pointer()
+ count := int64(args[3].SizeT())
+
+ inFile := t.GetFileVFS2(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef(t)
+ if !inFile.IsReadable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ outFile := t.GetFileVFS2(outFD)
+ if outFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer outFile.DecRef(t)
+ if !outFile.IsWritable() {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Verify that the outFile Append flag is not set.
+ if outFile.StatusFlags()&linux.O_APPEND != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Verify that inFile is a regular file or block device. This is a
+ // requirement; the same check appears in Linux
+ // (fs/splice.c:splice_direct_to_actor).
+ if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil {
+ return 0, nil, err
+ } else if stat.Mask&linux.STATX_TYPE == 0 ||
+ (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Copy offset if it exists.
+ offset := int64(-1)
+ if offsetAddr != 0 {
+ if inFile.Options().DenyPRead {
+ return 0, nil, syserror.ESPIPE
+ }
+ if _, err := t.CopyIn(offsetAddr, &offset); err != nil {
+ return 0, nil, err
+ }
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if offset+count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ }
+
+ // Validate count. This must come after offset checks.
+ if count < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if count == 0 {
+ return 0, nil, nil
+ }
+ if count > int64(kernel.MAX_RW_COUNT) {
+ count = int64(kernel.MAX_RW_COUNT)
+ }
+
+ // Copy data.
+ var (
+ n int64
+ err error
+ )
+ dw := dualWaiter{
+ inFile: inFile,
+ outFile: outFile,
+ }
+ defer dw.destroy()
+ outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD)
+ // Reading from input file should never block, since it is regular or
+ // block device. We only need to check if writing to the output file
+ // can block.
+ nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0
+ if outIsPipe {
+ for n < count {
+ var spliceN int64
+ if offset != -1 {
+ spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{})
+ offset += spliceN
+ } else {
+ spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{})
+ }
+ if spliceN == 0 && err == io.EOF {
+ // We reached the end of the file. Eat the error and exit the loop.
+ err = nil
+ break
+ }
+ n += spliceN
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
+ }
+ if err != nil {
+ break
+ }
+ }
+ } else {
+ // Read inFile to buffer, then write the contents to outFile.
+ buf := make([]byte, count)
+ for n < count {
+ var readN int64
+ if offset != -1 {
+ readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{})
+ offset += readN
+ } else {
+ readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ }
+ if readN == 0 && err == io.EOF {
+ // We reached the end of the file. Eat the error and exit the loop.
+ err = nil
+ break
+ }
+ n += readN
+ if err != nil {
+ break
+ }
+
+ // Write all of the bytes that we read. This may need
+ // multiple write calls to complete.
+ wbuf := buf[:n]
+ for len(wbuf) > 0 {
+ var writeN int64
+ writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{})
+ wbuf = wbuf[writeN:]
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForOut(t)
+ }
+ if err != nil {
+ // We didn't complete the write. Only
+ // report the bytes that were actually
+ // written, and rewind the offset.
+ notWritten := int64(len(wbuf))
+ n -= notWritten
+ if offset != -1 {
+ offset -= notWritten
+ }
+ break
+ }
+ }
+ if err == syserror.ErrWouldBlock && !nonBlock {
+ err = dw.waitForBoth(t)
+ }
+ if err != nil {
+ break
+ }
+ }
+ }
+
+ if offsetAddr != 0 {
+ // Copy out the new offset.
+ if _, err := t.CopyOut(offsetAddr, offset); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ if n == 0 {
+ return 0, nil, err
+ }
+
+ inFile.Dentry().InotifyWithParent(t, linux.IN_ACCESS, 0, vfs.PathEvent)
+ outFile.Dentry().InotifyWithParent(t, linux.IN_MODIFY, 0, vfs.PathEvent)
+ return uintptr(n), nil, nil
+}
+
+// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not
+// thread-safe, and does not take a reference on the vfs.FileDescriptions.
+//
+// Users must call destroy() when finished.
+type dualWaiter struct {
+ inFile *vfs.FileDescription
+ outFile *vfs.FileDescription
+
+ inW waiter.Entry
+ inCh chan struct{}
+ outW waiter.Entry
+ outCh chan struct{}
+}
+
+// waitForBoth waits for both dw.inFile and dw.outFile to be ready.
+func (dw *dualWaiter) waitForBoth(t *kernel.Task) error {
+ if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 {
+ if dw.inCh == nil {
+ dw.inW, dw.inCh = waiter.NewChannelEntry(nil)
+ dw.inFile.EventRegister(&dw.inW, eventMaskRead)
+ // We might be ready now. Try again before blocking.
+ return nil
+ }
+ if err := t.Block(dw.inCh); err != nil {
+ return err
+ }
+ }
+ return dw.waitForOut(t)
+}
+
+// waitForOut waits for dw.outfile to be read.
+func (dw *dualWaiter) waitForOut(t *kernel.Task) error {
+ if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 {
+ if dw.outCh == nil {
+ dw.outW, dw.outCh = waiter.NewChannelEntry(nil)
+ dw.outFile.EventRegister(&dw.outW, eventMaskWrite)
+ // We might be ready now. Try again before blocking.
+ return nil
+ }
+ if err := t.Block(dw.outCh); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// destroy cleans up resources help by dw. No more calls to wait* can occur
+// after destroy is called.
+func (dw *dualWaiter) destroy() {
+ if dw.inCh != nil {
+ dw.inFile.EventUnregister(&dw.inW)
+ dw.inCh = nil
+ }
+ if dw.outCh != nil {
+ dw.outFile.EventUnregister(&dw.outW)
+ dw.outCh = nil
+ }
+ dw.inFile = nil
+ dw.outFile = nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
new file mode 100644
index 000000000..0f5d5189c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -0,0 +1,388 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stat implements Linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */)
+}
+
+// Lstat implements Linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2).
+func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+ return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
+}
+
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef(t)
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef(t)
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef(t)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ _, err = stat.CopyOut(t, statAddr)
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef(t)
+ dirfile.DecRef(t)
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ _, err = stat.CopyOut(t, statAddr)
+ return err
+}
+
+func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
+ return linux.Timespec{
+ Sec: sxts.Sec,
+ Nsec: int64(sxts.Nsec),
+ }
+}
+
+// Fstat implements Linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ statx, err := file.Stat(t, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ _, err = stat.CopyOut(t, statAddr)
+ return 0, nil, err
+}
+
+// Statx implements Linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW|linux.AT_STATX_SYNC_TYPE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ // Make sure that only one sync type option is set.
+ syncType := uint32(flags & linux.AT_STATX_SYNC_TYPE)
+ if syncType != 0 && !bits.IsPowerOfTwo32(syncType) {
+ return 0, nil, syserror.EINVAL
+ }
+ if mask&linux.STATX__RESERVED != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: mask,
+ Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE),
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef(t)
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef(t)
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for statx(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef(t)
+ dirfile.DecRef(t)
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ _, err = statx.CopyOut(t, statxAddr)
+ return 0, nil, err
+}
+
+func userifyStatx(t *kernel.Task, statx *linux.Statx) {
+ userns := t.UserNamespace()
+ statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow())
+ statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow())
+}
+
+// Readlink implements Linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+ return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size)
+}
+
+// Access implements Linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+
+ return 0, nil, accessAt(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Faccessat implements Linux syscall faccessat(2).
+//
+// Note that the faccessat() system call does not take a flags argument:
+// "The raw faccessat() system call takes only the first three arguments. The
+// AT_EACCESS and AT_SYMLINK_NOFOLLOW flags are actually implemented within
+// the glibc wrapper function for faccessat(). If either of these flags is
+// specified, then the wrapper function employs fstatat(2) to determine access
+// permissions." - faccessat(2)
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+
+ return 0, nil, accessAt(t, dirfd, addr, mode)
+}
+
+func accessAt(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ const rOK = 4
+ const wOK = 2
+ const xOK = 1
+
+ // Sanity check the mode.
+ if mode&^(rOK|wOK|xOK) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+
+ // access(2) and faccessat(2) check permissions using real
+ // UID/GID, not effective UID/GID.
+ //
+ // "access() needs to use the real uid/gid, not the effective
+ // uid/gid. We do this by temporarily clearing all FS-related
+ // capabilities and switching the fsuid/fsgid around to the
+ // real ones." -fs/open.c:faccessat
+ creds := t.Credentials().Fork()
+ creds.EffectiveKUID = creds.RealKUID
+ creds.EffectiveKGID = creds.RealKGID
+ if creds.EffectiveKUID.In(creds.UserNamespace) == auth.RootUID {
+ creds.EffectiveCaps = creds.PermittedCaps
+ } else {
+ creds.EffectiveCaps = 0
+ }
+
+ return t.Kernel().VFS().AccessAt(t, creds, vfs.AccessTypes(mode), &tpop.pop)
+}
+
+// Readlinkat implements Linux syscall mknodat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ return readlinkat(t, dirfd, pathAddr, bufAddr, size)
+}
+
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+ if int(size) <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // "Since Linux 2.6.39, pathname can be an empty string, in which case the
+ // call operates on the symbolic link referred to by dirfd ..." -
+ // readlinkat(2)
+ tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(target) > int(size) {
+ target = target[:size]
+ }
+ n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target))
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Statfs implements Linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
+}
+
+// Fstatfs implements Linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufAddr := args[1].Pointer()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ _, err = statfs.CopyOut(t, bufAddr)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
new file mode 100644
index 000000000..2da538fc6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_amd64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
new file mode 100644
index 000000000..88b9c7627
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat_arm64.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint32(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int32(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
new file mode 100644
index 000000000..a6491ac37
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -0,0 +1,115 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements Linux syscall sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t)
+}
+
+// Syncfs implements Linux syscall syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ return 0, nil, file.SyncFS(t)
+}
+
+// Fsync implements Linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ return 0, nil, file.Sync(t)
+}
+
+// Fdatasync implements Linux syscall fdatasync(2).
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata.
+ return Fsync(t, args)
+}
+
+// SyncFileRange implements Linux syscall sync_file_range(2).
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ flags := args[3].Uint()
+
+ // Check for negative values and overflow.
+ if offset < 0 || offset+nbytes < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ // TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support
+ // is a full-file sync, i.e. fsync(2). As a result, there are severe
+ // limitations on how much we support sync_file_range:
+ // - In Linux, sync_file_range(2) doesn't write out the file's metadata, even
+ // if the file size is changed. We do.
+ // - We always sync the entire file instead of [offset, offset+nbytes).
+ // - We do not support the use of WAIT_BEFORE without WAIT_AFTER. For
+ // correctness, we would have to perform a write-out every time WAIT_BEFORE
+ // was used, but this would be much more expensive than expected if there
+ // were no write-out operations in progress.
+ // - Whenever WAIT_AFTER is used, we sync the file.
+ // - Ignore WRITE. If this flag is used with WAIT_AFTER, then the file will
+ // be synced anyway. If this flag is used without WAIT_AFTER, then it is
+ // safe (and less expensive) to do nothing, because the syscall will not
+ // wait for the write-out to complete--we only need to make sure that the
+ // next time WAIT_BEFORE or WAIT_AFTER are used, the write-out completes.
+ // - According to fs/sync.c, WAIT_BEFORE|WAIT_AFTER "will detect any I/O
+ // errors or ENOSPC conditions and will return those to the caller, after
+ // clearing the EIO and ENOSPC flags in the address_space." We don't do
+ // this.
+
+ if flags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 &&
+ flags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 {
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, syserror.ENOSYS
+ }
+
+ if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 {
+ if err := file.Sync(t); err != nil {
+ return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
+ }
+ }
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
new file mode 100644
index 000000000..7a26890ef
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go
@@ -0,0 +1,127 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/timerfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// TimerfdCreate implements Linux syscall timerfd_create(2).
+func TimerfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ clockID := args[0].Int()
+ flags := args[1].Int()
+
+ if flags&^(linux.TFD_CLOEXEC|linux.TFD_NONBLOCK) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Timerfds aren't writable per se (their implementation of Write just
+ // returns EINVAL), but they are "opened for writing", which is necessary
+ // to actually reach said implementation of Write.
+ fileFlags := uint32(linux.O_RDWR)
+ if flags&linux.TFD_NONBLOCK != 0 {
+ fileFlags |= linux.O_NONBLOCK
+ }
+
+ var clock ktime.Clock
+ switch clockID {
+ case linux.CLOCK_REALTIME:
+ clock = t.Kernel().RealtimeClock()
+ case linux.CLOCK_MONOTONIC, linux.CLOCK_BOOTTIME:
+ clock = t.Kernel().MonotonicClock()
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+ vfsObj := t.Kernel().VFS()
+ file, err := timerfd.New(t, vfsObj, clock, fileFlags)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef(t)
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.TFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// TimerfdSettime implements Linux syscall timerfd_settime(2).
+func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ flags := args[1].Int()
+ newValAddr := args[2].Pointer()
+ oldValAddr := args[3].Pointer()
+
+ if flags&^(linux.TFD_TIMER_ABSTIME) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var newVal linux.Itimerspec
+ if _, err := t.CopyIn(newValAddr, &newVal); err != nil {
+ return 0, nil, err
+ }
+ newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock())
+ if err != nil {
+ return 0, nil, err
+ }
+ tm, oldS := tfd.SetTime(newS)
+ if oldValAddr != 0 {
+ oldVal := ktime.ItimerspecFromSetting(tm, oldS)
+ if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil {
+ return 0, nil, err
+ }
+ }
+ return 0, nil, nil
+}
+
+// TimerfdGettime implements Linux syscall timerfd_gettime(2).
+func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ curValAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ tfd, ok := file.Impl().(*timerfd.TimerFileDescription)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ tm, s := tfd.GetTime()
+ curVal := ktime.ItimerspecFromSetting(tm, s)
+ _, err := t.CopyOut(curValAddr, &curVal)
+ return 0, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
new file mode 100644
index 000000000..c576d9475
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -0,0 +1,268 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package vfs2 provides syscall implementations that use VFS2.
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/syscalls"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+)
+
+// Override syscall table to add syscalls implementations from this package.
+func Override() {
+ // Override AMD64.
+ s := linux.AMD64
+ s.Table[0] = syscalls.Supported("read", Read)
+ s.Table[1] = syscalls.Supported("write", Write)
+ s.Table[2] = syscalls.Supported("open", Open)
+ s.Table[3] = syscalls.Supported("close", Close)
+ s.Table[4] = syscalls.Supported("stat", Stat)
+ s.Table[5] = syscalls.Supported("fstat", Fstat)
+ s.Table[6] = syscalls.Supported("lstat", Lstat)
+ s.Table[7] = syscalls.Supported("poll", Poll)
+ s.Table[8] = syscalls.Supported("lseek", Lseek)
+ s.Table[9] = syscalls.Supported("mmap", Mmap)
+ s.Table[16] = syscalls.Supported("ioctl", Ioctl)
+ s.Table[17] = syscalls.Supported("pread64", Pread64)
+ s.Table[18] = syscalls.Supported("pwrite64", Pwrite64)
+ s.Table[19] = syscalls.Supported("readv", Readv)
+ s.Table[20] = syscalls.Supported("writev", Writev)
+ s.Table[21] = syscalls.Supported("access", Access)
+ s.Table[22] = syscalls.Supported("pipe", Pipe)
+ s.Table[23] = syscalls.Supported("select", Select)
+ s.Table[32] = syscalls.Supported("dup", Dup)
+ s.Table[33] = syscalls.Supported("dup2", Dup2)
+ s.Table[40] = syscalls.Supported("sendfile", Sendfile)
+ s.Table[41] = syscalls.Supported("socket", Socket)
+ s.Table[42] = syscalls.Supported("connect", Connect)
+ s.Table[43] = syscalls.Supported("accept", Accept)
+ s.Table[44] = syscalls.Supported("sendto", SendTo)
+ s.Table[45] = syscalls.Supported("recvfrom", RecvFrom)
+ s.Table[46] = syscalls.Supported("sendmsg", SendMsg)
+ s.Table[47] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[48] = syscalls.Supported("shutdown", Shutdown)
+ s.Table[49] = syscalls.Supported("bind", Bind)
+ s.Table[50] = syscalls.Supported("listen", Listen)
+ s.Table[51] = syscalls.Supported("getsockname", GetSockName)
+ s.Table[52] = syscalls.Supported("getpeername", GetPeerName)
+ s.Table[53] = syscalls.Supported("socketpair", SocketPair)
+ s.Table[54] = syscalls.Supported("setsockopt", SetSockOpt)
+ s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt)
+ s.Table[59] = syscalls.Supported("execve", Execve)
+ s.Table[72] = syscalls.Supported("fcntl", Fcntl)
+ s.Table[73] = syscalls.Supported("flock", Flock)
+ s.Table[74] = syscalls.Supported("fsync", Fsync)
+ s.Table[75] = syscalls.Supported("fdatasync", Fdatasync)
+ s.Table[76] = syscalls.Supported("truncate", Truncate)
+ s.Table[77] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[78] = syscalls.Supported("getdents", Getdents)
+ s.Table[79] = syscalls.Supported("getcwd", Getcwd)
+ s.Table[80] = syscalls.Supported("chdir", Chdir)
+ s.Table[81] = syscalls.Supported("fchdir", Fchdir)
+ s.Table[82] = syscalls.Supported("rename", Rename)
+ s.Table[83] = syscalls.Supported("mkdir", Mkdir)
+ s.Table[84] = syscalls.Supported("rmdir", Rmdir)
+ s.Table[85] = syscalls.Supported("creat", Creat)
+ s.Table[86] = syscalls.Supported("link", Link)
+ s.Table[87] = syscalls.Supported("unlink", Unlink)
+ s.Table[88] = syscalls.Supported("symlink", Symlink)
+ s.Table[89] = syscalls.Supported("readlink", Readlink)
+ s.Table[90] = syscalls.Supported("chmod", Chmod)
+ s.Table[91] = syscalls.Supported("fchmod", Fchmod)
+ s.Table[92] = syscalls.Supported("chown", Chown)
+ s.Table[93] = syscalls.Supported("fchown", Fchown)
+ s.Table[94] = syscalls.Supported("lchown", Lchown)
+ s.Table[132] = syscalls.Supported("utime", Utime)
+ s.Table[133] = syscalls.Supported("mknod", Mknod)
+ s.Table[137] = syscalls.Supported("statfs", Statfs)
+ s.Table[138] = syscalls.Supported("fstatfs", Fstatfs)
+ s.Table[161] = syscalls.Supported("chroot", Chroot)
+ s.Table[162] = syscalls.Supported("sync", Sync)
+ s.Table[165] = syscalls.Supported("mount", Mount)
+ s.Table[166] = syscalls.Supported("umount2", Umount2)
+ s.Table[187] = syscalls.Supported("readahead", Readahead)
+ s.Table[188] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
+ s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
+ s.Table[191] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
+ s.Table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
+ s.Table[194] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[195] = syscalls.Supported("llistxattr", Llistxattr)
+ s.Table[196] = syscalls.Supported("flistxattr", Flistxattr)
+ s.Table[197] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
+ s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
+ s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"})
+ s.Table[213] = syscalls.Supported("epoll_create", EpollCreate)
+ s.Table[217] = syscalls.Supported("getdents64", Getdents64)
+ s.Table[221] = syscalls.PartiallySupported("fadvise64", Fadvise64, "The syscall is 'supported', but ignores all provided advice.", nil)
+ s.Table[232] = syscalls.Supported("epoll_wait", EpollWait)
+ s.Table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
+ s.Table[235] = syscalls.Supported("utimes", Utimes)
+ s.Table[253] = syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil)
+ s.Table[254] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[255] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[257] = syscalls.Supported("openat", Openat)
+ s.Table[258] = syscalls.Supported("mkdirat", Mkdirat)
+ s.Table[259] = syscalls.Supported("mknodat", Mknodat)
+ s.Table[260] = syscalls.Supported("fchownat", Fchownat)
+ s.Table[261] = syscalls.Supported("futimesat", Futimesat)
+ s.Table[262] = syscalls.Supported("newfstatat", Newfstatat)
+ s.Table[263] = syscalls.Supported("unlinkat", Unlinkat)
+ s.Table[264] = syscalls.Supported("renameat", Renameat)
+ s.Table[265] = syscalls.Supported("linkat", Linkat)
+ s.Table[266] = syscalls.Supported("symlinkat", Symlinkat)
+ s.Table[267] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[268] = syscalls.Supported("fchmodat", Fchmodat)
+ s.Table[269] = syscalls.Supported("faccessat", Faccessat)
+ s.Table[270] = syscalls.Supported("pselect", Pselect)
+ s.Table[271] = syscalls.Supported("ppoll", Ppoll)
+ s.Table[275] = syscalls.Supported("splice", Splice)
+ s.Table[276] = syscalls.Supported("tee", Tee)
+ s.Table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
+ s.Table[280] = syscalls.Supported("utimensat", Utimensat)
+ s.Table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
+ s.Table[282] = syscalls.Supported("signalfd", Signalfd)
+ s.Table[283] = syscalls.Supported("timerfd_create", TimerfdCreate)
+ s.Table[284] = syscalls.Supported("eventfd", Eventfd)
+ s.Table[285] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil)
+ s.Table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ s.Table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ s.Table[288] = syscalls.Supported("accept4", Accept4)
+ s.Table[289] = syscalls.Supported("signalfd4", Signalfd4)
+ s.Table[290] = syscalls.Supported("eventfd2", Eventfd2)
+ s.Table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
+ s.Table[292] = syscalls.Supported("dup3", Dup3)
+ s.Table[293] = syscalls.Supported("pipe2", Pipe2)
+ s.Table[294] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil)
+ s.Table[295] = syscalls.Supported("preadv", Preadv)
+ s.Table[296] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[299] = syscalls.Supported("recvmmsg", RecvMMsg)
+ s.Table[306] = syscalls.Supported("syncfs", Syncfs)
+ s.Table[307] = syscalls.Supported("sendmmsg", SendMMsg)
+ s.Table[316] = syscalls.Supported("renameat2", Renameat2)
+ s.Table[319] = syscalls.Supported("memfd_create", MemfdCreate)
+ s.Table[322] = syscalls.Supported("execveat", Execveat)
+ s.Table[327] = syscalls.Supported("preadv2", Preadv2)
+ s.Table[328] = syscalls.Supported("pwritev2", Pwritev2)
+ s.Table[332] = syscalls.Supported("statx", Statx)
+ s.Init()
+
+ // Override ARM64.
+ s = linux.ARM64
+ s.Table[5] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
+ s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
+ s.Table[8] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr)
+ s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr)
+ s.Table[11] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[12] = syscalls.Supported("llistxattr", Llistxattr)
+ s.Table[13] = syscalls.Supported("flistxattr", Flistxattr)
+ s.Table[14] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr)
+ s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr)
+ s.Table[17] = syscalls.Supported("getcwd", Getcwd)
+ s.Table[19] = syscalls.Supported("eventfd2", Eventfd2)
+ s.Table[20] = syscalls.Supported("epoll_create1", EpollCreate1)
+ s.Table[21] = syscalls.Supported("epoll_ctl", EpollCtl)
+ s.Table[22] = syscalls.Supported("epoll_pwait", EpollPwait)
+ s.Table[23] = syscalls.Supported("dup", Dup)
+ s.Table[24] = syscalls.Supported("dup3", Dup3)
+ s.Table[25] = syscalls.Supported("fcntl", Fcntl)
+ s.Table[26] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil)
+ s.Table[27] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[28] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[29] = syscalls.Supported("ioctl", Ioctl)
+ s.Table[32] = syscalls.Supported("flock", Flock)
+ s.Table[33] = syscalls.Supported("mknodat", Mknodat)
+ s.Table[34] = syscalls.Supported("mkdirat", Mkdirat)
+ s.Table[35] = syscalls.Supported("unlinkat", Unlinkat)
+ s.Table[36] = syscalls.Supported("symlinkat", Symlinkat)
+ s.Table[37] = syscalls.Supported("linkat", Linkat)
+ s.Table[38] = syscalls.Supported("renameat", Renameat)
+ s.Table[39] = syscalls.Supported("umount2", Umount2)
+ s.Table[40] = syscalls.Supported("mount", Mount)
+ s.Table[43] = syscalls.Supported("statfs", Statfs)
+ s.Table[44] = syscalls.Supported("fstatfs", Fstatfs)
+ s.Table[45] = syscalls.Supported("truncate", Truncate)
+ s.Table[46] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[48] = syscalls.Supported("faccessat", Faccessat)
+ s.Table[49] = syscalls.Supported("chdir", Chdir)
+ s.Table[50] = syscalls.Supported("fchdir", Fchdir)
+ s.Table[51] = syscalls.Supported("chroot", Chroot)
+ s.Table[52] = syscalls.Supported("fchmod", Fchmod)
+ s.Table[53] = syscalls.Supported("fchmodat", Fchmodat)
+ s.Table[54] = syscalls.Supported("fchownat", Fchownat)
+ s.Table[55] = syscalls.Supported("fchown", Fchown)
+ s.Table[56] = syscalls.Supported("openat", Openat)
+ s.Table[57] = syscalls.Supported("close", Close)
+ s.Table[59] = syscalls.Supported("pipe2", Pipe2)
+ s.Table[61] = syscalls.Supported("getdents64", Getdents64)
+ s.Table[62] = syscalls.Supported("lseek", Lseek)
+ s.Table[63] = syscalls.Supported("read", Read)
+ s.Table[64] = syscalls.Supported("write", Write)
+ s.Table[65] = syscalls.Supported("readv", Readv)
+ s.Table[66] = syscalls.Supported("writev", Writev)
+ s.Table[67] = syscalls.Supported("pread64", Pread64)
+ s.Table[68] = syscalls.Supported("pwrite64", Pwrite64)
+ s.Table[69] = syscalls.Supported("preadv", Preadv)
+ s.Table[70] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[72] = syscalls.Supported("pselect", Pselect)
+ s.Table[73] = syscalls.Supported("ppoll", Ppoll)
+ s.Table[74] = syscalls.Supported("signalfd4", Signalfd4)
+ s.Table[76] = syscalls.Supported("splice", Splice)
+ s.Table[77] = syscalls.Supported("tee", Tee)
+ s.Table[78] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[80] = syscalls.Supported("fstat", Fstat)
+ s.Table[81] = syscalls.Supported("sync", Sync)
+ s.Table[82] = syscalls.Supported("fsync", Fsync)
+ s.Table[83] = syscalls.Supported("fdatasync", Fdatasync)
+ s.Table[84] = syscalls.Supported("sync_file_range", SyncFileRange)
+ s.Table[85] = syscalls.Supported("timerfd_create", TimerfdCreate)
+ s.Table[86] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ s.Table[87] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ s.Table[88] = syscalls.Supported("utimensat", Utimensat)
+ s.Table[198] = syscalls.Supported("socket", Socket)
+ s.Table[199] = syscalls.Supported("socketpair", SocketPair)
+ s.Table[200] = syscalls.Supported("bind", Bind)
+ s.Table[201] = syscalls.Supported("listen", Listen)
+ s.Table[202] = syscalls.Supported("accept", Accept)
+ s.Table[203] = syscalls.Supported("connect", Connect)
+ s.Table[204] = syscalls.Supported("getsockname", GetSockName)
+ s.Table[205] = syscalls.Supported("getpeername", GetPeerName)
+ s.Table[206] = syscalls.Supported("sendto", SendTo)
+ s.Table[207] = syscalls.Supported("recvfrom", RecvFrom)
+ s.Table[208] = syscalls.Supported("setsockopt", SetSockOpt)
+ s.Table[209] = syscalls.Supported("getsockopt", GetSockOpt)
+ s.Table[210] = syscalls.Supported("shutdown", Shutdown)
+ s.Table[211] = syscalls.Supported("sendmsg", SendMsg)
+ s.Table[212] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[221] = syscalls.Supported("execve", Execve)
+ s.Table[222] = syscalls.Supported("mmap", Mmap)
+ s.Table[242] = syscalls.Supported("accept4", Accept4)
+ s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg)
+ s.Table[267] = syscalls.Supported("syncfs", Syncfs)
+ s.Table[269] = syscalls.Supported("sendmmsg", SendMMsg)
+ s.Table[276] = syscalls.Supported("renameat2", Renameat2)
+ s.Table[279] = syscalls.Supported("memfd_create", MemfdCreate)
+ s.Table[281] = syscalls.Supported("execveat", Execveat)
+ s.Table[286] = syscalls.Supported("preadv2", Preadv2)
+ s.Table[287] = syscalls.Supported("pwritev2", Pwritev2)
+ s.Table[291] = syscalls.Supported("statx", Statx)
+
+ s.Init()
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
new file mode 100644
index 000000000..ef99246ed
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -0,0 +1,356 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Listxattr implements Linux syscall listxattr(2).
+func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, followFinalSymlink)
+}
+
+// Llistxattr implements Linux syscall llistxattr(2).
+func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, nofollowFinalSymlink)
+}
+
+func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size))
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Flistxattr implements Linux syscall flistxattr(2).
+func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ names, err := file.Listxattr(t, uint64(size))
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Getxattr implements Linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, followFinalSymlink)
+}
+
+// Lgetxattr implements Linux syscall lgetxattr(2).
+func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, nofollowFinalSymlink)
+}
+
+func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release(t)
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{
+ Name: name,
+ Size: uint64(size),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Fgetxattr implements Linux syscall fgetxattr(2).
+func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)})
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Setxattr implements Linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, followFinalSymlink)
+}
+
+// Lsetxattr implements Linux syscall lsetxattr(2).
+func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, nofollowFinalSymlink)
+}
+
+func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Fsetxattr implements Linux syscall fsetxattr(2).
+func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Removexattr implements Linux syscall removexattr(2).
+func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, followFinalSymlink)
+}
+
+// Lremovexattr implements Linux syscall lremovexattr(2).
+func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, nofollowFinalSymlink)
+}
+
+func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release(t)
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+}
+
+// Fremovexattr implements Linux syscall fremovexattr(2).
+func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef(t)
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Removexattr(t, name)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+ if size > linux.XATTR_LIST_MAX {
+ size = linux.XATTR_LIST_MAX
+ }
+ var buf bytes.Buffer
+ for _, name := range names {
+ buf.WriteString(name)
+ buf.WriteByte(0)
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the list.
+ return buf.Len(), nil
+ }
+ if buf.Len() > int(size) {
+ if size >= linux.XATTR_LIST_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(listAddr, buf.Bytes())
+}
+
+func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ return "", syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return "", err
+ }
+ return gohacks.StringFromImmutableBytes(buf), nil
+}
+
+func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the value.
+ return len(value), nil
+ }
+ if len(value) > int(size) {
+ if size >= linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value))
+}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index d3a4cd943..04f81a35b 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -9,7 +8,7 @@ go_template_instance(
out = "seqatomic_parameters_unsafe.go",
package = "time",
suffix = "Parameters",
- template = "//third_party/gvsync:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "Parameters",
},
@@ -31,13 +30,12 @@ go_library(
"tsc_amd64.s",
"tsc_arm64.s",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/time",
visibility = ["//:sandbox"],
deps = [
"//pkg/log",
"//pkg/metric",
+ "//pkg/sync",
"//pkg/syserror",
- "//third_party/gvsync",
],
)
@@ -48,5 +46,5 @@ go_test(
"parameters_test.go",
"sampler_test.go",
],
- embed = [":time"],
+ library = ":time",
)
diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go
index 318503277..f9a93115d 100644
--- a/pkg/sentry/time/calibrated_clock.go
+++ b/pkg/sentry/time/calibrated_clock.go
@@ -17,11 +17,11 @@
package time
import (
- "sync"
"time"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/time/muldiv_arm64.s b/pkg/sentry/time/muldiv_arm64.s
index 5ad57a8a3..8afc62d53 100644
--- a/pkg/sentry/time/muldiv_arm64.s
+++ b/pkg/sentry/time/muldiv_arm64.s
@@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "funcdata.h"
#include "textflag.h"
// Documentation is available in parameters.go.
//
// func muldiv64(value, multiplier, divisor uint64) (uint64, bool)
TEXT ·muldiv64(SB),NOSPLIT,$40-33
+ GO_ARGS
+ NO_LOCAL_POINTERS
MOVD value+0(FP), R0
MOVD multiplier+8(FP), R1
MOVD divisor+16(FP), R2
diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go
index 65868cb26..cd1b95117 100644
--- a/pkg/sentry/time/parameters.go
+++ b/pkg/sentry/time/parameters.go
@@ -228,11 +228,15 @@ func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Par
//
// The log level is determined by the error severity.
func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) {
- fn := log.Debugf
- if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() {
+ magNS := int64(errorNS.Magnitude())
+ if magNS <= 10*time.Microsecond.Nanoseconds() {
+ // Don't log small errors.
+ return
+ }
+ fn := log.Infof
+ if magNS > time.Millisecond.Nanoseconds() {
+ // Upgrade large errors to warning.
fn = log.Warningf
- } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() {
- fn = log.Infof
}
fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency)
diff --git a/pkg/sentry/time/parameters_test.go b/pkg/sentry/time/parameters_test.go
index e1b9084ac..0ce1257f6 100644
--- a/pkg/sentry/time/parameters_test.go
+++ b/pkg/sentry/time/parameters_test.go
@@ -484,3 +484,18 @@ func TestMulDivOverflow(t *testing.T) {
})
}
}
+
+func BenchmarkMuldiv64(b *testing.B) {
+ var v uint64 = math.MaxUint64
+ for i := uint64(1); i <= 1000000; i++ {
+ mult := uint64(1000000000)
+ div := i * mult
+ res, ok := muldiv64(v, mult, div)
+ if !ok {
+ b.Errorf("Result of %v * %v / %v ok got false want true", v, mult, div)
+ }
+ if want := v / i; res != want {
+ b.Errorf("Bad result of %v * %v / %v: got %v, want %v", v, mult, div, res, want)
+ }
+ }
+}
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
index fc7614fff..5d4aa3a63 100644
--- a/pkg/sentry/unimpl/BUILD
+++ b/pkg/sentry/unimpl/BUILD
@@ -1,37 +1,20 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@rules_cc//cc:defs.bzl", "cc_proto_library")
+load("//tools:defs.bzl", "go_library", "proto_library")
package(licenses = ["notice"])
proto_library(
- name = "unimplemented_syscall_proto",
+ name = "unimplemented_syscall",
srcs = ["unimplemented_syscall.proto"],
visibility = ["//visibility:public"],
deps = ["//pkg/sentry/arch:registers_proto"],
)
-cc_proto_library(
- name = "unimplemented_syscall_cc_proto",
- visibility = ["//visibility:public"],
- deps = [":unimplemented_syscall_proto"],
-)
-
-go_proto_library(
- name = "unimplemented_syscall_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto",
- proto = ":unimplemented_syscall_proto",
- visibility = ["//visibility:public"],
- deps = ["//pkg/sentry/arch:registers_go_proto"],
-)
-
go_library(
name = "unimpl",
srcs = ["events.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/context",
"//pkg/log",
- "//pkg/sentry/context",
],
)
diff --git a/pkg/sentry/unimpl/events.go b/pkg/sentry/unimpl/events.go
index 79b5de9e4..73ed9372f 100644
--- a/pkg/sentry/unimpl/events.go
+++ b/pkg/sentry/unimpl/events.go
@@ -17,8 +17,8 @@
package unimpl
import (
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
)
// contextID is the events package's type for context.Context.Value keys.
diff --git a/pkg/sentry/uniqueid/BUILD b/pkg/sentry/uniqueid/BUILD
index 86a87edd4..7467e6398 100644
--- a/pkg/sentry/uniqueid/BUILD
+++ b/pkg/sentry/uniqueid/BUILD
@@ -1,14 +1,13 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "uniqueid",
srcs = ["context.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/uniqueid",
visibility = ["//pkg/sentry:internal"],
deps = [
- "//pkg/sentry/context",
+ "//pkg/context",
"//pkg/sentry/socket/unix/transport",
],
)
diff --git a/pkg/sentry/uniqueid/context.go b/pkg/sentry/uniqueid/context.go
index 4e466d66d..1fb884a90 100644
--- a/pkg/sentry/uniqueid/context.go
+++ b/pkg/sentry/uniqueid/context.go
@@ -17,7 +17,7 @@
package uniqueid
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
diff --git a/pkg/sentry/usage/BUILD b/pkg/sentry/usage/BUILD
index c32fe3241..099315613 100644
--- a/pkg/sentry/usage/BUILD
+++ b/pkg/sentry/usage/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -11,12 +11,12 @@ go_library(
"memory_unsafe.go",
"usage.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/usage",
visibility = [
- "//pkg/sentry:internal",
+ "//:sandbox",
],
deps = [
"//pkg/bits",
"//pkg/memutil",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index d6ef644d8..ab1d140d2 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -17,12 +17,12 @@ package usage
import (
"fmt"
"os"
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/memutil"
+ "gvisor.dev/gvisor/pkg/sync"
)
// MemoryKind represents a type of memory used by the application.
@@ -252,14 +252,23 @@ func (m *MemoryLocked) Copy() (MemoryStats, uint64) {
return ms, m.totalLocked()
}
-// MinimumTotalMemoryBytes is the minimum reported total system memory.
-var MinimumTotalMemoryBytes uint64 = 2 << 30 // 2 GB
+// These options control how much total memory the is reported to the application.
+// They may only be set before the application starts executing, and must not
+// be modified.
+var (
+ // MinimumTotalMemoryBytes is the minimum reported total system memory.
+ MinimumTotalMemoryBytes uint64 = 2 << 30 // 2 GB
+
+ // MaximumTotalMemoryBytes is the maximum reported total system memory.
+ // The 0 value indicates no maximum.
+ MaximumTotalMemoryBytes uint64
+)
// TotalMemory returns the "total usable memory" available.
//
// This number doesn't really have a true value so it's based on the following
-// inputs and further bounded to be above some minimum guaranteed value (2GB),
-// additionally ensuring that total memory reported is always less than used.
+// inputs and further bounded to be above the MinumumTotalMemoryBytes and below
+// MaximumTotalMemoryBytes.
//
// memSize should be the platform.Memory size reported by platform.Memory.TotalSize()
// used is the total memory reported by MemoryLocked.Total()
@@ -275,5 +284,8 @@ func TotalMemory(memSize, used uint64) uint64 {
memSize = uint64(1) << (uint(msb) + 1)
}
}
+ if MaximumTotalMemoryBytes > 0 && memSize > MaximumTotalMemoryBytes {
+ memSize = MaximumTotalMemoryBytes
+ }
return memSize
}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index eff4b44f6..642769e7c 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -1,40 +1,83 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
-package(licenses = ["notice"])
+licenses(["notice"])
+
+go_template_instance(
+ name = "epoll_interest_list",
+ out = "epoll_interest_list.go",
+ package = "vfs",
+ prefix = "epollInterest",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*epollInterest",
+ "Linker": "*epollInterest",
+ },
+)
+
+go_template_instance(
+ name = "event_list",
+ out = "event_list.go",
+ package = "vfs",
+ prefix = "event",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Event",
+ "Linker": "*Event",
+ },
+)
go_library(
name = "vfs",
srcs = [
+ "anonfs.go",
"context.go",
"debug.go",
"dentry.go",
+ "device.go",
+ "epoll.go",
+ "epoll_interest_list.go",
+ "event_list.go",
"file_description.go",
"file_description_impl_util.go",
"filesystem.go",
+ "filesystem_impl_util.go",
"filesystem_type.go",
+ "inotify.go",
+ "lock.go",
"mount.go",
"mount_unsafe.go",
"options.go",
+ "pathname.go",
"permissions.go",
"resolving_path.go",
- "syscalls.go",
- "testutil.go",
"vfs.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/vfs",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/fd",
+ "//pkg/fdnotifier",
"//pkg/fspath",
+ "//pkg/gohacks",
+ "//pkg/log",
+ "//pkg/safemem",
"//pkg/sentry/arch",
- "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/lock",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
"//pkg/sentry/memmap",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/uniqueid",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
"//pkg/waiter",
- "//third_party/gvsync",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
@@ -45,13 +88,13 @@ go_test(
"file_description_impl_util_test.go",
"mount_test.go",
],
- embed = [":vfs"],
+ library = ":vfs",
deps = [
"//pkg/abi/linux",
- "//pkg/sentry/context",
- "//pkg/sentry/context/contexttest",
- "//pkg/sentry/kernel/auth",
- "//pkg/sentry/usermem",
+ "//pkg/context",
+ "//pkg/sentry/contexttest",
+ "//pkg/sync",
"//pkg/syserror",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md
index 7847854bc..4b9faf2ea 100644
--- a/pkg/sentry/vfs/README.md
+++ b/pkg/sentry/vfs/README.md
@@ -39,8 +39,8 @@ Mount references are held by:
- Mount: Each referenced Mount holds a reference on its parent, which is the
mount containing its mount point.
-- VirtualFilesystem: A reference is held on all Mounts that are attached
- (reachable by Mount traversal).
+- VirtualFilesystem: A reference is held on each Mount that has been connected
+ to a mount point, but not yet umounted.
MountNamespace and FileDescription references are held by users of VFS. The
expectation is that each `kernel.Task` holds a reference on its corresponding
@@ -169,8 +169,6 @@ This construction, which is essentially a type-safe analogue to Linux's
- binder, which is similarly far too incomplete to use.
- - whitelistfs, which we are already actively attempting to remove.
-
- Save/restore. For instance, it is unclear if the current implementation of
the `state` package supports the inheritance pattern described above.
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
new file mode 100644
index 000000000..5a0e3e6b5
--- /dev/null
+++ b/pkg/sentry/vfs/anonfs.go
@@ -0,0 +1,314 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// NewAnonVirtualDentry returns a VirtualDentry with the given synthetic name,
+// consistent with Linux's fs/anon_inodes.c:anon_inode_getfile(). References
+// are taken on the returned VirtualDentry.
+func (vfs *VirtualFilesystem) NewAnonVirtualDentry(name string) VirtualDentry {
+ d := anonDentry{
+ name: name,
+ }
+ d.vfsd.Init(&d)
+ vfs.anonMount.IncRef()
+ // anonDentry no-ops refcounting.
+ return VirtualDentry{
+ mount: vfs.anonMount,
+ dentry: &d.vfsd,
+ }
+}
+
+const (
+ anonfsBlockSize = usermem.PageSize // via fs/libfs.c:pseudo_fs_fill_super()
+
+ // Mode, UID, and GID for a generic anonfs file.
+ anonFileMode = 0600 // no type is correct
+ anonFileUID = auth.RootKUID
+ anonFileGID = auth.RootKGID
+)
+
+// anonFilesystemType implements FilesystemType.
+type anonFilesystemType struct{}
+
+// GetFilesystem implements FilesystemType.GetFilesystem.
+func (anonFilesystemType) GetFilesystem(context.Context, *VirtualFilesystem, *auth.Credentials, string, GetFilesystemOptions) (*Filesystem, *Dentry, error) {
+ panic("cannot instaniate an anon filesystem")
+}
+
+// Name implemenents FilesystemType.Name.
+func (anonFilesystemType) Name() string {
+ return "none"
+}
+
+// anonFilesystem is the implementation of FilesystemImpl that backs
+// VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
+//
+// Since all Dentries in anonFilesystem are non-directories, all FilesystemImpl
+// methods that would require an anonDentry to be a directory return ENOTDIR.
+type anonFilesystem struct {
+ vfsfs Filesystem
+
+ devMinor uint32
+}
+
+type anonDentry struct {
+ vfsd Dentry
+
+ name string
+}
+
+// Release implements FilesystemImpl.Release.
+func (fs *anonFilesystem) Release(ctx context.Context) {
+}
+
+// Sync implements FilesystemImpl.Sync.
+func (fs *anonFilesystem) Sync(ctx context.Context) error {
+ return nil
+}
+
+// AccessAt implements vfs.Filesystem.Impl.AccessAt.
+func (fs *anonFilesystem) AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return GenericCheckPermissions(creds, ats, anonFileMode, anonFileUID, anonFileGID)
+}
+
+// GetDentryAt implements FilesystemImpl.GetDentryAt.
+func (fs *anonFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) {
+ if !rp.Done() {
+ return nil, syserror.ENOTDIR
+ }
+ if opts.CheckSearchable {
+ return nil, syserror.ENOTDIR
+ }
+ // anonDentry no-ops refcounting.
+ return rp.Start(), nil
+}
+
+// GetParentDentryAt implements FilesystemImpl.GetParentDentryAt.
+func (fs *anonFilesystem) GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error) {
+ if !rp.Final() {
+ return nil, syserror.ENOTDIR
+ }
+ // anonDentry no-ops refcounting.
+ return rp.Start(), nil
+}
+
+// LinkAt implements FilesystemImpl.LinkAt.
+func (fs *anonFilesystem) LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// MkdirAt implements FilesystemImpl.MkdirAt.
+func (fs *anonFilesystem) MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// MknodAt implements FilesystemImpl.MknodAt.
+func (fs *anonFilesystem) MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// OpenAt implements FilesystemImpl.OpenAt.
+func (fs *anonFilesystem) OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error) {
+ if !rp.Done() {
+ return nil, syserror.ENOTDIR
+ }
+ return nil, syserror.ENODEV
+}
+
+// ReadlinkAt implements FilesystemImpl.ReadlinkAt.
+func (fs *anonFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error) {
+ if !rp.Done() {
+ return "", syserror.ENOTDIR
+ }
+ return "", syserror.EINVAL
+}
+
+// RenameAt implements FilesystemImpl.RenameAt.
+func (fs *anonFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// RmdirAt implements FilesystemImpl.RmdirAt.
+func (fs *anonFilesystem) RmdirAt(ctx context.Context, rp *ResolvingPath) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// SetStatAt implements FilesystemImpl.SetStatAt.
+func (fs *anonFilesystem) SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ // Linux actually permits anon_inode_inode's metadata to be set, which is
+ // visible to all users of anon_inode_inode. We just silently ignore
+ // metadata changes.
+ return nil
+}
+
+// StatAt implements FilesystemImpl.StatAt.
+func (fs *anonFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts StatOptions) (linux.Statx, error) {
+ if !rp.Done() {
+ return linux.Statx{}, syserror.ENOTDIR
+ }
+ // See fs/anon_inodes.c:anon_inode_init() => fs/libfs.c:alloc_anon_inode().
+ return linux.Statx{
+ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS,
+ Blksize: anonfsBlockSize,
+ Nlink: 1,
+ UID: uint32(anonFileUID),
+ GID: uint32(anonFileGID),
+ Mode: anonFileMode,
+ Ino: 1,
+ Size: 0,
+ Blocks: 0,
+ DevMajor: linux.UNNAMED_MAJOR,
+ DevMinor: fs.devMinor,
+ }, nil
+}
+
+// StatFSAt implements FilesystemImpl.StatFSAt.
+func (fs *anonFilesystem) StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error) {
+ if !rp.Done() {
+ return linux.Statfs{}, syserror.ENOTDIR
+ }
+ return linux.Statfs{
+ Type: linux.ANON_INODE_FS_MAGIC,
+ BlockSize: anonfsBlockSize,
+ }, nil
+}
+
+// SymlinkAt implements FilesystemImpl.SymlinkAt.
+func (fs *anonFilesystem) SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// UnlinkAt implements FilesystemImpl.UnlinkAt.
+func (fs *anonFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error {
+ if !rp.Final() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt.
+func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ if !rp.Final() {
+ return nil, syserror.ENOTDIR
+ }
+ if err := GenericCheckPermissions(rp.Credentials(), MayWrite, anonFileMode, anonFileUID, anonFileGID); err != nil {
+ return nil, err
+ }
+ return nil, syserror.ECONNREFUSED
+}
+
+// ListxattrAt implements FilesystemImpl.ListxattrAt.
+func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) {
+ if !rp.Done() {
+ return nil, syserror.ENOTDIR
+ }
+ return nil, nil
+}
+
+// GetxattrAt implements FilesystemImpl.GetxattrAt.
+func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) {
+ if !rp.Done() {
+ return "", syserror.ENOTDIR
+ }
+ return "", syserror.ENOTSUP
+}
+
+// SetxattrAt implements FilesystemImpl.SetxattrAt.
+func (fs *anonFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// RemovexattrAt implements FilesystemImpl.RemovexattrAt.
+func (fs *anonFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error {
+ if !rp.Done() {
+ return syserror.ENOTDIR
+ }
+ return syserror.EPERM
+}
+
+// PrependPath implements FilesystemImpl.PrependPath.
+func (fs *anonFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error {
+ b.PrependComponent(fmt.Sprintf("anon_inode:%s", vd.dentry.impl.(*anonDentry).name))
+ return PrependPathSyntheticError{}
+}
+
+// IncRef implements DentryImpl.IncRef.
+func (d *anonDentry) IncRef() {
+ // no-op
+}
+
+// TryIncRef implements DentryImpl.TryIncRef.
+func (d *anonDentry) TryIncRef() bool {
+ return true
+}
+
+// DecRef implements DentryImpl.DecRef.
+func (d *anonDentry) DecRef(ctx context.Context) {
+ // no-op
+}
+
+// InotifyWithParent implements DentryImpl.InotifyWithParent.
+//
+// Although Linux technically supports inotify on pseudo filesystems (inotify
+// is implemented at the vfs layer), it is not particularly useful. It is left
+// unimplemented until someone actually needs it.
+func (d *anonDentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) {}
+
+// Watches implements DentryImpl.Watches.
+func (d *anonDentry) Watches() *Watches {
+ return nil
+}
+
+// OnZeroWatches implements Dentry.OnZeroWatches.
+func (d *anonDentry) OnZeroWatches(context.Context) {}
diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go
index 32cf9151b..c9e724fef 100644
--- a/pkg/sentry/vfs/context.go
+++ b/pkg/sentry/vfs/context.go
@@ -15,7 +15,7 @@
package vfs
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// contextID is this package's type for context.Context.Value keys.
@@ -24,14 +24,52 @@ type contextID int
const (
// CtxMountNamespace is a Context.Value key for a MountNamespace.
CtxMountNamespace contextID = iota
+
+ // CtxRoot is a Context.Value key for a VFS root.
+ CtxRoot
)
-// MountNamespaceFromContext returns the MountNamespace used by ctx. It does
-// not take a reference on the returned MountNamespace. If ctx is not
-// associated with a MountNamespace, MountNamespaceFromContext returns nil.
+// MountNamespaceFromContext returns the MountNamespace used by ctx. If ctx is
+// not associated with a MountNamespace, MountNamespaceFromContext returns nil.
+//
+// A reference is taken on the returned MountNamespace.
func MountNamespaceFromContext(ctx context.Context) *MountNamespace {
if v := ctx.Value(CtxMountNamespace); v != nil {
return v.(*MountNamespace)
}
return nil
}
+
+// RootFromContext returns the VFS root used by ctx. It takes a reference on
+// the returned VirtualDentry. If ctx does not have a specific VFS root,
+// RootFromContext returns a zero-value VirtualDentry.
+func RootFromContext(ctx context.Context) VirtualDentry {
+ if v := ctx.Value(CtxRoot); v != nil {
+ return v.(VirtualDentry)
+ }
+ return VirtualDentry{}
+}
+
+type rootContext struct {
+ context.Context
+ root VirtualDentry
+}
+
+// WithRoot returns a copy of ctx with the given root.
+func WithRoot(ctx context.Context, root VirtualDentry) context.Context {
+ return &rootContext{
+ Context: ctx,
+ root: root,
+ }
+}
+
+// Value implements Context.Value.
+func (rc rootContext) Value(key interface{}) interface{} {
+ switch key {
+ case CtxRoot:
+ rc.root.IncRef()
+ return rc.root
+ default:
+ return rc.Context.Value(key)
+ }
+}
diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go
index 45912fc58..bc7ea93ea 100644
--- a/pkg/sentry/vfs/dentry.go
+++ b/pkg/sentry/vfs/dentry.go
@@ -15,33 +15,18 @@
package vfs
import (
- "fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
-// Dentry represents a node in a Filesystem tree which may represent a file.
+// Dentry represents a node in a Filesystem tree at which a file exists.
//
// Dentries are reference-counted. Unless otherwise specified, all Dentry
// methods require that a reference is held.
//
-// A Dentry transitions through up to 3 different states through its lifetime:
-//
-// - Dentries are initially "independent". Independent Dentries have no parent,
-// and consequently no name.
-//
-// - Dentry.InsertChild() causes an independent Dentry to become a "child" of
-// another Dentry. A child node has a parent node, and a name in that parent,
-// both of which are mutable by DentryMoveChild(). Each child Dentry's name is
-// unique within its parent.
-//
-// - Dentry.RemoveChild() causes a child Dentry to become "disowned". A
-// disowned Dentry can still refer to its former parent and its former name in
-// said parent, but the disowned Dentry is no longer reachable from its parent,
-// and a new Dentry with the same name may become a child of the parent. (This
-// is analogous to a struct dentry being "unhashed" in Linux.)
-//
// Dentry is loosely analogous to Linux's struct dentry, but:
//
// - VFS does not associate Dentries with inodes. gVisor interacts primarily
@@ -50,15 +35,12 @@ import (
// and not inodes. Furthermore, when parties outside the scope of VFS can
// rename inodes on such filesystems, VFS generally cannot "follow" the rename,
// both due to synchronization issues and because it may not even be able to
-// name the destination path; this implies that it would in fact be *incorrect*
+// name the destination path; this implies that it would in fact be incorrect
// for Dentries to be associated with inodes on such filesystems. Consequently,
// operations that are inode operations in Linux are FilesystemImpl methods
// and/or FileDescriptionImpl methods in gVisor's VFS. Filesystems that do
// support inodes may store appropriate state in implementations of DentryImpl.
//
-// - VFS does not provide synchronization for mutable Dentry fields, other than
-// mount-related ones.
-//
// - VFS does not require that Dentries are instantiated for all paths accessed
// through VFS, only those that are tracked beyond the scope of a single
// Filesystem operation. This includes file descriptions, mount points, mount
@@ -66,38 +48,34 @@ import (
// of Dentries for operations on mutable remote filesystems that can't actually
// cache any state in the Dentry.
//
+// - VFS does not track filesystem structure (i.e. relationships between
+// Dentries), since both the relevant state and synchronization are
+// filesystem-specific.
+//
// - For the reasons above, VFS is not directly responsible for managing Dentry
// lifetime. Dentry reference counts only indicate the extent to which VFS
// requires Dentries to exist; Filesystems may elect to cache or discard
// Dentries with zero references.
+//
+// +stateify savable
type Dentry struct {
- // parent is this Dentry's parent in this Filesystem. If this Dentry is
- // independent, parent is nil.
- parent *Dentry
+ // mu synchronizes deletion/invalidation and mounting over this Dentry.
+ mu sync.Mutex `state:"nosave"`
- // name is this Dentry's name in parent.
- name string
-
- flags uint32
+ // dead is true if the file represented by this Dentry has been deleted (by
+ // CommitDeleteDentry or CommitRenameReplaceDentry) or invalidated (by
+ // InvalidateDentry). dead is protected by mu.
+ dead bool
// mounts is the number of Mounts for which this Dentry is Mount.point.
// mounts is accessed using atomic memory operations.
mounts uint32
- // children are child Dentries.
- children map[string]*Dentry
-
// impl is the DentryImpl associated with this Dentry. impl is immutable.
// This should be the last field in Dentry.
impl DentryImpl
}
-const (
- // dflagsDisownedMask is set in Dentry.flags if the Dentry has been
- // disowned.
- dflagsDisownedMask = 1 << iota
-)
-
// Init must be called before first use of d.
func (d *Dentry) Init(impl DentryImpl) {
d.impl = impl
@@ -114,7 +92,7 @@ func (d *Dentry) Impl() DentryImpl {
type DentryImpl interface {
// IncRef increments the Dentry's reference count. A Dentry with a non-zero
// reference count must remain coherent with the state of the filesystem.
- IncRef(fs *Filesystem)
+ IncRef()
// TryIncRef increments the Dentry's reference count and returns true. If
// the Dentry's reference count is zero, TryIncRef may do nothing and
@@ -122,148 +100,140 @@ type DentryImpl interface {
// guarantee that the Dentry is coherent with the state of the filesystem.)
//
// TryIncRef does not require that a reference is held on the Dentry.
- TryIncRef(fs *Filesystem) bool
+ TryIncRef() bool
// DecRef decrements the Dentry's reference count.
- DecRef(fs *Filesystem)
-}
+ DecRef(ctx context.Context)
-// IsDisowned returns true if d is disowned.
-func (d *Dentry) IsDisowned() bool {
- return atomic.LoadUint32(&d.flags)&dflagsDisownedMask != 0
-}
+ // InotifyWithParent notifies all watches on the targets represented by this
+ // dentry and its parent. The parent's watches are notified first, followed
+ // by this dentry's.
+ //
+ // InotifyWithParent automatically adds the IN_ISDIR flag for dentries
+ // representing directories.
+ //
+ // Note that the events may not actually propagate up to the user, depending
+ // on the event masks.
+ InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType)
-// Preconditions: !d.IsDisowned().
-func (d *Dentry) setDisowned() {
- atomic.AddUint32(&d.flags, dflagsDisownedMask)
-}
+ // Watches returns the set of inotify watches for the file corresponding to
+ // the Dentry. Dentries that are hard links to the same underlying file
+ // share the same watches.
+ //
+ // Watches may return nil if the dentry belongs to a FilesystemImpl that
+ // does not support inotify. If an implementation returns a non-nil watch
+ // set, it must always return a non-nil watch set. Likewise, if an
+ // implementation returns a nil watch set, it must always return a nil watch
+ // set.
+ //
+ // The caller does not need to hold a reference on the dentry.
+ Watches() *Watches
-func (d *Dentry) isMounted() bool {
- return atomic.LoadUint32(&d.mounts) != 0
+ // OnZeroWatches is called whenever the number of watches on a dentry drops
+ // to zero. This is needed by some FilesystemImpls (e.g. gofer) to manage
+ // dentry lifetime.
+ //
+ // The caller does not need to hold a reference on the dentry. OnZeroWatches
+ // may acquire inotify locks, so to prevent deadlock, no inotify locks should
+ // be held by the caller.
+ OnZeroWatches(ctx context.Context)
}
-func (d *Dentry) incRef(fs *Filesystem) {
- d.impl.IncRef(fs)
+// IncRef increments d's reference count.
+func (d *Dentry) IncRef() {
+ d.impl.IncRef()
}
-func (d *Dentry) tryIncRef(fs *Filesystem) bool {
- return d.impl.TryIncRef(fs)
+// TryIncRef increments d's reference count and returns true. If d's reference
+// count is zero, TryIncRef may instead do nothing and return false.
+func (d *Dentry) TryIncRef() bool {
+ return d.impl.TryIncRef()
}
-func (d *Dentry) decRef(fs *Filesystem) {
- d.impl.DecRef(fs)
+// DecRef decrements d's reference count.
+func (d *Dentry) DecRef(ctx context.Context) {
+ d.impl.DecRef(ctx)
}
-// These functions are exported so that filesystem implementations can use
-// them. The vfs package, and users of VFS, should not call these functions.
-// Unless otherwise specified, these methods require that there are no
-// concurrent mutators of d.
-
-// Name returns d's name in its parent in its owning Filesystem. If d is
-// independent, Name returns an empty string.
-func (d *Dentry) Name() string {
- return d.name
+// IsDead returns true if d has been deleted or invalidated by its owning
+// filesystem.
+func (d *Dentry) IsDead() bool {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dead
}
-// Parent returns d's parent in its owning Filesystem. It does not take a
-// reference on the returned Dentry. If d is independent, Parent returns nil.
-func (d *Dentry) Parent() *Dentry {
- return d.parent
+func (d *Dentry) isMounted() bool {
+ return atomic.LoadUint32(&d.mounts) != 0
}
-// ParentOrSelf is equivalent to Parent, but returns d if d is independent.
-func (d *Dentry) ParentOrSelf() *Dentry {
- if d.parent == nil {
- return d
- }
- return d.parent
+// InotifyWithParent notifies all watches on the targets represented by d and
+// its parent of events.
+func (d *Dentry) InotifyWithParent(ctx context.Context, events, cookie uint32, et EventType) {
+ d.impl.InotifyWithParent(ctx, events, cookie, et)
}
-// Child returns d's child with the given name in its owning Filesystem. It
-// does not take a reference on the returned Dentry. If no such child exists,
-// Child returns nil.
-func (d *Dentry) Child(name string) *Dentry {
- return d.children[name]
+// Watches returns the set of inotify watches associated with d.
+//
+// Watches will return nil if d belongs to a FilesystemImpl that does not
+// support inotify.
+func (d *Dentry) Watches() *Watches {
+ return d.impl.Watches()
}
-// HasChildren returns true if d has any children.
-func (d *Dentry) HasChildren() bool {
- return len(d.children) != 0
+// OnZeroWatches performs cleanup tasks whenever the number of watches on a
+// dentry drops to zero.
+func (d *Dentry) OnZeroWatches(ctx context.Context) {
+ d.impl.OnZeroWatches(ctx)
}
-// InsertChild makes child a child of d with the given name.
-//
-// InsertChild is a mutator of d and child.
-//
-// Preconditions: child must be an independent Dentry. d and child must be from
-// the same Filesystem. d must not already have a child with the given name.
-func (d *Dentry) InsertChild(child *Dentry, name string) {
- if checkInvariants {
- if _, ok := d.children[name]; ok {
- panic(fmt.Sprintf("parent already contains a child named %q", name))
- }
- if child.parent != nil || child.name != "" {
- panic(fmt.Sprintf("child is not independent: parent = %v, name = %q", child.parent, child.name))
- }
- }
- if d.children == nil {
- d.children = make(map[string]*Dentry)
- }
- d.children[name] = child
- child.parent = d
- child.name = name
-}
+// The following functions are exported so that filesystem implementations can
+// use them. The vfs package, and users of VFS, should not call these
+// functions.
// PrepareDeleteDentry must be called before attempting to delete the file
// represented by d. If PrepareDeleteDentry succeeds, the caller must call
// AbortDeleteDentry or CommitDeleteDentry depending on the deletion's outcome.
-//
-// Preconditions: d is a child Dentry.
func (vfs *VirtualFilesystem) PrepareDeleteDentry(mntns *MountNamespace, d *Dentry) error {
- if checkInvariants {
- if d.parent == nil {
- panic("d is independent")
- }
- if d.IsDisowned() {
- panic("d is already disowned")
- }
- }
- vfs.mountMu.RLock()
- if _, ok := mntns.mountpoints[d]; ok {
- vfs.mountMu.RUnlock()
+ vfs.mountMu.Lock()
+ if mntns.mountpoints[d] != 0 {
+ vfs.mountMu.Unlock()
return syserror.EBUSY
}
- // Return with vfs.mountMu locked, which will be unlocked by
- // AbortDeleteDentry or CommitDeleteDentry.
+ d.mu.Lock()
+ vfs.mountMu.Unlock()
+ // Return with d.mu locked to block attempts to mount over it; it will be
+ // unlocked by AbortDeleteDentry or CommitDeleteDentry.
return nil
}
// AbortDeleteDentry must be called after PrepareDeleteDentry if the deletion
// fails.
-func (vfs *VirtualFilesystem) AbortDeleteDentry() {
- vfs.mountMu.RUnlock()
+func (vfs *VirtualFilesystem) AbortDeleteDentry(d *Dentry) {
+ d.mu.Unlock()
}
-// CommitDeleteDentry must be called after the file represented by d is
-// deleted, and causes d to become disowned.
-//
-// Preconditions: PrepareDeleteDentry was previously called on d.
-func (vfs *VirtualFilesystem) CommitDeleteDentry(d *Dentry) {
- delete(d.parent.children, d.name)
- d.setDisowned()
- // TODO: lazily unmount mounts at d
- vfs.mountMu.RUnlock()
+// CommitDeleteDentry must be called after PrepareDeleteDentry if the deletion
+// succeeds.
+func (vfs *VirtualFilesystem) CommitDeleteDentry(ctx context.Context, d *Dentry) {
+ d.dead = true
+ d.mu.Unlock()
+ if d.isMounted() {
+ vfs.forgetDeadMountpoint(ctx, d)
+ }
}
-// DeleteDentry combines PrepareDeleteDentry and CommitDeleteDentry, as
-// appropriate for in-memory filesystems that don't need to ensure that some
-// external state change succeeds before committing the deletion.
-func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) error {
- if err := vfs.PrepareDeleteDentry(mntns, d); err != nil {
- return err
+// InvalidateDentry is called when d ceases to represent the file it formerly
+// did for reasons outside of VFS' control (e.g. d represents the local state
+// of a file on a remote filesystem on which the file has already been
+// deleted).
+func (vfs *VirtualFilesystem) InvalidateDentry(ctx context.Context, d *Dentry) {
+ d.mu.Lock()
+ d.dead = true
+ d.mu.Unlock()
+ if d.isMounted() {
+ vfs.forgetDeadMountpoint(ctx, d)
}
- vfs.CommitDeleteDentry(d)
- return nil
}
// PrepareRenameDentry must be called before attempting to rename the file
@@ -272,37 +242,24 @@ func (vfs *VirtualFilesystem) DeleteDentry(mntns *MountNamespace, d *Dentry) err
// caller must call AbortRenameDentry, CommitRenameReplaceDentry, or
// CommitRenameExchangeDentry depending on the rename's outcome.
//
-// Preconditions: from is a child Dentry. If to is not nil, it must be a child
-// Dentry from the same Filesystem.
+// Preconditions: If to is not nil, it must be a child Dentry from the same
+// Filesystem. from != to.
func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, to *Dentry) error {
- if checkInvariants {
- if from.parent == nil {
- panic("from is independent")
- }
- if from.IsDisowned() {
- panic("from is already disowned")
- }
- if to != nil {
- if to.parent == nil {
- panic("to is independent")
- }
- if to.IsDisowned() {
- panic("to is already disowned")
- }
- }
- }
- vfs.mountMu.RLock()
- if _, ok := mntns.mountpoints[from]; ok {
- vfs.mountMu.RUnlock()
+ vfs.mountMu.Lock()
+ if mntns.mountpoints[from] != 0 {
+ vfs.mountMu.Unlock()
return syserror.EBUSY
}
if to != nil {
- if _, ok := mntns.mountpoints[to]; ok {
- vfs.mountMu.RUnlock()
+ if mntns.mountpoints[to] != 0 {
+ vfs.mountMu.Unlock()
return syserror.EBUSY
}
+ to.mu.Lock()
}
- // Return with vfs.mountMu locked, which will be unlocked by
+ from.mu.Lock()
+ vfs.mountMu.Unlock()
+ // Return with from.mu and to.mu locked, which will be unlocked by
// AbortRenameDentry, CommitRenameReplaceDentry, or
// CommitRenameExchangeDentry.
return nil
@@ -310,8 +267,11 @@ func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, t
// AbortRenameDentry must be called after PrepareRenameDentry if the rename
// fails.
-func (vfs *VirtualFilesystem) AbortRenameDentry() {
- vfs.mountMu.RUnlock()
+func (vfs *VirtualFilesystem) AbortRenameDentry(from, to *Dentry) {
+ from.mu.Unlock()
+ if to != nil {
+ to.mu.Unlock()
+ }
}
// CommitRenameReplaceDentry must be called after the file represented by from
@@ -319,19 +279,15 @@ func (vfs *VirtualFilesystem) AbortRenameDentry() {
// that was replaced by from.
//
// Preconditions: PrepareRenameDentry was previously called on from and to.
-// newParent.Child(newName) == to.
-func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, newParent *Dentry, newName string, to *Dentry) {
+func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(ctx context.Context, from, to *Dentry) {
+ from.mu.Unlock()
if to != nil {
- to.setDisowned()
- // TODO: lazily unmount mounts at d
- }
- if newParent.children == nil {
- newParent.children = make(map[string]*Dentry)
+ to.dead = true
+ to.mu.Unlock()
+ if to.isMounted() {
+ vfs.forgetDeadMountpoint(ctx, to)
+ }
}
- newParent.children[newName] = from
- from.parent = newParent
- from.name = newName
- vfs.mountMu.RUnlock()
}
// CommitRenameExchangeDentry must be called after the files represented by
@@ -339,9 +295,31 @@ func (vfs *VirtualFilesystem) CommitRenameReplaceDentry(from, newParent *Dentry,
//
// Preconditions: PrepareRenameDentry was previously called on from and to.
func (vfs *VirtualFilesystem) CommitRenameExchangeDentry(from, to *Dentry) {
- from.parent, to.parent = to.parent, from.parent
- from.name, to.name = to.name, from.name
- from.parent.children[from.name] = from
- to.parent.children[to.name] = to
- vfs.mountMu.RUnlock()
+ from.mu.Unlock()
+ to.mu.Unlock()
+}
+
+// forgetDeadMountpoint is called when a mount point is deleted or invalidated
+// to umount all mounts using it in all other mount namespaces.
+//
+// forgetDeadMountpoint is analogous to Linux's
+// fs/namespace.c:__detach_mounts().
+func (vfs *VirtualFilesystem) forgetDeadMountpoint(ctx context.Context, d *Dentry) {
+ var (
+ vdsToDecRef []VirtualDentry
+ mountsToDecRef []*Mount
+ )
+ vfs.mountMu.Lock()
+ vfs.mounts.seq.BeginWrite()
+ for mnt := range vfs.mountpoints[d] {
+ vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(mnt, &umountRecursiveOptions{}, vdsToDecRef, mountsToDecRef)
+ }
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef(ctx)
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef(ctx)
+ }
}
diff --git a/pkg/sentry/vfs/device.go b/pkg/sentry/vfs/device.go
new file mode 100644
index 000000000..1e9dffc8f
--- /dev/null
+++ b/pkg/sentry/vfs/device.go
@@ -0,0 +1,132 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// DeviceKind indicates whether a device is a block or character device.
+type DeviceKind uint32
+
+const (
+ // BlockDevice indicates a block device.
+ BlockDevice DeviceKind = iota
+
+ // CharDevice indicates a character device.
+ CharDevice
+)
+
+// String implements fmt.Stringer.String.
+func (kind DeviceKind) String() string {
+ switch kind {
+ case BlockDevice:
+ return "block"
+ case CharDevice:
+ return "character"
+ default:
+ return fmt.Sprintf("invalid device kind %d", kind)
+ }
+}
+
+type devTuple struct {
+ kind DeviceKind
+ major uint32
+ minor uint32
+}
+
+// A Device backs device special files.
+type Device interface {
+ // Open returns a FileDescription representing this device.
+ Open(ctx context.Context, mnt *Mount, d *Dentry, opts OpenOptions) (*FileDescription, error)
+}
+
+// +stateify savable
+type registeredDevice struct {
+ dev Device
+ opts RegisterDeviceOptions
+}
+
+// RegisterDeviceOptions contains options to
+// VirtualFilesystem.RegisterDevice().
+//
+// +stateify savable
+type RegisterDeviceOptions struct {
+ // GroupName is the name shown for this device registration in
+ // /proc/devices. If GroupName is empty, this registration will not be
+ // shown in /proc/devices.
+ GroupName string
+}
+
+// RegisterDevice registers the given Device in vfs with the given major and
+// minor device numbers.
+func (vfs *VirtualFilesystem) RegisterDevice(kind DeviceKind, major, minor uint32, dev Device, opts *RegisterDeviceOptions) error {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.Lock()
+ defer vfs.devicesMu.Unlock()
+ if existing, ok := vfs.devices[tup]; ok {
+ return fmt.Errorf("%s device number (%d, %d) is already registered to device type %T", kind, major, minor, existing.dev)
+ }
+ vfs.devices[tup] = &registeredDevice{
+ dev: dev,
+ opts: *opts,
+ }
+ return nil
+}
+
+// OpenDeviceSpecialFile returns a FileDescription representing the given
+// device.
+func (vfs *VirtualFilesystem) OpenDeviceSpecialFile(ctx context.Context, mnt *Mount, d *Dentry, kind DeviceKind, major, minor uint32, opts *OpenOptions) (*FileDescription, error) {
+ tup := devTuple{kind, major, minor}
+ vfs.devicesMu.RLock()
+ defer vfs.devicesMu.RUnlock()
+ rd, ok := vfs.devices[tup]
+ if !ok {
+ return nil, syserror.ENXIO
+ }
+ return rd.dev.Open(ctx, mnt, d, *opts)
+}
+
+// GetAnonBlockDevMinor allocates and returns an unused minor device number for
+// an "anonymous" block device with major number UNNAMED_MAJOR.
+func (vfs *VirtualFilesystem) GetAnonBlockDevMinor() (uint32, error) {
+ vfs.anonBlockDevMinorMu.Lock()
+ defer vfs.anonBlockDevMinorMu.Unlock()
+ minor := vfs.anonBlockDevMinorNext
+ const maxDevMinor = (1 << 20) - 1
+ for minor < maxDevMinor {
+ if _, ok := vfs.anonBlockDevMinor[minor]; !ok {
+ vfs.anonBlockDevMinor[minor] = struct{}{}
+ vfs.anonBlockDevMinorNext = minor + 1
+ return minor, nil
+ }
+ minor++
+ }
+ return 0, syserror.EMFILE
+}
+
+// PutAnonBlockDevMinor deallocates a minor device number returned by a
+// previous call to GetAnonBlockDevMinor.
+func (vfs *VirtualFilesystem) PutAnonBlockDevMinor(minor uint32) {
+ vfs.anonBlockDevMinorMu.Lock()
+ defer vfs.anonBlockDevMinorMu.Unlock()
+ delete(vfs.anonBlockDevMinor, minor)
+ if minor < vfs.anonBlockDevMinorNext {
+ vfs.anonBlockDevMinorNext = minor
+ }
+}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
new file mode 100644
index 000000000..1b5af9f73
--- /dev/null
+++ b/pkg/sentry/vfs/epoll.go
@@ -0,0 +1,383 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// epollCycleMu serializes attempts to register EpollInstances with other
+// EpollInstances in order to check for cycles.
+var epollCycleMu sync.Mutex
+
+// EpollInstance represents an epoll instance, as described by epoll(7).
+type EpollInstance struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ DentryMetadataFileDescriptionImpl
+ NoLockFD
+
+ // q holds waiters on this EpollInstance.
+ q waiter.Queue
+
+ // interest is the set of file descriptors that are registered with the
+ // EpollInstance for monitoring. interest is protected by interestMu.
+ interestMu sync.Mutex
+ interest map[epollInterestKey]*epollInterest
+
+ // mu protects fields in registered epollInterests.
+ mu sync.Mutex
+
+ // ready is the set of file descriptors that may be "ready" for I/O. Note
+ // that this must be an ordered list, not a map: "If more than maxevents
+ // file descriptors are ready when epoll_wait() is called, then successive
+ // epoll_wait() calls will round robin through the set of ready file
+ // descriptors. This behavior helps avoid starvation scenarios, where a
+ // process fails to notice that additional file descriptors are ready
+ // because it focuses on a set of file descriptors that are already known
+ // to be ready." - epoll_wait(2)
+ ready epollInterestList
+}
+
+type epollInterestKey struct {
+ // file is the registered FileDescription. No reference is held on file;
+ // instead, when the last reference is dropped, FileDescription.DecRef()
+ // removes the FileDescription from all EpollInstances. file is immutable.
+ file *FileDescription
+
+ // num is the file descriptor number with which this entry was registered.
+ // num is immutable.
+ num int32
+}
+
+// epollInterest represents an EpollInstance's interest in a file descriptor.
+type epollInterest struct {
+ // epoll is the owning EpollInstance. epoll is immutable.
+ epoll *EpollInstance
+
+ // key is the file to which this epollInterest applies. key is immutable.
+ key epollInterestKey
+
+ // waiter is registered with key.file. entry is protected by epoll.mu.
+ waiter waiter.Entry
+
+ // mask is the event mask associated with this registration, including
+ // flags EPOLLET and EPOLLONESHOT. mask is protected by epoll.mu.
+ mask uint32
+
+ // ready is true if epollInterestEntry is linked into epoll.ready. ready
+ // and epollInterestEntry are protected by epoll.mu.
+ ready bool
+ epollInterestEntry
+
+ // userData is the struct epoll_event::data associated with this
+ // epollInterest. userData is protected by epoll.mu.
+ userData [2]int32
+}
+
+// NewEpollInstanceFD returns a FileDescription representing a new epoll
+// instance. A reference is taken on the returned FileDescription.
+func (vfs *VirtualFilesystem) NewEpollInstanceFD(ctx context.Context) (*FileDescription, error) {
+ vd := vfs.NewAnonVirtualDentry("[eventpoll]")
+ defer vd.DecRef(ctx)
+ ep := &EpollInstance{
+ interest: make(map[epollInterestKey]*epollInterest),
+ }
+ if err := ep.vfsfd.Init(ep, linux.O_RDWR, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &ep.vfsfd, nil
+}
+
+// Release implements FileDescriptionImpl.Release.
+func (ep *EpollInstance) Release(ctx context.Context) {
+ // Unregister all polled fds.
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+ for key, epi := range ep.interest {
+ file := key.file
+ file.epollMu.Lock()
+ delete(file.epolls, epi)
+ file.epollMu.Unlock()
+ file.EventUnregister(&epi.waiter)
+ }
+ ep.interest = nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (ep *EpollInstance) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn == 0 {
+ return 0
+ }
+ ep.mu.Lock()
+ for epi := ep.ready.Front(); epi != nil; epi = epi.Next() {
+ wmask := waiter.EventMaskFromLinux(epi.mask)
+ if epi.key.file.Readiness(wmask)&wmask != 0 {
+ ep.mu.Unlock()
+ return waiter.EventIn
+ }
+ }
+ ep.mu.Unlock()
+ return 0
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (ep *EpollInstance) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ ep.q.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (ep *EpollInstance) EventUnregister(e *waiter.Entry) {
+ ep.q.EventUnregister(e)
+}
+
+// Seek implements FileDescriptionImpl.Seek.
+func (ep *EpollInstance) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ // Linux: fs/eventpoll.c:eventpoll_fops.llseek == noop_llseek
+ return 0, nil
+}
+
+// AddInterest implements the semantics of EPOLL_CTL_ADD.
+//
+// Preconditions: A reference must be held on file.
+func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
+ // Check for cyclic polling if necessary.
+ subep, _ := file.impl.(*EpollInstance)
+ if subep != nil {
+ epollCycleMu.Lock()
+ // epollCycleMu must be locked for the rest of AddInterest to ensure
+ // that cyclic polling is not introduced after the check.
+ defer epollCycleMu.Unlock()
+ if subep.mightPoll(ep) {
+ return syserror.ELOOP
+ }
+ }
+
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+
+ // Fail if the key is already registered.
+ key := epollInterestKey{
+ file: file,
+ num: num,
+ }
+ if _, ok := ep.interest[key]; ok {
+ return syserror.EEXIST
+ }
+
+ // Register interest in file.
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP
+ epi := &epollInterest{
+ epoll: ep,
+ key: key,
+ mask: mask,
+ userData: event.Data,
+ }
+ epi.waiter.Callback = epi
+ ep.interest[key] = epi
+ wmask := waiter.EventMaskFromLinux(mask)
+ file.EventRegister(&epi.waiter, wmask)
+
+ // Check if the file is already ready.
+ if file.Readiness(wmask)&wmask != 0 {
+ epi.Callback(nil)
+ }
+
+ // Add epi to file.epolls so that it is removed when the last
+ // FileDescription reference is dropped.
+ file.epollMu.Lock()
+ if file.epolls == nil {
+ file.epolls = make(map[*epollInterest]struct{})
+ }
+ file.epolls[epi] = struct{}{}
+ file.epollMu.Unlock()
+
+ return nil
+}
+
+func (ep *EpollInstance) mightPoll(ep2 *EpollInstance) bool {
+ return ep.mightPollRecursive(ep2, 4) // Linux: fs/eventpoll.c:EP_MAX_NESTS
+}
+
+func (ep *EpollInstance) mightPollRecursive(ep2 *EpollInstance, remainingRecursion int) bool {
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+ for key := range ep.interest {
+ nextep, ok := key.file.impl.(*EpollInstance)
+ if !ok {
+ continue
+ }
+ if nextep == ep2 {
+ return true
+ }
+ if remainingRecursion == 0 {
+ return true
+ }
+ if nextep.mightPollRecursive(ep2, remainingRecursion-1) {
+ return true
+ }
+ }
+ return false
+}
+
+// ModifyInterest implements the semantics of EPOLL_CTL_MOD.
+//
+// Preconditions: A reference must be held on file.
+func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event linux.EpollEvent) error {
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+
+ // Fail if the key is not already registered.
+ epi, ok := ep.interest[epollInterestKey{
+ file: file,
+ num: num,
+ }]
+ if !ok {
+ return syserror.ENOENT
+ }
+
+ // Update epi for the next call to ep.ReadEvents().
+ mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP
+ ep.mu.Lock()
+ epi.mask = mask
+ epi.userData = event.Data
+ ep.mu.Unlock()
+
+ // Re-register with the new mask.
+ file.EventUnregister(&epi.waiter)
+ wmask := waiter.EventMaskFromLinux(mask)
+ file.EventRegister(&epi.waiter, wmask)
+
+ // Check if the file is already ready with the new mask.
+ if file.Readiness(wmask)&wmask != 0 {
+ epi.Callback(nil)
+ }
+
+ return nil
+}
+
+// DeleteInterest implements the semantics of EPOLL_CTL_DEL.
+//
+// Preconditions: A reference must be held on file.
+func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error {
+ ep.interestMu.Lock()
+ defer ep.interestMu.Unlock()
+
+ // Fail if the key is not already registered.
+ epi, ok := ep.interest[epollInterestKey{
+ file: file,
+ num: num,
+ }]
+ if !ok {
+ return syserror.ENOENT
+ }
+
+ // Unregister from the file so that epi will no longer be readied.
+ file.EventUnregister(&epi.waiter)
+
+ // Forget about epi.
+ ep.removeLocked(epi)
+
+ file.epollMu.Lock()
+ delete(file.epolls, epi)
+ file.epollMu.Unlock()
+
+ return nil
+}
+
+// Callback implements waiter.EntryCallback.Callback.
+func (epi *epollInterest) Callback(*waiter.Entry) {
+ newReady := false
+ epi.epoll.mu.Lock()
+ if !epi.ready {
+ newReady = true
+ epi.ready = true
+ epi.epoll.ready.PushBack(epi)
+ }
+ epi.epoll.mu.Unlock()
+ if newReady {
+ epi.epoll.q.Notify(waiter.EventIn)
+ }
+}
+
+// Preconditions: ep.interestMu must be locked.
+func (ep *EpollInstance) removeLocked(epi *epollInterest) {
+ delete(ep.interest, epi.key)
+ ep.mu.Lock()
+ if epi.ready {
+ epi.ready = false
+ ep.ready.Remove(epi)
+ }
+ ep.mu.Unlock()
+}
+
+// ReadEvents reads up to len(events) ready events into events and returns the
+// number of events read.
+//
+// Preconditions: len(events) != 0.
+func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int {
+ i := 0
+ // Hot path: avoid defer.
+ ep.mu.Lock()
+ var next *epollInterest
+ var requeue epollInterestList
+ for epi := ep.ready.Front(); epi != nil; epi = next {
+ next = epi.Next()
+ // Regardless of what else happens, epi is initially removed from the
+ // ready list.
+ ep.ready.Remove(epi)
+ wmask := waiter.EventMaskFromLinux(epi.mask)
+ ievents := epi.key.file.Readiness(wmask) & wmask
+ if ievents == 0 {
+ // Leave epi off the ready list.
+ epi.ready = false
+ continue
+ }
+ // Determine what we should do with epi.
+ switch {
+ case epi.mask&linux.EPOLLONESHOT != 0:
+ // Clear all events from the mask; they must be re-added by
+ // EPOLL_CTL_MOD.
+ epi.mask &= linux.EP_PRIVATE_BITS
+ fallthrough
+ case epi.mask&linux.EPOLLET != 0:
+ // Leave epi off the ready list.
+ epi.ready = false
+ default:
+ // Queue epi to be moved to the end of the ready list.
+ requeue.PushBack(epi)
+ }
+ // Report ievents.
+ events[i] = linux.EpollEvent{
+ Events: ievents.ToLinux(),
+ Data: epi.userData,
+ }
+ i++
+ if i == len(events) {
+ break
+ }
+ }
+ ep.ready.PushBackList(&requeue)
+ ep.mu.Unlock()
+ return i
+}
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 3a9665800..dcafffe57 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -18,10 +18,14 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -38,30 +42,193 @@ type FileDescription struct {
// operations.
refs int64
+ // flagsMu protects statusFlags and asyncHandler below.
+ flagsMu sync.Mutex
+
+ // statusFlags contains status flags, "initialized by open(2) and possibly
+ // modified by fcntl()" - fcntl(2). statusFlags can be read using atomic
+ // memory operations when it does not need to be synchronized with an
+ // access to asyncHandler.
+ statusFlags uint32
+
+ // asyncHandler handles O_ASYNC signal generation. It is set with the
+ // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must
+ // also be set by fcntl(2).
+ asyncHandler FileAsync
+
+ // epolls is the set of epollInterests registered for this FileDescription.
+ // epolls is protected by epollMu.
+ epollMu sync.Mutex
+ epolls map[*epollInterest]struct{}
+
// vd is the filesystem location at which this FileDescription was opened.
// A reference is held on vd. vd is immutable.
vd VirtualDentry
+ // opts contains options passed to FileDescription.Init(). opts is
+ // immutable.
+ opts FileDescriptionOptions
+
+ // readable is MayReadFileWithOpenFlags(statusFlags). readable is
+ // immutable.
+ //
+ // readable is analogous to Linux's FMODE_READ.
+ readable bool
+
+ // writable is MayWriteFileWithOpenFlags(statusFlags). If writable is true,
+ // the FileDescription holds a write count on vd.mount. writable is
+ // immutable.
+ //
+ // writable is analogous to Linux's FMODE_WRITE.
+ writable bool
+
+ usedLockBSD uint32
+
// impl is the FileDescriptionImpl associated with this Filesystem. impl is
// immutable. This should be the last field in FileDescription.
impl FileDescriptionImpl
}
-// Init must be called before first use of fd. It takes references on mnt and
-// d.
-func (fd *FileDescription) Init(impl FileDescriptionImpl, mnt *Mount, d *Dentry) {
+// FileDescriptionOptions contains options to FileDescription.Init().
+type FileDescriptionOptions struct {
+ // If AllowDirectIO is true, allow O_DIRECT to be set on the file.
+ AllowDirectIO bool
+
+ // If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE.
+ DenyPRead bool
+
+ // If DenyPWrite is true, calls to FileDescription.PWrite() return
+ // ESPIPE.
+ DenyPWrite bool
+
+ // If UseDentryMetadata is true, calls to FileDescription methods that
+ // interact with file and filesystem metadata (Stat, SetStat, StatFS,
+ // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling
+ // the corresponding FilesystemImpl methods instead of the corresponding
+ // FileDescriptionImpl methods.
+ //
+ // UseDentryMetadata is intended for file descriptions that are implemented
+ // outside of individual filesystems, such as pipes, sockets, and device
+ // special files. FileDescriptions for which UseDentryMetadata is true may
+ // embed DentryMetadataFileDescriptionImpl to obtain appropriate
+ // implementations of FileDescriptionImpl methods that should not be
+ // called.
+ UseDentryMetadata bool
+}
+
+// FileCreationFlags are the set of flags passed to FileDescription.Init() but
+// omitted from FileDescription.StatusFlags().
+const FileCreationFlags = linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC
+
+// Init must be called before first use of fd. If it succeeds, it takes
+// references on mnt and d. flags is the initial file description flags, which
+// is usually the full set of flags passed to open(2).
+func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mount, d *Dentry, opts *FileDescriptionOptions) error {
+ writable := MayWriteFileWithOpenFlags(flags)
+ if writable {
+ if err := mnt.CheckBeginWrite(); err != nil {
+ return err
+ }
+ }
+
fd.refs = 1
+
+ // Remove "file creation flags" to mirror the behavior from file.f_flags in
+ // fs/open.c:do_dentry_open.
+ fd.statusFlags = flags &^ FileCreationFlags
fd.vd = VirtualDentry{
mount: mnt,
dentry: d,
}
- fd.vd.IncRef()
+ mnt.IncRef()
+ d.IncRef()
+ fd.opts = *opts
+ fd.readable = MayReadFileWithOpenFlags(flags)
+ fd.writable = writable
fd.impl = impl
+ return nil
}
-// Impl returns the FileDescriptionImpl associated with fd.
-func (fd *FileDescription) Impl() FileDescriptionImpl {
- return fd.impl
+// IncRef increments fd's reference count.
+func (fd *FileDescription) IncRef() {
+ atomic.AddInt64(&fd.refs, 1)
+}
+
+// TryIncRef increments fd's reference count and returns true. If fd's
+// reference count is already zero, TryIncRef does nothing and returns false.
+//
+// TryIncRef does not require that a reference is held on fd.
+func (fd *FileDescription) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&fd.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// DecRef decrements fd's reference count.
+func (fd *FileDescription) DecRef(ctx context.Context) {
+ if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 {
+ // Unregister fd from all epoll instances.
+ fd.epollMu.Lock()
+ epolls := fd.epolls
+ fd.epolls = nil
+ fd.epollMu.Unlock()
+ for epi := range epolls {
+ ep := epi.epoll
+ ep.interestMu.Lock()
+ // Check that epi has not been concurrently unregistered by
+ // EpollInstance.DeleteInterest() or EpollInstance.Release().
+ if _, ok := ep.interest[epi.key]; ok {
+ fd.EventUnregister(&epi.waiter)
+ ep.removeLocked(epi)
+ }
+ ep.interestMu.Unlock()
+ }
+
+ // If BSD locks were used, release any lock that it may have acquired.
+ if atomic.LoadUint32(&fd.usedLockBSD) != 0 {
+ fd.impl.UnlockBSD(context.Background(), fd)
+ }
+
+ // Release implementation resources.
+ fd.impl.Release(ctx)
+ if fd.writable {
+ fd.vd.mount.EndWrite()
+ }
+ fd.vd.DecRef(ctx)
+ fd.flagsMu.Lock()
+ // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1.
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Unregister(fd)
+ }
+ fd.asyncHandler = nil
+ fd.flagsMu.Unlock()
+ } else if refs < 0 {
+ panic("FileDescription.DecRef() called without holding a reference")
+ }
+}
+
+// Refs returns the current number of references. The returned count
+// is inherently racy and is unsafe to use without external synchronization.
+func (fd *FileDescription) Refs() int64 {
+ return atomic.LoadInt64(&fd.refs)
+}
+
+// Mount returns the mount on which fd was opened. It does not take a reference
+// on the returned Mount.
+func (fd *FileDescription) Mount() *Mount {
+ return fd.vd.mount
+}
+
+// Dentry returns the dentry at which fd was opened. It does not take a
+// reference on the returned Dentry.
+func (fd *FileDescription) Dentry() *Dentry {
+ return fd.vd.dentry
}
// VirtualDentry returns the location at which fd was opened. It does not take
@@ -70,19 +237,88 @@ func (fd *FileDescription) VirtualDentry() VirtualDentry {
return fd.vd
}
-// IncRef increments fd's reference count.
-func (fd *FileDescription) IncRef() {
- atomic.AddInt64(&fd.refs, 1)
+// Options returns the options passed to fd.Init().
+func (fd *FileDescription) Options() FileDescriptionOptions {
+ return fd.opts
}
-// DecRef decrements fd's reference count.
-func (fd *FileDescription) DecRef() {
- if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 {
- fd.impl.Release()
- fd.vd.DecRef()
- } else if refs < 0 {
- panic("FileDescription.DecRef() called without holding a reference")
+// StatusFlags returns file description status flags, as for fcntl(F_GETFL).
+func (fd *FileDescription) StatusFlags() uint32 {
+ return atomic.LoadUint32(&fd.statusFlags)
+}
+
+// SetStatusFlags sets file description status flags, as for fcntl(F_SETFL).
+func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Credentials, flags uint32) error {
+ // Compare Linux's fs/fcntl.c:setfl().
+ oldFlags := fd.StatusFlags()
+ // Linux documents this check as "O_APPEND cannot be cleared if the file is
+ // marked as append-only and the file is open for write", which would make
+ // sense. However, the check as actually implemented seems to be "O_APPEND
+ // cannot be changed if the file is marked as append-only".
+ if (flags^oldFlags)&linux.O_APPEND != 0 {
+ stat, err := fd.Stat(ctx, StatOptions{
+ // There is no mask bit for stx_attributes.
+ Mask: 0,
+ // Linux just reads inode::i_flags directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if (stat.AttributesMask&linux.STATX_ATTR_APPEND != 0) && (stat.Attributes&linux.STATX_ATTR_APPEND != 0) {
+ return syserror.EPERM
+ }
}
+ if (flags&linux.O_NOATIME != 0) && (oldFlags&linux.O_NOATIME == 0) {
+ stat, err := fd.Stat(ctx, StatOptions{
+ Mask: linux.STATX_UID,
+ // Linux's inode_owner_or_capable() just reads inode::i_uid
+ // directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return err
+ }
+ if stat.Mask&linux.STATX_UID == 0 {
+ return syserror.EPERM
+ }
+ if !CanActAsOwner(creds, auth.KUID(stat.UID)) {
+ return syserror.EPERM
+ }
+ }
+ if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO {
+ return syserror.EINVAL
+ }
+ // TODO(gvisor.dev/issue/1035): FileDescriptionImpl.SetOAsync()?
+ const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK
+ fd.flagsMu.Lock()
+ if fd.asyncHandler != nil {
+ // Use fd.statusFlags instead of oldFlags, which may have become outdated,
+ // to avoid double registering/unregistering.
+ if fd.statusFlags&linux.O_ASYNC == 0 && flags&linux.O_ASYNC != 0 {
+ fd.asyncHandler.Register(fd)
+ } else if fd.statusFlags&linux.O_ASYNC != 0 && flags&linux.O_ASYNC == 0 {
+ fd.asyncHandler.Unregister(fd)
+ }
+ }
+ atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags))
+ fd.flagsMu.Unlock()
+ return nil
+}
+
+// IsReadable returns true if fd was opened for reading.
+func (fd *FileDescription) IsReadable() bool {
+ return fd.readable
+}
+
+// IsWritable returns true if fd was opened for writing.
+func (fd *FileDescription) IsWritable() bool {
+ return fd.writable
+}
+
+// Impl returns the FileDescriptionImpl associated with fd.
+func (fd *FileDescription) Impl() FileDescriptionImpl {
+ return fd.impl
}
// FileDescriptionImpl contains implementation details for an FileDescription.
@@ -93,42 +329,50 @@ func (fd *FileDescription) DecRef() {
// be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID and
// auth.KGID respectively).
//
+// All methods may return errors not specified.
+//
// FileDescriptionImpl is analogous to Linux's struct file_operations.
type FileDescriptionImpl interface {
// Release is called when the associated FileDescription reaches zero
// references.
- Release()
+ Release(ctx context.Context)
// OnClose is called when a file descriptor representing the
// FileDescription is closed. Note that returning a non-nil error does not
// prevent the file descriptor from being closed.
OnClose(ctx context.Context) error
- // StatusFlags returns file description status flags, as for
- // fcntl(F_GETFL).
- StatusFlags(ctx context.Context) (uint32, error)
-
- // SetStatusFlags sets file description status flags, as for
- // fcntl(F_SETFL).
- SetStatusFlags(ctx context.Context, flags uint32) error
-
// Stat returns metadata for the file represented by the FileDescription.
Stat(ctx context.Context, opts StatOptions) (linux.Statx, error)
// SetStat updates metadata for the file represented by the
- // FileDescription.
+ // FileDescription. Implementations are responsible for checking if the
+ // operation can be performed (see vfs.CheckSetStat() for common checks).
SetStat(ctx context.Context, opts SetStatOptions) error
// StatFS returns metadata for the filesystem containing the file
// represented by the FileDescription.
StatFS(ctx context.Context) (linux.Statfs, error)
+ // Allocate grows the file to offset + length bytes.
+ // Only mode == 0 is supported currently.
+ //
+ // Preconditions: The FileDescription was opened for writing.
+ Allocate(ctx context.Context, mode, offset, length uint64) error
+
// waiter.Waitable methods may be used to poll for I/O events.
waiter.Waitable
// PRead reads from the file into dst, starting at the given offset, and
// returns the number of bytes read. PRead is permitted to return partial
// reads with a nil error.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for reading.
+ // FileDescriptionOptions.DenyPRead == false.
PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error)
// Read is similar to PRead, but does not specify an offset.
@@ -138,6 +382,12 @@ type FileDescriptionImpl interface {
// the number of bytes read; note that POSIX 2.9.7 "Thread Interactions
// with Regular File Operations" requires that all operations that may
// mutate the FileDescription offset are serialized.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, Read returns EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for reading.
Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error)
// PWrite writes src to the file, starting at the given offset, and returns
@@ -147,6 +397,14 @@ type FileDescriptionImpl interface {
// As in Linux (but not POSIX), if O_APPEND is in effect for the
// FileDescription, PWrite should ignore the offset and append data to the
// end of the file.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, PWrite returns
+ // EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for writing.
+ // FileDescriptionOptions.DenyPWrite == false.
PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error)
// Write is similar to PWrite, but does not specify an offset, which is
@@ -156,6 +414,12 @@ type FileDescriptionImpl interface {
// PWrite that uses a FileDescription offset, to make it possible for
// remote filesystems to implement O_APPEND correctly (i.e. atomically with
// respect to writers outside the scope of VFS).
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies unsupported options, Write returns EOPNOTSUPP.
+ //
+ // Preconditions: The FileDescription was opened for writing.
Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error)
// IterDirents invokes cb on each entry in the directory represented by the
@@ -185,7 +449,31 @@ type FileDescriptionImpl interface {
// Ioctl implements the ioctl(2) syscall.
Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error)
- // TODO: extended attributes; file locking
+ // Listxattr returns all extended attribute names for the file.
+ Listxattr(ctx context.Context, size uint64) ([]string, error)
+
+ // Getxattr returns the value associated with the given extended attribute
+ // for the file.
+ Getxattr(ctx context.Context, opts GetxattrOptions) (string, error)
+
+ // Setxattr changes the value associated with the given extended attribute
+ // for the file.
+ Setxattr(ctx context.Context, opts SetxattrOptions) error
+
+ // Removexattr removes the given extended attribute from the file.
+ Removexattr(ctx context.Context, name string) error
+
+ // LockBSD tries to acquire a BSD-style advisory file lock.
+ LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error
+
+ // UnlockBSD releases a BSD-style advisory file lock.
+ UnlockBSD(ctx context.Context, uid lock.UniqueID) error
+
+ // LockPOSIX tries to acquire a POSIX-style advisory file lock.
+ LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, length uint64, whence int16, block lock.Blocker) error
+
+ // UnlockPOSIX releases a POSIX-style advisory file lock.
+ UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, length uint64, whence int16) error
}
// Dirent holds the information contained in struct linux_dirent64.
@@ -208,9 +496,352 @@ type Dirent struct {
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
type IterDirentsCallback interface {
- // Handle handles the given iterated Dirent. It returns true if iteration
- // should continue, and false if FileDescriptionImpl.IterDirents should
- // terminate now and restart with the same Dirent the next time it is
- // called.
- Handle(dirent Dirent) bool
+ // Handle handles the given iterated Dirent. If Handle returns a non-nil
+ // error, FileDescriptionImpl.IterDirents must stop iteration and return
+ // the error; the next call to FileDescriptionImpl.IterDirents should
+ // restart with the same Dirent.
+ Handle(dirent Dirent) error
+}
+
+// IterDirentsCallbackFunc implements IterDirentsCallback for a function with
+// the semantics of IterDirentsCallback.Handle.
+type IterDirentsCallbackFunc func(dirent Dirent) error
+
+// Handle implements IterDirentsCallback.Handle.
+func (f IterDirentsCallbackFunc) Handle(dirent Dirent) error {
+ return f(dirent)
+}
+
+// OnClose is called when a file descriptor representing the FileDescription is
+// closed. Returning a non-nil error should not prevent the file descriptor
+// from being closed.
+func (fd *FileDescription) OnClose(ctx context.Context) error {
+ return fd.impl.OnClose(ctx)
+}
+
+// Stat returns metadata for the file represented by fd.
+func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(ctx, rp)
+ return stat, err
+ }
+ return fd.impl.Stat(ctx, opts)
+}
+
+// SetStat updates metadata for the file represented by fd.
+func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts)
+ vfsObj.putResolvingPath(ctx, rp)
+ return err
+ }
+ return fd.impl.SetStat(ctx, opts)
+}
+
+// StatFS returns metadata for the filesystem containing the file represented
+// by fd.
+func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp)
+ vfsObj.putResolvingPath(ctx, rp)
+ return statfs, err
+ }
+ return fd.impl.StatFS(ctx)
+}
+
+// Allocate grows file represented by FileDescription to offset + length bytes.
+func (fd *FileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ if !fd.IsWritable() {
+ return syserror.EBADF
+ }
+ return fd.impl.Allocate(ctx, mode, offset, length)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// It returns fd's I/O readiness.
+func (fd *FileDescription) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fd.impl.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+//
+// It registers e for I/O readiness events in mask.
+func (fd *FileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.impl.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+//
+// It unregisters e for I/O readiness events.
+func (fd *FileDescription) EventUnregister(e *waiter.Entry) {
+ fd.impl.EventUnregister(e)
+}
+
+// PRead reads from the file represented by fd into dst, starting at the given
+// offset, and returns the number of bytes read. PRead is permitted to return
+// partial reads with a nil error.
+func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ if fd.opts.DenyPRead {
+ return 0, syserror.ESPIPE
+ }
+ if !fd.readable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.PRead(ctx, dst, offset, opts)
+}
+
+// Read is similar to PRead, but does not specify an offset.
+func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ if !fd.readable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.Read(ctx, dst, opts)
+}
+
+// PWrite writes src to the file represented by fd, starting at the given
+// offset, and returns the number of bytes written. PWrite is permitted to
+// return partial writes with a nil error.
+func (fd *FileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ if fd.opts.DenyPWrite {
+ return 0, syserror.ESPIPE
+ }
+ if !fd.writable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.PWrite(ctx, src, offset, opts)
+}
+
+// Write is similar to PWrite, but does not specify an offset.
+func (fd *FileDescription) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ if !fd.writable {
+ return 0, syserror.EBADF
+ }
+ return fd.impl.Write(ctx, src, opts)
+}
+
+// IterDirents invokes cb on each entry in the directory represented by fd. If
+// IterDirents has been called since the last call to Seek, it continues
+// iteration from the end of the last call.
+func (fd *FileDescription) IterDirents(ctx context.Context, cb IterDirentsCallback) error {
+ return fd.impl.IterDirents(ctx, cb)
+}
+
+// Seek changes fd's offset (assuming one exists) and returns its new value.
+func (fd *FileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
+ return fd.impl.Seek(ctx, offset, whence)
+}
+
+// Sync has the semantics of fsync(2).
+func (fd *FileDescription) Sync(ctx context.Context) error {
+ return fd.impl.Sync(ctx)
+}
+
+// ConfigureMMap mutates opts to implement mmap(2) for the file represented by
+// fd.
+func (fd *FileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ return fd.impl.ConfigureMMap(ctx, opts)
+}
+
+// Ioctl implements the ioctl(2) syscall.
+func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return fd.impl.Ioctl(ctx, uio, args)
+}
+
+// Listxattr returns all extended attribute names for the file represented by
+// fd.
+//
+// If the size of the list (including a NUL terminating byte after every entry)
+// would exceed size, ERANGE may be returned. Note that implementations
+// are free to ignore size entirely and return without error). In all cases,
+// if size is 0, the list should be returned without error, regardless of size.
+func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size)
+ vfsObj.putResolvingPath(ctx, rp)
+ return names, err
+ }
+ names, err := fd.impl.Listxattr(ctx, size)
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by default
+ // don't exist.
+ return nil, nil
+ }
+ return names, err
+}
+
+// Getxattr returns the value associated with the given extended attribute for
+// the file represented by fd.
+//
+// If the size of the return value exceeds opts.Size, ERANGE may be returned
+// (note that implementations are free to ignore opts.Size entirely and return
+// without error). In all cases, if opts.Size is 0, the value should be
+// returned without error, regardless of size.
+func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
+ vfsObj.putResolvingPath(ctx, rp)
+ return val, err
+ }
+ return fd.impl.Getxattr(ctx, *opts)
+}
+
+// Setxattr changes the value associated with the given extended attribute for
+// the file represented by fd.
+func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ vfsObj.putResolvingPath(ctx, rp)
+ return err
+ }
+ return fd.impl.Setxattr(ctx, *opts)
+}
+
+// Removexattr removes the given extended attribute from the file represented
+// by fd.
+func (fd *FileDescription) Removexattr(ctx context.Context, name string) error {
+ if fd.opts.UseDentryMetadata {
+ vfsObj := fd.vd.mount.vfs
+ rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{
+ Root: fd.vd,
+ Start: fd.vd,
+ })
+ err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ vfsObj.putResolvingPath(ctx, rp)
+ return err
+ }
+ return fd.impl.Removexattr(ctx, name)
+}
+
+// SyncFS instructs the filesystem containing fd to execute the semantics of
+// syncfs(2).
+func (fd *FileDescription) SyncFS(ctx context.Context) error {
+ return fd.vd.mount.fs.impl.Sync(ctx)
+}
+
+// MappedName implements memmap.MappingIdentity.MappedName.
+func (fd *FileDescription) MappedName(ctx context.Context) string {
+ vfsroot := RootFromContext(ctx)
+ s, _ := fd.vd.mount.vfs.PathnameWithDeleted(ctx, vfsroot, fd.vd)
+ if vfsroot.Ok() {
+ vfsroot.DecRef(ctx)
+ }
+ return s
+}
+
+// DeviceID implements memmap.MappingIdentity.DeviceID.
+func (fd *FileDescription) DeviceID() uint64 {
+ stat, err := fd.Stat(context.Background(), StatOptions{
+ // There is no STATX_DEV; we assume that Stat will return it if it's
+ // available regardless of mask.
+ Mask: 0,
+ // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_sb->s_dev
+ // directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil {
+ return 0
+ }
+ return uint64(linux.MakeDeviceID(uint16(stat.DevMajor), stat.DevMinor))
+}
+
+// InodeID implements memmap.MappingIdentity.InodeID.
+func (fd *FileDescription) InodeID() uint64 {
+ stat, err := fd.Stat(context.Background(), StatOptions{
+ Mask: linux.STATX_INO,
+ // fs/proc/task_mmu.c:show_map_vma() just reads inode::i_ino directly.
+ Sync: linux.AT_STATX_DONT_SYNC,
+ })
+ if err != nil || stat.Mask&linux.STATX_INO == 0 {
+ return 0
+ }
+ return stat.Ino
+}
+
+// Msync implements memmap.MappingIdentity.Msync.
+func (fd *FileDescription) Msync(ctx context.Context, mr memmap.MappableRange) error {
+ return fd.Sync(ctx)
+}
+
+// LockBSD tries to acquire a BSD-style advisory file lock.
+func (fd *FileDescription) LockBSD(ctx context.Context, lockType lock.LockType, blocker lock.Blocker) error {
+ atomic.StoreUint32(&fd.usedLockBSD, 1)
+ return fd.impl.LockBSD(ctx, fd, lockType, blocker)
+}
+
+// UnlockBSD releases a BSD-style advisory file lock.
+func (fd *FileDescription) UnlockBSD(ctx context.Context) error {
+ return fd.impl.UnlockBSD(ctx, fd)
+}
+
+// LockPOSIX locks a POSIX-style file range lock.
+func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t lock.LockType, start, end uint64, whence int16, block lock.Blocker) error {
+ return fd.impl.LockPOSIX(ctx, uid, t, start, end, whence, block)
+}
+
+// UnlockPOSIX unlocks a POSIX-style file range lock.
+func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error {
+ return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence)
+}
+
+// A FileAsync sends signals to its owner when w is ready for IO. This is only
+// implemented by pkg/sentry/fasync:FileAsync, but we unfortunately need this
+// interface to avoid circular dependencies.
+type FileAsync interface {
+ Register(w waiter.Waitable)
+ Unregister(w waiter.Waitable)
+}
+
+// AsyncHandler returns the FileAsync for fd.
+func (fd *FileDescription) AsyncHandler() FileAsync {
+ fd.flagsMu.Lock()
+ defer fd.flagsMu.Unlock()
+ return fd.asyncHandler
+}
+
+// SetAsyncHandler sets fd.asyncHandler if it has not been set before and
+// returns it.
+func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsync {
+ fd.flagsMu.Lock()
+ defer fd.flagsMu.Unlock()
+ if fd.asyncHandler == nil {
+ fd.asyncHandler = newHandler()
+ if fd.statusFlags&linux.O_ASYNC != 0 {
+ fd.asyncHandler.Register(fd)
+ }
+ }
+ return fd.asyncHandler
}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 4fbad7840..6b8b4ad49 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -17,14 +17,15 @@ package vfs
import (
"bytes"
"io"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -32,8 +33,8 @@ import (
// implementations to adapt:
// - Have a local fileDescription struct (containing FileDescription) which
// embeds FileDescriptionDefaultImpl and overrides the default methods
-// which are common to all fd implementations for that for that filesystem
-// like StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
+// which are common to all fd implementations for that filesystem like
+// StatusFlags, SetStatusFlags, Stat, SetStat, StatFS, etc.
// - This should be embedded in all file description implementations as the
// first field by value.
// - Directory FDs would also embed DirectoryFileDescriptionDefaultImpl.
@@ -55,6 +56,12 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err
return linux.Statfs{}, syserror.ENOSYS
}
+// Allocate implements FileDescriptionImpl.Allocate analogously to
+// fallocate called on regular file, directory or FIFO in Linux.
+func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ENODEV
+}
+
// Readiness implements waiter.Waitable.Readiness analogously to
// file_operations::poll == NULL in Linux.
func (FileDescriptionDefaultImpl) Readiness(mask waiter.EventMask) waiter.EventMask {
@@ -127,11 +134,41 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg
return 0, syserror.ENOTTY
}
+// Listxattr implements FileDescriptionImpl.Listxattr analogously to
+// inode_operations::listxattr == NULL in Linux.
+func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) {
+ // This isn't exactly accurate; see FileDescription.Listxattr.
+ return nil, syserror.ENOTSUP
+}
+
+// Getxattr implements FileDescriptionImpl.Getxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) {
+ return "", syserror.ENOTSUP
+}
+
+// Setxattr implements FileDescriptionImpl.Setxattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error {
+ return syserror.ENOTSUP
+}
+
+// Removexattr implements FileDescriptionImpl.Removexattr analogously to
+// inode::i_opflags & IOP_XATTR == 0 in Linux.
+func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error {
+ return syserror.ENOTSUP
+}
+
// DirectoryFileDescriptionDefaultImpl may be embedded by implementations of
// FileDescriptionImpl that always represent directories to obtain
// implementations of non-directory I/O methods that return EISDIR.
type DirectoryFileDescriptionDefaultImpl struct{}
+// Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate.
+func (DirectoryFileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.EISDIR
+}
+
// PRead implements FileDescriptionImpl.PRead.
func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
return 0, syserror.EISDIR
@@ -152,6 +189,48 @@ func (DirectoryFileDescriptionDefaultImpl) Write(ctx context.Context, src userme
return 0, syserror.EISDIR
}
+// DentryMetadataFileDescriptionImpl may be embedded by implementations of
+// FileDescriptionImpl for which FileDescriptionOptions.UseDentryMetadata is
+// true to obtain implementations of Stat and SetStat that panic.
+type DentryMetadataFileDescriptionImpl struct{}
+
+// Stat implements FileDescriptionImpl.Stat.
+func (DentryMetadataFileDescriptionImpl) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.Stat")
+}
+
+// SetStat implements FileDescriptionImpl.SetStat.
+func (DentryMetadataFileDescriptionImpl) SetStat(ctx context.Context, opts SetStatOptions) error {
+ panic("illegal call to DentryMetadataFileDescriptionImpl.SetStat")
+}
+
+// DynamicBytesSource represents a data source for a
+// DynamicBytesFileDescriptionImpl.
+type DynamicBytesSource interface {
+ // Generate writes the file's contents to buf.
+ Generate(ctx context.Context, buf *bytes.Buffer) error
+}
+
+// StaticData implements DynamicBytesSource over a static string.
+type StaticData struct {
+ Data string
+}
+
+// Generate implements DynamicBytesSource.
+func (s *StaticData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(s.Data)
+ return nil
+}
+
+// WritableDynamicBytesSource extends DynamicBytesSource to allow writes to the
+// underlying source.
+type WritableDynamicBytesSource interface {
+ DynamicBytesSource
+
+ // Write sends writes to the source.
+ Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error)
+}
+
// DynamicBytesFileDescriptionImpl may be embedded by implementations of
// FileDescriptionImpl that represent read-only regular files whose contents
// are backed by a bytes.Buffer that is regenerated when necessary, consistent
@@ -167,13 +246,6 @@ type DynamicBytesFileDescriptionImpl struct {
lastRead int64 // offset at which the last Read, PRead, or Seek ended
}
-// DynamicBytesSource represents a data source for a
-// DynamicBytesFileDescriptionImpl.
-type DynamicBytesSource interface {
- // Generate writes the file's contents to buf.
- Generate(ctx context.Context, buf *bytes.Buffer) error
-}
-
// SetDataSource must be called exactly once on fd before first use.
func (fd *DynamicBytesFileDescriptionImpl) SetDataSource(data DynamicBytesSource) {
fd.data = data
@@ -252,3 +324,105 @@ func (fd *DynamicBytesFileDescriptionImpl) Seek(ctx context.Context, offset int6
fd.off = offset
return offset, nil
}
+
+// Preconditions: fd.mu must be locked.
+func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+ limit, err := CheckLimit(ctx, offset, src.NumBytes())
+ if err != nil {
+ return 0, err
+ }
+ src = src.TakeFirst64(limit)
+
+ writable, ok := fd.data.(WritableDynamicBytesSource)
+ if !ok {
+ return 0, syserror.EIO
+ }
+ n, err := writable.Write(ctx, src, offset)
+ if err != nil {
+ return 0, err
+ }
+
+ // Invalidate cached data that might exist prior to this call.
+ fd.buf.Reset()
+ return n, nil
+}
+
+// PWrite implements FileDescriptionImpl.PWrite.
+func (fd *DynamicBytesFileDescriptionImpl) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.pwriteLocked(ctx, src, offset, opts)
+ fd.mu.Unlock()
+ return n, err
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (fd *DynamicBytesFileDescriptionImpl) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ n, err := fd.pwriteLocked(ctx, src, fd.off, opts)
+ fd.off += n
+ fd.mu.Unlock()
+ return n, err
+}
+
+// GenericConfigureMMap may be used by most implementations of
+// FileDescriptionImpl.ConfigureMMap.
+func GenericConfigureMMap(fd *FileDescription, m memmap.Mappable, opts *memmap.MMapOpts) error {
+ opts.Mappable = m
+ opts.MappingIdentity = fd
+ fd.IncRef()
+ return nil
+}
+
+// LockFD may be used by most implementations of FileDescriptionImpl.Lock*
+// functions. Caller must call Init().
+type LockFD struct {
+ locks *FileLocks
+}
+
+// Init initializes fd with FileLocks to use.
+func (fd *LockFD) Init(locks *FileLocks) {
+ fd.locks = locks
+}
+
+// Locks returns the locks associated with this file.
+func (fd *LockFD) Locks() *FileLocks {
+ return fd.locks
+}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (fd *LockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return fd.locks.LockBSD(uid, t, block)
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (fd *LockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ fd.locks.UnlockBSD(uid)
+ return nil
+}
+
+// NoLockFD implements Lock*/Unlock* portion of FileDescriptionImpl interface
+// returning ENOLCK.
+type NoLockFD struct{}
+
+// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
+func (NoLockFD) LockBSD(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockBSD implements vfs.FileDescriptionImpl.UnlockBSD.
+func (NoLockFD) UnlockBSD(ctx context.Context, uid fslock.UniqueID) error {
+ return syserror.ENOLCK
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (NoLockFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return syserror.ENOLCK
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (NoLockFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return syserror.ENOLCK
+}
diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go
index 511b829fc..1cd607c0a 100644
--- a/pkg/sentry/vfs/file_description_impl_util_test.go
+++ b/pkg/sentry/vfs/file_description_impl_util_test.go
@@ -22,11 +22,10 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
)
// fileDescription is the common fd struct which a filesystem implementation
@@ -34,76 +33,92 @@ import (
type fileDescription struct {
vfsfd FileDescription
FileDescriptionDefaultImpl
+ NoLockFD
}
-// genCountFD is a read-only FileDescriptionImpl representing a regular file
-// that contains the number of times its DynamicBytesSource.Generate()
+// genCount contains the number of times its DynamicBytesSource.Generate()
// implementation has been called.
-type genCountFD struct {
- fileDescription
- DynamicBytesFileDescriptionImpl
-
+type genCount struct {
count uint64 // accessed using atomic memory ops
}
-func newGenCountFD(mnt *Mount, vfsd *Dentry) *FileDescription {
- var fd genCountFD
- fd.vfsfd.Init(&fd, mnt, vfsd)
- fd.DynamicBytesFileDescriptionImpl.SetDataSource(&fd)
- return &fd.vfsfd
+// Generate implements DynamicBytesSource.Generate.
+func (g *genCount) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ fmt.Fprintf(buf, "%d", atomic.AddUint64(&g.count, 1))
+ return nil
}
-// Release implements FileDescriptionImpl.Release.
-func (fd *genCountFD) Release() {
+type storeData struct {
+ data string
}
-// StatusFlags implements FileDescriptionImpl.StatusFlags.
-func (fd *genCountFD) StatusFlags(ctx context.Context) (uint32, error) {
+var _ WritableDynamicBytesSource = (*storeData)(nil)
+
+// Generate implements DynamicBytesSource.
+func (d *storeData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ buf.WriteString(d.data)
+ return nil
+}
+
+// Generate implements WritableDynamicBytesSource.
+func (d *storeData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ buf := make([]byte, src.NumBytes())
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+
+ d.data = string(buf[:n])
return 0, nil
}
-// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags.
-func (fd *genCountFD) SetStatusFlags(ctx context.Context, flags uint32) error {
- return syserror.EPERM
+// testFD is a read-only FileDescriptionImpl representing a regular file.
+type testFD struct {
+ fileDescription
+ DynamicBytesFileDescriptionImpl
+
+ data DynamicBytesSource
+}
+
+func newTestFD(ctx context.Context, vfsObj *VirtualFilesystem, statusFlags uint32, data DynamicBytesSource) *FileDescription {
+ vd := vfsObj.NewAnonVirtualDentry("genCountFD")
+ defer vd.DecRef(ctx)
+ var fd testFD
+ fd.vfsfd.Init(&fd, statusFlags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{})
+ fd.DynamicBytesFileDescriptionImpl.SetDataSource(data)
+ return &fd.vfsfd
}
+// Release implements FileDescriptionImpl.Release.
+func (fd *testFD) Release(context.Context) {
+}
+
+// SetStatusFlags implements FileDescriptionImpl.SetStatusFlags.
// Stat implements FileDescriptionImpl.Stat.
-func (fd *genCountFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
+func (fd *testFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, error) {
// Note that Statx.Mask == 0 in the return value.
return linux.Statx{}, nil
}
// SetStat implements FileDescriptionImpl.SetStat.
-func (fd *genCountFD) SetStat(ctx context.Context, opts SetStatOptions) error {
+func (fd *testFD) SetStat(ctx context.Context, opts SetStatOptions) error {
return syserror.EPERM
}
-// Generate implements DynamicBytesSource.Generate.
-func (fd *genCountFD) Generate(ctx context.Context, buf *bytes.Buffer) error {
- fmt.Fprintf(buf, "%d", atomic.AddUint64(&fd.count, 1))
- return nil
-}
-
func TestGenCountFD(t *testing.T) {
ctx := contexttest.Context(t)
- creds := auth.CredentialsFromContext(ctx)
- vfsObj := New() // vfs.New()
- vfsObj.MustRegisterFilesystemType("testfs", FDTestFilesystemType{})
- mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "testfs", &NewFilesystemOptions{})
- if err != nil {
- t.Fatalf("failed to create testfs root mount: %v", err)
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
}
- vd := mntns.Root()
- defer vd.DecRef()
-
- fd := newGenCountFD(vd.Mount(), vd.Dentry())
- defer fd.DecRef()
+ fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &genCount{})
+ defer fd.DecRef(ctx)
// The first read causes Generate to be called to fill the FD's buffer.
buf := make([]byte, 2)
ioseq := usermem.BytesIOSequence(buf)
- n, err := fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err := fd.Read(ctx, ioseq, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("first Read: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
@@ -112,17 +127,17 @@ func TestGenCountFD(t *testing.T) {
}
// A second read without seeking is still at EOF.
- n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
if n != 0 || err != io.EOF {
t.Fatalf("second Read: got (%d, %v), wanted (0, EOF)", n, err)
}
// Seeking to the beginning of the file causes it to be regenerated.
- n, err = fd.Impl().Seek(ctx, 0, linux.SEEK_SET)
+ n, err = fd.Seek(ctx, 0, linux.SEEK_SET)
if n != 0 || err != nil {
t.Fatalf("Seek: got (%d, %v), wanted (0, nil)", n, err)
}
- n, err = fd.Impl().Read(ctx, ioseq, ReadOptions{})
+ n, err = fd.Read(ctx, ioseq, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("Read after Seek: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
@@ -131,11 +146,79 @@ func TestGenCountFD(t *testing.T) {
}
// PRead at the beginning of the file also causes it to be regenerated.
- n, err = fd.Impl().PRead(ctx, ioseq, 0, ReadOptions{})
+ n, err = fd.PRead(ctx, ioseq, 0, ReadOptions{})
if n != 1 || (err != nil && err != io.EOF) {
t.Fatalf("PRead: got (%d, %v), wanted (1, nil or EOF)", n, err)
}
if want := byte('3'); buf[0] != want {
t.Errorf("PRead: got byte %c, wanted %c", buf[0], want)
}
+
+ // Write and PWrite fails.
+ if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO {
+ t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
+ }
+ if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO {
+ t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO)
+ }
+}
+
+func TestWritable(t *testing.T) {
+ ctx := contexttest.Context(t)
+
+ vfsObj := &VirtualFilesystem{}
+ if err := vfsObj.Init(ctx); err != nil {
+ t.Fatalf("VFS init: %v", err)
+ }
+ fd := newTestFD(ctx, vfsObj, linux.O_RDWR, &storeData{data: "init"})
+ defer fd.DecRef(ctx)
+
+ buf := make([]byte, 10)
+ ioseq := usermem.BytesIOSequence(buf)
+ if n, err := fd.Read(ctx, ioseq, ReadOptions{}); n != 4 && err != io.EOF {
+ t.Fatalf("Read: got (%v, %v), wanted (4, EOF)", n, err)
+ }
+ if want := "init"; want == string(buf) {
+ t.Fatalf("Read: got %v, wanted %v", string(buf), want)
+ }
+
+ // Test PWrite.
+ want := "write"
+ writeIOSeq := usermem.BytesIOSequence([]byte(want))
+ if n, err := fd.PWrite(ctx, writeIOSeq, 0, WriteOptions{}); int(n) != len(want) && err != nil {
+ t.Errorf("PWrite: got err (%v, %v), wanted (%v, nil)", n, err, len(want))
+ }
+ if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF {
+ t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want))
+ }
+ if want == string(buf) {
+ t.Fatalf("PRead: got %v, wanted %v", string(buf), want)
+ }
+
+ // Test Seek to 0 followed by Write.
+ want = "write2"
+ writeIOSeq = usermem.BytesIOSequence([]byte(want))
+ if n, err := fd.Seek(ctx, 0, linux.SEEK_SET); n != 0 && err != nil {
+ t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err)
+ }
+ if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); int(n) != len(want) && err != nil {
+ t.Errorf("Write: got err (%v, %v), wanted (%v, nil)", n, err, len(want))
+ }
+ if n, err := fd.PRead(ctx, ioseq, 0, ReadOptions{}); int(n) != len(want) && err != io.EOF {
+ t.Fatalf("PRead: got (%v, %v), wanted (%v, EOF)", n, err, len(want))
+ }
+ if want == string(buf) {
+ t.Fatalf("PRead: got %v, wanted %v", string(buf), want)
+ }
+
+ // Test failure if offset != 0.
+ if n, err := fd.Seek(ctx, 1, linux.SEEK_SET); n != 0 && err != nil {
+ t.Errorf("Seek: got err (%v, %v), wanted (0, nil)", n, err)
+ }
+ if n, err := fd.Write(ctx, writeIOSeq, WriteOptions{}); n != 0 && err != syserror.EINVAL {
+ t.Errorf("Write: got err (%v, %v), wanted (0, EINVAL)", n, err)
+ }
+ if n, err := fd.PWrite(ctx, writeIOSeq, 2, WriteOptions{}); n != 0 && err != syserror.EINVAL {
+ t.Errorf("PWrite: got err (%v, %v), wanted (0, EINVAL)", n, err)
+ }
}
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 7a074b718..df3758fd1 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -18,7 +18,10 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
// A Filesystem is a tree of nodes represented by Dentries, which forms part of
@@ -28,20 +31,44 @@ import (
// Filesystem methods require that a reference is held.
//
// Filesystem is analogous to Linux's struct super_block.
+//
+// +stateify savable
type Filesystem struct {
// refs is the reference count. refs is accessed using atomic memory
// operations.
refs int64
+ // vfs is the VirtualFilesystem that uses this Filesystem. vfs is
+ // immutable.
+ vfs *VirtualFilesystem
+
+ // fsType is the FilesystemType of this Filesystem.
+ fsType FilesystemType
+
// impl is the FilesystemImpl associated with this Filesystem. impl is
// immutable. This should be the last field in Dentry.
impl FilesystemImpl
}
// Init must be called before first use of fs.
-func (fs *Filesystem) Init(impl FilesystemImpl) {
+func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) {
fs.refs = 1
+ fs.vfs = vfsObj
+ fs.fsType = fsType
fs.impl = impl
+ vfsObj.filesystemsMu.Lock()
+ vfsObj.filesystems[fs] = struct{}{}
+ vfsObj.filesystemsMu.Unlock()
+}
+
+// FilesystemType returns the FilesystemType for this Filesystem.
+func (fs *Filesystem) FilesystemType() FilesystemType {
+ return fs.fsType
+}
+
+// VirtualFilesystem returns the containing VirtualFilesystem.
+func (fs *Filesystem) VirtualFilesystem() *VirtualFilesystem {
+ return fs.vfs
}
// Impl returns the FilesystemImpl associated with fs.
@@ -49,15 +76,36 @@ func (fs *Filesystem) Impl() FilesystemImpl {
return fs.impl
}
-func (fs *Filesystem) incRef() {
+// IncRef increments fs' reference count.
+func (fs *Filesystem) IncRef() {
if atomic.AddInt64(&fs.refs, 1) <= 1 {
- panic("Filesystem.incRef() called without holding a reference")
+ panic("Filesystem.IncRef() called without holding a reference")
+ }
+}
+
+// TryIncRef increments fs' reference count and returns true. If fs' reference
+// count is zero, TryIncRef does nothing and returns false.
+//
+// TryIncRef does not require that a reference is held on fs.
+func (fs *Filesystem) TryIncRef() bool {
+ for {
+ refs := atomic.LoadInt64(&fs.refs)
+ if refs <= 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) {
+ return true
+ }
}
}
-func (fs *Filesystem) decRef() {
+// DecRef decrements fs' reference count.
+func (fs *Filesystem) DecRef(ctx context.Context) {
if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 {
- fs.impl.Release()
+ fs.vfs.filesystemsMu.Lock()
+ delete(fs.vfs.filesystems, fs)
+ fs.vfs.filesystemsMu.Unlock()
+ fs.impl.Release(ctx)
} else if refs < 0 {
panic("Filesystem.decRef() called without holding a reference")
}
@@ -73,6 +121,24 @@ func (fs *Filesystem) decRef() {
// (responsible for actually implementing the operation) isn't known until path
// resolution is complete.
//
+// Unless otherwise specified, FilesystemImpl methods are responsible for
+// performing permission checks. In many cases, vfs package functions in
+// permissions.go may be used to help perform these checks.
+//
+// When multiple specified error conditions apply to a given method call, the
+// implementation may return any applicable errno unless otherwise specified,
+// but returning the earliest error specified is preferable to maximize
+// compatibility with Linux.
+//
+// All methods may return errors not specified, notably including:
+//
+// - ENOENT if a required path component does not exist.
+//
+// - ENOTDIR if an intermediate path component is not a directory.
+//
+// - Errors from vfs-package functions (ResolvingPath.Resolve*(),
+// Mount.CheckBeginWrite(), permission-checking functions, etc.)
+//
// For all methods that take or return linux.Statx, Statx.Uid and Statx.Gid
// should be interpreted as IDs in the root UserNamespace (i.e. as auth.KUID
// and auth.KGID respectively).
@@ -83,58 +149,243 @@ func (fs *Filesystem) decRef() {
type FilesystemImpl interface {
// Release is called when the associated Filesystem reaches zero
// references.
- Release()
+ Release(ctx context.Context)
// Sync "causes all pending modifications to filesystem metadata and cached
// file data to be written to the underlying [filesystem]", as by syncfs(2).
Sync(ctx context.Context) error
+ // AccessAt checks whether a user with creds can access the file at rp.
+ AccessAt(ctx context.Context, rp *ResolvingPath, creds *auth.Credentials, ats AccessTypes) error
+
// GetDentryAt returns a Dentry representing the file at rp. A reference is
// taken on the returned Dentry.
//
// GetDentryAt does not correspond directly to a Linux syscall; it is used
// in the implementation of:
//
- // - Syscalls that need to resolve two paths: rename(), renameat(),
- // renameat2(), link(), linkat().
+ // - Syscalls that need to resolve two paths: link(), linkat().
//
// - Syscalls that need to refer to a filesystem position outside the
// context of a file description: chdir(), fchdir(), chroot(), mount(),
// umount().
GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error)
+ // GetParentDentryAt returns a Dentry representing the directory at the
+ // second-to-last path component in rp. (Note that, despite the name, this
+ // is not necessarily the parent directory of the file at rp, since the
+ // last path component in rp may be "." or "..".) A reference is taken on
+ // the returned Dentry.
+ //
+ // GetParentDentryAt does not correspond directly to a Linux syscall; it is
+ // used in the implementation of the rename() family of syscalls, which
+ // must resolve the parent directories of two paths.
+ //
+ // Preconditions: !rp.Done().
+ //
+ // Postconditions: If GetParentDentryAt returns a nil error, then
+ // rp.Final(). If GetParentDentryAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ GetParentDentryAt(ctx context.Context, rp *ResolvingPath) (*Dentry, error)
+
// LinkAt creates a hard link at rp representing the same file as vd. It
// does not take ownership of references on vd.
//
- // The implementation is responsible for checking that vd.Mount() ==
- // rp.Mount(), and that vd does not represent a directory.
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", LinkAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, LinkAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), LinkAt returns ENOENT.
+ //
+ // - If the directory in which the link would be created has been removed
+ // by RmdirAt or RenameAt, LinkAt returns ENOENT.
+ //
+ // - If rp.Mount != vd.Mount(), LinkAt returns EXDEV.
+ //
+ // - If vd represents a directory, LinkAt returns EPERM.
+ //
+ // - If vd represents a file for which all existing links have been
+ // removed, or a file created by open(O_TMPFILE|O_EXCL), LinkAt returns
+ // ENOENT. Equivalently, if vd represents a file with a link count of 0 not
+ // created by open(O_TMPFILE) without O_EXCL, LinkAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If LinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error
// MkdirAt creates a directory at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", MkdirAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, MkdirAt returns EEXIST.
+ //
+ // - If the directory in which the new directory would be created has been
+ // removed by RmdirAt or RenameAt, MkdirAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If MkdirAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error
// MknodAt creates a regular file, device special file, or named pipe at
// rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", MknodAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, MknodAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), MknodAt returns ENOENT.
+ //
+ // - If the directory in which the file would be created has been removed
+ // by RmdirAt or RenameAt, MknodAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If MknodAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error
// OpenAt returns an FileDescription providing access to the file at rp. A
// reference is taken on the returned FileDescription.
+ //
+ // Errors:
+ //
+ // - If opts.Flags specifies O_TMPFILE and this feature is unsupported by
+ // the implementation, OpenAt returns EOPNOTSUPP. (All other unsupported
+ // features are silently ignored, consistently with Linux's open*(2).)
OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error)
// ReadlinkAt returns the target of the symbolic link at rp.
+ //
+ // Errors:
+ //
+ // - If the file at rp is not a symbolic link, ReadlinkAt returns EINVAL.
ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error)
- // RenameAt renames the Dentry represented by vd to rp. It does not take
- // ownership of references on vd.
+ // RenameAt renames the file named oldName in directory oldParentVD to rp.
+ // It does not take ownership of references on oldParentVD.
+ //
+ // Errors [1]:
+ //
+ // - If opts.Flags specifies unsupported options, RenameAt returns EINVAL.
+ //
+ // - If the last path component in rp is "." or "..", and opts.Flags
+ // contains RENAME_NOREPLACE, RenameAt returns EEXIST.
+ //
+ // - If the last path component in rp is "." or "..", and opts.Flags does
+ // not contain RENAME_NOREPLACE, RenameAt returns EBUSY.
+ //
+ // - If rp.Mount != oldParentVD.Mount(), RenameAt returns EXDEV.
+ //
+ // - If the renamed file is not a directory, and opts.MustBeDir is true,
+ // RenameAt returns ENOTDIR.
+ //
+ // - If renaming would replace an existing file and opts.Flags contains
+ // RENAME_NOREPLACE, RenameAt returns EEXIST.
+ //
+ // - If there is no existing file at rp and opts.Flags contains
+ // RENAME_EXCHANGE, RenameAt returns ENOENT.
+ //
+ // - If there is an existing non-directory file at rp, and rp.MustBeDir()
+ // is true, RenameAt returns ENOTDIR.
+ //
+ // - If the renamed file is not a directory, opts.Flags does not contain
+ // RENAME_EXCHANGE, and rp.MustBeDir() is true, RenameAt returns ENOTDIR.
+ // (This check is not subsumed by the check for directory replacement below
+ // since it applies even if there is no file to replace.)
+ //
+ // - If the renamed file is a directory, and the new parent directory of
+ // the renamed file is either the renamed directory or a descendant
+ // subdirectory of the renamed directory, RenameAt returns EINVAL.
+ //
+ // - If renaming would exchange the renamed file with an ancestor directory
+ // of the renamed file, RenameAt returns EINVAL.
+ //
+ // - If renaming would replace an ancestor directory of the renamed file,
+ // RenameAt returns ENOTEMPTY. (This check would be subsumed by the
+ // non-empty directory check below; however, this check takes place before
+ // the self-rename check.)
+ //
+ // - If the renamed file would replace or exchange with itself (i.e. the
+ // source and destination paths resolve to the same file), RenameAt returns
+ // nil, skipping the checks described below.
+ //
+ // - If the source or destination directory is not writable by the provider
+ // of rp.Credentials(), RenameAt returns EACCES.
+ //
+ // - If the renamed file is a directory, and renaming would replace a
+ // non-directory file, RenameAt returns ENOTDIR.
//
- // The implementation is responsible for checking that vd.Mount() ==
- // rp.Mount().
- RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error
+ // - If the renamed file is not a directory, and renaming would replace a
+ // directory, RenameAt returns EISDIR.
+ //
+ // - If the new parent directory of the renamed file has been removed by
+ // RmdirAt or a preceding call to RenameAt, RenameAt returns ENOENT.
+ //
+ // - If the renamed file is a directory, it is not writable by the
+ // provider of rp.Credentials(), and the source and destination parent
+ // directories are different, RenameAt returns EACCES. (This is nominally
+ // required to change the ".." entry in the renamed directory.)
+ //
+ // - If renaming would replace a non-empty directory, RenameAt returns
+ // ENOTEMPTY.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink(). oldParentVD.Dentry() was obtained from a
+ // previous call to
+ // oldParentVD.Mount().Filesystem().Impl().GetParentDentryAt(). oldName is
+ // not "." or "..".
+ //
+ // Postconditions: If RenameAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
+ //
+ // [1] "The worst of all namespace operations - renaming directory.
+ // "Perverted" doesn't even start to describe it. Somebody in UCB had a
+ // heck of a trip..." - fs/namei.c:vfs_rename()
+ RenameAt(ctx context.Context, rp *ResolvingPath, oldParentVD VirtualDentry, oldName string, opts RenameOptions) error
// RmdirAt removes the directory at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is ".", RmdirAt returns EINVAL.
+ //
+ // - If the last path component in rp is "..", RmdirAt returns ENOTEMPTY.
+ //
+ // - If no file exists at rp, RmdirAt returns ENOENT.
+ //
+ // - If the file at rp exists but is not a directory, RmdirAt returns
+ // ENOTDIR.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If RmdirAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
RmdirAt(ctx context.Context, rp *ResolvingPath) error
- // SetStatAt updates metadata for the file at the given path.
+ // SetStatAt updates metadata for the file at the given path. Implementations
+ // are responsible for checking if the operation can be performed
+ // (see vfs.CheckSetStat() for common checks).
+ //
+ // Errors:
+ //
+ // - If opts specifies unsupported options, SetStatAt returns EINVAL.
SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error
// StatAt returns metadata for the file at rp.
@@ -146,10 +397,160 @@ type FilesystemImpl interface {
StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error)
// SymlinkAt creates a symbolic link at rp referring to the given target.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", SymlinkAt returns
+ // EEXIST.
+ //
+ // - If a file already exists at rp, SymlinkAt returns EEXIST.
+ //
+ // - If rp.MustBeDir(), SymlinkAt returns ENOENT.
+ //
+ // - If the directory in which the symbolic link would be created has been
+ // removed by RmdirAt or RenameAt, SymlinkAt returns ENOENT.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If SymlinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error
- // UnlinkAt removes the non-directory file at rp.
+ // UnlinkAt removes the file at rp.
+ //
+ // Errors:
+ //
+ // - If the last path component in rp is "." or "..", UnlinkAt returns
+ // EISDIR.
+ //
+ // - If no file exists at rp, UnlinkAt returns ENOENT.
+ //
+ // - If rp.MustBeDir(), and the file at rp exists and is not a directory,
+ // UnlinkAt returns ENOTDIR.
+ //
+ // - If the file at rp exists but is a directory, UnlinkAt returns EISDIR.
+ //
+ // Preconditions: !rp.Done(). For the final path component in rp,
+ // !rp.ShouldFollowSymlink().
+ //
+ // Postconditions: If UnlinkAt returns an error returned by
+ // ResolvingPath.Resolve*(), then !rp.Done().
UnlinkAt(ctx context.Context, rp *ResolvingPath) error
- // TODO: d_path(); extended attributes; inotify_add_watch(); bind()
+ // ListxattrAt returns all extended attribute names for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // ListxattrAt returns ENOTSUP.
+ //
+ // - If the size of the list (including a NUL terminating byte after every
+ // entry) would exceed size, ERANGE may be returned. Note that
+ // implementations are free to ignore size entirely and return without
+ // error). In all cases, if size is 0, the list should be returned without
+ // error, regardless of size.
+ ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error)
+
+ // GetxattrAt returns the value associated with the given extended
+ // attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, GetxattrAt
+ // returns ENOTSUP.
+ //
+ // - If an extended attribute named opts.Name does not exist, ENODATA is
+ // returned.
+ //
+ // - If the size of the return value exceeds opts.Size, ERANGE may be
+ // returned (note that implementations are free to ignore opts.Size entirely
+ // and return without error). In all cases, if opts.Size is 0, the value
+ // should be returned without error, regardless of size.
+ GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error)
+
+ // SetxattrAt changes the value associated with the given extended
+ // attribute for the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem, SetxattrAt
+ // returns ENOTSUP.
+ //
+ // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists,
+ // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist,
+ // ENODATA is returned.
+ SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error
+
+ // RemovexattrAt removes the given extended attribute from the file at rp.
+ //
+ // Errors:
+ //
+ // - If extended attributes are not supported by the filesystem,
+ // RemovexattrAt returns ENOTSUP.
+ //
+ // - If name does not exist, ENODATA is returned.
+ RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error
+
+ // BoundEndpointAt returns the Unix socket endpoint bound at the path rp.
+ //
+ // Errors:
+ //
+ // - If the file does not have write permissions, then BoundEndpointAt
+ // returns EACCES.
+ //
+ // - If a non-socket file exists at rp, then BoundEndpointAt returns
+ // ECONNREFUSED.
+ BoundEndpointAt(ctx context.Context, rp *ResolvingPath, opts BoundEndpointOptions) (transport.BoundEndpoint, error)
+
+ // PrependPath prepends a path from vd to vd.Mount().Root() to b.
+ //
+ // If vfsroot.Ok(), it is the contextual VFS root; if it is encountered
+ // before vd.Mount().Root(), PrependPath should stop prepending path
+ // components and return a PrependPathAtVFSRootError.
+ //
+ // If traversal of vd.Dentry()'s ancestors encounters an independent
+ // ("root") Dentry that is not vd.Mount().Root() (i.e. vd.Dentry() is not a
+ // descendant of vd.Mount().Root()), PrependPath should stop prepending
+ // path components and return a PrependPathAtNonMountRootError.
+ //
+ // Filesystems for which Dentries do not have meaningful paths may prepend
+ // an arbitrary descriptive string to b and then return a
+ // PrependPathSyntheticError.
+ //
+ // Most implementations can acquire the appropriate locks to ensure that
+ // Dentry.Name() and Dentry.Parent() are fixed for vd.Dentry() and all of
+ // its ancestors, then call GenericPrependPath.
+ //
+ // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
+ PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
+}
+
+// PrependPathAtVFSRootError is returned by implementations of
+// FilesystemImpl.PrependPath() when they encounter the contextual VFS root.
+type PrependPathAtVFSRootError struct{}
+
+// Error implements error.Error.
+func (PrependPathAtVFSRootError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() reached VFS root"
+}
+
+// PrependPathAtNonMountRootError is returned by implementations of
+// FilesystemImpl.PrependPath() when they encounter an independent ancestor
+// Dentry that is not the Mount root.
+type PrependPathAtNonMountRootError struct{}
+
+// Error implements error.Error.
+func (PrependPathAtNonMountRootError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() reached root other than Mount root"
+}
+
+// PrependPathSyntheticError is returned by implementations of
+// FilesystemImpl.PrependPath() for which prepended names do not represent real
+// paths.
+type PrependPathSyntheticError struct{}
+
+// Error implements error.Error.
+func (PrependPathSyntheticError) Error() string {
+ return "vfs.FilesystemImpl.PrependPath() prepended synthetic name"
}
diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go
new file mode 100644
index 000000000..465e610e0
--- /dev/null
+++ b/pkg/sentry/vfs/filesystem_impl_util.go
@@ -0,0 +1,43 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "strings"
+)
+
+// GenericParseMountOptions parses a comma-separated list of options of the
+// form "key" or "key=value", where neither key nor value contain commas, and
+// returns it as a map. If str contains duplicate keys, then the last value
+// wins. For example:
+//
+// str = "key0=value0,key1,key2=value2,key0=value3" -> map{'key0':'value3','key1':'','key2':'value2'}
+//
+// GenericParseMountOptions is not appropriate if values may contain commas,
+// e.g. in the case of the mpol mount option for tmpfs(5).
+func GenericParseMountOptions(str string) map[string]string {
+ m := make(map[string]string)
+ for _, opt := range strings.Split(str, ",") {
+ if len(opt) > 0 {
+ res := strings.SplitN(opt, "=", 2)
+ if len(res) == 2 {
+ m[res[0]] = res[1]
+ } else {
+ m[opt] = ""
+ }
+ }
+ }
+ return m
+}
diff --git a/pkg/sentry/vfs/filesystem_type.go b/pkg/sentry/vfs/filesystem_type.go
index f401ad7f3..f2298f7f6 100644
--- a/pkg/sentry/vfs/filesystem_type.go
+++ b/pkg/sentry/vfs/filesystem_type.go
@@ -15,9 +15,10 @@
package vfs
import (
+ "bytes"
"fmt"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -25,46 +26,92 @@ import (
//
// FilesystemType is analogous to Linux's struct file_system_type.
type FilesystemType interface {
- // NewFilesystem returns a Filesystem configured by the given options,
+ // GetFilesystem returns a Filesystem configured by the given options,
// along with its mount root. A reference is taken on the returned
// Filesystem and Dentry.
- NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error)
+ GetFilesystem(ctx context.Context, vfsObj *VirtualFilesystem, creds *auth.Credentials, source string, opts GetFilesystemOptions) (*Filesystem, *Dentry, error)
+
+ // Name returns the name of this FilesystemType.
+ Name() string
}
-// NewFilesystemOptions contains options to FilesystemType.NewFilesystem.
-type NewFilesystemOptions struct {
+// GetFilesystemOptions contains options to FilesystemType.GetFilesystem.
+type GetFilesystemOptions struct {
// Data is the string passed as the 5th argument to mount(2), which is
// usually a comma-separated list of filesystem-specific mount options.
Data string
// InternalData holds opaque FilesystemType-specific data. There is
// intentionally no way for applications to specify InternalData; if it is
- // not nil, the call to NewFilesystem originates from within the sentry.
+ // not nil, the call to GetFilesystem originates from within the sentry.
InternalData interface{}
}
+// +stateify savable
+type registeredFilesystemType struct {
+ fsType FilesystemType
+ opts RegisterFilesystemTypeOptions
+}
+
+// RegisterFilesystemTypeOptions contains options to
+// VirtualFilesystem.RegisterFilesystem().
+type RegisterFilesystemTypeOptions struct {
+ // If AllowUserMount is true, allow calls to VirtualFilesystem.MountAt()
+ // for which MountOptions.InternalMount == false to use this filesystem
+ // type.
+ AllowUserMount bool
+
+ // If AllowUserList is true, make this filesystem type visible in
+ // /proc/filesystems.
+ AllowUserList bool
+
+ // If RequiresDevice is true, indicate that mounting this filesystem
+ // requires a block device as the mount source in /proc/filesystems.
+ RequiresDevice bool
+}
+
// RegisterFilesystemType registers the given FilesystemType in vfs with the
// given name.
-func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType) error {
+func (vfs *VirtualFilesystem) RegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) error {
vfs.fsTypesMu.Lock()
defer vfs.fsTypesMu.Unlock()
if existing, ok := vfs.fsTypes[name]; ok {
- return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing)
+ return fmt.Errorf("name %q is already registered to filesystem type %T", name, existing.fsType)
+ }
+ vfs.fsTypes[name] = &registeredFilesystemType{
+ fsType: fsType,
+ opts: *opts,
}
- vfs.fsTypes[name] = fsType
return nil
}
// MustRegisterFilesystemType is equivalent to RegisterFilesystemType but
// panics on failure.
-func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType) {
- if err := vfs.RegisterFilesystemType(name, fsType); err != nil {
+func (vfs *VirtualFilesystem) MustRegisterFilesystemType(name string, fsType FilesystemType, opts *RegisterFilesystemTypeOptions) {
+ if err := vfs.RegisterFilesystemType(name, fsType, opts); err != nil {
panic(fmt.Sprintf("failed to register filesystem type %T: %v", fsType, err))
}
}
-func (vfs *VirtualFilesystem) getFilesystemType(name string) FilesystemType {
+func (vfs *VirtualFilesystem) getFilesystemType(name string) *registeredFilesystemType {
vfs.fsTypesMu.RLock()
defer vfs.fsTypesMu.RUnlock()
return vfs.fsTypes[name]
}
+
+// GenerateProcFilesystems emits the contents of /proc/filesystems for vfs to
+// buf.
+func (vfs *VirtualFilesystem) GenerateProcFilesystems(buf *bytes.Buffer) {
+ vfs.fsTypesMu.RLock()
+ defer vfs.fsTypesMu.RUnlock()
+ for name, rft := range vfs.fsTypes {
+ if !rft.opts.AllowUserList {
+ continue
+ }
+ var nodev string
+ if !rft.opts.RequiresDevice {
+ nodev = "nodev"
+ }
+ fmt.Fprintf(buf, "%s\t%s\n", nodev, name)
+ }
+}
diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md
new file mode 100644
index 000000000..e7da49faa
--- /dev/null
+++ b/pkg/sentry/vfs/g3doc/inotify.md
@@ -0,0 +1,210 @@
+# Inotify
+
+Inotify is a mechanism for monitoring filesystem events in Linux--see
+inotify(7). An inotify instance can be used to monitor files and directories for
+modifications, creation/deletion, etc. The inotify API consists of system calls
+that create inotify instances (inotify_init/inotify_init1) and add/remove
+watches on files to an instance (inotify_add_watch/inotify_rm_watch). Events are
+generated from various places in the sentry, including the syscall layer, the
+vfs layer, the process fd table, and within each filesystem implementation. This
+document outlines the implementation details of inotify in VFS2.
+
+## Inotify Objects
+
+Inotify data structures are implemented in the vfs package.
+
+### vfs.Inotify
+
+Inotify instances are represented by vfs.Inotify objects, which implement
+vfs.FileDescriptionImpl. As in Linux, inotify fds are backed by a
+pseudo-filesystem (anonfs). Each inotify instance receives events from a set of
+vfs.Watch objects, which can be modified with inotify_add_watch(2) and
+inotify_rm_watch(2). An application can retrieve events by reading the inotify
+fd.
+
+### vfs.Watches
+
+The set of all watches held on a single file (i.e., the watch target) is stored
+in vfs.Watches. Each watch will belong to a different inotify instance (an
+instance can only have one watch on any watch target). The watches are stored in
+a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions
+to a single file will all share the same vfs.Watches. Activity on the target
+causes its vfs.Watches to generate notifications on its watches’ inotify
+instances.
+
+### vfs.Watch
+
+A single watch, owned by one inotify instance and applied to one watch target.
+Both the vfs.Inotify owner and vfs.Watches on the target will hold a vfs.Watch,
+which leads to some complicated locking behavior (see Lock Ordering). Whenever a
+watch is notified of an event on its target, it will queue events to its inotify
+instance for delivery to the user.
+
+### vfs.Event
+
+vfs.Event is a simple struct encapsulating all the fields for an inotify event.
+It is generated by vfs.Watches and forwarded to the watches' owners. It is
+serialized to the user during read(2) syscalls on the associated fs.Inotify's
+fd.
+
+## Lock Ordering
+
+There are three locks related to the inotify implementation:
+
+Inotify.mu: the inotify instance lock. Inotify.evMu: the inotify event queue
+lock. Watches.mu: the watch set lock, used to protect the collection of watches
+on a target.
+
+The correct lock ordering for inotify code is:
+
+Inotify.mu -> Watches.mu -> Inotify.evMu.
+
+Note that we use a distinct lock to protect the inotify event queue. If we
+simply used Inotify.mu, we could simultaneously have locks being acquired in the
+order of Inotify.mu -> Watches.mu and Watches.mu -> Inotify.mu, which would
+cause deadlocks. For instance, adding a watch to an inotify instance would
+require locking Inotify.mu, and then adding the same watch to the target would
+cause Watches.mu to be held. At the same time, generating an event on the target
+would require Watches.mu to be held before iterating through each watch, and
+then notifying the owner of each watch would cause Inotify.mu to be held.
+
+See the vfs package comment to understand how inotify locks fit into the overall
+ordering of filesystem locks.
+
+## Watch Targets in Different Filesystem Implementations
+
+In Linux, watches reside on inodes at the virtual filesystem layer. As a result,
+all hard links and file descriptions on a single file will all share the same
+watch set. In VFS2, there is no common inode structure across filesystem types
+(some may not even have inodes), so we have to plumb inotify support through
+each specific filesystem implementation. Some of the technical considerations
+are outlined below.
+
+### Tmpfs
+
+For filesystems with inodes, like tmpfs, the design is quite similar to that of
+Linux, where watches reside on the inode.
+
+### Pseudo-filesystems
+
+Technically, because inotify is implemented at the vfs layer in Linux,
+pseudo-filesystems on top of kernfs support inotify passively. However, watches
+can only track explicit filesystem operations like read/write, open/close,
+mknod, etc., so watches on a target like /proc/self/fd will not generate events
+every time a new fd is added or removed. As of this writing, we leave inotify
+unimplemented in kernfs and anonfs; it does not seem particularly useful.
+
+### Gofer Filesystem (fsimpl/gofer)
+
+The gofer filesystem has several traits that make it difficult to support
+inotify:
+
+* **There are no inodes.** A file is represented as a dentry that holds an
+ unopened p9 file (and possibly an open FID), through which the Sentry
+ interacts with the gofer.
+ * *Solution:* Because there is no inode structure stored in the sandbox,
+ inotify watches must be held on the dentry. This would be an issue in
+ the presence of hard links, where multiple dentries would need to share
+ the same set of watches, but in VFS2, we do not support the internal
+ creation of hard links on gofer fs. As a result, we make the assumption
+ that every dentry corresponds to a unique inode. However, the next point
+ raises an issue with this assumption:
+* **The Sentry cannot always be aware of hard links on the remote
+ filesystem.** There is no way for us to confirm whether two files on the
+ remote filesystem are actually links to the same inode. QIDs and inodes are
+ not always 1:1. The assumption that dentries and inodes are 1:1 is
+ inevitably broken if there are remote hard links that we cannot detect.
+ * *Solution:* this is an issue with gofer fs in general, not only inotify,
+ and we will have to live with it.
+* **Dentries can be cached, and then evicted.** Dentry lifetime does not
+ correspond to file lifetime. Because gofer fs is not entirely in-memory, the
+ absence of a dentry does not mean that the corresponding file does not
+ exist, nor does a dentry reaching zero references mean that the
+ corresponding file no longer exists. When a dentry reaches zero references,
+ it will be cached, in case the file at that path is needed again in the
+ future. However, the dentry may be evicted from the cache, which will cause
+ a new dentry to be created next time the same file path is used. The
+ existing watches will be lost.
+ * *Solution:* When a dentry reaches zero references, do not cache it if it
+ has any watches, so we can avoid eviction/destruction. Note that if the
+ dentry was deleted or invalidated (d.vfsd.IsDead()), we should still
+ destroy it along with its watches. Additionally, when a dentry’s last
+ watch is removed, we cache it if it also has zero references. This way,
+ the dentry can eventually be evicted from memory if it is no longer
+ needed.
+* **Dentries can be invalidated.** Another issue with dentry lifetime is that
+ the remote file at the file path represented may change from underneath the
+ dentry. In this case, the next time that the dentry is used, it will be
+ invalidated and a new dentry will replace it. In this case, it is not clear
+ what should be done with the watches on the old dentry.
+ * *Solution:* Silently destroy the watches when invalidation occurs. We
+ have no way of knowing exactly what happened, when it happens. Inotify
+ instances on NFS files in Linux probably behave in a similar fashion,
+ since inotify is implemented at the vfs layer and is not aware of the
+ complexities of remote file systems.
+ * An alternative would be to issue some kind of event upon invalidation,
+ e.g. a delete event, but this has several issues:
+ * We cannot discern whether the remote file was invalidated because it was
+ moved, deleted, etc. This information is crucial, because these cases
+ should result in different events. Furthermore, the watches should only
+ be destroyed if the file has been deleted.
+ * Moreover, the mechanism for detecting whether the underlying file has
+ changed is to check whether a new QID is given by the gofer. This may
+ result in false positives, e.g. suppose that the server closed and
+ re-opened the same file, which may result in a new QID.
+ * Finally, the time of the event may be completely different from the time
+ of the file modification, since a dentry is not immediately notified
+ when the underlying file has changed. It would be quite unexpected to
+ receive the notification when invalidation was triggered, i.e. the next
+ time the file was accessed within the sandbox, because then the
+ read/write/etc. operation on the file would not result in the expected
+ event.
+ * Another point in favor of the first solution: inotify in Linux can
+ already be lossy on local filesystems (one of the sacrifices made so
+ that filesystem performance isn’t killed), and it is lossy on NFS for
+ similar reasons to gofer fs. Therefore, it is better for inotify to be
+ silent than to emit incorrect notifications.
+* **There may be external users of the remote filesystem.** We can only track
+ operations performed on the file within the sandbox. This is sufficient
+ under InteropModeExclusive, but whenever there are external users, the set
+ of actions we are aware of is incomplete.
+ * *Solution:* We could either return an error or just issue a warning when
+ inotify is used without InteropModeExclusive. Although faulty, VFS1
+ allows it when the filesystem is shared, and Linux does the same for
+ remote filesystems (as mentioned above, inotify sits at the vfs level).
+
+## Dentry Interface
+
+For events that must be generated above the vfs layer, we provide the following
+DentryImpl methods to allow interactions with targets on any FilesystemImpl:
+
+* **InotifyWithParent()** generates events on the dentry’s watches as well as
+ its parent’s.
+* **Watches()** retrieves the watch set of the target represented by the
+ dentry. This is used to access and modify watches on a target.
+* **OnZeroWatches()** performs cleanup tasks after the last watch is removed
+ from a dentry. This is needed by gofer fs, which must allow a watched dentry
+ to be cached once it has no more watches. Most implementations can just do
+ nothing. Note that OnZeroWatches() must be called after all inotify locks
+ are released to preserve lock ordering, since it may acquire
+ FilesystemImpl-specific locks.
+
+## IN_EXCL_UNLINK
+
+There are several options that can be set for a watch, specified as part of the
+mask in inotify_add_watch(2). In particular, IN_EXCL_UNLINK requires some
+additional support in each filesystem.
+
+A watch with IN_EXCL_UNLINK will not generate events for its target if it
+corresponds to a path that was unlinked. For instance, if an fd is opened on
+“foo/bar” and “foo/bar” is subsequently unlinked, any reads/writes/etc. on the
+fd will be ignored by watches on “foo” or “foo/bar” with IN_EXCL_UNLINK. This
+requires each DentryImpl to keep track of whether it has been unlinked, in order
+to determine whether events should be sent to watches with IN_EXCL_UNLINK.
+
+## IN_ONESHOT
+
+One-shot watches expire after generating a single event. When an event occurs,
+all one-shot watches on the target that successfully generated an event are
+removed. Lock ordering can cause the management of one-shot watches to be quite
+expensive; see Watches.Notify() for more information.
diff --git a/pkg/sentry/vfs/genericfstree/BUILD b/pkg/sentry/vfs/genericfstree/BUILD
new file mode 100644
index 000000000..d8fd92677
--- /dev/null
+++ b/pkg/sentry/vfs/genericfstree/BUILD
@@ -0,0 +1,16 @@
+load("//tools/go_generics:defs.bzl", "go_template")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+go_template(
+ name = "generic_fstree",
+ srcs = [
+ "genericfstree.go",
+ ],
+ types = [
+ "Dentry",
+ ],
+)
diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go
new file mode 100644
index 000000000..8882fa84a
--- /dev/null
+++ b/pkg/sentry/vfs/genericfstree/genericfstree.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package genericfstree provides tools for implementing vfs.FilesystemImpls
+// where a single statically-determined lock or set of locks is sufficient to
+// ensure that a Dentry's name and parent are contextually immutable.
+//
+// Clients using this package must use the go_template_instance rule in
+// tools/go_generics/defs.bzl to create an instantiation of this template
+// package, providing types to use in place of Dentry.
+package genericfstree
+
+import (
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Dentry is a required type parameter that is a struct with the given fields.
+type Dentry struct {
+ // vfsd is the embedded vfs.Dentry corresponding to this vfs.DentryImpl.
+ vfsd vfs.Dentry
+
+ // parent is the parent of this Dentry in the filesystem's tree. If this
+ // Dentry is a filesystem root, parent is nil.
+ parent *Dentry
+
+ // name is the name of this Dentry in its parent. If this Dentry is a
+ // filesystem root, name is unspecified.
+ name string
+}
+
+// IsAncestorDentry returns true if d is an ancestor of d2; that is, d is
+// either d2's parent or an ancestor of d2's parent.
+func IsAncestorDentry(d, d2 *Dentry) bool {
+ for d2 != nil { // Stop at root, where d2.parent == nil.
+ if d2.parent == d {
+ return true
+ }
+ if d2.parent == d2 {
+ return false
+ }
+ d2 = d2.parent
+ }
+ return false
+}
+
+// ParentOrSelf returns d.parent. If d.parent is nil, ParentOrSelf returns d.
+func ParentOrSelf(d *Dentry) *Dentry {
+ if d.parent != nil {
+ return d.parent
+ }
+ return d
+}
+
+// PrependPath is a generic implementation of FilesystemImpl.PrependPath().
+func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath.Builder) error {
+ for {
+ if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() {
+ return vfs.PrependPathAtVFSRootError{}
+ }
+ if &d.vfsd == mnt.Root() {
+ return nil
+ }
+ if d.parent == nil {
+ return vfs.PrependPathAtNonMountRootError{}
+ }
+ b.PrependComponent(d.name)
+ d = d.parent
+ }
+}
diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go
new file mode 100644
index 000000000..aff220a61
--- /dev/null
+++ b/pkg/sentry/vfs/inotify.go
@@ -0,0 +1,774 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "bytes"
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/uniqueid"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// inotifyEventBaseSize is the base size of linux's struct inotify_event. This
+// must be a power 2 for rounding below.
+const inotifyEventBaseSize = 16
+
+// EventType defines different kinds of inotfiy events.
+//
+// The way events are labelled appears somewhat arbitrary, but they must match
+// Linux so that IN_EXCL_UNLINK behaves as it does in Linux.
+type EventType uint8
+
+// PathEvent and InodeEvent correspond to FSNOTIFY_EVENT_PATH and
+// FSNOTIFY_EVENT_INODE in Linux.
+const (
+ PathEvent EventType = iota
+ InodeEvent EventType = iota
+)
+
+// Inotify represents an inotify instance created by inotify_init(2) or
+// inotify_init1(2). Inotify implements FileDescriptionImpl.
+//
+// +stateify savable
+type Inotify struct {
+ vfsfd FileDescription
+ FileDescriptionDefaultImpl
+ DentryMetadataFileDescriptionImpl
+ NoLockFD
+
+ // Unique identifier for this inotify instance. We don't just reuse the
+ // inotify fd because fds can be duped. These should not be exposed to the
+ // user, since we may aggressively reuse an id on S/R.
+ id uint64
+
+ // queue is used to notify interested parties when the inotify instance
+ // becomes readable or writable.
+ queue waiter.Queue `state:"nosave"`
+
+ // evMu *only* protects the events list. We need a separate lock while
+ // queuing events: using mu may violate lock ordering, since at that point
+ // the calling goroutine may already hold Watches.mu.
+ evMu sync.Mutex `state:"nosave"`
+
+ // A list of pending events for this inotify instance. Protected by evMu.
+ events eventList
+
+ // A scratch buffer, used to serialize inotify events. Allocate this
+ // ahead of time for the sake of performance. Protected by evMu.
+ scratch []byte
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // nextWatchMinusOne is used to allocate watch descriptors on this Inotify
+ // instance. Note that Linux starts numbering watch descriptors from 1.
+ nextWatchMinusOne int32
+
+ // Map from watch descriptors to watch objects.
+ watches map[int32]*Watch
+}
+
+var _ FileDescriptionImpl = (*Inotify)(nil)
+
+// NewInotifyFD constructs a new Inotify instance.
+func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) (*FileDescription, error) {
+ // O_CLOEXEC affects file descriptors, so it must be handled outside of vfs.
+ flags &^= linux.O_CLOEXEC
+ if flags&^linux.O_NONBLOCK != 0 {
+ return nil, syserror.EINVAL
+ }
+
+ id := uniqueid.GlobalFromContext(ctx)
+ vd := vfsObj.NewAnonVirtualDentry(fmt.Sprintf("[inotifyfd:%d]", id))
+ defer vd.DecRef(ctx)
+ fd := &Inotify{
+ id: id,
+ scratch: make([]byte, inotifyEventBaseSize),
+ watches: make(map[int32]*Watch),
+ }
+ if err := fd.vfsfd.Init(fd, flags, vd.Mount(), vd.Dentry(), &FileDescriptionOptions{
+ UseDentryMetadata: true,
+ DenyPRead: true,
+ DenyPWrite: true,
+ }); err != nil {
+ return nil, err
+ }
+ return &fd.vfsfd, nil
+}
+
+// Release implements FileDescriptionImpl.Release. Release removes all
+// watches and frees all resources for an inotify instance.
+func (i *Inotify) Release(ctx context.Context) {
+ var ds []*Dentry
+
+ // We need to hold i.mu to avoid a race with concurrent calls to
+ // Inotify.handleDeletion from Watches. There's no risk of Watches
+ // accessing this Inotify after the destructor ends, because we remove all
+ // references to it below.
+ i.mu.Lock()
+ for _, w := range i.watches {
+ // Remove references to the watch from the watches set on the target. We
+ // don't need to worry about the references from i.watches, since this
+ // file description is about to be destroyed.
+ d := w.target
+ ws := d.Watches()
+ // Watchable dentries should never return a nil watch set.
+ if ws == nil {
+ panic("Cannot remove watch from an unwatchable dentry")
+ }
+ ws.Remove(i.id)
+ if ws.Size() == 0 {
+ ds = append(ds, d)
+ }
+ }
+ i.mu.Unlock()
+
+ for _, d := range ds {
+ d.OnZeroWatches(ctx)
+ }
+}
+
+// Allocate implements FileDescription.Allocate.
+func (i *Inotify) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ panic("Allocate should not be called on read-only inotify fds")
+}
+
+// EventRegister implements waiter.Waitable.
+func (i *Inotify) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ i.queue.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.
+func (i *Inotify) EventUnregister(e *waiter.Entry) {
+ i.queue.EventUnregister(e)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+//
+// Readiness indicates whether there are pending events for an inotify instance.
+func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask {
+ ready := waiter.EventMask(0)
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if !i.events.Empty() {
+ ready |= waiter.EventIn
+ }
+
+ return mask & ready
+}
+
+// PRead implements FileDescriptionImpl.PRead.
+func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// PWrite implements FileDescriptionImpl.PWrite.
+func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements FileDescriptionImpl.Write.
+func (*Inotify) Write(ctx context.Context, src usermem.IOSequence, opts WriteOptions) (int64, error) {
+ return 0, syserror.EBADF
+}
+
+// Read implements FileDescriptionImpl.Read.
+func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOptions) (int64, error) {
+ if dst.NumBytes() < inotifyEventBaseSize {
+ return 0, syserror.EINVAL
+ }
+
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+
+ if i.events.Empty() {
+ // Nothing to read yet, tell caller to block.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var writeLen int64
+ for it := i.events.Front(); it != nil; {
+ // Advance `it` before the element is removed from the list, or else
+ // it.Next() will always be nil.
+ event := it
+ it = it.Next()
+
+ // Does the buffer have enough remaining space to hold the event we're
+ // about to write out?
+ if dst.NumBytes() < int64(event.sizeOf()) {
+ if writeLen > 0 {
+ // Buffer wasn't big enough for all pending events, but we did
+ // write some events out.
+ return writeLen, nil
+ }
+ return 0, syserror.EINVAL
+ }
+
+ // Linux always dequeues an available event as long as there's enough
+ // buffer space to copy it out, even if the copy below fails. Emulate
+ // this behaviour.
+ i.events.Remove(event)
+
+ // Buffer has enough space, copy event to the read buffer.
+ n, err := event.CopyTo(ctx, i.scratch, dst)
+ if err != nil {
+ return 0, err
+ }
+
+ writeLen += n
+ dst = dst.DropFirst64(n)
+ }
+ return writeLen, nil
+}
+
+// Ioctl implements FileDescriptionImpl.Ioctl.
+func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch args[1].Int() {
+ case linux.FIONREAD:
+ i.evMu.Lock()
+ defer i.evMu.Unlock()
+ var n uint32
+ for e := i.events.Front(); e != nil; e = e.Next() {
+ n += uint32(e.sizeOf())
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], n)
+ _, err := uio.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{})
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func (i *Inotify) queueEvent(ev *Event) {
+ i.evMu.Lock()
+
+ // Check if we should coalesce the event we're about to queue with the last
+ // one currently in the queue. Events are coalesced if they are identical.
+ if last := i.events.Back(); last != nil {
+ if ev.equals(last) {
+ // "Coalesce" the two events by simply not queuing the new one. We
+ // don't need to raise a waiter.EventIn notification because no new
+ // data is available for reading.
+ i.evMu.Unlock()
+ return
+ }
+ }
+
+ i.events.PushBack(ev)
+
+ // Release mutex before notifying waiters because we don't control what they
+ // can do.
+ i.evMu.Unlock()
+
+ i.queue.Notify(waiter.EventIn)
+}
+
+// newWatchLocked creates and adds a new watch to target.
+//
+// Precondition: i.mu must be locked. ws must be the watch set for target d.
+func (i *Inotify) newWatchLocked(d *Dentry, ws *Watches, mask uint32) *Watch {
+ w := &Watch{
+ owner: i,
+ wd: i.nextWatchIDLocked(),
+ target: d,
+ mask: mask,
+ }
+
+ // Hold the watch in this inotify instance as well as the watch set on the
+ // target.
+ i.watches[w.wd] = w
+ ws.Add(w)
+ return w
+}
+
+// newWatchIDLocked allocates and returns a new watch descriptor.
+//
+// Precondition: i.mu must be locked.
+func (i *Inotify) nextWatchIDLocked() int32 {
+ i.nextWatchMinusOne++
+ return i.nextWatchMinusOne
+}
+
+// AddWatch constructs a new inotify watch and adds it to the target. It
+// returns the watch descriptor returned by inotify_add_watch(2).
+//
+// The caller must hold a reference on target.
+func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) {
+ // Note: Locking this inotify instance protects the result returned by
+ // Lookup() below. With the lock held, we know for sure the lookup result
+ // won't become stale because it's impossible for *this* instance to
+ // add/remove watches on target.
+ i.mu.Lock()
+ defer i.mu.Unlock()
+
+ ws := target.Watches()
+ if ws == nil {
+ // While Linux supports inotify watches on all filesystem types, watches on
+ // filesystems like kernfs are not generally useful, so we do not.
+ return 0, syserror.EPERM
+ }
+ // Does the target already have a watch from this inotify instance?
+ if existing := ws.Lookup(i.id); existing != nil {
+ newmask := mask
+ if mask&linux.IN_MASK_ADD != 0 {
+ // "Add (OR) events to watch mask for this pathname if it already
+ // exists (instead of replacing mask)." -- inotify(7)
+ newmask |= atomic.LoadUint32(&existing.mask)
+ }
+ atomic.StoreUint32(&existing.mask, newmask)
+ return existing.wd, nil
+ }
+
+ // No existing watch, create a new watch.
+ w := i.newWatchLocked(target, ws, mask)
+ return w.wd, nil
+}
+
+// RmWatch looks up an inotify watch for the given 'wd' and configures the
+// target to stop sending events to this inotify instance.
+func (i *Inotify) RmWatch(ctx context.Context, wd int32) error {
+ i.mu.Lock()
+
+ // Find the watch we were asked to removed.
+ w, ok := i.watches[wd]
+ if !ok {
+ i.mu.Unlock()
+ return syserror.EINVAL
+ }
+
+ // Remove the watch from this instance.
+ delete(i.watches, wd)
+
+ // Remove the watch from the watch target.
+ ws := w.target.Watches()
+ // AddWatch ensures that w.target has a non-nil watch set.
+ if ws == nil {
+ panic("Watched dentry cannot have nil watch set")
+ }
+ ws.Remove(w.OwnerID())
+ remaining := ws.Size()
+ i.mu.Unlock()
+
+ if remaining == 0 {
+ w.target.OnZeroWatches(ctx)
+ }
+
+ // Generate the event for the removal.
+ i.queueEvent(newEvent(wd, "", linux.IN_IGNORED, 0))
+
+ return nil
+}
+
+// Watches is the collection of all inotify watches on a single file.
+//
+// +stateify savable
+type Watches struct {
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // ws is the map of active watches in this collection, keyed by the inotify
+ // instance id of the owner.
+ ws map[uint64]*Watch
+}
+
+// Size returns the number of watches held by w.
+func (w *Watches) Size() int {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return len(w.ws)
+}
+
+// Lookup returns the watch owned by an inotify instance with the given id.
+// Returns nil if no such watch exists.
+//
+// Precondition: the inotify instance with the given id must be locked to
+// prevent the returned watch from being concurrently modified or replaced in
+// Inotify.watches.
+func (w *Watches) Lookup(id uint64) *Watch {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.ws[id]
+}
+
+// Add adds watch into this set of watches.
+//
+// Precondition: the inotify instance with the given id must be locked.
+func (w *Watches) Add(watch *Watch) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ owner := watch.OwnerID()
+ // Sanity check, we should never have two watches for one owner on the
+ // same target.
+ if _, exists := w.ws[owner]; exists {
+ panic(fmt.Sprintf("Watch collision with ID %+v", owner))
+ }
+ if w.ws == nil {
+ w.ws = make(map[uint64]*Watch)
+ }
+ w.ws[owner] = watch
+}
+
+// Remove removes a watch with the given id from this set of watches and
+// releases it. The caller is responsible for generating any watch removal
+// event, as appropriate. The provided id must match an existing watch in this
+// collection.
+//
+// Precondition: the inotify instance with the given id must be locked.
+func (w *Watches) Remove(id uint64) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if w.ws == nil {
+ // This watch set is being destroyed. The thread executing the
+ // destructor is already in the process of deleting all our watches. We
+ // got here with no references on the target because we raced with the
+ // destructor notifying all the watch owners of destruction. See the
+ // comment in Watches.HandleDeletion for why this race exists.
+ return
+ }
+
+ // It is possible for w.Remove() to be called for the same watch multiple
+ // times. See the treatment of one-shot watches in Watches.Notify().
+ if _, ok := w.ws[id]; ok {
+ delete(w.ws, id)
+ }
+}
+
+// Notify queues a new event with watches in this set. Watches with
+// IN_EXCL_UNLINK are skipped if the event is coming from a child that has been
+// unlinked.
+func (w *Watches) Notify(ctx context.Context, name string, events, cookie uint32, et EventType, unlinked bool) {
+ var hasExpired bool
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if unlinked && watch.ExcludeUnlinked() && et == PathEvent {
+ continue
+ }
+ if watch.Notify(name, events, cookie) {
+ hasExpired = true
+ }
+ }
+ w.mu.RUnlock()
+
+ if hasExpired {
+ w.cleanupExpiredWatches(ctx)
+ }
+}
+
+// This function is relatively expensive and should only be called where there
+// are expired watches.
+func (w *Watches) cleanupExpiredWatches(ctx context.Context) {
+ // Because of lock ordering, we cannot acquire Inotify.mu for each watch
+ // owner while holding w.mu. As a result, store expired watches locally
+ // before removing.
+ var toRemove []*Watch
+ w.mu.RLock()
+ for _, watch := range w.ws {
+ if atomic.LoadInt32(&watch.expired) == 1 {
+ toRemove = append(toRemove, watch)
+ }
+ }
+ w.mu.RUnlock()
+ for _, watch := range toRemove {
+ watch.owner.RmWatch(ctx, watch.wd)
+ }
+}
+
+// HandleDeletion is called when the watch target is destroyed. Clear the
+// watch set, detach watches from the inotify instances they belong to, and
+// generate the appropriate events.
+func (w *Watches) HandleDeletion(ctx context.Context) {
+ w.Notify(ctx, "", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */)
+
+ // As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for
+ // the owner of each watch being deleted. Instead, atomically store the
+ // watches map in a local variable and set it to nil so we can iterate over
+ // it with the assurance that there will be no concurrent accesses.
+ var ws map[uint64]*Watch
+ w.mu.Lock()
+ ws = w.ws
+ w.ws = nil
+ w.mu.Unlock()
+
+ // Remove each watch from its owner's watch set, and generate a corresponding
+ // watch removal event.
+ for _, watch := range ws {
+ i := watch.owner
+ i.mu.Lock()
+ _, found := i.watches[watch.wd]
+ delete(i.watches, watch.wd)
+
+ // Release mutex before notifying waiters because we don't control what
+ // they can do.
+ i.mu.Unlock()
+
+ // If watch was not found, it was removed from the inotify instance before
+ // we could get to it, in which case we should not generate an event.
+ if found {
+ i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0))
+ }
+ }
+}
+
+// Watch represent a particular inotify watch created by inotify_add_watch.
+//
+// +stateify savable
+type Watch struct {
+ // Inotify instance which owns this watch.
+ //
+ // This field is immutable after creation.
+ owner *Inotify
+
+ // Descriptor for this watch. This is unique across an inotify instance.
+ //
+ // This field is immutable after creation.
+ wd int32
+
+ // target is a dentry representing the watch target. Its watch set contains this watch.
+ //
+ // This field is immutable after creation.
+ target *Dentry
+
+ // Events being monitored via this watch. Must be accessed with atomic
+ // memory operations.
+ mask uint32
+
+ // expired is set to 1 to indicate that this watch is a one-shot that has
+ // already sent a notification and therefore can be removed. Must be accessed
+ // with atomic memory operations.
+ expired int32
+}
+
+// OwnerID returns the id of the inotify instance that owns this watch.
+func (w *Watch) OwnerID() uint64 {
+ return w.owner.id
+}
+
+// ExcludeUnlinked indicates whether the watched object should continue to be
+// notified of events originating from a path that has been unlinked.
+//
+// For example, if "foo/bar" is opened and then unlinked, operations on the
+// open fd may be ignored by watches on "foo" and "foo/bar" with IN_EXCL_UNLINK.
+func (w *Watch) ExcludeUnlinked() bool {
+ return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK != 0
+}
+
+// Notify queues a new event on this watch. Returns true if this is a one-shot
+// watch that should be deleted, after this event was successfully queued.
+func (w *Watch) Notify(name string, events uint32, cookie uint32) bool {
+ if atomic.LoadInt32(&w.expired) == 1 {
+ // This is a one-shot watch that is already in the process of being
+ // removed. This may happen if a second event reaches the watch target
+ // before this watch has been removed.
+ return false
+ }
+
+ mask := atomic.LoadUint32(&w.mask)
+ if mask&events == 0 {
+ // We weren't watching for this event.
+ return false
+ }
+
+ // Event mask should include bits matched from the watch plus all control
+ // event bits.
+ unmaskableBits := ^uint32(0) &^ linux.IN_ALL_EVENTS
+ effectiveMask := unmaskableBits | mask
+ matchedEvents := effectiveMask & events
+ w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie))
+ if mask&linux.IN_ONESHOT != 0 {
+ atomic.StoreInt32(&w.expired, 1)
+ return true
+ }
+ return false
+}
+
+// Event represents a struct inotify_event from linux.
+//
+// +stateify savable
+type Event struct {
+ eventEntry
+
+ wd int32
+ mask uint32
+ cookie uint32
+
+ // len is computed based on the name field is set automatically by
+ // Event.setName. It should be 0 when no name is set; otherwise it is the
+ // length of the name slice.
+ len uint32
+
+ // The name field has special padding requirements and should only be set by
+ // calling Event.setName.
+ name []byte
+}
+
+func newEvent(wd int32, name string, events, cookie uint32) *Event {
+ e := &Event{
+ wd: wd,
+ mask: events,
+ cookie: cookie,
+ }
+ if name != "" {
+ e.setName(name)
+ }
+ return e
+}
+
+// paddedBytes converts a go string to a null-terminated c-string, padded with
+// null bytes to a total size of 'l'. 'l' must be large enough for all the bytes
+// in the 's' plus at least one null byte.
+func paddedBytes(s string, l uint32) []byte {
+ if l < uint32(len(s)+1) {
+ panic("Converting string to byte array results in truncation, this can lead to buffer-overflow due to the missing null-byte!")
+ }
+ b := make([]byte, l)
+ copy(b, s)
+
+ // b was zero-value initialized during make(), so the rest of the slice is
+ // already filled with null bytes.
+
+ return b
+}
+
+// setName sets the optional name for this event.
+func (e *Event) setName(name string) {
+ // We need to pad the name such that the entire event length ends up a
+ // multiple of inotifyEventBaseSize.
+ unpaddedLen := len(name) + 1
+ // Round up to nearest multiple of inotifyEventBaseSize.
+ e.len = uint32((unpaddedLen + inotifyEventBaseSize - 1) & ^(inotifyEventBaseSize - 1))
+ // Make sure we haven't overflowed and wrapped around when rounding.
+ if unpaddedLen > int(e.len) {
+ panic("Overflow when rounding inotify event size, the 'name' field was too big.")
+ }
+ e.name = paddedBytes(name, e.len)
+}
+
+func (e *Event) sizeOf() int {
+ s := inotifyEventBaseSize + int(e.len)
+ if s < inotifyEventBaseSize {
+ panic("Overflowed event size")
+ }
+ return s
+}
+
+// CopyTo serializes this event to dst. buf is used as a scratch buffer to
+// construct the output. We use a buffer allocated ahead of time for
+// performance. buf must be at least inotifyEventBaseSize bytes.
+func (e *Event) CopyTo(ctx context.Context, buf []byte, dst usermem.IOSequence) (int64, error) {
+ usermem.ByteOrder.PutUint32(buf[0:], uint32(e.wd))
+ usermem.ByteOrder.PutUint32(buf[4:], e.mask)
+ usermem.ByteOrder.PutUint32(buf[8:], e.cookie)
+ usermem.ByteOrder.PutUint32(buf[12:], e.len)
+
+ writeLen := 0
+
+ n, err := dst.CopyOut(ctx, buf)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ dst = dst.DropFirst(n)
+
+ if e.len > 0 {
+ n, err = dst.CopyOut(ctx, e.name)
+ if err != nil {
+ return 0, err
+ }
+ writeLen += n
+ }
+
+ // Santiy check.
+ if writeLen != e.sizeOf() {
+ panic(fmt.Sprintf("Serialized unexpected amount of data for an event, expected %d, wrote %d.", e.sizeOf(), writeLen))
+ }
+
+ return int64(writeLen), nil
+}
+
+func (e *Event) equals(other *Event) bool {
+ return e.wd == other.wd &&
+ e.mask == other.mask &&
+ e.cookie == other.cookie &&
+ e.len == other.len &&
+ bytes.Equal(e.name, other.name)
+}
+
+// InotifyEventFromStatMask generates the appropriate events for an operation
+// that set the stats specified in mask.
+func InotifyEventFromStatMask(mask uint32) uint32 {
+ var ev uint32
+ if mask&(linux.STATX_UID|linux.STATX_GID|linux.STATX_MODE) != 0 {
+ ev |= linux.IN_ATTRIB
+ }
+ if mask&linux.STATX_SIZE != 0 {
+ ev |= linux.IN_MODIFY
+ }
+
+ if (mask & (linux.STATX_ATIME | linux.STATX_MTIME)) == (linux.STATX_ATIME | linux.STATX_MTIME) {
+ // Both times indicates a utime(s) call.
+ ev |= linux.IN_ATTRIB
+ } else if mask&linux.STATX_ATIME != 0 {
+ ev |= linux.IN_ACCESS
+ } else if mask&linux.STATX_MTIME != 0 {
+ mask |= linux.IN_MODIFY
+ }
+ return ev
+}
+
+// InotifyRemoveChild sends the appriopriate notifications to the watch sets of
+// the child being removed and its parent. Note that unlike most pairs of
+// parent/child notifications, the child is notified first in this case.
+func InotifyRemoveChild(ctx context.Context, self, parent *Watches, name string) {
+ if self != nil {
+ self.Notify(ctx, "", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */)
+ }
+ if parent != nil {
+ parent.Notify(ctx, name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */)
+ }
+}
+
+// InotifyRename sends the appriopriate notifications to the watch sets of the
+// file being renamed and its old/new parents.
+func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, oldName, newName string, isDir bool) {
+ var dirEv uint32
+ if isDir {
+ dirEv = linux.IN_ISDIR
+ }
+ cookie := uniqueid.InotifyCookie(ctx)
+ if oldParent != nil {
+ oldParent.Notify(ctx, oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */)
+ }
+ if newParent != nil {
+ newParent.Notify(ctx, newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */)
+ }
+ // Somewhat surprisingly, self move events do not have a cookie.
+ if renamed != nil {
+ renamed.Notify(ctx, "", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */)
+ }
+}
diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go
new file mode 100644
index 000000000..6c7583a81
--- /dev/null
+++ b/pkg/sentry/vfs/lock.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package lock provides POSIX and BSD style file locking for VFS2 file
+// implementations.
+//
+// The actual implementations can be found in the lock package under
+// sentry/fs/lock.
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// FileLocks supports POSIX and BSD style locks, which correspond to fcntl(2)
+// and flock(2) respectively in Linux. It can be embedded into various file
+// implementations for VFS2 that support locking.
+//
+// Note that in Linux these two types of locks are _not_ cooperative, because
+// race and deadlock conditions make merging them prohibitive. We do the same
+// and keep them oblivious to each other.
+type FileLocks struct {
+ // bsd is a set of BSD-style advisory file wide locks, see flock(2).
+ bsd fslock.Locks
+
+ // posix is a set of POSIX-style regional advisory locks, see fcntl(2).
+ posix fslock.Locks
+}
+
+// LockBSD tries to acquire a BSD-style lock on the entire file.
+func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fslock.Blocker) error {
+ if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) {
+ return nil
+ }
+ return syserror.ErrWouldBlock
+}
+
+// UnlockBSD releases a BSD-style lock on the entire file.
+//
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (fl *FileLocks) UnlockBSD(uid fslock.UniqueID) {
+ fl.bsd.UnlockRegion(uid, fslock.LockRange{0, fslock.LockEOF})
+}
+
+// LockPOSIX tries to acquire a POSIX-style lock on a file region.
+func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ rng, err := computeRange(ctx, fd, start, length, whence)
+ if err != nil {
+ return err
+ }
+ if fl.posix.LockRegion(uid, t, rng, block) {
+ return nil
+ }
+ return syserror.ErrWouldBlock
+}
+
+// UnlockPOSIX releases a POSIX-style lock on a file region.
+//
+// This operation is always successful, even if there did not exist a lock on
+// the requested region held by uid in the first place.
+func (fl *FileLocks) UnlockPOSIX(ctx context.Context, fd *FileDescription, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ rng, err := computeRange(ctx, fd, start, length, whence)
+ if err != nil {
+ return err
+ }
+ fl.posix.UnlockRegion(uid, rng)
+ return nil
+}
+
+func computeRange(ctx context.Context, fd *FileDescription, start uint64, length uint64, whence int16) (fslock.LockRange, error) {
+ var off int64
+ switch whence {
+ case linux.SEEK_SET:
+ off = 0
+ case linux.SEEK_CUR:
+ // Note that Linux does not hold any mutexes while retrieving the file
+ // offset, see fs/locks.c:flock_to_posix_lock and fs/locks.c:fcntl_setlk.
+ curOff, err := fd.Seek(ctx, 0, linux.SEEK_CUR)
+ if err != nil {
+ return fslock.LockRange{}, err
+ }
+ off = curOff
+ case linux.SEEK_END:
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_SIZE})
+ if err != nil {
+ return fslock.LockRange{}, err
+ }
+ off = int64(stat.Size)
+ default:
+ return fslock.LockRange{}, syserror.EINVAL
+ }
+
+ return fslock.ComputeRange(int64(start), int64(length), off)
+}
diff --git a/pkg/sentry/vfs/memxattr/BUILD b/pkg/sentry/vfs/memxattr/BUILD
new file mode 100644
index 000000000..d8c4d27b9
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "memxattr",
+ srcs = ["xattr.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ ],
+)
diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go
new file mode 100644
index 000000000..cc1e7d764
--- /dev/null
+++ b/pkg/sentry/vfs/memxattr/xattr.go
@@ -0,0 +1,102 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package memxattr provides a default, in-memory extended attribute
+// implementation.
+package memxattr
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SimpleExtendedAttributes implements extended attributes using a map of
+// names to values.
+//
+// +stateify savable
+type SimpleExtendedAttributes struct {
+ // mu protects the below fields.
+ mu sync.RWMutex `state:"nosave"`
+ xattrs map[string]string
+}
+
+// Getxattr returns the value at 'name'.
+func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) {
+ x.mu.RLock()
+ value, ok := x.xattrs[opts.Name]
+ x.mu.RUnlock()
+ if !ok {
+ return "", syserror.ENODATA
+ }
+ // Check that the size of the buffer provided in getxattr(2) is large enough
+ // to contain the value.
+ if opts.Size != 0 && uint64(len(value)) > opts.Size {
+ return "", syserror.ERANGE
+ }
+ return value, nil
+}
+
+// Setxattr sets 'value' at 'name'.
+func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if x.xattrs == nil {
+ if opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+ x.xattrs = make(map[string]string)
+ }
+
+ _, ok := x.xattrs[opts.Name]
+ if ok && opts.Flags&linux.XATTR_CREATE != 0 {
+ return syserror.EEXIST
+ }
+ if !ok && opts.Flags&linux.XATTR_REPLACE != 0 {
+ return syserror.ENODATA
+ }
+
+ x.xattrs[opts.Name] = opts.Value
+ return nil
+}
+
+// Listxattr returns all names in xattrs.
+func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) {
+ // Keep track of the size of the buffer needed in listxattr(2) for the list.
+ listSize := 0
+ x.mu.RLock()
+ names := make([]string, 0, len(x.xattrs))
+ for n := range x.xattrs {
+ names = append(names, n)
+ // Add one byte per null terminator.
+ listSize += len(n) + 1
+ }
+ x.mu.RUnlock()
+ if size != 0 && uint64(listSize) > size {
+ return nil, syserror.ERANGE
+ }
+ return names, nil
+}
+
+// Removexattr removes the xattr at 'name'.
+func (x *SimpleExtendedAttributes) Removexattr(name string) error {
+ x.mu.Lock()
+ defer x.mu.Unlock()
+ if _, ok := x.xattrs[name]; !ok {
+ return syserror.ENODATA
+ }
+ delete(x.xattrs, name)
+ return nil
+}
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 11702f720..d1d29d0cd 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -15,10 +15,15 @@
package vfs
import (
+ "bytes"
+ "fmt"
"math"
+ "sort"
+ "strings"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -37,49 +42,103 @@ import (
//
// Mount is analogous to Linux's struct mount. (gVisor does not distinguish
// between struct mount and struct vfsmount.)
+//
+// +stateify savable
type Mount struct {
+ // vfs, fs, root are immutable. References are held on fs and root.
+ //
+ // Invariant: root belongs to fs.
+ vfs *VirtualFilesystem
+ fs *Filesystem
+ root *Dentry
+
+ // ID is the immutable mount ID.
+ ID uint64
+
+ // Flags contains settings as specified for mount(2), e.g. MS_NOEXEC, except
+ // for MS_RDONLY which is tracked in "writers". Immutable.
+ Flags MountFlags
+
+ // key is protected by VirtualFilesystem.mountMu and
+ // VirtualFilesystem.mounts.seq, and may be nil. References are held on
+ // key.parent and key.point if they are not nil.
+ //
+ // Invariant: key.parent != nil iff key.point != nil. key.point belongs to
+ // key.parent.fs.
+ key mountKey
+
+ // ns is the namespace in which this Mount was mounted. ns is protected by
+ // VirtualFilesystem.mountMu.
+ ns *MountNamespace
+
// The lower 63 bits of refs are a reference count. The MSB of refs is set
- // if the Mount has been eagerly unmounted, as by umount(2) without the
+ // if the Mount has been eagerly umounted, as by umount(2) without the
// MNT_DETACH flag. refs is accessed using atomic memory operations.
refs int64
+ // children is the set of all Mounts for which Mount.key.parent is this
+ // Mount. children is protected by VirtualFilesystem.mountMu.
+ children map[*Mount]struct{}
+
+ // umounted is true if VFS.umountRecursiveLocked() has been called on this
+ // Mount. VirtualFilesystem does not hold a reference on Mounts for which
+ // umounted is true. umounted is protected by VirtualFilesystem.mountMu.
+ umounted bool
+
// The lower 63 bits of writers is the number of calls to
// Mount.CheckBeginWrite() that have not yet been paired with a call to
// Mount.EndWrite(). The MSB of writers is set if MS_RDONLY is in effect.
// writers is accessed using atomic memory operations.
writers int64
+}
- // key is protected by VirtualFilesystem.mountMu and
- // VirtualFilesystem.mounts.seq, and may be nil. References are held on
- // key.parent and key.point if they are not nil.
- //
- // Invariant: key.parent != nil iff key.point != nil. key.point belongs to
- // key.parent.fs.
- key mountKey
+func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *MountNamespace, opts *MountOptions) *Mount {
+ mnt := &Mount{
+ ID: atomic.AddUint64(&vfs.lastMountID, 1),
+ Flags: opts.Flags,
+ vfs: vfs,
+ fs: fs,
+ root: root,
+ ns: mntns,
+ refs: 1,
+ }
+ if opts.ReadOnly {
+ mnt.setReadOnlyLocked(true)
+ }
+ return mnt
+}
- // fs, root, and ns are immutable. References are held on fs and root (but
- // not ns).
- //
- // Invariant: root belongs to fs.
- fs *Filesystem
- root *Dentry
- ns *MountNamespace
+// Options returns a copy of the MountOptions currently applicable to mnt.
+func (mnt *Mount) Options() MountOptions {
+ mnt.vfs.mountMu.Lock()
+ defer mnt.vfs.mountMu.Unlock()
+ return MountOptions{
+ Flags: mnt.Flags,
+ ReadOnly: mnt.readOnly(),
+ }
}
-// A MountNamespace is a collection of Mounts.
-//
+// A MountNamespace is a collection of Mounts.//
// MountNamespaces are reference-counted. Unless otherwise specified, all
// MountNamespace methods require that a reference is held.
//
// MountNamespace is analogous to Linux's struct mnt_namespace.
+//
+// +stateify savable
type MountNamespace struct {
- refs int64 // accessed using atomic memory operations
+ // Owner is the usernamespace that owns this mount namespace.
+ Owner *auth.UserNamespace
// root is the MountNamespace's root mount. root is immutable.
root *Mount
- // mountpoints contains all Dentries which are mount points in this
- // namespace. mountpoints is protected by VirtualFilesystem.mountMu.
+ // refs is the reference count. refs is accessed using atomic memory
+ // operations.
+ refs int64
+
+ // mountpoints maps all Dentries which are mount points in this namespace
+ // to the number of Mounts for which they are mount points. mountpoints is
+ // protected by VirtualFilesystem.mountMu.
//
// mountpoints is used to determine if a Dentry can be moved or removed
// (which requires that the Dentry is not a mount point in the calling
@@ -89,59 +148,80 @@ type MountNamespace struct {
// MountNamespace; this is required to ensure that
// VFS.PrepareDeleteDentry() and VFS.PrepareRemoveDentry() operate
// correctly on unreferenced MountNamespaces.
- mountpoints map[*Dentry]struct{}
+ mountpoints map[*Dentry]uint32
}
// NewMountNamespace returns a new mount namespace with a root filesystem
// configured by the given arguments. A reference is taken on the returned
// MountNamespace.
-func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *NewFilesystemOptions) (*MountNamespace, error) {
- fsType := vfs.getFilesystemType(fsTypeName)
- if fsType == nil {
+func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
+ ctx.Warningf("Unknown filesystem type: %s", fsTypeName)
return nil, syserror.ENODEV
}
- fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts)
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts)
if err != nil {
return nil, err
}
mntns := &MountNamespace{
+ Owner: creds.UserNamespace,
refs: 1,
- mountpoints: make(map[*Dentry]struct{}),
- }
- mntns.root = &Mount{
- fs: fs,
- root: root,
- ns: mntns,
- refs: 1,
+ mountpoints: make(map[*Dentry]uint32),
}
+ mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{})
return mntns, nil
}
-// NewMount creates and mounts a new Filesystem.
-func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *NewFilesystemOptions) error {
- fsType := vfs.getFilesystemType(fsTypeName)
- if fsType == nil {
- return syserror.ENODEV
+// NewDisconnectedMount returns a Mount representing fs with the given root
+// (which may be nil). The new Mount is not associated with any MountNamespace
+// and is not connected to any other Mounts. References are taken on fs and
+// root.
+func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, opts *MountOptions) (*Mount, error) {
+ fs.IncRef()
+ if root != nil {
+ root.IncRef()
+ }
+ return newMount(vfs, fs, root, nil /* mntns */, opts), nil
+}
+
+// MountDisconnected creates a Filesystem configured by the given arguments,
+// then returns a Mount representing it. The new Mount is not associated with
+// any MountNamespace and is not connected to any other Mounts.
+func (vfs *VirtualFilesystem) MountDisconnected(ctx context.Context, creds *auth.Credentials, source string, fsTypeName string, opts *MountOptions) (*Mount, error) {
+ rft := vfs.getFilesystemType(fsTypeName)
+ if rft == nil {
+ return nil, syserror.ENODEV
+ }
+ if !opts.InternalMount && !rft.opts.AllowUserMount {
+ return nil, syserror.ENODEV
}
- fs, root, err := fsType.NewFilesystem(ctx, creds, source, *opts)
+ fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions)
if err != nil {
- return err
+ return nil, err
}
+ defer root.DecRef(ctx)
+ defer fs.DecRef(ctx)
+ return vfs.NewDisconnectedMount(fs, root, opts)
+}
+
+// ConnectMountAt connects mnt at the path represented by target.
+//
+// Preconditions: mnt must be disconnected.
+func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Credentials, mnt *Mount, target *PathOperation) error {
// We can't hold vfs.mountMu while calling FilesystemImpl methods due to
// lock ordering.
vd, err := vfs.GetDentryAt(ctx, creds, target, &GetDentryOptions{})
if err != nil {
- root.decRef(fs)
- fs.decRef()
return err
}
vfs.mountMu.Lock()
+ vd.dentry.mu.Lock()
for {
- if vd.dentry.IsDisowned() {
+ if vd.dentry.dead {
+ vd.dentry.mu.Unlock()
vfs.mountMu.Unlock()
- vd.DecRef()
- root.decRef(fs)
- fs.decRef()
+ vd.DecRef(ctx)
return syserror.ENOENT
}
// vd might have been mounted over between vfs.GetDentryAt() and
@@ -153,36 +233,298 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti
if nextmnt == nil {
break
}
- nextmnt.incRef()
- nextmnt.root.incRef(nextmnt.fs)
- vd.DecRef()
+ // It's possible that nextmnt has been umounted but not disconnected,
+ // in which case vfs no longer holds a reference on it, and the last
+ // reference may be concurrently dropped even though we're holding
+ // vfs.mountMu.
+ if !nextmnt.tryIncMountedRef() {
+ break
+ }
+ // This can't fail since we're holding vfs.mountMu.
+ nextmnt.root.IncRef()
+ vd.dentry.mu.Unlock()
+ vd.DecRef(ctx)
vd = VirtualDentry{
mount: nextmnt,
dentry: nextmnt.root,
}
+ vd.dentry.mu.Lock()
}
- // TODO: Linux requires that either both the mount point and the mount root
- // are directories, or neither are, and returns ENOTDIR if this is not the
- // case.
+ // TODO(gvisor.dev/issue/1035): Linux requires that either both the mount
+ // point and the mount root are directories, or neither are, and returns
+ // ENOTDIR if this is not the case.
mntns := vd.mount.ns
- mnt := &Mount{
- fs: fs,
- root: root,
- ns: mntns,
- refs: 1,
+ vfs.mounts.seq.BeginWrite()
+ vfs.connectLocked(mnt, vd, mntns)
+ vfs.mounts.seq.EndWrite()
+ vd.dentry.mu.Unlock()
+ vfs.mountMu.Unlock()
+ return nil
+}
+
+// MountAt creates and mounts a Filesystem configured by the given arguments.
+func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
+ mnt, err := vfs.MountDisconnected(ctx, creds, source, fsTypeName, opts)
+ if err != nil {
+ return err
+ }
+ defer mnt.DecRef(ctx)
+ if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil {
+ return err
}
- mnt.storeKey(vd.mount, vd.dentry)
+ return nil
+}
+
+// UmountAt removes the Mount at the given path.
+func (vfs *VirtualFilesystem) UmountAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *UmountOptions) error {
+ if opts.Flags&^(linux.MNT_FORCE|linux.MNT_DETACH) != 0 {
+ return syserror.EINVAL
+ }
+
+ // MNT_FORCE is currently unimplemented except for the permission check.
+ // Force unmounting specifically requires CAP_SYS_ADMIN in the root user
+ // namespace, and not in the owner user namespace for the target mount. See
+ // fs/namespace.c:SYSCALL_DEFINE2(umount, ...)
+ if opts.Flags&linux.MNT_FORCE != 0 && creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) {
+ return syserror.EPERM
+ }
+
+ vd, err := vfs.GetDentryAt(ctx, creds, pop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+ defer vd.DecRef(ctx)
+ if vd.dentry != vd.mount.root {
+ return syserror.EINVAL
+ }
+ vfs.mountMu.Lock()
+ if mntns := MountNamespaceFromContext(ctx); mntns != nil {
+ defer mntns.DecRef(ctx)
+ if mntns != vd.mount.ns {
+ vfs.mountMu.Unlock()
+ return syserror.EINVAL
+ }
+ }
+
+ // TODO(gvisor.dev/issue/1035): Linux special-cases umount of the caller's
+ // root, which we don't implement yet (we'll just fail it since the caller
+ // holds a reference on it).
+
+ vfs.mounts.seq.BeginWrite()
+ if opts.Flags&linux.MNT_DETACH == 0 {
+ if len(vd.mount.children) != 0 {
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ // We are holding a reference on vd.mount.
+ expectedRefs := int64(1)
+ if !vd.mount.umounted {
+ expectedRefs = 2
+ }
+ if atomic.LoadInt64(&vd.mount.refs)&^math.MinInt64 != expectedRefs { // mask out MSB
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ return syserror.EBUSY
+ }
+ }
+ vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(vd.mount, &umountRecursiveOptions{
+ eager: opts.Flags&linux.MNT_DETACH == 0,
+ disconnectHierarchy: true,
+ }, nil, nil)
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef(ctx)
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef(ctx)
+ }
+ return nil
+}
+
+type umountRecursiveOptions struct {
+ // If eager is true, ensure that future calls to Mount.tryIncMountedRef()
+ // on umounted mounts fail.
+ //
+ // eager is analogous to Linux's UMOUNT_SYNC.
+ eager bool
+
+ // If disconnectHierarchy is true, Mounts that are umounted hierarchically
+ // should be disconnected from their parents. (Mounts whose parents are not
+ // umounted, which in most cases means the Mount passed to the initial call
+ // to umountRecursiveLocked, are unconditionally disconnected for
+ // consistency with Linux.)
+ //
+ // disconnectHierarchy is analogous to Linux's !UMOUNT_CONNECTED.
+ disconnectHierarchy bool
+}
+
+// umountRecursiveLocked marks mnt and its descendants as umounted. It does not
+// release mount or dentry references; instead, it appends VirtualDentries and
+// Mounts on which references must be dropped to vdsToDecRef and mountsToDecRef
+// respectively, and returns updated slices. (This is necessary because
+// filesystem locks possibly taken by DentryImpl.DecRef() may precede
+// vfs.mountMu in the lock order, and Mount.DecRef() may lock vfs.mountMu.)
+//
+// umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree().
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section.
+func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) {
+ if !mnt.umounted {
+ mnt.umounted = true
+ mountsToDecRef = append(mountsToDecRef, mnt)
+ if parent := mnt.parent(); parent != nil && (opts.disconnectHierarchy || !parent.umounted) {
+ vdsToDecRef = append(vdsToDecRef, vfs.disconnectLocked(mnt))
+ }
+ }
+ if opts.eager {
+ for {
+ refs := atomic.LoadInt64(&mnt.refs)
+ if refs < 0 {
+ break
+ }
+ if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs|math.MinInt64) {
+ break
+ }
+ }
+ }
+ for child := range mnt.children {
+ vdsToDecRef, mountsToDecRef = vfs.umountRecursiveLocked(child, opts, vdsToDecRef, mountsToDecRef)
+ }
+ return vdsToDecRef, mountsToDecRef
+}
+
+// connectLocked makes vd the mount parent/point for mnt. It consumes
+// references held by vd.
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt
+// must not already be connected.
+func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) {
+ if checkInvariants {
+ if mnt.parent() != nil {
+ panic("VFS.connectLocked called on connected mount")
+ }
+ }
+ mnt.IncRef() // dropped by callers of umountRecursiveLocked
+ mnt.storeKey(vd)
+ if vd.mount.children == nil {
+ vd.mount.children = make(map[*Mount]struct{})
+ }
+ vd.mount.children[mnt] = struct{}{}
atomic.AddUint32(&vd.dentry.mounts, 1)
- mntns.mountpoints[vd.dentry] = struct{}{}
+ mnt.ns = mntns
+ mntns.mountpoints[vd.dentry]++
+ vfs.mounts.insertSeqed(mnt)
vfsmpmounts, ok := vfs.mountpoints[vd.dentry]
if !ok {
vfsmpmounts = make(map[*Mount]struct{})
vfs.mountpoints[vd.dentry] = vfsmpmounts
}
vfsmpmounts[mnt] = struct{}{}
- vfs.mounts.Insert(mnt)
- vfs.mountMu.Unlock()
- return nil
+}
+
+// disconnectLocked makes vd have no mount parent/point and returns its old
+// mount parent/point with a reference held.
+//
+// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a
+// writer critical section. mnt.parent() != nil.
+func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry {
+ vd := mnt.loadKey()
+ if checkInvariants {
+ if vd.mount != nil {
+ panic("VFS.disconnectLocked called on disconnected mount")
+ }
+ }
+ mnt.storeKey(VirtualDentry{})
+ delete(vd.mount.children, mnt)
+ atomic.AddUint32(&vd.dentry.mounts, math.MaxUint32) // -1
+ mnt.ns.mountpoints[vd.dentry]--
+ if mnt.ns.mountpoints[vd.dentry] == 0 {
+ delete(mnt.ns.mountpoints, vd.dentry)
+ }
+ vfs.mounts.removeSeqed(mnt)
+ vfsmpmounts := vfs.mountpoints[vd.dentry]
+ delete(vfsmpmounts, mnt)
+ if len(vfsmpmounts) == 0 {
+ delete(vfs.mountpoints, vd.dentry)
+ }
+ return vd
+}
+
+// tryIncMountedRef increments mnt's reference count and returns true. If mnt's
+// reference count is already zero, or has been eagerly umounted,
+// tryIncMountedRef does nothing and returns false.
+//
+// tryIncMountedRef does not require that a reference is held on mnt.
+func (mnt *Mount) tryIncMountedRef() bool {
+ for {
+ refs := atomic.LoadInt64(&mnt.refs)
+ if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted
+ return false
+ }
+ if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) {
+ return true
+ }
+ }
+}
+
+// IncRef increments mnt's reference count.
+func (mnt *Mount) IncRef() {
+ // In general, negative values for mnt.refs are valid because the MSB is
+ // the eager-unmount bit.
+ atomic.AddInt64(&mnt.refs, 1)
+}
+
+// DecRef decrements mnt's reference count.
+func (mnt *Mount) DecRef(ctx context.Context) {
+ refs := atomic.AddInt64(&mnt.refs, -1)
+ if refs&^math.MinInt64 == 0 { // mask out MSB
+ var vd VirtualDentry
+ if mnt.parent() != nil {
+ mnt.vfs.mountMu.Lock()
+ mnt.vfs.mounts.seq.BeginWrite()
+ vd = mnt.vfs.disconnectLocked(mnt)
+ mnt.vfs.mounts.seq.EndWrite()
+ mnt.vfs.mountMu.Unlock()
+ }
+ mnt.root.DecRef(ctx)
+ mnt.fs.DecRef(ctx)
+ if vd.Ok() {
+ vd.DecRef(ctx)
+ }
+ }
+}
+
+// IncRef increments mntns' reference count.
+func (mntns *MountNamespace) IncRef() {
+ if atomic.AddInt64(&mntns.refs, 1) <= 1 {
+ panic("MountNamespace.IncRef() called without holding a reference")
+ }
+}
+
+// DecRef decrements mntns' reference count.
+func (mntns *MountNamespace) DecRef(ctx context.Context) {
+ vfs := mntns.root.fs.VirtualFilesystem()
+ if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 {
+ vfs.mountMu.Lock()
+ vfs.mounts.seq.BeginWrite()
+ vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{
+ disconnectHierarchy: true,
+ }, nil, nil)
+ vfs.mounts.seq.EndWrite()
+ vfs.mountMu.Unlock()
+ for _, vd := range vdsToDecRef {
+ vd.DecRef(ctx)
+ }
+ for _, mnt := range mountsToDecRef {
+ mnt.DecRef(ctx)
+ }
+ } else if refs < 0 {
+ panic("MountNamespace.DecRef() called without holding a reference")
+ }
}
// getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes
@@ -192,7 +534,7 @@ func (vfs *VirtualFilesystem) NewMount(ctx context.Context, creds *auth.Credenti
// getMountAt is analogous to Linux's fs/namei.c:follow_mount().
//
// Preconditions: References are held on mnt and d.
-func (vfs *VirtualFilesystem) getMountAt(mnt *Mount, d *Dentry) *Mount {
+func (vfs *VirtualFilesystem) getMountAt(ctx context.Context, mnt *Mount, d *Dentry) *Mount {
// The first mount is special-cased:
//
// - The caller is assumed to have checked d.isMounted() already. (This
@@ -223,7 +565,7 @@ retryFirst:
// Raced with umount.
continue
}
- mnt.decRef()
+ mnt.DecRef(ctx)
mnt = next
d = next.root
}
@@ -231,12 +573,12 @@ retryFirst:
}
// getMountpointAt returns the mount point for the stack of Mounts including
-// mnt. It takes a reference on the returned Mount and Dentry. If no such mount
+// mnt. It takes a reference on the returned VirtualDentry. If no such mount
// point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil).
//
// Preconditions: References are held on mnt and root. vfsroot is not (mnt,
// mnt.root).
-func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry) (*Mount, *Dentry) {
+func (vfs *VirtualFilesystem) getMountpointAt(ctx context.Context, mnt *Mount, vfsroot VirtualDentry) VirtualDentry {
// The first mount is special-cased:
//
// - The caller must have already checked mnt against vfsroot.
@@ -246,21 +588,26 @@ func (vfs *VirtualFilesystem) getMountpointAt(mnt *Mount, vfsroot VirtualDentry)
// - We don't drop the caller's reference on mnt.
retryFirst:
epoch := vfs.mounts.seq.BeginRead()
- parent, point := mnt.loadKey()
+ parent, point := mnt.parent(), mnt.point()
if !vfs.mounts.seq.ReadOk(epoch) {
goto retryFirst
}
if parent == nil {
- return nil, nil
+ return VirtualDentry{}
}
if !parent.tryIncMountedRef() {
// Raced with umount.
goto retryFirst
}
- if !point.tryIncRef(parent.fs) {
+ if !point.TryIncRef() {
// Since Mount holds a reference on Mount.key.point, this can only
// happen due to a racing change to Mount.key.
- parent.decRef()
+ parent.DecRef(ctx)
+ goto retryFirst
+ }
+ if !vfs.mounts.seq.ReadOk(epoch) {
+ point.DecRef(ctx)
+ parent.DecRef(ctx)
goto retryFirst
}
mnt = parent
@@ -274,7 +621,7 @@ retryFirst:
}
retryNotFirst:
epoch := vfs.mounts.seq.BeginRead()
- parent, point := mnt.loadKey()
+ parent, point := mnt.parent(), mnt.point()
if !vfs.mounts.seq.ReadOk(epoch) {
goto retryNotFirst
}
@@ -285,59 +632,23 @@ retryFirst:
// Raced with umount.
goto retryNotFirst
}
- if !point.tryIncRef(parent.fs) {
+ if !point.TryIncRef() {
// Since Mount holds a reference on Mount.key.point, this can
// only happen due to a racing change to Mount.key.
- parent.decRef()
+ parent.DecRef(ctx)
goto retryNotFirst
}
if !vfs.mounts.seq.ReadOk(epoch) {
- point.decRef(parent.fs)
- parent.decRef()
+ point.DecRef(ctx)
+ parent.DecRef(ctx)
goto retryNotFirst
}
- d.decRef(mnt.fs)
- mnt.decRef()
+ d.DecRef(ctx)
+ mnt.DecRef(ctx)
mnt = parent
d = point
}
- return mnt, d
-}
-
-// tryIncMountedRef increments mnt's reference count and returns true. If mnt's
-// reference count is already zero, or has been eagerly unmounted,
-// tryIncMountedRef does nothing and returns false.
-//
-// tryIncMountedRef does not require that a reference is held on mnt.
-func (mnt *Mount) tryIncMountedRef() bool {
- for {
- refs := atomic.LoadInt64(&mnt.refs)
- if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted
- return false
- }
- if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) {
- return true
- }
- }
-}
-
-func (mnt *Mount) incRef() {
- // In general, negative values for mnt.refs are valid because the MSB is
- // the eager-unmount bit.
- atomic.AddInt64(&mnt.refs, 1)
-}
-
-func (mnt *Mount) decRef() {
- refs := atomic.AddInt64(&mnt.refs, -1)
- if refs&^math.MinInt64 == 0 { // mask out MSB
- parent, point := mnt.loadKey()
- if point != nil {
- point.decRef(parent.fs)
- parent.decRef()
- }
- mnt.root.decRef(mnt.fs)
- mnt.fs.decRef()
- }
+ return VirtualDentry{mnt, d}
}
// CheckBeginWrite increments the counter of in-progress write operations on
@@ -360,7 +671,7 @@ func (mnt *Mount) EndWrite() {
atomic.AddInt64(&mnt.writers, -1)
}
-// Preconditions: VirtualFilesystem.mountMu must be locked for writing.
+// Preconditions: VirtualFilesystem.mountMu must be locked.
func (mnt *Mount) setReadOnlyLocked(ro bool) error {
if oldRO := atomic.LoadInt64(&mnt.writers) < 0; oldRO == ro {
return nil
@@ -377,26 +688,32 @@ func (mnt *Mount) setReadOnlyLocked(ro bool) error {
return nil
}
+func (mnt *Mount) readOnly() bool {
+ return atomic.LoadInt64(&mnt.writers) < 0
+}
+
// Filesystem returns the mounted Filesystem. It does not take a reference on
// the returned Filesystem.
func (mnt *Mount) Filesystem() *Filesystem {
return mnt.fs
}
-// IncRef increments mntns' reference count.
-func (mntns *MountNamespace) IncRef() {
- if atomic.AddInt64(&mntns.refs, 1) <= 1 {
- panic("MountNamespace.IncRef() called without holding a reference")
+// submountsLocked returns this Mount and all Mounts that are descendents of
+// it.
+//
+// Precondition: mnt.vfs.mountMu must be held.
+func (mnt *Mount) submountsLocked() []*Mount {
+ mounts := []*Mount{mnt}
+ for m := range mnt.children {
+ mounts = append(mounts, m.submountsLocked()...)
}
+ return mounts
}
-// DecRef decrements mntns' reference count.
-func (mntns *MountNamespace) DecRef() {
- if refs := atomic.AddInt64(&mntns.refs, 0); refs == 0 {
- // TODO: unmount mntns.root
- } else if refs < 0 {
- panic("MountNamespace.DecRef() called without holding a reference")
- }
+// Root returns the mount's root. It does not take a reference on the returned
+// Dentry.
+func (mnt *Mount) Root() *Dentry {
+ return mnt.root
}
// Root returns mntns' root. A reference is taken on the returned
@@ -409,3 +726,178 @@ func (mntns *MountNamespace) Root() VirtualDentry {
vd.IncRef()
return vd
}
+
+// GenerateProcMounts emits the contents of /proc/[pid]/mounts for vfs to buf.
+//
+// Preconditions: taskRootDir.Ok().
+func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
+ vfs.mountMu.Lock()
+ defer vfs.mountMu.Unlock()
+ rootMnt := taskRootDir.mount
+ mounts := rootMnt.submountsLocked()
+ sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+ for _, mnt := range mounts {
+ // Get the path to this mount relative to task root.
+ mntRootVD := VirtualDentry{
+ mount: mnt,
+ dentry: mnt.root,
+ }
+ path, err := vfs.PathnameReachable(ctx, taskRootDir, mntRootVD)
+ if err != nil {
+ // For some reason we didn't get a path. Log a warning
+ // and run with empty path.
+ ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ path = ""
+ }
+ if path == "" {
+ // Either an error occurred, or path is not reachable
+ // from root.
+ break
+ }
+
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
+ opts += ",noexec"
+ }
+
+ // Format:
+ // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order>
+ //
+ // The "needs dump" and "fsck order" flags are always 0, which
+ // is allowed.
+ fmt.Fprintf(buf, "%s %s %s %s %d %d\n", "none", path, mnt.fs.FilesystemType().Name(), opts, 0, 0)
+ }
+}
+
+// GenerateProcMountInfo emits the contents of /proc/[pid]/mountinfo for vfs to
+// buf.
+//
+// Preconditions: taskRootDir.Ok().
+func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) {
+ vfs.mountMu.Lock()
+ defer vfs.mountMu.Unlock()
+ rootMnt := taskRootDir.mount
+ mounts := rootMnt.submountsLocked()
+ sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID })
+ for _, mnt := range mounts {
+ // Get the path to this mount relative to task root.
+ mntRootVD := VirtualDentry{
+ mount: mnt,
+ dentry: mnt.root,
+ }
+ path, err := vfs.PathnameReachable(ctx, taskRootDir, mntRootVD)
+ if err != nil {
+ // For some reason we didn't get a path. Log a warning
+ // and run with empty path.
+ ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err)
+ path = ""
+ }
+ if path == "" {
+ // Either an error occurred, or path is not reachable
+ // from root.
+ break
+ }
+ // Stat the mount root to get the major/minor device numbers.
+ pop := &PathOperation{
+ Root: mntRootVD,
+ Start: mntRootVD,
+ }
+ statx, err := vfs.StatAt(ctx, auth.NewAnonymousCredentials(), pop, &StatOptions{})
+ if err != nil {
+ // Well that's not good. Ignore this mount.
+ break
+ }
+
+ // Format:
+ // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
+ // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
+
+ // (1) Mount ID.
+ fmt.Fprintf(buf, "%d ", mnt.ID)
+
+ // (2) Parent ID (or this ID if there is no parent).
+ pID := mnt.ID
+ if p := mnt.parent(); p != nil {
+ pID = p.ID
+ }
+ fmt.Fprintf(buf, "%d ", pID)
+
+ // (3) Major:Minor device ID. We don't have a superblock, so we
+ // just use the root inode device number.
+ fmt.Fprintf(buf, "%d:%d ", statx.DevMajor, statx.DevMinor)
+
+ // (4) Root: the pathname of the directory in the filesystem
+ // which forms the root of this mount.
+ //
+ // NOTE(b/78135857): This will always be "/" until we implement
+ // bind mounts.
+ fmt.Fprintf(buf, "/ ")
+
+ // (5) Mount point (relative to process root).
+ fmt.Fprintf(buf, "%s ", manglePath(path))
+
+ // (6) Mount options.
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+ if mnt.Flags.NoATime {
+ opts = ",noatime"
+ }
+ if mnt.Flags.NoExec {
+ opts += ",noexec"
+ }
+ fmt.Fprintf(buf, "%s ", opts)
+
+ // (7) Optional fields: zero or more fields of the form "tag[:value]".
+ // (8) Separator: the end of the optional fields is marked by a single hyphen.
+ fmt.Fprintf(buf, "- ")
+
+ // (9) Filesystem type.
+ fmt.Fprintf(buf, "%s ", mnt.fs.FilesystemType().Name())
+
+ // (10) Mount source: filesystem-specific information or "none".
+ fmt.Fprintf(buf, "none ")
+
+ // (11) Superblock options, and final newline.
+ fmt.Fprintf(buf, "%s\n", superBlockOpts(path, mnt))
+ }
+}
+
+// manglePath replaces ' ', '\t', '\n', and '\\' with their octal equivalents.
+// See Linux fs/seq_file.c:mangle_path.
+func manglePath(p string) string {
+ r := strings.NewReplacer(" ", "\\040", "\t", "\\011", "\n", "\\012", "\\", "\\134")
+ return r.Replace(p)
+}
+
+// superBlockOpts returns the super block options string for the the mount at
+// the given path.
+func superBlockOpts(mountPath string, mnt *Mount) string {
+ // gVisor doesn't (yet) have a concept of super block options, so we
+ // use the ro/rw bit from the mount flag.
+ opts := "rw"
+ if mnt.readOnly() {
+ opts = "ro"
+ }
+
+ // NOTE(b/147673608): If the mount is a cgroup, we also need to include
+ // the cgroup name in the options. For now we just read that from the
+ // path.
+ //
+ // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
+ // should get this value from the cgroup itself, and not rely on the
+ // path.
+ if mnt.fs.FilesystemType().Name() == "cgroup" {
+ splitPath := strings.Split(mountPath, "/")
+ cgroupType := splitPath[len(splitPath)-1]
+ opts += "," + cgroupType
+ }
+ return opts
+}
diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go
index f394d7483..3335e4057 100644
--- a/pkg/sentry/vfs/mount_test.go
+++ b/pkg/sentry/vfs/mount_test.go
@@ -17,8 +17,9 @@ package vfs
import (
"fmt"
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestMountTableLookupEmpty(t *testing.T) {
@@ -37,7 +38,7 @@ func TestMountTableInsertLookup(t *testing.T) {
mt.Init()
mount := &Mount{}
- mount.storeKey(&Mount{}, &Dentry{})
+ mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
mt.Insert(mount)
if m := mt.Lookup(mount.parent(), mount.point()); m != mount {
@@ -54,7 +55,7 @@ func TestMountTableInsertLookup(t *testing.T) {
}
}
-// TODO: concurrent lookup/insertion/removal
+// TODO(gvisor.dev/issue/1035): concurrent lookup/insertion/removal.
// must be powers of 2
var benchNumMounts = []int{1 << 2, 1 << 5, 1 << 8}
@@ -78,18 +79,10 @@ const enableComparativeBenchmarks = false
func newBenchMount() *Mount {
mount := &Mount{}
- mount.storeKey(&Mount{}, &Dentry{})
+ mount.storeKey(VirtualDentry{&Mount{}, &Dentry{}})
return mount
}
-func vdkey(mnt *Mount) VirtualDentry {
- parent, point := mnt.loadKey()
- return VirtualDentry{
- mount: parent,
- dentry: point,
- }
-}
-
func BenchmarkMountTableParallelLookup(b *testing.B) {
for numG, maxG := 1, runtime.GOMAXPROCS(0); numG >= 0 && numG <= maxG; numG *= 2 {
for _, numMounts := range benchNumMounts {
@@ -101,7 +94,7 @@ func BenchmarkMountTableParallelLookup(b *testing.B) {
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
mt.Insert(mount)
- keys = append(keys, vdkey(mount))
+ keys = append(keys, mount.loadKey())
}
var ready sync.WaitGroup
@@ -153,7 +146,7 @@ func BenchmarkMountMapParallelLookup(b *testing.B) {
keys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- key := vdkey(mount)
+ key := mount.loadKey()
ms[key] = mount
keys = append(keys, key)
}
@@ -208,7 +201,7 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) {
keys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- key := vdkey(mount)
+ key := mount.loadKey()
ms.Store(key, mount)
keys = append(keys, key)
}
@@ -290,7 +283,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) {
ms := make(map[VirtualDentry]*Mount)
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- ms[vdkey(mount)] = mount
+ ms[mount.loadKey()] = mount
}
negkeys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
@@ -325,7 +318,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) {
var ms sync.Map
for i := 0; i < numMounts; i++ {
mount := newBenchMount()
- ms.Store(vdkey(mount), mount)
+ ms.Store(mount.loadKey(), mount)
}
negkeys := make([]VirtualDentry, 0, numMounts)
for i := 0; i < numMounts; i++ {
@@ -379,7 +372,7 @@ func BenchmarkMountMapInsert(b *testing.B) {
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms[vdkey(mount)] = mount
+ ms[mount.loadKey()] = mount
}
}
@@ -399,7 +392,7 @@ func BenchmarkMountSyncMapInsert(b *testing.B) {
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms.Store(vdkey(mount), mount)
+ ms.Store(mount.loadKey(), mount)
}
}
@@ -432,13 +425,13 @@ func BenchmarkMountMapRemove(b *testing.B) {
ms := make(map[VirtualDentry]*Mount)
for i := range mounts {
mount := mounts[i]
- ms[vdkey(mount)] = mount
+ ms[mount.loadKey()] = mount
}
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- delete(ms, vdkey(mount))
+ delete(ms, mount.loadKey())
}
}
@@ -454,12 +447,12 @@ func BenchmarkMountSyncMapRemove(b *testing.B) {
var ms sync.Map
for i := range mounts {
mount := mounts[i]
- ms.Store(vdkey(mount), mount)
+ ms.Store(mount.loadKey(), mount)
}
b.ResetTimer()
for i := range mounts {
mount := mounts[i]
- ms.Delete(vdkey(mount))
+ ms.Delete(mount.loadKey())
}
}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index b0511aa40..70f850ca4 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.12
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -26,7 +26,8 @@ import (
"sync/atomic"
"unsafe"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sync"
)
// mountKey represents the location at which a Mount is mounted. It is
@@ -38,16 +39,6 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
-// Invariant: mnt.key's fields are nil. parent and point are non-nil.
-func (mnt *Mount) storeKey(parent *Mount, point *Dentry) {
- atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(parent))
- atomic.StorePointer(&mnt.key.point, unsafe.Pointer(point))
-}
-
-func (mnt *Mount) loadKey() (*Mount, *Dentry) {
- return (*Mount)(atomic.LoadPointer(&mnt.key.parent)), (*Dentry)(atomic.LoadPointer(&mnt.key.point))
-}
-
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,11 +47,26 @@ func (mnt *Mount) point() *Dentry {
return (*Dentry)(atomic.LoadPointer(&mnt.key.point))
}
+func (mnt *Mount) loadKey() VirtualDentry {
+ return VirtualDentry{
+ mount: mnt.parent(),
+ dentry: mnt.point(),
+ }
+}
+
+// Invariant: mnt.key.parent == nil. vd.Ok().
+func (mnt *Mount) storeKey(vd VirtualDentry) {
+ atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
+ atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
+}
+
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
+//
+// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -72,8 +78,8 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq gvsync.SeqCount
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
+ seed uint32 // for hashing keys
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -86,7 +92,7 @@ type mountTable struct {
// length and cap in separate uint32s) for ~free.
size uint64
- slots unsafe.Pointer // []mountSlot; never nil after Init
+ slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
}
type mountSlot struct {
@@ -155,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
loop:
for {
@@ -201,9 +207,19 @@ loop:
// Insert inserts the given mount into mt.
//
-// Preconditions: There are no concurrent mutators of mt. mt must not already
-// contain a Mount with the same mount point and parent.
+// Preconditions: mt must not already contain a Mount with the same mount point
+// and parent.
func (mt *mountTable) Insert(mount *Mount) {
+ mt.seq.BeginWrite()
+ mt.insertSeqed(mount)
+ mt.seq.EndWrite()
+}
+
+// insertSeqed inserts the given mount into mt.
+//
+// Preconditions: mt.seq must be in a writer critical section. mt must not
+// already contain a Mount with the same mount point and parent.
+func (mt *mountTable) insertSeqed(mount *Mount) {
hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
// We're under the maximum load factor if:
@@ -215,10 +231,8 @@ func (mt *mountTable) Insert(mount *Mount) {
tcap := uintptr(1) << order
if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) {
// Atomically insert the new element into the table.
- mt.seq.BeginWrite()
atomic.AddUint64(&mt.size, mtSizeLenOne)
mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash)
- mt.seq.EndWrite()
return
}
@@ -241,8 +255,6 @@ func (mt *mountTable) Insert(mount *Mount) {
for {
oldSlot := (*mountSlot)(oldCur)
if oldSlot.value != nil {
- // Don't need to lock mt.seq yet since newSlots isn't visible
- // to readers.
mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash)
}
if oldCur == oldLast {
@@ -252,11 +264,9 @@ func (mt *mountTable) Insert(mount *Mount) {
}
// Insert the new element into the new table.
mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash)
- // Atomically switch to the new table.
- mt.seq.BeginWrite()
+ // Switch to the new table.
atomic.AddUint64(&mt.size, mtSizeLenOne|mtSizeOrderOne)
atomic.StorePointer(&mt.slots, newSlots)
- mt.seq.EndWrite()
}
// Preconditions: There are no concurrent mutators of the table (slots, cap).
@@ -294,9 +304,18 @@ func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, has
// Remove removes the given mount from mt.
//
-// Preconditions: There are no concurrent mutators of mt. mt must contain
-// mount.
+// Preconditions: mt must contain mount.
func (mt *mountTable) Remove(mount *Mount) {
+ mt.seq.BeginWrite()
+ mt.removeSeqed(mount)
+ mt.seq.EndWrite()
+}
+
+// removeSeqed removes the given mount from mt.
+//
+// Preconditions: mt.seq must be in a writer critical section. mt must contain
+// mount.
+func (mt *mountTable) removeSeqed(mount *Mount) {
hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
@@ -311,7 +330,6 @@ func (mt *mountTable) Remove(mount *Mount) {
// backward until we either find an empty slot, or an element that
// is already in its first-probed slot. (This is backward shift
// deletion.)
- mt.seq.BeginWrite()
for {
nextOff := (off + mountSlotBytes) & offmask
nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff))
@@ -330,7 +348,6 @@ func (mt *mountTable) Remove(mount *Mount) {
}
atomic.StorePointer(&slot.value, nil)
atomic.AddUint64(&mt.size, mtSizeLenNegOne)
- mt.seq.EndWrite()
return
}
if checkInvariants && slotValue == nil {
@@ -345,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
//go:linkname rand32 runtime.fastrand
func rand32() uint32
-
-// This is copy/pasted from runtime.noescape(), and is needed because arguments
-// apparently escape from all functions defined by linkname.
-//
-//go:nosplit
-func noescape(p unsafe.Pointer) unsafe.Pointer {
- x := uintptr(p)
- return unsafe.Pointer(x ^ 0)
-}
diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go
index 3aa73d911..dfc8573fd 100644
--- a/pkg/sentry/vfs/options.go
+++ b/pkg/sentry/vfs/options.go
@@ -16,6 +16,7 @@ package vfs
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
)
// GetDentryOptions contains options to VirtualFilesystem.GetDentryAt() and
@@ -32,6 +33,25 @@ type GetDentryOptions struct {
type MkdirOptions struct {
// Mode is the file mode bits for the created directory.
Mode linux.FileMode
+
+ // If ForSyntheticMountpoint is true, FilesystemImpl.MkdirAt() may create
+ // the given directory in memory only (as opposed to persistent storage).
+ // The created directory should be able to support the creation of
+ // subdirectories with ForSyntheticMountpoint == true. It does not need to
+ // support the creation of subdirectories with ForSyntheticMountpoint ==
+ // false, or files of other types.
+ //
+ // FilesystemImpls are permitted to ignore the ForSyntheticMountpoint
+ // option.
+ //
+ // The ForSyntheticMountpoint option exists because, unlike mount(2), the
+ // OCI Runtime Specification permits the specification of mount points that
+ // do not exist, under the expectation that container runtimes will create
+ // them. (More accurately, the OCI Runtime Specification completely fails
+ // to document this feature, but it's implemented by runc.)
+ // ForSyntheticMountpoint allows such mount points to be created even when
+ // the underlying persistent filesystem is immutable.
+ ForSyntheticMountpoint bool
}
// MknodOptions contains options to VirtualFilesystem.MknodAt() and
@@ -44,6 +64,48 @@ type MknodOptions struct {
// DevMinor are the major and minor device numbers for the created device.
DevMajor uint32
DevMinor uint32
+
+ // Endpoint is the endpoint to bind to the created file, if a socket file is
+ // being created for bind(2) on a Unix domain socket.
+ Endpoint transport.BoundEndpoint
+}
+
+// MountFlags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+// MS_RDONLY is not part of MountFlags because it's tracked in Mount.writers.
+type MountFlags struct {
+ // NoExec is equivalent to MS_NOEXEC.
+ NoExec bool
+
+ // NoATime is equivalent to MS_NOATIME and indicates that the
+ // filesystem should not update access time in-place.
+ NoATime bool
+
+ // NoDev is equivalent to MS_NODEV and indicates that the
+ // filesystem should not allow access to devices (special files).
+ // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE
+ // filesystems.
+ NoDev bool
+
+ // NoSUID is equivalent to MS_NOSUID and indicates that the
+ // filesystem should not honor set-user-ID and set-group-ID bits or
+ // file capabilities when executing programs.
+ NoSUID bool
+}
+
+// MountOptions contains options to VirtualFilesystem.MountAt().
+type MountOptions struct {
+ // Flags contains flags as specified for mount(2), e.g. MS_NOEXEC.
+ Flags MountFlags
+
+ // ReadOnly is equivalent to MS_RDONLY.
+ ReadOnly bool
+
+ // GetFilesystemOptions contains options to FilesystemType.GetFilesystem().
+ GetFilesystemOptions GetFilesystemOptions
+
+ // If InternalMount is true, allow the use of filesystem types for which
+ // RegisterFilesystemTypeOptions.AllowUserMount == false.
+ InternalMount bool
}
// OpenOptions contains options to VirtualFilesystem.OpenAt() and
@@ -51,7 +113,7 @@ type MknodOptions struct {
type OpenOptions struct {
// Flags contains access mode and flags as specified for open(2).
//
- // FilesystemImpls is reponsible for implementing the following flags:
+ // FilesystemImpls are responsible for implementing the following flags:
// O_RDONLY, O_WRONLY, O_RDWR, O_APPEND, O_CREAT, O_DIRECT, O_DSYNC,
// O_EXCL, O_NOATIME, O_NOCTTY, O_NONBLOCK, O_PATH, O_SYNC, O_TMPFILE, and
// O_TRUNC. VFS is responsible for handling O_DIRECTORY, O_LARGEFILE, and
@@ -62,6 +124,12 @@ type OpenOptions struct {
// If FilesystemImpl.OpenAt() creates a file, Mode is the file mode for the
// created file.
Mode linux.FileMode
+
+ // FileExec is set when the file is being opened to be executed.
+ // VirtualFilesystem.OpenAt() checks that the caller has execute permissions
+ // on the file, that the file is a regular file, and that the mount doesn't
+ // have MS_NOEXEC set.
+ FileExec bool
}
// ReadOptions contains options to FileDescription.PRead(),
@@ -77,6 +145,9 @@ type ReadOptions struct {
type RenameOptions struct {
// Flags contains flags as specified for renameat2(2).
Flags uint32
+
+ // If MustBeDir is true, the renamed file must be a directory.
+ MustBeDir bool
}
// SetStatOptions contains options to VirtualFilesystem.SetStatAt(),
@@ -93,6 +164,58 @@ type SetStatOptions struct {
// == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask
// instead).
Stat linux.Statx
+
+ // NeedWritePerm indicates that write permission on the file is needed for
+ // this operation. This is needed for truncate(2) (note that ftruncate(2)
+ // does not require the same check--instead, it checks that the fd is
+ // writable).
+ NeedWritePerm bool
+}
+
+// BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt()
+// and FilesystemImpl.BoundEndpointAt().
+type BoundEndpointOptions struct {
+ // Addr is the path of the file whose socket endpoint is being retrieved.
+ // It is generally irrelevant: most endpoints are stored at a dentry that
+ // was created through a bind syscall, so the path can be stored on creation.
+ // However, if the endpoint was created in FilesystemImpl.BoundEndpointAt(),
+ // then we may not know what the original bind address was.
+ //
+ // For example, if connect(2) is called with address "foo" which corresponds
+ // a remote named socket in goferfs, we need to generate an endpoint wrapping
+ // that file. In this case, we can use Addr to set the endpoint address to
+ // "foo". Note that Addr is only a best-effort attempt--we still do not know
+ // the exact address that was used on the remote fs to bind the socket (it
+ // may have been "foo", "./foo", etc.).
+ Addr string
+}
+
+// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(),
+// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and
+// FileDescriptionImpl.Getxattr().
+type GetxattrOptions struct {
+ // Name is the name of the extended attribute to retrieve.
+ Name string
+
+ // Size is the maximum value size that the caller will tolerate. If the value
+ // is larger than size, getxattr methods may return ERANGE, but they are also
+ // free to ignore the hint entirely (i.e. the value returned may be larger
+ // than size). All size checking is done independently at the syscall layer.
+ Size uint64
+}
+
+// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(),
+// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and
+// FileDescriptionImpl.Setxattr().
+type SetxattrOptions struct {
+ // Name is the name of the extended attribute being mutated.
+ Name string
+
+ // Value is the extended attribute's new value.
+ Value string
+
+ // Flags contains flags as specified for setxattr/lsetxattr/fsetxattr(2).
+ Flags uint32
}
// StatOptions contains options to VirtualFilesystem.StatAt(),
@@ -114,6 +237,12 @@ type StatOptions struct {
Sync uint32
}
+// UmountOptions contains options to VirtualFilesystem.UmountAt().
+type UmountOptions struct {
+ // Flags contains flags as specified for umount2(2).
+ Flags uint32
+}
+
// WriteOptions contains options to FileDescription.PWrite(),
// FileDescriptionImpl.PWrite(), FileDescription.Write(), and
// FileDescriptionImpl.Write().
diff --git a/pkg/sentry/vfs/pathname.go b/pkg/sentry/vfs/pathname.go
new file mode 100644
index 000000000..e4da15009
--- /dev/null
+++ b/pkg/sentry/vfs/pathname.go
@@ -0,0 +1,195 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+var fspathBuilderPool = sync.Pool{
+ New: func() interface{} {
+ return &fspath.Builder{}
+ },
+}
+
+func getFSPathBuilder() *fspath.Builder {
+ return fspathBuilderPool.Get().(*fspath.Builder)
+}
+
+func putFSPathBuilder(b *fspath.Builder) {
+ // No methods can be called on b after b.String(), so reset it to its zero
+ // value (as returned by fspathBuilderPool.New) instead.
+ *b = fspath.Builder{}
+ fspathBuilderPool.Put(b)
+}
+
+// PathnameWithDeleted returns an absolute pathname to vd, consistent with
+// Linux's d_path(). In particular, if vd.Dentry() has been disowned,
+// PathnameWithDeleted appends " (deleted)" to the returned pathname.
+func (vfs *VirtualFilesystem) PathnameWithDeleted(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef(ctx)
+ }
+ }()
+
+ origD := vd.dentry
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ // genericfstree.PrependPath() will have returned
+ // PrependPathAtVFSRootError in this case since it checks
+ // against vfsroot before mnt.root, but other implementations
+ // of FilesystemImpl.PrependPath() may return nil instead.
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ break loop
+ }
+ if haveRef {
+ vd.DecRef(ctx)
+ }
+ vd = nextVD
+ haveRef = true
+ // continue loop
+ case PrependPathSyntheticError:
+ // Skip prepending "/" and appending " (deleted)".
+ return b.String(), nil
+ case PrependPathAtVFSRootError, PrependPathAtNonMountRootError:
+ break loop
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ if origD.IsDead() {
+ b.AppendString(" (deleted)")
+ }
+ return b.String(), nil
+}
+
+// PathnameReachable returns an absolute pathname to vd, consistent with
+// Linux's __d_path() (as used by seq_path_root()). If vfsroot.Ok() and vd is
+// not reachable from vfsroot, such that seq_path_root() would return SEQ_SKIP
+// (causing the entire containing entry to be skipped), PathnameReachable
+// returns ("", nil).
+func (vfs *VirtualFilesystem) PathnameReachable(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef(ctx)
+ }
+ }()
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ return "", nil
+ }
+ if haveRef {
+ vd.DecRef(ctx)
+ }
+ vd = nextVD
+ haveRef = true
+ case PrependPathAtVFSRootError:
+ break loop
+ case PrependPathAtNonMountRootError, PrependPathSyntheticError:
+ return "", nil
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ return b.String(), nil
+}
+
+// PathnameForGetcwd returns an absolute pathname to vd, consistent with
+// Linux's sys_getcwd().
+func (vfs *VirtualFilesystem) PathnameForGetcwd(ctx context.Context, vfsroot, vd VirtualDentry) (string, error) {
+ if vd.dentry.IsDead() {
+ return "", syserror.ENOENT
+ }
+
+ b := getFSPathBuilder()
+ defer putFSPathBuilder(b)
+ haveRef := false
+ defer func() {
+ if haveRef {
+ vd.DecRef(ctx)
+ }
+ }()
+ unreachable := false
+loop:
+ for {
+ err := vd.mount.fs.impl.PrependPath(ctx, vfsroot, vd, b)
+ switch err.(type) {
+ case nil:
+ if vd.mount == vfsroot.mount && vd.mount.root == vfsroot.dentry {
+ break loop
+ }
+ nextVD := vfs.getMountpointAt(ctx, vd.mount, vfsroot)
+ if !nextVD.Ok() {
+ unreachable = true
+ break loop
+ }
+ if haveRef {
+ vd.DecRef(ctx)
+ }
+ vd = nextVD
+ haveRef = true
+ case PrependPathAtVFSRootError:
+ break loop
+ case PrependPathAtNonMountRootError, PrependPathSyntheticError:
+ unreachable = true
+ break loop
+ default:
+ return "", err
+ }
+ }
+ b.PrependByte('/')
+ if unreachable {
+ b.PrependString("(unreachable)")
+ }
+ return b.String(), nil
+}
+
+// As of this writing, we do not have equivalents to:
+//
+// - d_absolute_path(), which returns EINVAL if (effectively) any call to
+// FilesystemImpl.PrependPath() would return PrependPathAtNonMountRootError.
+//
+// - dentry_path(), which does not walk up mounts (and only returns the path
+// relative to Filesystem root), but also appends "//deleted" for disowned
+// Dentries.
+//
+// These should be added as necessary.
diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go
index f8e74355c..33389c1df 100644
--- a/pkg/sentry/vfs/permissions.go
+++ b/pkg/sentry/vfs/permissions.go
@@ -15,8 +15,12 @@
package vfs
import (
+ "math"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -25,23 +29,44 @@ type AccessTypes uint16
// Bits in AccessTypes.
const (
+ MayExec AccessTypes = 1
+ MayWrite AccessTypes = 2
MayRead AccessTypes = 4
- MayWrite = 2
- MayExec = 1
)
+// OnlyRead returns true if access _only_ allows read.
+func (a AccessTypes) OnlyRead() bool {
+ return a == MayRead
+}
+
+// MayRead returns true if access allows read.
+func (a AccessTypes) MayRead() bool {
+ return a&MayRead != 0
+}
+
+// MayWrite returns true if access allows write.
+func (a AccessTypes) MayWrite() bool {
+ return a&MayWrite != 0
+}
+
+// MayExec returns true if access allows exec.
+func (a AccessTypes) MayExec() bool {
+ return a&MayExec != 0
+}
+
// GenericCheckPermissions checks that creds has the given access rights on a
// file with the given permissions, UID, and GID, subject to the rules of
-// fs/namei.c:generic_permission(). isDir is true if the file is a directory.
-func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir bool, mode uint16, kuid auth.KUID, kgid auth.KGID) error {
+// fs/namei.c:generic_permission().
+func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
// Check permission bits.
- perms := mode
+ perms := uint16(mode.Permissions())
if creds.EffectiveKUID == kuid {
perms >>= 6
} else if creds.InGroup(kgid) {
perms >>= 3
}
if uint16(ats)&perms == uint16(ats) {
+ // All permission bits match, access granted.
return nil
}
@@ -53,7 +78,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
}
// CAP_DAC_READ_SEARCH allows the caller to read and search arbitrary
// directories, and read arbitrary non-directory files.
- if (isDir && (ats&MayWrite == 0)) || ats == MayRead {
+ if (mode.IsDir() && !ats.MayWrite()) || ats.OnlyRead() {
if creds.HasCapability(linux.CAP_DAC_READ_SEARCH) {
return nil
}
@@ -61,7 +86,7 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
// CAP_DAC_OVERRIDE allows arbitrary access to directories, read/write
// access to non-directory files, and execute access to non-directory files
// for which at least one execute bit is set.
- if isDir || (ats&MayExec == 0) || (mode&0111 != 0) {
+ if mode.IsDir() || !ats.MayExec() || (mode.Permissions()&0111 != 0) {
if creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
return nil
}
@@ -69,32 +94,67 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, isDir boo
return syserror.EACCES
}
+// MayLink determines whether creating a hard link to a file with the given
+// mode, kuid, and kgid is permitted.
+//
+// This corresponds to Linux's fs/namei.c:may_linkat.
+func MayLink(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ // Source inode owner can hardlink all they like; otherwise, it must be a
+ // safe source.
+ if CanActAsOwner(creds, kuid) {
+ return nil
+ }
+
+ // Only regular files can be hard linked.
+ if mode.FileType() != linux.S_IFREG {
+ return syserror.EPERM
+ }
+
+ // Setuid files should not get pinned to the filesystem.
+ if mode&linux.S_ISUID != 0 {
+ return syserror.EPERM
+ }
+
+ // Executable setgid files should not get pinned to the filesystem, but we
+ // don't support S_IXGRP anyway.
+
+ // Hardlinking to unreadable or unwritable sources is dangerous.
+ if err := GenericCheckPermissions(creds, MayRead|MayWrite, mode, kuid, kgid); err != nil {
+ return syserror.EPERM
+ }
+ return nil
+}
+
// AccessTypesForOpenFlags returns the access types required to open a file
// with the given OpenOptions.Flags. Note that this is NOT the same thing as
// the set of accesses permitted for the opened file:
//
// - O_TRUNC causes MayWrite to be set in the returned AccessTypes (since it
-// mutates the file), but does not permit the opened to write to the file
+// mutates the file), but does not permit writing to the open file description
// thereafter.
//
// - "Linux reserves the special, nonstandard access mode 3 (binary 11) in
// flags to mean: check for read and write permission on the file and return a
// file descriptor that can't be used for reading or writing." - open(2). Thus
-// AccessTypesForOpenFlags returns MayRead|MayWrite in this case, but
-// filesystems are responsible for ensuring that access is denied.
+// AccessTypesForOpenFlags returns MayRead|MayWrite in this case.
//
// Use May{Read,Write}FileWithOpenFlags() for these checks instead.
-func AccessTypesForOpenFlags(flags uint32) AccessTypes {
- switch flags & linux.O_ACCMODE {
+func AccessTypesForOpenFlags(opts *OpenOptions) AccessTypes {
+ ats := AccessTypes(0)
+ if opts.FileExec {
+ ats |= MayExec
+ }
+
+ switch opts.Flags & linux.O_ACCMODE {
case linux.O_RDONLY:
- if flags&linux.O_TRUNC != 0 {
- return MayRead | MayWrite
+ if opts.Flags&linux.O_TRUNC != 0 {
+ return ats | MayRead | MayWrite
}
- return MayRead
+ return ats | MayRead
case linux.O_WRONLY:
- return MayWrite
+ return ats | MayWrite
default:
- return MayRead | MayWrite
+ return ats | MayRead | MayWrite
}
}
@@ -119,3 +179,108 @@ func MayWriteFileWithOpenFlags(flags uint32) bool {
return false
}
}
+
+// CheckSetStat checks that creds has permission to change the metadata of a
+// file with the given permissions, UID, and GID as specified by stat, subject
+// to the rules of Linux's fs/attr.c:setattr_prepare().
+func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error {
+ stat := &opts.Stat
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ limit, err := CheckLimit(ctx, 0, int64(stat.Size))
+ if err != nil {
+ return err
+ }
+ if limit < int64(stat.Size) {
+ return syserror.ErrExceedsFileSizeLimit
+ }
+ }
+ if stat.Mask&linux.STATX_MODE != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ return syserror.EPERM
+ }
+ // TODO(b/30815691): "If the calling process is not privileged (Linux:
+ // does not have the CAP_FSETID capability), and the group of the file
+ // does not match the effective group ID of the process or one of its
+ // supplementary group IDs, the S_ISGID bit will be turned off, but
+ // this will not cause an error to be returned." - chmod(2)
+ }
+ if stat.Mask&linux.STATX_UID != 0 {
+ if !((creds.EffectiveKUID == kuid && auth.KUID(stat.UID) == kuid) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if stat.Mask&linux.STATX_GID != 0 {
+ if !((creds.EffectiveKUID == kuid && creds.InGroup(auth.KGID(stat.GID))) ||
+ HasCapabilityOnFile(creds, linux.CAP_CHOWN, kuid, kgid)) {
+ return syserror.EPERM
+ }
+ }
+ if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) {
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
+ if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 {
+ if !CanActAsOwner(creds, kuid) {
+ if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec != linux.UTIME_NOW) ||
+ (stat.Mask&linux.STATX_CTIME != 0 && stat.Ctime.Nsec != linux.UTIME_NOW) {
+ return syserror.EPERM
+ }
+ if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// CheckDeleteSticky checks whether the sticky bit is set on a directory with
+// the given file mode, and if so, checks whether creds has permission to
+// remove a file owned by childKUID from a directory with the given mode.
+// CheckDeleteSticky is consistent with fs/linux.h:check_sticky().
+func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, childKUID auth.KUID) error {
+ if parentMode&linux.ModeSticky == 0 {
+ return nil
+ }
+ if CanActAsOwner(creds, childKUID) {
+ return nil
+ }
+ return syserror.EPERM
+}
+
+// CanActAsOwner returns true if creds can act as the owner of a file with the
+// given owning UID, consistent with Linux's
+// fs/inode.c:inode_owner_or_capable().
+func CanActAsOwner(creds *auth.Credentials, kuid auth.KUID) bool {
+ if creds.EffectiveKUID == kuid {
+ return true
+ }
+ return creds.HasCapability(linux.CAP_FOWNER) && creds.UserNamespace.MapFromKUID(kuid).Ok()
+}
+
+// HasCapabilityOnFile returns true if creds has the given capability with
+// respect to a file with the given owning UID and GID, consistent with Linux's
+// kernel/capability.c:capable_wrt_inode_uidgid().
+func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth.KUID, kgid auth.KGID) bool {
+ return creds.HasCapability(cp) && creds.UserNamespace.MapFromKUID(kuid).Ok() && creds.UserNamespace.MapFromKGID(kgid).Ok()
+}
+
+// CheckLimit enforces file size rlimits. It returns error if the write
+// operation must not proceed. Otherwise it returns the max length allowed to
+// without violating the limit.
+func CheckLimit(ctx context.Context, offset, size int64) (int64, error) {
+ fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur
+ if fileSizeLimit > math.MaxInt64 {
+ return size, nil
+ }
+ if offset >= int64(fileSizeLimit) {
+ return 0, syserror.ErrExceedsFileSizeLimit
+ }
+ remaining := int64(fileSizeLimit) - offset
+ if remaining < size {
+ return remaining, nil
+ }
+ return size, nil
+}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 8d05c8583..3304372d9 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -16,11 +16,12 @@ package vfs
import (
"fmt"
- "sync"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -29,7 +30,9 @@ import (
//
// From the perspective of FilesystemImpl methods, a ResolvingPath represents a
// starting Dentry on the associated Filesystem (on which a reference is
-// already held) and a stream of path components relative to that Dentry.
+// already held), a stream of path components relative to that Dentry, and
+// elements of the invoking Context that are commonly required by
+// FilesystemImpl methods.
//
// ResolvingPath is loosely analogous to Linux's struct nameidata.
type ResolvingPath struct {
@@ -85,11 +88,11 @@ func init() {
// so error "constants" are really mutable vars, necessitating somewhat
// expensive interface object comparisons.
-type resolveMountRootError struct{}
+type resolveMountRootOrJumpError struct{}
// Error implements error.Error.
-func (resolveMountRootError) Error() string {
- return "resolving mount root"
+func (resolveMountRootOrJumpError) Error() string {
+ return "resolving mount root or jump"
}
type resolveMountPointError struct{}
@@ -112,57 +115,53 @@ var resolvingPathPool = sync.Pool{
},
}
-func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) (*ResolvingPath, error) {
- path, err := fspath.Parse(pop.Pathname)
- if err != nil {
- return nil, err
- }
+func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath {
rp := resolvingPathPool.Get().(*ResolvingPath)
rp.vfs = vfs
rp.root = pop.Root
rp.mount = pop.Start.mount
rp.start = pop.Start.dentry
- rp.pit = path.Begin
+ rp.pit = pop.Path.Begin
rp.flags = 0
if pop.FollowFinalSymlink {
rp.flags |= rpflagsFollowFinalSymlink
}
- rp.mustBeDir = path.Dir
- rp.mustBeDirOrig = path.Dir
+ rp.mustBeDir = pop.Path.Dir
+ rp.mustBeDirOrig = pop.Path.Dir
rp.symlinks = 0
rp.curPart = 0
rp.numOrigParts = 1
rp.creds = creds
- rp.parts[0] = path.Begin
- rp.origParts[0] = path.Begin
- return rp, nil
+ rp.parts[0] = pop.Path.Begin
+ rp.origParts[0] = pop.Path.Begin
+ return rp
}
-func (vfs *VirtualFilesystem) putResolvingPath(rp *ResolvingPath) {
+func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) {
rp.root = VirtualDentry{}
- rp.decRefStartAndMount()
+ rp.decRefStartAndMount(ctx)
rp.mount = nil
rp.start = nil
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
resolvingPathPool.Put(rp)
}
-func (rp *ResolvingPath) decRefStartAndMount() {
+func (rp *ResolvingPath) decRefStartAndMount(ctx context.Context) {
if rp.flags&rpflagsHaveStartRef != 0 {
- rp.start.decRef(rp.mount.fs)
+ rp.start.DecRef(ctx)
}
if rp.flags&rpflagsHaveMountRef != 0 {
- rp.mount.decRef()
+ rp.mount.DecRef(ctx)
}
}
-func (rp *ResolvingPath) releaseErrorState() {
+func (rp *ResolvingPath) releaseErrorState(ctx context.Context) {
if rp.nextStart != nil {
- rp.nextStart.decRef(rp.nextMount.fs)
+ rp.nextStart.DecRef(ctx)
rp.nextStart = nil
}
if rp.nextMount != nil {
- rp.nextMount.decRef()
+ rp.nextMount.DecRef(ctx)
rp.nextMount = nil
}
}
@@ -232,19 +231,19 @@ func (rp *ResolvingPath) Advance() {
rp.pit = next
} else { // at end of path segment, continue with next one
rp.curPart--
- rp.pit = rp.parts[rp.curPart-1]
+ rp.pit = rp.parts[rp.curPart]
}
}
// Restart resets the stream of path components represented by rp to its state
// on entry to the current FilesystemImpl method.
-func (rp *ResolvingPath) Restart() {
+func (rp *ResolvingPath) Restart(ctx context.Context) {
rp.pit = rp.origParts[rp.numOrigParts-1]
rp.mustBeDir = rp.mustBeDirOrig
rp.symlinks = rp.symlinksOrig
rp.curPart = rp.numOrigParts - 1
copy(rp.parts[:], rp.origParts[:rp.numOrigParts])
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
}
func (rp *ResolvingPath) relpathCommit() {
@@ -255,88 +254,67 @@ func (rp *ResolvingPath) relpathCommit() {
rp.origParts[rp.curPart] = rp.pit
}
-// ResolveParent returns the VFS parent of d. It does not take a reference on
-// the returned Dentry.
-//
-// Preconditions: There are no concurrent mutators of d.
-//
-// Postconditions: If the returned error is nil, then the returned Dentry is
-// not nil.
-func (rp *ResolvingPath) ResolveParent(d *Dentry) (*Dentry, error) {
- var parent *Dentry
+// CheckRoot is called before resolving the parent of the Dentry d. If the
+// Dentry is contextually a VFS root, such that path resolution should treat
+// d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the
+// root of a non-root mount, such that path resolution should switch to another
+// Mount, CheckRoot returns (unspecified, non-nil error). Otherwise, path
+// resolution should resolve d's parent normally, and CheckRoot returns (false,
+// nil).
+func (rp *ResolvingPath) CheckRoot(ctx context.Context, d *Dentry) (bool, error) {
if d == rp.root.dentry && rp.mount == rp.root.mount {
- // At contextual VFS root.
- parent = d
+ // At contextual VFS root (due to e.g. chroot(2)).
+ return true, nil
} else if d == rp.mount.root {
// At mount root ...
- mnt, mntpt := rp.vfs.getMountpointAt(rp.mount, rp.root)
- if mnt != nil {
+ vd := rp.vfs.getMountpointAt(ctx, rp.mount, rp.root)
+ if vd.Ok() {
// ... of non-root mount.
- rp.nextMount = mnt
- rp.nextStart = mntpt
- return nil, resolveMountRootError{}
+ rp.nextMount = vd.mount
+ rp.nextStart = vd.dentry
+ return false, resolveMountRootOrJumpError{}
}
// ... of root mount.
- parent = d
- } else if d.parent == nil {
- // At filesystem root.
- parent = d
- } else {
- parent = d.parent
- }
- if parent.isMounted() {
- if mnt := rp.vfs.getMountAt(rp.mount, parent); mnt != nil {
- rp.nextMount = mnt
- return nil, resolveMountPointError{}
- }
+ return true, nil
}
- return parent, nil
+ return false, nil
}
-// ResolveChild returns the VFS child of d with the given name. It does not
-// take a reference on the returned Dentry. If no such child exists,
-// ResolveChild returns (nil, nil).
-//
-// Preconditions: There are no concurrent mutators of d.
-func (rp *ResolvingPath) ResolveChild(d *Dentry, name string) (*Dentry, error) {
- child := d.children[name]
- if child == nil {
- return nil, nil
+// CheckMount is called after resolving the parent or child of another Dentry
+// to d. If d is a mount point, such that path resolution should switch to
+// another Mount, CheckMount returns a non-nil error. Otherwise, CheckMount
+// returns nil.
+func (rp *ResolvingPath) CheckMount(ctx context.Context, d *Dentry) error {
+ if !d.isMounted() {
+ return nil
}
- if child.isMounted() {
- if mnt := rp.vfs.getMountAt(rp.mount, child); mnt != nil {
- rp.nextMount = mnt
- return nil, resolveMountPointError{}
- }
- }
- return child, nil
-}
-
-// ResolveComponent returns the Dentry reached by starting at d and resolving
-// the current path component in the stream represented by rp. It does not
-// advance the stream. It does not take a reference on the returned Dentry. If
-// no such Dentry exists, ResolveComponent returns (nil, nil).
-//
-// Preconditions: !rp.Done(). There are no concurrent mutators of d.
-func (rp *ResolvingPath) ResolveComponent(d *Dentry) (*Dentry, error) {
- switch pc := rp.Component(); pc {
- case ".":
- return d, nil
- case "..":
- return rp.ResolveParent(d)
- default:
- return rp.ResolveChild(d, pc)
+ if mnt := rp.vfs.getMountAt(ctx, rp.mount, d); mnt != nil {
+ rp.nextMount = mnt
+ return resolveMountPointError{}
}
+ return nil
}
// ShouldFollowSymlink returns true if, supposing that the current path
// component in pcs represents a symbolic link, the symbolic link should be
// followed.
//
+// If path is terminated with '/', the '/' is considered the last element and
+// any symlink before that is followed:
+// - For most non-creating walks, the last path component is handled by
+// fs/namei.c:lookup_last(), which sets LOOKUP_FOLLOW if the first byte
+// after the path component is non-NULL (which is only possible if it's '/')
+// and the path component is of type LAST_NORM.
+//
+// - For open/openat/openat2 without O_CREAT, the last path component is
+// handled by fs/namei.c:do_last(), which does the same, though without the
+// LAST_NORM check.
+//
// Preconditions: !rp.Done().
func (rp *ResolvingPath) ShouldFollowSymlink() bool {
- // Non-final symlinks are always followed.
- return rp.flags&rpflagsFollowFinalSymlink != 0 || !rp.Final()
+ // Non-final symlinks are always followed. Paths terminated with '/' are also
+ // always followed.
+ return rp.flags&rpflagsFollowFinalSymlink != 0 || !rp.Final() || rp.MustBeDir()
}
// HandleSymlink is called when the current path component is a symbolic link
@@ -345,29 +323,34 @@ func (rp *ResolvingPath) ShouldFollowSymlink() bool {
// symlink target and returns nil. Otherwise it returns a non-nil error.
//
// Preconditions: !rp.Done().
+//
+// Postconditions: If HandleSymlink returns a nil error, then !rp.Done().
func (rp *ResolvingPath) HandleSymlink(target string) error {
if rp.symlinks >= linux.MaxSymlinkTraversals {
return syserror.ELOOP
}
- targetPath, err := fspath.Parse(target)
- if err != nil {
- return err
+ if len(target) == 0 {
+ return syserror.ENOENT
}
rp.symlinks++
+ targetPath := fspath.Parse(target)
if targetPath.Absolute {
rp.absSymlinkTarget = targetPath
return resolveAbsSymlinkError{}
}
- if !targetPath.Begin.Ok() {
- panic(fmt.Sprintf("symbolic link has non-empty target %q that is both relative and has no path components?", target))
- }
// Consume the path component that represented the symlink.
rp.Advance()
// Prepend the symlink target to the relative path.
+ if checkInvariants {
+ if !targetPath.HasComponents() {
+ panic(fmt.Sprintf("non-empty pathname %q parsed to relative path with no components", target))
+ }
+ }
rp.relpathPrepend(targetPath)
return nil
}
+// Preconditions: path.HasComponents().
func (rp *ResolvingPath) relpathPrepend(path fspath.Path) {
if rp.pit.Ok() {
rp.parts[rp.curPart] = rp.pit
@@ -385,12 +368,33 @@ func (rp *ResolvingPath) relpathPrepend(path fspath.Path) {
}
}
-func (rp *ResolvingPath) handleError(err error) bool {
+// HandleJump is called when the current path component is a "magic" link to
+// the given VirtualDentry, like /proc/[pid]/fd/[fd]. If the calling Filesystem
+// method should continue path traversal, HandleMagicSymlink updates the path
+// component stream to reflect the magic link target and returns nil. Otherwise
+// it returns a non-nil error.
+//
+// Preconditions: !rp.Done().
+func (rp *ResolvingPath) HandleJump(target VirtualDentry) error {
+ if rp.symlinks >= linux.MaxSymlinkTraversals {
+ return syserror.ELOOP
+ }
+ rp.symlinks++
+ // Consume the path component that represented the magic link.
+ rp.Advance()
+ // Unconditionally return a resolveMountRootOrJumpError, even if the Mount
+ // isn't changing, to force restarting at the new Dentry.
+ target.IncRef()
+ rp.nextMount = target.mount
+ rp.nextStart = target.dentry
+ return resolveMountRootOrJumpError{}
+}
+
+func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool {
switch err.(type) {
- case resolveMountRootError:
- // Switch to the new Mount. We hold references on the Mount and Dentry
- // (from VFS.getMountpointAt()).
- rp.decRefStartAndMount()
+ case resolveMountRootOrJumpError:
+ // Switch to the new Mount. We hold references on the Mount and Dentry.
+ rp.decRefStartAndMount(ctx)
rp.mount = rp.nextMount
rp.start = rp.nextStart
rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef
@@ -407,10 +411,9 @@ func (rp *ResolvingPath) handleError(err error) bool {
return true
case resolveMountPointError:
- // Switch to the new Mount. We hold a reference on the Mount (from
- // VFS.getMountAt()), but borrow the reference on the mount root from
- // the Mount.
- rp.decRefStartAndMount()
+ // Switch to the new Mount. We hold a reference on the Mount, but
+ // borrow the reference on the mount root from the Mount.
+ rp.decRefStartAndMount(ctx)
rp.mount = rp.nextMount
rp.start = rp.nextMount.root
rp.flags = rp.flags&^rpflagsHaveStartRef | rpflagsHaveMountRef
@@ -421,12 +424,12 @@ func (rp *ResolvingPath) handleError(err error) bool {
// path.
rp.relpathCommit()
// Restart path resolution on the new Mount.
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
return true
case resolveAbsSymlinkError:
// Switch to the new Mount. References are borrowed from rp.root.
- rp.decRefStartAndMount()
+ rp.decRefStartAndMount(ctx)
rp.mount = rp.root.mount
rp.start = rp.root.dentry
rp.flags &^= rpflagsHaveMountRef | rpflagsHaveStartRef
@@ -438,7 +441,7 @@ func (rp *ResolvingPath) handleError(err error) bool {
// path, including the symlink target we just prepended.
rp.relpathCommit()
// Restart path resolution on the new Mount.
- rp.releaseErrorState()
+ rp.releaseErrorState(ctx)
return true
default:
@@ -447,6 +450,17 @@ func (rp *ResolvingPath) handleError(err error) bool {
}
}
+// canHandleError returns true if err is an error returned by rp.Resolve*()
+// that rp.handleError() may attempt to handle.
+func (rp *ResolvingPath) canHandleError(err error) bool {
+ switch err.(type) {
+ case resolveMountRootOrJumpError, resolveMountPointError, resolveAbsSymlinkError:
+ return true
+ default:
+ return false
+ }
+}
+
// MustBeDir returns true if the file traversed by rp must be a directory.
func (rp *ResolvingPath) MustBeDir() bool {
return rp.mustBeDir
diff --git a/pkg/sentry/vfs/syscalls.go b/pkg/sentry/vfs/syscalls.go
deleted file mode 100644
index abde0feaa..000000000
--- a/pkg/sentry/vfs/syscalls.go
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// PathOperation specifies the path operated on by a VFS method.
-//
-// PathOperation is passed to VFS methods by pointer to reduce memory copying:
-// it's somewhat large and should never escape. (Options structs are passed by
-// pointer to VFS and FileDescription methods for the same reason.)
-type PathOperation struct {
- // Root is the VFS root. References on Root are borrowed from the provider
- // of the PathOperation.
- //
- // Invariants: Root.Ok().
- Root VirtualDentry
-
- // Start is the starting point for the path traversal. References on Start
- // are borrowed from the provider of the PathOperation (i.e. the caller of
- // the VFS method to which the PathOperation was passed).
- //
- // Invariants: Start.Ok(). If Pathname.Absolute, then Start == Root.
- Start VirtualDentry
-
- // Path is the pathname traversed by this operation.
- Pathname string
-
- // If FollowFinalSymlink is true, and the Dentry traversed by the final
- // path component represents a symbolic link, the symbolic link should be
- // followed.
- FollowFinalSymlink bool
-}
-
-// GetDentryAt returns a VirtualDentry representing the given path, at which a
-// file must exist. A reference is taken on the returned VirtualDentry.
-func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return VirtualDentry{}, err
- }
- for {
- d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
- if err == nil {
- vd := VirtualDentry{
- mount: rp.mount,
- dentry: d,
- }
- rp.mount.incRef()
- vfs.putResolvingPath(rp)
- return vd, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return VirtualDentry{}, err
- }
- }
-}
-
-// MkdirAt creates a directory at the given path.
-func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
- // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
- // also honored." - mkdir(2)
- opts.Mode &= 01777
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return err
- }
- for {
- err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return err
- }
- }
-}
-
-// MknodAt creates a file of the given mode at the given path. It returns an
-// error from the syserror package.
-func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return nil
- }
- for {
- if err = rp.mount.fs.impl.MknodAt(ctx, rp, *opts); err == nil {
- vfs.putResolvingPath(rp)
- return nil
- }
- // Handle mount traversals.
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return err
- }
- }
-}
-
-// OpenAt returns a FileDescription providing access to the file at the given
-// path. A reference is taken on the returned FileDescription.
-func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
- // Remove:
- //
- // - O_LARGEFILE, which we always report in FileDescription status flags
- // since only 64-bit architectures are supported at this time.
- //
- // - O_CLOEXEC, which affects file descriptors and therefore must be
- // handled outside of VFS.
- //
- // - Unknown flags.
- opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
- // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
- if opts.Flags&linux.O_SYNC != 0 {
- opts.Flags |= linux.O_DSYNC
- }
- // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified
- // with O_DIRECTORY and a writable access mode (to ensure that it fails on
- // filesystem implementations that do not support it).
- if opts.Flags&linux.O_TMPFILE != 0 {
- if opts.Flags&linux.O_DIRECTORY == 0 {
- return nil, syserror.EINVAL
- }
- if opts.Flags&linux.O_CREAT != 0 {
- return nil, syserror.EINVAL
- }
- if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
- return nil, syserror.EINVAL
- }
- }
- // O_PATH causes most other flags to be ignored.
- if opts.Flags&linux.O_PATH != 0 {
- opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH
- }
- // "On Linux, the following bits are also honored in mode: [S_ISUID,
- // S_ISGID, S_ISVTX]" - open(2)
- opts.Mode &= 07777
-
- if opts.Flags&linux.O_NOFOLLOW != 0 {
- pop.FollowFinalSymlink = false
- }
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return nil, err
- }
- if opts.Flags&linux.O_DIRECTORY != 0 {
- rp.mustBeDir = true
- rp.mustBeDirOrig = true
- }
- for {
- fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return fd, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return nil, err
- }
- }
-}
-
-// StatAt returns metadata for the file at the given path.
-func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
- rp, err := vfs.getResolvingPath(creds, pop)
- if err != nil {
- return linux.Statx{}, err
- }
- for {
- stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
- if err == nil {
- vfs.putResolvingPath(rp)
- return stat, nil
- }
- if !rp.handleError(err) {
- vfs.putResolvingPath(rp)
- return linux.Statx{}, err
- }
- }
-}
-
-// StatusFlags returns file description status flags.
-func (fd *FileDescription) StatusFlags(ctx context.Context) (uint32, error) {
- flags, err := fd.impl.StatusFlags(ctx)
- flags |= linux.O_LARGEFILE
- return flags, err
-}
-
-// SetStatusFlags sets file description status flags.
-func (fd *FileDescription) SetStatusFlags(ctx context.Context, flags uint32) error {
- return fd.impl.SetStatusFlags(ctx, flags)
-}
-
-// TODO:
-//
-// - VFS.SyncAllFilesystems() for sync(2)
-//
-// - Something for syncfs(2)
-//
-// - VFS.LinkAt()
-//
-// - VFS.ReadlinkAt()
-//
-// - VFS.RenameAt()
-//
-// - VFS.RmdirAt()
-//
-// - VFS.SetStatAt()
-//
-// - VFS.StatFSAt()
-//
-// - VFS.SymlinkAt()
-//
-// - VFS.UnlinkAt()
-//
-// - FileDescription.(almost everything)
diff --git a/pkg/sentry/vfs/testutil.go b/pkg/sentry/vfs/testutil.go
deleted file mode 100644
index 70b192ece..000000000
--- a/pkg/sentry/vfs/testutil.go
+++ /dev/null
@@ -1,139 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs
-
-import (
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/syserror"
-)
-
-// FDTestFilesystemType is a test-only FilesystemType that produces Filesystems
-// for which all FilesystemImpl methods taking a path return EPERM. It is used
-// to produce Mounts and Dentries for testing of FileDescriptionImpls that do
-// not depend on their originating Filesystem.
-type FDTestFilesystemType struct{}
-
-// FDTestFilesystem is a test-only FilesystemImpl produced by
-// FDTestFilesystemType.
-type FDTestFilesystem struct {
- vfsfs Filesystem
-}
-
-// NewFilesystem implements FilesystemType.NewFilesystem.
-func (fstype FDTestFilesystemType) NewFilesystem(ctx context.Context, creds *auth.Credentials, source string, opts NewFilesystemOptions) (*Filesystem, *Dentry, error) {
- var fs FDTestFilesystem
- fs.vfsfs.Init(&fs)
- return &fs.vfsfs, fs.NewDentry(), nil
-}
-
-// Release implements FilesystemImpl.Release.
-func (fs *FDTestFilesystem) Release() {
-}
-
-// Sync implements FilesystemImpl.Sync.
-func (fs *FDTestFilesystem) Sync(ctx context.Context) error {
- return nil
-}
-
-// GetDentryAt implements FilesystemImpl.GetDentryAt.
-func (fs *FDTestFilesystem) GetDentryAt(ctx context.Context, rp *ResolvingPath, opts GetDentryOptions) (*Dentry, error) {
- return nil, syserror.EPERM
-}
-
-// LinkAt implements FilesystemImpl.LinkAt.
-func (fs *FDTestFilesystem) LinkAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry) error {
- return syserror.EPERM
-}
-
-// MkdirAt implements FilesystemImpl.MkdirAt.
-func (fs *FDTestFilesystem) MkdirAt(ctx context.Context, rp *ResolvingPath, opts MkdirOptions) error {
- return syserror.EPERM
-}
-
-// MknodAt implements FilesystemImpl.MknodAt.
-func (fs *FDTestFilesystem) MknodAt(ctx context.Context, rp *ResolvingPath, opts MknodOptions) error {
- return syserror.EPERM
-}
-
-// OpenAt implements FilesystemImpl.OpenAt.
-func (fs *FDTestFilesystem) OpenAt(ctx context.Context, rp *ResolvingPath, opts OpenOptions) (*FileDescription, error) {
- return nil, syserror.EPERM
-}
-
-// ReadlinkAt implements FilesystemImpl.ReadlinkAt.
-func (fs *FDTestFilesystem) ReadlinkAt(ctx context.Context, rp *ResolvingPath) (string, error) {
- return "", syserror.EPERM
-}
-
-// RenameAt implements FilesystemImpl.RenameAt.
-func (fs *FDTestFilesystem) RenameAt(ctx context.Context, rp *ResolvingPath, vd VirtualDentry, opts RenameOptions) error {
- return syserror.EPERM
-}
-
-// RmdirAt implements FilesystemImpl.RmdirAt.
-func (fs *FDTestFilesystem) RmdirAt(ctx context.Context, rp *ResolvingPath) error {
- return syserror.EPERM
-}
-
-// SetStatAt implements FilesystemImpl.SetStatAt.
-func (fs *FDTestFilesystem) SetStatAt(ctx context.Context, rp *ResolvingPath, opts SetStatOptions) error {
- return syserror.EPERM
-}
-
-// StatAt implements FilesystemImpl.StatAt.
-func (fs *FDTestFilesystem) StatAt(ctx context.Context, rp *ResolvingPath, opts StatOptions) (linux.Statx, error) {
- return linux.Statx{}, syserror.EPERM
-}
-
-// StatFSAt implements FilesystemImpl.StatFSAt.
-func (fs *FDTestFilesystem) StatFSAt(ctx context.Context, rp *ResolvingPath) (linux.Statfs, error) {
- return linux.Statfs{}, syserror.EPERM
-}
-
-// SymlinkAt implements FilesystemImpl.SymlinkAt.
-func (fs *FDTestFilesystem) SymlinkAt(ctx context.Context, rp *ResolvingPath, target string) error {
- return syserror.EPERM
-}
-
-// UnlinkAt implements FilesystemImpl.UnlinkAt.
-func (fs *FDTestFilesystem) UnlinkAt(ctx context.Context, rp *ResolvingPath) error {
- return syserror.EPERM
-}
-
-type fdTestDentry struct {
- vfsd Dentry
-}
-
-// NewDentry returns a new Dentry.
-func (fs *FDTestFilesystem) NewDentry() *Dentry {
- var d fdTestDentry
- d.vfsd.Init(&d)
- return &d.vfsd
-}
-
-// IncRef implements DentryImpl.IncRef.
-func (d *fdTestDentry) IncRef(vfsfs *Filesystem) {
-}
-
-// TryIncRef implements DentryImpl.TryIncRef.
-func (d *fdTestDentry) TryIncRef(vfsfs *Filesystem) bool {
- return true
-}
-
-// DecRef implements DentryImpl.DecRef.
-func (d *fdTestDentry) DecRef(vfsfs *Filesystem) {
-}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 4a8a69540..9c2420683 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -16,24 +16,47 @@
//
// Lock order:
//
-// Filesystem implementation locks
-// VirtualFilesystem.mountMu
+// EpollInstance.interestMu
+// FileDescription.epollMu
+// FilesystemImpl/FileDescriptionImpl locks
+// VirtualFilesystem.mountMu
+// Dentry.mu
+// Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry
+// VirtualFilesystem.filesystemsMu
+// EpollInstance.mu
+// Inotify.mu
+// Watches.mu
+// Inotify.evMu
// VirtualFilesystem.fsTypesMu
+//
+// Locking Dentry.mu in multiple Dentries requires holding
+// VirtualFilesystem.mountMu. Locking EpollInstance.interestMu in multiple
+// EpollInstances requires holding epollCycleMu.
package vfs
import (
- "sync"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// A VirtualFilesystem (VFS for short) combines Filesystems in trees of Mounts.
//
// There is no analogue to the VirtualFilesystem type in Linux, as the
// equivalent state in Linux is global.
+//
+// +stateify savable
type VirtualFilesystem struct {
// mountMu serializes mount mutations.
//
// mountMu is analogous to Linux's namespace_sem.
- mountMu sync.RWMutex
+ mountMu sync.Mutex `state:"nosave"`
// mounts maps (mount parent, mount point) pairs to mounts. (Since mounts
// are uniquely namespaced, including mount parent in the key correctly
@@ -52,7 +75,7 @@ type VirtualFilesystem struct {
// mountpoints maps mount points to mounts at those points in all
// namespaces. mountpoints is protected by mountMu.
//
- // mountpoints is used to find mounts that must be unmounted due to
+ // mountpoints is used to find mounts that must be umounted due to
// removal of a mount point Dentry from another mount namespace. ("A file
// or directory that is a mount point in one namespace that is not a mount
// point in another namespace, may be renamed, unlinked, or removed
@@ -62,20 +85,701 @@ type VirtualFilesystem struct {
// mountpoints is analogous to Linux's mountpoint_hashtable.
mountpoints map[*Dentry]map[*Mount]struct{}
- // fsTypes contains all FilesystemTypes that are usable in the
- // VirtualFilesystem. fsTypes is protected by fsTypesMu.
- fsTypesMu sync.RWMutex
- fsTypes map[string]FilesystemType
+ // lastMountID is the last allocated mount ID. lastMountID is accessed
+ // using atomic memory operations.
+ lastMountID uint64
+
+ // anonMount is a Mount, not included in mounts or mountpoints,
+ // representing an anonFilesystem. anonMount is used to back
+ // VirtualDentries returned by VirtualFilesystem.NewAnonVirtualDentry().
+ // anonMount is immutable.
+ //
+ // anonMount is analogous to Linux's anon_inode_mnt.
+ anonMount *Mount
+
+ // devices contains all registered Devices. devices is protected by
+ // devicesMu.
+ devicesMu sync.RWMutex `state:"nosave"`
+ devices map[devTuple]*registeredDevice
+
+ // anonBlockDevMinor contains all allocated anonymous block device minor
+ // numbers. anonBlockDevMinorNext is a lower bound for the smallest
+ // unallocated anonymous block device number. anonBlockDevMinorNext and
+ // anonBlockDevMinor are protected by anonBlockDevMinorMu.
+ anonBlockDevMinorMu sync.Mutex `state:"nosave"`
+ anonBlockDevMinorNext uint32
+ anonBlockDevMinor map[uint32]struct{}
+
+ // fsTypes contains all registered FilesystemTypes. fsTypes is protected by
+ // fsTypesMu.
+ fsTypesMu sync.RWMutex `state:"nosave"`
+ fsTypes map[string]*registeredFilesystemType
+
+ // filesystems contains all Filesystems. filesystems is protected by
+ // filesystemsMu.
+ filesystemsMu sync.Mutex `state:"nosave"`
+ filesystems map[*Filesystem]struct{}
}
-// New returns a new VirtualFilesystem with no mounts or FilesystemTypes.
-func New() *VirtualFilesystem {
- vfs := &VirtualFilesystem{
- mountpoints: make(map[*Dentry]map[*Mount]struct{}),
- fsTypes: make(map[string]FilesystemType),
+// Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes.
+func (vfs *VirtualFilesystem) Init(ctx context.Context) error {
+ if vfs.mountpoints != nil {
+ panic("VFS already initialized")
}
+ vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{})
+ vfs.devices = make(map[devTuple]*registeredDevice)
+ vfs.anonBlockDevMinorNext = 1
+ vfs.anonBlockDevMinor = make(map[uint32]struct{})
+ vfs.fsTypes = make(map[string]*registeredFilesystemType)
+ vfs.filesystems = make(map[*Filesystem]struct{})
vfs.mounts.Init()
- return vfs
+
+ // Construct vfs.anonMount.
+ anonfsDevMinor, err := vfs.GetAnonBlockDevMinor()
+ if err != nil {
+ // This shouldn't be possible since anonBlockDevMinorNext was
+ // initialized to 1 above (no device numbers have been allocated yet).
+ panic(fmt.Sprintf("VirtualFilesystem.Init: device number allocation for anonfs failed: %v", err))
+ }
+ anonfs := anonFilesystem{
+ devMinor: anonfsDevMinor,
+ }
+ anonfs.vfsfs.Init(vfs, &anonFilesystemType{}, &anonfs)
+ defer anonfs.vfsfs.DecRef(ctx)
+ anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{})
+ if err != nil {
+ // We should not be passing any MountOptions that would cause
+ // construction of this mount to fail.
+ panic(fmt.Sprintf("VirtualFilesystem.Init: anonfs mount failed: %v", err))
+ }
+ vfs.anonMount = anonMount
+
+ return nil
+}
+
+// PathOperation specifies the path operated on by a VFS method.
+//
+// PathOperation is passed to VFS methods by pointer to reduce memory copying:
+// it's somewhat large and should never escape. (Options structs are passed by
+// pointer to VFS and FileDescription methods for the same reason.)
+type PathOperation struct {
+ // Root is the VFS root. References on Root are borrowed from the provider
+ // of the PathOperation.
+ //
+ // Invariants: Root.Ok().
+ Root VirtualDentry
+
+ // Start is the starting point for the path traversal. References on Start
+ // are borrowed from the provider of the PathOperation (i.e. the caller of
+ // the VFS method to which the PathOperation was passed).
+ //
+ // Invariants: Start.Ok(). If Path.Absolute, then Start == Root.
+ Start VirtualDentry
+
+ // Path is the pathname traversed by this operation.
+ Path fspath.Path
+
+ // If FollowFinalSymlink is true, and the Dentry traversed by the final
+ // path component represents a symbolic link, the symbolic link should be
+ // followed.
+ FollowFinalSymlink bool
+}
+
+// AccessAt checks whether a user with creds has access to the file at
+// the given path.
+func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credentials, ats AccessTypes, pop *PathOperation) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// GetDentryAt returns a VirtualDentry representing the given path, at which a
+// file must exist. A reference is taken on the returned VirtualDentry.
+func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetDentryOptions) (VirtualDentry, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ d, err := rp.mount.fs.impl.GetDentryAt(ctx, rp, *opts)
+ if err == nil {
+ vd := VirtualDentry{
+ mount: rp.mount,
+ dentry: d,
+ }
+ rp.mount.IncRef()
+ vfs.putResolvingPath(ctx, rp)
+ return vd, nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return VirtualDentry{}, err
+ }
+ }
+}
+
+// Preconditions: pop.Path.Begin.Ok().
+func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (VirtualDentry, string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ parent, err := rp.mount.fs.impl.GetParentDentryAt(ctx, rp)
+ if err == nil {
+ parentVD := VirtualDentry{
+ mount: rp.mount,
+ dentry: parent,
+ }
+ rp.mount.IncRef()
+ name := rp.Component()
+ vfs.putResolvingPath(ctx, rp)
+ return parentVD, name, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.GetParentDentryAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return VirtualDentry{}, "", err
+ }
+ }
+}
+
+// LinkAt creates a hard link at newpop representing the existing file at
+// oldpop.
+func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation) error {
+ oldVD, err := vfs.GetDentryAt(ctx, creds, oldpop, &GetDentryOptions{})
+ if err != nil {
+ return err
+ }
+
+ if !newpop.Path.Begin.Ok() {
+ oldVD.DecRef(ctx)
+ if newpop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if newpop.FollowFinalSymlink {
+ oldVD.DecRef(ctx)
+ ctx.Warningf("VirtualFilesystem.LinkAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, newpop)
+ for {
+ err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ oldVD.DecRef(ctx)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.LinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ oldVD.DecRef(ctx)
+ return err
+ }
+ }
+}
+
+// MkdirAt creates a directory at the given path.
+func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.MkdirAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+ // "Under Linux, apart from the permission bits, the S_ISVTX mode bit is
+ // also honored." - mkdir(2)
+ opts.Mode &= 0777 | linux.S_ISVTX
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.MkdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// MknodAt creates a file of the given mode at the given path. It returns an
+// error from the syserror package.
+func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.MknodAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.MknodAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// OpenAt returns a FileDescription providing access to the file at the given
+// path. A reference is taken on the returned FileDescription.
+func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) {
+ // Remove:
+ //
+ // - O_CLOEXEC, which affects file descriptors and therefore must be
+ // handled outside of VFS.
+ //
+ // - Unknown flags.
+ opts.Flags &= linux.O_ACCMODE | linux.O_CREAT | linux.O_EXCL | linux.O_NOCTTY | linux.O_TRUNC | linux.O_APPEND | linux.O_NONBLOCK | linux.O_DSYNC | linux.O_ASYNC | linux.O_DIRECT | linux.O_LARGEFILE | linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_NOATIME | linux.O_SYNC | linux.O_PATH | linux.O_TMPFILE
+ // Linux's __O_SYNC (which we call linux.O_SYNC) implies O_DSYNC.
+ if opts.Flags&linux.O_SYNC != 0 {
+ opts.Flags |= linux.O_DSYNC
+ }
+ // Linux's __O_TMPFILE (which we call linux.O_TMPFILE) must be specified
+ // with O_DIRECTORY and a writable access mode (to ensure that it fails on
+ // filesystem implementations that do not support it).
+ if opts.Flags&linux.O_TMPFILE != 0 {
+ if opts.Flags&linux.O_DIRECTORY == 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_CREAT != 0 {
+ return nil, syserror.EINVAL
+ }
+ if opts.Flags&linux.O_ACCMODE == linux.O_RDONLY {
+ return nil, syserror.EINVAL
+ }
+ }
+ // O_PATH causes most other flags to be ignored.
+ if opts.Flags&linux.O_PATH != 0 {
+ opts.Flags &= linux.O_DIRECTORY | linux.O_NOFOLLOW | linux.O_PATH
+ }
+ // "On Linux, the following bits are also honored in mode: [S_ISUID,
+ // S_ISGID, S_ISVTX]" - open(2)
+ opts.Mode &= 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+ if opts.Flags&linux.O_NOFOLLOW != 0 {
+ pop.FollowFinalSymlink = false
+ }
+ rp := vfs.getResolvingPath(creds, pop)
+ if opts.Flags&linux.O_DIRECTORY != 0 {
+ rp.mustBeDir = true
+ rp.mustBeDirOrig = true
+ }
+ for {
+ fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+
+ if opts.FileExec {
+ if fd.Mount().Flags.NoExec {
+ fd.DecRef(ctx)
+ return nil, syserror.EACCES
+ }
+
+ // Only a regular file can be executed.
+ stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
+ if err != nil {
+ fd.DecRef(ctx)
+ return nil, err
+ }
+ if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
+ fd.DecRef(ctx)
+ return nil, syserror.EACCES
+ }
+ }
+
+ fd.Dentry().InotifyWithParent(ctx, linux.IN_OPEN, 0, PathEvent)
+ return fd, nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return nil, err
+ }
+ }
+}
+
+// ReadlinkAt returns the target of the symbolic link at the given path.
+func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return target, nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return "", err
+ }
+ }
+}
+
+// RenameAt renames the file at oldpop to newpop.
+func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credentials, oldpop, newpop *PathOperation, opts *RenameOptions) error {
+ if !oldpop.Path.Begin.Ok() {
+ if oldpop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if oldpop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.RenameAt: source path can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ oldParentVD, oldName, err := vfs.getParentDirAndName(ctx, creds, oldpop)
+ if err != nil {
+ return err
+ }
+ if oldName == "." || oldName == ".." {
+ oldParentVD.DecRef(ctx)
+ return syserror.EBUSY
+ }
+
+ if !newpop.Path.Begin.Ok() {
+ oldParentVD.DecRef(ctx)
+ if newpop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if newpop.FollowFinalSymlink {
+ oldParentVD.DecRef(ctx)
+ ctx.Warningf("VirtualFilesystem.RenameAt: destination path can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, newpop)
+ renameOpts := *opts
+ if oldpop.Path.Dir {
+ renameOpts.MustBeDir = true
+ }
+ for {
+ err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ oldParentVD.DecRef(ctx)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.RenameAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ oldParentVD.DecRef(ctx)
+ return err
+ }
+ }
+}
+
+// RmdirAt removes the directory at the given path.
+func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.RmdirAt: file deletion paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.RmdirAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.RmdirAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// SetStatAt changes metadata for the file at the given path.
+func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetStatOptions) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// StatAt returns metadata for the file at the given path.
+func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *StatOptions) (linux.Statx, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return stat, nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return linux.Statx{}, err
+ }
+ }
+}
+
+// StatFSAt returns metadata for the filesystem containing the file at the
+// given path.
+func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) (linux.Statfs, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return statfs, nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return linux.Statfs{}, err
+ }
+ }
+}
+
+// SymlinkAt creates a symbolic link at the given path with the given target.
+func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EEXIST
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.SymlinkAt: file creation paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.SymlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// UnlinkAt deletes the non-directory file at the given path.
+func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return syserror.EBUSY
+ }
+ return syserror.ENOENT
+ }
+ if pop.FollowFinalSymlink {
+ ctx.Warningf("VirtualFilesystem.UnlinkAt: file deletion paths can't follow final symlink")
+ return syserror.EINVAL
+ }
+
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.UnlinkAt(ctx, rp)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.UnlinkAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// BoundEndpointAt gets the bound endpoint at the given path, if one exists.
+func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) {
+ if !pop.Path.Begin.Ok() {
+ if pop.Path.Absolute {
+ return nil, syserror.ECONNREFUSED
+ }
+ return nil, syserror.ENOENT
+ }
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return bep, nil
+ }
+ if checkInvariants {
+ if rp.canHandleError(err) && rp.Done() {
+ panic(fmt.Sprintf("%T.BoundEndpointAt() consumed all path components and returned %v", rp.mount.fs.impl, err))
+ }
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return nil, err
+ }
+ }
+}
+
+// ListxattrAt returns all extended attribute names for the file at the given
+// path.
+func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return names, nil
+ }
+ if err == syserror.ENOTSUP {
+ // Linux doesn't actually return ENOTSUP in this case; instead,
+ // fs/xattr.c:vfs_listxattr() falls back to allowing the security
+ // subsystem to return security extended attributes, which by
+ // default don't exist.
+ vfs.putResolvingPath(ctx, rp)
+ return nil, nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return nil, err
+ }
+ }
+}
+
+// GetxattrAt returns the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return val, nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return "", err
+ }
+ }
+}
+
+// SetxattrAt changes the value associated with the given extended attribute
+// for the file at the given path.
+func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// RemovexattrAt removes the given extended attribute from the file at rp.
+func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error {
+ rp := vfs.getResolvingPath(creds, pop)
+ for {
+ err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name)
+ if err == nil {
+ vfs.putResolvingPath(ctx, rp)
+ return nil
+ }
+ if !rp.handleError(ctx, err) {
+ vfs.putResolvingPath(ctx, rp)
+ return err
+ }
+ }
+}
+
+// SyncAllFilesystems has the semantics of Linux's sync(2).
+func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error {
+ fss := make(map[*Filesystem]struct{})
+ vfs.filesystemsMu.Lock()
+ for fs := range vfs.filesystems {
+ if !fs.TryIncRef() {
+ continue
+ }
+ fss[fs] = struct{}{}
+ }
+ vfs.filesystemsMu.Unlock()
+ var retErr error
+ for fs := range fss {
+ if err := fs.impl.Sync(ctx); err != nil && retErr == nil {
+ retErr = err
+ }
+ fs.DecRef(ctx)
+ }
+ return retErr
}
// A VirtualDentry represents a node in a VFS tree, by combining a Dentry
@@ -97,11 +801,21 @@ func New() *VirtualFilesystem {
// VirtualDentry methods require that a reference is held on the VirtualDentry.
//
// VirtualDentry is analogous to Linux's struct path.
+//
+// +stateify savable
type VirtualDentry struct {
mount *Mount
dentry *Dentry
}
+// MakeVirtualDentry creates a VirtualDentry.
+func MakeVirtualDentry(mount *Mount, dentry *Dentry) VirtualDentry {
+ return VirtualDentry{
+ mount: mount,
+ dentry: dentry,
+ }
+}
+
// Ok returns true if vd is not empty. It does not require that a reference is
// held.
func (vd VirtualDentry) Ok() bool {
@@ -111,15 +825,15 @@ func (vd VirtualDentry) Ok() bool {
// IncRef increments the reference counts on the Mount and Dentry represented
// by vd.
func (vd VirtualDentry) IncRef() {
- vd.mount.incRef()
- vd.dentry.incRef(vd.mount.fs)
+ vd.mount.IncRef()
+ vd.dentry.IncRef()
}
// DecRef decrements the reference counts on the Mount and Dentry represented
// by vd.
-func (vd VirtualDentry) DecRef() {
- vd.dentry.decRef(vd.mount.fs)
- vd.mount.decRef()
+func (vd VirtualDentry) DecRef(ctx context.Context) {
+ vd.dentry.DecRef(ctx)
+ vd.mount.DecRef(ctx)
}
// Mount returns the Mount associated with vd. It does not take a reference on
diff --git a/pkg/sentry/watchdog/BUILD b/pkg/sentry/watchdog/BUILD
index 4d8435265..1c5a1c9b6 100644
--- a/pkg/sentry/watchdog/BUILD
+++ b/pkg/sentry/watchdog/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "watchdog",
srcs = ["watchdog.go"],
- importpath = "gvisor.dev/gvisor/pkg/sentry/watchdog",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
@@ -13,5 +12,6 @@ go_library(
"//pkg/metric",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
+ "//pkg/sync",
],
)
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 145102c0d..748273366 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -32,7 +32,6 @@ package watchdog
import (
"bytes"
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -40,17 +39,48 @@ import (
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sync"
)
-// DefaultTimeout is a resonable timeout value for most applications.
-const DefaultTimeout = 3 * time.Minute
+// Opts configures the watchdog.
+type Opts struct {
+ // TaskTimeout is the amount of time to allow a task to execute the
+ // same syscall without blocking before it's declared stuck.
+ TaskTimeout time.Duration
+
+ // TaskTimeoutAction indicates what action to take when a stuck tasks
+ // is detected.
+ TaskTimeoutAction Action
+
+ // StartupTimeout is the amount of time to allow between watchdog
+ // creation and calling watchdog.Start.
+ StartupTimeout time.Duration
+
+ // StartupTimeoutAction indicates what action to take when
+ // watchdog.Start is not called within the timeout.
+ StartupTimeoutAction Action
+}
+
+// DefaultOpts is a default set of options for the watchdog.
+var DefaultOpts = Opts{
+ // Task timeout.
+ TaskTimeout: 3 * time.Minute,
+ TaskTimeoutAction: LogWarning,
+
+ // Startup timeout.
+ StartupTimeout: 30 * time.Second,
+ StartupTimeoutAction: LogWarning,
+}
// descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period
// is discounted from task's last update time. It's set high enough that small scheduling delays won't
// trigger it.
const descheduleThreshold = 1 * time.Second
-var stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+var (
+ stuckStartup = metric.MustCreateNewUint64Metric("/watchdog/stuck_startup_detected", true /* sync */, "Incremented once on startup watchdog timeout")
+ stuckTasks = metric.MustCreateNewUint64Metric("/watchdog/stuck_tasks_detected", true /* sync */, "Cumulative count of stuck tasks detected")
+)
// Amount of time to wait before dumping the stack to the log again when the same task(s) remains stuck.
var stackDumpSameTaskPeriod = time.Minute
@@ -61,6 +91,7 @@ type Action int
const (
// LogWarning logs warning message followed by stack trace.
LogWarning Action = iota
+
// Panic will do the same logging as LogWarning and panic().
Panic
)
@@ -80,17 +111,13 @@ func (a Action) String() string {
// Watchdog is the main watchdog class. It controls a goroutine that periodically
// analyses all tasks and reports if any of them appear to be stuck.
type Watchdog struct {
+ // Configuration options are embedded.
+ Opts
+
// period indicates how often to check all tasks. It's calculated based on
- // 'taskTimeout'.
+ // opts.TaskTimeout.
period time.Duration
- // taskTimeout is the amount of time to allow a task to execute the same syscall
- // without blocking before it's declared stuck.
- taskTimeout time.Duration
-
- // timeoutAction indicates what action to take when a stuck tasks is detected.
- timeoutAction Action
-
// k is where the tasks come from.
k *kernel.Kernel
@@ -113,8 +140,12 @@ type Watchdog struct {
// mu protects the fields below.
mu sync.Mutex
- // started is true if the watchdog has been started before.
- started bool
+ // running is true if the watchdog is running.
+ running bool
+
+ // startCalled is true if Start has ever been called. It remains true
+ // even if Stop is called.
+ startCalled bool
}
type offender struct {
@@ -122,58 +153,84 @@ type offender struct {
}
// New creates a new watchdog.
-func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog {
- // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much.
- period := taskTimeout / 4
- return &Watchdog{
- k: k,
- period: period,
- taskTimeout: taskTimeout,
- timeoutAction: a,
- offenders: make(map[*kernel.Task]*offender),
- stop: make(chan struct{}),
- done: make(chan struct{}),
+func New(k *kernel.Kernel, opts Opts) *Watchdog {
+ // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much.
+ period := opts.TaskTimeout / 4
+ w := &Watchdog{
+ Opts: opts,
+ k: k,
+ period: period,
+ offenders: make(map[*kernel.Task]*offender),
+ stop: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+
+ // Handle StartupTimeout if it exists.
+ if w.StartupTimeout > 0 {
+ log.Infof("Watchdog waiting %v for startup", w.StartupTimeout)
+ go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore.
}
+
+ return w
}
// Start starts the watchdog.
func (w *Watchdog) Start() {
- if w.taskTimeout == 0 {
- log.Infof("Watchdog disabled")
- return
- }
-
w.mu.Lock()
defer w.mu.Unlock()
- if w.started {
+ w.startCalled = true
+
+ if w.running {
return
}
+ if w.TaskTimeout == 0 {
+ log.Infof("Watchdog task timeout disabled")
+ return
+ }
w.lastRun = w.k.MonotonicClock().Now()
- log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction)
+ log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction)
go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore.
- w.started = true
+ w.running = true
}
// Stop requests the watchdog to stop and wait for it.
func (w *Watchdog) Stop() {
- if w.taskTimeout == 0 {
+ if w.TaskTimeout == 0 {
return
}
w.mu.Lock()
defer w.mu.Unlock()
- if !w.started {
+ if !w.running {
return
}
log.Infof("Stopping watchdog")
w.stop <- struct{}{}
<-w.done
- w.started = false
+ w.running = false
log.Infof("Watchdog stopped")
}
+// waitForStart waits for Start to be called and takes action if it does not
+// happen within the startup timeout.
+func (w *Watchdog) waitForStart() {
+ <-time.After(w.StartupTimeout)
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.startCalled {
+ // We are fine.
+ return
+ }
+
+ stuckStartup.Increment()
+
+ var buf bytes.Buffer
+ buf.WriteString(fmt.Sprintf("Watchdog.Start() not called within %s", w.StartupTimeout))
+ w.doAction(w.StartupTimeoutAction, false, &buf)
+}
+
// loop is the main watchdog routine. It only returns when 'Stop()' is called.
func (w *Watchdog) loop() {
// Loop until someone stops it.
@@ -202,9 +259,9 @@ func (w *Watchdog) runTurn() {
select {
case <-done:
- case <-time.After(w.taskTimeout):
+ case <-time.After(w.TaskTimeout):
// Report if the watchdog is not making progress.
- // No one is wathching the watchdog watcher though.
+ // No one is watching the watchdog watcher though.
w.reportStuckWatchdog()
<-done
}
@@ -231,12 +288,14 @@ func (w *Watchdog) runTurn() {
if tsched.State == kernel.TaskGoroutineRunningSys {
lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick)))
elapsed := now.Sub(lastUpdateTime) - discount
- if elapsed > w.taskTimeout {
+ if elapsed > w.TaskTimeout {
tc, ok := w.offenders[t]
if !ok {
// New stuck task detected.
//
- // TODO(b/65849403): Tasks blocked doing IO may be considered stuck in kernel.
+ // Note that tasks blocked doing IO may be considered stuck in kernel,
+ // unless they are surrounded b
+ // Task.UninterruptibleSleepStart/Finish.
tc = &offender{lastUpdateTime: lastUpdateTime}
stuckTasks.Increment()
newTaskFound = true
@@ -261,28 +320,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
tid := w.k.TaskSet().Root.IDOfTask(t)
buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime)))
}
+
buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine")
- w.onStuckTask(newTaskFound, &buf)
+
+ // Force stack dump only if a new task is detected.
+ w.doAction(w.TaskTimeoutAction, newTaskFound, &buf)
}
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
- buf.WriteString("Watchdog goroutine is stuck:\n")
- w.onStuckTask(true, &buf)
+ buf.WriteString("Watchdog goroutine is stuck")
+ w.doAction(w.TaskTimeoutAction, false, &buf)
}
-func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) {
- switch w.timeoutAction {
+// doAction will take the given action. If the action is LogWarning, the stack
+// is not always dumped to the log to prevent log flooding. "forceStack"
+// guarantees that the stack will be dumped regardless.
+func (w *Watchdog) doAction(action Action, forceStack bool, msg *bytes.Buffer) {
+ switch action {
case LogWarning:
- // Dump stack only if a new task is detected or if it sometime has passed since
- // the last time a stack dump was generated.
- if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod {
+ // Dump stack only if forced or sometime has passed since the last time a
+ // stack dump was generated.
+ if !forceStack && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod {
msg.WriteString("\n...[stack dump skipped]...")
log.Warningf(msg.String())
- } else {
- log.TracebackAll(msg.String())
- w.lastStackDump = time.Now()
+ return
}
+ log.TracebackAll(msg.String())
+ w.lastStackDump = time.Now()
case Panic:
// Panic will skip over running tasks, which is likely the culprit here. So manually
@@ -300,6 +365,10 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) {
case <-metricsEmitted:
case <-time.After(1 * time.Second):
}
- panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String()))
+ panic(fmt.Sprintf("%s\nStack for running G's are skipped while panicking.", msg.String()))
+
+ default:
+ panic(fmt.Sprintf("Unknown watchdog action %v", action))
+
}
}
diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD
new file mode 100644
index 000000000..f08599ebd
--- /dev/null
+++ b/pkg/shim/runsc/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "runsc",
+ srcs = [
+ "runsc.go",
+ "utils.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "@com_github_containerd_go_runc//:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go
new file mode 100644
index 000000000..c5cf68efa
--- /dev/null
+++ b/pkg/shim/runsc/runsc.go
@@ -0,0 +1,514 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runsc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "syscall"
+ "time"
+
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+var Monitor runc.ProcessMonitor = runc.Monitor
+
+// DefaultCommand is the default command for Runsc.
+const DefaultCommand = "runsc"
+
+// Runsc is the client to the runsc cli.
+type Runsc struct {
+ Command string
+ PdeathSignal syscall.Signal
+ Setpgid bool
+ Root string
+ Log string
+ LogFormat runc.Format
+ Config map[string]string
+}
+
+// List returns all containers created inside the provided runsc root directory.
+func (r *Runsc) List(context context.Context) ([]*runc.Container, error) {
+ data, err := cmdOutput(r.command(context, "list", "--format=json"), false)
+ if err != nil {
+ return nil, err
+ }
+ var out []*runc.Container
+ if err := json.Unmarshal(data, &out); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+// State returns the state for the container provided by id.
+func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) {
+ data, err := cmdOutput(r.command(context, "state", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+ var c runc.Container
+ if err := json.Unmarshal(data, &c); err != nil {
+ return nil, err
+ }
+ return &c, nil
+}
+
+type CreateOpts struct {
+ runc.IO
+ ConsoleSocket runc.ConsoleSocket
+
+ // PidFile is a path to where a pid file should be created.
+ PidFile string
+
+ // UserLog is a path to where runsc user log should be generated.
+ UserLog string
+}
+
+func (o *CreateOpts) args() (out []string, err error) {
+ if o.PidFile != "" {
+ abs, err := filepath.Abs(o.PidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--pid-file", abs)
+ }
+ if o.ConsoleSocket != nil {
+ out = append(out, "--console-socket", o.ConsoleSocket.Path())
+ }
+ if o.UserLog != "" {
+ out = append(out, "--user-log", o.UserLog)
+ }
+ return out, nil
+}
+
+// Create creates a new container and returns its pid if it was created successfully.
+func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateOpts) error {
+ args := []string{"create", "--bundle", bundle}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if opts != nil && opts.IO != nil {
+ if c, ok := opts.IO.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return err
+}
+
+// Start will start an already created container.
+func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error {
+ cmd := r.command(context, "start", id)
+ if cio != nil {
+ cio.Set(cmd)
+ }
+
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if cio != nil {
+ if c, ok := cio.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return err
+}
+
+type waitResult struct {
+ ID string `json:"id"`
+ ExitStatus int `json:"exitStatus"`
+}
+
+// Wait will wait for a running container, and return its exit status.
+//
+// TODO(random-liu): Add exec process support.
+func (r *Runsc) Wait(context context.Context, id string) (int, error) {
+ data, err := cmdOutput(r.command(context, "wait", id), true)
+ if err != nil {
+ return 0, fmt.Errorf("%s: %s", err, data)
+ }
+ var res waitResult
+ if err := json.Unmarshal(data, &res); err != nil {
+ return 0, err
+ }
+ return res.ExitStatus, nil
+}
+
+type ExecOpts struct {
+ runc.IO
+ PidFile string
+ InternalPidFile string
+ ConsoleSocket runc.ConsoleSocket
+ Detach bool
+}
+
+func (o *ExecOpts) args() (out []string, err error) {
+ if o.ConsoleSocket != nil {
+ out = append(out, "--console-socket", o.ConsoleSocket.Path())
+ }
+ if o.Detach {
+ out = append(out, "--detach")
+ }
+ if o.PidFile != "" {
+ abs, err := filepath.Abs(o.PidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--pid-file", abs)
+ }
+ if o.InternalPidFile != "" {
+ abs, err := filepath.Abs(o.InternalPidFile)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, "--internal-pid-file", abs)
+ }
+ return out, nil
+}
+
+// Exec executes an additional process inside the container based on a full OCI
+// Process specification.
+func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opts *ExecOpts) error {
+ f, err := ioutil.TempFile(os.Getenv("XDG_RUNTIME_DIR"), "runsc-process")
+ if err != nil {
+ return err
+ }
+ defer os.Remove(f.Name())
+ err = json.NewEncoder(f).Encode(spec)
+ f.Close()
+ if err != nil {
+ return err
+ }
+ args := []string{"exec", "--process", f.Name()}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+ if cmd.Stdout == nil && cmd.Stderr == nil {
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ if opts != nil && opts.IO != nil {
+ if c, ok := opts.IO.(runc.StartCloser); ok {
+ if err := c.CloseAfterStart(); err != nil {
+ return err
+ }
+ }
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+ return err
+}
+
+// Run runs the create, start, delete lifecycle of the container and returns
+// its exit status after it has exited.
+func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts) (int, error) {
+ args := []string{"run", "--bundle", bundle}
+ if opts != nil {
+ oargs, err := opts.args()
+ if err != nil {
+ return -1, err
+ }
+ args = append(args, oargs...)
+ }
+ cmd := r.command(context, append(args, id)...)
+ if opts != nil && opts.IO != nil {
+ opts.Set(cmd)
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return -1, err
+ }
+ return Monitor.Wait(cmd, ec)
+}
+
+type DeleteOpts struct {
+ Force bool
+}
+
+func (o *DeleteOpts) args() (out []string) {
+ if o.Force {
+ out = append(out, "--force")
+ }
+ return out
+}
+
+// Delete deletes the container.
+func (r *Runsc) Delete(context context.Context, id string, opts *DeleteOpts) error {
+ args := []string{"delete"}
+ if opts != nil {
+ args = append(args, opts.args()...)
+ }
+ return r.runOrError(r.command(context, append(args, id)...))
+}
+
+// KillOpts specifies options for killing a container and its processes.
+type KillOpts struct {
+ All bool
+ Pid int
+}
+
+func (o *KillOpts) args() (out []string) {
+ if o.All {
+ out = append(out, "--all")
+ }
+ if o.Pid != 0 {
+ out = append(out, "--pid", strconv.Itoa(o.Pid))
+ }
+ return out
+}
+
+// Kill sends the specified signal to the container.
+func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts) error {
+ args := []string{
+ "kill",
+ }
+ if opts != nil {
+ args = append(args, opts.args()...)
+ }
+ return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...))
+}
+
+// Stats return the stats for a container like cpu, memory, and I/O.
+func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) {
+ cmd := r.command(context, "events", "--stats", id)
+ rd, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ rd.Close()
+ Monitor.Wait(cmd, ec)
+ }()
+ var e runc.Event
+ if err := json.NewDecoder(rd).Decode(&e); err != nil {
+ return nil, err
+ }
+ return e.Stats, nil
+}
+
+// Events returns an event stream from runsc for a container with stats and OOM notifications.
+func (r *Runsc) Events(context context.Context, id string, interval time.Duration) (chan *runc.Event, error) {
+ cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id)
+ rd, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, err
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ rd.Close()
+ return nil, err
+ }
+ var (
+ dec = json.NewDecoder(rd)
+ c = make(chan *runc.Event, 128)
+ )
+ go func() {
+ defer func() {
+ close(c)
+ rd.Close()
+ Monitor.Wait(cmd, ec)
+ }()
+ for {
+ var e runc.Event
+ if err := dec.Decode(&e); err != nil {
+ if err == io.EOF {
+ return
+ }
+ e = runc.Event{
+ Type: "error",
+ Err: err,
+ }
+ }
+ c <- &e
+ }
+ }()
+ return c, nil
+}
+
+// Ps lists all the processes inside the container returning their pids.
+func (r *Runsc) Ps(context context.Context, id string) ([]int, error) {
+ data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+ var pids []int
+ if err := json.Unmarshal(data, &pids); err != nil {
+ return nil, err
+ }
+ return pids, nil
+}
+
+// Top lists all the processes inside the container returning the full ps data.
+func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) {
+ data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true)
+ if err != nil {
+ return nil, fmt.Errorf("%s: %s", err, data)
+ }
+
+ topResults, err := runc.ParsePSOutput(data)
+ if err != nil {
+ return nil, fmt.Errorf("%s: ", err)
+ }
+ return topResults, nil
+}
+
+func (r *Runsc) args() []string {
+ var args []string
+ if r.Root != "" {
+ args = append(args, fmt.Sprintf("--root=%s", r.Root))
+ }
+ if r.Log != "" {
+ args = append(args, fmt.Sprintf("--log=%s", r.Log))
+ }
+ if r.LogFormat != "" {
+ args = append(args, fmt.Sprintf("--log-format=%s", r.LogFormat))
+ }
+ for k, v := range r.Config {
+ args = append(args, fmt.Sprintf("--%s=%s", k, v))
+ }
+ return args
+}
+
+// runOrError will run the provided command.
+//
+// If an error is encountered and neither Stdout or Stderr was set the error
+// will be returned in the format of <error>: <stderr>.
+func (r *Runsc) runOrError(cmd *exec.Cmd) error {
+ if cmd.Stdout != nil || cmd.Stderr != nil {
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return err
+ }
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+ return err
+ }
+ data, err := cmdOutput(cmd, true)
+ if err != nil {
+ return fmt.Errorf("%s: %s", err, data)
+ }
+ return nil
+}
+
+func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd {
+ command := r.Command
+ if command == "" {
+ command = DefaultCommand
+ }
+ cmd := exec.CommandContext(context, command, append(r.args(), args...)...)
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: r.Setpgid,
+ }
+ if r.PdeathSignal != 0 {
+ cmd.SysProcAttr.Pdeathsig = r.PdeathSignal
+ }
+
+ return cmd
+}
+
+func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) {
+ b := getBuf()
+ defer putBuf(b)
+
+ cmd.Stdout = b
+ if combined {
+ cmd.Stderr = b
+ }
+ ec, err := Monitor.Start(cmd)
+ if err != nil {
+ return nil, err
+ }
+
+ status, err := Monitor.Wait(cmd, ec)
+ if err == nil && status != 0 {
+ err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0])
+ }
+
+ return b.Bytes(), err
+}
diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go
new file mode 100644
index 000000000..c514b3bc7
--- /dev/null
+++ b/pkg/shim/runsc/utils.go
@@ -0,0 +1,44 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runsc
+
+import (
+ "bytes"
+ "strings"
+ "sync"
+)
+
+var bytesBufferPool = sync.Pool{
+ New: func() interface{} {
+ return bytes.NewBuffer(nil)
+ },
+}
+
+func getBuf() *bytes.Buffer {
+ return bytesBufferPool.Get().(*bytes.Buffer)
+}
+
+func putBuf(b *bytes.Buffer) {
+ b.Reset()
+ bytesBufferPool.Put(b)
+}
+
+// FormatLogPath parses runsc config, and fill in %ID% in the log path.
+func FormatLogPath(id string, config map[string]string) {
+ if path, ok := config["debug-log"]; ok {
+ config["debug-log"] = strings.Replace(path, "%ID%", id, -1)
+ }
+}
diff --git a/pkg/shim/v1/proc/BUILD b/pkg/shim/v1/proc/BUILD
new file mode 100644
index 000000000..4377306af
--- /dev/null
+++ b/pkg/shim/v1/proc/BUILD
@@ -0,0 +1,36 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "proc",
+ srcs = [
+ "deleted_state.go",
+ "exec.go",
+ "exec_state.go",
+ "init.go",
+ "init_state.go",
+ "io.go",
+ "process.go",
+ "types.go",
+ "utils.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_go_runc//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v1/proc/deleted_state.go b/pkg/shim/v1/proc/deleted_state.go
new file mode 100644
index 000000000..d9b970c4d
--- /dev/null
+++ b/pkg/shim/v1/proc/deleted_state.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/process"
+)
+
+type deletedState struct{}
+
+func (*deletedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a deleted process.ss")
+}
+
+func (*deletedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a deleted process.ss")
+}
+
+func (*deletedState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound)
+}
+
+func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound)
+}
+
+func (*deletedState) SetExited(status int) {}
+
+func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return nil, fmt.Errorf("cannot exec in a deleted state")
+}
diff --git a/pkg/shim/v1/proc/exec.go b/pkg/shim/v1/proc/exec.go
new file mode 100644
index 000000000..1d1d90488
--- /dev/null
+++ b/pkg/shim/v1/proc/exec.go
@@ -0,0 +1,281 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+type execProcess struct {
+ wg sync.WaitGroup
+
+ execState execState
+
+ mu sync.Mutex
+ id string
+ console console.Console
+ io runc.IO
+ status int
+ exited time.Time
+ pid int
+ internalPid int
+ closers []io.Closer
+ stdin io.Closer
+ stdio stdio.Stdio
+ path string
+ spec specs.Process
+
+ parent *Init
+ waitBlock chan struct{}
+}
+
+func (e *execProcess) Wait() {
+ <-e.waitBlock
+}
+
+func (e *execProcess) ID() string {
+ return e.id
+}
+
+func (e *execProcess) Pid() int {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.pid
+}
+
+func (e *execProcess) ExitStatus() int {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.status
+}
+
+func (e *execProcess) ExitedAt() time.Time {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return e.exited
+}
+
+func (e *execProcess) SetExited(status int) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.execState.SetExited(status)
+}
+
+func (e *execProcess) setExited(status int) {
+ e.status = status
+ e.exited = time.Now()
+ e.parent.Platform.ShutdownConsole(context.Background(), e.console)
+ close(e.waitBlock)
+}
+
+func (e *execProcess) Delete(ctx context.Context) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Delete(ctx)
+}
+
+func (e *execProcess) delete(ctx context.Context) error {
+ e.wg.Wait()
+ if e.io != nil {
+ for _, c := range e.closers {
+ c.Close()
+ }
+ e.io.Close()
+ }
+ pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
+ // silently ignore error
+ os.Remove(pidfile)
+ internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
+ // silently ignore error
+ os.Remove(internalPidfile)
+ return nil
+}
+
+func (e *execProcess) Resize(ws console.WinSize) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Resize(ws)
+}
+
+func (e *execProcess) resize(ws console.WinSize) error {
+ if e.console == nil {
+ return nil
+ }
+ return e.console.Resize(ws)
+}
+
+func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Kill(ctx, sig, false)
+}
+
+func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error {
+ internalPid := e.internalPid
+ if internalPid != 0 {
+ if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{
+ Pid: internalPid,
+ }); err != nil {
+ // If this returns error, consider the process has
+ // already stopped.
+ //
+ // TODO: Fix after signal handling is fixed.
+ return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound)
+ }
+ }
+ return nil
+}
+
+func (e *execProcess) Stdin() io.Closer {
+ return e.stdin
+}
+
+func (e *execProcess) Stdio() stdio.Stdio {
+ return e.stdio
+}
+
+func (e *execProcess) Start(ctx context.Context) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ return e.execState.Start(ctx)
+}
+
+func (e *execProcess) start(ctx context.Context) (err error) {
+ var (
+ socket *runc.Socket
+ pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id))
+ internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id))
+ )
+ if e.stdio.Terminal {
+ if socket, err = runc.NewTempConsoleSocket(); err != nil {
+ return fmt.Errorf("failed to create runc console socket: %w", err)
+ }
+ defer socket.Close()
+ } else if e.stdio.IsNull() {
+ if e.io, err = runc.NewNullIO(); err != nil {
+ return fmt.Errorf("creating new NULL IO: %w", err)
+ }
+ } else {
+ if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil {
+ return fmt.Errorf("failed to create runc io pipes: %w", err)
+ }
+ }
+ opts := &runsc.ExecOpts{
+ PidFile: pidfile,
+ InternalPidFile: internalPidfile,
+ IO: e.io,
+ Detach: true,
+ }
+ if socket != nil {
+ opts.ConsoleSocket = socket
+ }
+ eventCh := e.parent.Monitor.Subscribe()
+ defer func() {
+ // Unsubscribe if an error is returned.
+ if err != nil {
+ e.parent.Monitor.Unsubscribe(eventCh)
+ }
+ }()
+ if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil {
+ close(e.waitBlock)
+ return e.parent.runtimeError(err, "OCI runtime exec failed")
+ }
+ if e.stdio.Stdin != "" {
+ sc, err := fifo.OpenFifo(context.Background(), e.stdio.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("failed to open stdin fifo %s: %w", e.stdio.Stdin, err)
+ }
+ e.closers = append(e.closers, sc)
+ e.stdin = sc
+ }
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ if socket != nil {
+ console, err := socket.ReceiveMaster()
+ if err != nil {
+ return fmt.Errorf("failed to retrieve console master: %w", err)
+ }
+ if e.console, err = e.parent.Platform.CopyConsole(ctx, console, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil {
+ return fmt.Errorf("failed to start console copy: %w", err)
+ }
+ } else if !e.stdio.IsNull() {
+ if err := copyPipes(ctx, e.io, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil {
+ return fmt.Errorf("failed to start io pipe copy: %w", err)
+ }
+ }
+ pid, err := runc.ReadPidFile(opts.PidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err)
+ }
+ e.pid = pid
+ internalPid, err := runc.ReadPidFile(opts.InternalPidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err)
+ }
+ e.internalPid = internalPid
+ go func() {
+ defer e.parent.Monitor.Unsubscribe(eventCh)
+ for event := range eventCh {
+ if event.Pid == e.pid {
+ ExitCh <- Exit{
+ Timestamp: event.Timestamp,
+ ID: e.id,
+ Status: event.Status,
+ }
+ break
+ }
+ }
+ }()
+ return nil
+}
+
+func (e *execProcess) Status(ctx context.Context) (string, error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ // if we don't have a pid then the exec process has just been created
+ if e.pid == 0 {
+ return "created", nil
+ }
+ // if we have a pid and it can be signaled, the process is running
+ // TODO(random-liu): Use `runsc kill --pid`.
+ if err := unix.Kill(e.pid, 0); err == nil {
+ return "running", nil
+ }
+ // else if we have a pid but it can nolonger be signaled, it has stopped
+ return "stopped", nil
+}
diff --git a/pkg/shim/v1/proc/exec_state.go b/pkg/shim/v1/proc/exec_state.go
new file mode 100644
index 000000000..4dcda8b44
--- /dev/null
+++ b/pkg/shim/v1/proc/exec_state.go
@@ -0,0 +1,154 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+)
+
+type execState interface {
+ Resize(console.WinSize) error
+ Start(context.Context) error
+ Delete(context.Context) error
+ Kill(context.Context, uint32, bool) error
+ SetExited(int)
+}
+
+type execCreatedState struct {
+ p *execProcess
+}
+
+func (s *execCreatedState) transition(name string) error {
+ switch name {
+ case "running":
+ s.p.execState = &execRunningState{p: s.p}
+ case "stopped":
+ s.p.execState = &execStoppedState{p: s.p}
+ case "deleted":
+ s.p.execState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execCreatedState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *execCreatedState) Start(ctx context.Context) error {
+ if err := s.p.start(ctx); err != nil {
+ return err
+ }
+ return s.transition("running")
+}
+
+func (s *execCreatedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execCreatedState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+type execRunningState struct {
+ p *execProcess
+}
+
+func (s *execRunningState) transition(name string) error {
+ switch name {
+ case "stopped":
+ s.p.execState = &execStoppedState{p: s.p}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execRunningState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *execRunningState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a running process")
+}
+
+func (s *execRunningState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a running process")
+}
+
+func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execRunningState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+type execStoppedState struct {
+ p *execProcess
+}
+
+func (s *execStoppedState) transition(name string) error {
+ switch name {
+ case "deleted":
+ s.p.execState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *execStoppedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a stopped container")
+}
+
+func (s *execStoppedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a stopped process")
+}
+
+func (s *execStoppedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *execStoppedState) SetExited(status int) {
+ // no op
+}
diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go
new file mode 100644
index 000000000..dab3123d6
--- /dev/null
+++ b/pkg/shim/v1/proc/init.go
@@ -0,0 +1,460 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+// InitPidFile name of the file that contains the init pid.
+const InitPidFile = "init.pid"
+
+// Init represents an initial process for a container.
+type Init struct {
+ wg sync.WaitGroup
+ initState initState
+
+ // mu is used to ensure that `Start()` and `Exited()` calls return in
+ // the right order when invoked in separate go routines. This is the
+ // case within the shim implementation as it makes use of the reaper
+ // interface.
+ mu sync.Mutex
+
+ waitBlock chan struct{}
+
+ WorkDir string
+
+ id string
+ Bundle string
+ console console.Console
+ Platform stdio.Platform
+ io runc.IO
+ runtime *runsc.Runsc
+ status int
+ exited time.Time
+ pid int
+ closers []io.Closer
+ stdin io.Closer
+ stdio stdio.Stdio
+ Rootfs string
+ IoUID int
+ IoGID int
+ Sandbox bool
+ UserLog string
+ Monitor ProcessMonitor
+}
+
+// NewRunsc returns a new runsc instance for a process.
+func NewRunsc(root, path, namespace, runtime string, config map[string]string) *runsc.Runsc {
+ if root == "" {
+ root = RunscRoot
+ }
+ return &runsc.Runsc{
+ Command: runtime,
+ PdeathSignal: syscall.SIGKILL,
+ Log: filepath.Join(path, "log.json"),
+ LogFormat: runc.JSON,
+ Root: filepath.Join(root, namespace),
+ Config: config,
+ }
+}
+
+// New returns a new init process.
+func New(id string, runtime *runsc.Runsc, stdio stdio.Stdio) *Init {
+ p := &Init{
+ id: id,
+ runtime: runtime,
+ stdio: stdio,
+ status: 0,
+ waitBlock: make(chan struct{}),
+ }
+ p.initState = &createdState{p: p}
+ return p
+}
+
+// Create the process with the provided config.
+func (p *Init) Create(ctx context.Context, r *CreateConfig) (err error) {
+ var socket *runc.Socket
+ if r.Terminal {
+ if socket, err = runc.NewTempConsoleSocket(); err != nil {
+ return fmt.Errorf("failed to create OCI runtime console socket: %w", err)
+ }
+ defer socket.Close()
+ } else if hasNoIO(r) {
+ if p.io, err = runc.NewNullIO(); err != nil {
+ return fmt.Errorf("creating new NULL IO: %w", err)
+ }
+ } else {
+ if p.io, err = runc.NewPipeIO(p.IoUID, p.IoGID, withConditionalIO(p.stdio)); err != nil {
+ return fmt.Errorf("failed to create OCI runtime io pipes: %w", err)
+ }
+ }
+ pidFile := filepath.Join(p.Bundle, InitPidFile)
+ opts := &runsc.CreateOpts{
+ PidFile: pidFile,
+ }
+ if socket != nil {
+ opts.ConsoleSocket = socket
+ }
+ if p.Sandbox {
+ opts.IO = p.io
+ // UserLog is only useful for sandbox.
+ opts.UserLog = p.UserLog
+ }
+ if err := p.runtime.Create(ctx, r.ID, r.Bundle, opts); err != nil {
+ return p.runtimeError(err, "OCI runtime create failed")
+ }
+ if r.Stdin != "" {
+ sc, err := fifo.OpenFifo(context.Background(), r.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("failed to open stdin fifo %s: %w", r.Stdin, err)
+ }
+ p.stdin = sc
+ p.closers = append(p.closers, sc)
+ }
+ ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ defer cancel()
+ if socket != nil {
+ console, err := socket.ReceiveMaster()
+ if err != nil {
+ return fmt.Errorf("failed to retrieve console master: %w", err)
+ }
+ console, err = p.Platform.CopyConsole(ctx, console, r.Stdin, r.Stdout, r.Stderr, &p.wg)
+ if err != nil {
+ return fmt.Errorf("failed to start console copy: %w", err)
+ }
+ p.console = console
+ } else if !hasNoIO(r) {
+ if err := copyPipes(ctx, p.io, r.Stdin, r.Stdout, r.Stderr, &p.wg); err != nil {
+ return fmt.Errorf("failed to start io pipe copy: %w", err)
+ }
+ }
+ pid, err := runc.ReadPidFile(pidFile)
+ if err != nil {
+ return fmt.Errorf("failed to retrieve OCI runtime container pid: %w", err)
+ }
+ p.pid = pid
+ return nil
+}
+
+// Wait waits for the process to exit.
+func (p *Init) Wait() {
+ <-p.waitBlock
+}
+
+// ID returns the ID of the process.
+func (p *Init) ID() string {
+ return p.id
+}
+
+// Pid returns the PID of the process.
+func (p *Init) Pid() int {
+ return p.pid
+}
+
+// ExitStatus returns the exit status of the process.
+func (p *Init) ExitStatus() int {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.status
+}
+
+// ExitedAt returns the time when the process exited.
+func (p *Init) ExitedAt() time.Time {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.exited
+}
+
+// Status returns the status of the process.
+func (p *Init) Status(ctx context.Context) (string, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ c, err := p.runtime.State(ctx, p.id)
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ return "stopped", nil
+ }
+ return "", p.runtimeError(err, "OCI runtime state failed")
+ }
+ return p.convertStatus(c.Status), nil
+}
+
+// Start starts the init process.
+func (p *Init) Start(ctx context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Start(ctx)
+}
+
+func (p *Init) start(ctx context.Context) error {
+ var cio runc.IO
+ if !p.Sandbox {
+ cio = p.io
+ }
+ if err := p.runtime.Start(ctx, p.id, cio); err != nil {
+ return p.runtimeError(err, "OCI runtime start failed")
+ }
+ go func() {
+ status, err := p.runtime.Wait(context.Background(), p.id)
+ if err != nil {
+ log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id)
+ // TODO(random-liu): Handle runsc kill error.
+ if err := p.killAll(ctx); err != nil {
+ log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id)
+ }
+ status = internalErrorCode
+ }
+ ExitCh <- Exit{
+ Timestamp: time.Now(),
+ ID: p.id,
+ Status: status,
+ }
+ }()
+ return nil
+}
+
+// SetExited set the exit stauts of the init process.
+func (p *Init) SetExited(status int) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.initState.SetExited(status)
+}
+
+func (p *Init) setExited(status int) {
+ p.exited = time.Now()
+ p.status = status
+ p.Platform.ShutdownConsole(context.Background(), p.console)
+ close(p.waitBlock)
+}
+
+// Delete deletes the init process.
+func (p *Init) Delete(ctx context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Delete(ctx)
+}
+
+func (p *Init) delete(ctx context.Context) error {
+ p.killAll(ctx)
+ p.wg.Wait()
+ err := p.runtime.Delete(ctx, p.id, nil)
+ // ignore errors if a runtime has already deleted the process
+ // but we still hold metadata and pipes
+ //
+ // this is common during a checkpoint, runc will delete the container state
+ // after a checkpoint and the container will no longer exist within runc
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ err = nil
+ } else {
+ err = p.runtimeError(err, "failed to delete task")
+ }
+ }
+ if p.io != nil {
+ for _, c := range p.closers {
+ c.Close()
+ }
+ p.io.Close()
+ }
+ if err2 := mount.UnmountAll(p.Rootfs, 0); err2 != nil {
+ log.G(ctx).WithError(err2).Warn("failed to cleanup rootfs mount")
+ if err == nil {
+ err = fmt.Errorf("failed rootfs umount: %w", err2)
+ }
+ }
+ return err
+}
+
+// Resize resizes the init processes console.
+func (p *Init) Resize(ws console.WinSize) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if p.console == nil {
+ return nil
+ }
+ return p.console.Resize(ws)
+}
+
+func (p *Init) resize(ws console.WinSize) error {
+ if p.console == nil {
+ return nil
+ }
+ return p.console.Resize(ws)
+}
+
+// Kill kills the init process.
+func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Kill(ctx, signal, all)
+}
+
+func (p *Init) kill(context context.Context, signal uint32, all bool) error {
+ var (
+ killErr error
+ backoff = 100 * time.Millisecond
+ )
+ timeout := 1 * time.Second
+ for start := time.Now(); time.Now().Sub(start) < timeout; {
+ c, err := p.runtime.State(context, p.id)
+ if err != nil {
+ if strings.Contains(err.Error(), "does not exist") {
+ return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
+ }
+ return p.runtimeError(err, "OCI runtime state failed")
+ }
+ // For runsc, signal only works when container is running state.
+ // If the container is not in running state, directly return
+ // "no such process"
+ if p.convertStatus(c.Status) == "stopped" {
+ return fmt.Errorf("no such process: %w", errdefs.ErrNotFound)
+ }
+ killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{
+ All: all,
+ })
+ if killErr == nil {
+ return nil
+ }
+ time.Sleep(backoff)
+ backoff *= 2
+ }
+ return p.runtimeError(killErr, "kill timeout")
+}
+
+// KillAll kills all processes belonging to the init process.
+func (p *Init) KillAll(context context.Context) error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.killAll(context)
+}
+
+func (p *Init) killAll(context context.Context) error {
+ p.runtime.Kill(context, p.id, int(syscall.SIGKILL), &runsc.KillOpts{
+ All: true,
+ })
+ // Ignore error handling for `runsc kill --all` for now.
+ // * If it doesn't return error, it is good;
+ // * If it returns error, consider the container has already stopped.
+ // TODO: Fix `runsc kill --all` error handling.
+ return nil
+}
+
+// Stdin returns the stdin of the process.
+func (p *Init) Stdin() io.Closer {
+ return p.stdin
+}
+
+// Runtime returns the OCI runtime configured for the init process.
+func (p *Init) Runtime() *runsc.Runsc {
+ return p.runtime
+}
+
+// Exec returns a new child process.
+func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ return p.initState.Exec(ctx, path, r)
+}
+
+// exec returns a new exec'd process.
+func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ // process exec request
+ var spec specs.Process
+ if err := json.Unmarshal(r.Spec.Value, &spec); err != nil {
+ return nil, err
+ }
+ spec.Terminal = r.Terminal
+
+ e := &execProcess{
+ id: r.ID,
+ path: path,
+ parent: p,
+ spec: spec,
+ stdio: stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ },
+ waitBlock: make(chan struct{}),
+ }
+ e.execState = &execCreatedState{p: e}
+ return e, nil
+}
+
+// Stdio returns the stdio of the process.
+func (p *Init) Stdio() stdio.Stdio {
+ return p.stdio
+}
+
+func (p *Init) runtimeError(rErr error, msg string) error {
+ if rErr == nil {
+ return nil
+ }
+
+ rMsg, err := getLastRuntimeError(p.runtime)
+ switch {
+ case err != nil:
+ return fmt.Errorf("%s: %w (unable to retrieve OCI runtime error: %v)", msg, rErr, err)
+ case rMsg == "":
+ return fmt.Errorf("%s: %w", msg, rErr)
+ default:
+ return fmt.Errorf("%s: %s", msg, rMsg)
+ }
+}
+
+func (p *Init) convertStatus(status string) string {
+ if status == "created" && !p.Sandbox && p.status == internalErrorCode {
+ // Treat start failure state for non-root container as stopped.
+ return "stopped"
+ }
+ return status
+}
+
+func withConditionalIO(c stdio.Stdio) runc.IOOpt {
+ return func(o *runc.IOOption) {
+ o.OpenStdin = c.Stdin != ""
+ o.OpenStdout = c.Stdout != ""
+ o.OpenStderr = c.Stderr != ""
+ }
+}
diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go
new file mode 100644
index 000000000..9233ecc85
--- /dev/null
+++ b/pkg/shim/v1/proc/init_state.go
@@ -0,0 +1,182 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/pkg/process"
+)
+
+type initState interface {
+ Resize(console.WinSize) error
+ Start(context.Context) error
+ Delete(context.Context) error
+ Exec(context.Context, string, *ExecConfig) (process.Process, error)
+ Kill(context.Context, uint32, bool) error
+ SetExited(int)
+}
+
+type createdState struct {
+ p *Init
+}
+
+func (s *createdState) transition(name string) error {
+ switch name {
+ case "running":
+ s.p.initState = &runningState{p: s.p}
+ case "stopped":
+ s.p.initState = &stoppedState{p: s.p}
+ case "deleted":
+ s.p.initState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *createdState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *createdState) Start(ctx context.Context) error {
+ if err := s.p.start(ctx); err != nil {
+ // Containerd doesn't allow deleting container in created state.
+ // However, for gvisor, a non-root container in created state can
+ // only go to running state. If the container can't be started,
+ // it can only stay in created state, and never be deleted.
+ // To work around that, we treat non-root container in start failure
+ // state as stopped.
+ if !s.p.Sandbox {
+ s.p.io.Close()
+ s.p.setExited(internalErrorCode)
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+ }
+ return err
+ }
+ return s.transition("running")
+}
+
+func (s *createdState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *createdState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return s.p.exec(ctx, path, r)
+}
+
+type runningState struct {
+ p *Init
+}
+
+func (s *runningState) transition(name string) error {
+ switch name {
+ case "stopped":
+ s.p.initState = &stoppedState{p: s.p}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *runningState) Resize(ws console.WinSize) error {
+ return s.p.resize(ws)
+}
+
+func (s *runningState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a running process.ss")
+}
+
+func (s *runningState) Delete(ctx context.Context) error {
+ return fmt.Errorf("cannot delete a running process.ss")
+}
+
+func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return s.p.kill(ctx, sig, all)
+}
+
+func (s *runningState) SetExited(status int) {
+ s.p.setExited(status)
+
+ if err := s.transition("stopped"); err != nil {
+ panic(err)
+ }
+}
+
+func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return s.p.exec(ctx, path, r)
+}
+
+type stoppedState struct {
+ p *Init
+}
+
+func (s *stoppedState) transition(name string) error {
+ switch name {
+ case "deleted":
+ s.p.initState = &deletedState{}
+ default:
+ return fmt.Errorf("invalid state transition %q to %q", stateName(s), name)
+ }
+ return nil
+}
+
+func (s *stoppedState) Resize(ws console.WinSize) error {
+ return fmt.Errorf("cannot resize a stopped container")
+}
+
+func (s *stoppedState) Start(ctx context.Context) error {
+ return fmt.Errorf("cannot start a stopped process.ss")
+}
+
+func (s *stoppedState) Delete(ctx context.Context) error {
+ if err := s.p.delete(ctx); err != nil {
+ return err
+ }
+ return s.transition("deleted")
+}
+
+func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error {
+ return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id)
+}
+
+func (s *stoppedState) SetExited(status int) {
+ // no op
+}
+
+func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) {
+ return nil, fmt.Errorf("cannot exec in a stopped state")
+}
diff --git a/pkg/shim/v1/proc/io.go b/pkg/shim/v1/proc/io.go
new file mode 100644
index 000000000..34d825fb7
--- /dev/null
+++ b/pkg/shim/v1/proc/io.go
@@ -0,0 +1,162 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/fifo"
+ runc "github.com/containerd/go-runc"
+)
+
+// TODO(random-liu): This file can be a util.
+
+var bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+}
+
+func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error {
+ var sameFile *countingWriteCloser
+ for _, i := range []struct {
+ name string
+ dest func(wc io.WriteCloser, rc io.Closer)
+ }{
+ {
+ name: stdout,
+ dest: func(wc io.WriteCloser, rc io.Closer) {
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ if _, err := io.CopyBuffer(wc, rio.Stdout(), *p); err != nil {
+ log.G(ctx).Warn("error copying stdout")
+ }
+ wg.Done()
+ wc.Close()
+ if rc != nil {
+ rc.Close()
+ }
+ }()
+ },
+ }, {
+ name: stderr,
+ dest: func(wc io.WriteCloser, rc io.Closer) {
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ if _, err := io.CopyBuffer(wc, rio.Stderr(), *p); err != nil {
+ log.G(ctx).Warn("error copying stderr")
+ }
+ wg.Done()
+ wc.Close()
+ if rc != nil {
+ rc.Close()
+ }
+ }()
+ },
+ },
+ } {
+ ok, err := isFifo(i.name)
+ if err != nil {
+ return err
+ }
+ var (
+ fw io.WriteCloser
+ fr io.Closer
+ )
+ if ok {
+ if fw, err = fifo.OpenFifo(ctx, i.name, syscall.O_WRONLY, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ if fr, err = fifo.OpenFifo(ctx, i.name, syscall.O_RDONLY, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ } else {
+ if sameFile != nil {
+ sameFile.count++
+ i.dest(sameFile, nil)
+ continue
+ }
+ if fw, err = os.OpenFile(i.name, syscall.O_WRONLY|syscall.O_APPEND, 0); err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err)
+ }
+ if stdout == stderr {
+ sameFile = &countingWriteCloser{
+ WriteCloser: fw,
+ count: 1,
+ }
+ }
+ }
+ i.dest(fw, fr)
+ }
+ if stdin == "" {
+ return nil
+ }
+ f, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", stdin, err)
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+
+ io.CopyBuffer(rio.Stdin(), f, *p)
+ rio.Stdin().Close()
+ f.Close()
+ }()
+ return nil
+}
+
+// countingWriteCloser masks io.Closer() until close has been invoked a certain number of times.
+type countingWriteCloser struct {
+ io.WriteCloser
+ count int64
+}
+
+func (c *countingWriteCloser) Close() error {
+ if atomic.AddInt64(&c.count, -1) > 0 {
+ return nil
+ }
+ return c.WriteCloser.Close()
+}
+
+// isFifo checks if a file is a fifo.
+//
+// If the file does not exist then it returns false.
+func isFifo(path string) (bool, error) {
+ stat, err := os.Stat(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, err
+ }
+ if stat.Mode()&os.ModeNamedPipe == os.ModeNamedPipe {
+ return true, nil
+ }
+ return false, nil
+}
diff --git a/test/root/testdata/simple.go b/pkg/shim/v1/proc/process.go
index 1cca53f0c..d462c3eef 100644
--- a/test/root/testdata/simple.go
+++ b/pkg/shim/v1/proc/process.go
@@ -1,10 +1,11 @@
+// Copyright 2018 The containerd Authors.
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
-// http://www.apache.org/licenses/LICENSE-2.0
+// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,30 +13,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testdata
+package proc
import (
- "encoding/json"
"fmt"
)
-// SimpleSpec returns a JSON config for a simple container that runs the
-// specified command in the specified image.
-func SimpleSpec(name, image string, cmd []string) string {
- cmds, err := json.Marshal(cmd)
- if err != nil {
- // This shouldn't happen.
- panic(err)
- }
- return fmt.Sprintf(`
-{
- "metadata": {
- "name": %q
- },
- "image": {
- "image": %q
- },
- "command": %s
+// RunscRoot is the path to the root runsc state directory.
+const RunscRoot = "/run/containerd/runsc"
+
+func stateName(v interface{}) string {
+ switch v.(type) {
+ case *runningState, *execRunningState:
+ return "running"
+ case *createdState, *execCreatedState:
+ return "created"
+ case *deletedState:
+ return "deleted"
+ case *stoppedState:
+ return "stopped"
}
-`, name, image, cmds)
+ panic(fmt.Errorf("invalid state %v", v))
}
diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go
new file mode 100644
index 000000000..2b0df4663
--- /dev/null
+++ b/pkg/shim/v1/proc/types.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "time"
+
+ runc "github.com/containerd/go-runc"
+ "github.com/gogo/protobuf/types"
+)
+
+// Mount holds filesystem mount configuration.
+type Mount struct {
+ Type string
+ Source string
+ Target string
+ Options []string
+}
+
+// CreateConfig hold task creation configuration.
+type CreateConfig struct {
+ ID string
+ Bundle string
+ Runtime string
+ Rootfs []Mount
+ Terminal bool
+ Stdin string
+ Stdout string
+ Stderr string
+ Options *types.Any
+}
+
+// ExecConfig holds exec creation configuration.
+type ExecConfig struct {
+ ID string
+ Terminal bool
+ Stdin string
+ Stdout string
+ Stderr string
+ Spec *types.Any
+}
+
+// Exit is the type of exit events.
+type Exit struct {
+ Timestamp time.Time
+ ID string
+ Status int
+}
+
+// ProcessMonitor monitors process exit changes.
+type ProcessMonitor interface {
+ // Subscribe to process exit changes
+ Subscribe() chan runc.Exit
+ // Unsubscribe to process exit changes
+ Unsubscribe(c chan runc.Exit)
+}
diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go
new file mode 100644
index 000000000..716de2f59
--- /dev/null
+++ b/pkg/shim/v1/proc/utils.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proc
+
+import (
+ "encoding/json"
+ "io"
+ "os"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+)
+
+const (
+ internalErrorCode = 128
+ bufferSize = 32
+)
+
+// ExitCh is the exit events channel for containers and exec processes
+// inside the sandbox.
+var ExitCh = make(chan Exit, bufferSize)
+
+// TODO(mlaventure): move to runc package?
+func getLastRuntimeError(r *runsc.Runsc) (string, error) {
+ if r.Log == "" {
+ return "", nil
+ }
+
+ f, err := os.OpenFile(r.Log, os.O_RDONLY, 0400)
+ if err != nil {
+ return "", err
+ }
+
+ var (
+ errMsg string
+ log struct {
+ Level string
+ Msg string
+ Time time.Time
+ }
+ )
+
+ dec := json.NewDecoder(f)
+ for err = nil; err == nil; {
+ if err = dec.Decode(&log); err != nil && err != io.EOF {
+ return "", err
+ }
+ if log.Level == "error" {
+ errMsg = strings.TrimSpace(log.Msg)
+ }
+ }
+
+ return errMsg, nil
+}
+
+func copyFile(to, from string) error {
+ ff, err := os.Open(from)
+ if err != nil {
+ return err
+ }
+ defer ff.Close()
+ tt, err := os.Create(to)
+ if err != nil {
+ return err
+ }
+ defer tt.Close()
+
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ _, err = io.CopyBuffer(tt, ff, *p)
+ return err
+}
+
+func hasNoIO(r *CreateConfig) bool {
+ return r.Stdin == "" && r.Stdout == "" && r.Stderr == ""
+}
diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD
new file mode 100644
index 000000000..05c595bc9
--- /dev/null
+++ b/pkg/shim/v1/shim/BUILD
@@ -0,0 +1,40 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "shim",
+ srcs = [
+ "api.go",
+ "platform.go",
+ "service.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/proc",
+ "//pkg/shim/v1/utils",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//api/events:go_default_library",
+ "@com_github_containerd_containerd//api/types/task:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_containerd//runtime:go_default_library",
+ "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library",
+ "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_google_grpc//codes:go_default_library",
+ "@org_golang_google_grpc//status:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v1/shim/api.go b/pkg/shim/v1/shim/api.go
new file mode 100644
index 000000000..5dd8ff172
--- /dev/null
+++ b/pkg/shim/v1/shim/api.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "github.com/containerd/containerd/api/events"
+)
+
+type TaskCreate = events.TaskCreate
+type TaskStart = events.TaskStart
+type TaskOOM = events.TaskOOM
+type TaskExit = events.TaskExit
+type TaskDelete = events.TaskDelete
+type TaskExecAdded = events.TaskExecAdded
+type TaskExecStarted = events.TaskExecStarted
diff --git a/pkg/shim/v1/shim/platform.go b/pkg/shim/v1/shim/platform.go
new file mode 100644
index 000000000..f590f80ef
--- /dev/null
+++ b/pkg/shim/v1/shim/platform.go
@@ -0,0 +1,106 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/console"
+ "github.com/containerd/fifo"
+)
+
+type linuxPlatform struct {
+ epoller *console.Epoller
+}
+
+func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) {
+ if p.epoller == nil {
+ return nil, fmt.Errorf("uninitialized epoller")
+ }
+
+ epollConsole, err := p.epoller.Add(console)
+ if err != nil {
+ return nil, err
+ }
+
+ if stdin != "" {
+ in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(epollConsole, in, *p)
+ }()
+ }
+
+ outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(outw, epollConsole, *p)
+ epollConsole.Close()
+ outr.Close()
+ outw.Close()
+ wg.Done()
+ }()
+ return epollConsole, nil
+}
+
+func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error {
+ if p.epoller == nil {
+ return fmt.Errorf("uninitialized epoller")
+ }
+ epollConsole, ok := cons.(*console.EpollConsole)
+ if !ok {
+ return fmt.Errorf("expected EpollConsole, got %#v", cons)
+ }
+ return epollConsole.Shutdown(p.epoller.CloseConsole)
+}
+
+func (p *linuxPlatform) Close() error {
+ return p.epoller.Close()
+}
+
+// initialize a single epoll fd to manage our consoles. `initPlatform` should
+// only be called once.
+func (s *Service) initPlatform() error {
+ if s.platform != nil {
+ return nil
+ }
+ epoller, err := console.NewEpoller()
+ if err != nil {
+ return fmt.Errorf("failed to initialize epoller: %w", err)
+ }
+ s.platform = &linuxPlatform{
+ epoller: epoller,
+ }
+ go epoller.Wait()
+ return nil
+}
diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go
new file mode 100644
index 000000000..84a810cb2
--- /dev/null
+++ b/pkg/shim/v1/shim/service.go
@@ -0,0 +1,573 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package shim
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/api/types/task"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/containerd/runtime"
+ "github.com/containerd/containerd/runtime/linux/runctypes"
+ shim "github.com/containerd/containerd/runtime/v1/shim/v1"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/proc"
+ "gvisor.dev/gvisor/pkg/shim/v1/utils"
+)
+
+var (
+ empty = &types.Empty{}
+ bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+ }
+)
+
+// Config contains shim specific configuration.
+type Config struct {
+ Path string
+ Namespace string
+ WorkDir string
+ RuntimeRoot string
+ RunscConfig map[string]string
+}
+
+// NewService returns a new shim service that can be used via GRPC.
+func NewService(config Config, publisher events.Publisher) (*Service, error) {
+ if config.Namespace == "" {
+ return nil, fmt.Errorf("shim namespace cannot be empty")
+ }
+ ctx := namespaces.WithNamespace(context.Background(), config.Namespace)
+ s := &Service{
+ config: config,
+ context: ctx,
+ processes: make(map[string]process.Process),
+ events: make(chan interface{}, 128),
+ ec: proc.ExitCh,
+ }
+ go s.processExits()
+ if err := s.initPlatform(); err != nil {
+ return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
+ }
+ go s.forward(publisher)
+ return s, nil
+}
+
+// Service is the shim implementation of a remote shim over GRPC.
+type Service struct {
+ mu sync.Mutex
+
+ config Config
+ context context.Context
+ processes map[string]process.Process
+ events chan interface{}
+ platform stdio.Platform
+ ec chan proc.Exit
+
+ // Filled by Create()
+ id string
+ bundle string
+}
+
+// Create creates a new initial process and container with the underlying OCI runtime.
+func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shim.CreateTaskResponse, err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ var mounts []proc.Mount
+ for _, m := range r.Rootfs {
+ mounts = append(mounts, proc.Mount{
+ Type: m.Type,
+ Source: m.Source,
+ Target: m.Target,
+ Options: m.Options,
+ })
+ }
+
+ rootfs := filepath.Join(r.Bundle, "rootfs")
+ if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ config := &proc.CreateConfig{
+ ID: r.ID,
+ Bundle: r.Bundle,
+ Runtime: r.Runtime,
+ Rootfs: mounts,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Options: r.Options,
+ }
+ defer func() {
+ if err != nil {
+ if err2 := mount.UnmountAll(rootfs, 0); err2 != nil {
+ log.G(ctx).WithError(err2).Warn("Failed to cleanup rootfs mount")
+ }
+ }
+ }()
+ for _, rm := range mounts {
+ m := &mount.Mount{
+ Type: rm.Type,
+ Source: rm.Source,
+ Options: rm.Options,
+ }
+ if err := m.Mount(rootfs); err != nil {
+ return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err)
+ }
+ }
+ process, err := newInit(
+ ctx,
+ s.config.Path,
+ s.config.WorkDir,
+ s.config.RuntimeRoot,
+ s.config.Namespace,
+ s.config.RunscConfig,
+ s.platform,
+ config,
+ )
+ if err := process.Create(ctx, config); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ // Save the main task id and bundle to the shim for additional
+ // requests.
+ s.id = r.ID
+ s.bundle = r.Bundle
+ pid := process.Pid()
+ s.processes[r.ID] = process
+ return &shim.CreateTaskResponse{
+ Pid: uint32(pid),
+ }, nil
+}
+
+// Start starts a process.
+func (s *Service) Start(ctx context.Context, r *shim.StartRequest) (*shim.StartResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Start(ctx); err != nil {
+ return nil, err
+ }
+ return &shim.StartResponse{
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Delete deletes the initial process and container.
+func (s *Service) Delete(ctx context.Context, r *types.Empty) (*shim.DeleteResponse, error) {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ s.mu.Lock()
+ delete(s.processes, s.id)
+ s.mu.Unlock()
+ s.platform.Close()
+ return &shim.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// DeleteProcess deletes an exec'd process.
+func (s *Service) DeleteProcess(ctx context.Context, r *shim.DeleteProcessRequest) (*shim.DeleteResponse, error) {
+ if r.ID == s.id {
+ return nil, status.Errorf(codes.InvalidArgument, "cannot delete init process with DeleteProcess")
+ }
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ s.mu.Lock()
+ delete(s.processes, r.ID)
+ s.mu.Unlock()
+ return &shim.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Exec spawns an additional process inside the container.
+func (s *Service) Exec(ctx context.Context, r *shim.ExecProcessRequest) (*types.Empty, error) {
+ s.mu.Lock()
+
+ if p := s.processes[r.ID]; p != nil {
+ s.mu.Unlock()
+ return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ID)
+ }
+
+ p := s.processes[s.id]
+ s.mu.Unlock()
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+
+ process, err := p.(*proc.Init).Exec(ctx, s.config.Path, &proc.ExecConfig{
+ ID: r.ID,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Spec: r.Spec,
+ })
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ s.mu.Lock()
+ s.processes[r.ID] = process
+ s.mu.Unlock()
+ return empty, nil
+}
+
+// ResizePty resises the terminal of a process.
+func (s *Service) ResizePty(ctx context.Context, r *shim.ResizePtyRequest) (*types.Empty, error) {
+ if r.ID == "" {
+ return nil, errdefs.ToGRPCf(errdefs.ErrInvalidArgument, "id not provided")
+ }
+ ws := console.WinSize{
+ Width: uint16(r.Width),
+ Height: uint16(r.Height),
+ }
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Resize(ws); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// State returns runtime state information for a process.
+func (s *Service) State(ctx context.Context, r *shim.StateRequest) (*shim.StateResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ st, err := p.Status(ctx)
+ if err != nil {
+ return nil, err
+ }
+ status := task.StatusUnknown
+ switch st {
+ case "created":
+ status = task.StatusCreated
+ case "running":
+ status = task.StatusRunning
+ case "stopped":
+ status = task.StatusStopped
+ }
+ sio := p.Stdio()
+ return &shim.StateResponse{
+ ID: p.ID(),
+ Bundle: s.bundle,
+ Pid: uint32(p.Pid()),
+ Status: status,
+ Stdin: sio.Stdin,
+ Stdout: sio.Stdout,
+ Stderr: sio.Stderr,
+ Terminal: sio.Terminal,
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+// Pause pauses the container.
+func (s *Service) Pause(ctx context.Context, r *types.Empty) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Resume resumes the container.
+func (s *Service) Resume(ctx context.Context, r *types.Empty) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Kill kills a process with the provided signal.
+func (s *Service) Kill(ctx context.Context, r *shim.KillRequest) (*types.Empty, error) {
+ if r.ID == "" {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+ }
+
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// ListPids returns all pids inside the container.
+func (s *Service) ListPids(ctx context.Context, r *shim.ListPidsRequest) (*shim.ListPidsResponse, error) {
+ pids, err := s.getContainerPids(ctx, r.ID)
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ var processes []*task.ProcessInfo
+ for _, pid := range pids {
+ pInfo := task.ProcessInfo{
+ Pid: pid,
+ }
+ for _, p := range s.processes {
+ if p.Pid() == int(pid) {
+ d := &runctypes.ProcessDetails{
+ ExecID: p.ID(),
+ }
+ a, err := typeurl.MarshalAny(d)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err)
+ }
+ pInfo.Info = a
+ break
+ }
+ }
+ processes = append(processes, &pInfo)
+ }
+ return &shim.ListPidsResponse{
+ Processes: processes,
+ }, nil
+}
+
+// CloseIO closes the I/O context of a process.
+func (s *Service) CloseIO(ctx context.Context, r *shim.CloseIORequest) (*types.Empty, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ if stdin := p.Stdin(); stdin != nil {
+ if err := stdin.Close(); err != nil {
+ return nil, fmt.Errorf("close stdin: %w", err)
+ }
+ }
+ return empty, nil
+}
+
+// Checkpoint checkpoints the container.
+func (s *Service) Checkpoint(ctx context.Context, r *shim.CheckpointTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// ShimInfo returns shim information such as the shim's pid.
+func (s *Service) ShimInfo(ctx context.Context, r *types.Empty) (*shim.ShimInfoResponse, error) {
+ return &shim.ShimInfoResponse{
+ ShimPid: uint32(os.Getpid()),
+ }, nil
+}
+
+// Update updates a running container.
+func (s *Service) Update(ctx context.Context, r *shim.UpdateTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Wait waits for a process to exit.
+func (s *Service) Wait(ctx context.Context, r *shim.WaitRequest) (*shim.WaitResponse, error) {
+ p, err := s.getExecProcess(r.ID)
+ if err != nil {
+ return nil, err
+ }
+ p.Wait()
+
+ return &shim.WaitResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+func (s *Service) processExits() {
+ for e := range s.ec {
+ s.checkProcesses(e)
+ }
+}
+
+func (s *Service) allProcesses() []process.Process {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ res := make([]process.Process, 0, len(s.processes))
+ for _, p := range s.processes {
+ res = append(res, p)
+ }
+ return res
+}
+
+func (s *Service) checkProcesses(e proc.Exit) {
+ for _, p := range s.allProcesses() {
+ if p.ID() == e.ID {
+ if ip, ok := p.(*proc.Init); ok {
+ // Ensure all children are killed.
+ if err := ip.KillAll(s.context); err != nil {
+ log.G(s.context).WithError(err).WithField("id", ip.ID()).
+ Error("failed to kill init's children")
+ }
+ }
+ p.SetExited(e.Status)
+ s.events <- &TaskExit{
+ ContainerID: s.id,
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ ExitStatus: uint32(e.Status),
+ ExitedAt: p.ExitedAt(),
+ }
+ return
+ }
+ }
+}
+
+func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) {
+ p, err := s.getInitProcess()
+ if err != nil {
+ return nil, err
+ }
+
+ ps, err := p.(*proc.Init).Runtime().Ps(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ pids := make([]uint32, 0, len(ps))
+ for _, pid := range ps {
+ pids = append(pids, uint32(pid))
+ }
+ return pids, nil
+}
+
+func (s *Service) forward(publisher events.Publisher) {
+ for e := range s.events {
+ if err := publisher.Publish(s.context, getTopic(s.context, e), e); err != nil {
+ log.G(s.context).WithError(err).Error("post event")
+ }
+ }
+}
+
+// getInitProcess returns the init process.
+func (s *Service) getInitProcess() (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ p := s.processes[s.id]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ return p, nil
+}
+
+// getExecProcess returns the given exec process.
+func (s *Service) getExecProcess(id string) (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ p := s.processes[id]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process %s does not exist", id)
+ }
+ return p, nil
+}
+
+func getTopic(ctx context.Context, e interface{}) string {
+ switch e.(type) {
+ case *TaskCreate:
+ return runtime.TaskCreateEventTopic
+ case *TaskStart:
+ return runtime.TaskStartEventTopic
+ case *TaskOOM:
+ return runtime.TaskOOMEventTopic
+ case *TaskExit:
+ return runtime.TaskExitEventTopic
+ case *TaskDelete:
+ return runtime.TaskDeleteEventTopic
+ case *TaskExecAdded:
+ return runtime.TaskExecAddedEventTopic
+ case *TaskExecStarted:
+ return runtime.TaskExecStartedEventTopic
+ default:
+ log.L.Printf("no topic for type %#v", e)
+ }
+ return runtime.TaskUnknownTopic
+}
+
+func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) {
+ var options runctypes.CreateOptions
+ if r.Options != nil {
+ v, err := typeurl.UnmarshalAny(r.Options)
+ if err != nil {
+ return nil, err
+ }
+ options = *v.(*runctypes.CreateOptions)
+ }
+
+ spec, err := utils.ReadSpec(r.Bundle)
+ if err != nil {
+ return nil, fmt.Errorf("read oci spec: %w", err)
+ }
+ if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil {
+ return nil, fmt.Errorf("update volume annotations: %w", err)
+ }
+
+ runsc.FormatLogPath(r.ID, config)
+ rootfs := filepath.Join(path, "rootfs")
+ runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config)
+ p := proc.New(r.ID, runtime, stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ })
+ p.Bundle = r.Bundle
+ p.Platform = platform
+ p.Rootfs = rootfs
+ p.WorkDir = workDir
+ p.IoUID = int(options.IoUid)
+ p.IoGID = int(options.IoGid)
+ p.Sandbox = utils.IsSandbox(spec)
+ p.UserLog = utils.UserLogPath(spec)
+ p.Monitor = reaper.Default
+ return p, nil
+}
diff --git a/pkg/shim/v1/utils/BUILD b/pkg/shim/v1/utils/BUILD
new file mode 100644
index 000000000..54a0aabb7
--- /dev/null
+++ b/pkg/shim/v1/utils/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "utils",
+ srcs = [
+ "annotations.go",
+ "utils.go",
+ "volumes.go",
+ ],
+ visibility = [
+ "//pkg/shim:__subpackages__",
+ "//shim:__subpackages__",
+ ],
+ deps = [
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
+)
+
+go_test(
+ name = "utils_test",
+ size = "small",
+ srcs = ["volumes_test.go"],
+ library = ":utils",
+ deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"],
+)
diff --git a/pkg/shim/v1/utils/annotations.go b/pkg/shim/v1/utils/annotations.go
new file mode 100644
index 000000000..1e9d3f365
--- /dev/null
+++ b/pkg/shim/v1/utils/annotations.go
@@ -0,0 +1,25 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+// Annotations from the CRI annotations package.
+//
+// These are vendor due to import conflicts.
+const (
+ sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory"
+ containerTypeAnnotation = "io.kubernetes.cri.container-type"
+ containerTypeSandbox = "sandbox"
+ containerTypeContainer = "container"
+)
diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go
new file mode 100644
index 000000000..07e346654
--- /dev/null
+++ b/pkg/shim/v1/utils/utils.go
@@ -0,0 +1,56 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+// ReadSpec reads OCI spec from the bundle directory.
+func ReadSpec(bundle string) (*specs.Spec, error) {
+ f, err := os.Open(filepath.Join(bundle, "config.json"))
+ if err != nil {
+ return nil, err
+ }
+ b, err := ioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+ var spec specs.Spec
+ if err := json.Unmarshal(b, &spec); err != nil {
+ return nil, err
+ }
+ return &spec, nil
+}
+
+// IsSandbox checks whether a container is a sandbox container.
+func IsSandbox(spec *specs.Spec) bool {
+ t, ok := spec.Annotations[containerTypeAnnotation]
+ return !ok || t == containerTypeSandbox
+}
+
+// UserLogPath gets user log path from OCI annotation.
+func UserLogPath(spec *specs.Spec) string {
+ sandboxLogDir := spec.Annotations[sandboxLogDirAnnotation]
+ if sandboxLogDir == "" {
+ return ""
+ }
+ return filepath.Join(sandboxLogDir, "gvisor.log")
+}
diff --git a/pkg/shim/v1/utils/volumes.go b/pkg/shim/v1/utils/volumes.go
new file mode 100644
index 000000000..52a428179
--- /dev/null
+++ b/pkg/shim/v1/utils/volumes.go
@@ -0,0 +1,155 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "path/filepath"
+ "strings"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+const volumeKeyPrefix = "dev.gvisor.spec.mount."
+
+var kubeletPodsDir = "/var/lib/kubelet/pods"
+
+// volumeName gets volume name from volume annotation key, example:
+// dev.gvisor.spec.mount.NAME.share
+func volumeName(k string) string {
+ return strings.SplitN(strings.TrimPrefix(k, volumeKeyPrefix), ".", 2)[0]
+}
+
+// volumeFieldName gets volume field name from volume annotation key, example:
+// `type` is the field of dev.gvisor.spec.mount.NAME.type
+func volumeFieldName(k string) string {
+ parts := strings.Split(strings.TrimPrefix(k, volumeKeyPrefix), ".")
+ return parts[len(parts)-1]
+}
+
+// podUID gets pod UID from the pod log path.
+func podUID(s *specs.Spec) (string, error) {
+ sandboxLogDir := s.Annotations[sandboxLogDirAnnotation]
+ if sandboxLogDir == "" {
+ return "", fmt.Errorf("no sandbox log path annotation")
+ }
+ fields := strings.Split(filepath.Base(sandboxLogDir), "_")
+ switch len(fields) {
+ case 1: // This is the old CRI logging path.
+ return fields[0], nil
+ case 3: // This is the new CRI logging path.
+ return fields[2], nil
+ }
+ return "", fmt.Errorf("unexpected sandbox log path %q", sandboxLogDir)
+}
+
+// isVolumeKey checks whether an annotation key is for volume.
+func isVolumeKey(k string) bool {
+ return strings.HasPrefix(k, volumeKeyPrefix)
+}
+
+// volumeSourceKey constructs the annotation key for volume source.
+func volumeSourceKey(volume string) string {
+ return volumeKeyPrefix + volume + ".source"
+}
+
+// volumePath searches the volume path in the kubelet pod directory.
+func volumePath(volume, uid string) (string, error) {
+ // TODO: Support subpath when gvisor supports pod volume bind mount.
+ volumeSearchPath := fmt.Sprintf("%s/%s/volumes/*/%s", kubeletPodsDir, uid, volume)
+ dirs, err := filepath.Glob(volumeSearchPath)
+ if err != nil {
+ return "", err
+ }
+ if len(dirs) != 1 {
+ return "", fmt.Errorf("unexpected matched volume list %v", dirs)
+ }
+ return dirs[0], nil
+}
+
+// isVolumePath checks whether a string is the volume path.
+func isVolumePath(volume, path string) (bool, error) {
+ // TODO: Support subpath when gvisor supports pod volume bind mount.
+ volumeSearchPath := fmt.Sprintf("%s/*/volumes/*/%s", kubeletPodsDir, volume)
+ return filepath.Match(volumeSearchPath, path)
+}
+
+// UpdateVolumeAnnotations add necessary OCI annotations for gvisor
+// volume optimization.
+func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error {
+ var (
+ uid string
+ err error
+ )
+ if IsSandbox(s) {
+ uid, err = podUID(s)
+ if err != nil {
+ // Skip if we can't get pod UID, because this doesn't work
+ // for containerd 1.1.
+ return nil
+ }
+ }
+ var updated bool
+ for k, v := range s.Annotations {
+ if !isVolumeKey(k) {
+ continue
+ }
+ if volumeFieldName(k) != "type" {
+ continue
+ }
+ volume := volumeName(k)
+ if uid != "" {
+ // This is a sandbox.
+ path, err := volumePath(volume, uid)
+ if err != nil {
+ return fmt.Errorf("get volume path for %q: %w", volume, err)
+ }
+ s.Annotations[volumeSourceKey(volume)] = path
+ updated = true
+ } else {
+ // This is a container.
+ for i := range s.Mounts {
+ // An error is returned for sandbox if source
+ // annotation is not successfully applied, so
+ // it is guaranteed that the source annotation
+ // for sandbox has already been successfully
+ // applied at this point.
+ //
+ // The volume name is unique inside a pod, so
+ // matching without podUID is fine here.
+ //
+ // TODO: Pass podUID down to shim for containers to do
+ // more accurate matching.
+ if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes {
+ // gVisor requires the container mount type to match
+ // sandbox mount type.
+ s.Mounts[i].Type = v
+ updated = true
+ }
+ }
+ }
+ }
+ if !updated {
+ return nil
+ }
+ // Update bundle.
+ b, err := json.Marshal(s)
+ if err != nil {
+ return err
+ }
+ return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666)
+}
diff --git a/pkg/shim/v1/utils/volumes_test.go b/pkg/shim/v1/utils/volumes_test.go
new file mode 100644
index 000000000..3e02c6151
--- /dev/null
+++ b/pkg/shim/v1/utils/volumes_test.go
@@ -0,0 +1,308 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package utils
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "reflect"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+)
+
+func TestUpdateVolumeAnnotations(t *testing.T) {
+ dir, err := ioutil.TempDir("", "test-update-volume-annotations")
+ if err != nil {
+ t.Fatalf("create tempdir: %v", err)
+ }
+ defer os.RemoveAll(dir)
+ kubeletPodsDir = dir
+
+ const (
+ testPodUID = "testuid"
+ testVolumeName = "testvolume"
+ testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID
+ testLegacyLogDirPath = "/var/log/pods/" + testPodUID
+ )
+ testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName)
+
+ if err := os.MkdirAll(testVolumePath, 0755); err != nil {
+ t.Fatalf("Create test volume: %v", err)
+ }
+
+ for _, test := range []struct {
+ desc string
+ spec *specs.Spec
+ expected *specs.Spec
+ expectErr bool
+ expectUpdate bool
+ }{
+ {
+ desc: "volume annotations for sandbox",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "volume annotations for sandbox with legacy log path",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLegacyLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath,
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "tmpfs: volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "tmpfs",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "bind: volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: testVolumePath,
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "container",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expectUpdate: true,
+ },
+ {
+ desc: "should not return error without pod log directory",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod",
+ "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs",
+ "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro",
+ },
+ },
+ },
+ {
+ desc: "should return error if volume path does not exist",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ "dev.gvisor.spec.mount.notexist.share": "pod",
+ "dev.gvisor.spec.mount.notexist.type": "tmpfs",
+ "dev.gvisor.spec.mount.notexist.options": "ro",
+ },
+ },
+ expectErr: true,
+ },
+ {
+ desc: "no volume annotations for sandbox",
+ spec: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ },
+ },
+ expected: &specs.Spec{
+ Annotations: map[string]string{
+ sandboxLogDirAnnotation: testLogDirPath,
+ containerTypeAnnotation: containerTypeSandbox,
+ },
+ },
+ },
+ {
+ desc: "no volume annotations for container",
+ spec: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: "/test",
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ },
+ },
+ expected: &specs.Spec{
+ Mounts: []specs.Mount{
+ {
+ Destination: "/test",
+ Type: "bind",
+ Source: "/test",
+ Options: []string{"ro"},
+ },
+ {
+ Destination: "/random",
+ Type: "bind",
+ Source: "/random",
+ Options: []string{"ro"},
+ },
+ },
+ Annotations: map[string]string{
+ containerTypeAnnotation: containerTypeContainer,
+ },
+ },
+ },
+ } {
+ t.Run(test.desc, func(t *testing.T) {
+ bundle, err := ioutil.TempDir(dir, "test-bundle")
+ if err != nil {
+ t.Fatalf("Create test bundle: %v", err)
+ }
+ err = UpdateVolumeAnnotations(bundle, test.spec)
+ if test.expectErr {
+ if err == nil {
+ t.Fatal("Expected error, but got nil")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ if !reflect.DeepEqual(test.expected, test.spec) {
+ t.Fatalf("Expected %+v, got %+v", test.expected, test.spec)
+ }
+ if test.expectUpdate {
+ b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json"))
+ if err != nil {
+ t.Fatalf("Read spec from bundle: %v", err)
+ }
+ var spec specs.Spec
+ if err := json.Unmarshal(b, &spec); err != nil {
+ t.Fatalf("Unmarshal spec: %v", err)
+ }
+ if !reflect.DeepEqual(test.expected, &spec) {
+ t.Fatalf("Expected %+v, got %+v", test.expected, &spec)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD
new file mode 100644
index 000000000..7e0a114a0
--- /dev/null
+++ b/pkg/shim/v2/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "v2",
+ srcs = [
+ "api.go",
+ "epoll.go",
+ "service.go",
+ "service_linux.go",
+ ],
+ visibility = ["//shim:__subpackages__"],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/proc",
+ "//pkg/shim/v1/utils",
+ "//pkg/shim/v2/options",
+ "//pkg/shim/v2/runtimeoptions",
+ "//runsc/specutils",
+ "@com_github_burntsushi_toml//:go_default_library",
+ "@com_github_containerd_cgroups//:go_default_library",
+ "@com_github_containerd_console//:go_default_library",
+ "@com_github_containerd_containerd//api/events:go_default_library",
+ "@com_github_containerd_containerd//api/types/task:go_default_library",
+ "@com_github_containerd_containerd//errdefs:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//log:go_default_library",
+ "@com_github_containerd_containerd//mount:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//pkg/process:go_default_library",
+ "@com_github_containerd_containerd//pkg/stdio:go_default_library",
+ "@com_github_containerd_containerd//runtime:go_default_library",
+ "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library",
+ "@com_github_containerd_containerd//runtime/v2/shim:go_default_library",
+ "@com_github_containerd_containerd//runtime/v2/task:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_fifo//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v2/api.go b/pkg/shim/v2/api.go
new file mode 100644
index 000000000..dbe5c59f6
--- /dev/null
+++ b/pkg/shim/v2/api.go
@@ -0,0 +1,22 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v2
+
+import (
+ "github.com/containerd/containerd/api/events"
+)
+
+type TaskOOM = events.TaskOOM
diff --git a/pkg/shim/v2/epoll.go b/pkg/shim/v2/epoll.go
new file mode 100644
index 000000000..41232cca8
--- /dev/null
+++ b/pkg/shim/v2/epoll.go
@@ -0,0 +1,129 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ "github.com/containerd/cgroups"
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/runtime"
+ "golang.org/x/sys/unix"
+)
+
+func newOOMEpoller(publisher events.Publisher) (*epoller, error) {
+ fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
+ if err != nil {
+ return nil, err
+ }
+ return &epoller{
+ fd: fd,
+ publisher: publisher,
+ set: make(map[uintptr]*item),
+ }, nil
+}
+
+type epoller struct {
+ mu sync.Mutex
+
+ fd int
+ publisher events.Publisher
+ set map[uintptr]*item
+}
+
+type item struct {
+ id string
+ cg cgroups.Cgroup
+}
+
+func (e *epoller) Close() error {
+ return unix.Close(e.fd)
+}
+
+func (e *epoller) run(ctx context.Context) {
+ var events [128]unix.EpollEvent
+ for {
+ select {
+ case <-ctx.Done():
+ e.Close()
+ return
+ default:
+ n, err := unix.EpollWait(e.fd, events[:], -1)
+ if err != nil {
+ if err == unix.EINTR || err == unix.EAGAIN {
+ continue
+ }
+ // Should not happen.
+ panic(fmt.Errorf("cgroups: epoll wait: %w", err))
+ }
+ for i := 0; i < n; i++ {
+ e.process(ctx, uintptr(events[i].Fd))
+ }
+ }
+ }
+}
+
+func (e *epoller) add(id string, cg cgroups.Cgroup) error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ fd, err := cg.OOMEventFD()
+ if err != nil {
+ return err
+ }
+ e.set[fd] = &item{
+ id: id,
+ cg: cg,
+ }
+ event := unix.EpollEvent{
+ Fd: int32(fd),
+ Events: unix.EPOLLHUP | unix.EPOLLIN | unix.EPOLLERR,
+ }
+ return unix.EpollCtl(e.fd, unix.EPOLL_CTL_ADD, int(fd), &event)
+}
+
+func (e *epoller) process(ctx context.Context, fd uintptr) {
+ flush(fd)
+ e.mu.Lock()
+ i, ok := e.set[fd]
+ if !ok {
+ e.mu.Unlock()
+ return
+ }
+ e.mu.Unlock()
+ if i.cg.State() == cgroups.Deleted {
+ e.mu.Lock()
+ delete(e.set, fd)
+ e.mu.Unlock()
+ unix.Close(int(fd))
+ return
+ }
+ if err := e.publisher.Publish(ctx, runtime.TaskOOMEventTopic, &TaskOOM{
+ ContainerID: i.id,
+ }); err != nil {
+ // Should not happen.
+ panic(fmt.Errorf("publish OOM event: %w", err))
+ }
+}
+
+func flush(fd uintptr) error {
+ var buf [8]byte
+ _, err := unix.Read(int(fd), buf[:])
+ return err
+}
diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD
new file mode 100644
index 000000000..ca212e874
--- /dev/null
+++ b/pkg/shim/v2/options/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "options",
+ srcs = [
+ "options.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go
new file mode 100644
index 000000000..de09f2f79
--- /dev/null
+++ b/pkg/shim/v2/options/options.go
@@ -0,0 +1,33 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package options
+
+const OptionType = "io.containerd.runsc.v1.options"
+
+// Options is runtime options for io.containerd.runsc.v1.
+type Options struct {
+ // ShimCgroup is the cgroup the shim should be in.
+ ShimCgroup string `toml:"shim_cgroup"`
+ // IoUid is the I/O's pipes uid.
+ IoUid uint32 `toml:"io_uid"`
+ // IoUid is the I/O's pipes gid.
+ IoGid uint32 `toml:"io_gid"`
+ // BinaryName is the binary name of the runsc binary.
+ BinaryName string `toml:"binary_name"`
+ // Root is the runsc root directory.
+ Root string `toml:"root"`
+ // RunscConfig is a key/value map of all runsc flags.
+ RunscConfig map[string]string `toml:"runsc_config"`
+}
diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD
new file mode 100644
index 000000000..01716034c
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library", "proto_library")
+
+package(licenses = ["notice"])
+
+proto_library(
+ name = "api",
+ srcs = [
+ "runtimeoptions.proto",
+ ],
+)
+
+go_library(
+ name = "runtimeoptions",
+ srcs = ["runtimeoptions.go"],
+ visibility = ["//pkg/shim/v2:__pkg__"],
+ deps = [
+ "//pkg/shim/v2/runtimeoptions:api_go_proto",
+ "@com_github_gogo_protobuf//proto:go_default_library",
+ ],
+)
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
new file mode 100644
index 000000000..1c1a0c5d1
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package runtimeoptions
+
+import (
+ proto "github.com/gogo/protobuf/proto"
+ pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto"
+)
+
+type Options = pb.Options
+
+func init() {
+ proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options")
+}
diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto
new file mode 100644
index 000000000..edb19020a
--- /dev/null
+++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto
@@ -0,0 +1,25 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package runtimeoptions;
+
+// This is a version of the runtimeoptions CRI API that is vendored.
+//
+// Imported the full CRI package is a nightmare.
+message Options {
+ string type_url = 1;
+ string config_path = 2;
+}
diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go
new file mode 100644
index 000000000..1534152fc
--- /dev/null
+++ b/pkg/shim/v2/service.go
@@ -0,0 +1,824 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/BurntSushi/toml"
+ "github.com/containerd/cgroups"
+ "github.com/containerd/console"
+ "github.com/containerd/containerd/api/events"
+ "github.com/containerd/containerd/api/types/task"
+ "github.com/containerd/containerd/errdefs"
+ "github.com/containerd/containerd/log"
+ "github.com/containerd/containerd/mount"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/pkg/process"
+ "github.com/containerd/containerd/pkg/stdio"
+ "github.com/containerd/containerd/runtime"
+ "github.com/containerd/containerd/runtime/linux/runctypes"
+ "github.com/containerd/containerd/runtime/v2/shim"
+ taskAPI "github.com/containerd/containerd/runtime/v2/task"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/proc"
+ "gvisor.dev/gvisor/pkg/shim/v1/utils"
+ "gvisor.dev/gvisor/pkg/shim/v2/options"
+ "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+var (
+ empty = &types.Empty{}
+ bufPool = sync.Pool{
+ New: func() interface{} {
+ buffer := make([]byte, 32<<10)
+ return &buffer
+ },
+ }
+)
+
+var _ = (taskAPI.TaskService)(&service{})
+
+// configFile is the default config file name. For containerd 1.2,
+// we assume that a config.toml should exist in the runtime root.
+const configFile = "config.toml"
+
+// New returns a new shim service that can be used via GRPC.
+func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) {
+ ep, err := newOOMEpoller(publisher)
+ if err != nil {
+ return nil, err
+ }
+ go ep.run(ctx)
+ s := &service{
+ id: id,
+ context: ctx,
+ processes: make(map[string]process.Process),
+ events: make(chan interface{}, 128),
+ ec: proc.ExitCh,
+ oomPoller: ep,
+ cancel: cancel,
+ }
+ go s.processExits()
+ runsc.Monitor = reaper.Default
+ if err := s.initPlatform(); err != nil {
+ cancel()
+ return nil, fmt.Errorf("failed to initialized platform behavior: %w", err)
+ }
+ go s.forward(publisher)
+ return s, nil
+}
+
+// service is the shim implementation of a remote shim over GRPC.
+type service struct {
+ mu sync.Mutex
+
+ context context.Context
+ task process.Process
+ processes map[string]process.Process
+ events chan interface{}
+ platform stdio.Platform
+ opts options.Options
+ ec chan proc.Exit
+ oomPoller *epoller
+
+ id string
+ bundle string
+ cancel func()
+}
+
+func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) {
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ self, err := os.Executable()
+ if err != nil {
+ return nil, err
+ }
+ cwd, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ args := []string{
+ "-namespace", ns,
+ "-address", containerdAddress,
+ "-publish-binary", containerdBinary,
+ }
+ cmd := exec.Command(self, args...)
+ cmd.Dir = cwd
+ cmd.Env = append(os.Environ(), "GOMAXPROCS=2")
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Setpgid: true,
+ }
+ return cmd, nil
+}
+
+func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) {
+ cmd, err := newCommand(ctx, containerdBinary, containerdAddress)
+ if err != nil {
+ return "", err
+ }
+ address, err := shim.SocketAddress(ctx, id)
+ if err != nil {
+ return "", err
+ }
+ socket, err := shim.NewSocket(address)
+ if err != nil {
+ return "", err
+ }
+ defer socket.Close()
+ f, err := socket.File()
+ if err != nil {
+ return "", err
+ }
+ defer f.Close()
+
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+
+ if err := cmd.Start(); err != nil {
+ return "", err
+ }
+ defer func() {
+ if err != nil {
+ cmd.Process.Kill()
+ }
+ }()
+ // make sure to wait after start
+ go cmd.Wait()
+ if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil {
+ return "", err
+ }
+ if err := shim.WriteAddress("address", address); err != nil {
+ return "", err
+ }
+ if err := shim.SetScore(cmd.Process.Pid); err != nil {
+ return "", fmt.Errorf("failed to set OOM Score on shim: %w", err)
+ }
+ return address, nil
+}
+
+func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) {
+ path, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ runtime, err := s.readRuntime(path)
+ if err != nil {
+ return nil, err
+ }
+ r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil)
+ if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{
+ Force: true,
+ }); err != nil {
+ log.L.Printf("failed to remove runc container: %v", err)
+ }
+ if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil {
+ log.L.Printf("failed to cleanup rootfs mount: %v", err)
+ }
+ return &taskAPI.DeleteResponse{
+ ExitedAt: time.Now(),
+ ExitStatus: 128 + uint32(unix.SIGKILL),
+ }, nil
+}
+
+func (s *service) readRuntime(path string) (string, error) {
+ data, err := ioutil.ReadFile(filepath.Join(path, "runtime"))
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func (s *service) writeRuntime(path, runtime string) error {
+ return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600)
+}
+
+// Create creates a new initial process and container with the underlying OCI
+// runtime.
+func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("create namespace: %w", err)
+ }
+
+ // Read from root for now.
+ var opts options.Options
+ if r.Options != nil {
+ v, err := typeurl.UnmarshalAny(r.Options)
+ if err != nil {
+ return nil, err
+ }
+ var path string
+ switch o := v.(type) {
+ case *runctypes.CreateOptions: // containerd 1.2.x
+ opts.IoUid = o.IoUid
+ opts.IoGid = o.IoGid
+ opts.ShimCgroup = o.ShimCgroup
+ case *runctypes.RuncOptions: // containerd 1.2.x
+ root := proc.RunscRoot
+ if o.RuntimeRoot != "" {
+ root = o.RuntimeRoot
+ }
+
+ opts.BinaryName = o.Runtime
+
+ path = filepath.Join(root, configFile)
+ if _, err := os.Stat(path); err != nil {
+ if !os.IsNotExist(err) {
+ return nil, fmt.Errorf("stat config file %q: %w", path, err)
+ }
+ // A config file in runtime root is not required.
+ path = ""
+ }
+ case *runtimeoptions.Options: // containerd 1.3.x+
+ if o.ConfigPath == "" {
+ break
+ }
+ if o.TypeUrl != options.OptionType {
+ return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl)
+ }
+ path = o.ConfigPath
+ default:
+ return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl)
+ }
+ if path != "" {
+ if _, err = toml.DecodeFile(path, &opts); err != nil {
+ return nil, fmt.Errorf("decode config file %q: %w", path, err)
+ }
+ }
+ }
+
+ var mounts []proc.Mount
+ for _, m := range r.Rootfs {
+ mounts = append(mounts, proc.Mount{
+ Type: m.Type,
+ Source: m.Source,
+ Target: m.Target,
+ Options: m.Options,
+ })
+ }
+
+ rootfs := filepath.Join(r.Bundle, "rootfs")
+ if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ config := &proc.CreateConfig{
+ ID: r.ID,
+ Bundle: r.Bundle,
+ Runtime: opts.BinaryName,
+ Rootfs: mounts,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Options: r.Options,
+ }
+ if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ if err := mount.UnmountAll(rootfs, 0); err != nil {
+ log.L.Printf("failed to cleanup rootfs mount: %v", err)
+ }
+ }
+ }()
+ for _, rm := range mounts {
+ m := &mount.Mount{
+ Type: rm.Type,
+ Source: rm.Source,
+ Options: rm.Options,
+ }
+ if err := m.Mount(rootfs); err != nil {
+ return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err)
+ }
+ }
+ process, err := newInit(
+ ctx,
+ r.Bundle,
+ filepath.Join(r.Bundle, "work"),
+ ns,
+ s.platform,
+ config,
+ &opts,
+ rootfs,
+ )
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ if err := process.Create(ctx, config); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ // Save the main task id and bundle to the shim for additional
+ // requests.
+ s.id = r.ID
+ s.bundle = r.Bundle
+
+ // Set up OOM notification on the sandbox's cgroup. This is done on
+ // sandbox create since the sandbox process will be created here.
+ pid := process.Pid()
+ if pid > 0 {
+ cg, err := cgroups.Load(cgroups.V1, cgroups.PidPath(pid))
+ if err != nil {
+ return nil, fmt.Errorf("loading cgroup for %d: %w", pid, err)
+ }
+ if err := s.oomPoller.add(s.id, cg); err != nil {
+ return nil, fmt.Errorf("add cg to OOM monitor: %w", err)
+ }
+ }
+ s.task = process
+ s.opts = opts
+ return &taskAPI.CreateTaskResponse{
+ Pid: uint32(process.Pid()),
+ }, nil
+
+}
+
+// Start starts a process.
+func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if err := p.Start(ctx); err != nil {
+ return nil, err
+ }
+ // TODO: Set the cgroup and oom notifications on restore.
+ // https://github.com/google/gvisor-containerd-shim/issues/58
+ return &taskAPI.StartResponse{
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Delete deletes the initial process and container.
+func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ if err := p.Delete(ctx); err != nil {
+ return nil, err
+ }
+ isTask := r.ExecID == ""
+ if !isTask {
+ s.mu.Lock()
+ delete(s.processes, r.ExecID)
+ s.mu.Unlock()
+ }
+ if isTask && s.platform != nil {
+ s.platform.Close()
+ }
+ return &taskAPI.DeleteResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ Pid: uint32(p.Pid()),
+ }, nil
+}
+
+// Exec spawns an additional process inside the container.
+func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) {
+ s.mu.Lock()
+ p := s.processes[r.ExecID]
+ s.mu.Unlock()
+ if p != nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID)
+ }
+ p = s.task
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{
+ ID: r.ExecID,
+ Terminal: r.Terminal,
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Spec: r.Spec,
+ })
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ s.mu.Lock()
+ s.processes[r.ExecID] = process
+ s.mu.Unlock()
+ return empty, nil
+}
+
+// ResizePty resizes the terminal of a process.
+func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ ws := console.WinSize{
+ Width: uint16(r.Width),
+ Height: uint16(r.Height),
+ }
+ if err := p.Resize(ws); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// State returns runtime state information for a process.
+func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ st, err := p.Status(ctx)
+ if err != nil {
+ return nil, err
+ }
+ status := task.StatusUnknown
+ switch st {
+ case "created":
+ status = task.StatusCreated
+ case "running":
+ status = task.StatusRunning
+ case "stopped":
+ status = task.StatusStopped
+ }
+ sio := p.Stdio()
+ return &taskAPI.StateResponse{
+ ID: p.ID(),
+ Bundle: s.bundle,
+ Pid: uint32(p.Pid()),
+ Status: status,
+ Stdin: sio.Stdin,
+ Stdout: sio.Stdout,
+ Stderr: sio.Stderr,
+ Terminal: sio.Terminal,
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+// Pause the container.
+func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Resume the container.
+func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Kill a process with the provided signal.
+func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ if err := p.Kill(ctx, r.Signal, r.All); err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ return empty, nil
+}
+
+// Pids returns all pids inside the container.
+func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) {
+ pids, err := s.getContainerPids(ctx, r.ID)
+ if err != nil {
+ return nil, errdefs.ToGRPC(err)
+ }
+ var processes []*task.ProcessInfo
+ for _, pid := range pids {
+ pInfo := task.ProcessInfo{
+ Pid: pid,
+ }
+ for _, p := range s.processes {
+ if p.Pid() == int(pid) {
+ d := &runctypes.ProcessDetails{
+ ExecID: p.ID(),
+ }
+ a, err := typeurl.MarshalAny(d)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err)
+ }
+ pInfo.Info = a
+ break
+ }
+ }
+ processes = append(processes, &pInfo)
+ }
+ return &taskAPI.PidsResponse{
+ Processes: processes,
+ }, nil
+}
+
+// CloseIO closes the I/O context of a process.
+func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if stdin := p.Stdin(); stdin != nil {
+ if err := stdin.Close(); err != nil {
+ return nil, fmt.Errorf("close stdin: %w", err)
+ }
+ }
+ return empty, nil
+}
+
+// Checkpoint checkpoints the container.
+func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Connect returns shim information such as the shim's pid.
+func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) {
+ var pid int
+ if s.task != nil {
+ pid = s.task.Pid()
+ }
+ return &taskAPI.ConnectResponse{
+ ShimPid: uint32(os.Getpid()),
+ TaskPid: uint32(pid),
+ }, nil
+}
+
+func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) {
+ s.cancel()
+ os.Exit(0)
+ return empty, nil
+}
+
+func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) {
+ path, err := os.Getwd()
+ if err != nil {
+ return nil, err
+ }
+ ns, err := namespaces.NamespaceRequired(ctx)
+ if err != nil {
+ return nil, err
+ }
+ runtime, err := s.readRuntime(path)
+ if err != nil {
+ return nil, err
+ }
+ rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil)
+ stats, err := rs.Stats(ctx, s.id)
+ if err != nil {
+ return nil, err
+ }
+
+ // gvisor currently (as of 2020-03-03) only returns the total memory
+ // usage and current PID value[0]. However, we copy the common fields here
+ // so that future updates will propagate correct information. We're
+ // using the cgroups.Metrics structure so we're returning the same type
+ // as runc.
+ //
+ // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81
+ data, err := typeurl.MarshalAny(&cgroups.Metrics{
+ CPU: &cgroups.CPUStat{
+ Usage: &cgroups.CPUUsage{
+ Total: stats.Cpu.Usage.Total,
+ Kernel: stats.Cpu.Usage.Kernel,
+ User: stats.Cpu.Usage.User,
+ PerCPU: stats.Cpu.Usage.Percpu,
+ },
+ Throttling: &cgroups.Throttle{
+ Periods: stats.Cpu.Throttling.Periods,
+ ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods,
+ ThrottledTime: stats.Cpu.Throttling.ThrottledTime,
+ },
+ },
+ Memory: &cgroups.MemoryStat{
+ Cache: stats.Memory.Cache,
+ Usage: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Usage.Limit,
+ Usage: stats.Memory.Usage.Usage,
+ Max: stats.Memory.Usage.Max,
+ Failcnt: stats.Memory.Usage.Failcnt,
+ },
+ Swap: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Swap.Limit,
+ Usage: stats.Memory.Swap.Usage,
+ Max: stats.Memory.Swap.Max,
+ Failcnt: stats.Memory.Swap.Failcnt,
+ },
+ Kernel: &cgroups.MemoryEntry{
+ Limit: stats.Memory.Kernel.Limit,
+ Usage: stats.Memory.Kernel.Usage,
+ Max: stats.Memory.Kernel.Max,
+ Failcnt: stats.Memory.Kernel.Failcnt,
+ },
+ KernelTCP: &cgroups.MemoryEntry{
+ Limit: stats.Memory.KernelTCP.Limit,
+ Usage: stats.Memory.KernelTCP.Usage,
+ Max: stats.Memory.KernelTCP.Max,
+ Failcnt: stats.Memory.KernelTCP.Failcnt,
+ },
+ },
+ Pids: &cgroups.PidsStat{
+ Current: stats.Pids.Current,
+ Limit: stats.Pids.Limit,
+ },
+ })
+ if err != nil {
+ return nil, err
+ }
+ return &taskAPI.StatsResponse{
+ Stats: data,
+ }, nil
+}
+
+// Update updates a running container.
+func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) {
+ return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented)
+}
+
+// Wait waits for a process to exit.
+func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) {
+ p, err := s.getProcess(r.ExecID)
+ if err != nil {
+ return nil, err
+ }
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created")
+ }
+ p.Wait()
+
+ return &taskAPI.WaitResponse{
+ ExitStatus: uint32(p.ExitStatus()),
+ ExitedAt: p.ExitedAt(),
+ }, nil
+}
+
+func (s *service) processExits() {
+ for e := range s.ec {
+ s.checkProcesses(e)
+ }
+}
+
+func (s *service) checkProcesses(e proc.Exit) {
+ // TODO(random-liu): Add `shouldKillAll` logic if container pid
+ // namespace is supported.
+ for _, p := range s.allProcesses() {
+ if p.ID() == e.ID {
+ if ip, ok := p.(*proc.Init); ok {
+ // Ensure all children are killed.
+ if err := ip.KillAll(s.context); err != nil {
+ log.G(s.context).WithError(err).WithField("id", ip.ID()).
+ Error("failed to kill init's children")
+ }
+ }
+ p.SetExited(e.Status)
+ s.events <- &events.TaskExit{
+ ContainerID: s.id,
+ ID: p.ID(),
+ Pid: uint32(p.Pid()),
+ ExitStatus: uint32(e.Status),
+ ExitedAt: p.ExitedAt(),
+ }
+ return
+ }
+ }
+}
+
+func (s *service) allProcesses() (o []process.Process) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for _, p := range s.processes {
+ o = append(o, p)
+ }
+ if s.task != nil {
+ o = append(o, s.task)
+ }
+ return o
+}
+
+func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, error) {
+ s.mu.Lock()
+ p := s.task
+ s.mu.Unlock()
+ if p == nil {
+ return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition)
+ }
+ ps, err := p.(*proc.Init).Runtime().Ps(ctx, id)
+ if err != nil {
+ return nil, err
+ }
+ pids := make([]uint32, 0, len(ps))
+ for _, pid := range ps {
+ pids = append(pids, uint32(pid))
+ }
+ return pids, nil
+}
+
+func (s *service) forward(publisher shim.Publisher) {
+ for e := range s.events {
+ ctx, cancel := context.WithTimeout(s.context, 5*time.Second)
+ err := publisher.Publish(ctx, getTopic(e), e)
+ cancel()
+ if err != nil {
+ // Should not happen.
+ panic(fmt.Errorf("post event: %w", err))
+ }
+ }
+}
+
+func (s *service) getProcess(execID string) (process.Process, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if execID == "" {
+ return s.task, nil
+ }
+ p := s.processes[execID]
+ if p == nil {
+ return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID)
+ }
+ return p, nil
+}
+
+func getTopic(e interface{}) string {
+ switch e.(type) {
+ case *events.TaskCreate:
+ return runtime.TaskCreateEventTopic
+ case *events.TaskStart:
+ return runtime.TaskStartEventTopic
+ case *events.TaskOOM:
+ return runtime.TaskOOMEventTopic
+ case *events.TaskExit:
+ return runtime.TaskExitEventTopic
+ case *events.TaskDelete:
+ return runtime.TaskDeleteEventTopic
+ case *events.TaskExecAdded:
+ return runtime.TaskExecAddedEventTopic
+ case *events.TaskExecStarted:
+ return runtime.TaskExecStartedEventTopic
+ default:
+ log.L.Printf("no topic for type %#v", e)
+ }
+ return runtime.TaskUnknownTopic
+}
+
+func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) {
+ spec, err := utils.ReadSpec(r.Bundle)
+ if err != nil {
+ return nil, fmt.Errorf("read oci spec: %w", err)
+ }
+ if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil {
+ return nil, fmt.Errorf("update volume annotations: %w", err)
+ }
+ runsc.FormatLogPath(r.ID, options.RunscConfig)
+ runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig)
+ p := proc.New(r.ID, runtime, stdio.Stdio{
+ Stdin: r.Stdin,
+ Stdout: r.Stdout,
+ Stderr: r.Stderr,
+ Terminal: r.Terminal,
+ })
+ p.Bundle = r.Bundle
+ p.Platform = platform
+ p.Rootfs = rootfs
+ p.WorkDir = workDir
+ p.IoUID = int(options.IoUid)
+ p.IoGID = int(options.IoGid)
+ p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox
+ p.UserLog = utils.UserLogPath(spec)
+ p.Monitor = reaper.Default
+ return p, nil
+}
diff --git a/pkg/shim/v2/service_linux.go b/pkg/shim/v2/service_linux.go
new file mode 100644
index 000000000..1800ab90b
--- /dev/null
+++ b/pkg/shim/v2/service_linux.go
@@ -0,0 +1,108 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package v2
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/console"
+ "github.com/containerd/fifo"
+)
+
+type linuxPlatform struct {
+ epoller *console.Epoller
+}
+
+func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) {
+ if p.epoller == nil {
+ return nil, fmt.Errorf("uninitialized epoller")
+ }
+
+ epollConsole, err := p.epoller.Add(console)
+ if err != nil {
+ return nil, err
+ }
+
+ if stdin != "" {
+ in, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0)
+ if err != nil {
+ return nil, err
+ }
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(epollConsole, in, *p)
+ }()
+ }
+
+ outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0)
+ if err != nil {
+ return nil, err
+ }
+ wg.Add(1)
+ go func() {
+ p := bufPool.Get().(*[]byte)
+ defer bufPool.Put(p)
+ io.CopyBuffer(outw, epollConsole, *p)
+ epollConsole.Close()
+ outr.Close()
+ outw.Close()
+ wg.Done()
+ }()
+ return epollConsole, nil
+}
+
+func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error {
+ if p.epoller == nil {
+ return fmt.Errorf("uninitialized epoller")
+ }
+ epollConsole, ok := cons.(*console.EpollConsole)
+ if !ok {
+ return fmt.Errorf("expected EpollConsole, got %#v", cons)
+ }
+ return epollConsole.Shutdown(p.epoller.CloseConsole)
+}
+
+func (p *linuxPlatform) Close() error {
+ return p.epoller.Close()
+}
+
+// initialize a single epoll fd to manage our consoles. `initPlatform` should
+// only be called once.
+func (s *service) initPlatform() error {
+ if s.platform != nil {
+ return nil
+ }
+ epoller, err := console.NewEpoller()
+ if err != nil {
+ return fmt.Errorf("failed to initialize epoller: %w", err)
+ }
+ s.platform = &linuxPlatform{
+ epoller: epoller,
+ }
+ go epoller.Wait()
+ return nil
+}
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index a23c86fb1..ae0fe1522 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -12,8 +11,8 @@ go_library(
"commit_noasm.go",
"sleep_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sleep",
visibility = ["//:sandbox"],
+ deps = ["//pkg/sync"],
)
go_test(
@@ -22,5 +21,5 @@ go_test(
srcs = [
"sleep_test.go",
],
- embed = [":sleep"],
+ library = ":sleep",
)
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
index 3af447fb9..f59061f37 100644
--- a/pkg/sleep/commit_noasm.go
+++ b/pkg/sleep/commit_noasm.go
@@ -28,15 +28,6 @@ import "sync/atomic"
// It is written in assembly because it is called from g0, so it doesn't have
// a race context.
func commitSleep(g uintptr, waitingG *uintptr) bool {
- for {
- // Check if the wait was aborted.
- if atomic.LoadUintptr(waitingG) == 0 {
- return false
- }
-
- // Try to store the G so that wakers know who to wake.
- if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) {
- return true
- }
- }
+ // Try to store the G so that wakers know who to wake.
+ return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
}
diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go
index 130806c86..1dd11707d 100644
--- a/pkg/sleep/sleep_test.go
+++ b/pkg/sleep/sleep_test.go
@@ -376,6 +376,31 @@ func TestRace(t *testing.T) {
}
}
+// TestRaceInOrder tests that multiple wakers can continuously send wake requests to
+// the sleeper and that the wakers are retrieved in the order asserted.
+func TestRaceInOrder(t *testing.T) {
+ w := make([]Waker, 10000)
+ s := Sleeper{}
+
+ // Associate each waker and start goroutines that will assert them.
+ for i := range w {
+ s.AddWaker(&w[i], i)
+ }
+ go func() {
+ for i := range w {
+ w[i].Assert()
+ }
+ }()
+
+ // Wait for all wake up notifications from all wakers.
+ for want := range w {
+ got, _ := s.Fetch(true)
+ if got != want {
+ t.Fatalf("got %d want %d", got, want)
+ }
+ }
+}
+
// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up
// from 4 wakers when at least one is already asserted.
func BenchmarkSleeperMultiSelect(b *testing.B) {
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index 8f5e60a25..118805492 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -13,7 +13,7 @@
// limitations under the License.
// +build go1.11
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -75,6 +75,8 @@ package sleep
import (
"sync/atomic"
"unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
@@ -299,20 +301,17 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
}
}
- for {
- // Nothing to do if there isn't a G waiting.
- g := atomic.LoadUintptr(&s.waitingG)
- if g == 0 {
- return
- }
+ // Nothing to do if there isn't a G waiting.
+ if atomic.LoadUintptr(&s.waitingG) == 0 {
+ return
+ }
- // Signal to the sleeper that a waker has been asserted.
- if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) {
- if g != preparingG {
- // We managed to get a G. Wake it up.
- goready(g, 0)
- }
- }
+ // Signal to the sleeper that a waker has been asserted.
+ switch g := atomic.SwapUintptr(&s.waitingG, 0); g {
+ case 0, preparingG:
+ default:
+ // We managed to get a G. Wake it up.
+ goready(g, 0)
}
}
@@ -326,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
//
// This struct is thread-safe, that is, its methods can be called concurrently
// by multiple goroutines.
+//
+// Note, it is not safe to copy a Waker as its fields are modified by value
+// (the pointer fields are individually modified with atomic operations).
type Waker struct {
+ _ sync.NoCopy
+
// s is the sleeper that this waker can wake up. Only one sleeper at a
// time is allowed. This field can have three classes of values:
// nil -- the waker is not asserted: it either is not associated with
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index be93750bf..089b3bbef 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,11 +1,47 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
go_template_instance(
+ name = "pending_list",
+ out = "pending_list.go",
+ package = "state",
+ prefix = "pending",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "pendingMapper",
+ "Linker": "*pendingEntry",
+ },
+)
+
+go_template_instance(
+ name = "deferred_list",
+ out = "deferred_list.go",
+ package = "state",
+ prefix = "deferred",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectEncodeState",
+ "ElementMapper": "deferredMapper",
+ "Linker": "*deferredEntry",
+ },
+)
+
+go_template_instance(
+ name = "complete_list",
+ out = "complete_list.go",
+ package = "state",
+ prefix = "complete",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*objectDecodeState",
+ "Linker": "*objectDecodeState",
+ },
+)
+
+go_template_instance(
name = "addr_range",
out = "addr_range.go",
package = "state",
@@ -31,7 +67,7 @@ go_template_instance(
types = {
"Key": "uintptr",
"Range": "addrRange",
- "Value": "reflect.Value",
+ "Value": "*objectEncodeState",
"Functions": "addrSetFunctions",
},
)
@@ -41,38 +77,24 @@ go_library(
srcs = [
"addr_range.go",
"addr_set.go",
+ "complete_list.go",
"decode.go",
+ "decode_unsafe.go",
+ "deferred_list.go",
"encode.go",
"encode_unsafe.go",
- "map.go",
- "printer.go",
+ "pending_list.go",
"state.go",
+ "state_norace.go",
+ "state_race.go",
"stats.go",
+ "types.go",
],
- importpath = "gvisor.dev/gvisor/pkg/state",
+ marshal = False,
+ stateify = False,
visibility = ["//:sandbox"],
deps = [
- ":object_go_proto",
- "@com_github_golang_protobuf//proto:go_default_library",
+ "//pkg/log",
+ "//pkg/state/wire",
],
)
-
-proto_library(
- name = "object_proto",
- srcs = ["object.proto"],
- visibility = ["//:sandbox"],
-)
-
-go_proto_library(
- name = "object_go_proto",
- importpath = "gvisor.dev/gvisor/pkg/state/object_go_proto",
- proto = ":object_proto",
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "state_test",
- timeout = "long",
- srcs = ["state_test.go"],
- embed = [":state"],
-)
diff --git a/pkg/state/README.md b/pkg/state/README.md
new file mode 100644
index 000000000..1aa401193
--- /dev/null
+++ b/pkg/state/README.md
@@ -0,0 +1,158 @@
+# State Encoding and Decoding
+
+The state package implements the encoding and decoding of data structures for
+`go_stateify`. This package is designed for use cases other than the standard
+encoding packages, e.g. `gob` and `json`. Principally:
+
+* This package operates on complex object graphs and accurately serializes and
+ restores all relationships. That is, you can have things like: intrusive
+ pointers, cycles, and pointer chains of arbitrary depths. These are not
+ handled appropriately by existing encoders. This is not an implementation
+ flaw: the formats themselves are not capable of representing these graphs,
+ as they can only generate directed trees.
+
+* This package allows installing order-dependent load callbacks and then
+ resolves that graph at load time, with cycle detection. Similarly, there is
+ no analogous feature possible in the standard encoders.
+
+* This package handles the resolution of interfaces, based on a registered
+ type name. For interface objects type information is saved in the serialized
+ format. This is generally true for `gob` as well, but it works differently.
+
+Here's an overview of how encoding and decoding works.
+
+## Encoding
+
+Encoding produces a `statefile`, which contains a list of chunks of the form
+`(header, payload)`. The payload can either be some raw data, or a series of
+encoded wire objects representing some object graph. All encoded objects are
+defined in the `wire` subpackage.
+
+Encoding of an object graph begins with `encodeState.Save`.
+
+### 1. Memory Map & Encoding
+
+To discover relationships between potentially interdependent data structures
+(for example, a struct may contain pointers to members of other data
+structures), the encoder first walks the object graph and constructs a memory
+map of the objects in the input graph. As this walk progresses, objects are
+queued in the `pending` list and items are placed on the `deferred` list as they
+are discovered. No single object will be encoded multiple times, but the
+discovered relationships between objects may change as more parts of the overall
+object graph are discovered.
+
+The encoder starts at the root object and recursively visits all reachable
+objects, recording the address ranges containing the underlying data for each
+object. This is stored as a segment set (`addrSet`), mapping address ranges to
+the of the object occupying the range; see `encodeState.values`. Note that there
+is special handling for zero-sized types and map objects during this process.
+
+Additionally, the encoder assigns each object a unique identifier which is used
+to indicate relationships between objects in the statefile; see `objectID` in
+`encode.go`.
+
+### 2. Type Serialization
+
+The enoder will subsequently serialize all information about discovered types,
+including field names. These are used during decoding to reconcile these types
+with other internally registered types.
+
+### 3. Object Serialization
+
+With a full address map, and all objects correctly encoded, all object encodings
+are serialized. The assigned `objectID`s aren't explicitly encoded in the
+statefile. The order of object messages in the stream determine their IDs.
+
+### Example
+
+Given the following data structure definitions:
+
+```go
+type system struct {
+ o *outer
+ i *inner
+}
+
+type outer struct {
+ a int64
+ cn *container
+}
+
+type container struct {
+ n uint64
+ elem *inner
+}
+
+type inner struct {
+ c container
+ x, y uint64
+}
+```
+
+Initialized like this:
+
+```go
+o := outer{
+ a: 10,
+ cn: nil,
+}
+i := inner{
+ x: 20,
+ y: 30,
+ c: container{},
+}
+s := system{
+ o: &o,
+ i: &i,
+}
+
+o.cn = &i.c
+o.cn.elem = &i
+
+```
+
+Encoding will produce an object stream like this:
+
+```
+g0r1 = struct{
+ i: g0r3,
+ o: g0r2,
+}
+g0r2 = struct{
+ a: 10,
+ cn: g0r3.c,
+}
+g0r3 = struct{
+ c: struct{
+ elem: g0r3,
+ n: 0u,
+ },
+ x: 20u,
+ y: 30u,
+}
+```
+
+Note how `g0r3.c` is correctly encoded as the underlying `container` object for
+`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i`
+being discovered after the pointer to it in `system.o.cn`. Also note that
+decoding isn't strictly reliant on the order of encoded object stream, as long
+as the relationship between objects are correctly encoded.
+
+## Decoding
+
+Decoding reads the statefile and reconstructs the object graph. Decoding begins
+in `decodeState.Load`. Decoding is performed in a single pass over the object
+stream in the statefile, and a subsequent pass over all deserialized objects is
+done to fire off all loading callbacks in the correctly defined order. Note that
+introducing cycles is possible here, but these are detected and an error will be
+returned.
+
+Decoding is relatively straight forward. For most primitive values, the decoder
+constructs an appropriate object and fills it with the values encoded in the
+statefile. Pointers need special handling, as they must point to a value
+allocated elsewhere. When values are constructed, the decoder indexes them by
+their `objectID`s in `decodeState.objectsByID`. The target of pointers are
+resolved by searching for the target in this index by their `objectID`; see
+`decodeState.register`. For pointers to values inside another value (fields in a
+pointer, elements of an array), the decoder uses the accessor path to walk to
+the appropriate location; see `walkChild`.
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 47e6b878a..c9971cdf6 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -16,28 +16,50 @@ package state
import (
"bytes"
- "encoding/binary"
- "errors"
+ "context"
"fmt"
- "io"
+ "math"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// objectState represents an object that may be in the process of being
+// internalCallback is a interface called on object completion.
+//
+// There are two implementations: objectDecodeState & userCallback.
+type internalCallback interface {
+ // source returns the dependent object. May be nil.
+ source() *objectDecodeState
+
+ // callbackRun executes the callback.
+ callbackRun()
+}
+
+// userCallback is an implementation of internalCallback.
+type userCallback func()
+
+// source implements internalCallback.source.
+func (userCallback) source() *objectDecodeState {
+ return nil
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (uc userCallback) callbackRun() {
+ uc()
+}
+
+// objectDecodeState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
-type objectState struct {
+type objectDecodeState struct {
// id is the id for this object.
- //
- // If this field is zero, then this is an anonymous (unregistered,
- // non-reference primitive) object. This is immutable.
- id uint64
+ id objectID
+
+ // typ is the id for this typeID. This may be zero if this is not a
+ // type-registered structure.
+ typ typeID
// obj is the object. This may or may not be valid yet, depending on
// whether complete returns true. However, regardless of whether the
@@ -56,69 +78,52 @@ type objectState struct {
// blockedBy is the number of dependencies this object has.
blockedBy int
- // blocking is a list of the objects blocked by this one.
- blocking []*objectState
+ // callbacksInline is inline storage for callbacks.
+ callbacksInline [2]internalCallback
// callbacks is a set of callbacks to execute on load.
- callbacks []func()
-
- // path is the decoding path to the object.
- path recoverable
-}
-
-// complete indicates the object is complete.
-func (os *objectState) complete() bool {
- return os.blockedBy == 0 && len(os.callbacks) == 0
-}
-
-// checkComplete checks for completion. If the object is complete, pending
-// callbacks will be executed and checkComplete will be called on downstream
-// objects (those depending on this one).
-func (os *objectState) checkComplete(stats *Stats) {
- if os.blockedBy > 0 {
- return
- }
- stats.Start(os.obj)
+ callbacks []internalCallback
- // Fire all callbacks.
- for _, fn := range os.callbacks {
- fn()
- }
- os.callbacks = nil
-
- // Clear all blocked objects.
- for _, other := range os.blocking {
- other.blockedBy--
- other.checkComplete(stats)
- }
- os.blocking = nil
- stats.Done()
+ completeEntry
}
-// waitFor queues a dependency on the given object.
-func (os *objectState) waitFor(other *objectState, callback func()) {
- os.blockedBy++
- other.blocking = append(other.blocking, os)
- if callback != nil {
- other.callbacks = append(other.callbacks, callback)
+// addCallback adds a callback to the objectDecodeState.
+func (ods *objectDecodeState) addCallback(ic internalCallback) {
+ if ods.callbacks == nil {
+ ods.callbacks = ods.callbacksInline[:0]
}
+ ods.callbacks = append(ods.callbacks, ic)
}
// findCycleFor returns when the given object is found in the blocking set.
-func (os *objectState) findCycleFor(target *objectState) []*objectState {
- for _, other := range os.blocking {
- if other == target {
- return []*objectState{target}
+func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState {
+ for _, ic := range ods.callbacks {
+ other := ic.source()
+ if other != nil && other == target {
+ return []*objectDecodeState{target}
} else if childList := other.findCycleFor(target); childList != nil {
return append(childList, other)
}
}
- return nil
+
+ // This should not occur.
+ Failf("no deadlock found?")
+ panic("unreachable")
}
// findCycle finds a dependency cycle.
-func (os *objectState) findCycle() []*objectState {
- return append(os.findCycleFor(os), os)
+func (ods *objectDecodeState) findCycle() []*objectDecodeState {
+ return append(ods.findCycleFor(ods), ods)
+}
+
+// source implements internalCallback.source.
+func (ods *objectDecodeState) source() *objectDecodeState {
+ return ods
+}
+
+// callbackRun implements internalCallback.callbackRun.
+func (ods *objectDecodeState) callbackRun() {
+ ods.blockedBy--
}
// decodeState is a graph of objects in the process of being decoded.
@@ -133,30 +138,69 @@ func (os *objectState) findCycle() []*objectState {
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
+ // ctx is the decode context.
+ ctx context.Context
+
+ // r is the input stream.
+ r wire.Reader
+
+ // types is the type database.
+ types typeDecodeDatabase
+
// objectByID is the set of objects in progress.
- objectsByID map[uint64]*objectState
+ objectsByID []*objectDecodeState
// deferred are objects that have been read, by no interest has been
// registered yet. These will be decoded once interest in registered.
- deferred map[uint64]*pb.Object
+ deferred map[objectID]wire.Object
- // outstanding is the number of outstanding objects.
- outstanding uint32
+ // pending is the set of objects that are not yet complete.
+ pending completeList
- // r is the input stream.
- r io.Reader
-
- // stats is the passed stats object.
- stats *Stats
-
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
-func (ds *decodeState) lookup(id uint64) *objectState {
- return ds.objectsByID[id]
+func (ds *decodeState) lookup(id objectID) *objectDecodeState {
+ if len(ds.objectsByID) < int(id) {
+ return nil
+ }
+ return ds.objectsByID[id-1]
+}
+
+// checkComplete checks for completion.
+func (ds *decodeState) checkComplete(ods *objectDecodeState) bool {
+ // Still blocked?
+ if ods.blockedBy > 0 {
+ return false
+ }
+
+ // Track stats if relevant.
+ if ods.callbacks != nil && ods.typ != 0 {
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ }
+
+ // Fire all callbacks.
+ for _, ic := range ods.callbacks {
+ ic.callbackRun()
+ }
+
+ // Mark completed.
+ cbs := ods.callbacks
+ ods.callbacks = nil
+ ds.pending.Remove(ods)
+
+ // Recursively check others.
+ for _, ic := range cbs {
+ if other := ic.source(); other != nil && other.blockedBy == 0 {
+ ds.checkComplete(other)
+ }
+ }
+
+ return true // All set.
}
// wait registers a dependency on an object.
@@ -164,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState {
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
-func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
+func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) {
switch id {
- case 0:
- // Nil pointer; nothing to wait for.
- fallthrough
case waiter.id:
// Trivial self reference.
fallthrough
@@ -180,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
return
}
+ // Mark as blocked.
+ waiter.blockedBy++
+
// No nil can be returned here.
- waiter.waitFor(ds.lookup(id), callback)
+ other := ds.lookup(id)
+ if callback != nil {
+ // Add the additional user callback.
+ other.addCallback(userCallback(callback))
+ }
+
+ // Mark waiter as unblocked.
+ other.addCallback(waiter)
}
// waitObject notes a blocking relationship.
-func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
- if rv, ok := p.Value.(*pb.Object_RefValue); ok {
+func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) {
+ if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 {
// Refs can encode pointers and maps.
- ds.wait(os, rv.RefValue, callback)
- } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
+ ds.wait(ods, objectID(rv.Root), callback)
+ } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 {
// See decodeObject; we need to wait for the array (if non-nil).
- ds.wait(os, sv.SliceValue.RefValue, callback)
- } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
+ ds.wait(ods, objectID(sv.Ref.Root), callback)
+ } else if iv, ok := encoded.(*wire.Interface); ok {
// It's an interface (wait recurisvely).
- ds.waitObject(os, iv.InterfaceValue.Value, callback)
+ ds.waitObject(ods, iv.Value, callback)
} else if callback != nil {
// Nothing to wait for: execute the callback immediately.
callback()
}
}
+// walkChild returns a child object from obj, given an accessor path. This is
+// the decode-side equivalent to traverse in encode.go.
+//
+// For the purposes of this function, a child object is either a field within a
+// struct or an array element, with one such indirection per element in
+// path. The returned value may be an unexported field, so it may not be
+// directly assignable. See unsafePointerTo.
+func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value {
+ // See wire.Ref.Dots. The path here is specified in reverse order.
+ for i := len(path) - 1; i >= 0; i-- {
+ switch pc := path[i].(type) {
+ case *wire.FieldName: // Must be a pointer.
+ if obj.Kind() != reflect.Struct {
+ Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.FieldByName(string(*pc))
+ case wire.Index: // Embedded.
+ if obj.Kind() != reflect.Array {
+ Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj)
+ }
+ obj = obj.Index(int(pc))
+ default:
+ panic("unreachable: switch should be exhaustive")
+ }
+ }
+ return obj
+}
+
// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
-// registered previously.
-func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
- os, ok := ds.objectsByID[id]
- if ok {
- return os
+// registered previously. This depends on the type provided if none is
+// available in the object itself.
+func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value {
+ // Grow the objectsByID slice.
+ id := objectID(r.Root)
+ if len(ds.objectsByID) < int(id) {
+ ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...)
+ }
+
+ // Does this object already exist?
+ ods := ds.objectsByID[id-1]
+ if ods != nil {
+ return walkChild(r.Dots, ods.obj)
+ }
+
+ // Create the object.
+ if len(r.Dots) != 0 {
+ typ = ds.findType(r.Type)
+ }
+ v := reflect.New(typ)
+ ods = &objectDecodeState{
+ id: id,
+ obj: v.Elem(),
}
+ ds.objectsByID[id-1] = ods
+ ds.pending.PushBack(ods)
- // Record in the object index.
- if typ.Kind() == reflect.Map {
- os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
- } else {
- os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
+ // Process any deferred objects & callbacks.
+ if encoded, ok := ds.deferred[id]; ok {
+ delete(ds.deferred, id)
+ ds.decodeObject(ods, ods.obj, encoded)
}
- ds.objectsByID[id] = os
- if o, ok := ds.deferred[id]; ok {
- // There is a deferred object.
- delete(ds.deferred, id) // Free memory.
- ds.decodeObject(os, os.obj, o, "", nil)
- } else {
- // There is no deferred object.
- ds.outstanding++
+ return walkChild(r.Dots, ods.obj)
+}
+
+// objectDecoder is for decoding structs.
+type objectDecoder struct {
+ // ds is decodeState.
+ ds *decodeState
+
+ // ods is current object being decoded.
+ ods *objectDecodeState
+
+ // reconciledTypeEntry is the reconciled type information.
+ rte *reconciledTypeEntry
+
+ // encoded is the encoded object state.
+ encoded *wire.Struct
+}
+
+// load is helper for the public methods on Source.
+func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) {
+ // Note that we have reconciled the type and may remap the fields here
+ // to match what's expected by the decoder. The "slot" parameter here
+ // is in terms of the local type, where the fields in the encoded
+ // object are in terms of the wire object's type, which might be in a
+ // different order (but will have the same fields).
+ v := *od.encoded.Field(od.rte.FieldOrder[slot])
+ od.ds.decodeObject(od.ods, objPtr.Elem(), v)
+ if wait {
+ // Mark this individual object a blocker.
+ od.ds.waitObject(od.ods, v, fn)
}
+}
- return os
+// aterLoad implements Source.AfterLoad.
+func (od *objectDecoder) afterLoad(fn func()) {
+ // Queue the local callback; this will execute when all of the above
+ // data dependencies have been cleared.
+ od.ods.addCallback(userCallback(fn))
}
// decodeStruct decodes a struct value.
-func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
- // Set the fields.
- m := Map{newInternalMap(nil, ds, os)}
- defer internalMapPool.Put(m.internalMap)
- for _, field := range s.Fields {
- m.data = append(m.data, entry{
- name: field.Name,
- object: field.Value,
- })
- }
-
- // Sort the fields for efficient searching.
- //
- // Technically, these should already appear in sorted order in the
- // state ordering, so this cost is effectively a single scan to ensure
- // that the order is correct.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- }
-
- // Invoke the load; this will recursively decode other objects.
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the loader.
- fns.invokeLoad(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow anonymous empty structs.
- return
- } else {
+func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) {
+ if encoded.TypeID == 0 {
+ // Allow anonymous empty structs, but only if the encoded
+ // object also has no fields.
+ if encoded.Fields() == 0 && obj.NumField() == 0 {
+ return
+ }
+
// Propagate an error.
- panic(fmt.Errorf("unregistered type %s", obj.Type()))
+ Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name())
+ }
+
+ // Lookup the object type.
+ rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type())
+ ods.typ = typeID(encoded.TypeID)
+
+ // Invoke the loader.
+ od := objectDecoder{
+ ds: ds,
+ ods: ods,
+ rte: rte,
+ encoded: encoded,
+ }
+ ds.stats.start(ods.typ)
+ defer ds.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateLoad(Source{internal: od})
}
}
// decodeMap decodes a map value.
-func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
+func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) {
if obj.IsNil() {
+ // See pointerTo.
obj.Set(reflect.MakeMap(obj.Type()))
}
- for i := 0; i < len(m.Keys); i++ {
+ for i := 0; i < len(encoded.Keys); i++ {
// Decode the objects.
kv := reflect.New(obj.Type().Key()).Elem()
vv := reflect.New(obj.Type().Elem()).Elem()
- ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
- ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
- ds.waitObject(os, m.Keys[i], nil)
- ds.waitObject(os, m.Values[i], nil)
+ ds.decodeObject(ods, kv, encoded.Keys[i])
+ ds.decodeObject(ods, vv, encoded.Values[i])
+ ds.waitObject(ods, encoded.Keys[i], nil)
+ ds.waitObject(ods, encoded.Values[i], nil)
// Set in the map.
obj.SetMapIndex(kv, vv)
@@ -288,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map)
}
// decodeArray decodes an array value.
-func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
- if len(a.Contents) != obj.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
+func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) {
+ if len(encoded.Contents) != obj.Len() {
+ Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents))
}
// Decode the contents into the array.
- for i := 0; i < len(a.Contents); i++ {
- ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
- ds.waitObject(os, a.Contents[i], nil)
+ for i := 0; i < len(encoded.Contents); i++ {
+ ds.decodeObject(ods, obj.Index(i), encoded.Contents[i])
+ ds.waitObject(ods, encoded.Contents[i], nil)
}
}
-// decodeInterface decodes an interface value.
-func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
- // Is this a nil value?
- if i.Type == "" {
- return // Just leave obj alone.
+// findType finds the type for the given wire.TypeSpecs.
+func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type {
+ switch x := t.(type) {
+ case wire.TypeID:
+ typ := ds.types.LookupType(typeID(x))
+ rte := ds.types.Lookup(typeID(x), typ)
+ return rte.LocalType
+ case *wire.TypeSpecPointer:
+ return reflect.PtrTo(ds.findType(x.Type))
+ case *wire.TypeSpecArray:
+ return reflect.ArrayOf(int(x.Count), ds.findType(x.Type))
+ case *wire.TypeSpecSlice:
+ return reflect.SliceOf(ds.findType(x.Type))
+ case *wire.TypeSpecMap:
+ return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value))
+ default:
+ // Should not happen.
+ Failf("unknown type %#v", t)
}
+ panic("unreachable")
+}
- // Get the dispatchable type. This may not be used if the given
- // reference has already been resolved, but if not we need to know the
- // type to create.
- t, ok := registeredTypes.lookupType(i.Type)
- if !ok {
- panic(fmt.Errorf("no valid type for %q", i.Type))
+// decodeInterface decodes an interface value.
+func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) {
+ if _, ok := encoded.Type.(wire.TypeSpecNil); ok {
+ // Special case; the nil object. Just decode directly, which
+ // will read nil from the wire (if encoded correctly).
+ ds.decodeObject(ods, obj, encoded.Value)
+ return
}
- if obj.Kind() != reflect.Map {
- // Set the obj to be the given typed value; this actually sets
- // obj to be a non-zero value -- namely, it inserts type
- // information. There's no need to do this for maps.
- obj.Set(reflect.Zero(t))
+ // We now need to resolve the actual type.
+ typ := ds.findType(encoded.Type)
+
+ // We need to imbue type information here, then we can proceed to
+ // decode normally. In order to avoid issues with setting value-types,
+ // we create a new non-interface version of this object. We will then
+ // set the interface object to be equal to whatever we decode.
+ origObj := obj
+ obj = reflect.New(typ).Elem()
+ defer origObj.Set(obj)
+
+ // With the object now having sufficient type information to actually
+ // have Set called on it, we can proceed to decode the value.
+ ds.decodeObject(ods, obj, encoded.Value)
+}
+
+// isFloatEq determines if x and y represent the same value.
+func isFloatEq(x float64, y float64) bool {
+ switch {
+ case math.IsNaN(x):
+ return math.IsNaN(y)
+ case math.IsInf(x, 1):
+ return math.IsInf(y, 1)
+ case math.IsInf(x, -1):
+ return math.IsInf(y, -1)
+ default:
+ return x == y
}
+}
- // Decode the dereferenced element; there is no need to wait here, as
- // the interface object shares the current object state.
- ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
+// isComplexEq determines if x and y represent the same value.
+func isComplexEq(x complex128, y complex128) bool {
+ return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y))
}
// decodeObject decodes a object value.
-func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
- ds.push(false, format, param)
- ds.stats.Add(obj)
- ds.stats.Start(obj)
-
- switch x := object.GetValue().(type) {
- case *pb.Object_BoolValue:
- obj.SetBool(x.BoolValue)
- case *pb.Object_StringValue:
- obj.SetString(string(x.StringValue))
- case *pb.Object_Int64Value:
- obj.SetInt(x.Int64Value)
- if obj.Int() != x.Int64Value {
- panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_Uint64Value:
- obj.SetUint(x.Uint64Value)
- if obj.Uint() != x.Uint64Value {
- panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
- }
- case *pb.Object_DoubleValue:
- obj.SetFloat(x.DoubleValue)
- if obj.Float() != x.DoubleValue {
- panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
+func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) {
+ switch x := encoded.(type) {
+ case wire.Nil: // Fast path: first.
+ // We leave obj alone here. That's because if obj represents an
+ // interface, it may have been imbued with type information in
+ // decodeInterface, and we don't want to destroy that.
+ case *wire.Ref:
+ // Nil pointers may be encoded in a "forceValue" context. For
+ // those we just leave it alone as the value will already be
+ // correct (nil).
+ if id := objectID(x.Root); id == 0 {
+ return
}
- case *pb.Object_RefValue:
- // Resolve the pointer itself, even though the object may not
- // be decoded yet. You need to use wait() in order to ensure
- // that is the case. See wait above, and Map.Barrier.
- if id := x.RefValue; id != 0 {
- // Decoding the interface should have imparted type
- // information, so from this point it's safe to resolve
- // and use this dynamic information for actually
- // creating the object in register.
- //
- // (For non-interfaces this is a no-op).
- dyntyp := reflect.TypeOf(obj.Interface())
- if dyntyp.Kind() == reflect.Map {
- // Remove the map object count here to avoid
- // double counting, as this object will be
- // counted again when it gets processed later.
- // We do not add a reference count as the
- // reference is artificial.
- ds.stats.Remove(obj)
- obj.Set(ds.register(id, dyntyp).obj)
- } else if dyntyp.Kind() == reflect.Ptr {
- ds.push(true /* dereference */, "", nil)
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
- ds.pop()
- } else {
- obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
+
+ // Note that if this is a map type, we go through a level of
+ // indirection to allow for map aliasing.
+ if obj.Kind() == reflect.Map {
+ v := ds.register(x, obj.Type())
+ if v.IsNil() {
+ // Note that we don't want to clobber the map
+ // if has already been decoded by decodeMap. We
+ // just make it so that we have a consistent
+ // reference when that eventually does happen.
+ v.Set(reflect.MakeMap(v.Type()))
}
- } else {
- // We leave obj alone here. That's because if obj
- // represents an interface, it may have been embued
- // with type information in decodeInterface, and we
- // don't want to destroy that information.
+ obj.Set(v)
+ return
}
- case *pb.Object_SliceValue:
- // It's okay to slice the array here, since the contents will
- // still be provided later on. These semantics are a bit
- // strange but they are handled in the Map.Barrier properly.
- //
- // The special semantics of zero ref apply here too.
- if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
- v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
- obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
+
+ // Normal assignment: authoritative only if no dots.
+ v := ds.register(x, obj.Type().Elem())
+ if v.IsValid() {
+ obj.Set(unsafePointerTo(v))
}
- case *pb.Object_ArrayValue:
- ds.decodeArray(os, obj, x.ArrayValue)
- case *pb.Object_StructValue:
- ds.decodeStruct(os, obj, x.StructValue)
- case *pb.Object_MapValue:
- ds.decodeMap(os, obj, x.MapValue)
- case *pb.Object_InterfaceValue:
- ds.decodeInterface(os, obj, x.InterfaceValue)
- case *pb.Object_ByteArrayValue:
- copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
- case *pb.Object_Uint16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Uint16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]uint16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Bool:
+ obj.SetBool(bool(x))
+ case wire.Int:
+ obj.SetInt(int64(x))
+ if obj.Int() != int64(x) {
+ Failf("signed integer truncated from %v to %v", int64(x), obj.Int())
}
- for i := range s {
- t[i] = uint16(s[i])
+ case wire.Uint:
+ obj.SetUint(uint64(x))
+ if obj.Uint() != uint64(x) {
+ Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint())
}
- case *pb.Object_Uint32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
- case *pb.Object_Uint64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
- case *pb.Object_UintptrArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
- case *pb.Object_Int8ArrayValue:
- copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
- case *pb.Object_Int16ArrayValue:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := x.Int16ArrayValue.Values
- t := obj.Slice(0, obj.Len()).Interface().([]int16)
- if len(t) != len(s) {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
+ case wire.Float32:
+ obj.SetFloat(float64(x))
+ case wire.Float64:
+ obj.SetFloat(float64(x))
+ if !isFloatEq(obj.Float(), float64(x)) {
+ Failf("floating point number truncated from %v to %v", float64(x), obj.Float())
}
- for i := range s {
- t[i] = int16(s[i])
+ case *wire.Complex64:
+ obj.SetComplex(complex128(*x))
+ case *wire.Complex128:
+ obj.SetComplex(complex128(*x))
+ if !isComplexEq(obj.Complex(), complex128(*x)) {
+ Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex())
}
- case *pb.Object_Int32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
- case *pb.Object_Int64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
- case *pb.Object_BoolArrayValue:
- copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
- case *pb.Object_Float64ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
- case *pb.Object_Float32ArrayValue:
- copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
+ case *wire.String:
+ obj.SetString(string(*x))
+ case *wire.Slice:
+ // See *wire.Ref above; same applies.
+ if id := objectID(x.Ref.Root); id == 0 {
+ return
+ }
+ // Note that it's fine to slice the array here and assume that
+ // contents will still be filled in later on.
+ typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type.
+ v := ds.register(&x.Ref, typ)
+ obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity)))
+ case *wire.Array:
+ ds.decodeArray(ods, obj, x)
+ case *wire.Struct:
+ ds.decodeStruct(ods, obj, x)
+ case *wire.Map:
+ ds.decodeMap(ods, obj, x)
+ case *wire.Interface:
+ ds.decodeInterface(ods, obj, x)
default:
// Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
+ Failf("unknown object %#v for %q", encoded, obj.Type().Name())
}
-
- ds.stats.Done()
- ds.pop()
}
-func copyArray(dest reflect.Value, src reflect.Value) {
- if dest.Len() != src.Len() {
- panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
- }
- reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
-}
-
-// Deserialize deserializes the object state.
+// Load deserializes the object graph rooted at obj.
//
// This function may panic and should be run in safely().
-func (ds *decodeState) Deserialize(obj reflect.Value) {
- ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
- ds.outstanding = 1 // The root object.
+func (ds *decodeState) Load(obj reflect.Value) {
+ ds.stats.init()
+ defer ds.stats.fini(func(id typeID) string {
+ return ds.types.LookupName(id)
+ })
+
+ // Create the root object.
+ ds.objectsByID = append(ds.objectsByID, &objectDecodeState{
+ id: 1,
+ obj: obj,
+ })
+
+ // Read the number of objects.
+ lastID, object, err := ReadHeader(ds.r)
+ if err != nil {
+ Failf("header error: %w", err)
+ }
+ if !object {
+ Failf("object missing")
+ }
+
+ // Decode all objects.
+ var (
+ encoded wire.Object
+ ods *objectDecodeState
+ id = objectID(1)
+ tid = typeID(1)
+ )
+ if err := safely(func() {
+ // Decode all objects in the stream.
+ //
+ // Note that the structure of this decoding loop should match
+ // the raw decoding loop in printer.go.
+ for id <= objectID(lastID) {
+ // Unmarshal the object.
+ encoded = wire.Load(ds.r)
+
+ // Is this a type object? Handle inline.
+ if wt, ok := encoded.(*wire.Type); ok {
+ ds.types.Register(wt)
+ tid++
+ encoded = nil
+ continue
+ }
- // Decode all objects in the stream.
- //
- // See above, we never process objects while we have no outstanding
- // interests (other than the very first object).
- for id := uint64(1); ds.outstanding > 0; id++ {
- os := ds.lookup(id)
- ds.stats.Start(os.obj)
-
- o, err := ds.readObject()
- if err != nil {
- panic(err)
- }
+ // Actually resolve the object.
+ ods = ds.lookup(id)
+ if ods != nil {
+ // Decode the object.
+ ds.decodeObject(ods, ods.obj, encoded)
+ } else {
+ // If an object hasn't had interest registered
+ // previously or isn't yet valid, we deferred
+ // decoding until interest is registered.
+ ds.deferred[id] = encoded
+ }
- if os != nil {
- // Decode the object.
- ds.from = &os.path
- ds.decodeObject(os, os.obj, o, "", nil)
- ds.outstanding--
+ // For error handling.
+ ods = nil
+ encoded = nil
+ id++
+ }
+ }); err != nil {
+ // Include as much information as we can, taking into account
+ // the possible state transitions above.
+ if ods != nil {
+ Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err)
+ } else if encoded != nil {
+ Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err)
} else {
- // If an object hasn't had interest registered
- // previously, we deferred decoding until interest is
- // registered.
- ds.deferred[id] = o
+ Failf("general decoding error: %w", err)
}
-
- ds.stats.Done()
- }
-
- // Check the zero-length header at the end.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- panic(err)
- }
- if length != 0 {
- panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
- }
- if object {
- panic("expected non-object terminal")
}
// Check if we have any deferred objects.
- if count := len(ds.deferred); count > 0 {
- // Shoud not happen, not propagated as an error.
- panic(fmt.Sprintf("still have %d deferred objects", count))
- }
-
- // Scan and fire all callbacks.
- for _, os := range ds.objectsByID {
- os.checkComplete(ds.stats)
+ for id, encoded := range ds.deferred {
+ // Shoud never happen, the graph was bogus.
+ Failf("still have deferred objects: one is ID %d, %#v", id, encoded)
}
- // Check if we have any remaining dependency cycles.
- for _, os := range ds.objectsByID {
- if !os.complete() {
- // This must be the result of a dependency cycle.
- cycle := os.findCycle()
- var buf bytes.Buffer
- buf.WriteString("dependency cycle: {")
- for i, cycleOS := range cycle {
- if i > 0 {
- buf.WriteString(" => ")
+ // Scan and fire all callbacks. We iterate over the list of incomplete
+ // objects until all have been finished. We stop iterating if no
+ // objects become complete (there is a dependency cycle).
+ //
+ // Note that we iterate backwards here, because there will be a strong
+ // tendendcy for blocking relationships to go from earlier objects to
+ // later (deeper) objects in the graph. This will reduce the number of
+ // iterations required to finish all objects.
+ if err := safely(func() {
+ for ds.pending.Back() != nil {
+ thisCycle := false
+ for ods = ds.pending.Back(); ods != nil; {
+ if ds.checkComplete(ods) {
+ thisCycle = true
+ break
}
- buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
+ ods = ods.Prev()
+ }
+ if !thisCycle {
+ break
}
- buf.WriteString("}")
- // Panic as an error; propagate to the caller.
- panic(errors.New(string(buf.Bytes())))
}
- }
-}
-
-type byteReader struct {
- io.Reader
-}
-
-// ReadByte implements io.ByteReader.
-func (br byteReader) ReadByte() (byte, error) {
- var b [1]byte
- n, err := br.Reader.Read(b[:])
- if n > 0 {
- return b[0], nil
- } else if err != nil {
- return 0, err
- } else {
- return 0, io.ErrUnexpectedEOF
+ }); err != nil {
+ Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err)
+ }
+
+ // Check if we have any remaining dependency cycles. If there are any
+ // objects left in the pending list, then it must be due to a cycle.
+ if ods := ds.pending.Front(); ods != nil {
+ // This must be the result of a dependency cycle.
+ cycle := ods.findCycle()
+ var buf bytes.Buffer
+ buf.WriteString("dependency cycle: {")
+ for i, cycleOS := range cycle {
+ if i > 0 {
+ buf.WriteString(" => ")
+ }
+ fmt.Fprintf(&buf, "%q", cycleOS.obj.Type())
+ }
+ buf.WriteString("}")
+ Failf("incomplete graph: %s", string(buf.Bytes()))
}
}
@@ -561,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) {
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
-func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
+func ReadHeader(r wire.Reader) (length uint64, object bool, err error) {
// Read the header.
- length, err = binary.ReadUvarint(byteReader{r})
+ err = safely(func() {
+ length = wire.LoadUint(r)
+ })
if err != nil {
- return
+ // On the header, pass raw I/O errors.
+ if sErr, ok := err.(*ErrState); ok {
+ return 0, false, sErr.Unwrap()
+ }
}
// Decode whether the object is valid.
- object = length&0x1 != 0
- length = length >> 1
+ object = length&objectFlag != 0
+ length &^= objectFlag
return
}
-
-// readObject reads an object from the stream.
-func (ds *decodeState) readObject() (*pb.Object, error) {
- // Read the header.
- length, object, err := ReadHeader(ds.r)
- if err != nil {
- return nil, err
- }
- if !object {
- return nil, fmt.Errorf("invalid object header")
- }
-
- // Read the object.
- buf := make([]byte, length)
- for done := 0; done < len(buf); {
- n, err := ds.r.Read(buf[done:])
- done += n
- if n == 0 && err != nil {
- return nil, err
- }
- }
-
- // Unmarshal.
- obj := new(pb.Object)
- if err := proto.Unmarshal(buf, obj); err != nil {
- return nil, err
- }
-
- return obj, nil
-}
diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go
new file mode 100644
index 000000000..d048f61a1
--- /dev/null
+++ b/pkg/state/decode_unsafe.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on
+// values representing unexported fields. This bypasses visibility, but not
+// type safety.
+func unsafePointerTo(obj reflect.Value) reflect.Value {
+ return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr()))
+}
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 5d9409a45..92fcad4e9 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -15,433 +15,797 @@
package state
import (
- "container/list"
- "encoding/binary"
- "fmt"
- "io"
+ "context"
"reflect"
- "sort"
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
-// queuedObject is an object queued for encoding.
-type queuedObject struct {
- id uint64
- obj reflect.Value
- path recoverable
+// objectEncodeState the type and identity of an object occupying a memory
+// address range. This is the value type for addrSet, and the intrusive entry
+// for the pending and deferred lists.
+type objectEncodeState struct {
+ // id is the assigned ID for this object.
+ id objectID
+
+ // obj is the object value. Note that this may be replaced if we
+ // encounter an object that contains this object. When this happens (in
+ // resolve), we will update existing references approprately, below,
+ // and defer a re-encoding of the object.
+ obj reflect.Value
+
+ // encoded is the encoded value of this object. Note that this may not
+ // be up to date if this object is still in the deferred list.
+ encoded wire.Object
+
+ // how indicates whether this object should be encoded as a value. This
+ // is used only for deferred encoding.
+ how encodeStrategy
+
+ // refs are the list of reference objects used by other objects
+ // referring to this object. When the object is updated, these
+ // references may be updated directly and automatically.
+ refs []*wire.Ref
+
+ pendingEntry
+ deferredEntry
}
// encodeState is state used for encoding.
//
-// The encoding process is a breadth-first traversal of the object graph. The
-// inherent races and dependencies are much simpler than the decode case.
+// The encoding process constructs a representation of the in-memory graph of
+// objects before a single object is serialized. This is done to ensure that
+// all references can be fully disambiguated. See resolve for more details.
type encodeState struct {
- // lastID is the last object ID.
- //
- // See idsByObject for context. Because of the special zero encoding
- // used for reference values, the first ID must be 1.
- lastID uint64
+ // ctx is the encode context.
+ ctx context.Context
- // idsByObject is a set of objects, indexed via:
- //
- // reflect.ValueOf(x).UnsafeAddr
- //
- // This provides IDs for objects.
- idsByObject map[uintptr]uint64
+ // w is the output stream.
+ w wire.Writer
+
+ // types is the type database.
+ types typeEncodeDatabase
- // values stores values that span the addresses.
+ // lastID is the last allocated object ID.
+ lastID objectID
+
+ // values tracks the address ranges occupied by objects, along with the
+ // types of these objects. This is used to locate pointer targets,
+ // including pointers to fields within another type.
//
- // addrSet is a a generated type which efficiently stores ranges of
- // addresses. When encoding pointers, these ranges are filled in and
- // used to check for overlapping or conflicting pointers. This would
- // indicate a pointer to an field, or a non-type safe value, neither of
- // which are currently decodable.
+ // Multiple objects may overlap in memory iff the larger object fully
+ // contains the smaller one, and the type of the smaller object matches
+ // a field or array element's type at the appropriate offset. An
+ // arbitrary number of objects may be nested in this manner.
//
- // See the usage of values below for more context.
+ // Note that this does not track zero-sized objects, those are tracked
+ // by zeroValues below.
values addrSet
- // w is the output stream.
- w io.Writer
+ // zeroValues tracks zero-sized objects.
+ zeroValues map[reflect.Type]*objectEncodeState
- // pending is the list of objects to be serialized.
- //
- // This is a set of queuedObjects.
- pending list.List
+ // deferred is the list of objects to be encoded.
+ deferred deferredList
- // done is the a list of finished objects.
- //
- // This is kept to prevent garbage collection and address reuse.
- done list.List
+ // pendingTypes is the list of types to be serialized. Serialization
+ // will occur when all objects have been encoded, but before pending is
+ // serialized.
+ pendingTypes []wire.Type
- // stats is the passed stats object.
- stats *Stats
+ // pending is the list of objects to be serialized. Serialization does
+ // not actually occur until the full object graph is computed.
+ pending pendingList
- // recoverable is the panic recover facility.
- recoverable
+ // stats tracks time data.
+ stats Stats
}
-// register looks up an ID, registering if necessary.
+// isSameSizeParent returns true if child is a field value or element within
+// parent. Only a struct or array can have a child value.
+//
+// isSameSizeParent deals with objects like this:
//
-// If the object was not previously registered, it is enqueued to be serialized.
-// See the documentation for idsByObject for more information.
-func (es *encodeState) register(obj reflect.Value) uint64 {
- // It is not legal to call register for any non-pointer objects (see
- // below), so we panic with a recoverable error if this is a mismatch.
- if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map {
- panic(fmt.Errorf("non-pointer %#v registered", obj.Interface()))
+// struct child {
+// // fields..
+// }
+//
+// struct parent {
+// c child
+// }
+//
+// var p parent
+// record(&p.c)
+//
+// Here, &p and &p.c occupy the exact same address range.
+//
+// Or like this:
+//
+// struct child {
+// // fields
+// }
+//
+// var arr [1]parent
+// record(&arr[0])
+//
+// Similarly, &arr[0] and &arr[0].c have the exact same address range.
+//
+// Precondition: parent and child must occupy the same memory.
+func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool {
+ switch parent.Kind() {
+ case reflect.Struct:
+ for i := 0; i < parent.NumField(); i++ {
+ field := parent.Field(i)
+ if field.Type() == childType {
+ return true
+ }
+ // Recurse through any intermediate types.
+ if isSameSizeParent(field, childType) {
+ return true
+ }
+ // Does it make sense to keep going if the first field
+ // doesn't match? Yes, because there might be an
+ // arbitrary number of zero-sized fields before we get
+ // a match, and childType itself can be zero-sized.
+ }
+ return false
+ case reflect.Array:
+ // The only case where an array with more than one elements can
+ // return true is if childType is zero-sized. In such cases,
+ // it's ambiguous which element contains the match since a
+ // zero-sized child object fully fits in any of the zero-sized
+ // elements in an array... However since all elements are of
+ // the same type, we only need to check one element.
+ //
+ // For non-zero-sized childTypes, parent.Len() must be 1, but a
+ // combination of the precondition and an implicit comparison
+ // between the array element size and childType ensures this.
+ return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType)
+ default:
+ return false
}
+}
- addr := obj.Pointer()
- if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 {
- // For zero-sized objects, we always provide a unique ID.
- // That's because the runtime internally multiplexes pointers
- // to the same address. We can't be certain what the intent is
- // with pointers to zero-sized objects, so we just give them
- // all unique identities.
- } else if id, ok := es.idsByObject[addr]; ok {
- // Already registered.
- return id
- }
-
- // Ensure that the first ID given out is one. See note on lastID. The
- // ID zero is used to indicate nil values.
+// nextID returns the next valid ID.
+func (es *encodeState) nextID() objectID {
es.lastID++
- id := es.lastID
- es.idsByObject[addr] = id
- if obj.Kind() == reflect.Ptr {
- // Dereference and treat as a pointer.
- es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()})
-
- // Register this object at all addresses.
- typ := obj.Elem().Type()
- if size := typ.Size(); size > 0 {
- r := addrRange{addr, addr + size}
- if !es.values.IsEmptyRange(r) {
- old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable)
- panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path()))
+ return objectID(es.lastID)
+}
+
+// dummyAddr points to the dummy zero-sized address.
+var dummyAddr = reflect.ValueOf(new(struct{})).Pointer()
+
+// resolve records the address range occupied by an object.
+func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) {
+ addr := obj.Pointer()
+
+ // Is this a map pointer? Just record the single address. It is not
+ // possible to take any pointers into the map internals.
+ if obj.Kind() == reflect.Map {
+ if addr == 0 {
+ // Just leave the nil reference alone. This is fine, we
+ // may need to encode as a reference in this way. We
+ // return nil for our objectEncodeState so that anyone
+ // depending on this value knows there's nothing there.
+ return
+ }
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ // Ensure the map types match.
+ existing := seg.Value()
+ if existing.obj.Type() != obj.Type() {
+ Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj)
+ }
+
+ // No sense recording refs, maps may not be replaced by
+ // covering objects, they are maximal.
+ ref.Root = wire.Uint(existing.id)
+ return
+ }
+
+ // Record the map.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ how: encodeMapAsValue,
+ }
+ es.values.Add(addrRange{addr, addr + 1}, oes)
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+
+ // See above: no ref recording.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+
+ // If not a map, then the object must be a pointer.
+ if obj.Kind() != reflect.Ptr {
+ Failf("attempt to record non-map and non-pointer object %#v", obj)
+ }
+
+ obj = obj.Elem() // Value from here.
+
+ // Is this a zero-sized type?
+ typ := obj.Type()
+ size := typ.Size()
+ if size == 0 {
+ if addr == dummyAddr {
+ // Zero-sized objects point to a dummy byte within the
+ // runtime. There's no sense recording this in the
+ // address map. We add this to the dedicated
+ // zeroValues.
+ //
+ // Note that zero-sized objects must be *true*
+ // zero-sized objects. They cannot be part of some
+ // larger object. In that case, they are assigned a
+ // 1-byte address at the end of the object.
+ oes, ok := es.zeroValues[typ]
+ if !ok {
+ oes = &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ es.zeroValues[typ] = oes
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ }
+
+ // There's also no sense tracking back references. We
+ // know that this is a true zero-sized object, and not
+ // part of a larger container, so it will not change.
+ ref.Root = wire.Uint(oes.id)
+ return
+ }
+ size = 1 // See above.
+ }
+
+ // Calculate the container.
+ end := addr + size
+ r := addrRange{addr, end}
+ if seg, _ := es.values.Find(addr); seg.Ok() {
+ existing := seg.Value()
+ switch {
+ case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type():
+ // The object is a perfect match. Happy path. Avoid the
+ // traversal and just return directly. We don't need to
+ // encode the type information or any dots here.
+ ref.Root = wire.Uint(existing.id)
+ existing.refs = append(existing.refs, ref)
+ return
+
+ case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end):
+ // The previously registered object is larger than
+ // this, no need to update. But we expect some
+ // traversal below.
+
+ case seg.Start() == addr && seg.End() == end:
+ if !isSameSizeParent(obj, existing.obj.Type()) {
+ break // Needs traversal.
+ }
+ fallthrough // Needs update.
+
+ case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end):
+ // Update the object and redo the encoding.
+ old := existing.obj
+ existing.obj = obj
+ es.deferred.Remove(existing)
+ es.deferred.PushBack(existing)
+
+ // The previously registered object is superseded by
+ // this new object. We are guaranteed to not have any
+ // mergeable neighbours in this segment set.
+ if !raceEnabled {
+ seg.SetRangeUnchecked(r)
+ } else {
+ // Add extra paranoid. This will be statically
+ // removed at compile time unless a race build.
+ es.values.Remove(seg)
+ es.values.Add(r, existing)
+ seg = es.values.LowerBoundSegment(addr)
+ }
+
+ // Compute the traversal required & update references.
+ dots := traverse(obj.Type(), old.Type(), addr, seg.Start())
+ wt := es.findType(obj.Type())
+ for _, ref := range existing.refs {
+ ref.Dots = append(ref.Dots, dots...)
+ ref.Type = wt
}
- es.values.Add(r, reflect.ValueOf(es.recoverable.copy()))
+ default:
+ // There is a non-sensical overlap.
+ Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj)
}
+
+ // Compute the new reference, record and return it.
+ ref.Root = wire.Uint(existing.id)
+ ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr)
+ ref.Type = es.findType(obj.Type())
+ existing.refs = append(existing.refs, ref)
+ return
+ }
+
+ // The only remaining case is a pointer value that doesn't overlap with
+ // any registered addresses. Create a new entry for it, and start
+ // tracking the first reference we just created.
+ oes := &objectEncodeState{
+ id: es.nextID(),
+ obj: obj,
+ }
+ if !raceEnabled {
+ es.values.AddWithoutMerging(r, oes)
} else {
- // Push back the map itself; when maps are encoded from the
- // top-level, forceMap will be equal to true.
- es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()})
+ // Merges should never happen. This is just enabled extra
+ // sanity checks because the Merge function below will panic.
+ es.values.Add(r, oes)
+ }
+ es.pending.PushBack(oes)
+ es.deferred.PushBack(oes)
+ ref.Root = wire.Uint(oes.id)
+ oes.refs = append(oes.refs, ref)
+}
+
+// traverse searches for a target object within a root object, where the target
+// object is a struct field or array element within root, with potentially
+// multiple intervening types. traverse returns the set of field or element
+// traversals required to reach the target.
+//
+// Note that for efficiency, traverse returns the dots in the reverse order.
+// That is, the first traversal required will be the last element of the list.
+//
+// Precondition: The target object must lie completely within the range defined
+// by [rootAddr, rootAddr + sizeof(rootType)].
+func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot {
+ // Recursion base case: the types actually match.
+ if targetType == rootType && targetAddr == rootAddr {
+ return nil
}
- return id
+ switch rootType.Kind() {
+ case reflect.Struct:
+ offset := targetAddr - rootAddr
+ for i := rootType.NumField(); i > 0; i-- {
+ field := rootType.Field(i - 1)
+ // The first field from the end with an offset that is
+ // smaller than or equal to our address offset is where
+ // the target is located. Traverse from there.
+ if field.Offset <= offset {
+ dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr)
+ fieldName := wire.FieldName(field.Name)
+ return append(dots, &fieldName)
+ }
+ }
+ // Should never happen; the target should be reachable.
+ Failf("no field in root type %v contains target type %v", rootType, targetType)
+
+ case reflect.Array:
+ // Since arrays have homogenous types, all elements have the
+ // same size and we can compute where the target lives. This
+ // does not matter for the purpose of typing, but matters for
+ // the purpose of computing the address of the given index.
+ elemSize := int(rootType.Elem().Size())
+ n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down.
+ if rootType.Len() < n {
+ Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements",
+ targetType, targetAddr, rootType, rootAddr, rootType.Len())
+ }
+ dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr)
+ return append(dots, wire.Index(n))
+
+ default:
+ // For any other type, there's no possibility of aliasing so if
+ // the types didn't match earlier then we have an addresss
+ // collision which shouldn't be possible at this point.
+ Failf("traverse failed for root type %v and target type %v", rootType, targetType)
+ }
+ panic("unreachable")
}
// encodeMap encodes a map.
-func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map {
- var (
- keys []*pb.Object
- values []*pb.Object
- )
+func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) {
+ if obj.IsNil() {
+ // Because there is a difference between a nil map and an empty
+ // map, we need to not decode in the case of a truly nil map.
+ *dest = wire.Nil{}
+ return
+ }
+ l := obj.Len()
+ m := &wire.Map{
+ Keys: make([]wire.Object, l),
+ Values: make([]wire.Object, l),
+ }
+ *dest = m
for i, k := range obj.MapKeys() {
v := obj.MapIndex(k)
- kp := es.encodeObject(k, false, ".(key %d)", i)
- vp := es.encodeObject(v, false, "[%#v]", k.Interface())
- keys = append(keys, kp)
- values = append(values, vp)
+ // Map keys must be encoded using the full value because the
+ // type will be omitted after the first key.
+ es.encodeObject(k, encodeAsValue, &m.Keys[i])
+ es.encodeObject(v, encodeAsValue, &m.Values[i])
}
- return &pb.Map{Keys: keys, Values: values}
+}
+
+// objectEncoder is for encoding structs.
+type objectEncoder struct {
+ // es is encodeState.
+ es *encodeState
+
+ // encoded is the encoded struct.
+ encoded *wire.Struct
+}
+
+// save is called by the public methods on Sink.
+func (oe *objectEncoder) save(slot int, obj reflect.Value) {
+ fieldValue := oe.encoded.Field(slot)
+ oe.es.encodeObject(obj, encodeDefault, fieldValue)
}
// encodeStruct encodes a composite object.
-func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct {
- // Invoke the save.
- m := Map{newInternalMap(es, nil, nil)}
- defer internalMapPool.Put(m.internalMap)
+func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) {
+ // Ensure that the obj is addressable. There are two cases when it is
+ // not. First, is when this is dispatched via SaveValue. Second, when
+ // this is a map key as a struct. Either way, we need to make a copy to
+ // obtain an addressable value.
if !obj.CanAddr() {
- // Force it to a * type of the above; this involves a copy.
localObj := reflect.New(obj.Type())
localObj.Elem().Set(obj)
obj = localObj.Elem()
}
- fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
- if ok {
- // Invoke the provided saver.
- fns.invokeSave(obj.Addr(), m)
- } else if obj.NumField() == 0 {
- // Allow unregistered anonymous, empty structs.
- return &pb.Struct{}
- } else {
- // Propagate an error.
- panic(fmt.Errorf("unregistered type %T", obj.Interface()))
- }
-
- // Sort the underlying slice, and check for duplicates. This is done
- // once instead of on each add, because performing this sort once is
- // far more efficient.
- if len(m.data) > 1 {
- sort.Slice(m.data, func(i, j int) bool {
- return m.data[i].name < m.data[j].name
- })
- for i := range m.data {
- if i > 0 && m.data[i-1].name == m.data[i].name {
- panic(fmt.Errorf("duplicate name %s", m.data[i].name))
- }
+
+ // Prepare the value.
+ s := &wire.Struct{}
+ *dest = s
+
+ // Look the type up in the database.
+ te, ok := es.types.Lookup(obj.Type())
+ if te == nil {
+ if obj.NumField() == 0 {
+ // Allow unregistered anonymous, empty structs. This
+ // will just return success without ever invoking the
+ // passed function. This uses the immutable EmptyStruct
+ // variable to prevent an allocation in this case.
+ //
+ // Note that this mechanism does *not* work for
+ // interfaces in general. So you can't dispatch
+ // non-registered empty structs via interfaces because
+ // then they can't be restored.
+ s.Alloc(0)
+ return
}
+ // We need a SaverLoader for struct types.
+ Failf("struct %T does not implement SaverLoader", obj.Interface())
}
-
- // Encode the resulting fields.
- fields := make([]*pb.Field, 0, len(m.data))
- for _, e := range m.data {
- fields = append(fields, &pb.Field{
- Name: e.name,
- Value: e.object,
- })
+ if !ok {
+ // Queue the type to be serialized.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
}
- // Return the encoded object.
- return &pb.Struct{Fields: fields}
+ // Invoke the provided saver.
+ s.TypeID = wire.TypeID(te.ID)
+ s.Alloc(len(te.Fields))
+ oe := objectEncoder{
+ es: es,
+ encoded: s,
+ }
+ es.stats.start(te.ID)
+ defer es.stats.done()
+ if sl, ok := obj.Addr().Interface().(SaverLoader); ok {
+ // Note: may be a registered empty struct which does not
+ // implement the saver/loader interfaces.
+ sl.StateSave(Sink{internal: oe})
+ }
}
// encodeArray encodes an array.
-func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array {
- var (
- contents []*pb.Object
- )
- for i := 0; i < obj.Len(); i++ {
- entry := es.encodeObject(obj.Index(i), false, "[%d]", i)
- contents = append(contents, entry)
- }
- return &pb.Array{Contents: contents}
+func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) {
+ l := obj.Len()
+ a := &wire.Array{
+ Contents: make([]wire.Object, l),
+ }
+ *dest = a
+ for i := 0; i < l; i++ {
+ // We need to encode the full value because arrays are encoded
+ // using the type information from only the first element.
+ es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i])
+ }
+}
+
+// findType recursively finds type information.
+func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec {
+ // First: check if this is a proper type. It's possible for pointers,
+ // slices, arrays, maps, etc to all have some different type.
+ te, ok := es.types.Lookup(typ)
+ if te != nil {
+ if !ok {
+ // See encodeStruct.
+ es.pendingTypes = append(es.pendingTypes, te.Type)
+ }
+ return wire.TypeID(te.ID)
+ }
+
+ switch typ.Kind() {
+ case reflect.Ptr:
+ return &wire.TypeSpecPointer{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Slice:
+ return &wire.TypeSpecSlice{
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Array:
+ return &wire.TypeSpecArray{
+ Count: wire.Uint(typ.Len()),
+ Type: es.findType(typ.Elem()),
+ }
+ case reflect.Map:
+ return &wire.TypeSpecMap{
+ Key: es.findType(typ.Key()),
+ Value: es.findType(typ.Elem()),
+ }
+ default:
+ // After potentially chasing many pointers, the
+ // ultimate type of the object is not known.
+ Failf("type %q is not known", typ)
+ }
+ panic("unreachable")
}
// encodeInterface encodes an interface.
-//
-// Precondition: the value is not nil.
-func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface {
- // Check for the nil interface.
- obj = reflect.ValueOf(obj.Interface())
+func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) {
+ // Dereference the object.
+ obj = obj.Elem()
if !obj.IsValid() {
- return &pb.Interface{
- Type: "", // left alone in decode.
- Value: &pb.Object{Value: &pb.Object_RefValue{0}},
+ // Special case: the nil object.
+ *dest = &wire.Interface{
+ Type: wire.TypeSpecNil{},
+ Value: wire.Nil{},
}
+ return
}
- // We have an interface value here. How do we save that? We
- // resolve the underlying type and save it as a dispatchable.
- typName, ok := registeredTypes.lookupName(obj.Type())
- if !ok {
- panic(fmt.Errorf("type %s is not registered", obj.Type()))
+
+ // Encode underlying object.
+ i := &wire.Interface{
+ Type: es.findType(obj.Type()),
}
+ *dest = i
+ es.encodeObject(obj, encodeAsValue, &i.Value)
+}
- // Encode the object again.
- return &pb.Interface{
- Type: typName,
- Value: es.encodeObject(obj, false, ".(%s)", typName),
+// isPrimitive returns true if this is a primitive object, or a composite
+// object composed entirely of primitives.
+func isPrimitiveZero(typ reflect.Type) bool {
+ switch typ.Kind() {
+ case reflect.Ptr:
+ // Pointers are always treated as primitive types because we
+ // won't encode directly from here. Returning true here won't
+ // prevent the object from being encoded correctly.
+ return true
+ case reflect.Bool:
+ return true
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return true
+ case reflect.Float32, reflect.Float64:
+ return true
+ case reflect.Complex64, reflect.Complex128:
+ return true
+ case reflect.String:
+ return true
+ case reflect.Slice:
+ // The slice itself a primitive, but not necessarily the array
+ // that points to. This is similar to a pointer.
+ return true
+ case reflect.Array:
+ // We cannot treat an array as a primitive, because it may be
+ // composed of structures or other things with side-effects.
+ return isPrimitiveZero(typ.Elem())
+ case reflect.Interface:
+ // Since we now that this type is the zero type, the interface
+ // value must be zero. Therefore this is primitive.
+ return true
+ case reflect.Struct:
+ return false
+ case reflect.Map:
+ // The isPrimitiveZero function is called only on zero-types to
+ // see if it's safe to serialize. Since a zero map has no
+ // elements, it is safe to treat as a primitive.
+ return true
+ default:
+ Failf("unknown type %q", typ.Name())
}
+ panic("unreachable")
}
-// encodeObject encodes an object.
-//
-// If mapAsValue is true, then a map will be encoded directly.
-func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) {
- es.push(false, format, param)
- es.stats.Add(obj)
- es.stats.Start(obj)
+// encodeStrategy is the strategy used for encodeObject.
+type encodeStrategy int
+const (
+ // encodeDefault means types are encoded normally as references.
+ encodeDefault encodeStrategy = iota
+
+ // encodeAsValue means that types will never take short-circuited and
+ // will always be encoded as a normal value.
+ encodeAsValue
+
+ // encodeMapAsValue means that even maps will be fully encoded.
+ encodeMapAsValue
+)
+
+// encodeObject encodes an object.
+func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) {
+ if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() {
+ *dest = wire.Nil{}
+ return
+ }
switch obj.Kind() {
+ case reflect.Ptr: // Fast path: first.
+ r := new(wire.Ref)
+ *dest = r
+ if obj.IsNil() {
+ // May be in an array or elsewhere such that a value is
+ // required. So we encode as a reference to the zero
+ // object, which does not exist. Note that this has to
+ // be handled correctly in the decode path as well.
+ return
+ }
+ es.resolve(obj, r)
case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}}
+ *dest = wire.Bool(obj.Bool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}}
+ *dest = wire.Int(obj.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}}
- case reflect.Float32, reflect.Float64:
- object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}}
+ *dest = wire.Uint(obj.Uint())
+ case reflect.Float32:
+ *dest = wire.Float32(obj.Float())
+ case reflect.Float64:
+ *dest = wire.Float64(obj.Float())
+ case reflect.Complex64:
+ c := wire.Complex64(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.Complex128:
+ c := wire.Complex128(obj.Complex())
+ *dest = &c // Needs alloc.
+ case reflect.String:
+ s := wire.String(obj.String())
+ *dest = &s // Needs alloc.
case reflect.Array:
- switch obj.Type().Elem().Kind() {
- case reflect.Uint8:
- object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}}
- case reflect.Uint16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]uint16)
- t := make([]uint32, len(s))
- for i := range s {
- t[i] = uint32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}}
- case reflect.Uint32:
- object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}}
- case reflect.Uint64:
- object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Uintptr:
- object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}}
- case reflect.Int8:
- object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}}
- case reflect.Int16:
- // 16-bit slices are serialized as 32-bit slices.
- // See object.proto for details.
- s := pbSlice(obj).Interface().([]int16)
- t := make([]int32, len(s))
- for i := range s {
- t[i] = int32(s[i])
- }
- object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}}
- case reflect.Int32:
- object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}}
- case reflect.Int64:
- object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}}
- case reflect.Bool:
- object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}}
- case reflect.Float32:
- object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}}
- case reflect.Float64:
- object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}}
- default:
- object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}}
- }
+ es.encodeArray(obj, dest)
case reflect.Slice:
- if obj.IsNil() || obj.Cap() == 0 {
- // Handled specially in decode; store as nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- // Serialize a slice as the array plus length and capacity.
- object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{
- Capacity: uint32(obj.Cap()),
- Length: uint32(obj.Len()),
- RefValue: es.register(arrayFromSlice(obj)),
- }}}
+ s := &wire.Slice{
+ Capacity: wire.Uint(obj.Cap()),
+ Length: wire.Uint(obj.Len()),
}
- case reflect.String:
- object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}}
- case reflect.Ptr:
+ *dest = s
+ // Note that we do need to provide a wire.Slice type here as
+ // how is not encodeDefault. If this were the case, then it
+ // would have been caught by the IsZero check above and we
+ // would have just used wire.Nil{}.
if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else {
- es.push(true /* dereference */, "", nil)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
- es.pop()
+ return
}
+ // Slices need pointer resolution.
+ es.resolve(arrayFromSlice(obj), &s.Ref)
case reflect.Interface:
- // We don't check for IsNil here, as we want to encode type
- // information. The case of the empty interface (no type, no
- // value) is handled by encodeInteface.
- object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}}
+ es.encodeInterface(obj, dest)
case reflect.Struct:
- object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}}
+ es.encodeStruct(obj, dest)
case reflect.Map:
- if obj.IsNil() {
- // Handled specially in decode; store as a nil value.
- object = &pb.Object{Value: &pb.Object_RefValue{0}}
- } else if mapAsValue {
- // Encode the map directly.
- object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}}
- } else {
- // Encode a reference to the map.
- //
- // Remove the map object count here to avoid double
- // counting, as this object will be counted again when
- // it gets processed later. We do not add a reference
- // count as the reference is artificial.
- es.stats.Remove(obj)
- object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}}
+ if how == encodeMapAsValue {
+ es.encodeMap(obj, dest)
+ return
}
+ r := new(wire.Ref)
+ *dest = r
+ es.resolve(obj, r)
default:
- panic(fmt.Errorf("unknown primitive %#v", obj.Interface()))
+ Failf("unknown object %#v", obj.Interface())
+ panic("unreachable")
}
-
- es.stats.Done()
- es.pop()
- return
}
-// Serialize serializes the object state.
-//
-// This function may panic and should be run in safely().
-func (es *encodeState) Serialize(obj reflect.Value) {
- es.register(obj.Addr())
-
- // Pop off the list until we're done.
- for es.pending.Len() > 0 {
- e := es.pending.Front()
-
- // Extract the queued object.
- qo := e.Value.(queuedObject)
- es.stats.Start(qo.obj)
+// Save serializes the object graph rooted at obj.
+func (es *encodeState) Save(obj reflect.Value) {
+ es.stats.init()
+ defer es.stats.fini(func(id typeID) string {
+ return es.pendingTypes[id-1].Name
+ })
+
+ // Resolve the first object, which should queue a pile of additional
+ // objects on the pending list. All queued objects should be fully
+ // resolved, and we should be able to serialize after this call.
+ var root wire.Ref
+ es.resolve(obj.Addr(), &root)
+
+ // Encode the graph.
+ var oes *objectEncodeState
+ if err := safely(func() {
+ for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() {
+ // Remove and encode the object. Note that as a result
+ // of this encoding, the object may be enqueued on the
+ // deferred list yet again. That's expected, and why it
+ // is removed first.
+ es.deferred.Remove(oes)
+ es.encodeObject(oes.obj, oes.how, &oes.encoded)
+ }
+ }); err != nil {
+ // Include the object in the error message.
+ Failf("encoding error at object %#v: %w", oes.obj.Interface(), err)
+ }
- es.pending.Remove(e)
+ // Check that items are pending.
+ if es.pending.Front() == nil {
+ Failf("pending is empty?")
+ }
- es.from = &qo.path
- o := es.encodeObject(qo.obj, true, "", nil)
+ // Write the header with the number of objects. Note that there is no
+ // way that es.lastID could conflict with objectID, which would
+ // indicate that an impossibly large encoding.
+ if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil {
+ Failf("error writing header: %w", err)
+ }
- // Emit to our output stream.
- if err := es.writeObject(qo.id, o); err != nil {
- panic(err)
+ // Serialize all pending types and pending objects. Note that we don't
+ // bother removing from this list as we walk it because that just
+ // wastes time. It will not change after this point.
+ var id objectID
+ if err := safely(func() {
+ for _, wt := range es.pendingTypes {
+ // Encode the type.
+ wire.Save(es.w, &wt)
}
+ for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() {
+ id++ // First object is 1.
+ if oes.id != id {
+ Failf("expected id %d, got %d", id, oes.id)
+ }
- // Mark as done.
- es.done.PushBack(e)
- es.stats.Done()
+ // Marshall the object.
+ wire.Save(es.w, oes.encoded)
+ }
+ }); err != nil {
+ // Include the object and the error.
+ Failf("error serializing object %#v: %w", oes.encoded, err)
}
- // Write a zero-length terminal at the end; this is a sanity check
- // applied at decode time as well (see decode.go).
- if err := WriteHeader(es.w, 0, false); err != nil {
- panic(err)
+ // Check what we wrote.
+ if id != es.lastID {
+ Failf("expected %d objects, wrote %d", es.lastID, id)
}
}
+// objectFlag indicates that the length is a # of objects, rather than a raw
+// byte length. When this is set on a length header in the stream, it may be
+// decoded appropriately.
+const objectFlag uint64 = 1 << 63
+
// WriteHeader writes a header.
//
// Each object written to the statefile should be prefixed with a header. In
// order to generate statefiles that play nicely with debugging tools, raw
// writes should be prefixed with a header with object set to false and the
// appropriate length. This will allow tools to skip these regions.
-func WriteHeader(w io.Writer, length uint64, object bool) error {
- // The lowest-order bit encodes whether this is a valid object. This is
- // a purely internal convention, but allows the object flag to be
- // returned from ReadHeader.
- length = length << 1
+func WriteHeader(w wire.Writer, length uint64, object bool) error {
+ // Sanity check the length.
+ if length&objectFlag != 0 {
+ Failf("impossibly huge length: %d", length)
+ }
if object {
- length |= 0x1
+ length |= objectFlag
}
// Write a header.
- var hdr [32]byte
- encodedLen := binary.PutUvarint(hdr[:], length)
- for done := 0; done < encodedLen; {
- n, err := w.Write(hdr[done:encodedLen])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
-
- return nil
+ return safely(func() {
+ wire.SaveUint(w, length)
+ })
}
-// writeObject writes an object to the stream.
-func (es *encodeState) writeObject(id uint64, obj *pb.Object) error {
- // Marshal the proto.
- buf, err := proto.Marshal(obj)
- if err != nil {
- return err
- }
+// pendingMapper is for the pending list.
+type pendingMapper struct{}
- // Write the object header.
- if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil {
- return err
- }
+func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry }
- // Write the object.
- for done := 0; done < len(buf); {
- n, err := es.w.Write(buf[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
+// deferredMapper is for the deferred list.
+type deferredMapper struct{}
- return nil
-}
+func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry }
// addrSetFunctions is used by addrSet.
type addrSetFunctions struct{}
@@ -454,13 +818,24 @@ func (addrSetFunctions) MaxKey() uintptr {
return ^uintptr(0)
}
-func (addrSetFunctions) ClearValue(val *reflect.Value) {
+func (addrSetFunctions) ClearValue(val **objectEncodeState) {
+ *val = nil
}
-func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) {
- return val1, val1 == val2
+func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) {
+ if val1.obj == val2.obj {
+ // This, should never happen. It would indicate that the same
+ // object exists in two non-contiguous address ranges. Note
+ // that this assertion can only be triggered if the race
+ // detector is enabled.
+ Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj)
+ }
+ // Reject the merge.
+ return val1, false
}
-func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) {
- return val, val
+func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) {
+ // A split should never happen: we don't remove ranges.
+ Failf("unexpected split in addrSet @ %v: %#v", r, val.obj)
+ panic("unreachable")
}
diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go
index 457e6dbb7..e0dad83b4 100644
--- a/pkg/state/encode_unsafe.go
+++ b/pkg/state/encode_unsafe.go
@@ -31,51 +31,3 @@ func arrayFromSlice(obj reflect.Value) reflect.Value {
reflect.ArrayOf(obj.Cap(), obj.Type().Elem()),
unsafe.Pointer(obj.Pointer()))
}
-
-// pbSlice returns a protobuf-supported slice of the array and erase the
-// original element type (which could be a defined type or non-supported type).
-func pbSlice(obj reflect.Value) reflect.Value {
- var typ reflect.Type
- switch obj.Type().Elem().Kind() {
- case reflect.Uint8:
- typ = reflect.TypeOf(byte(0))
- case reflect.Uint16:
- typ = reflect.TypeOf(uint16(0))
- case reflect.Uint32:
- typ = reflect.TypeOf(uint32(0))
- case reflect.Uint64:
- typ = reflect.TypeOf(uint64(0))
- case reflect.Uintptr:
- typ = reflect.TypeOf(uint64(0))
- case reflect.Int8:
- typ = reflect.TypeOf(byte(0))
- case reflect.Int16:
- typ = reflect.TypeOf(int16(0))
- case reflect.Int32:
- typ = reflect.TypeOf(int32(0))
- case reflect.Int64:
- typ = reflect.TypeOf(int64(0))
- case reflect.Bool:
- typ = reflect.TypeOf(bool(false))
- case reflect.Float32:
- typ = reflect.TypeOf(float32(0))
- case reflect.Float64:
- typ = reflect.TypeOf(float64(0))
- default:
- panic("slice element is not of basic value type")
- }
- return reflect.NewAt(
- reflect.ArrayOf(obj.Len(), typ),
- unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
- ).Elem().Slice(0, obj.Len())
-}
-
-func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value {
- if obj.Type().Elem().Size() != elemTyp.Size() {
- panic("cannot cast slice into other element type of different size")
- }
- return reflect.NewAt(
- reflect.ArrayOf(obj.Len(), elemTyp),
- unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()),
- ).Elem()
-}
diff --git a/pkg/state/map.go b/pkg/state/map.go
deleted file mode 100644
index 7e6fefed4..000000000
--- a/pkg/state/map.go
+++ /dev/null
@@ -1,221 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "fmt"
- "reflect"
- "sort"
- "sync"
-
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
-)
-
-// entry is a single map entry.
-type entry struct {
- name string
- object *pb.Object
-}
-
-// internalMap is the internal Map state.
-//
-// These are recycled via a pool to avoid churn.
-type internalMap struct {
- // es is encodeState.
- es *encodeState
-
- // ds is decodeState.
- ds *decodeState
-
- // os is current object being decoded.
- //
- // This will always be nil during encode.
- os *objectState
-
- // data stores the encoded values.
- data []entry
-}
-
-var internalMapPool = sync.Pool{
- New: func() interface{} {
- return new(internalMap)
- },
-}
-
-// newInternalMap returns a cached map.
-func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap {
- m := internalMapPool.Get().(*internalMap)
- m.es = es
- m.ds = ds
- m.os = os
- if m.data != nil {
- m.data = m.data[:0]
- }
- return m
-}
-
-// Map is a generic state container.
-//
-// This is the object passed to Save and Load in order to store their state.
-//
-// Detailed documentation is available in individual methods.
-type Map struct {
- *internalMap
-}
-
-// Save adds the given object to the map.
-//
-// You should pass always pointers to the object you are saving. For example:
-//
-// type X struct {
-// A int
-// B *int
-// }
-//
-// func (x *X) Save(m Map) {
-// m.Save("A", &x.A)
-// m.Save("B", &x.B)
-// }
-//
-// func (x *X) Load(m Map) {
-// m.Load("A", &x.A)
-// m.Load("B", &x.B)
-// }
-func (m Map) Save(name string, objPtr interface{}) {
- m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s")
-}
-
-// SaveValue adds the given object value to the map.
-//
-// This should be used for values where pointers are not available, or casts
-// are required during Save/Load.
-//
-// For example, if we want to cast external package type P.Foo to int64:
-//
-// type X struct {
-// A P.Foo
-// }
-//
-// func (x *X) Save(m Map) {
-// m.SaveValue("A", int64(x.A))
-// }
-//
-// func (x *X) Load(m Map) {
-// m.LoadValue("A", new(int64), func(x interface{}) {
-// x.A = P.Foo(x.(int64))
-// })
-// }
-func (m Map) SaveValue(name string, obj interface{}) {
- m.save(name, reflect.ValueOf(obj), ".(value %s)")
-}
-
-// save is helper for the above. It takes the name of value to save the field
-// to, the field object (obj), and a format string that specifies how the
-// field's saving logic is dispatched from the struct (normal, value, etc.). The
-// format string should expect one string parameter, which is the name of the
-// field.
-func (m Map) save(name string, obj reflect.Value, format string) {
- if m.es == nil {
- // Not currently encoding.
- m.Failf("no encode state for %q", name)
- }
-
- // Attempt the encode.
- //
- // These are sorted at the end, after all objects are added and will be
- // sorted and checked for duplicates (see encodeStruct).
- m.data = append(m.data, entry{
- name: name,
- object: m.es.encodeObject(obj, false, format, name),
- })
-}
-
-// Load loads the given object from the map.
-//
-// See Save for an example.
-func (m Map) Load(name string, objPtr interface{}) {
- m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s")
-}
-
-// LoadWait loads the given objects from the map, and marks it as requiring all
-// AfterLoad executions to complete prior to running this object's AfterLoad.
-//
-// See Save for an example.
-func (m Map) LoadWait(name string, objPtr interface{}) {
- m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)")
-}
-
-// LoadValue loads the given object value from the map.
-//
-// See SaveValue for an example.
-func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) {
- o := reflect.ValueOf(objPtr)
- m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)")
-}
-
-// load is helper for the above. It takes the name of value to load the field
-// from, the target field pointer (objPtr), whether load completion of the
-// struct depends on the field's load completion (wait), the load completion
-// logic (fn), and a format string that specifies how the field's loading logic
-// is dispatched from the struct (normal, wait, value, etc.). The format string
-// should expect one string parameter, which is the name of the field.
-func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) {
- if m.ds == nil {
- // Not currently decoding.
- m.Failf("no decode state for %q", name)
- }
-
- // Find the object.
- //
- // These are sorted up front (and should appear in the state file
- // sorted as well), so we can do a binary search here to ensure that
- // large structs don't behave badly.
- i := sort.Search(len(m.data), func(i int) bool {
- return m.data[i].name >= name
- })
- if i >= len(m.data) || m.data[i].name != name {
- // There is no data for this name?
- m.Failf("no data found for %q", name)
- }
-
- // Perform the decode.
- m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name)
- if wait {
- // Mark this individual object a blocker.
- m.ds.waitObject(m.os, m.data[i].object, fn)
- }
-}
-
-// Failf fails the save or restore with the provided message. Processing will
-// stop after calling Failf, as the state package uses a panic & recover
-// mechanism for state errors. You should defer any cleanup required.
-func (m Map) Failf(format string, args ...interface{}) {
- panic(fmt.Errorf(format, args...))
-}
-
-// AfterLoad schedules a function execution when all objects have been allocated
-// and their automated loading and customized load logic have been executed. fn
-// will not be executed until all of current object's dependencies' AfterLoad()
-// logic, if exist, have been executed.
-func (m Map) AfterLoad(fn func()) {
- if m.ds == nil {
- // Not currently decoding.
- m.Failf("not decoding")
- }
-
- // Queue the local callback; this will execute when all of the above
- // data dependencies have been cleared.
- m.os.callbacks = append(m.os.callbacks, fn)
-}
diff --git a/pkg/state/object.proto b/pkg/state/object.proto
deleted file mode 100644
index 952289069..000000000
--- a/pkg/state/object.proto
+++ /dev/null
@@ -1,140 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-syntax = "proto3";
-
-package gvisor.state.statefile;
-
-// Slice is a slice value.
-message Slice {
- uint32 length = 1;
- uint32 capacity = 2;
- uint64 ref_value = 3;
-}
-
-// Array is an array value.
-message Array {
- repeated Object contents = 1;
-}
-
-// Map is a map value.
-message Map {
- repeated Object keys = 1;
- repeated Object values = 2;
-}
-
-// Interface is an interface value.
-message Interface {
- string type = 1;
- Object value = 2;
-}
-
-// Struct is a basic composite value.
-message Struct {
- repeated Field fields = 1;
-}
-
-// Field encodes a single field.
-message Field {
- string name = 1;
- Object value = 2;
-}
-
-// Uint16s encodes an uint16 array. To be used inside oneof structure.
-message Uint16s {
- // There is no 16-bit type in protobuf so we use variable length 32-bit here.
- repeated uint32 values = 1;
-}
-
-// Uint32s encodes an uint32 array. To be used inside oneof structure.
-message Uint32s {
- repeated fixed32 values = 1;
-}
-
-// Uint64s encodes an uint64 array. To be used inside oneof structure.
-message Uint64s {
- repeated fixed64 values = 1;
-}
-
-// Uintptrs encodes an uintptr array. To be used inside oneof structure.
-message Uintptrs {
- repeated fixed64 values = 1;
-}
-
-// Int8s encodes an int8 array. To be used inside oneof structure.
-message Int8s {
- bytes values = 1;
-}
-
-// Int16s encodes an int16 array. To be used inside oneof structure.
-message Int16s {
- // There is no 16-bit type in protobuf so we use variable length 32-bit here.
- repeated int32 values = 1;
-}
-
-// Int32s encodes an int32 array. To be used inside oneof structure.
-message Int32s {
- repeated sfixed32 values = 1;
-}
-
-// Int64s encodes an int64 array. To be used inside oneof structure.
-message Int64s {
- repeated sfixed64 values = 1;
-}
-
-// Bools encodes a boolean array. To be used inside oneof structure.
-message Bools {
- repeated bool values = 1;
-}
-
-// Float64s encodes a float64 array. To be used inside oneof structure.
-message Float64s {
- repeated double values = 1;
-}
-
-// Float32s encodes a float32 array. To be used inside oneof structure.
-message Float32s {
- repeated float values = 1;
-}
-
-// Object are primitive encodings.
-//
-// Note that ref_value references an Object.id, below.
-message Object {
- oneof value {
- bool bool_value = 1;
- bytes string_value = 2;
- int64 int64_value = 3;
- uint64 uint64_value = 4;
- double double_value = 5;
- uint64 ref_value = 6;
- Slice slice_value = 7;
- Array array_value = 8;
- Interface interface_value = 9;
- Struct struct_value = 10;
- Map map_value = 11;
- bytes byte_array_value = 12;
- Uint16s uint16_array_value = 13;
- Uint32s uint32_array_value = 14;
- Uint64s uint64_array_value = 15;
- Uintptrs uintptr_array_value = 16;
- Int8s int8_array_value = 17;
- Int16s int16_array_value = 18;
- Int32s int32_array_value = 19;
- Int64s int64_array_value = 20;
- Bools bool_array_value = 21;
- Float64s float64_array_value = 22;
- Float32s float32_array_value = 23;
- }
-}
diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD
new file mode 100644
index 000000000..d053802f7
--- /dev/null
+++ b/pkg/state/pretty/BUILD
@@ -0,0 +1,13 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pretty",
+ srcs = ["pretty.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go
new file mode 100644
index 000000000..cf37aaa49
--- /dev/null
+++ b/pkg/state/pretty/pretty.go
@@ -0,0 +1,273 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pretty is a pretty-printer for state streams.
+package pretty
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+func formatRef(x *wire.Ref, graph uint64, html bool) string {
+ baseRef := fmt.Sprintf("g%dr%d", graph, x.Root)
+ fullRef := baseRef
+ if len(x.Dots) > 0 {
+ // See wire.Ref; Type valid if Dots non-zero.
+ typ, _ := formatType(x.Type, graph, html)
+ var buf strings.Builder
+ buf.WriteString("(*")
+ buf.WriteString(typ)
+ buf.WriteString(")(")
+ buf.WriteString(baseRef)
+ for _, component := range x.Dots {
+ switch v := component.(type) {
+ case *wire.FieldName:
+ buf.WriteString(".")
+ buf.WriteString(string(*v))
+ case wire.Index:
+ buf.WriteString(fmt.Sprintf("[%d]", v))
+ default:
+ panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component)))
+ }
+ }
+ buf.WriteString(")")
+ fullRef = buf.String()
+ }
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef)
+ }
+ return fullRef
+}
+
+func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) {
+ switch x := t.(type) {
+ case wire.TypeID:
+ base := fmt.Sprintf("g%dt%d", graph, x)
+ if html {
+ return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true
+ }
+ return fmt.Sprintf("%s", base), true
+ case wire.TypeSpecNil:
+ return "", false // Only nil type.
+ case *wire.TypeSpecPointer:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("(*%s)", element), true
+ case *wire.TypeSpecArray:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("[%d](%s)", x.Count, element), true
+ case *wire.TypeSpecSlice:
+ element, _ := formatType(x.Type, graph, html)
+ return fmt.Sprintf("([]%s)", element), true
+ case *wire.TypeSpecMap:
+ key, _ := formatType(x.Key, graph, html)
+ value, _ := formatType(x.Value, graph, html)
+ return fmt.Sprintf("(map[%s]%s)", key, value), true
+ default:
+ panic(fmt.Sprintf("unreachable: unknown type %T", t))
+ }
+}
+
+// format formats a single object, for pretty-printing. It also returns whether
+// the value is a non-zero value.
+func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) {
+ switch x := encoded.(type) {
+ case wire.Nil:
+ return "nil", false
+ case *wire.String:
+ return fmt.Sprintf("%q", *x), *x != ""
+ case *wire.Complex64:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Complex128:
+ return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0
+ case *wire.Ref:
+ return formatRef(x, graph, html), x.Root != 0
+ case *wire.Type:
+ tabs := "\n" + strings.Repeat("\t", depth)
+ items := make([]string, 0, len(x.Fields)+2)
+ items = append(items, fmt.Sprintf("type %s {", x.Name))
+ for i := 0; i < len(x.Fields); i++ {
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i]))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true // No zero value.
+ case *wire.Slice:
+ return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0
+ case *wire.Array:
+ if len(x.Contents) == 0 {
+ return "[]", false
+ }
+ items := make([]string, 0, len(x.Contents)+2)
+ zeros := make([]string, 0) // used to eliminate zero entries.
+ items = append(items, "[")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Contents); i++ {
+ item, ok := format(graph, depth+1, x.Contents[i], html)
+ if !ok {
+ zeros = append(zeros, fmt.Sprintf("\t%s,", item))
+ continue
+ }
+ if len(zeros) > 0 {
+ items = append(items, zeros...)
+ zeros = nil
+ }
+ items = append(items, fmt.Sprintf("\t%s,", item))
+ }
+ if len(zeros) > 0 {
+ items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
+ }
+ items = append(items, "]")
+ return strings.Join(items, tabs), len(zeros) < len(x.Contents)
+ case *wire.Struct:
+ typ, _ := formatType(x.TypeID, graph, html)
+ if x.Fields() == 0 {
+ return fmt.Sprintf("struct[%s]{}", typ), false
+ }
+ items := make([]string, 0, 2)
+ items = append(items, fmt.Sprintf("struct[%s]{", typ))
+ tabs := "\n" + strings.Repeat("\t", depth)
+ allZero := true
+ for i := 0; i < x.Fields(); i++ {
+ element, ok := format(graph, depth+1, *x.Field(i), html)
+ allZero = allZero && !ok
+ items = append(items, fmt.Sprintf("\t%d: %s,", i, element))
+ i++
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), !allZero
+ case *wire.Map:
+ if len(x.Keys) == 0 {
+ return "map{}", false
+ }
+ items := make([]string, 0, len(x.Keys)+2)
+ items = append(items, "map{")
+ tabs := "\n" + strings.Repeat("\t", depth)
+ for i := 0; i < len(x.Keys); i++ {
+ key, _ := format(graph, depth+1, x.Keys[i], html)
+ value, _ := format(graph, depth+1, x.Values[i], html)
+ items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
+ }
+ items = append(items, "}")
+ return strings.Join(items, tabs), true
+ case *wire.Interface:
+ typ, typOk := formatType(x.Type, graph, html)
+ element, elementOk := format(graph, depth+1, x.Value, html)
+ return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk
+ default:
+ // Must be a primitive; use reflection.
+ return fmt.Sprintf("%v", encoded), true
+ }
+}
+
+// printStream is the basic print implementation.
+func printStream(w io.Writer, r wire.Reader, html bool) (err error) {
+ // current graph ID.
+ var graph uint64
+
+ if html {
+ fmt.Fprintf(w, "<pre>")
+ defer fmt.Fprintf(w, "</pre>")
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if rErr, ok := r.(error); ok {
+ err = rErr // Override return.
+ return
+ }
+ panic(r) // Propagate.
+ }
+ }()
+
+ for {
+ // Find the first object to begin generation.
+ length, object, err := state.ReadHeader(r)
+ if err == io.EOF {
+ // Nothing else to do.
+ break
+ } else if err != nil {
+ return err
+ }
+ if !object {
+ graph++ // Increment the graph.
+ if length > 0 {
+ fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
+ io.Copy(ioutil.Discard, &io.LimitedReader{
+ R: r,
+ N: int64(length),
+ })
+ }
+ continue
+ }
+
+ // Read & unmarshal the object.
+ //
+ // Note that this loop must match the general structure of the
+ // loop in decode.go. But we don't register type information,
+ // etc. and just print the raw structures.
+ var (
+ oid uint64 = 1
+ tid uint64 = 1
+ )
+ for oid <= length {
+ // Unmarshal the object.
+ encoded := wire.Load(r)
+
+ // Is this a type?
+ if _, ok := encoded.(*wire.Type); ok {
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dt%d", graph, tid)
+ if html {
+ // See below.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ tid++
+ continue
+ }
+
+ // Format the node.
+ str, _ := format(graph, 0, encoded, html)
+ tag := fmt.Sprintf("g%dr%d", graph, oid)
+ if html {
+ // Create a little tag with an anchor next to it for linking.
+ tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">&#9875;</a>", tag, tag, tag)
+ }
+ if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
+ return err
+ }
+ oid++
+ }
+ }
+
+ return nil
+}
+
+// PrintText reads the stream from r and prints text to w.
+func PrintText(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, false /* html */)
+}
+
+// PrintHTML reads the stream from r and prints html to w.
+func PrintHTML(w io.Writer, r wire.Reader) error {
+ return printStream(w, r, true /* html */)
+}
diff --git a/pkg/state/printer.go b/pkg/state/printer.go
deleted file mode 100644
index 3ce18242f..000000000
--- a/pkg/state/printer.go
+++ /dev/null
@@ -1,251 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "fmt"
- "io"
- "io/ioutil"
- "reflect"
- "strings"
-
- "github.com/golang/protobuf/proto"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
-)
-
-// format formats a single object, for pretty-printing. It also returns whether
-// the value is a non-zero value.
-func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) {
- switch x := object.GetValue().(type) {
- case *pb.Object_BoolValue:
- return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false
- case *pb.Object_StringValue:
- return fmt.Sprintf("\"%s\"", string(x.StringValue)), len(x.StringValue) != 0
- case *pb.Object_Int64Value:
- return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0
- case *pb.Object_Uint64Value:
- return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0
- case *pb.Object_DoubleValue:
- return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0
- case *pb.Object_RefValue:
- if x.RefValue == 0 {
- return "nil", false
- }
- ref := fmt.Sprintf("g%dr%d", graph, x.RefValue)
- if html {
- ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
- }
- return ref, true
- case *pb.Object_SliceValue:
- if x.SliceValue.RefValue == 0 {
- return "nil", false
- }
- ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue)
- if html {
- ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref)
- }
- return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true
- case *pb.Object_ArrayValue:
- if len(x.ArrayValue.Contents) == 0 {
- return "[]", false
- }
- items := make([]string, 0, len(x.ArrayValue.Contents)+2)
- zeros := make([]string, 0) // used to eliminate zero entries.
- items = append(items, "[")
- tabs := "\n" + strings.Repeat("\t", depth)
- for i := 0; i < len(x.ArrayValue.Contents); i++ {
- item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html)
- if ok {
- if len(zeros) > 0 {
- items = append(items, zeros...)
- zeros = nil
- }
- items = append(items, fmt.Sprintf("\t%s,", item))
- } else {
- zeros = append(zeros, fmt.Sprintf("\t%s,", item))
- }
- }
- if len(zeros) > 0 {
- items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros)))
- }
- items = append(items, "]")
- return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents)
- case *pb.Object_StructValue:
- if len(x.StructValue.Fields) == 0 {
- return "struct{}", false
- }
- items := make([]string, 0, len(x.StructValue.Fields)+2)
- items = append(items, "struct{")
- tabs := "\n" + strings.Repeat("\t", depth)
- allZero := true
- for _, field := range x.StructValue.Fields {
- element, ok := format(graph, depth+1, field.Value, html)
- allZero = allZero && !ok
- items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element))
- }
- items = append(items, "}")
- return strings.Join(items, tabs), !allZero
- case *pb.Object_MapValue:
- if len(x.MapValue.Keys) == 0 {
- return "map{}", false
- }
- items := make([]string, 0, len(x.MapValue.Keys)+2)
- items = append(items, "map{")
- tabs := "\n" + strings.Repeat("\t", depth)
- for i := 0; i < len(x.MapValue.Keys); i++ {
- key, _ := format(graph, depth+1, x.MapValue.Keys[i], html)
- value, _ := format(graph, depth+1, x.MapValue.Values[i], html)
- items = append(items, fmt.Sprintf("\t%s: %s,", key, value))
- }
- items = append(items, "}")
- return strings.Join(items, tabs), true
- case *pb.Object_InterfaceValue:
- if x.InterfaceValue.Type == "" {
- return "interface(nil){}", false
- }
- element, _ := format(graph, depth+1, x.InterfaceValue.Value, html)
- return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true
- case *pb.Object_ByteArrayValue:
- return printArray(reflect.ValueOf(x.ByteArrayValue))
- case *pb.Object_Uint16ArrayValue:
- return printArray(reflect.ValueOf(x.Uint16ArrayValue.Values))
- case *pb.Object_Uint32ArrayValue:
- return printArray(reflect.ValueOf(x.Uint32ArrayValue.Values))
- case *pb.Object_Uint64ArrayValue:
- return printArray(reflect.ValueOf(x.Uint64ArrayValue.Values))
- case *pb.Object_UintptrArrayValue:
- return printArray(castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
- case *pb.Object_Int8ArrayValue:
- return printArray(castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
- case *pb.Object_Int16ArrayValue:
- return printArray(reflect.ValueOf(x.Int16ArrayValue.Values))
- case *pb.Object_Int32ArrayValue:
- return printArray(reflect.ValueOf(x.Int32ArrayValue.Values))
- case *pb.Object_Int64ArrayValue:
- return printArray(reflect.ValueOf(x.Int64ArrayValue.Values))
- case *pb.Object_BoolArrayValue:
- return printArray(reflect.ValueOf(x.BoolArrayValue.Values))
- case *pb.Object_Float64ArrayValue:
- return printArray(reflect.ValueOf(x.Float64ArrayValue.Values))
- case *pb.Object_Float32ArrayValue:
- return printArray(reflect.ValueOf(x.Float32ArrayValue.Values))
- }
-
- // Should not happen, but tolerate.
- return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true
-}
-
-// PrettyPrint reads the state stream from r, and pretty prints to w.
-func PrettyPrint(w io.Writer, r io.Reader, html bool) error {
- var (
- // current graph ID.
- graph uint64
-
- // current object ID.
- id uint64
- )
-
- if html {
- fmt.Fprintf(w, "<pre>")
- defer fmt.Fprintf(w, "</pre>")
- }
-
- for {
- // Find the first object to begin generation.
- length, object, err := ReadHeader(r)
- if err == io.EOF {
- // Nothing else to do.
- break
- } else if err != nil {
- return err
- }
- if !object {
- // Increment the graph number & reset the ID.
- graph++
- id = 0
- if length > 0 {
- fmt.Fprintf(w, "(%d bytes non-object data)\n", length)
- io.Copy(ioutil.Discard, &io.LimitedReader{
- R: r,
- N: int64(length),
- })
- }
- continue
- }
-
- // Read & unmarshal the object.
- buf := make([]byte, length)
- for done := 0; done < len(buf); {
- n, err := r.Read(buf[done:])
- done += n
- if n == 0 && err != nil {
- return err
- }
- }
- obj := new(pb.Object)
- if err := proto.Unmarshal(buf, obj); err != nil {
- return err
- }
-
- id++ // First object must be one.
- str, _ := format(graph, 0, obj, html)
- tag := fmt.Sprintf("g%dr%d", graph, id)
- if html {
- tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag)
- }
- if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func printArray(s reflect.Value) (string, bool) {
- zero := reflect.Zero(s.Type().Elem()).Interface()
- z := "0"
- switch s.Type().Elem().Kind() {
- case reflect.Bool:
- z = "false"
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- case reflect.Float32, reflect.Float64:
- default:
- return fmt.Sprintf("unexpected non-primitive type array: %#v", s.Interface()), true
- }
-
- zeros := 0
- items := make([]string, 0, s.Len())
- for i := 0; i <= s.Len(); i++ {
- if i < s.Len() && reflect.DeepEqual(s.Index(i).Interface(), zero) {
- zeros++
- continue
- }
- if zeros > 0 {
- if zeros <= 4 {
- for ; zeros > 0; zeros-- {
- items = append(items, z)
- }
- } else {
- items = append(items, fmt.Sprintf("(%d %ss)", zeros, z))
- zeros = 0
- }
- }
- if i < s.Len() {
- items = append(items, fmt.Sprintf("%v", s.Index(i).Interface()))
- }
- }
- return "[" + strings.Join(items, ",") + "]", zeros < s.Len()
-}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index d408ff84a..acb629969 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -31,285 +31,243 @@
// Uint64 default
// Float32 default
// Float64 default
-// Complex64 custom
-// Complex128 custom
+// Complex64 default
+// Complex128 default
// Array default
// Chan custom
// Func custom
-// Interface custom
-// Map default (*)
+// Interface default
+// Map default
// Ptr default
// Slice default
// String default
-// Struct custom
+// Struct custom (*) Unless zero-sized.
// UnsafePointer custom
//
-// (*) Maps are treated as value types by this package, even if they are
-// pointers internally. If you want to save two independent references
-// to the same map value, you must explicitly use a pointer to a map.
+// See README.md for an overview of how encoding and decoding works.
package state
import (
+ "context"
"fmt"
- "io"
"reflect"
"runtime"
- pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
+// objectID is a unique identifier assigned to each object to be serialized.
+// Each instance of an object is considered separately, i.e. if there are two
+// objects of the same type in the object graph being serialized, they'll be
+// assigned unique objectIDs.
+type objectID uint32
+
+// typeID is the identifier for a type. Types are serialized and tracked
+// alongside objects in order to avoid the overhead of encoding field names in
+// all objects.
+type typeID uint32
+
// ErrState is returned when an error is encountered during encode/decode.
type ErrState struct {
// err is the underlying error.
err error
- // path is the visit path from root to the current object.
- path string
-
// trace is the stack trace.
trace string
}
// Error returns a sensible description of the state error.
func (e *ErrState) Error() string {
- return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace)
+ return fmt.Sprintf("%v:\n%s", e.err, e.trace)
}
-// UnwrapErrState returns the underlying error in ErrState.
-//
-// If err is not *ErrState, err is returned directly.
-func UnwrapErrState(err error) error {
- if e, ok := err.(*ErrState); ok {
- return e.err
- }
- return err
+// Unwrap implements standard unwrapping.
+func (e *ErrState) Unwrap() error {
+ return e.err
}
// Save saves the given object state.
-func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
+func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) {
// Create the encoding state.
- es := &encodeState{
- idsByObject: make(map[uintptr]uint64),
- w: w,
- stats: stats,
+ es := encodeState{
+ ctx: ctx,
+ w: w,
+ types: makeTypeEncodeDatabase(),
+ zeroValues: make(map[reflect.Type]*objectEncodeState),
}
// Perform the encoding.
- return es.safely(func() {
- es.Serialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ es.Save(reflect.ValueOf(rootPtr).Elem())
})
+ return es.stats, err
}
// Load loads a checkpoint.
-func Load(r io.Reader, rootPtr interface{}, stats *Stats) error {
+func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) {
// Create the decoding state.
- ds := &decodeState{
- objectsByID: make(map[uint64]*objectState),
- deferred: make(map[uint64]*pb.Object),
- r: r,
- stats: stats,
+ ds := decodeState{
+ ctx: ctx,
+ r: r,
+ types: makeTypeDecodeDatabase(),
+ deferred: make(map[objectID]wire.Object),
}
// Attempt our decode.
- return ds.safely(func() {
- ds.Deserialize(reflect.ValueOf(rootPtr).Elem())
+ err := safely(func() {
+ ds.Load(reflect.ValueOf(rootPtr).Elem())
})
+ return ds.stats, err
}
-// Fns are the state dispatch functions.
-type Fns struct {
- // Save is a function like Save(concreteType, Map).
- Save interface{}
-
- // Load is a function like Load(concreteType, Map).
- Load interface{}
+// Sink is used for Type.StateSave.
+type Sink struct {
+ internal objectEncoder
}
-// Save executes the save function.
-func (fns *Fns) invokeSave(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// Save adds the given object to the map.
+//
+// You should pass always pointers to the object you are saving. For example:
+//
+// type X struct {
+// A int
+// B *int
+// }
+//
+// func (x *X) StateTypeInfo(m Sink) state.TypeInfo {
+// return state.TypeInfo{
+// Name: "pkg.X",
+// Fields: []string{
+// "A",
+// "B",
+// },
+// }
+// }
+//
+// func (x *X) StateSave(m Sink) {
+// m.Save(0, &x.A) // Field is A.
+// m.Save(1, &x.B) // Field is B.
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.Load(0, &x.A) // Field is A.
+// m.Load(1, &x.B) // Field is B.
+// }
+func (s Sink) Save(slot int, objPtr interface{}) {
+ s.internal.save(slot, reflect.ValueOf(objPtr).Elem())
}
-// Load executes the load function.
-func (fns *Fns) invokeLoad(obj reflect.Value, m Map) {
- reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)})
+// SaveValue adds the given object value to the map.
+//
+// This should be used for values where pointers are not available, or casts
+// are required during Save/Load.
+//
+// For example, if we want to cast external package type P.Foo to int64:
+//
+// func (x *X) StateSave(m Sink) {
+// m.SaveValue(0, "A", int64(x.A))
+// }
+//
+// func (x *X) StateLoad(m Source) {
+// m.LoadValue(0, new(int64), func(x interface{}) {
+// x.A = P.Foo(x.(int64))
+// })
+// }
+func (s Sink) SaveValue(slot int, obj interface{}) {
+ s.internal.save(slot, reflect.ValueOf(obj))
}
-// validateStateFn ensures types are correct.
-func validateStateFn(fn interface{}, typ reflect.Type) bool {
- fnTyp := reflect.TypeOf(fn)
- if fnTyp.Kind() != reflect.Func {
- return false
- }
- if fnTyp.NumIn() != 2 {
- return false
- }
- if fnTyp.NumOut() != 0 {
- return false
- }
- if fnTyp.In(0) != typ {
- return false
- }
- if fnTyp.In(1) != reflect.TypeOf(Map{}) {
- return false
- }
- return true
+// Context returns the context object provided at save time.
+func (s Sink) Context() context.Context {
+ return s.internal.es.ctx
}
-// Validate validates all state functions.
-func (fns *Fns) Validate(typ reflect.Type) bool {
- return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ)
+// Type is an interface that must be implemented by Struct objects. This allows
+// these objects to be serialized while minimizing runtime reflection required.
+//
+// All these methods can be automatically generated by the go_statify tool.
+type Type interface {
+ // StateTypeName returns the type's name.
+ //
+ // This is used for matching type information during encoding and
+ // decoding, as well as dynamic interface dispatch. This should be
+ // globally unique.
+ StateTypeName() string
+
+ // StateFields returns information about the type.
+ //
+ // Fields is the set of fields for the object. Calls to Sink.Save and
+ // Source.Load must be made in-order with respect to these fields.
+ //
+ // This will be called at most once per serialization.
+ StateFields() []string
}
-type typeDatabase struct {
- // nameToType is a forward lookup table.
- nameToType map[string]reflect.Type
-
- // typeToName is the reverse lookup table.
- typeToName map[reflect.Type]string
+// SaverLoader must be implemented by struct types.
+type SaverLoader interface {
+ // StateSave saves the state of the object to the given Map.
+ StateSave(Sink)
- // typeToFns is the function lookup table.
- typeToFns map[reflect.Type]Fns
+ // StateLoad loads the state of the object.
+ StateLoad(Source)
}
-// registeredTypes is a database used for SaveInterface and LoadInterface.
-var registeredTypes = typeDatabase{
- nameToType: make(map[string]reflect.Type),
- typeToName: make(map[reflect.Type]string),
- typeToFns: make(map[reflect.Type]Fns),
+// Source is used for Type.StateLoad.
+type Source struct {
+ internal objectDecoder
}
-// register registers a type under the given name. This will generally be
-// called via init() methods, and therefore uses panic to propagate errors.
-func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) {
- // We can't allow name collisions.
- if ot, ok := t.nameToType[name]; ok {
- panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name()))
- }
-
- // Or multiple registrations.
- if on, ok := t.typeToName[typ]; ok {
- panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on))
- }
-
- t.nameToType[name] = typ
- t.typeToName[typ] = name
- t.typeToFns[typ] = fns
+// Load loads the given object passed as a pointer..
+//
+// See Sink.Save for an example.
+func (s Source) Load(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), false, nil)
}
-// lookupType finds a type given a name.
-func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) {
- typ, ok := t.nameToType[name]
- return typ, ok
+// LoadWait loads the given objects from the map, and marks it as requiring all
+// AfterLoad executions to complete prior to running this object's AfterLoad.
+//
+// See Sink.Save for an example.
+func (s Source) LoadWait(slot int, objPtr interface{}) {
+ s.internal.load(slot, reflect.ValueOf(objPtr), true, nil)
}
-// lookupName finds a name given a type.
-func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) {
- name, ok := t.typeToName[typ]
- return name, ok
+// LoadValue loads the given object value from the map.
+//
+// See Sink.SaveValue for an example.
+func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) {
+ o := reflect.ValueOf(objPtr)
+ s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) })
}
-// lookupFns finds functions given a type.
-func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) {
- fns, ok := t.typeToFns[typ]
- return fns, ok
+// AfterLoad schedules a function execution when all objects have been
+// allocated and their automated loading and customized load logic have been
+// executed. fn will not be executed until all of current object's
+// dependencies' AfterLoad() logic, if exist, have been executed.
+func (s Source) AfterLoad(fn func()) {
+ s.internal.afterLoad(fn)
}
-// Register must be called for any interface implementation types that
-// implements Loader.
-//
-// Register should be called either immediately after startup or via init()
-// methods. Double registration of either names or types will result in a panic.
-//
-// No synchronization is provided; this should only be called in init.
-//
-// Example usage:
-//
-// state.Register("Foo", (*Foo)(nil), state.Fns{
-// Save: (*Foo).Save,
-// Load: (*Foo).Load,
-// })
-//
-func Register(name string, instance interface{}, fns Fns) {
- registeredTypes.register(name, reflect.TypeOf(instance), fns)
+// Context returns the context object provided at load time.
+func (s Source) Context() context.Context {
+ return s.internal.ds.ctx
}
// IsZeroValue checks if the given value is the zero value.
//
// This function is used by the stateify tool.
func IsZeroValue(val interface{}) bool {
- if val == nil {
- return true
- }
- return reflect.DeepEqual(val, reflect.Zero(reflect.TypeOf(val)).Interface())
-}
-
-// step captures one encoding / decoding step. On each step, there is up to one
-// choice made, which is captured by non-nil param. We intentionally do not
-// eagerly create the final path string, as that will only be needed upon panic.
-type step struct {
- // dereference indicate if the current object is obtained by
- // dereferencing a pointer.
- dereference bool
-
- // format is the formatting string that takes param below, if
- // non-nil. For example, in array indexing case, we have "[%d]".
- format string
-
- // param stores the choice made at the current encoding / decoding step.
- // For eaxmple, in array indexing case, param stores the index. When no
- // choice is made, e.g. dereference, param should be nil.
- param interface{}
+ return val == nil || reflect.ValueOf(val).Elem().IsZero()
}
-// recoverable is the state encoding / decoding panic recovery facility. It is
-// also used to store encoding / decoding steps as well as the reference to the
-// original queued object from which the current object is dispatched. The
-// complete encoding / decoding path is synthesised from the steps in all queued
-// objects leading to the current object.
-type recoverable struct {
- from *recoverable
- steps []step
+// Failf is a wrapper around panic that should be used to generate errors that
+// can be caught during saving and loading.
+func Failf(fmtStr string, v ...interface{}) {
+ panic(fmt.Errorf(fmtStr, v...))
}
-// push enters a new context level.
-func (sr *recoverable) push(dereference bool, format string, param interface{}) {
- sr.steps = append(sr.steps, step{dereference, format, param})
-}
-
-// pop exits the current context level.
-func (sr *recoverable) pop() {
- if len(sr.steps) <= 1 {
- return
- }
- sr.steps = sr.steps[:len(sr.steps)-1]
-}
-
-// path returns the complete encoding / decoding path from root. This is only
-// called upon panic.
-func (sr *recoverable) path() string {
- if sr.from == nil {
- return "root"
- }
- p := sr.from.path()
- for _, s := range sr.steps {
- if s.dereference {
- p = fmt.Sprintf("*(%s)", p)
- }
- if s.param == nil {
- p += s.format
- } else {
- p += fmt.Sprintf(s.format, s.param)
- }
- }
- return p
-}
-
-func (sr *recoverable) copy() recoverable {
- return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)}
-}
-
-// safely executes the given function, catching a panic and unpacking as an error.
+// safely executes the given function, catching a panic and unpacking as an
+// error.
//
// The error flow through the state package uses panic and recover. There are
// two important reasons for this:
@@ -323,9 +281,15 @@ func (sr *recoverable) copy() recoverable {
// method doesn't add a lot of value. If there are specific error conditions
// that you'd like to handle, you should add appropriate functionality to
// objects themselves prior to calling Save() and Load().
-func (sr *recoverable) safely(fn func()) (err error) {
+func safely(fn func()) (err error) {
defer func() {
if r := recover(); r != nil {
+ if es, ok := r.(*ErrState); ok {
+ err = es // Propagate.
+ return
+ }
+
+ // Build a new state error.
es := new(ErrState)
if e, ok := r.(error); ok {
es.err = e
@@ -333,8 +297,6 @@ func (sr *recoverable) safely(fn func()) (err error) {
es.err = fmt.Errorf("%v", r)
}
- es.path = sr.path()
-
// Make a stack. We don't know how big it will be ahead
// of time, but want to make sure we get the whole
// thing. So we just do a stupid brute force approach.
diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go
new file mode 100644
index 000000000..4281aed6d
--- /dev/null
+++ b/pkg/state/state_norace.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package state
+
+var raceEnabled = false
diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go
new file mode 100644
index 000000000..8232981ce
--- /dev/null
+++ b/pkg/state/state_race.go
@@ -0,0 +1,19 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build race
+
+package state
+
+var raceEnabled = true
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
deleted file mode 100644
index 7c24bbcda..000000000
--- a/pkg/state/state_test.go
+++ /dev/null
@@ -1,720 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package state
-
-import (
- "bytes"
- "io/ioutil"
- "math"
- "reflect"
- "testing"
-)
-
-// TestCase is used to define a single success/failure testcase of
-// serialization of a set of objects.
-type TestCase struct {
- // Name is the name of the test case.
- Name string
-
- // Objects is the list of values to serialize.
- Objects []interface{}
-
- // Fail is whether the test case is supposed to fail or not.
- Fail bool
-}
-
-// runTest runs all testcases.
-func runTest(t *testing.T, tests []TestCase) {
- for _, test := range tests {
- t.Logf("TEST %s:", test.Name)
- for i, root := range test.Objects {
- t.Logf(" case#%d: %#v", i, root)
-
- // Save the passed object.
- saveBuffer := &bytes.Buffer{}
- saveObjectPtr := reflect.New(reflect.TypeOf(root))
- saveObjectPtr.Elem().Set(reflect.ValueOf(root))
- if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
- t.Errorf(" FAIL: Save failed unexpectedly: %v", err)
- continue
- } else if err != nil {
- t.Logf(" PASS: Save failed as expected: %v", err)
- continue
- }
-
- // Load a new copy of the object.
- loadObjectPtr := reflect.New(reflect.TypeOf(root))
- if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
- t.Errorf(" FAIL: Load failed unexpectedly: %v", err)
- continue
- } else if err != nil {
- t.Logf(" PASS: Load failed as expected: %v", err)
- continue
- }
-
- // Compare the values.
- loadedValue := loadObjectPtr.Elem().Interface()
- if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail {
- t.Errorf(" FAIL: Objects differs; got %#v", loadedValue)
- continue
- } else if !eq {
- t.Logf(" PASS: Object different as expected.")
- continue
- }
-
- // Everything went okay. Is that good?
- if test.Fail {
- t.Errorf(" FAIL: Unexpected success.")
- } else {
- t.Logf(" PASS: Success.")
- }
- }
- }
-}
-
-// dumbStruct is a struct which does not implement the loader/saver interface.
-// We expect that serialization of this struct will fail.
-type dumbStruct struct {
- A int
- B int
-}
-
-// smartStruct is a struct which does implement the loader/saver interface.
-// We expect that serialization of this struct will succeed.
-type smartStruct struct {
- A int
- B int
-}
-
-func (s *smartStruct) save(m Map) {
- m.Save("A", &s.A)
- m.Save("B", &s.B)
-}
-
-func (s *smartStruct) load(m Map) {
- m.Load("A", &s.A)
- m.Load("B", &s.B)
-}
-
-// valueLoadStruct uses a value load.
-type valueLoadStruct struct {
- v int
-}
-
-func (v *valueLoadStruct) save(m Map) {
- m.SaveValue("v", v.v)
-}
-
-func (v *valueLoadStruct) load(m Map) {
- m.LoadValue("v", new(int), func(value interface{}) {
- v.v = value.(int)
- })
-}
-
-// afterLoadStruct has an AfterLoad function.
-type afterLoadStruct struct {
- v int
-}
-
-func (a *afterLoadStruct) save(m Map) {
-}
-
-func (a *afterLoadStruct) load(m Map) {
- m.AfterLoad(func() {
- a.v++
- })
-}
-
-// genericContainer is a generic dispatcher.
-type genericContainer struct {
- v interface{}
-}
-
-func (g *genericContainer) save(m Map) {
- m.Save("v", &g.v)
-}
-
-func (g *genericContainer) load(m Map) {
- m.Load("v", &g.v)
-}
-
-// sliceContainer is a generic slice.
-type sliceContainer struct {
- v []interface{}
-}
-
-func (s *sliceContainer) save(m Map) {
- m.Save("v", &s.v)
-}
-
-func (s *sliceContainer) load(m Map) {
- m.Load("v", &s.v)
-}
-
-// mapContainer is a generic map.
-type mapContainer struct {
- v map[int]interface{}
-}
-
-func (mc *mapContainer) save(m Map) {
- m.Save("v", &mc.v)
-}
-
-func (mc *mapContainer) load(m Map) {
- // Some of the test cases below assume legacy behavior wherein maps
- // will automatically inherit dependencies.
- m.LoadWait("v", &mc.v)
-}
-
-// dumbMap is a map which does not implement the loader/saver interface.
-// Serialization of this map will default to the standard encode/decode logic.
-type dumbMap map[string]int
-
-// pointerStruct contains various pointers, shared and non-shared, and pointers
-// to pointers. We expect that serialization will respect the structure.
-type pointerStruct struct {
- A *int
- B *int
- C *int
- D *int
-
- AA **int
- BB **int
-}
-
-func (p *pointerStruct) save(m Map) {
- m.Save("A", &p.A)
- m.Save("B", &p.B)
- m.Save("C", &p.C)
- m.Save("D", &p.D)
- m.Save("AA", &p.AA)
- m.Save("BB", &p.BB)
-}
-
-func (p *pointerStruct) load(m Map) {
- m.Load("A", &p.A)
- m.Load("B", &p.B)
- m.Load("C", &p.C)
- m.Load("D", &p.D)
- m.Load("AA", &p.AA)
- m.Load("BB", &p.BB)
-}
-
-// testInterface is a trivial interface example.
-type testInterface interface {
- Foo()
-}
-
-// testImpl is a trivial implementation of testInterface.
-type testImpl struct {
-}
-
-// Foo satisfies testInterface.
-func (t *testImpl) Foo() {
-}
-
-// testImpl is trivially serializable.
-func (t *testImpl) save(m Map) {
-}
-
-// testImpl is trivially serializable.
-func (t *testImpl) load(m Map) {
-}
-
-// testI demonstrates interface dispatching.
-type testI struct {
- I testInterface
-}
-
-func (t *testI) save(m Map) {
- m.Save("I", &t.I)
-}
-
-func (t *testI) load(m Map) {
- m.Load("I", &t.I)
-}
-
-// cycleStruct is used to implement basic cycles.
-type cycleStruct struct {
- c *cycleStruct
-}
-
-func (c *cycleStruct) save(m Map) {
- m.Save("c", &c.c)
-}
-
-func (c *cycleStruct) load(m Map) {
- m.Load("c", &c.c)
-}
-
-// badCycleStruct actually has deadlocking dependencies.
-//
-// This should pass if b.b = {nil|b} and fail otherwise.
-type badCycleStruct struct {
- b *badCycleStruct
-}
-
-func (b *badCycleStruct) save(m Map) {
- m.Save("b", &b.b)
-}
-
-func (b *badCycleStruct) load(m Map) {
- m.LoadWait("b", &b.b)
- m.AfterLoad(func() {
- // This is not executable, since AfterLoad requires that the
- // object and all dependencies are complete. This should cause
- // a deadlock error during load.
- })
-}
-
-// emptyStructPointer points to an empty struct.
-type emptyStructPointer struct {
- nothing *struct{}
-}
-
-func (e *emptyStructPointer) save(m Map) {
- m.Save("nothing", &e.nothing)
-}
-
-func (e *emptyStructPointer) load(m Map) {
- m.Load("nothing", &e.nothing)
-}
-
-// truncateInteger truncates an integer.
-type truncateInteger struct {
- v int64
- v2 int32
-}
-
-func (t *truncateInteger) save(m Map) {
- t.v2 = int32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateInteger) load(m Map) {
- m.Load("v", &t.v2)
- t.v = int64(t.v2)
-}
-
-// truncateUnsignedInteger truncates an unsigned integer.
-type truncateUnsignedInteger struct {
- v uint64
- v2 uint32
-}
-
-func (t *truncateUnsignedInteger) save(m Map) {
- t.v2 = uint32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateUnsignedInteger) load(m Map) {
- m.Load("v", &t.v2)
- t.v = uint64(t.v2)
-}
-
-// truncateFloat truncates a floating point number.
-type truncateFloat struct {
- v float64
- v2 float32
-}
-
-func (t *truncateFloat) save(m Map) {
- t.v2 = float32(t.v)
- m.Save("v", &t.v)
-}
-
-func (t *truncateFloat) load(m Map) {
- m.Load("v", &t.v2)
- t.v = float64(t.v2)
-}
-
-func TestTypes(t *testing.T) {
- // x and y are basic integers, while xp points to x.
- x := 1
- y := 2
- xp := &x
-
- // cs is a single object cycle.
- cs := cycleStruct{nil}
- cs.c = &cs
-
- // cs1 and cs2 are in a two object cycle.
- cs1 := cycleStruct{nil}
- cs2 := cycleStruct{nil}
- cs1.c = &cs2
- cs2.c = &cs1
-
- // bs is a single object cycle.
- bs := badCycleStruct{nil}
- bs.b = &bs
-
- // bs2 and bs2 are in a deadlocking cycle.
- bs1 := badCycleStruct{nil}
- bs2 := badCycleStruct{nil}
- bs1.b = &bs2
- bs2.b = &bs1
-
- // regular nils.
- var (
- nilmap dumbMap
- nilslice []byte
- )
-
- // embed points to embedded fields.
- embed1 := pointerStruct{}
- embed1.AA = &embed1.A
- embed2 := pointerStruct{}
- embed2.BB = &embed2.B
-
- // es1 contains two structs pointing to the same empty struct.
- es := emptyStructPointer{new(struct{})}
- es1 := []emptyStructPointer{es, es}
-
- tests := []TestCase{
- {
- Name: "bool",
- Objects: []interface{}{
- true,
- false,
- },
- },
- {
- Name: "integers",
- Objects: []interface{}{
- int(0),
- int(1),
- int(-1),
- int8(0),
- int8(1),
- int8(-1),
- int16(0),
- int16(1),
- int16(-1),
- int32(0),
- int32(1),
- int32(-1),
- int64(0),
- int64(1),
- int64(-1),
- },
- },
- {
- Name: "unsigned integers",
- Objects: []interface{}{
- uint(0),
- uint(1),
- uint8(0),
- uint8(1),
- uint16(0),
- uint16(1),
- uint32(1),
- uint64(0),
- uint64(1),
- },
- },
- {
- Name: "strings",
- Objects: []interface{}{
- "",
- "foo",
- "bar",
- "\xa0",
- },
- },
- {
- Name: "slices",
- Objects: []interface{}{
- []int{-1, 0, 1},
- []*int{&x, &x, &x},
- []int{1, 2, 3}[0:1],
- []int{1, 2, 3}[1:2],
- make([]byte, 32),
- make([]byte, 32)[:16],
- make([]byte, 32)[:16:20],
- nilslice,
- },
- },
- {
- Name: "arrays",
- Objects: []interface{}{
- &[1048576]bool{false, true, false, true},
- &[1048576]uint8{0, 1, 2, 3},
- &[1048576]byte{0, 1, 2, 3},
- &[1048576]uint16{0, 1, 2, 3},
- &[1048576]uint{0, 1, 2, 3},
- &[1048576]uint32{0, 1, 2, 3},
- &[1048576]uint64{0, 1, 2, 3},
- &[1048576]uintptr{0, 1, 2, 3},
- &[1048576]int8{0, -1, -2, -3},
- &[1048576]int16{0, -1, -2, -3},
- &[1048576]int32{0, -1, -2, -3},
- &[1048576]int64{0, -1, -2, -3},
- &[1048576]float32{0, 1.1, 2.2, 3.3},
- &[1048576]float64{0, 1.1, 2.2, 3.3},
- },
- },
- {
- Name: "pointers",
- Objects: []interface{}{
- &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp},
- &pointerStruct{},
- },
- },
- {
- Name: "empty struct",
- Objects: []interface{}{
- struct{}{},
- },
- },
- {
- Name: "unenlightened structs",
- Objects: []interface{}{
- &dumbStruct{A: 1, B: 2},
- },
- Fail: true,
- },
- {
- Name: "enlightened structs",
- Objects: []interface{}{
- &smartStruct{A: 1, B: 2},
- },
- },
- {
- Name: "load-hooks",
- Objects: []interface{}{
- &afterLoadStruct{v: 1},
- &valueLoadStruct{v: 1},
- &genericContainer{v: &afterLoadStruct{v: 1}},
- &genericContainer{v: &valueLoadStruct{v: 1}},
- &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
- &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
- },
- },
- {
- Name: "maps",
- Objects: []interface{}{
- dumbMap{"a": -1, "b": 0, "c": 1},
- map[smartStruct]int{{}: 0, {A: 1}: 1},
- nilmap,
- &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}},
- },
- },
- {
- Name: "interfaces",
- Objects: []interface{}{
- &testI{&testImpl{}},
- &testI{nil},
- &testI{(*testImpl)(nil)},
- },
- },
- {
- Name: "unregistered-interfaces",
- Objects: []interface{}{
- &genericContainer{v: afterLoadStruct{v: 1}},
- &genericContainer{v: valueLoadStruct{v: 1}},
- &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}},
- &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}},
- &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}},
- },
- Fail: true,
- },
- {
- Name: "cycles",
- Objects: []interface{}{
- &cs,
- &cs1,
- &cycleStruct{&cs1},
- &cycleStruct{&cs},
- &badCycleStruct{nil},
- &bs,
- },
- },
- {
- Name: "deadlock",
- Objects: []interface{}{
- &bs1,
- },
- Fail: true,
- },
- {
- Name: "embed",
- Objects: []interface{}{
- &embed1,
- &embed2,
- },
- Fail: true,
- },
- {
- Name: "empty structs",
- Objects: []interface{}{
- new(struct{}),
- es,
- es1,
- },
- },
- {
- Name: "truncated okay",
- Objects: []interface{}{
- &truncateInteger{v: 1},
- &truncateUnsignedInteger{v: 1},
- &truncateFloat{v: 1.0},
- },
- },
- {
- Name: "truncated bad",
- Objects: []interface{}{
- &truncateInteger{v: math.MaxInt32 + 1},
- &truncateUnsignedInteger{v: math.MaxUint32 + 1},
- &truncateFloat{v: math.MaxFloat32 * 2},
- },
- Fail: true,
- },
- }
-
- runTest(t, tests)
-}
-
-// benchStruct is used for benchmarking.
-type benchStruct struct {
- b *benchStruct
-
- // Dummy data is included to ensure that these objects are large.
- // This is to detect possible regression when registering objects.
- _ [4096]byte
-}
-
-func (b *benchStruct) save(m Map) {
- m.Save("b", &b.b)
-}
-
-func (b *benchStruct) load(m Map) {
- m.LoadWait("b", &b.b)
- m.AfterLoad(b.afterLoad)
-}
-
-func (b *benchStruct) afterLoad() {
- // Do nothing, just force scheduling.
-}
-
-// buildObject builds a benchmark object.
-func buildObject(n int) (b *benchStruct) {
- for i := 0; i < n; i++ {
- b = &benchStruct{b: b}
- }
- return
-}
-
-func BenchmarkEncoding(b *testing.B) {
- b.StopTimer()
- bs := buildObject(b.N)
- var stats Stats
- b.StartTimer()
- if err := Save(ioutil.Discard, bs, &stats); err != nil {
- b.Errorf("save failed: %v", err)
- }
- b.StopTimer()
- if b.N > 1000 {
- b.Logf("breakdown (n=%d): %s", b.N, &stats)
- }
-}
-
-func BenchmarkDecoding(b *testing.B) {
- b.StopTimer()
- bs := buildObject(b.N)
- var newBS benchStruct
- buf := &bytes.Buffer{}
- if err := Save(buf, bs, nil); err != nil {
- b.Errorf("save failed: %v", err)
- }
- var stats Stats
- b.StartTimer()
- if err := Load(buf, &newBS, &stats); err != nil {
- b.Errorf("load failed: %v", err)
- }
- b.StopTimer()
- if b.N > 1000 {
- b.Logf("breakdown (n=%d): %s", b.N, &stats)
- }
-}
-
-func init() {
- Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{
- Save: (*smartStruct).save,
- Load: (*smartStruct).load,
- })
- Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{
- Save: (*afterLoadStruct).save,
- Load: (*afterLoadStruct).load,
- })
- Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{
- Save: (*valueLoadStruct).save,
- Load: (*valueLoadStruct).load,
- })
- Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{
- Save: (*genericContainer).save,
- Load: (*genericContainer).load,
- })
- Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{
- Save: (*sliceContainer).save,
- Load: (*sliceContainer).load,
- })
- Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{
- Save: (*mapContainer).save,
- Load: (*mapContainer).load,
- })
- Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{
- Save: (*pointerStruct).save,
- Load: (*pointerStruct).load,
- })
- Register("stateTest.testImpl", (*testImpl)(nil), Fns{
- Save: (*testImpl).save,
- Load: (*testImpl).load,
- })
- Register("stateTest.testI", (*testI)(nil), Fns{
- Save: (*testI).save,
- Load: (*testI).load,
- })
- Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{
- Save: (*cycleStruct).save,
- Load: (*cycleStruct).load,
- })
- Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{
- Save: (*badCycleStruct).save,
- Load: (*badCycleStruct).load,
- })
- Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{
- Save: (*emptyStructPointer).save,
- Load: (*emptyStructPointer).load,
- })
- Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{
- Save: (*truncateInteger).save,
- Load: (*truncateInteger).load,
- })
- Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{
- Save: (*truncateUnsignedInteger).save,
- Load: (*truncateUnsignedInteger).load,
- })
- Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{
- Save: (*truncateFloat).save,
- Load: (*truncateFloat).load,
- })
- Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{
- Save: (*benchStruct).save,
- Load: (*benchStruct).load,
- })
-}
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index 8a865d229..d6c89c7e9 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -1,16 +1,15 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "statefile",
srcs = ["statefile.go"],
- importpath = "gvisor.dev/gvisor/pkg/state/statefile",
visibility = ["//:sandbox"],
deps = [
"//pkg/binary",
"//pkg/compressio",
+ "//pkg/state/wire",
],
)
@@ -18,6 +17,6 @@ go_test(
name = "statefile_test",
size = "small",
srcs = ["statefile_test.go"],
- embed = [":statefile"],
+ library = ":statefile",
deps = ["//pkg/compressio"],
)
diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go
index c0f4c4954..bdfb800fb 100644
--- a/pkg/state/statefile/statefile.go
+++ b/pkg/state/statefile/statefile.go
@@ -57,6 +57,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/compressio"
+ "gvisor.dev/gvisor/pkg/state/wire"
)
// keySize is the AES-256 key length.
@@ -83,10 +84,16 @@ var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size
// ErrMetadataInvalid is returned if passed metadata is invalid.
var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _")
+// WriteCloser is an io.Closer and wire.Writer.
+type WriteCloser interface {
+ wire.Writer
+ io.Closer
+}
+
// NewWriter returns a state data writer for a statefile.
//
// Note that the returned WriteCloser must be closed.
-func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteCloser, error) {
+func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) {
if metadata == nil {
metadata = make(map[string]string)
}
@@ -215,7 +222,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) {
}
// NewReader returns a reader for a statefile.
-func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
+func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) {
// Read the metadata with the hash.
h := hmac.New(sha256.New, key)
metadata, err := metadata(r, h)
@@ -224,9 +231,9 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) {
}
// Wrap in compression.
- rc, err := compressio.NewReader(r, key)
+ cr, err := compressio.NewReader(r, key)
if err != nil {
return nil, nil, err
}
- return rc, metadata, nil
+ return cr, metadata, nil
}
diff --git a/pkg/state/stats.go b/pkg/state/stats.go
index eb51cda47..eaec664a1 100644
--- a/pkg/state/stats.go
+++ b/pkg/state/stats.go
@@ -17,7 +17,6 @@ package state
import (
"bytes"
"fmt"
- "reflect"
"sort"
"time"
)
@@ -35,92 +34,81 @@ type statEntry struct {
// All exported receivers accept nil.
type Stats struct {
// byType contains a breakdown of time spent by type.
- byType map[reflect.Type]*statEntry
+ //
+ // This is indexed *directly* by typeID, including zero.
+ byType []statEntry
// stack contains objects in progress.
- stack []reflect.Type
+ stack []typeID
+
+ // names contains type names.
+ //
+ // This is also indexed *directly* by typeID, including zero, which we
+ // hard-code as "state.default". This is only resolved by calling fini
+ // on the stats object.
+ names []string
// last is the last start time.
last time.Time
}
-// sample adds the samples to the given object.
-func (s *Stats) sample(typ reflect.Type) {
- now := time.Now()
- s.byType[typ].total += now.Sub(s.last)
- s.last = now
+// init initializes statistics.
+func (s *Stats) init() {
+ s.last = time.Now()
+ s.stack = append(s.stack, 0)
}
-// Add adds a sample count.
-func (s *Stats) Add(obj reflect.Value) {
- if s == nil {
- return
- }
- if s.byType == nil {
- s.byType = make(map[reflect.Type]*statEntry)
- }
- typ := obj.Type()
- entry, ok := s.byType[typ]
- if !ok {
- entry = new(statEntry)
- s.byType[typ] = entry
+// fini finalizes statistics.
+func (s *Stats) fini(resolve func(id typeID) string) {
+ s.done()
+
+ // Resolve all type names.
+ s.names = make([]string, len(s.byType))
+ s.names[0] = "state.default" // See above.
+ for id := typeID(1); int(id) < len(s.names); id++ {
+ s.names[id] = resolve(id)
}
- entry.count++
}
-// Remove removes a sample count. It should only be called after a previous
-// Add().
-func (s *Stats) Remove(obj reflect.Value) {
- if s == nil {
- return
+// sample adds the samples to the given object.
+func (s *Stats) sample(id typeID) {
+ now := time.Now()
+ if len(s.byType) <= int(id) {
+ // Allocate all the missing entries in one fell swoop.
+ s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...)
}
- typ := obj.Type()
- entry := s.byType[typ]
- entry.count--
+ s.byType[id].total += now.Sub(s.last)
+ s.last = now
}
-// Start starts a sample.
-func (s *Stats) Start(obj reflect.Value) {
- if s == nil {
- return
- }
- if len(s.stack) > 0 {
- last := s.stack[len(s.stack)-1]
- s.sample(last)
- } else {
- // First time sample.
- s.last = time.Now()
- }
- s.stack = append(s.stack, obj.Type())
+// start starts a sample.
+func (s *Stats) start(id typeID) {
+ last := s.stack[len(s.stack)-1]
+ s.sample(last)
+ s.stack = append(s.stack, id)
}
-// Done finishes the current sample.
-func (s *Stats) Done() {
- if s == nil {
- return
- }
+// done finishes the current sample.
+func (s *Stats) done() {
last := s.stack[len(s.stack)-1]
s.sample(last)
+ s.byType[last].count++
s.stack = s.stack[:len(s.stack)-1]
}
type sliceEntry struct {
- typ reflect.Type
+ name string
entry *statEntry
}
// String returns a table representation of the stats.
func (s *Stats) String() string {
- if s == nil || len(s.byType) == 0 {
- return "(no data)"
- }
-
// Build a list of stat entries.
ss := make([]sliceEntry, 0, len(s.byType))
- for typ, entry := range s.byType {
+ for id := 0; id < len(s.names); id++ {
ss = append(ss, sliceEntry{
- typ: typ,
- entry: entry,
+ name: s.names[id],
+ entry: &s.byType[id],
})
}
@@ -136,17 +124,22 @@ func (s *Stats) String() string {
total time.Duration
)
buf.WriteString("\n")
- buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type"))
- buf.WriteString("-------------+----------+----------+-------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type"))
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
for _, se := range ss {
+ if se.entry.count == 0 {
+ // Since we store all types linearly, we are not
+ // guaranteed that any entry actually has time.
+ continue
+ }
count += se.entry.count
total += se.entry.total
per := se.entry.total / time.Duration(se.entry.count)
- buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n",
- se.entry.total, se.entry.count, per, se.typ.String()))
+ buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n",
+ se.entry.total, se.entry.count, per, se.name))
}
- buf.WriteString("-------------+----------+----------+-------------\n")
- buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]",
+ buf.WriteString("-----------------+----------+------------------+----------------\n")
+ buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]",
total, count, total/time.Duration(count)))
return string(buf.Bytes())
}
diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD
new file mode 100644
index 000000000..9297cafbe
--- /dev/null
+++ b/pkg/state/tests/BUILD
@@ -0,0 +1,43 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tests",
+ srcs = [
+ "array.go",
+ "bench.go",
+ "integer.go",
+ "load.go",
+ "map.go",
+ "register.go",
+ "struct.go",
+ "tests.go",
+ ],
+ deps = [
+ "//pkg/state",
+ "//pkg/state/pretty",
+ ],
+)
+
+go_test(
+ name = "tests_test",
+ size = "small",
+ srcs = [
+ "array_test.go",
+ "bench_test.go",
+ "bool_test.go",
+ "float_test.go",
+ "integer_test.go",
+ "load_test.go",
+ "map_test.go",
+ "register_test.go",
+ "string_test.go",
+ "struct_test.go",
+ ],
+ library = ":tests",
+ deps = [
+ "//pkg/state",
+ "//pkg/state/wire",
+ ],
+)
diff --git a/test/root/testdata/sandbox.go b/pkg/state/tests/array.go
index 0db210370..0972a80e7 100644
--- a/test/root/testdata/sandbox.go
+++ b/pkg/state/tests/array.go
@@ -12,19 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testdata
+package tests
-// Sandbox is a default JSON config for a sandbox.
-const Sandbox = `
-{
- "metadata": {
- "name": "default-sandbox",
- "namespace": "default",
- "attempt": 1,
- "uid": "hdishd83djaidwnduwk28bcsb"
- },
- "linux": {
- },
- "log_directory": "/tmp"
+// +stateify savable
+type arrayContainer struct {
+ v [1]interface{}
+}
+
+// +stateify savable
+type arrayPtrContainer struct {
+ v *[1]interface{}
+}
+
+// +stateify savable
+type sliceContainer struct {
+ v []interface{}
+}
+
+// +stateify savable
+type slicePtrContainer struct {
+ v *[]interface{}
}
-`
diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go
new file mode 100644
index 000000000..a347b2947
--- /dev/null
+++ b/pkg/state/tests/array_test.go
@@ -0,0 +1,134 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allArrayPrimitives = []interface{}{
+ [1]bool{},
+ [1]bool{true},
+ [2]bool{false, true},
+ [1]int{},
+ [1]int{1},
+ [2]int{0, 1},
+ [1]int8{},
+ [1]int8{1},
+ [2]int8{0, 1},
+ [1]int16{},
+ [1]int16{1},
+ [2]int16{0, 1},
+ [1]int32{},
+ [1]int32{1},
+ [2]int32{0, 1},
+ [1]int64{},
+ [1]int64{1},
+ [2]int64{0, 1},
+ [1]uint{},
+ [1]uint{1},
+ [2]uint{0, 1},
+ [1]uintptr{},
+ [1]uintptr{1},
+ [2]uintptr{0, 1},
+ [1]uint8{},
+ [1]uint8{1},
+ [2]uint8{0, 1},
+ [1]uint16{},
+ [1]uint16{1},
+ [2]uint16{0, 1},
+ [1]uint32{},
+ [1]uint32{1},
+ [2]uint32{0, 1},
+ [1]uint64{},
+ [1]uint64{1},
+ [2]uint64{0, 1},
+ [1]string{},
+ [1]string{""},
+ [1]string{nonEmptyString},
+ [2]string{"", nonEmptyString},
+}
+
+func TestArrayPrimitives(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allArrayPrimitives))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives))))
+}
+
+func TestSlices(t *testing.T) {
+ var allSlices = flatten(
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ return v.Slice(0, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the pure "nil" value for the slice.
+ return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true
+ }
+ return v.Slice(1, v.Len()).Interface(), true
+ }),
+ filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o)).Elem()
+ v.Set(reflect.ValueOf(o))
+ if v.Len() == 0 {
+ // Return the zero-valued slice.
+ return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true
+ }
+ return v.Slice(0, v.Len()-1).Interface(), true
+ }),
+ )
+ runTestCases(t, false, "plain", allSlices)
+ runTestCases(t, false, "pointers", pointersTo(allSlices))
+ runTestCases(t, false, "interfaces", interfacesTo(allSlices))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices)))
+}
+
+func TestArrayContainers(t *testing.T) {
+ var (
+ emptyArray [1]interface{}
+ fullArray [1]interface{}
+ )
+ fullArray[0] = &emptyArray
+ runTestCases(t, false, "", []interface{}{
+ arrayContainer{v: emptyArray},
+ arrayContainer{v: fullArray},
+ arrayPtrContainer{v: nil},
+ arrayPtrContainer{v: &emptyArray},
+ arrayPtrContainer{v: &fullArray},
+ })
+}
+
+func TestSliceContainers(t *testing.T) {
+ var (
+ nilSlice []interface{}
+ emptySlice = make([]interface{}, 0)
+ fullSlice = []interface{}{nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ sliceContainer{v: nilSlice},
+ sliceContainer{v: emptySlice},
+ sliceContainer{v: fullSlice},
+ slicePtrContainer{v: nil},
+ slicePtrContainer{v: &nilSlice},
+ slicePtrContainer{v: &emptySlice},
+ slicePtrContainer{v: &fullSlice},
+ })
+}
diff --git a/test/root/testdata/httpd.go b/pkg/state/tests/bench.go
index 45d5e33d4..40869cdfb 100644
--- a/test/root/testdata/httpd.go
+++ b/pkg/state/tests/bench.go
@@ -12,21 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testdata
+package tests
-// Httpd is a JSON config for an httpd container.
-const Httpd = `
-{
- "metadata": {
- "name": "httpd"
- },
- "image":{
- "image": "httpd"
- },
- "mounts": [
- ],
- "linux": {
- },
- "log_path": "httpd.log"
+// +stateify savable
+type benchStruct struct {
+ B *benchStruct // Must be exported for gob.
+}
+
+func (b *benchStruct) afterLoad() {
+ // Do nothing, just force scheduling.
}
-`
diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go
new file mode 100644
index 000000000..7e102c907
--- /dev/null
+++ b/pkg/state/tests/bench_test.go
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "bytes"
+ "context"
+ "encoding/gob"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// buildPtrObject builds a benchmark object.
+func buildPtrObject(n int) interface{} {
+ b := new(benchStruct)
+ for i := 0; i < n; i++ {
+ b = &benchStruct{B: b}
+ }
+ return b
+}
+
+// buildMapObject builds a benchmark object.
+func buildMapObject(n int) interface{} {
+ b := new(benchStruct)
+ m := make(map[int]*benchStruct)
+ for i := 0; i < n; i++ {
+ m[i] = b
+ }
+ return &m
+}
+
+// buildSliceObject builds a benchmark object.
+func buildSliceObject(n int) interface{} {
+ b := new(benchStruct)
+ s := make([]*benchStruct, 0, n)
+ for i := 0; i < n; i++ {
+ s = append(s, b)
+ }
+ return &s
+}
+
+var allObjects = map[string]struct {
+ New func(int) interface{}
+}{
+ "ptr": {
+ New: buildPtrObject,
+ },
+ "map": {
+ New: buildMapObject,
+ },
+ "slice": {
+ New: buildSliceObject,
+ },
+}
+
+func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) {
+ // maxSize is the maximum size of an individual object below. For an N
+ // larger than this, we start to return multiple objects.
+ const maxSize = 1024
+ if n <= maxSize {
+ return 1, fn(n)
+ }
+ iters = (n + maxSize - 1) / maxSize
+ return iters, fn(maxSize)
+}
+
+// gobSave is a version of save using gob (no stats available).
+func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) {
+ enc := gob.NewEncoder(w)
+ err = enc.Encode(v)
+ return
+}
+
+// gobLoad is a version of load using gob (no stats available).
+func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) {
+ dec := gob.NewDecoder(r)
+ err = dec.Decode(v)
+ return
+}
+
+var allAlgos = map[string]struct {
+ Save func(context.Context, wire.Writer, interface{}) (state.Stats, error)
+ Load func(context.Context, wire.Reader, interface{}) (state.Stats, error)
+ MaxPtr int
+}{
+ "state": {
+ Save: state.Save,
+ Load: state.Load,
+ },
+ "gob": {
+ Save: gobSave,
+ Load: gobLoad,
+ },
+}
+
+func BenchmarkEncoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ b.ReportAllocs()
+ b.StartTimer()
+ for i := 0; i < n; i++ {
+ if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
+
+func BenchmarkDecoding(b *testing.B) {
+ for objName, objInfo := range allObjects {
+ for algoName, algoInfo := range allAlgos {
+ b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) {
+ b.StopTimer()
+ n, v := buildObjects(b.N, objInfo.New)
+ buf := new(bytes.Buffer)
+ if _, err := algoInfo.Save(context.Background(), buf, v); err != nil {
+ b.Errorf("save failed: %v", err)
+ }
+ b.ReportAllocs()
+ b.StartTimer()
+ var r bytes.Reader
+ for i := 0; i < n; i++ {
+ r.Reset(buf.Bytes())
+ if _, err := algoInfo.Load(context.Background(), &r, v); err != nil {
+ b.Errorf("load failed: %v", err)
+ }
+ }
+ b.StopTimer()
+ })
+ }
+ }
+}
diff --git a/test/root/testdata/busybox.go b/pkg/state/tests/bool_test.go
index e4dbd2843..e17cfacf9 100644
--- a/test/root/testdata/busybox.go
+++ b/pkg/state/tests/bool_test.go
@@ -12,21 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testdata
+package tests
-// MountOverSymlink is a JSON config for a container that /etc/resolv.conf is a
-// symlink to /tmp/resolv.conf.
-var MountOverSymlink = `
-{
- "metadata": {
- "name": "busybox"
- },
- "image": {
- "image": "k8s.gcr.io/busybox"
- },
- "command": [
- "sleep",
- "1000"
- ]
+import (
+ "testing"
+)
+
+var allBools = []bool{
+ true,
+ false,
+}
+
+func TestBool(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allBools))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allBools)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools))))
}
-`
diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go
new file mode 100644
index 000000000..3e89edd9c
--- /dev/null
+++ b/pkg/state/tests/float_test.go
@@ -0,0 +1,118 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var safeFloat32s = []float32{
+ float32(0.0),
+ float32(1.0),
+ float32(-1.0),
+ float32(math.Inf(1)),
+ float32(math.Inf(-1)),
+}
+
+var allFloat32s = append(safeFloat32s, float32(math.NaN()))
+
+var safeFloat64s = []float64{
+ float64(0.0),
+ float64(1.0),
+ float64(-1.0),
+ math.Inf(1),
+ math.Inf(-1),
+}
+
+var allFloat64s = append(safeFloat64s, math.NaN())
+
+func TestFloat(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allFloat32s,
+ allFloat64s,
+ ))
+ // See checkEqual for why NaNs are missing.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ )))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(
+ safeFloat32s,
+ safeFloat64s,
+ ))))
+}
+
+const onlyDouble float64 = 1.0000000000000002
+
+func TestFloatTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingFloat32{save: onlyDouble},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingFloat32{save: 1.0},
+ })
+}
+
+var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} {
+ return complex(i.(float32), j.(float32))
+})
+
+var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} {
+ return complex(i.(float64), j.(float64))
+})
+
+func TestComplex(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(
+ allComplex64s,
+ allComplex128s,
+ ))
+ // See TestFloat; same issue.
+ runTestCases(t, false, "pointers", pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacse", interfacesTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ )))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten(
+ safeComplex64s,
+ safeComplex128s,
+ ))))
+}
+
+func TestComplexTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingComplex64{save: complex(onlyDouble, onlyDouble)},
+ truncatingComplex64{save: complex(1.0, onlyDouble)},
+ truncatingComplex64{save: complex(onlyDouble, 1.0)},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingComplex64{save: complex(1.0, 1.0)},
+ })
+}
diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go
new file mode 100644
index 000000000..ca403eed1
--- /dev/null
+++ b/pkg/state/tests/integer.go
@@ -0,0 +1,163 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// +stateify type
+type truncatingUint8 struct {
+ save uint64
+ load uint8 `state:"nosave"`
+}
+
+func (t *truncatingUint8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint8)(nil)
+
+// +stateify type
+type truncatingUint16 struct {
+ save uint64
+ load uint16 `state:"nosave"`
+}
+
+func (t *truncatingUint16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint16)(nil)
+
+// +stateify type
+type truncatingUint32 struct {
+ save uint64
+ load uint32 `state:"nosave"`
+}
+
+func (t *truncatingUint32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingUint32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = uint64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingUint32)(nil)
+
+// +stateify type
+type truncatingInt8 struct {
+ save int64
+ load int8 `state:"nosave"`
+}
+
+func (t *truncatingInt8) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt8) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt8)(nil)
+
+// +stateify type
+type truncatingInt16 struct {
+ save int64
+ load int16 `state:"nosave"`
+}
+
+func (t *truncatingInt16) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt16) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt16)(nil)
+
+// +stateify type
+type truncatingInt32 struct {
+ save int64
+ load int32 `state:"nosave"`
+}
+
+func (t *truncatingInt32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingInt32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = int64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingInt32)(nil)
+
+// +stateify type
+type truncatingFloat32 struct {
+ save float64
+ load float32 `state:"nosave"`
+}
+
+func (t *truncatingFloat32) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingFloat32) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = float64(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingFloat32)(nil)
+
+// +stateify type
+type truncatingComplex64 struct {
+ save complex128
+ load complex64 `state:"nosave"`
+}
+
+func (t *truncatingComplex64) StateSave(m state.Sink) {
+ m.Save(0, &t.save)
+}
+
+func (t *truncatingComplex64) StateLoad(m state.Source) {
+ m.Load(0, &t.load)
+ t.save = complex128(t.load)
+ t.load = 0
+}
+
+var _ state.SaverLoader = (*truncatingComplex64)(nil)
diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go
new file mode 100644
index 000000000..d3931c952
--- /dev/null
+++ b/pkg/state/tests/integer_test.go
@@ -0,0 +1,94 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "math"
+ "testing"
+)
+
+var (
+ allIntTs = []int{-1, 0, 1}
+ allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8}
+ allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16}
+ allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32}
+ allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64}
+ allUintTs = []uint{0, 1}
+ allUintptrs = []uintptr{0, 1, ^uintptr(0)}
+ allUint8s = []uint8{0, 1, math.MaxUint8}
+ allUint16s = []uint16{0, 1, math.MaxUint16}
+ allUint32s = []uint32{0, 1, math.MaxUint32}
+ allUint64s = []uint64{0, 1, math.MaxUint64}
+)
+
+var allInts = flatten(
+ allIntTs,
+ allInt8s,
+ allInt16s,
+ allInt32s,
+ allInt64s,
+)
+
+var allUints = flatten(
+ allUintTs,
+ allUintptrs,
+ allUint8s,
+ allUint16s,
+ allUint32s,
+ allUint64s,
+)
+
+func TestInt(t *testing.T) {
+ runTestCases(t, false, "plain", allInts)
+ runTestCases(t, false, "pointers", pointersTo(allInts))
+ runTestCases(t, false, "interfaces", interfacesTo(allInts))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts)))
+}
+
+func TestIntTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingInt8{save: math.MinInt8 - 1},
+ truncatingInt16{save: math.MinInt16 - 1},
+ truncatingInt32{save: math.MinInt32 - 1},
+ truncatingInt8{save: math.MaxInt8 + 1},
+ truncatingInt16{save: math.MaxInt16 + 1},
+ truncatingInt32{save: math.MaxInt32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingInt8{save: 1},
+ truncatingInt16{save: 1},
+ truncatingInt32{save: 1},
+ })
+}
+
+func TestUint(t *testing.T) {
+ runTestCases(t, false, "plain", allUints)
+ runTestCases(t, false, "pointers", pointersTo(allUints))
+ runTestCases(t, false, "interfaces", interfacesTo(allUints))
+ runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints)))
+}
+
+func TestUintTruncation(t *testing.T) {
+ runTestCases(t, true, "pass", []interface{}{
+ truncatingUint8{save: math.MaxUint8 + 1},
+ truncatingUint16{save: math.MaxUint16 + 1},
+ truncatingUint32{save: math.MaxUint32 + 1},
+ })
+ runTestCases(t, false, "fail", []interface{}{
+ truncatingUint8{save: 1},
+ truncatingUint16{save: 1},
+ truncatingUint32{save: 1},
+ })
+}
diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go
new file mode 100644
index 000000000..a8350c0f3
--- /dev/null
+++ b/pkg/state/tests/load.go
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type genericContainer struct {
+ v interface{}
+}
+
+// +stateify savable
+type afterLoadStruct struct {
+ v int `state:"nosave"`
+}
+
+func (a *afterLoadStruct) afterLoad() {
+ a.v++
+}
+
+// +stateify savable
+type valueLoadStruct struct {
+ v int `state:".(int64)"`
+}
+
+func (v *valueLoadStruct) saveV() int64 {
+ return int64(v.v) // Save as int64.
+}
+
+func (v *valueLoadStruct) loadV(value int64) {
+ v.v = int(value) // Load as int.
+}
+
+// +stateify savable
+type cycleStruct struct {
+ c *cycleStruct
+}
+
+// +stateify savable
+type badCycleStruct struct {
+ b *badCycleStruct `state:"wait"`
+}
+
+func (b *badCycleStruct) afterLoad() {
+ if b.b != b {
+ // This is not executable, since AfterLoad requires that the
+ // object and all dependencies are complete. This should cause
+ // a deadlock error during load.
+ panic("badCycleStruct.afterLoad called")
+ }
+}
diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go
new file mode 100644
index 000000000..1e9794296
--- /dev/null
+++ b/pkg/state/tests/load_test.go
@@ -0,0 +1,70 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+func TestLoadHooks(t *testing.T) {
+ runTestCases(t, false, "load-hooks", []interface{}{
+ &afterLoadStruct{v: 1},
+ &valueLoadStruct{v: 1},
+ &genericContainer{v: &afterLoadStruct{v: 1}},
+ &genericContainer{v: &valueLoadStruct{v: 1}},
+ &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}},
+ &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}},
+ &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}},
+ })
+}
+
+func TestCycles(t *testing.T) {
+ // cs is a single object cycle.
+ cs := cycleStruct{nil}
+ cs.c = &cs
+
+ // cs1 and cs2 are in a two object cycle.
+ cs1 := cycleStruct{nil}
+ cs2 := cycleStruct{nil}
+ cs1.c = &cs2
+ cs2.c = &cs1
+
+ runTestCases(t, false, "cycles", []interface{}{
+ cs,
+ cs1,
+ })
+}
+
+func TestDeadlock(t *testing.T) {
+ // bs is a single object cycle. This does not cause deadlock because an
+ // object cannot wait for itself.
+ bs := badCycleStruct{nil}
+ bs.b = &bs
+
+ runTestCases(t, false, "self", []interface{}{
+ &bs,
+ })
+
+ // bs2 and bs2 are in a deadlocking cycle.
+ bs1 := badCycleStruct{nil}
+ bs2 := badCycleStruct{nil}
+ bs1.b = &bs2
+ bs2.b = &bs1
+
+ runTestCases(t, true, "deadlock", []interface{}{
+ &bs1,
+ })
+}
diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go
new file mode 100644
index 000000000..db4e548f1
--- /dev/null
+++ b/pkg/state/tests/map.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+// +stateify savable
+type mapContainer struct {
+ v map[int]interface{}
+}
+
+// +stateify savable
+type mapPtrContainer struct {
+ v *map[int]interface{}
+}
+
+// +stateify savable
+type registeredMapStruct struct{}
diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go
new file mode 100644
index 000000000..92bf0fc01
--- /dev/null
+++ b/pkg/state/tests/map_test.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "reflect"
+ "testing"
+)
+
+var allMapPrimitives = []interface{}{
+ bool(true),
+ int(1),
+ int8(1),
+ int16(1),
+ int32(1),
+ int64(1),
+ uint(1),
+ uintptr(1),
+ uint8(1),
+ uint16(1),
+ uint32(1),
+ uint64(1),
+ string(""),
+ registeredMapStruct{},
+}
+
+var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives))
+
+var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives))
+
+var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} {
+ m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2)))
+ m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2)))
+ return m.Interface()
+})
+
+func TestMapAliasing(t *testing.T) {
+ v := make(map[int]int)
+ ptrToV := &v
+ aliases := []map[int]int{v, v}
+ runTestCases(t, false, "", []interface{}{ptrToV, aliases})
+}
+
+func TestMapsEmpty(t *testing.T) {
+ runTestCases(t, false, "plain", emptyMaps)
+ runTestCases(t, false, "pointers", pointersTo(emptyMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(emptyMaps))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps)))
+}
+
+func TestMapsFull(t *testing.T) {
+ runTestCases(t, false, "plain", fullMaps)
+ runTestCases(t, false, "pointers", pointersTo(fullMaps))
+ runTestCases(t, false, "interfaces", interfacesTo(fullMaps))
+ runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps)))
+}
+
+func TestMapContainers(t *testing.T) {
+ var (
+ nilMap map[int]interface{}
+ emptyMap = make(map[int]interface{})
+ fullMap = map[int]interface{}{0: nil}
+ )
+ runTestCases(t, false, "", []interface{}{
+ mapContainer{v: nilMap},
+ mapContainer{v: emptyMap},
+ mapContainer{v: fullMap},
+ mapPtrContainer{v: nil},
+ mapPtrContainer{v: &nilMap},
+ mapPtrContainer{v: &emptyMap},
+ mapPtrContainer{v: &fullMap},
+ })
+}
diff --git a/pkg/sentry/socket/rpcinet/rpcinet.go b/pkg/state/tests/register.go
index 5d4fd4dac..074d86315 100644
--- a/pkg/sentry/socket/rpcinet/rpcinet.go
+++ b/pkg/state/tests/register.go
@@ -12,5 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package rpcinet implements sockets using an RPC for each syscall.
-package rpcinet
+package tests
+
+// +stateify savable
+type alreadyRegisteredStruct struct{}
+
+// +stateify savable
+type alreadyRegisteredOther int
diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go
new file mode 100644
index 000000000..c829753cc
--- /dev/null
+++ b/pkg/state/tests/register_test.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+// faker calls itself whatever is in the name field.
+type faker struct {
+ Name string
+ Fields []string
+}
+
+func (f *faker) StateTypeName() string {
+ return f.Name
+}
+
+func (f *faker) StateFields() []string {
+ return f.Fields
+}
+
+// fakerWithSaverLoader has all it needs.
+type fakerWithSaverLoader struct {
+ faker
+}
+
+func (f *fakerWithSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerWithSaverLoader) StateLoad(m state.Source) {}
+
+// fakerOther calls itself .. uh, itself?
+type fakerOther string
+
+func (f *fakerOther) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOther) StateFields() []string {
+ return nil
+}
+
+func newFakerOther(name string) *fakerOther {
+ f := fakerOther(name)
+ return &f
+}
+
+// fakerOtherBadFields returns non-nil fields.
+type fakerOtherBadFields string
+
+func (f *fakerOtherBadFields) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherBadFields) StateFields() []string {
+ return []string{string(*f)}
+}
+
+func newFakerOtherBadFields(name string) *fakerOtherBadFields {
+ f := fakerOtherBadFields(name)
+ return &f
+}
+
+// fakerOtherSaverLoader implements SaverLoader methods.
+type fakerOtherSaverLoader string
+
+func (f *fakerOtherSaverLoader) StateTypeName() string {
+ return string(*f)
+}
+
+func (f *fakerOtherSaverLoader) StateFields() []string {
+ return nil
+}
+
+func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {}
+
+func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {}
+
+func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader {
+ f := fakerOtherSaverLoader(name)
+ return &f
+}
+
+func TestRegisterPrimitives(t *testing.T) {
+ for _, typeName := range []string{
+ "int",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "uint",
+ "uintptr",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ "float32",
+ "float64",
+ "complex64",
+ "complex128",
+ "string",
+ } {
+ t.Run("struct/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(&faker{
+ Name: typeName,
+ })
+ })
+ t.Run("other/"+typeName, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering type %q did not panic", typeName)
+ }
+ }()
+ state.Register(newFakerOther(typeName))
+ })
+ }
+}
+
+func TestRegisterBad(t *testing.T) {
+ const (
+ goodName = "foo"
+ firstField = "a"
+ secondField = "b"
+ )
+ for name, object := range map[string]state.Type{
+ "non-struct-with-fields": newFakerOtherBadFields(goodName),
+ "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName),
+ "struct-without-saverloader": &faker{Name: goodName},
+ "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()),
+ "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()),
+ "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}},
+ "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}},
+ "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}},
+ "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}},
+ "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}},
+ "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}},
+ } {
+ t.Run(name, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Registering object %#v did not panic", object)
+ }
+ }()
+ state.Register(object)
+ })
+
+ }
+}
diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go
new file mode 100644
index 000000000..44f5a562c
--- /dev/null
+++ b/pkg/state/tests/string_test.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+)
+
+const nonEmptyString = "hello world"
+
+var allStrings = []string{
+ "",
+ nonEmptyString,
+ "\\0",
+}
+
+func TestString(t *testing.T) {
+ runTestCases(t, false, "plain", flatten(allStrings))
+ runTestCases(t, false, "pointers", pointersTo(flatten(allStrings)))
+ runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings)))
+ runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings))))
+}
diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go
new file mode 100644
index 000000000..bd2c2b399
--- /dev/null
+++ b/pkg/state/tests/struct.go
@@ -0,0 +1,65 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+type unregisteredEmptyStruct struct{}
+
+// typeOnlyEmptyStruct just implements the state.Type interface.
+type typeOnlyEmptyStruct struct{}
+
+func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" }
+
+func (*typeOnlyEmptyStruct) StateFields() []string { return nil }
+
+// +stateify savable
+type savableEmptyStruct struct{}
+
+// +stateify savable
+type emptyStructPointer struct {
+ nothing *struct{}
+}
+
+// +stateify savable
+type outerSame struct {
+ inner inner
+}
+
+// +stateify savable
+type outerFieldFirst struct {
+ inner inner
+ v int64
+}
+
+// +stateify savable
+type outerFieldSecond struct {
+ v int64
+ inner inner
+}
+
+// +stateify savable
+type outerArray struct {
+ inner [2]inner
+}
+
+// +stateify savable
+type inner struct {
+ v int64
+}
+
+// +stateify savable
+type system struct {
+ v1 interface{}
+ v2 interface{}
+}
diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go
new file mode 100644
index 000000000..de9d17aa7
--- /dev/null
+++ b/pkg/state/tests/struct_test.go
@@ -0,0 +1,89 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tests
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+)
+
+func TestEmptyStruct(t *testing.T) {
+ runTestCases(t, false, "plain", []interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ })
+ runTestCases(t, false, "pointers", pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{
+ // Only registered types can be dispatched via interfaces. All
+ // other types should fail, even if it is the empty struct.
+ savableEmptyStruct{},
+ }))
+ runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ }))
+ runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{
+ savableEmptyStruct{},
+ })))
+ runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{
+ unregisteredEmptyStruct{},
+ typeOnlyEmptyStruct{},
+ })))
+
+ // Ensuring empty struct aliasing works.
+ es := emptyStructPointer{new(struct{})}
+ runTestCases(t, false, "empty-struct-pointers", []interface{}{
+ emptyStructPointer{},
+ es,
+ []emptyStructPointer{es, es}, // Same pointer.
+ })
+}
+
+func TestRegisterTypeOnlyStruct(t *testing.T) {
+ defer func() {
+ if r := recover(); r == nil {
+ t.Errorf("Register did not panic")
+ }
+ }()
+ state.Register((*typeOnlyEmptyStruct)(nil))
+}
+
+func TestEmbeddedPointers(t *testing.T) {
+ var (
+ ofs outerSame
+ of1 outerFieldFirst
+ of2 outerFieldSecond
+ oa outerArray
+ )
+
+ runTestCases(t, false, "embedded-pointers", []interface{}{
+ system{&ofs, &ofs.inner},
+ system{&ofs.inner, &ofs},
+ system{&of1, &of1.inner},
+ system{&of1.inner, &of1},
+ system{&of2, &of2.inner},
+ system{&of2.inner, &of2},
+ system{&oa, &oa.inner[0]},
+ system{&oa, &oa.inner[1]},
+ system{&oa.inner[0], &oa},
+ system{&oa.inner[1], &oa},
+ })
+}
diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go
new file mode 100644
index 000000000..435a0e9db
--- /dev/null
+++ b/pkg/state/tests/tests.go
@@ -0,0 +1,215 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tests tests the state packages.
+package tests
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "math"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/pretty"
+)
+
+// discard is an implementation of wire.Writer.
+type discard struct{}
+
+// Write implements wire.Writer.Write.
+func (discard) Write(p []byte) (int, error) { return len(p), nil }
+
+// WriteByte implements wire.Writer.WriteByte.
+func (discard) WriteByte(byte) error { return nil }
+
+// checkEqual checks if two objects are equal.
+//
+// N.B. This only handles one level of dereferences for NaN. Otherwise we
+// would need to fork the entire implementation of reflect.DeepEqual.
+func checkEqual(root, loadedValue interface{}) bool {
+ if reflect.DeepEqual(root, loadedValue) {
+ return true
+ }
+
+ // NaN is not equal to itself. We handle the case of raw floating point
+ // primitives here, but don't handle this case nested.
+ rf32, ok1 := root.(float32)
+ lf32, ok2 := loadedValue.(float32)
+ if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
+ return true
+ }
+ rf64, ok1 := root.(float64)
+ lf64, ok2 := loadedValue.(float64)
+ if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
+ return true
+ }
+
+ // Same real for complex numbers.
+ rc64, ok1 := root.(complex64)
+ lc64, ok2 := root.(complex64)
+ if ok1 && ok2 {
+ return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
+ }
+ rc128, ok1 := root.(complex128)
+ lc128, ok2 := root.(complex128)
+ if ok1 && ok2 {
+ return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
+ }
+
+ return false
+}
+
+// runTestCases runs a test for each object in objects.
+func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) {
+ t.Helper()
+ for i, root := range objects {
+ t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
+ t.Logf("Original object:\n%#v", root)
+
+ // Save the passed object.
+ saveBuffer := &bytes.Buffer{}
+ saveObjectPtr := reflect.New(reflect.TypeOf(root))
+ saveObjectPtr.Elem().Set(reflect.ValueOf(root))
+ saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Save failed unexpectedly: %v", err)
+ }
+
+ // Dump the serialized proto to aid with debugging.
+ var ppBuf bytes.Buffer
+ t.Logf("Raw state:\n%v", saveBuffer.Bytes())
+ if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // We don't count this as a test failure if we
+ // have shouldFail set, but we will count as a
+ // failure if we were not expecting to fail.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
+ }
+ }
+ if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
+ // See above.
+ if !shouldFail {
+ t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
+ }
+ }
+ t.Logf("Encoded state:\n%s", ppBuf.String())
+ t.Logf("Save stats:\n%s", saveStats.String())
+
+ // Load a new copy of the object.
+ loadObjectPtr := reflect.New(reflect.TypeOf(root))
+ loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
+ if err != nil {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Load failed unexpectedly: %v", err)
+ }
+
+ // Compare the values.
+ loadedValue := loadObjectPtr.Elem().Interface()
+ if !checkEqual(root, loadedValue) {
+ if shouldFail {
+ return
+ }
+ t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue)
+ }
+
+ // Everything went okay. Is that good?
+ if shouldFail {
+ t.Fatalf("This test was expected to fail, but didn't.")
+ }
+ t.Logf("Load stats:\n%s", loadStats.String())
+
+ // Truncate half the bytes in the byte stream,
+ // and ensure that we can't restore. Then
+ // truncate only the final byte and ensure that
+ // we can't restore.
+ l := saveBuffer.Len()
+ halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
+ if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with half bytes succeeded unexpectedly.")
+ }
+ missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
+ if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
+ t.Errorf("Load with missing byte succeeded unexpectedly.")
+ }
+ })
+ }
+}
+
+// convert converts the slice to an []interface{}.
+func convert(v interface{}) (r []interface{}) {
+ s := reflect.ValueOf(v) // Must be slice.
+ for i := 0; i < s.Len(); i++ {
+ r = append(r, s.Index(i).Interface())
+ }
+ return r
+}
+
+// flatten flattens multiple slices.
+func flatten(vs ...interface{}) (r []interface{}) {
+ for _, v := range vs {
+ r = append(r, convert(v)...)
+ }
+ return r
+}
+
+// filter maps from one slice to another.
+func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) {
+ s := reflect.ValueOf(vs)
+ for i := 0; i < s.Len(); i++ {
+ v, ok := fn(s.Index(i).Interface())
+ if ok {
+ r = append(r, v)
+ }
+ }
+ return r
+}
+
+// combine combines objects in two slices as specified.
+func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) {
+ s1 := reflect.ValueOf(v1)
+ s2 := reflect.ValueOf(v2)
+ for i := 0; i < s1.Len(); i++ {
+ for j := 0; j < s2.Len(); j++ {
+ // Combine using the given function.
+ r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
+ }
+ }
+ return r
+}
+
+// pointersTo is a filter function that returns pointers.
+func pointersTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ v := reflect.New(reflect.TypeOf(o))
+ v.Elem().Set(reflect.ValueOf(o))
+ return v.Interface(), true
+ })
+}
+
+// interfacesTo is a filter function that returns interface objects.
+func interfacesTo(vs interface{}) []interface{} {
+ return filter(vs, func(o interface{}) (interface{}, bool) {
+ var v [1]interface{}
+ v[0] = o
+ return v, true
+ })
+}
diff --git a/pkg/state/types.go b/pkg/state/types.go
new file mode 100644
index 000000000..215ef80f8
--- /dev/null
+++ b/pkg/state/types.go
@@ -0,0 +1,361 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package state
+
+import (
+ "reflect"
+ "sort"
+
+ "gvisor.dev/gvisor/pkg/state/wire"
+)
+
+// assertValidType asserts that the type is valid.
+func assertValidType(name string, fields []string) {
+ if name == "" {
+ Failf("type has empty name")
+ }
+ fieldsCopy := make([]string, len(fields))
+ for i := 0; i < len(fields); i++ {
+ if fields[i] == "" {
+ Failf("field has empty name for type %q", name)
+ }
+ fieldsCopy[i] = fields[i]
+ }
+ sort.Slice(fieldsCopy, func(i, j int) bool {
+ return fieldsCopy[i] < fieldsCopy[j]
+ })
+ for i := range fieldsCopy {
+ if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] {
+ Failf("duplicate field %q for type %s", fieldsCopy[i], name)
+ }
+ }
+}
+
+// typeEntry is an entry in the typeDatabase.
+type typeEntry struct {
+ ID typeID
+ wire.Type
+}
+
+// reconciledTypeEntry is a reconciled entry in the typeDatabase.
+type reconciledTypeEntry struct {
+ wire.Type
+ LocalType reflect.Type
+ FieldOrder []int
+}
+
+// typeEncodeDatabase is an internal TypeInfo database for encoding.
+type typeEncodeDatabase struct {
+ // byType maps by type to the typeEntry.
+ byType map[reflect.Type]*typeEntry
+
+ // lastID is the last used ID.
+ lastID typeID
+}
+
+// makeTypeEncodeDatabase makes a typeDatabase.
+func makeTypeEncodeDatabase() typeEncodeDatabase {
+ return typeEncodeDatabase{
+ byType: make(map[reflect.Type]*typeEntry),
+ }
+}
+
+// typeDecodeDatabase is an internal TypeInfo database for decoding.
+type typeDecodeDatabase struct {
+ // byID maps by ID to type.
+ byID []*reconciledTypeEntry
+
+ // pending are entries that are pending validation by Lookup. These
+ // will be reconciled with actual objects. Note that these will also be
+ // used to lookup types by name, since they may not be reconciled and
+ // there's little value to deleting from this map.
+ pending []*wire.Type
+}
+
+// makeTypeDecodeDatabase makes a typeDatabase.
+func makeTypeDecodeDatabase() typeDecodeDatabase {
+ return typeDecodeDatabase{}
+}
+
+// lookupNameFields extracts the name and fields from an object.
+func lookupNameFields(typ reflect.Type) (string, []string, bool) {
+ v := reflect.Zero(reflect.PtrTo(typ)).Interface()
+ t, ok := v.(Type)
+ if !ok {
+ // Is this a primitive?
+ if typ.Kind() == reflect.Interface {
+ return interfaceType, nil, true
+ }
+ name := typ.Name()
+ if _, ok := primitiveTypeDatabase[name]; !ok {
+ // This is not a known type, and not a primitive. The
+ // encoder may proceed for anonymous empty structs, or
+ // it may deference the type pointer and try again.
+ return "", nil, false
+ }
+ return name, nil, true
+ }
+ // Extract the name from the object.
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ return name, fields, true
+}
+
+// Lookup looks up or registers the given object.
+//
+// The bool indicates whether this is an existing entry: false means the entry
+// did not exist, and true means the entry did exist. If this bool is false and
+// the returned typeEntry are nil, then the obj did not implement the Type
+// interface.
+func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) {
+ te, ok := tdb.byType[typ]
+ if !ok {
+ // Lookup the type information.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs may still be encoded, so let the
+ // caller decide what to do from here.
+ return nil, false
+ }
+
+ // Register the new type.
+ tdb.lastID++
+ te = &typeEntry{
+ ID: tdb.lastID,
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ }
+
+ // All done.
+ tdb.byType[typ] = te
+ return te, false
+ }
+ return te, true
+}
+
+// Register adds a typeID entry.
+func (tbd *typeDecodeDatabase) Register(typ *wire.Type) {
+ assertValidType(typ.Name, typ.Fields)
+ tbd.pending = append(tbd.pending, typ)
+}
+
+// LookupName looks up the type name by ID.
+func (tbd *typeDecodeDatabase) LookupName(id typeID) string {
+ if len(tbd.pending) < int(id) {
+ // This is likely an encoder error?
+ Failf("type ID %d not available", id)
+ }
+ return tbd.pending[id-1].Name
+}
+
+// LookupType looks up the type by ID.
+func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type {
+ name := tbd.LookupName(id)
+ typ, ok := globalTypeDatabase[name]
+ if !ok {
+ // If not available, see if it's primitive.
+ typ, ok = primitiveTypeDatabase[name]
+ if !ok && name == interfaceType {
+ // Matches the built-in interface type.
+ var i interface{}
+ return reflect.TypeOf(&i).Elem()
+ }
+ if !ok {
+ // The type is perhaps not registered?
+ Failf("type name %q is not available", name)
+ }
+ return typ // Primitive type.
+ }
+ return typ // Registered type.
+}
+
+// singleFieldOrder defines the field order for a single field.
+var singleFieldOrder = []int{0}
+
+// Lookup looks up or registers the given object.
+//
+// First, the typeID is searched to see if this has already been appropriately
+// reconciled. If no, then a reconcilation will take place that may result in a
+// field ordering. If a nil reconciledTypeEntry is returned from this method,
+// then the object does not support the Type interface.
+//
+// This method never returns nil.
+func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry {
+ if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil {
+ // Already reconciled.
+ return tbd.byID[id-1]
+ }
+ // The ID has not been reconciled yet. That's fine. We need to make
+ // sure it aligns with the current provided object.
+ if len(tbd.pending) < int(id) {
+ // This id was never registered. Probably an encoder error?
+ Failf("typeDatabase does not contain id %d", id)
+ }
+ // Extract the pending info.
+ pending := tbd.pending[id-1]
+ // Grow the byID list.
+ if len(tbd.byID) < int(id) {
+ tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...)
+ }
+ // Reconcile the type.
+ name, fields, ok := lookupNameFields(typ)
+ if !ok {
+ // Empty structs are decoded only when the type is nil. Since
+ // this isn't the case, we fail here.
+ Failf("unsupported type %q during decode; can't reconcile", pending.Name)
+ }
+ if name != pending.Name {
+ // Are these the same type? Print a helpful message as this may
+ // actually happen in practice if types change.
+ Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)",
+ id, name, fields, pending.Name, pending.Fields)
+ }
+ rte := &reconciledTypeEntry{
+ Type: wire.Type{
+ Name: name,
+ Fields: fields,
+ },
+ LocalType: typ,
+ }
+ // If there are zero or one fields, then we skip allocating the field
+ // slice. There is special handling for decoding in this case. If the
+ // field name does not match, it will be caught in the general purpose
+ // code below.
+ if len(fields) != len(pending.Fields) {
+ Failf("type %q contains different fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ if len(fields) == 0 {
+ tbd.byID[id-1] = rte // Save.
+ return rte
+ }
+ if len(fields) == 1 && fields[0] == pending.Fields[0] {
+ tbd.byID[id-1] = rte // Save.
+ rte.FieldOrder = singleFieldOrder
+ return rte
+ }
+ // For each field in the current object's information, match it to a
+ // field in the destination object. We know from the assertion above
+ // and the insertion on insertion to pending that neither field
+ // contains any duplicates.
+ fieldOrder := make([]int, len(fields))
+ for i, name := range fields {
+ fieldOrder[i] = -1 // Sentinel.
+ // Is it an exact match?
+ if pending.Fields[i] == name {
+ fieldOrder[i] = i
+ continue
+ }
+ // Find the matching field.
+ for j, otherName := range pending.Fields {
+ if name == otherName {
+ fieldOrder[i] = j
+ break
+ }
+ }
+ if fieldOrder[i] == -1 {
+ // The type name matches but we are lacking some common fields.
+ Failf("type %q has mismatched fields: %v (decode) and %v (encode)",
+ name, fields, pending.Fields)
+ }
+ }
+ // The type has been reeconciled.
+ rte.FieldOrder = fieldOrder
+ tbd.byID[id-1] = rte
+ return rte
+}
+
+// interfaceType defines all interfaces.
+const interfaceType = "interface"
+
+// primitiveTypeDatabase is a set of fixed types.
+var primitiveTypeDatabase = func() map[string]reflect.Type {
+ r := make(map[string]reflect.Type)
+ for _, t := range []reflect.Type{
+ reflect.TypeOf(false),
+ reflect.TypeOf(int(0)),
+ reflect.TypeOf(int8(0)),
+ reflect.TypeOf(int16(0)),
+ reflect.TypeOf(int32(0)),
+ reflect.TypeOf(int64(0)),
+ reflect.TypeOf(uint(0)),
+ reflect.TypeOf(uintptr(0)),
+ reflect.TypeOf(uint8(0)),
+ reflect.TypeOf(uint16(0)),
+ reflect.TypeOf(uint32(0)),
+ reflect.TypeOf(uint64(0)),
+ reflect.TypeOf(""),
+ reflect.TypeOf(float32(0.0)),
+ reflect.TypeOf(float64(0.0)),
+ reflect.TypeOf(complex64(0.0)),
+ reflect.TypeOf(complex128(0.0)),
+ } {
+ r[t.Name()] = t
+ }
+ return r
+}()
+
+// globalTypeDatabase is used for dispatching interfaces on decode.
+var globalTypeDatabase = map[string]reflect.Type{}
+
+// Register registers a type.
+//
+// This must be called on init and only done once.
+func Register(t Type) {
+ name := t.StateTypeName()
+ fields := t.StateFields()
+ assertValidType(name, fields)
+ // Register must always be called on pointers.
+ typ := reflect.TypeOf(t)
+ if typ.Kind() != reflect.Ptr {
+ Failf("Register must be called on pointers")
+ }
+ typ = typ.Elem()
+ if typ.Kind() == reflect.Struct {
+ // All registered structs must implement SaverLoader. We allow
+ // the registration is non-struct types with just the Type
+ // interface, but we need to call StateSave/StateLoad methods
+ // on aggregate types.
+ if _, ok := t.(SaverLoader); !ok {
+ Failf("struct %T does not implement SaverLoader", t)
+ }
+ } else {
+ // Non-structs must not have any fields. We don't support
+ // calling StateSave/StateLoad methods on any non-struct types.
+ // If custom behavior is required, these types should be
+ // wrapped in a structure of some kind.
+ if len(fields) != 0 {
+ Failf("non-struct %T has non-zero fields %v", t, fields)
+ }
+ // We don't allow non-structs to implement StateSave/StateLoad
+ // methods, because they won't be called and it's confusing.
+ if _, ok := t.(SaverLoader); ok {
+ Failf("non-struct %T implements SaverLoader", t)
+ }
+ }
+ if _, ok := primitiveTypeDatabase[name]; ok {
+ Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t)
+ }
+ if _, ok := globalTypeDatabase[name]; ok {
+ Failf("conflicting globalTypeDatabase entries for %T: name conflict", t)
+ }
+ if name == interfaceType {
+ Failf("conflicting name for %T: matches interfaceType", t)
+ }
+ globalTypeDatabase[name] = typ
+}
diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD
new file mode 100644
index 000000000..311b93dcb
--- /dev/null
+++ b/pkg/state/wire/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "wire",
+ srcs = ["wire.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/gohacks"],
+)
diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go
new file mode 100644
index 000000000..93dee6740
--- /dev/null
+++ b/pkg/state/wire/wire.go
@@ -0,0 +1,970 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package wire contains a few basic types that can be composed to serialize
+// graph information for the state package. This package defines the wire
+// protocol.
+//
+// Note that these types are careful about how they implement the relevant
+// interfaces (either value receiver or pointer receiver), so that native-sized
+// types, such as integers and simple pointers, can fit inside the interface
+// object.
+//
+// This package also uses panic as control flow, so called should be careful to
+// wrap calls in appropriate handlers.
+//
+// Testing for this package is driven by the state test package.
+package wire
+
+import (
+ "fmt"
+ "io"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
+)
+
+// Reader is the required reader interface.
+type Reader interface {
+ io.Reader
+ ReadByte() (byte, error)
+}
+
+// Writer is the required writer interface.
+type Writer interface {
+ io.Writer
+ WriteByte(byte) error
+}
+
+// readFull is a utility. The equivalent is not needed for Write, but the API
+// contract dictates that it must always complete all bytes given or return an
+// error.
+func readFull(r io.Reader, p []byte) {
+ for done := 0; done < len(p); {
+ n, err := r.Read(p[done:])
+ done += n
+ if n == 0 && err != nil {
+ panic(err)
+ }
+ }
+}
+
+// Object is a generic object.
+type Object interface {
+ // save saves the given object.
+ //
+ // Panic is used for error control flow.
+ save(Writer)
+
+ // load loads a new object of the given type.
+ //
+ // Panic is used for error control flow.
+ load(Reader) Object
+}
+
+// Bool is a boolean.
+type Bool bool
+
+// loadBool loads an object of type Bool.
+func loadBool(r Reader) Bool {
+ b := loadUint(r)
+ return Bool(b == 1)
+}
+
+// save implements Object.save.
+func (b Bool) save(w Writer) {
+ var v Uint
+ if b {
+ v = 1
+ } else {
+ v = 0
+ }
+ v.save(w)
+}
+
+// load implements Object.load.
+func (Bool) load(r Reader) Object { return loadBool(r) }
+
+// Int is a signed integer.
+//
+// This uses varint encoding.
+type Int int64
+
+// loadInt loads an object of type Int.
+func loadInt(r Reader) Int {
+ u := loadUint(r)
+ x := Int(u >> 1)
+ if u&1 != 0 {
+ x = ^x
+ }
+ return x
+}
+
+// save implements Object.save.
+func (i Int) save(w Writer) {
+ u := Uint(i) << 1
+ if i < 0 {
+ u = ^u
+ }
+ u.save(w)
+}
+
+// load implements Object.load.
+func (Int) load(r Reader) Object { return loadInt(r) }
+
+// Uint is an unsigned integer.
+type Uint uint64
+
+// loadUint loads an object of type Uint.
+func loadUint(r Reader) Uint {
+ var (
+ u Uint
+ s uint
+ )
+ for i := 0; i <= 9; i++ {
+ b, err := r.ReadByte()
+ if err != nil {
+ panic(err)
+ }
+ if b < 0x80 {
+ if i == 9 && b > 1 {
+ panic("overflow")
+ }
+ u |= Uint(b) << s
+ return u
+ }
+ u |= Uint(b&0x7f) << s
+ s += 7
+ }
+ panic("unreachable")
+}
+
+// save implements Object.save.
+func (u Uint) save(w Writer) {
+ for u >= 0x80 {
+ if err := w.WriteByte(byte(u) | 0x80); err != nil {
+ panic(err)
+ }
+ u >>= 7
+ }
+ if err := w.WriteByte(byte(u)); err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (Uint) load(r Reader) Object { return loadUint(r) }
+
+// Float32 is a 32-bit floating point number.
+type Float32 float32
+
+// loadFloat32 loads an object of type Float32.
+func loadFloat32(r Reader) Float32 {
+ n := loadUint(r)
+ return Float32(math.Float32frombits(uint32(n)))
+}
+
+// save implements Object.save.
+func (f Float32) save(w Writer) {
+ n := Uint(math.Float32bits(float32(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float32) load(r Reader) Object { return loadFloat32(r) }
+
+// Float64 is a 64-bit floating point number.
+type Float64 float64
+
+// loadFloat64 loads an object of type Float64.
+func loadFloat64(r Reader) Float64 {
+ n := loadUint(r)
+ return Float64(math.Float64frombits(uint64(n)))
+}
+
+// save implements Object.save.
+func (f Float64) save(w Writer) {
+ n := Uint(math.Float64bits(float64(f)))
+ n.save(w)
+}
+
+// load implements Object.load.
+func (Float64) load(r Reader) Object { return loadFloat64(r) }
+
+// Complex64 is a 64-bit complex number.
+type Complex64 complex128
+
+// loadComplex64 loads an object of type Complex64.
+func loadComplex64(r Reader) Complex64 {
+ re := loadFloat32(r)
+ im := loadFloat32(r)
+ return Complex64(complex(float32(re), float32(im)))
+}
+
+// save implements Object.save.
+func (c *Complex64) save(w Writer) {
+ re := Float32(real(*c))
+ im := Float32(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex64) load(r Reader) Object {
+ c := loadComplex64(r)
+ return &c
+}
+
+// Complex128 is a 128-bit complex number.
+type Complex128 complex128
+
+// loadComplex128 loads an object of type Complex128.
+func loadComplex128(r Reader) Complex128 {
+ re := loadFloat64(r)
+ im := loadFloat64(r)
+ return Complex128(complex(float64(re), float64(im)))
+}
+
+// save implements Object.save.
+func (c *Complex128) save(w Writer) {
+ re := Float64(real(*c))
+ im := Float64(imag(*c))
+ re.save(w)
+ im.save(w)
+}
+
+// load implements Object.load.
+func (*Complex128) load(r Reader) Object {
+ c := loadComplex128(r)
+ return &c
+}
+
+// String is a string.
+type String string
+
+// loadString loads an object of type String.
+func loadString(r Reader) String {
+ l := loadUint(r)
+ p := make([]byte, l)
+ readFull(r, p)
+ return String(gohacks.StringFromImmutableBytes(p))
+}
+
+// save implements Object.save.
+func (s *String) save(w Writer) {
+ l := Uint(len(*s))
+ l.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*s))
+ _, err := w.Write(p) // Must write all bytes.
+ if err != nil {
+ panic(err)
+ }
+}
+
+// load implements Object.load.
+func (*String) load(r Reader) Object {
+ s := loadString(r)
+ return &s
+}
+
+// Dot is a kind of reference: one of Index and FieldName.
+type Dot interface {
+ isDot()
+}
+
+// Index is a reference resolution.
+type Index uint32
+
+func (Index) isDot() {}
+
+// FieldName is a reference resolution.
+type FieldName string
+
+func (*FieldName) isDot() {}
+
+// Ref is a reference to an object.
+type Ref struct {
+ // Root is the root object.
+ Root Uint
+
+ // Dots is the set of traversals required from the Root object above.
+ // Note that this will be stored in reverse order for efficiency.
+ Dots []Dot
+
+ // Type is the base type for the root object. This is non-nil iff Dots
+ // is non-zero length (that is, this is a complex reference). This is
+ // not *strictly* necessary, but can be used to simplify decoding.
+ Type TypeSpec
+}
+
+// loadRef loads an object of type Ref (abstract).
+func loadRef(r Reader) Ref {
+ ref := Ref{
+ Root: loadUint(r),
+ }
+ l := loadUint(r)
+ ref.Dots = make([]Dot, l)
+ for i := 0; i < int(l); i++ {
+ // Disambiguate between an Index (non-negative) and a field
+ // name (negative). This does some space and avoids a dedicate
+ // loadDot function. See Ref.save for the other side.
+ d := loadInt(r)
+ if d >= 0 {
+ ref.Dots[i] = Index(d)
+ continue
+ }
+ p := make([]byte, -d)
+ readFull(r, p)
+ fieldName := FieldName(gohacks.StringFromImmutableBytes(p))
+ ref.Dots[i] = &fieldName
+ }
+ if l != 0 {
+ // Only if dots is non-zero.
+ ref.Type = loadTypeSpec(r)
+ }
+ return ref
+}
+
+// save implements Object.save.
+func (r *Ref) save(w Writer) {
+ r.Root.save(w)
+ l := Uint(len(r.Dots))
+ l.save(w)
+ for _, d := range r.Dots {
+ // See LoadRef. We use non-negative numbers to encode Index
+ // objects and negative numbers to encode field lengths.
+ switch x := d.(type) {
+ case Index:
+ i := Int(x)
+ i.save(w)
+ case *FieldName:
+ d := Int(-len(*x))
+ d.save(w)
+ p := gohacks.ImmutableBytesFromString(string(*x))
+ if _, err := w.Write(p); err != nil {
+ panic(err)
+ }
+ default:
+ panic("unknown dot implementation")
+ }
+ }
+ if l != 0 {
+ // See above.
+ saveTypeSpec(w, r.Type)
+ }
+}
+
+// load implements Object.load.
+func (*Ref) load(r Reader) Object {
+ ref := loadRef(r)
+ return &ref
+}
+
+// Nil is a primitive zero value of any type.
+type Nil struct{}
+
+// loadNil loads an object of type Nil.
+func loadNil(r Reader) Nil {
+ return Nil{}
+}
+
+// save implements Object.save.
+func (Nil) save(w Writer) {}
+
+// load implements Object.load.
+func (Nil) load(r Reader) Object { return loadNil(r) }
+
+// Slice is a slice value.
+type Slice struct {
+ Length Uint
+ Capacity Uint
+ Ref Ref
+}
+
+// loadSlice loads an object of type Slice.
+func loadSlice(r Reader) Slice {
+ return Slice{
+ Length: loadUint(r),
+ Capacity: loadUint(r),
+ Ref: loadRef(r),
+ }
+}
+
+// save implements Object.save.
+func (s *Slice) save(w Writer) {
+ s.Length.save(w)
+ s.Capacity.save(w)
+ s.Ref.save(w)
+}
+
+// load implements Object.load.
+func (*Slice) load(r Reader) Object {
+ s := loadSlice(r)
+ return &s
+}
+
+// Array is an array value.
+type Array struct {
+ Contents []Object
+}
+
+// loadArray loads an object of type Array.
+func loadArray(r Reader) Array {
+ l := loadUint(r)
+ if l == 0 {
+ // Note that there isn't a single object available to encode
+ // the type of, so we need this additional branch.
+ return Array{}
+ }
+ // All the objects here have the same type, so use dynamic dispatch
+ // only once. All other objects will automatically take the same type
+ // as the first object.
+ contents := make([]Object, l)
+ v := Load(r)
+ contents[0] = v
+ for i := 1; i < int(l); i++ {
+ contents[i] = v.load(r)
+ }
+ return Array{
+ Contents: contents,
+ }
+}
+
+// save implements Object.save.
+func (a *Array) save(w Writer) {
+ l := Uint(len(a.Contents))
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, a.Contents[0])
+ for i := 1; i < int(l); i++ {
+ a.Contents[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Array) load(r Reader) Object {
+ a := loadArray(r)
+ return &a
+}
+
+// Map is a map value.
+type Map struct {
+ Keys []Object
+ Values []Object
+}
+
+// loadMap loads an object of type Map.
+func loadMap(r Reader) Map {
+ l := loadUint(r)
+ if l == 0 {
+ // See LoadArray.
+ return Map{}
+ }
+ // See type dispatch notes in Array.
+ keys := make([]Object, l)
+ values := make([]Object, l)
+ k := Load(r)
+ v := Load(r)
+ keys[0] = k
+ values[0] = v
+ for i := 1; i < int(l); i++ {
+ keys[i] = k.load(r)
+ values[i] = v.load(r)
+ }
+ return Map{
+ Keys: keys,
+ Values: values,
+ }
+}
+
+// save implements Object.save.
+func (m *Map) save(w Writer) {
+ l := Uint(len(m.Keys))
+ if int(l) != len(m.Values) {
+ panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values)))
+ }
+ l.save(w)
+ if l == 0 {
+ // See LoadArray.
+ return
+ }
+ // See above.
+ Save(w, m.Keys[0])
+ Save(w, m.Values[0])
+ for i := 1; i < int(l); i++ {
+ m.Keys[i].save(w)
+ m.Values[i].save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Map) load(r Reader) Object {
+ m := loadMap(r)
+ return &m
+}
+
+// TypeSpec is a type dereference.
+type TypeSpec interface {
+ isTypeSpec()
+}
+
+// TypeID is a concrete type ID.
+type TypeID Uint
+
+func (TypeID) isTypeSpec() {}
+
+// TypeSpecPointer is a pointer type.
+type TypeSpecPointer struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecPointer) isTypeSpec() {}
+
+// TypeSpecArray is an array type.
+type TypeSpecArray struct {
+ Count Uint
+ Type TypeSpec
+}
+
+func (*TypeSpecArray) isTypeSpec() {}
+
+// TypeSpecSlice is a slice type.
+type TypeSpecSlice struct {
+ Type TypeSpec
+}
+
+func (*TypeSpecSlice) isTypeSpec() {}
+
+// TypeSpecMap is a map type.
+type TypeSpecMap struct {
+ Key TypeSpec
+ Value TypeSpec
+}
+
+func (*TypeSpecMap) isTypeSpec() {}
+
+// TypeSpecNil is an empty type.
+type TypeSpecNil struct{}
+
+func (TypeSpecNil) isTypeSpec() {}
+
+// TypeSpec types.
+//
+// These use a distinct encoding on the wire, as they are used only in the
+// interface object. They are decoded through the dedicated loadTypeSpec and
+// saveTypeSpec functions.
+const (
+ typeSpecTypeID Uint = iota
+ typeSpecPointer
+ typeSpecArray
+ typeSpecSlice
+ typeSpecMap
+ typeSpecNil
+)
+
+// loadTypeSpec loads TypeSpec values.
+func loadTypeSpec(r Reader) TypeSpec {
+ switch hdr := loadUint(r); hdr {
+ case typeSpecTypeID:
+ return TypeID(loadUint(r))
+ case typeSpecPointer:
+ return &TypeSpecPointer{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecArray:
+ return &TypeSpecArray{
+ Count: loadUint(r),
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecSlice:
+ return &TypeSpecSlice{
+ Type: loadTypeSpec(r),
+ }
+ case typeSpecMap:
+ return &TypeSpecMap{
+ Key: loadTypeSpec(r),
+ Value: loadTypeSpec(r),
+ }
+ case typeSpecNil:
+ return TypeSpecNil{}
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// saveTypeSpec saves TypeSpec values.
+func saveTypeSpec(w Writer, t TypeSpec) {
+ switch x := t.(type) {
+ case TypeID:
+ typeSpecTypeID.save(w)
+ Uint(x).save(w)
+ case *TypeSpecPointer:
+ typeSpecPointer.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecArray:
+ typeSpecArray.save(w)
+ x.Count.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecSlice:
+ typeSpecSlice.save(w)
+ saveTypeSpec(w, x.Type)
+ case *TypeSpecMap:
+ typeSpecMap.save(w)
+ saveTypeSpec(w, x.Key)
+ saveTypeSpec(w, x.Value)
+ case TypeSpecNil:
+ typeSpecNil.save(w)
+ default:
+ // This should not happen?
+ panic(fmt.Errorf("unknown type %T", t))
+ }
+}
+
+// Interface is an interface value.
+type Interface struct {
+ Type TypeSpec
+ Value Object
+}
+
+// loadInterface loads an object of type Interface.
+func loadInterface(r Reader) Interface {
+ return Interface{
+ Type: loadTypeSpec(r),
+ Value: Load(r),
+ }
+}
+
+// save implements Object.save.
+func (i *Interface) save(w Writer) {
+ saveTypeSpec(w, i.Type)
+ Save(w, i.Value)
+}
+
+// load implements Object.load.
+func (*Interface) load(r Reader) Object {
+ i := loadInterface(r)
+ return &i
+}
+
+// Type is type information.
+type Type struct {
+ Name string
+ Fields []string
+}
+
+// loadType loads an object of type Type.
+func loadType(r Reader) Type {
+ name := string(loadString(r))
+ l := loadUint(r)
+ fields := make([]string, l)
+ for i := 0; i < int(l); i++ {
+ fields[i] = string(loadString(r))
+ }
+ return Type{
+ Name: name,
+ Fields: fields,
+ }
+}
+
+// save implements Object.save.
+func (t *Type) save(w Writer) {
+ s := String(t.Name)
+ s.save(w)
+ l := Uint(len(t.Fields))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ s := String(t.Fields[i])
+ s.save(w)
+ }
+}
+
+// load implements Object.load.
+func (*Type) load(r Reader) Object {
+ t := loadType(r)
+ return &t
+}
+
+// multipleObjects is a special type for serializing multiple objects.
+type multipleObjects []Object
+
+// loadMultipleObjects loads a series of objects.
+func loadMultipleObjects(r Reader) multipleObjects {
+ l := loadUint(r)
+ m := make(multipleObjects, l)
+ for i := 0; i < int(l); i++ {
+ m[i] = Load(r)
+ }
+ return m
+}
+
+// save implements Object.save.
+func (m *multipleObjects) save(w Writer) {
+ l := Uint(len(*m))
+ l.save(w)
+ for i := 0; i < int(l); i++ {
+ Save(w, (*m)[i])
+ }
+}
+
+// load implements Object.load.
+func (*multipleObjects) load(r Reader) Object {
+ m := loadMultipleObjects(r)
+ return &m
+}
+
+// noObjects represents no objects.
+type noObjects struct{}
+
+// loadNoObjects loads a sentinel.
+func loadNoObjects(r Reader) noObjects { return noObjects{} }
+
+// save implements Object.save.
+func (noObjects) save(w Writer) {}
+
+// load implements Object.load.
+func (noObjects) load(r Reader) Object { return loadNoObjects(r) }
+
+// Struct is a basic composite value.
+type Struct struct {
+ TypeID TypeID
+ fields Object // Optionally noObjects or *multipleObjects.
+}
+
+// Field returns a pointer to the given field slot.
+//
+// This must be called after Alloc.
+func (s *Struct) Field(i int) *Object {
+ if fields, ok := s.fields.(*multipleObjects); ok {
+ return &((*fields)[i])
+ }
+ if _, ok := s.fields.(noObjects); ok {
+ // Alloc may be optionally called; can't call twice.
+ panic("Field called inappropriately, wrong Alloc?")
+ }
+ return &s.fields
+}
+
+// Alloc allocates the given number of fields.
+//
+// This must be called before Add and Save.
+//
+// Precondition: slots must be positive.
+func (s *Struct) Alloc(slots int) {
+ switch {
+ case slots == 0:
+ s.fields = noObjects{}
+ case slots == 1:
+ // Leave it alone.
+ case slots > 1:
+ fields := make(multipleObjects, slots)
+ s.fields = &fields
+ default:
+ // Violates precondition.
+ panic(fmt.Sprintf("Alloc called with negative slots %d?", slots))
+ }
+}
+
+// Fields returns the number of fields.
+func (s *Struct) Fields() int {
+ switch x := s.fields.(type) {
+ case *multipleObjects:
+ return len(*x)
+ case noObjects:
+ return 0
+ default:
+ return 1
+ }
+}
+
+// loadStruct loads an object of type Struct.
+func loadStruct(r Reader) Struct {
+ return Struct{
+ TypeID: TypeID(loadUint(r)),
+ fields: Load(r),
+ }
+}
+
+// save implements Object.save.
+//
+// Precondition: Alloc must have been called, and the fields all filled in
+// appropriately. See Alloc and Add for more details.
+func (s *Struct) save(w Writer) {
+ Uint(s.TypeID).save(w)
+ Save(w, s.fields)
+}
+
+// load implements Object.load.
+func (*Struct) load(r Reader) Object {
+ s := loadStruct(r)
+ return &s
+}
+
+// Object types.
+//
+// N.B. Be careful about changing the order or introducing new elements in the
+// middle here. This is part of the wire format and shouldn't change.
+const (
+ typeBool Uint = iota
+ typeInt
+ typeUint
+ typeFloat32
+ typeFloat64
+ typeNil
+ typeRef
+ typeString
+ typeSlice
+ typeArray
+ typeMap
+ typeStruct
+ typeNoObjects
+ typeMultipleObjects
+ typeInterface
+ typeComplex64
+ typeComplex128
+ typeType
+)
+
+// Save saves the given object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Save(w Writer, obj Object) {
+ switch x := obj.(type) {
+ case Bool:
+ typeBool.save(w)
+ x.save(w)
+ case Int:
+ typeInt.save(w)
+ x.save(w)
+ case Uint:
+ typeUint.save(w)
+ x.save(w)
+ case Float32:
+ typeFloat32.save(w)
+ x.save(w)
+ case Float64:
+ typeFloat64.save(w)
+ x.save(w)
+ case Nil:
+ typeNil.save(w)
+ x.save(w)
+ case *Ref:
+ typeRef.save(w)
+ x.save(w)
+ case *String:
+ typeString.save(w)
+ x.save(w)
+ case *Slice:
+ typeSlice.save(w)
+ x.save(w)
+ case *Array:
+ typeArray.save(w)
+ x.save(w)
+ case *Map:
+ typeMap.save(w)
+ x.save(w)
+ case *Struct:
+ typeStruct.save(w)
+ x.save(w)
+ case noObjects:
+ typeNoObjects.save(w)
+ x.save(w)
+ case *multipleObjects:
+ typeMultipleObjects.save(w)
+ x.save(w)
+ case *Interface:
+ typeInterface.save(w)
+ x.save(w)
+ case *Type:
+ typeType.save(w)
+ x.save(w)
+ case *Complex64:
+ typeComplex64.save(w)
+ x.save(w)
+ case *Complex128:
+ typeComplex128.save(w)
+ x.save(w)
+ default:
+ panic(fmt.Errorf("unknown type: %#v", obj))
+ }
+}
+
+// Load loads a new object.
+//
+// +checkescape all
+//
+// N.B. This function will panic on error.
+func Load(r Reader) Object {
+ switch hdr := loadUint(r); hdr {
+ case typeBool:
+ return loadBool(r)
+ case typeInt:
+ return loadInt(r)
+ case typeUint:
+ return loadUint(r)
+ case typeFloat32:
+ return loadFloat32(r)
+ case typeFloat64:
+ return loadFloat64(r)
+ case typeNil:
+ return loadNil(r)
+ case typeRef:
+ return ((*Ref)(nil)).load(r) // Escapes.
+ case typeString:
+ return ((*String)(nil)).load(r) // Escapes.
+ case typeSlice:
+ return ((*Slice)(nil)).load(r) // Escapes.
+ case typeArray:
+ return ((*Array)(nil)).load(r) // Escapes.
+ case typeMap:
+ return ((*Map)(nil)).load(r) // Escapes.
+ case typeStruct:
+ return ((*Struct)(nil)).load(r) // Escapes.
+ case typeNoObjects: // Special for struct.
+ return loadNoObjects(r)
+ case typeMultipleObjects: // Special for struct.
+ return ((*multipleObjects)(nil)).load(r) // Escapes.
+ case typeInterface:
+ return ((*Interface)(nil)).load(r) // Escapes.
+ case typeComplex64:
+ return ((*Complex64)(nil)).load(r) // Escapes.
+ case typeComplex128:
+ return ((*Complex128)(nil)).load(r) // Escapes.
+ case typeType:
+ return ((*Type)(nil)).load(r) // Escapes.
+ default:
+ // This is not a valid stream?
+ panic(fmt.Errorf("unknown header: %d", hdr))
+ }
+}
+
+// LoadUint loads a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func LoadUint(r Reader) uint64 {
+ return uint64(loadUint(r))
+}
+
+// SaveUint saves a single unsigned integer.
+//
+// N.B. This function will panic on error.
+func SaveUint(w Writer, v uint64) {
+ Uint(v).save(w)
+}
diff --git a/third_party/gvsync/BUILD b/pkg/sync/BUILD
index 7d6d59c48..4d47207f7 100644
--- a/third_party/gvsync/BUILD
+++ b/pkg/sync/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template")
package(
@@ -28,26 +28,29 @@ go_template(
)
go_library(
- name = "gvsync",
+ name = "sync",
srcs = [
- "downgradable_rwmutex_1_12_unsafe.go",
- "downgradable_rwmutex_1_13_unsafe.go",
- "downgradable_rwmutex_unsafe.go",
- "gvsync.go",
+ "aliases.go",
"memmove_unsafe.go",
+ "mutex_unsafe.go",
+ "nocopy.go",
"norace_unsafe.go",
"race_unsafe.go",
+ "rwmutex_unsafe.go",
"seqcount.go",
+ "sync.go",
],
- importpath = "gvisor.dev/gvisor/third_party/gvsync",
+ marshal = False,
+ stateify = False,
)
go_test(
- name = "gvsync_test",
+ name = "sync_test",
size = "small",
srcs = [
- "downgradable_rwmutex_test.go",
+ "mutex_test.go",
+ "rwmutex_test.go",
"seqcount_test.go",
],
- embed = [":gvsync"],
+ library = ":sync",
)
diff --git a/third_party/gvsync/LICENSE b/pkg/sync/LICENSE
index 6a66aea5e..6a66aea5e 100644
--- a/third_party/gvsync/LICENSE
+++ b/pkg/sync/LICENSE
diff --git a/third_party/gvsync/README.md b/pkg/sync/README.md
index fcc7e6f44..2183c4e20 100644
--- a/third_party/gvsync/README.md
+++ b/pkg/sync/README.md
@@ -1,3 +1,5 @@
+# Syncutil
+
This package provides additional synchronization primitives not provided by the
Go stdlib 'sync' package. It is partially derived from the upstream 'sync'
-package.
+package from go1.10.
diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go
new file mode 100644
index 000000000..0d4316254
--- /dev/null
+++ b/pkg/sync/aliases.go
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "sync"
+)
+
+// Aliases of standard library types.
+type (
+ // Cond is an alias of sync.Cond.
+ Cond = sync.Cond
+
+ // Locker is an alias of sync.Locker.
+ Locker = sync.Locker
+
+ // Once is an alias of sync.Once.
+ Once = sync.Once
+
+ // Pool is an alias of sync.Pool.
+ Pool = sync.Pool
+
+ // WaitGroup is an alias of sync.WaitGroup.
+ WaitGroup = sync.WaitGroup
+
+ // Map is an alias of sync.Map.
+ Map = sync.Map
+)
+
+// NewCond is a wrapper around sync.NewCond.
+func NewCond(l Locker) *Cond {
+ return sync.NewCond(l)
+}
diff --git a/third_party/gvsync/atomicptr_unsafe.go b/pkg/sync/atomicptr_unsafe.go
index 525c4beed..525c4beed 100644
--- a/third_party/gvsync/atomicptr_unsafe.go
+++ b/pkg/sync/atomicptr_unsafe.go
diff --git a/third_party/gvsync/atomicptrtest/BUILD b/pkg/sync/atomicptrtest/BUILD
index 447ecf96a..e97553254 100644
--- a/third_party/gvsync/atomicptrtest/BUILD
+++ b/pkg/sync/atomicptrtest/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -8,7 +8,7 @@ go_template_instance(
out = "atomicptr_int_unsafe.go",
package = "atomicptr",
suffix = "Int",
- template = "//third_party/gvsync:generic_atomicptr",
+ template = "//pkg/sync:generic_atomicptr",
types = {
"Value": "int",
},
@@ -17,12 +17,11 @@ go_template_instance(
go_library(
name = "atomicptr",
srcs = ["atomicptr_int_unsafe.go"],
- importpath = "gvisor.dev/gvisor/third_party/gvsync/atomicptr",
)
go_test(
name = "atomicptr_test",
size = "small",
srcs = ["atomicptr_test.go"],
- embed = [":atomicptr"],
+ library = ":atomicptr",
)
diff --git a/third_party/gvsync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptrtest/atomicptr_test.go
index 8fdc5112e..8fdc5112e 100644
--- a/third_party/gvsync/atomicptrtest/atomicptr_test.go
+++ b/pkg/sync/atomicptrtest/atomicptr_test.go
diff --git a/third_party/gvsync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go
index 84b69f215..1d7780695 100644
--- a/third_party/gvsync/memmove_unsafe.go
+++ b/pkg/sync/memmove_unsafe.go
@@ -4,11 +4,11 @@
// license that can be found in the LICENSE file.
// +build go1.12
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
-package gvsync
+package sync
import (
"unsafe"
diff --git a/pkg/sync/mutex_test.go b/pkg/sync/mutex_test.go
new file mode 100644
index 000000000..0838248b4
--- /dev/null
+++ b/pkg/sync/mutex_test.go
@@ -0,0 +1,71 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sync
+
+import (
+ "sync"
+ "testing"
+ "unsafe"
+)
+
+// TestStructSize verifies that syncMutex's size hasn't drifted from the
+// standard library's version.
+//
+// The correctness of this package relies on these remaining in sync.
+func TestStructSize(t *testing.T) {
+ const (
+ got = unsafe.Sizeof(syncMutex{})
+ want = unsafe.Sizeof(sync.Mutex{})
+ )
+ if got != want {
+ t.Errorf("got sizeof(syncMutex) = %d, want = sizeof(sync.Mutex) = %d", got, want)
+ }
+}
+
+// TestFieldValues verifies that the semantics of syncMutex.state from the
+// standard library's implementation.
+//
+// The correctness of this package relies on these remaining in sync.
+func TestFieldValues(t *testing.T) {
+ var m Mutex
+ m.Lock()
+ if got := *m.state(); got != mutexLocked {
+ t.Errorf("got locked sync.Mutex.state = %d, want = %d", got, mutexLocked)
+ }
+ m.Unlock()
+ if got := *m.state(); got != mutexUnlocked {
+ t.Errorf("got unlocked sync.Mutex.state = %d, want = %d", got, mutexUnlocked)
+ }
+}
+
+func TestDoubleTryLock(t *testing.T) {
+ var m Mutex
+ if !m.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ if m.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestTryLockAfterLock(t *testing.T) {
+ var m Mutex
+ m.Lock()
+ if m.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestTryLockUnlock(t *testing.T) {
+ var m Mutex
+ if !m.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ m.Unlock()
+ if !m.TryLock() {
+ t.Fatal("failed to aquire lock after unlock")
+ }
+}
diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go
new file mode 100644
index 000000000..dc034d561
--- /dev/null
+++ b/pkg/sync/mutex_unsafe.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.13
+// +build !go1.16
+
+// When updating the build constraint (above), check that syncMutex matches the
+// standard library sync.Mutex definition.
+
+package sync
+
+import (
+ "sync"
+ "sync/atomic"
+ "unsafe"
+)
+
+// Mutex is a try lock.
+type Mutex struct {
+ sync.Mutex
+}
+
+type syncMutex struct {
+ state int32
+ sema uint32
+}
+
+func (m *Mutex) state() *int32 {
+ return &(*syncMutex)(unsafe.Pointer(&m.Mutex)).state
+}
+
+const (
+ mutexUnlocked = 0
+ mutexLocked = 1
+)
+
+// TryLock tries to aquire the mutex. It returns true if it succeeds and false
+// otherwise. TryLock does not block.
+func (m *Mutex) TryLock() bool {
+ if atomic.CompareAndSwapInt32(m.state(), mutexUnlocked, mutexLocked) {
+ if RaceEnabled {
+ RaceAcquire(unsafe.Pointer(&m.Mutex))
+ }
+ return true
+ }
+ return false
+}
diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go
new file mode 100644
index 000000000..722b29501
--- /dev/null
+++ b/pkg/sync/nocopy.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sync
+
+// NoCopy may be embedded into structs which must not be copied
+// after the first use.
+//
+// See https://golang.org/issues/8005#issuecomment-190753527
+// for details.
+type NoCopy struct{}
+
+// Lock is a no-op used by -copylocks checker from `go vet`.
+func (*NoCopy) Lock() {}
+
+// Unlock is a no-op used by -copylocks checker from `go vet`.
+func (*NoCopy) Unlock() {}
diff --git a/third_party/gvsync/norace_unsafe.go b/pkg/sync/norace_unsafe.go
index e3852db8c..006055dd6 100644
--- a/third_party/gvsync/norace_unsafe.go
+++ b/pkg/sync/norace_unsafe.go
@@ -5,7 +5,7 @@
// +build !race
-package gvsync
+package sync
import (
"unsafe"
diff --git a/third_party/gvsync/race_unsafe.go b/pkg/sync/race_unsafe.go
index 13c02a830..31d8fa9a6 100644
--- a/third_party/gvsync/race_unsafe.go
+++ b/pkg/sync/race_unsafe.go
@@ -5,7 +5,7 @@
// +build race
-package gvsync
+package sync
import (
"runtime"
diff --git a/third_party/gvsync/downgradable_rwmutex_test.go b/pkg/sync/rwmutex_test.go
index 40c384b8b..ce667e825 100644
--- a/third_party/gvsync/downgradable_rwmutex_test.go
+++ b/pkg/sync/rwmutex_test.go
@@ -9,7 +9,7 @@
// addition of downgradingWriter and the renaming of num_iterations to
// numIterations to shut up Golint.
-package gvsync
+package sync
import (
"fmt"
@@ -18,7 +18,7 @@ import (
"testing"
)
-func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) {
+func parallelReader(m *RWMutex, clocked, cunlock, cdone chan bool) {
m.RLock()
clocked <- true
<-cunlock
@@ -28,7 +28,7 @@ func parallelReader(m *DowngradableRWMutex, clocked, cunlock, cdone chan bool) {
func doTestParallelReaders(numReaders, gomaxprocs int) {
runtime.GOMAXPROCS(gomaxprocs)
- var m DowngradableRWMutex
+ var m RWMutex
clocked := make(chan bool)
cunlock := make(chan bool)
cdone := make(chan bool)
@@ -55,7 +55,7 @@ func TestParallelReaders(t *testing.T) {
doTestParallelReaders(4, 2)
}
-func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) {
+func reader(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) {
for i := 0; i < numIterations; i++ {
rwm.RLock()
n := atomic.AddInt32(activity, 1)
@@ -70,7 +70,7 @@ func reader(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone
cdone <- true
}
-func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) {
+func writer(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) {
for i := 0; i < numIterations; i++ {
rwm.Lock()
n := atomic.AddInt32(activity, 10000)
@@ -85,7 +85,7 @@ func writer(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone
cdone <- true
}
-func downgradingWriter(rwm *DowngradableRWMutex, numIterations int, activity *int32, cdone chan bool) {
+func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone chan bool) {
for i := 0; i < numIterations; i++ {
rwm.Lock()
n := atomic.AddInt32(activity, 10000)
@@ -112,7 +112,7 @@ func HammerDowngradableRWMutex(gomaxprocs, numReaders, numIterations int) {
runtime.GOMAXPROCS(gomaxprocs)
// Number of active readers + 10000 * number of active writers.
var activity int32
- var rwm DowngradableRWMutex
+ var rwm RWMutex
cdone := make(chan bool)
go writer(&rwm, numIterations, &activity, cdone)
go downgradingWriter(&rwm, numIterations, &activity, cdone)
@@ -148,3 +148,58 @@ func TestDowngradableRWMutex(t *testing.T) {
HammerDowngradableRWMutex(10, 10, n)
HammerDowngradableRWMutex(10, 5, n)
}
+
+func TestRWDoubleTryLock(t *testing.T) {
+ var rwm RWMutex
+ if !rwm.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ if rwm.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestRWTryLockAfterLock(t *testing.T) {
+ var rwm RWMutex
+ rwm.Lock()
+ if rwm.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestRWTryLockUnlock(t *testing.T) {
+ var rwm RWMutex
+ if !rwm.TryLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ rwm.Unlock()
+ if !rwm.TryLock() {
+ t.Fatal("failed to aquire lock after unlock")
+ }
+}
+
+func TestTryRLockAfterLock(t *testing.T) {
+ var rwm RWMutex
+ rwm.Lock()
+ if rwm.TryRLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestTryLockAfterRLock(t *testing.T) {
+ var rwm RWMutex
+ rwm.RLock()
+ if rwm.TryLock() {
+ t.Fatal("unexpectedly succeeded in aquiring locked mutex")
+ }
+}
+
+func TestDoubleTryRLock(t *testing.T) {
+ var rwm RWMutex
+ if !rwm.TryRLock() {
+ t.Fatal("failed to aquire lock")
+ }
+ if !rwm.TryRLock() {
+ t.Fatal("failed to read aquire read locked lock")
+ }
+}
diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go
index 1f6007aa1..995c0346e 100644
--- a/third_party/gvsync/downgradable_rwmutex_unsafe.go
+++ b/pkg/sync/rwmutex_unsafe.go
@@ -3,8 +3,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build go1.12
-// +build !go1.14
+// +build go1.13
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
@@ -16,10 +16,9 @@
// - RUnlock -> Lock (via writerSem)
// - DowngradeLock -> RLock (via readerSem)
-package gvsync
+package sync
import (
- "sync"
"sync/atomic"
"unsafe"
)
@@ -27,20 +26,48 @@ import (
//go:linkname runtimeSemacquire sync.runtime_Semacquire
func runtimeSemacquire(s *uint32)
-// DowngradableRWMutex is identical to sync.RWMutex, but adds the DowngradeLock
-// method.
-type DowngradableRWMutex struct {
- w sync.Mutex // held if there are pending writers
- writerSem uint32 // semaphore for writers to wait for completing readers
- readerSem uint32 // semaphore for readers to wait for completing writers
- readerCount int32 // number of pending readers
- readerWait int32 // number of departing readers
+//go:linkname runtimeSemrelease sync.runtime_Semrelease
+func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
+
+// RWMutex is identical to sync.RWMutex, but adds the DowngradeLock,
+// TryLock and TryRLock methods.
+type RWMutex struct {
+ w Mutex // held if there are pending writers
+ writerSem uint32 // semaphore for writers to wait for completing readers
+ readerSem uint32 // semaphore for readers to wait for completing writers
+ readerCount int32 // number of pending readers
+ readerWait int32 // number of departing readers
}
const rwmutexMaxReaders = 1 << 30
+// TryRLock locks rw for reading. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryRLock() bool {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ for {
+ rc := atomic.LoadInt32(&rw.readerCount)
+ if rc < 0 {
+ if RaceEnabled {
+ RaceEnable()
+ }
+ return false
+ }
+ if !atomic.CompareAndSwapInt32(&rw.readerCount, rc, rc+1) {
+ continue
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.readerSem))
+ }
+ return true
+ }
+}
+
// RLock locks rw for reading.
-func (rw *DowngradableRWMutex) RLock() {
+func (rw *RWMutex) RLock() {
if RaceEnabled {
RaceDisable()
}
@@ -55,14 +82,14 @@ func (rw *DowngradableRWMutex) RLock() {
}
// RUnlock undoes a single RLock call.
-func (rw *DowngradableRWMutex) RUnlock() {
+func (rw *RWMutex) RUnlock() {
if RaceEnabled {
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
RaceDisable()
}
if r := atomic.AddInt32(&rw.readerCount, -1); r < 0 {
if r+1 == 0 || r+1 == -rwmutexMaxReaders {
- panic("RUnlock of unlocked DowngradableRWMutex")
+ panic("RUnlock of unlocked RWMutex")
}
// A writer is pending.
if atomic.AddInt32(&rw.readerWait, -1) == 0 {
@@ -75,8 +102,36 @@ func (rw *DowngradableRWMutex) RUnlock() {
}
}
+// TryLock locks rw for writing. It returns true if it succeeds and false
+// otherwise. It does not block.
+func (rw *RWMutex) TryLock() bool {
+ if RaceEnabled {
+ RaceDisable()
+ }
+ // First, resolve competition with other writers.
+ if !rw.w.TryLock() {
+ if RaceEnabled {
+ RaceEnable()
+ }
+ return false
+ }
+ // Only proceed if there are no readers.
+ if !atomic.CompareAndSwapInt32(&rw.readerCount, 0, -rwmutexMaxReaders) {
+ rw.w.Unlock()
+ if RaceEnabled {
+ RaceEnable()
+ }
+ return false
+ }
+ if RaceEnabled {
+ RaceEnable()
+ RaceAcquire(unsafe.Pointer(&rw.writerSem))
+ }
+ return true
+}
+
// Lock locks rw for writing.
-func (rw *DowngradableRWMutex) Lock() {
+func (rw *RWMutex) Lock() {
if RaceEnabled {
RaceDisable()
}
@@ -95,7 +150,7 @@ func (rw *DowngradableRWMutex) Lock() {
}
// Unlock unlocks rw for writing.
-func (rw *DowngradableRWMutex) Unlock() {
+func (rw *RWMutex) Unlock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.writerSem))
RaceRelease(unsafe.Pointer(&rw.readerSem))
@@ -104,7 +159,7 @@ func (rw *DowngradableRWMutex) Unlock() {
// Announce to readers there is no active writer.
r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders)
if r >= rwmutexMaxReaders {
- panic("Unlock of unlocked DowngradableRWMutex")
+ panic("Unlock of unlocked RWMutex")
}
// Unblock blocked readers, if any.
for i := 0; i < int(r); i++ {
@@ -118,7 +173,7 @@ func (rw *DowngradableRWMutex) Unlock() {
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
-func (rw *DowngradableRWMutex) DowngradeLock() {
+func (rw *RWMutex) DowngradeLock() {
if RaceEnabled {
RaceRelease(unsafe.Pointer(&rw.readerSem))
RaceDisable()
@@ -126,7 +181,7 @@ func (rw *DowngradableRWMutex) DowngradeLock() {
// Announce to readers there is no active writer and one additional reader.
r := atomic.AddInt32(&rw.readerCount, rwmutexMaxReaders+1)
if r >= rwmutexMaxReaders+1 {
- panic("DowngradeLock of unlocked DowngradableRWMutex")
+ panic("DowngradeLock of unlocked RWMutex")
}
// Unblock blocked readers, if any. Note that this loop starts as 1 since r
// includes this goroutine.
diff --git a/third_party/gvsync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go
index 382eeed43..eda6fb131 100644
--- a/third_party/gvsync/seqatomic_unsafe.go
+++ b/pkg/sync/seqatomic_unsafe.go
@@ -13,7 +13,7 @@ import (
"strings"
"unsafe"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Value is a required type parameter.
@@ -26,17 +26,17 @@ type Value struct{}
// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race
// with any writer critical sections in sc.
-func SeqAtomicLoad(sc *gvsync.SeqCount, ptr *Value) Value {
+func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value {
// This function doesn't use SeqAtomicTryLoad because doing so is
// measurably, significantly (~20%) slower; Go is awful at inlining.
var val Value
for {
epoch := sc.BeginRead()
- if gvsync.RaceEnabled {
+ if sync.RaceEnabled {
// runtime.RaceDisable() doesn't actually stop the race detector,
// so it can't help us here. Instead, call runtime.memmove
// directly, which is not instrumented by the race detector.
- gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
// This is ~40% faster for short reads than going through memmove.
val = *ptr
@@ -52,10 +52,10 @@ func SeqAtomicLoad(sc *gvsync.SeqCount, ptr *Value) Value {
// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read
// would race with a writer critical section, SeqAtomicTryLoad returns
// (unspecified, false).
-func SeqAtomicTryLoad(sc *gvsync.SeqCount, epoch gvsync.SeqCountEpoch, ptr *Value) (Value, bool) {
+func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) {
var val Value
- if gvsync.RaceEnabled {
- gvsync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
+ if sync.RaceEnabled {
+ sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val))
} else {
val = *ptr
}
@@ -66,7 +66,7 @@ func init() {
var val Value
typ := reflect.TypeOf(val)
name := typ.Name()
- if ptrs := gvsync.PointersInType(typ, name); len(ptrs) != 0 {
+ if ptrs := sync.PointersInType(typ, name); len(ptrs) != 0 {
panic(fmt.Sprintf("SeqAtomicLoad<%s> is invalid since values %s of type %s contain pointers:\n%s", typ, name, typ, strings.Join(ptrs, "\n")))
}
}
diff --git a/third_party/gvsync/seqatomictest/BUILD b/pkg/sync/seqatomictest/BUILD
index c858c20c4..5c38c783e 100644
--- a/third_party/gvsync/seqatomictest/BUILD
+++ b/pkg/sync/seqatomictest/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -8,7 +8,7 @@ go_template_instance(
out = "seqatomic_int_unsafe.go",
package = "seqatomic",
suffix = "Int",
- template = "//third_party/gvsync:generic_seqatomic",
+ template = "//pkg/sync:generic_seqatomic",
types = {
"Value": "int",
},
@@ -17,9 +17,8 @@ go_template_instance(
go_library(
name = "seqatomic",
srcs = ["seqatomic_int_unsafe.go"],
- importpath = "gvisor.dev/gvisor/third_party/gvsync/seqatomic",
deps = [
- "//third_party/gvsync",
+ "//pkg/sync",
],
)
@@ -27,8 +26,6 @@ go_test(
name = "seqatomic_test",
size = "small",
srcs = ["seqatomic_test.go"],
- embed = [":seqatomic"],
- deps = [
- "//third_party/gvsync",
- ],
+ library = ":seqatomic",
+ deps = ["//pkg/sync"],
)
diff --git a/third_party/gvsync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomictest/seqatomic_test.go
index a5447f589..2c4568b07 100644
--- a/third_party/gvsync/seqatomictest/seqatomic_test.go
+++ b/pkg/sync/seqatomictest/seqatomic_test.go
@@ -19,11 +19,11 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/third_party/gvsync"
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSeqAtomicLoadUncontended(t *testing.T) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
const want = 1
data := want
if got := SeqAtomicLoadInt(&seq, &data); got != want {
@@ -32,7 +32,7 @@ func TestSeqAtomicLoadUncontended(t *testing.T) {
}
func TestSeqAtomicLoadAfterWrite(t *testing.T) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
var data int
const want = 1
seq.BeginWrite()
@@ -44,7 +44,7 @@ func TestSeqAtomicLoadAfterWrite(t *testing.T) {
}
func TestSeqAtomicLoadDuringWrite(t *testing.T) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
var data int
const want = 1
seq.BeginWrite()
@@ -59,7 +59,7 @@ func TestSeqAtomicLoadDuringWrite(t *testing.T) {
}
func TestSeqAtomicTryLoadUncontended(t *testing.T) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
const want = 1
data := want
epoch := seq.BeginRead()
@@ -69,7 +69,7 @@ func TestSeqAtomicTryLoadUncontended(t *testing.T) {
}
func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
var data int
epoch := seq.BeginRead()
seq.BeginWrite()
@@ -80,7 +80,7 @@ func TestSeqAtomicTryLoadDuringWrite(t *testing.T) {
}
func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
var data int
epoch := seq.BeginRead()
seq.BeginWrite()
@@ -91,7 +91,7 @@ func TestSeqAtomicTryLoadAfterWrite(t *testing.T) {
}
func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
const want = 42
data := want
b.RunParallel(func(pb *testing.PB) {
@@ -104,7 +104,7 @@ func BenchmarkSeqAtomicLoadIntUncontended(b *testing.B) {
}
func BenchmarkSeqAtomicTryLoadIntUncontended(b *testing.B) {
- var seq gvsync.SeqCount
+ var seq sync.SeqCount
const want = 42
data := want
b.RunParallel(func(pb *testing.PB) {
diff --git a/third_party/gvsync/seqcount.go b/pkg/sync/seqcount.go
index 2c9c2c3d6..a1e895352 100644
--- a/third_party/gvsync/seqcount.go
+++ b/pkg/sync/seqcount.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package gvsync
+package sync
import (
"fmt"
diff --git a/third_party/gvsync/seqcount_test.go b/pkg/sync/seqcount_test.go
index 085e574b3..6eb7b4b59 100644
--- a/third_party/gvsync/seqcount_test.go
+++ b/pkg/sync/seqcount_test.go
@@ -3,7 +3,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package gvsync
+package sync
import (
"reflect"
diff --git a/third_party/gvsync/gvsync.go b/pkg/sync/sync.go
index 3bbef13c3..b16cf5333 100644
--- a/third_party/gvsync/gvsync.go
+++ b/pkg/sync/sync.go
@@ -3,5 +3,5 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package gvsync provides synchronization primitives.
-package gvsync
+// Package sync provides synchronization primitives.
+package sync
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
new file mode 100644
index 000000000..0500a22cf
--- /dev/null
+++ b/pkg/syncevent/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "syncevent",
+ srcs = [
+ "broadcaster.go",
+ "receiver.go",
+ "source.go",
+ "syncevent.go",
+ "waiter_amd64.s",
+ "waiter_arm64.s",
+ "waiter_asm_unsafe.go",
+ "waiter_noasm_unsafe.go",
+ "waiter_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/atomicbitops",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "syncevent_test",
+ size = "small",
+ srcs = [
+ "broadcaster_test.go",
+ "syncevent_example_test.go",
+ "waiter_test.go",
+ ],
+ library = ":syncevent",
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go
new file mode 100644
index 000000000..4bff59e7d
--- /dev/null
+++ b/pkg/syncevent/broadcaster.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Broadcaster is an implementation of Source that supports any number of
+// subscribed Receivers.
+//
+// The zero value of Broadcaster is valid and has no subscribed Receivers.
+// Broadcaster is not copyable by value.
+//
+// All Broadcaster methods may be called concurrently from multiple goroutines.
+type Broadcaster struct {
+ // Broadcaster is implemented as a hash table where keys are assigned by
+ // the Broadcaster and returned as SubscriptionIDs, making it safe to use
+ // the identity function for hashing. The hash table resolves collisions
+ // using linear probing and features Robin Hood insertion and backward
+ // shift deletion in order to support a relatively high load factor
+ // efficiently, which matters since the cost of Broadcast is linear in the
+ // size of the table.
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // Invariants: len(table) is 0 or a power of 2.
+ table []broadcasterSlot
+
+ // load is the number of entries in table with receiver != nil.
+ load int
+
+ lastID SubscriptionID
+}
+
+type broadcasterSlot struct {
+ // Invariants: If receiver == nil, then filter == NoEvents and id == 0.
+ // Otherwise, id != 0.
+ receiver *Receiver
+ filter Set
+ id SubscriptionID
+}
+
+const (
+ broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1
+
+ broadcasterMaxLoadNum = 13
+ broadcasterMaxLoadDen = 16
+)
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID {
+ b.mu.Lock()
+
+ // Assign an ID for this subscription.
+ b.lastID++
+ id := b.lastID
+
+ // Expand the table if over the maximum load factor:
+ //
+ // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen
+ // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table)
+ b.load++
+ if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) {
+ // Double the number of slots in the new table.
+ newlen := broadcasterMinNonZeroTableSize
+ if len(b.table) != 0 {
+ newlen = 2 * len(b.table)
+ }
+ if newlen <= cap(b.table) {
+ // Reuse excess capacity in the current table, moving entries not
+ // already in their first-probed positions to better ones.
+ newtable := b.table[:newlen]
+ newmask := uint64(newlen - 1)
+ for i := range b.table {
+ if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) {
+ entry := b.table[i]
+ b.table[i] = broadcasterSlot{}
+ broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter)
+ }
+ }
+ b.table = newtable
+ } else {
+ newtable := make([]broadcasterSlot, newlen)
+ // Copy existing entries to the new table.
+ for i := range b.table {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ }
+ }
+ // Switch to the new table.
+ b.table = newtable
+ }
+ }
+
+ broadcasterTableInsert(b.table, id, r, filter)
+ b.mu.Unlock()
+ return id
+}
+
+// Preconditions: table must not be full. len(table) is a power of 2.
+func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) {
+ entry := broadcasterSlot{
+ receiver: r,
+ filter: filter,
+ id: id,
+ }
+ mask := uint64(len(table) - 1)
+ i := uint64(id) & mask
+ disp := uint64(0)
+ for {
+ if table[i].receiver == nil {
+ table[i] = entry
+ return
+ }
+ // If we've been displaced farther from our first-probed slot than the
+ // element stored in this one, swap elements and switch to inserting
+ // the replaced one. (This is Robin Hood insertion.)
+ slotDisp := (i - uint64(table[i].id)) & mask
+ if disp > slotDisp {
+ table[i], entry = entry, table[i]
+ disp = slotDisp
+ }
+ i = (i + 1) & mask
+ disp++
+ }
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) {
+ b.mu.Lock()
+
+ mask := uint64(len(b.table) - 1)
+ i := uint64(id) & mask
+ for {
+ if b.table[i].id == id {
+ // Found the element to remove. Move all subsequent elements
+ // backward until we either find an empty slot, or an element that
+ // is already in its first-probed slot. (This is backward shift
+ // deletion.)
+ for {
+ next := (i + 1) & mask
+ if b.table[next].receiver == nil {
+ break
+ }
+ if uint64(b.table[next].id)&mask == next {
+ break
+ }
+ b.table[i] = b.table[next]
+ i = next
+ }
+ b.table[i] = broadcasterSlot{}
+ break
+ }
+ i = (i + 1) & mask
+ }
+
+ // If a table 1/4 of the current size would still be at or under the
+ // maximum load factor (i.e. the current table size is at least two
+ // expansions bigger than necessary), halve the size of the table to reduce
+ // the cost of Broadcast. Since we are concerned with iteration time and
+ // not memory usage, reuse the existing slice to reduce future allocations
+ // from table re-expansion.
+ b.load--
+ if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) {
+ newlen := len(b.table) / 2
+ newtable := b.table[:newlen]
+ for i := newlen; i < len(b.table); i++ {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ b.table[i] = broadcasterSlot{}
+ }
+ }
+ b.table = newtable
+ }
+
+ b.mu.Unlock()
+}
+
+// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset
+// of events to which they subscribed. The order in which Receivers are
+// notified is unspecified.
+func (b *Broadcaster) Broadcast(events Set) {
+ b.mu.Lock()
+ for i := range b.table {
+ if intersection := events & b.table[i].filter; intersection != 0 {
+ // We don't need to check if broadcasterSlot.receiver is nil, since
+ // if it is then broadcasterSlot.filter is 0.
+ b.table[i].receiver.Notify(intersection)
+ }
+ }
+ b.mu.Unlock()
+}
+
+// FilteredEvents returns the set of events for which Broadcast will notify at
+// least one Receiver, i.e. the union of filters for all subscribed Receivers.
+func (b *Broadcaster) FilteredEvents() Set {
+ var es Set
+ b.mu.Lock()
+ for i := range b.table {
+ es |= b.table[i].filter
+ }
+ b.mu.Unlock()
+ return es
+}
diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go
new file mode 100644
index 000000000..e88779e23
--- /dev/null
+++ b/pkg/syncevent/broadcaster_test.go
@@ -0,0 +1,376 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestBroadcasterFilter(t *testing.T) {
+ const numReceivers = 2 * MaxEvents
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents))
+ }
+ for ev := 0; ev < MaxEvents; ev++ {
+ br.Broadcast(1 << ev)
+ for i := range ws {
+ want := NoEvents
+ if i%MaxEvents == ev {
+ want = 1 << ev
+ }
+ if got := ws[i].Receiver().PendingAndAckAll(); got != want {
+ t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want)
+ }
+ }
+ }
+}
+
+// TestBroadcasterManySubscriptions tests that subscriptions are not lost by
+// table expansion/compaction.
+func TestBroadcasterManySubscriptions(t *testing.T) {
+ const numReceivers = 5000 // arbitrary
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ ids := make([]SubscriptionID, numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Subscribe receiver i.
+ ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1)
+ // Check that receivers [0, i] are subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[j].Pending() != 1 {
+ t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i)
+ }
+ ws[j].Ack(1)
+ }
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Unsubscribe receiver unsub[i].
+ br.UnsubscribeEvents(ids[unsub[i]])
+ // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that
+ // receivers (unsub[i], unsub[numReceivers]) are still subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[unsub[j]].Pending() != 0 {
+ t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ }
+ for j := i + 1; j < numReceivers; j++ {
+ if ws[unsub[j]].Pending() != 1 {
+ t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ ws[unsub[j]].Ack(1)
+ }
+ }
+}
+
+var (
+ receiverCountsNonZero = []int{1, 4, 16, 64}
+ receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...)
+)
+
+// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage
+// pattern X (described in terms of Broadcaster) with Broadcaster, a
+// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively.
+
+// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe
+// cycle.
+
+func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) {
+ var br Broadcaster
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ id := br.SubscribeEvents(w.Receiver(), 1)
+ br.UnsubscribeEvents(id)
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribe(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m[w.Receiver()] = Set(1)
+ mu.Unlock()
+ mu.Lock()
+ delete(m, w.Receiver())
+ mu.Unlock()
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) {
+ var q waiter.Queue
+ e, _ := waiter.NewChannelEntry(nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.EventRegister(&e, 1)
+ q.EventUnregister(&e)
+ }
+}
+
+// BenchmarkXxxSubscribeUnsubscribeBatch is similar to
+// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large
+// number of Receivers at a time in order to measure the amortized overhead of
+// table expansion/compaction. (Since waiter.Queue is implemented using a
+// linked list, BenchmarkQueueSubscribeUnsubscribe and
+// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same
+// result.)
+
+const numBatchReceivers = 1000
+
+func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+ ids := make([]SubscriptionID, numBatchReceivers)
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ br.UnsubscribeEvents(ids[unsub[j]])
+ }
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ m[ws[j].Receiver()] = Set(1)
+ mu.Unlock()
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ delete(m, ws[unsub[j]].Receiver())
+ mu.Unlock()
+ }
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) {
+ var q waiter.Queue
+ es := make([]waiter.Entry, numBatchReceivers)
+ for i := range es {
+ es[i], _ = waiter.NewChannelEntry(nil)
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventRegister(&es[j], 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventUnregister(&es[unsub[j]])
+ }
+ }
+}
+
+// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast
+// already-pending events to multiple Receivers.
+
+func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+ br.Broadcast(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ for i := 0; i < n; i++ {
+ e, _ := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ }
+ q.Notify(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ }
+ })
+ }
+}
+
+// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to
+// multiple Receivers, check that all Receivers have received the event, and
+// clear the event from all Receivers.
+
+func BenchmarkBroadcasterBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ chs := make([]chan struct{}, n)
+ for i := range chs {
+ e, ch := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ chs[i] = ch
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ for _, ch := range chs {
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("channel did not receive event")
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go
new file mode 100644
index 000000000..5c86e5400
--- /dev/null
+++ b/pkg/syncevent/receiver.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+)
+
+// Receiver is an event sink that holds pending events and invokes a callback
+// whenever new events become pending. Receiver's methods may be called
+// concurrently from multiple goroutines.
+//
+// Receiver.Init() must be called before first use.
+type Receiver struct {
+ // pending is the set of pending events. pending is accessed using atomic
+ // memory operations.
+ pending uint64
+
+ // cb is notified when new events become pending. cb is immutable after
+ // Init().
+ cb ReceiverCallback
+}
+
+// ReceiverCallback receives callbacks from a Receiver.
+type ReceiverCallback interface {
+ // NotifyPending is called when the corresponding Receiver has new pending
+ // events.
+ //
+ // NotifyPending is called synchronously from Receiver.Notify(), so
+ // implementations must not take locks that may be held by callers of
+ // Receiver.Notify(). NotifyPending may be called concurrently from
+ // multiple goroutines.
+ NotifyPending()
+}
+
+// Init must be called before first use of r.
+func (r *Receiver) Init(cb ReceiverCallback) {
+ r.cb = cb
+}
+
+// Pending returns the set of pending events.
+func (r *Receiver) Pending() Set {
+ return Set(atomic.LoadUint64(&r.pending))
+}
+
+// Notify sets the given events as pending.
+func (r *Receiver) Notify(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already pending.
+ if p&es == es {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-OR because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) {
+ // If the CAS fails, fall back to atomic-OR.
+ atomicbitops.OrUint64(&r.pending, uint64(es))
+ }
+ r.cb.NotifyPending()
+}
+
+// Ack unsets the given events as pending.
+func (r *Receiver) Ack(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already not pending.
+ if p&es == 0 {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-AND because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) {
+ // If the CAS fails, fall back to atomic-AND.
+ atomicbitops.AndUint64(&r.pending, ^uint64(es))
+ }
+}
+
+// PendingAndAckAll unsets all events as pending and returns the set of
+// previously-pending events.
+//
+// PendingAndAckAll should only be used in preference to a call to Pending
+// followed by a conditional call to Ack when the caller expects events to be
+// pending (e.g. after a call to ReceiverCallback.NotifyPending()).
+func (r *Receiver) PendingAndAckAll() Set {
+ return Set(atomic.SwapUint64(&r.pending, 0))
+}
diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go
new file mode 100644
index 000000000..ddffb171a
--- /dev/null
+++ b/pkg/syncevent/source.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+// Source represents an event source.
+type Source interface {
+ // SubscribeEvents causes the Source to notify the given Receiver of the
+ // given subset of events.
+ //
+ // Preconditions: r != nil. The ReceiverCallback for r must not take locks
+ // that are ordered prior to the Source; for example, it cannot call any
+ // Source methods.
+ SubscribeEvents(r *Receiver, filter Set) SubscriptionID
+
+ // UnsubscribeEvents causes the Source to stop notifying the Receiver
+ // subscribed by a previous call to SubscribeEvents that returned the given
+ // SubscriptionID.
+ //
+ // Preconditions: UnsubscribeEvents may be called at most once for any
+ // given SubscriptionID.
+ UnsubscribeEvents(id SubscriptionID)
+}
+
+// SubscriptionID identifies a call to Source.SubscribeEvents.
+type SubscriptionID uint64
+
+// UnsubscribeAndAck is a convenience function that unsubscribes r from the
+// given events from src and also clears them from r.
+func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) {
+ src.UnsubscribeEvents(id)
+ r.Ack(filter)
+}
+
+// NoopSource implements Source by never sending events to subscribed
+// Receivers.
+type NoopSource struct{}
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID {
+ return 0
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (NoopSource) UnsubscribeEvents(SubscriptionID) {
+}
+
+// See Broadcaster for a non-noop implementations of Source.
diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go
new file mode 100644
index 000000000..9fb6a06de
--- /dev/null
+++ b/pkg/syncevent/syncevent.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syncevent provides efficient primitives for goroutine
+// synchronization based on event bitmasks.
+package syncevent
+
+// Set is a bitmask where each bit represents a distinct user-defined event.
+// The event package does not treat any bits in Set specially.
+type Set uint64
+
+const (
+ // NoEvents is a Set containing no events.
+ NoEvents = Set(0)
+
+ // AllEvents is a Set containing all possible events.
+ AllEvents = ^Set(0)
+
+ // MaxEvents is the number of distinct events that can be represented by a Set.
+ MaxEvents = 64
+)
diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go
new file mode 100644
index 000000000..bfb18e2ea
--- /dev/null
+++ b/pkg/syncevent/syncevent_example_test.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+)
+
+func Example_ioReadinessInterrputible() {
+ const (
+ evReady = Set(1 << iota)
+ evInterrupt
+ )
+ errNotReady := fmt.Errorf("not ready for I/O")
+
+ // State of some I/O object.
+ var (
+ br Broadcaster
+ ready uint32
+ )
+ doIO := func() error {
+ if atomic.LoadUint32(&ready) == 0 {
+ return errNotReady
+ }
+ return nil
+ }
+ go func() {
+ // The I/O object eventually becomes ready for I/O.
+ time.Sleep(100 * time.Millisecond)
+ // When it does, it first ensures that future calls to isReady() return
+ // true, then broadcasts the readiness event to Receivers.
+ atomic.StoreUint32(&ready, 1)
+ br.Broadcast(evReady)
+ }()
+
+ // Each user of the I/O object owns a Waiter.
+ var w Waiter
+ w.Init()
+ // The Waiter may be asynchronously interruptible, e.g. for signal
+ // handling in the sentry.
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ w.Receiver().Notify(evInterrupt)
+ }()
+
+ // To use the I/O object:
+ //
+ // Optionally, if the I/O object is likely to be ready, attempt I/O first.
+ err := doIO()
+ if err == nil {
+ // Success, we're done.
+ return /* nil */
+ }
+ if err != errNotReady {
+ // Failure, I/O failed for some reason other than readiness.
+ return /* err */
+ }
+ // Subscribe for readiness events from the I/O object.
+ id := br.SubscribeEvents(w.Receiver(), evReady)
+ // When we are finished blocking, unsubscribe from readiness events and
+ // remove readiness events from the pending event set.
+ defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id)
+ for {
+ // Attempt I/O again. This must be done after the call to SubscribeEvents,
+ // since the I/O object might have become ready between the previous call
+ // to doIO and the call to SubscribeEvents.
+ err = doIO()
+ if err == nil {
+ return /* nil */
+ }
+ if err != errNotReady {
+ return /* err */
+ }
+ // Block until either the I/O object indicates it is ready, or we are
+ // interrupted.
+ events := w.Wait()
+ if events&evInterrupt != 0 {
+ // In the specific case of sentry signal handling, signal delivery
+ // is handled by another system, so we aren't responsible for
+ // acknowledging evInterrupt.
+ return /* errInterrupted */
+ }
+ // Note that, in a concurrent context, the I/O object might become
+ // ready and then not ready again. To handle this:
+ //
+ // - evReady must be acknowledged before calling doIO() again (rather
+ // than after), so that if the I/O object becomes ready *again* after
+ // the call to doIO(), the readiness event is not lost.
+ //
+ // - We must loop instead of just calling doIO() once after receiving
+ // evReady.
+ w.Ack(evReady)
+ }
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s
new file mode 100644
index 000000000..5e216b045
--- /dev/null
+++ b/pkg/syncevent/waiter_amd64.s
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVQ ptr+0(FP), DI
+ MOVQ wg+8(FP), SI
+
+ MOVQ $·preparingG(SB), AX
+ LOCK
+ CMPXCHGQ DI, 0(SI)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
+
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s
new file mode 100644
index 000000000..f4c06f194
--- /dev/null
+++ b/pkg/syncevent/waiter_arm64.s
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVD wg+8(FP), R0
+ MOVD $·preparingG(SB), R1
+ MOVD ptr+0(FP), R2
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
+
diff --git a/pkg/syncevent/waiter_asm_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go
new file mode 100644
index 000000000..19d6b0b15
--- /dev/null
+++ b/pkg/syncevent/waiter_asm_unsafe.go
@@ -0,0 +1,24 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64 arm64
+
+package syncevent
+
+import (
+ "unsafe"
+)
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
new file mode 100644
index 000000000..0f74a689c
--- /dev/null
+++ b/pkg/syncevent/waiter_noasm_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// waiterUnlock is called from g0, so when the race detector is enabled,
+// waiterUnlock must be implemented in assembly since no race context is
+// available.
+//
+// +build !race
+// +build !amd64,!arm64
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// waiterUnlock is the "unlock function" passed to runtime.gopark by
+// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
+// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
+// should be aborted.
+//
+//go:nosplit
+func waiterUnlock(ptr unsafe.Pointer, wg *unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), ptr)
+}
diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go
new file mode 100644
index 000000000..3c8cbcdd8
--- /dev/null
+++ b/pkg/syncevent/waiter_test.go
@@ -0,0 +1,414 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestWaiterAlreadyPending(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ w.Notify(want)
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterAsyncNotify(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ w.Notify(want)
+ }()
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterWaitFor(t *testing.T) {
+ var w Waiter
+ w.Init()
+ evWaited := Set(1)
+ evOther := Set(2)
+ w.Notify(evOther)
+ notifiedEvent := uint32(0)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ atomic.StoreUint32(&notifiedEvent, 1)
+ w.Notify(evWaited)
+ }()
+ if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want {
+ t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want)
+ }
+ if atomic.LoadUint32(&notifiedEvent) == 0 {
+ t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event")
+ }
+}
+
+func TestWaiterWaitAndAckAll(t *testing.T) {
+ var w Waiter
+ w.Init()
+ w.Notify(AllEvents)
+ if got := w.WaitAndAckAll(); got != AllEvents {
+ t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents)
+ }
+ if got := w.Pending(); got != NoEvents {
+ t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got)
+ }
+}
+
+// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage
+// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and
+// buffered chan struct{} respectively. When the maximum number of event
+// sources is relevant, we use 3 event sources because this is representative
+// of the kernel.Task.block() use case: an interrupt source, a timeout source,
+// and the actual event source being waited on.
+
+// Event set used by most benchmarks.
+const evBench Set = 1
+
+// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of
+// an event that is already pending.
+
+func BenchmarkWaiterNotifyRedundant(b *testing.B) {
+ var w Waiter
+ w.Init()
+ w.Notify(evBench)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyRedundant(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+ w.Assert()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ }
+}
+
+func BenchmarkChannelNotifyRedundant(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ ch <- struct{}{}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an
+// event, return that event using a blocking check, and then unset the event as
+// pending.
+
+func BenchmarkWaiterNotifyWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+
+ // wait + ack
+ <-ch
+ }
+}
+
+// BenchmarkSleeperMultiNotifyWaitAck is equivalent to
+// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a
+// temporary sleep.Waker. This is necessary when multiple goroutines may wait
+// for the same event, since each sleep.Waker can wake only a single
+// sleep.Sleeper.
+//
+// The syncevent package does not require a distinct object for each
+// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and
+// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state
+// for channels, runtime.sudog, is inescapably runtime-allocated, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would
+// also be identical.
+
+func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ // The sleep package doesn't provide sync.Pool allocation of Wakers;
+ // we do for a fairer comparison.
+ wakerPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Waker{}
+ },
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := wakerPool.Get().(*sleep.Waker)
+ s.AddWaker(w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ wakerPool.Put(w)
+ }
+}
+
+// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also
+// includes allocation of a temporary Waiter. This models the case where a
+// goroutine not already associated with a Waiter needs one in order to block.
+//
+// The analogous state for channels is built into runtime.g, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be
+// identical.
+
+func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := GetWaiter()
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ PutWaiter(w)
+ }
+}
+
+func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) {
+ // The sleep package doesn't provide sync.Pool allocation of Sleepers;
+ // we do for a fairer comparison.
+ sleeperPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Sleeper{}
+ },
+ }
+ var w sleep.Waker
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s := sleeperPool.Get().(*sleep.Sleeper)
+ s.AddWaker(&w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ sleeperPool.Put(s)
+ }
+}
+
+// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows
+// for multiple event sources.
+
+func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ws[0].Assert()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, wanted 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+
+ // wait + clear
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an
+// event while another goroutine signals the event. This assumes that a new
+// goroutine doesn't run immediately (i.e. the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+
+func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(1)
+ }()
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Assert()
+ }()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }()
+ <-ch
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but
+// allows for multiple event sources.
+
+func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(evBench)
+ }()
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ws[0].Assert()
+ }()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, expected 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+ }()
+
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
new file mode 100644
index 000000000..ad271e1a0
--- /dev/null
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -0,0 +1,206 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.16
+
+// Check go:linkname function signatures when updating Go version.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g unsafe.Pointer, traceskip int)
+
+const (
+ waitReasonSelect = 9 // Go: src/runtime/runtime2.go
+ traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
+)
+
+// Waiter allows a goroutine to block on pending events received by a Receiver.
+//
+// Waiter.Init() must be called before first use.
+type Waiter struct {
+ r Receiver
+
+ // g is one of:
+ //
+ // - nil: No goroutine is blocking in Wait.
+ //
+ // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // completed waiterUnlock(). Thus the wait can only be interrupted by
+ // replacing the value of g with nil (the G may not be in state Gwaiting
+ // yet, so we can't call goready.)
+ //
+ // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
+ // goroutine blocked in Wait, which can only be woken by calling goready.
+ g unsafe.Pointer `state:"zerovalue"`
+}
+
+// Sentinel object for Waiter.g.
+var preparingG struct{}
+
+// Init must be called before first use of w.
+func (w *Waiter) Init() {
+ w.r.Init(w)
+}
+
+// Receiver returns the Receiver that receives events that unblock calls to
+// w.Wait().
+func (w *Waiter) Receiver() *Receiver {
+ return &w.r
+}
+
+// Pending returns the set of pending events.
+func (w *Waiter) Pending() Set {
+ return w.r.Pending()
+}
+
+// Wait blocks until at least one event is pending, then returns the set of
+// pending events. It does not affect the set of pending events; callers must
+// call w.Ack() to do so, or use w.WaitAndAck() instead.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) Wait() Set {
+ return w.WaitFor(AllEvents)
+}
+
+// WaitFor blocks until at least one event in es is pending, then returns the
+// set of pending events (including those not in es). It does not affect the
+// set of pending events; callers must call w.Ack() to do so.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitFor(es Set) Set {
+ for {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending.
+ if p := w.r.Pending(); p&es != NoEvents {
+ return p
+ }
+
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if p := w.r.Pending(); p&es != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+}
+
+// Ack marks the given events as not pending.
+func (w *Waiter) Ack(es Set) {
+ w.r.Ack(es)
+}
+
+// WaitAndAckAll blocks until at least one event is pending, then marks all
+// events as not pending and returns the set of previously-pending events.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitAndAckAll() Set {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending. Call Pending() first since, in the common case that events are
+ // not yet pending, this skips an atomic swap on w.r.pending.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+
+ for {
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+
+ // Check for pending events. We call PendingAndAckAll() directly now since
+ // we only expect to be woken after events become pending.
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+}
+
+// Notify marks the given events as pending, possibly unblocking concurrent
+// calls to w.Wait() or w.WaitFor().
+func (w *Waiter) Notify(es Set) {
+ w.r.Notify(es)
+}
+
+// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter
+// should not call NotifyPending.
+func (w *Waiter) NotifyPending() {
+ // Optimization: Skip the atomic swap on w.g if there is no sleeping
+ // goroutine. NotifyPending is called after w.r.Pending() is updated, so
+ // concurrent and future calls to w.Wait() will observe pending events and
+ // abort sleeping.
+ if atomic.LoadPointer(&w.g) == nil {
+ return
+ }
+ // Wake a sleeping G, or prevent a G that is preparing to sleep from doing
+ // so. Swap is needed here to ensure that only one call to NotifyPending
+ // calls goready.
+ if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
+ goready(g, 0)
+ }
+}
+
+var waiterPool = sync.Pool{
+ New: func() interface{} {
+ w := &Waiter{}
+ w.Init()
+ return w
+ },
+}
+
+// GetWaiter returns an unused Waiter. PutWaiter should be called to release
+// the Waiter once it is no longer needed.
+//
+// Where possible, users should prefer to associate each goroutine that calls
+// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of
+// Waiters in hot paths.
+func GetWaiter() *Waiter {
+ return waiterPool.Get().(*Waiter)
+}
+
+// PutWaiter releases an unused Waiter previously returned by GetWaiter.
+func PutWaiter(w *Waiter) {
+ waiterPool.Put(w)
+}
diff --git a/pkg/syserr/BUILD b/pkg/syserr/BUILD
index 5665ad4ee..7d760344a 100644
--- a/pkg/syserr/BUILD
+++ b/pkg/syserr/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,7 +9,6 @@ go_library(
"netstack.go",
"syserr.go",
],
- importpath = "gvisor.dev/gvisor/pkg/syserr",
visibility = ["//visibility:public"],
deps = [
"//pkg/abi/linux",
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 8ff922c69..5ae10939d 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -22,7 +22,7 @@ import (
// Mapping for tcpip.Error types.
var (
ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL)
- ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL)
+ ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV)
ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV)
ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT)
ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST)
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
index bd3f9fd28..b13c15d9b 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/syserror/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "syserror",
srcs = ["syserror.go"],
- importpath = "gvisor.dev/gvisor/pkg/syserror",
visibility = ["//visibility:public"],
)
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 1987e89cc..798e07b01 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -27,8 +27,10 @@ import (
var (
E2BIG = error(syscall.E2BIG)
EACCES = error(syscall.EACCES)
+ EADDRINUSE = error(syscall.EADDRINUSE)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
+ EBADFD = error(syscall.EBADFD)
EBUSY = error(syscall.EBUSY)
ECHILD = error(syscall.ECHILD)
ECONNREFUSED = error(syscall.ECONNREFUSED)
@@ -45,6 +47,7 @@ var (
ELIBBAD = error(syscall.ELIBBAD)
ELOOP = error(syscall.ELOOP)
EMFILE = error(syscall.EMFILE)
+ EMLINK = error(syscall.EMLINK)
EMSGSIZE = error(syscall.EMSGSIZE)
ENAMETOOLONG = error(syscall.ENAMETOOLONG)
ENOATTR = ENODATA
@@ -58,6 +61,7 @@ var (
ENOMEM = error(syscall.ENOMEM)
ENOSPC = error(syscall.ENOSPC)
ENOSYS = error(syscall.ENOSYS)
+ ENOTCONN = error(syscall.ENOTCONN)
ENOTDIR = error(syscall.ENOTDIR)
ENOTEMPTY = error(syscall.ENOTEMPTY)
ENOTSOCK = error(syscall.ENOTSOCK)
@@ -69,6 +73,7 @@ var (
EPERM = error(syscall.EPERM)
EPIPE = error(syscall.EPIPE)
ERANGE = error(syscall.ERANGE)
+ EREMOTE = error(syscall.EREMOTE)
EROFS = error(syscall.EROFS)
ESPIPE = error(syscall.ESPIPE)
ESRCH = error(syscall.ESRCH)
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 3c2b2b5ea..454e07662 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,12 +7,12 @@ go_library(
srcs = [
"tcpip.go",
"time_unsafe.go",
+ "timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip/buffer",
- "//pkg/tcpip/iptables",
"//pkg/waiter",
],
)
@@ -22,5 +21,12 @@ go_test(
name = "tcpip_test",
size = "small",
srcs = ["tcpip_test.go"],
- embed = [":tcpip"],
+ library = ":tcpip",
+)
+
+go_test(
+ name = "tcpip_x_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ deps = [":tcpip"],
)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 78df5a0b1..a984f1712 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,14 +1,13 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "gonet",
srcs = ["gonet.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet",
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
@@ -22,7 +21,7 @@ go_test(
name = "gonet_test",
size = "small",
srcs = ["gonet_test.go"],
- embed = [":gonet"],
+ library = ":gonet",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index cd6ce930a..d82ed5205 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -20,9 +20,9 @@ import (
"errors"
"io"
"net"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -43,18 +43,28 @@ func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
-// A Listener is a wrapper around a tcpip endpoint that implements
+// A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
// net.Listener.
-type Listener struct {
+type TCPListener struct {
stack *stack.Stack
ep tcpip.Endpoint
wq *waiter.Queue
cancel chan struct{}
}
-// NewListener creates a new Listener.
-func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
- // Create TCP endpoint, bind it, then start listening.
+// NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
+func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
+ return &TCPListener{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ cancel: make(chan struct{}),
+ }
+}
+
+// ListenTCP creates a new TCPListener.
+func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
+ // Create a TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
if err != nil {
@@ -81,28 +91,23 @@ func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkPr
}
}
- return &Listener{
- stack: s,
- ep: ep,
- wq: &wq,
- cancel: make(chan struct{}),
- }, nil
+ return NewTCPListener(s, &wq, ep), nil
}
// Close implements net.Listener.Close.
-func (l *Listener) Close() error {
+func (l *TCPListener) Close() error {
l.ep.Close()
return nil
}
// Shutdown stops the HTTP server.
-func (l *Listener) Shutdown() {
+func (l *TCPListener) Shutdown() {
l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
close(l.cancel) // broadcast cancellation
}
// Addr implements net.Listener.Addr.
-func (l *Listener) Addr() net.Addr {
+func (l *TCPListener) Addr() net.Addr {
a, err := l.ep.GetLocalAddress()
if err != nil {
return nil
@@ -208,9 +213,9 @@ func (d *deadlineTimer) SetDeadline(t time.Time) error {
return nil
}
-// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
+// A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
// interface.
-type Conn struct {
+type TCPConn struct {
deadlineTimer
wq *waiter.Queue
@@ -228,9 +233,9 @@ type Conn struct {
read buffer.View
}
-// NewConn creates a new Conn.
-func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
- c := &Conn{
+// NewTCPConn creates a new TCPConn.
+func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
+ c := &TCPConn{
wq: wq,
ep: ep,
}
@@ -239,7 +244,7 @@ func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
}
// Accept implements net.Conn.Accept.
-func (l *Listener) Accept() (net.Conn, error) {
+func (l *TCPListener) Accept() (net.Conn, error) {
n, wq, err := l.ep.Accept()
if err == tcpip.ErrWouldBlock {
@@ -272,7 +277,7 @@ func (l *Listener) Accept() (net.Conn, error) {
}
}
- return NewConn(wq, n), nil
+ return NewTCPConn(wq, n), nil
}
type opErrorer interface {
@@ -323,13 +328,18 @@ func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, a
}
// Read implements net.Conn.Read.
-func (c *Conn) Read(b []byte) (int, error) {
+func (c *TCPConn) Read(b []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
deadline := c.readCancel()
numRead := 0
+ defer func() {
+ if numRead != 0 {
+ c.ep.ModerateRecvBuf(numRead)
+ }
+ }()
for numRead != len(b) {
if len(c.read) == 0 {
var err error
@@ -352,7 +362,7 @@ func (c *Conn) Read(b []byte) (int, error) {
}
// Write implements net.Conn.Write.
-func (c *Conn) Write(b []byte) (int, error) {
+func (c *TCPConn) Write(b []byte) (int, error) {
deadline := c.writeCancel()
// Check if deadlineTimer has already expired.
@@ -431,7 +441,7 @@ func (c *Conn) Write(b []byte) (int, error) {
}
// Close implements net.Conn.Close.
-func (c *Conn) Close() error {
+func (c *TCPConn) Close() error {
c.ep.Close()
return nil
}
@@ -440,7 +450,7 @@ func (c *Conn) Close() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
-func (c *Conn) CloseRead() error {
+func (c *TCPConn) CloseRead() error {
if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -451,7 +461,7 @@ func (c *Conn) CloseRead() error {
// should just use Close.
//
// A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
-func (c *Conn) CloseWrite() error {
+func (c *TCPConn) CloseWrite() error {
if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
return c.newOpError("close", errors.New(terr.String()))
}
@@ -459,7 +469,7 @@ func (c *Conn) CloseWrite() error {
}
// LocalAddr implements net.Conn.LocalAddr.
-func (c *Conn) LocalAddr() net.Addr {
+func (c *TCPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
@@ -468,7 +478,7 @@ func (c *Conn) LocalAddr() net.Addr {
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *Conn) RemoteAddr() net.Addr {
+func (c *TCPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
@@ -476,7 +486,7 @@ func (c *Conn) RemoteAddr() net.Addr {
return fullToTCPAddr(a)
}
-func (c *Conn) newOpError(op string, err error) *net.OpError {
+func (c *TCPConn) newOpError(op string, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "tcp",
@@ -494,14 +504,14 @@ func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
}
-// DialTCP creates a new TCP Conn connected to the specified address.
-func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+// DialTCP creates a new TCPConn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
return DialContextTCP(context.Background(), s, addr, network)
}
-// DialContextTCP creates a new TCP Conn connected to the specified address
+// DialContextTCP creates a new TCPConn connected to the specified address
// with the option of adding cancellation and timeouts.
-func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
// Create TCP endpoint, then connect.
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
@@ -543,12 +553,12 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
}
}
- return NewConn(&wq, ep), nil
+ return NewTCPConn(&wq, ep), nil
}
-// A PacketConn is a wrapper around a tcpip endpoint that implements
-// net.PacketConn.
-type PacketConn struct {
+// A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
+// net.Conn and net.PacketConn.
+type UDPConn struct {
deadlineTimer
stack *stack.Stack
@@ -556,12 +566,23 @@ type PacketConn struct {
wq *waiter.Queue
}
-// DialUDP creates a new PacketConn.
+// NewUDPConn creates a new UDPConn.
+func NewUDPConn(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
+ c := &UDPConn{
+ stack: s,
+ ep: ep,
+ wq: wq,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
+// DialUDP creates a new UDPConn.
//
// If laddr is nil, a local address is automatically chosen.
//
-// If raddr is nil, the PacketConn is left unconnected.
-func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
+// If raddr is nil, the UDPConn is left unconnected.
+func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
var wq waiter.Queue
ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
if err != nil {
@@ -580,12 +601,7 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- c := PacketConn{
- stack: s,
- ep: ep,
- wq: &wq,
- }
- c.deadlineTimer.init()
+ c := NewUDPConn(s, &wq, ep)
if raddr != nil {
if err := c.ep.Connect(*raddr); err != nil {
@@ -599,14 +615,14 @@ func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.Netw
}
}
- return &c, nil
+ return c, nil
}
-func (c *PacketConn) newOpError(op string, err error) *net.OpError {
+func (c *UDPConn) newOpError(op string, err error) *net.OpError {
return c.newRemoteOpError(op, nil, err)
}
-func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
return &net.OpError{
Op: op,
Net: "udp",
@@ -617,22 +633,22 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne
}
// RemoteAddr implements net.Conn.RemoteAddr.
-func (c *PacketConn) RemoteAddr() net.Addr {
+func (c *UDPConn) RemoteAddr() net.Addr {
a, err := c.ep.GetRemoteAddress()
if err != nil {
return nil
}
- return fullToTCPAddr(a)
+ return fullToUDPAddr(a)
}
// Read implements net.Conn.Read
-func (c *PacketConn) Read(b []byte) (int, error) {
+func (c *UDPConn) Read(b []byte) (int, error) {
bytesRead, _, err := c.ReadFrom(b)
return bytesRead, err
}
// ReadFrom implements net.PacketConn.ReadFrom.
-func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
@@ -644,12 +660,12 @@ func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return copy(b, read), fullToUDPAddr(addr), nil
}
-func (c *PacketConn) Write(b []byte) (int, error) {
+func (c *UDPConn) Write(b []byte) (int, error) {
return c.WriteTo(b, nil)
}
// WriteTo implements net.PacketConn.WriteTo.
-func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
deadline := c.writeCancel()
// Check if deadline has already expired.
@@ -707,13 +723,13 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// Close implements net.PacketConn.Close.
-func (c *PacketConn) Close() error {
+func (c *UDPConn) Close() error {
c.ep.Close()
return nil
}
// LocalAddr implements net.PacketConn.LocalAddr.
-func (c *PacketConn) LocalAddr() net.Addr {
+func (c *UDPConn) LocalAddr() net.Addr {
a, err := c.ep.GetLocalAddress()
if err != nil {
return nil
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 8ced960bb..3c552988a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -41,7 +41,7 @@ const (
)
func TestTimeouts(t *testing.T) {
- nc := NewConn(nil, nil)
+ nc := NewTCPConn(nil, nil)
dlfs := []struct {
name string
f func(time.Time) error
@@ -127,12 +127,16 @@ func TestCloseReader(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -151,10 +155,8 @@ func TestCloseReader(t *testing.T) {
buf := make([]byte, 256)
n, err := c.Read(buf)
- got, ok := err.(*net.OpError)
- want := tcpip.ErrConnectionAborted
- if n != 0 || !ok || got.Err.Error() != want.String() {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want)
+ if n != 0 || err != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, err)
}
}()
sender, err := connect(s, addr)
@@ -170,13 +172,17 @@ func TestCloseReader(t *testing.T) {
sender.close()
}
-// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
+// TestCloseReaderWithForwarder tests that TCPConn.Close wakes TCPConn.Read when
// using tcp.Forwarder.
func TestCloseReaderWithForwarder(t *testing.T) {
s, err := newLoopbackStack()
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -194,7 +200,7 @@ func TestCloseReaderWithForwarder(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
// Give c.Read() a chance to block before closing the connection.
time.AfterFunc(time.Millisecond*50, func() {
@@ -203,10 +209,8 @@ func TestCloseReaderWithForwarder(t *testing.T) {
buf := make([]byte, 256)
n, e := c.Read(buf)
- got, ok := e.(*net.OpError)
- want := tcpip.ErrConnectionAborted
- if n != 0 || !ok || got.Err.Error() != want.String() {
- t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want)
+ if n != 0 || e != io.EOF {
+ t.Errorf("c.Read() = (%d, %v), want (0, EOF)", n, e)
}
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -229,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
+ _, err := r.CreateEndpoint(&wq)
if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err)
}
- defer ep.Close()
- r.Complete(false)
-
- c := NewConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
+ // Endpoint will be closed in deferred s.Close (above).
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -261,7 +256,7 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseRead(); err != nil {
t.Errorf("c.CloseRead() = %v", err)
@@ -282,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -295,7 +294,7 @@ func TestCloseWrite(t *testing.T) {
defer ep.Close()
r.Complete(false)
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
n, e := c.Read(make([]byte, 256))
if n != 0 || e != io.EOF {
@@ -313,7 +312,7 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("connect() = %v", terr)
}
- c := NewConn(tc.wq, tc.ep)
+ c := NewTCPConn(tc.wq, tc.ep)
if err := c.CloseWrite(); err != nil {
t.Errorf("c.CloseWrite() = %v", err)
@@ -338,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -357,7 +360,7 @@ func TestUDPForwarder(t *testing.T) {
}
defer ep.Close()
- c := NewConn(&wq, ep)
+ c := NewTCPConn(&wq, ep)
buf := make([]byte, 256)
n, e := c.Read(buf)
@@ -395,12 +398,16 @@ func TestDeadlineChange(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
- l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, e := ListenTCP(s, addr, ipv4.ProtocolNumber)
if e != nil {
t.Fatalf("NewListener() = %v", e)
}
@@ -444,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -496,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -545,7 +560,7 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
addr := tcpip.FullAddress{NICID, ip, 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
- l, err := NewListener(s, addr, ipv4.ProtocolNumber)
+ l, err := ListenTCP(s, addr, ipv4.ProtocolNumber)
if err != nil {
return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
}
@@ -566,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() {
c1.Close()
c2.Close()
+ s.Close()
+ s.Wait()
}
if err := l.Close(); err != nil {
@@ -628,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -645,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -663,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index a7bf0c4dc..5e135c50d 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"prependable.go",
"view.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/buffer",
visibility = ["//visibility:public"],
)
@@ -20,5 +18,5 @@ go_test(
"prependable_test.go",
"view_test.go",
],
- embed = [":buffer"],
+ library = ":buffer",
)
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
index 2f9a23d61..57d1922ab 100644
--- a/pkg/tcpip/buffer/prependable.go
+++ b/pkg/tcpip/buffer/prependable.go
@@ -42,7 +42,7 @@ func NewPrependableFromView(v View, extraCap int) Prependable {
if extraCap == 0 {
return Prependable{buf: v, usedIdx: 0}
}
- buf := make([]byte, extraCap, extraCap + len(v))
+ buf := make([]byte, extraCap, extraCap+len(v))
buf = append(buf, v...)
return Prependable{buf: buf, usedIdx: extraCap}
}
@@ -83,3 +83,9 @@ func (p *Prependable) Prepend(size int) []byte {
p.usedIdx -= size
return p.View()[:size:size]
}
+
+// DeepCopy copies p and the bytes backing it.
+func (p Prependable) DeepCopy() Prependable {
+ p.buf = append(View(nil), p.buf...)
+ return p
+}
diff --git a/pkg/tcpip/buffer/prependable_test.go b/pkg/tcpip/buffer/prependable_test.go
index 43660c307..435a94a61 100644
--- a/pkg/tcpip/buffer/prependable_test.go
+++ b/pkg/tcpip/buffer/prependable_test.go
@@ -45,6 +45,6 @@ func TestNewPrependableFromView(t *testing.T) {
if !reflect.DeepEqual(prep, testCase.want) {
t.Errorf("NewPrependableFromView(%#v, %d) = %#v; want %#v", testCase.view, testCase.extraSize, prep, testCase.want)
}
- } )
+ })
}
}
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 150310c11..ea0c5413d 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -15,6 +15,11 @@
// Package buffer provides the implementation of a buffer view.
package buffer
+import (
+ "bytes"
+ "io"
+)
+
// View is a slice of a buffer, with convenience methods.
type View []byte
@@ -45,11 +50,31 @@ func (v *View) CapLength(length int) {
*v = (*v)[:length:length]
}
+// Reader returns a bytes.Reader for v.
+func (v *View) Reader() bytes.Reader {
+ var r bytes.Reader
+ r.Reset(*v)
+ return r
+}
+
// ToVectorisedView returns a VectorisedView containing the receiver.
func (v View) ToVectorisedView() VectorisedView {
+ if len(v) == 0 {
+ return VectorisedView{}
+ }
return NewVectorisedView(len(v), []View{v})
}
+// IsEmpty returns whether v is of length zero.
+func (v View) IsEmpty() bool {
+ return len(v) == 0
+}
+
+// Size returns the length of v.
+func (v View) Size() int {
+ return len(v)
+}
+
// VectorisedView is a vectorised version of View using non contiguous memory.
// It supports all the convenience methods supported by View.
//
@@ -65,7 +90,8 @@ func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
-// TrimFront removes the first "count" bytes of the vectorised view.
+// TrimFront removes the first "count" bytes of the vectorised view. It panics
+// if count > vv.Size().
func (vv *VectorisedView) TrimFront(count int) {
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
@@ -74,8 +100,49 @@ func (vv *VectorisedView) TrimFront(count int) {
return
}
count -= len(vv.views[0])
- vv.RemoveFirst()
+ vv.removeFirst()
+ }
+}
+
+// Read implements io.Reader.
+func (vv *VectorisedView) Read(v View) (copied int, err error) {
+ count := len(v)
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ copy(v[copied:], vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return copied, nil
+ }
+ count -= len(vv.views[0])
+ copy(v[copied:], vv.views[0])
+ copied += len(vv.views[0])
+ vv.removeFirst()
+ }
+ if copied == 0 {
+ return 0, io.EOF
+ }
+ return copied, nil
+}
+
+// ReadToVV reads up to n bytes from vv to dstVV and removes them from vv. It
+// returns the number of bytes copied.
+func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ dstVV.AppendView(vv.views[0][:count])
+ vv.views[0].TrimFront(count)
+ copied += count
+ return
+ }
+ count -= len(vv.views[0])
+ dstVV.AppendView(vv.views[0])
+ copied += len(vv.views[0])
+ vv.removeFirst()
}
+ return copied
}
// CapLength irreversibly reduces the length of the vectorised view.
@@ -105,29 +172,45 @@ func (vv *VectorisedView) CapLength(length int) {
// Clone returns a clone of this VectorisedView.
// If the buffer argument is large enough to contain all the Views of this VectorisedView,
// the method will avoid allocations and use the buffer to store the Views of the clone.
-func (vv VectorisedView) Clone(buffer []View) VectorisedView {
+func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
-// First returns the first view of the vectorised view.
-func (vv VectorisedView) First() View {
+// PullUp returns the first "count" bytes of the vectorised view. If those
+// bytes aren't already contiguous inside the vectorised view, PullUp will
+// reallocate as needed to make them contiguous. PullUp fails and returns false
+// when count > vv.Size().
+func (vv *VectorisedView) PullUp(count int) (View, bool) {
if len(vv.views) == 0 {
- return nil
+ return nil, count == 0
+ }
+ if count <= len(vv.views[0]) {
+ return vv.views[0][:count], true
+ }
+ if count > vv.size {
+ return nil, false
}
- return vv.views[0]
-}
-// RemoveFirst removes the first view of the vectorised view.
-func (vv *VectorisedView) RemoveFirst() {
- if len(vv.views) == 0 {
- return
+ newFirst := NewView(count)
+ i := 0
+ for offset := 0; offset < count; i++ {
+ copy(newFirst[offset:], vv.views[i])
+ if count-offset < len(vv.views[i]) {
+ vv.views[i].TrimFront(count - offset)
+ break
+ }
+ offset += len(vv.views[i])
+ vv.views[i] = nil
}
- vv.size -= len(vv.views[0])
- vv.views = vv.views[1:]
+ // We're guaranteed that i > 0, since count is too large for the first
+ // view.
+ vv.views[i-1] = newFirst
+ vv.views = vv.views[i-1:]
+ return newFirst, true
}
// Size returns the size in bytes of the entire content stored in the vectorised view.
-func (vv VectorisedView) Size() int {
+func (vv *VectorisedView) Size() int {
return vv.size
}
@@ -135,7 +218,7 @@ func (vv VectorisedView) Size() int {
//
// If the vectorised view contains a single view, that view will be returned
// directly.
-func (vv VectorisedView) ToView() View {
+func (vv *VectorisedView) ToView() View {
if len(vv.views) == 1 {
return vv.views[0]
}
@@ -147,7 +230,7 @@ func (vv VectorisedView) ToView() View {
}
// Views returns the slice containing the all views.
-func (vv VectorisedView) Views() []View {
+func (vv *VectorisedView) Views() []View {
return vv.views
}
@@ -156,3 +239,28 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) {
vv.views = append(vv.views, vv2.views...)
vv.size += vv2.size
}
+
+// AppendView appends the given view into this vectorised view.
+func (vv *VectorisedView) AppendView(v View) {
+ if len(v) == 0 {
+ return
+ }
+ vv.views = append(vv.views, v)
+ vv.size += len(v)
+}
+
+// Readers returns a bytes.Reader for each of vv's views.
+func (vv *VectorisedView) Readers() []bytes.Reader {
+ readers := make([]bytes.Reader, 0, len(vv.views))
+ for _, v := range vv.views {
+ readers = append(readers, v.Reader())
+ }
+ return readers
+}
+
+// removeFirst panics when len(vv.views) < 1.
+func (vv *VectorisedView) removeFirst() {
+ vv.size -= len(vv.views[0])
+ vv.views[0] = nil
+ vv.views = vv.views[1:]
+}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index ebc3a17b7..726e54de9 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -16,6 +16,7 @@
package buffer
import (
+ "bytes"
"reflect"
"testing"
)
@@ -233,3 +234,288 @@ func TestToClone(t *testing.T) {
})
}
}
+
+func TestVVReadToVV(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ wantBytes string
+ leftVV VectorisedView
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ wantBytes: "0123456789",
+ leftVV: vv(20, "01234567890123456789"),
+ },
+ {
+ comment: "largeVV, multiple views, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ wantBytes: "123345",
+ leftVV: vv(7, "567", "8910"),
+ },
+ {
+ comment: "smallVV (multiple views), large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ wantBytes: "123",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "smallVV (single view), large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ wantBytes: "1",
+ leftVV: vv(0, ""),
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ wantBytes: "",
+ leftVV: vv(0, ""),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ var readTo VectorisedView
+ inSize := tc.vv.Size()
+ copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead)
+ if got, want := copied, len(tc.wantBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc: %+v", got, want, tc)
+ }
+ if got, want := string(readTo.ToView()), tc.wantBytes; got != want {
+ t.Errorf("unexpected content in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want {
+ t.Errorf("unexpected data left in vv after read got: %+v, want: %+v", got, want)
+ }
+ })
+ }
+}
+
+func TestVVRead(t *testing.T) {
+ testCases := []struct {
+ comment string
+ vv VectorisedView
+ bytesToRead int
+ readBytes string
+ leftBytes string
+ wantError bool
+ }{
+ {
+ comment: "large VV, short read",
+ vv: vv(30, "012345678901234567890123456789"),
+ bytesToRead: 10,
+ readBytes: "0123456789",
+ leftBytes: "01234567890123456789",
+ },
+ {
+ comment: "largeVV, multiple buffers, short read",
+ vv: vv(13, "123", "345", "567", "8910"),
+ bytesToRead: 6,
+ readBytes: "123345",
+ leftBytes: "5678910",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(3, "1", "2", "3"),
+ bytesToRead: 10,
+ readBytes: "123",
+ leftBytes: "",
+ },
+ {
+ comment: "smallVV, large read",
+ vv: vv(1, "1"),
+ bytesToRead: 10,
+ readBytes: "1",
+ leftBytes: "",
+ },
+ {
+ comment: "emptyVV, large read",
+ vv: vv(0, ""),
+ bytesToRead: 10,
+ readBytes: "",
+ wantError: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.comment, func(t *testing.T) {
+ readTo := NewView(tc.bytesToRead)
+ inSize := tc.vv.Size()
+ copied, err := tc.vv.Read(readTo)
+ if !tc.wantError && err != nil {
+ t.Fatalf("unexpected error in tc.vv.Read(..) = %s", err)
+ }
+ readTo = readTo[:copied]
+ if got, want := copied, len(tc.readBytes); got != want {
+ t.Errorf("incorrect number of bytes copied returned in ReadToVV got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(readTo), tc.readBytes; got != want {
+ t.Errorf("unexpected data in readTo got: %s, want: %s", got, want)
+ }
+ if got, want := tc.vv.Size(), inSize-copied; got != want {
+ t.Errorf("test VV has incorrect size after reading got: %d, want: %d, tc.vv: %+v", got, want, tc.vv)
+ }
+ if got, want := string(tc.vv.ToView()), tc.leftBytes; got != want {
+ t.Errorf("vv has incorrect data after Read got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+var pullUpTestCases = []struct {
+ comment string
+ in VectorisedView
+ count int
+ want []byte
+ result VectorisedView
+ ok bool
+}{
+ {
+ comment: "simple case",
+ in: vv(2, "12"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(2, "12"),
+ ok: true,
+ },
+ {
+ comment: "entire View",
+ in: vv(2, "1", "2"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(2, "1", "2"),
+ ok: true,
+ },
+ {
+ comment: "spanning across two Views",
+ in: vv(3, "1", "23"),
+ count: 2,
+ want: []byte("12"),
+ result: vv(3, "12", "3"),
+ ok: true,
+ },
+ {
+ comment: "spanning across all Views",
+ in: vv(5, "1", "23", "45"),
+ count: 5,
+ want: []byte("12345"),
+ result: vv(5, "12345"),
+ ok: true,
+ },
+ {
+ comment: "count = 0",
+ in: vv(1, "1"),
+ count: 0,
+ want: []byte{},
+ result: vv(1, "1"),
+ ok: true,
+ },
+ {
+ comment: "count = size",
+ in: vv(1, "1"),
+ count: 1,
+ want: []byte("1"),
+ result: vv(1, "1"),
+ ok: true,
+ },
+ {
+ comment: "count too large",
+ in: vv(3, "1", "23"),
+ count: 4,
+ want: nil,
+ result: vv(3, "1", "23"),
+ ok: false,
+ },
+ {
+ comment: "empty vv",
+ in: vv(0, ""),
+ count: 1,
+ want: nil,
+ result: vv(0, ""),
+ ok: false,
+ },
+ {
+ comment: "empty vv, count = 0",
+ in: vv(0, ""),
+ count: 0,
+ want: nil,
+ result: vv(0, ""),
+ ok: true,
+ },
+ {
+ comment: "empty views",
+ in: vv(3, "", "1", "", "23"),
+ count: 2,
+ want: []byte("12"),
+ result: vv(3, "12", "3"),
+ ok: true,
+ },
+}
+
+func TestPullUp(t *testing.T) {
+ for _, c := range pullUpTestCases {
+ got, ok := c.in.PullUp(c.count)
+
+ // Is the return value right?
+ if ok != c.ok {
+ t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t",
+ c.comment, c.count, c.in, ok, c.ok)
+ }
+ if bytes.Compare(got, View(c.want)) != 0 {
+ t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v",
+ c.comment, c.count, c.in, got, c.want)
+ }
+
+ // Is the underlying structure right?
+ if !reflect.DeepEqual(c.in, c.result) {
+ t.Errorf("Test %q failed when calling PullUp(%d). Got vv with structure %v. Wanted %v",
+ c.comment, c.count, c.in, c.result)
+ }
+ }
+}
+
+func TestToVectorisedView(t *testing.T) {
+ testCases := []struct {
+ in View
+ want VectorisedView
+ }{
+ {nil, VectorisedView{}},
+ {View{}, VectorisedView{}},
+ {View{'a'}, VectorisedView{size: 1, views: []View{{'a'}}}},
+ }
+ for _, tc := range testCases {
+ if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
+ }
+ }
+}
+
+func TestAppendView(t *testing.T) {
+ testCases := []struct {
+ vv VectorisedView
+ in View
+ want VectorisedView
+ }{
+ {VectorisedView{}, nil, VectorisedView{}},
+ {VectorisedView{}, View{}, VectorisedView{}},
+ {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, nil, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}},
+ {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}},
+ {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{'e'}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}, {'e'}}, 5}},
+ }
+ for _, tc := range testCases {
+ tc.vv.AppendView(tc.in)
+ if got, want := tc.vv, tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("(%v).ToVectorisedView failed got: %+v, want: %+v", tc.in, got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index b6fa6fc37..c984470e6 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,12 +6,12 @@ go_library(
name = "checker",
testonly = 1,
srcs = ["checker.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/checker",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 2f15bf1f1..b769094dc 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -21,6 +21,7 @@ import (
"reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -33,6 +34,9 @@ type NetworkChecker func(*testing.T, []header.Network)
// TransportChecker is a function to check a property of a transport packet.
type TransportChecker func(*testing.T, header.Transport)
+// ControlMessagesChecker is a function to check a property of ancillary data.
+type ControlMessagesChecker func(*testing.T, tcpip.ControlMessages)
+
// IPv4 checks the validity and properties of the given IPv4 packet. It is
// expected to be used in conjunction with other network checkers for specific
// properties. For example, to check the source and destination address, one
@@ -104,6 +108,8 @@ func DstAddr(addr tcpip.Address) NetworkChecker {
// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
func TTL(ttl uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
var v uint8
switch ip := h[0].(type) {
case header.IPv4:
@@ -158,6 +164,44 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTClass creates a checker that checks the TCLASS field in
+// ControlMessages.
+func ReceiveTClass(want uint32) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTClass {
+ t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
+ } else if got := cm.TClass; got != want {
+ t.Errorf("got cm.TClass = %d, want %d", got, want)
+ }
+ }
+}
+
+// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
+func ReceiveTOS(want uint8) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTOS {
+ t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
+ } else if got := cm.TOS; got != want {
+ t.Errorf("got cm.TOS = %d, want %d", got, want)
+ }
+ }
+}
+
+// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
+// ControlMessages.
+func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPPacketInfo {
+ t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
+ } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
+ t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
@@ -280,12 +324,30 @@ func SrcPort(port uint16) TransportChecker {
// DstPort creates a checker that checks the destination port.
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if p := h.DestinationPort(); p != port {
t.Errorf("Bad destination port, got %v, want %v", p, port)
}
}
}
+// NoChecksum creates a checker that checks if the checksum is zero.
+func NoChecksum(noChecksum bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ udp, ok := h.(header.UDP)
+ if !ok {
+ return
+ }
+
+ if b := udp.Checksum() == 0; b != noChecksum {
+ t.Errorf("bad checksum state, got %t, want %t", b, noChecksum)
+ }
+ }
+}
+
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
@@ -306,6 +368,7 @@ func SeqNum(seq uint32) TransportChecker {
func AckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -320,6 +383,8 @@ func AckNum(seq uint32) TransportChecker {
// Window creates a checker that checks the tcp window.
func Window(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -351,6 +416,8 @@ func TCPFlags(flags uint8) TransportChecker {
// given mask, match the supplied flags.
func TCPFlagsMatch(flags, mask uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -368,6 +435,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
// If wndscale is negative, the window scale option must not be present.
func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -464,6 +533,8 @@ func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
// skipped.
func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
tcp, ok := h.(header.TCP)
if !ok {
return
@@ -582,6 +653,8 @@ func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
// Payload creates a checker that checks the payload.
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
if got := h.Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
@@ -614,6 +687,7 @@ func ICMPv4(checkers ...TransportChecker) NetworkChecker {
func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -625,9 +699,10 @@ func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
}
// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
-func ICMPv4Code(want byte) TransportChecker {
+func ICMPv4Code(want header.ICMPv4Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv4, ok := h.(header.ICMPv4)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
@@ -670,6 +745,7 @@ func ICMPv6(checkers ...TransportChecker) NetworkChecker {
func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -681,9 +757,10 @@ func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
}
// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
-func ICMPv6Code(want byte) TransportChecker {
+func ICMPv6Code(want header.ICMPv6Code) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
+
icmpv6, ok := h.(header.ICMPv6)
if !ok {
t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
@@ -698,7 +775,7 @@ func ICMPv6Code(want byte) TransportChecker {
// message for type of ty, with potentially additional checks specified by
// checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
// NDP message as far as the size of the message (minSize) is concerned. The
// values within the message are up to checkers to validate.
func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) NetworkChecker {
@@ -730,9 +807,9 @@ func NDP(msgType header.ICMPv6Type, minSize int, checkers ...TransportChecker) N
// Neighbor Solicitation message (as per the raw wire format), with potentially
// additional checks specified by checkers.
//
-// checkers may assume that a valid ICMPv6 is passed to it containing a valid
-// NDPNS message as far as the size of the messages concerned. The values within
-// the message are up to checkers to validate.
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNS message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
func NDPNS(checkers ...TransportChecker) NetworkChecker {
return NDP(header.ICMPv6NeighborSolicit, header.NDPNSMinimumSize, checkers...)
}
@@ -750,7 +827,162 @@ func NDPNSTargetAddress(want tcpip.Address) TransportChecker {
ns := header.NDPNeighborSolicit(icmp.NDPPayload())
if got := ns.TargetAddress(); got != want {
- t.Fatalf("got %T.TargetAddress = %s, want = %s", ns, got, want)
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", ns, got, want)
+ }
+ }
+}
+
+// NDPNA creates a checker that checks that the packet contains a valid NDP
+// Neighbor Advertisement message (as per the raw wire format), with potentially
+// additional checks specified by checkers.
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPNA message as far as the size of the message is concerned. The values
+// within the message are up to checkers to validate.
+func NDPNA(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6NeighborAdvert, header.NDPNAMinimumSize, checkers...)
+}
+
+// NDPNATargetAddress creates a checker that checks the Target Address field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNATargetAddress(want tcpip.Address) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.TargetAddress(); got != want {
+ t.Errorf("got %T.TargetAddress() = %s, want = %s", na, got, want)
}
}
}
+
+// NDPNASolicitedFlag creates a checker that checks the Solicited field of
+// a header.NDPNeighborAdvert.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNASolicitedFlag(want bool) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+
+ if got := na.SolicitedFlag(); got != want {
+ t.Errorf("got %T.SolicitedFlag = %t, want = %t", na, got, want)
+ }
+ }
+}
+
+// ndpOptions checks that optsBuf only contains opts.
+func ndpOptions(t *testing.T, optsBuf header.NDPOptions, opts []header.NDPOption) {
+ t.Helper()
+
+ it, err := optsBuf.Iter(true)
+ if err != nil {
+ t.Errorf("optsBuf.Iter(true): %s", err)
+ return
+ }
+
+ i := 0
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ t.Fatalf("unexpected error when iterating over NDP options: %s", err)
+ }
+ if done {
+ break
+ }
+
+ if i >= len(opts) {
+ t.Errorf("got unexpected option: %s", opt)
+ continue
+ }
+
+ switch wantOpt := opts[i].(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPSourceLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ case header.NDPTargetLinkLayerAddressOption:
+ gotOpt, ok := opt.(header.NDPTargetLinkLayerAddressOption)
+ if !ok {
+ t.Errorf("got type = %T at index = %d; want = %T", opt, i, wantOpt)
+ } else if got, want := gotOpt.EthernetAddress(), wantOpt.EthernetAddress(); got != want {
+ t.Errorf("got EthernetAddress() = %s at index %d, want = %s", got, i, want)
+ }
+ default:
+ t.Fatalf("checker not implemented for expected NDP option: %T", wantOpt)
+ }
+
+ i++
+ }
+
+ if missing := opts[i:]; len(missing) > 0 {
+ t.Errorf("missing options: %s", missing)
+ }
+}
+
+// NDPNAOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNA message as far as the size is concerned.
+func NDPNAOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ na := header.NDPNeighborAdvert(icmp.NDPPayload())
+ ndpOptions(t, na.Options(), opts)
+ }
+}
+
+// NDPNSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Neighbor Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPNS message as far as the size is concerned.
+func NDPNSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ ns := header.NDPNeighborSolicit(icmp.NDPPayload())
+ ndpOptions(t, ns.Options(), opts)
+ }
+}
+
+// NDPRS creates a checker that checks that the packet contains a valid NDP
+// Router Solicitation message (as per the raw wire format).
+//
+// Checkers may assume that a valid ICMPv6 is passed to it containing a valid
+// NDPRS as far as the size of the message is concerned. The values within the
+// message are up to checkers to validate.
+func NDPRS(checkers ...TransportChecker) NetworkChecker {
+ return NDP(header.ICMPv6RouterSolicit, header.NDPRSMinimumSize, checkers...)
+}
+
+// NDPRSOptions creates a checker that checks that the packet contains the
+// provided NDP options within an NDP Router Solicitation message.
+//
+// The returned TransportChecker assumes that a valid ICMPv6 is passed to it
+// containing a valid NDPRS message as far as the size is concerned.
+func NDPRSOptions(opts []header.NDPOption) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+
+ icmp := h.(header.ICMPv6)
+ rs := header.NDPRouterSolicit(icmp.NDPPayload())
+ ndpOptions(t, rs.Options(), opts)
+ }
+}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index 0c5c20cea..ff2719291 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,15 +1,11 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "jenkins",
srcs = ["jenkins.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
)
go_test(
@@ -18,5 +14,5 @@ go_test(
srcs = [
"jenkins_test.go",
],
- embed = [":jenkins"],
+ library = ":jenkins",
)
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index a3485b35c..d87797617 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,5 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -15,15 +14,17 @@ go_library(
"interfaces.go",
"ipv4.go",
"ipv6.go",
+ "ipv6_extension_headers.go",
"ipv6_fragment.go",
"ndp_neighbor_advert.go",
"ndp_neighbor_solicit.go",
"ndp_options.go",
"ndp_router_advert.go",
+ "ndp_router_solicit.go",
+ "ndpoptionidentifier_string.go",
"tcp.go",
"udp.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/header",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
@@ -38,12 +39,16 @@ go_test(
size = "small",
srcs = [
"checksum_test.go",
+ "ipv6_test.go",
"ipversion_test.go",
"tcp_test.go",
],
deps = [
":header",
+ "//pkg/rand",
+ "//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
@@ -52,8 +57,13 @@ go_test(
size = "small",
srcs = [
"eth_test.go",
+ "ipv6_extension_headers_test.go",
"ndp_test.go",
],
- embed = [":header"],
- deps = ["//pkg/tcpip"],
+ library = ":header",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
)
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
index 718a4720a..83189676e 100644
--- a/pkg/tcpip/header/arp.go
+++ b/pkg/tcpip/header/arp.go
@@ -14,14 +14,33 @@
package header
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// ARPProtocolNumber is the ARP network protocol number.
ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
// ARPSize is the size of an IPv4-over-Ethernet ARP packet.
- ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+ ARPSize = 28
+)
+
+// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header.
+type ARPHardwareType uint16
+
+// Typical ARP HardwareType values. Some of the constants have to be specific
+// values as they are egressed on the wire in the HTYPE field of an ARP header.
+const (
+ ARPHardwareNone ARPHardwareType = 0
+ // ARPHardwareEther specifically is the HTYPE for Ethernet as specified
+ // in the IANA list here:
+ //
+ // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2
+ ARPHardwareEther ARPHardwareType = 1
+ ARPHardwareLoopback ARPHardwareType = 2
)
// ARPOp is an ARP opcode.
@@ -36,54 +55,64 @@ const (
// ARP is an ARP packet stored in a byte array as described in RFC 826.
type ARP []byte
-func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
-func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
-func (a ARP) hardwareAddressSize() int { return int(a[4]) }
-func (a ARP) protocolAddressSize() int { return int(a[5]) }
+const (
+ hTypeOffset = 0
+ protocolOffset = 2
+ haAddressSizeOffset = 4
+ protoAddressSizeOffset = 5
+ opCodeOffset = 6
+ senderHAAddressOffset = 8
+ senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize
+ targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize
+ targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize
+)
+
+func (a ARP) hardwareAddressType() ARPHardwareType {
+ return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:]))
+}
+
+func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) }
+func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) }
+func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) }
// Op is the ARP opcode.
-func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) }
// SetOp sets the ARP opcode.
func (a ARP) SetOp(op ARPOp) {
- a[6] = uint8(op >> 8)
- a[7] = uint8(op)
+ binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op))
}
// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
func (a ARP) SetIPv4OverEthernet() {
- a[0], a[1] = 0, 1 // htypeEthernet
- a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
- a[4] = 6 // macSize
- a[5] = uint8(IPv4AddressSize)
+ binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther))
+ binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber))
+ a[haAddressSizeOffset] = EthernetAddressSize
+ a[protoAddressSizeOffset] = uint8(IPv4AddressSize)
}
// HardwareAddressSender is the link address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressSender() []byte {
- const s = 8
- return a[s : s+6]
+ return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressSender is the protocol address of the sender.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressSender() []byte {
- const s = 8 + 6
- return a[s : s+4]
+ return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize]
}
// HardwareAddressTarget is the link address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) HardwareAddressTarget() []byte {
- const s = 8 + 6 + 4
- return a[s : s+6]
+ return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize]
}
// ProtocolAddressTarget is the protocol address of the target.
// It is a view on to the ARP packet so it can be used to set the value.
func (a ARP) ProtocolAddressTarget() []byte {
- const s = 8 + 6 + 4 + 6
- return a[s : s+4]
+ return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize]
}
// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
@@ -91,10 +120,8 @@ func (a ARP) IsValid() bool {
if len(a) < ARPSize {
return false
}
- const htypeEthernet = 1
- const macSize = 6
- return a.hardwareAddressSpace() == htypeEthernet &&
+ return a.hardwareAddressType() == ARPHardwareEther &&
a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
- a.hardwareAddressSize() == macSize &&
+ a.hardwareAddressSize() == EthernetAddressSize &&
a.protocolAddressSize() == IPv4AddressSize
}
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
index 9749c7f4d..14a4b2b44 100644
--- a/pkg/tcpip/header/checksum.go
+++ b/pkg/tcpip/header/checksum.go
@@ -45,12 +45,139 @@ func calculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
return ChecksumCombine(uint16(v), uint16(v>>16)), odd
}
+func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bool) {
+ v := initial
+
+ if odd {
+ v += uint32(buf[0])
+ buf = buf[1:]
+ }
+
+ l := len(buf)
+ odd = l&1 != 0
+ if odd {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+ for (l - 64) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[64:]
+ l = l - 64
+ }
+ if (l - 32) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ i += 16
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[32:]
+ l = l - 32
+ }
+ if (l - 16) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
+ v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
+ v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
+ v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
+ buf = buf[16:]
+ l = l - 16
+ }
+ if (l - 8) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
+ v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
+ buf = buf[8:]
+ l = l - 8
+ }
+ if (l - 4) >= 0 {
+ i := 0
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
+ buf = buf[4:]
+ l = l - 4
+ }
+
+ // At this point since l was even before we started unrolling
+ // there can be only two bytes left to add.
+ if l != 0 {
+ v += (uint32(buf[0]) << 8) + uint32(buf[1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16)), odd
+}
+
+// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in
+// the given byte array. This function uses a non-optimized implementation. Its
+// only retained for reference and to use as a benchmark/test. Most code should
+// use the header.Checksum function.
+//
+// The initial checksum must have been computed on an even number of bytes.
+func ChecksumOld(buf []byte, initial uint16) uint16 {
+ s, _ := calculateChecksum(buf, false, uint32(initial))
+ return s
+}
+
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
-// given byte array.
+// given byte array. This function uses an optimized unrolled version of the
+// checksum algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
- s, _ := calculateChecksum(buf, false, uint32(initial))
+ s, _ := unrolledCalculateChecksum(buf, false, uint32(initial))
return s
}
@@ -86,7 +213,7 @@ func ChecksumVVWithOffset(vv buffer.VectorisedView, initial uint16, off int, siz
}
v = v[:l]
- sum, odd = calculateChecksum(v, odd, uint32(sum))
+ sum, odd = unrolledCalculateChecksum(v, odd, uint32(sum))
size -= len(v)
if size == 0 {
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go
index 86b466c1c..309403482 100644
--- a/pkg/tcpip/header/checksum_test.go
+++ b/pkg/tcpip/header/checksum_test.go
@@ -17,6 +17,8 @@
package header_test
import (
+ "fmt"
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) {
})
}
}
+
+func TestChecksum(t *testing.T) {
+ var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024}
+ type testCase struct {
+ buf []byte
+ initial uint16
+ csumOrig uint16
+ csumNew uint16
+ }
+ testCases := make([]testCase, 100000)
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for i := range testCases {
+ testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)])
+ testCases[i].initial = uint16(rnd.Intn(65536))
+ rnd.Read(testCases[i].buf)
+ }
+
+ for i := range testCases {
+ testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial)
+ testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial)
+ if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want {
+ t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want)
+ }
+ }
+}
+
+func BenchmarkChecksum(b *testing.B) {
+ var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536}
+
+ checkSumImpls := []struct {
+ fn func([]byte, uint16) uint16
+ name string
+ }{
+ {header.ChecksumOld, fmt.Sprintf("checksum_old")},
+ {header.Checksum, fmt.Sprintf("checksum")},
+ }
+
+ for _, csumImpl := range checkSumImpls {
+ // Ensure same buffer generation for test consistency.
+ rnd := rand.New(rand.NewSource(42))
+ for _, bufSz := range bufSizes {
+ b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) {
+ tc := struct {
+ buf []byte
+ initial uint16
+ csum uint16
+ }{
+ buf: make([]byte, bufSz),
+ initial: uint16(rnd.Intn(65536)),
+ }
+ rnd.Read(tc.buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ tc.csum = csumImpl.fn(tc.buf, tc.initial)
+ }
+ })
+ }
+ }
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
index f5d2c127f..eaface8cb 100644
--- a/pkg/tcpip/header/eth.go
+++ b/pkg/tcpip/header/eth.go
@@ -53,6 +53,10 @@ const (
// (all bits set to 0).
unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
+ // EthernetBroadcastAddress is an ethernet address that addresses every node
+ // on a local link.
+ EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff")
+
// unicastMulticastFlagMask is the mask of the least significant bit in
// the first octet (in network byte order) of an ethernet address that
// determines whether the ethernet address is a unicast or multicast. If
@@ -134,3 +138,44 @@ func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
// addr is a valid unicast ethernet address.
return true
}
+
+// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address
+// for a multicast IPv4 address.
+//
+// addr MUST be a multicast IPv4 address.
+func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress {
+ var linkAddrBytes [EthernetAddressSize]byte
+ // RFC 1112 Host Extensions for IP Multicasting
+ //
+ // 6.4. Extensions to an Ethernet Local Network Module:
+ //
+ // An IP host group address is mapped to an Ethernet multicast
+ // address by placing the low-order 23-bits of the IP address
+ // into the low-order 23 bits of the Ethernet multicast address
+ // 01-00-5E-00-00-00 (hex).
+ linkAddrBytes[0] = 0x1
+ linkAddrBytes[2] = 0x5e
+ linkAddrBytes[3] = addr[1] & 0x7F
+ copy(linkAddrBytes[4:], addr[IPv4AddressSize-2:])
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
+
+// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address
+// for a multicast IPv6 address.
+//
+// addr MUST be a multicast IPv6 address.
+func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress {
+ // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
+ //
+ // 7. Address Mapping -- Multicast
+ //
+ // An IPv6 packet with a multicast destination address DST,
+ // consisting of the sixteen octets DST[1] through DST[16], is
+ // transmitted to the Ethernet multicast address whose first
+ // two octets are the value 3333 hexadecimal and whose last
+ // four octets are the last four octets of DST.
+ linkAddrBytes := []byte(addr[IPv6AddressSize-EthernetAddressSize:])
+ linkAddrBytes[0] = 0x33
+ linkAddrBytes[1] = 0x33
+ return tcpip.LinkAddress(linkAddrBytes[:])
+}
diff --git a/pkg/tcpip/header/eth_test.go b/pkg/tcpip/header/eth_test.go
index 6634c90f5..14413f2ce 100644
--- a/pkg/tcpip/header/eth_test.go
+++ b/pkg/tcpip/header/eth_test.go
@@ -66,3 +66,37 @@ func TestIsValidUnicastEthernetAddress(t *testing.T) {
})
}
}
+
+func TestEthernetAddressFromMulticastIPv4Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4 Multicast without 24th bit set",
+ addr: "\xe0\x7e\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ {
+ name: "IPv4 Multicast with 24th bit set",
+ addr: "\xe0\xfe\xdc\xba",
+ expectedLinkAddr: "\x01\x00\x5e\x7e\xdc\xba",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := EthernetAddressFromMulticastIPv4Address(test.addr); got != test.expectedLinkAddr {
+ t.Fatalf("got EthernetAddressFromMulticastIPv4Address(%s) = %s, want = %s", test.addr, got, test.expectedLinkAddr)
+ }
+ })
+ }
+}
+
+func TestEthernetAddressFromMulticastIPv6Address(t *testing.T) {
+ addr := tcpip.Address("\xff\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x1a")
+ if got, want := EthernetAddressFromMulticastIPv6Address(addr), tcpip.LinkAddress("\x33\x33\x0d\x0e\x0f\x1a"); got != want {
+ t.Fatalf("got EthernetAddressFromMulticastIPv6Address(%s) = %s, want = %s", addr, got, want)
+ }
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index 0cac6c0a5..be03fb086 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -54,6 +54,9 @@ const (
// ICMPv4Type is the ICMP type field described in RFC 792.
type ICMPv4Type byte
+// ICMPv4Code is the ICMP code field described in RFC 792.
+type ICMPv4Code byte
+
// Typical values of ICMPv4Type defined in RFC 792.
const (
ICMPv4EchoReply ICMPv4Type = 0
@@ -69,10 +72,13 @@ const (
ICMPv4InfoReply ICMPv4Type = 16
)
-// Values for ICMP code as defined in RFC 792.
+// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792.
const (
- ICMPv4PortUnreachable = 3
- ICMPv4FragmentationNeeded = 4
+ ICMPv4TTLExceeded ICMPv4Code = 0
+ ICMPv4HostUnreachable ICMPv4Code = 1
+ ICMPv4ProtoUnreachable ICMPv4Code = 2
+ ICMPv4PortUnreachable ICMPv4Code = 3
+ ICMPv4FragmentationNeeded ICMPv4Code = 4
)
// Type is the ICMP type field.
@@ -82,10 +88,10 @@ func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
-func (b ICMPv4) Code() byte { return b[1] }
+func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
-func (b ICMPv4) SetCode(c byte) { b[1] = c }
+func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index b4037b6c8..20b01d8f4 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -52,7 +52,7 @@ const (
// ICMPv6NeighborAdvertSize is size of a neighbor advertisement
// including the NDP Target Link Layer option for an Ethernet
// address.
- ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + ndpTargetEthernetLinkLayerAddressSize
+ ICMPv6NeighborAdvertSize = ICMPv6HeaderSize + NDPNAMinimumSize + NDPLinkLayerAddressSize
// ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv6EchoMinimumSize = 8
@@ -92,7 +92,6 @@ const (
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
type ICMPv6Type byte
-// Typical values of ICMPv6Type defined in RFC 4443.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
@@ -110,11 +109,38 @@ const (
ICMPv6RedirectMsg ICMPv6Type = 137
)
-// Values for ICMP code as defined in RFC 4443.
+// ICMPv6Code is the ICMP code field described in RFC 4443.
+type ICMPv6Code byte
+
+// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443
+// section 3.1.
+const (
+ ICMPv6NetworkUnreachable ICMPv6Code = 0
+ ICMPv6Prohibited ICMPv6Code = 1
+ ICMPv6BeyondScope ICMPv6Code = 2
+ ICMPv6AddressUnreachable ICMPv6Code = 3
+ ICMPv6PortUnreachable ICMPv6Code = 4
+ ICMPv6Policy ICMPv6Code = 5
+ ICMPv6RejectRoute ICMPv6Code = 6
+)
+
+// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3.
const (
- ICMPv6PortUnreachable = 4
+ ICMPv6HopLimitExceeded ICMPv6Code = 0
+ ICMPv6ReassemblyTimeout ICMPv6Code = 1
)
+// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
+const (
+ ICMPv6ErroneousHeader ICMPv6Code = 0
+ ICMPv6UnknownHeader ICMPv6Code = 1
+ ICMPv6UnknownOption ICMPv6Code = 2
+)
+
+// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
+// the code field. (Types not mentioned above.)
+const ICMPv6UnusedCode ICMPv6Code = 0
+
// Type is the ICMP type field.
func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
@@ -122,10 +148,10 @@ func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
-func (b ICMPv6) Code() byte { return b[1] }
+func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
-func (b ICMPv6) SetCode(c byte) { b[1] = c }
+func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index e5360e7c1..680eafd16 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -38,7 +38,8 @@ const (
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv4Fields struct {
- // IHL is the "internet header length" field of an IPv4 packet.
+ // IHL is the "internet header length" field of an IPv4 packet. The value
+ // is in bytes.
IHL uint8
// TOS is the "type of service" field of an IPv4 packet.
@@ -100,6 +101,11 @@ const (
// IPv4Version is the version of the ipv4 protocol.
IPv4Version = 4
+ // IPv4AllSystems is the all systems IPv4 multicast address as per
+ // IANA's IPv4 Multicast Address Space Registry. See
+ // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml.
+ IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01"
+
// IPv4Broadcast is the broadcast address of the IPv4 procotol.
IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
@@ -138,7 +144,7 @@ func IPVersion(b []byte) int {
}
// HeaderLength returns the value of the "header length" field of the ipv4
-// header.
+// header. The length returned is in bytes.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & 0xf) * 4
}
@@ -158,6 +164,11 @@ func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
+// More returns whether the more fragments flag is set.
+func (b IPv4) More() bool {
+ return b.Flags()&IPv4FlagMoreFragments != 0
+}
+
// TTL returns the "TTL" field of the ipv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
@@ -304,3 +315,12 @@ func IsV4MulticastAddress(addr tcpip.Address) bool {
}
return (addr[0] & 0xf0) == 0xe0
}
+
+// IsV4LoopbackAddress determines if the provided address is an IPv4 loopback
+// address (belongs to 127.0.0.1/8 subnet).
+func IsV4LoopbackAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv4AddressSize {
+ return false
+ }
+ return addr[0] == 0x7f
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index f1e60911b..ea3823898 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -15,7 +15,9 @@
package header
import (
+ "crypto/sha256"
"encoding/binary"
+ "fmt"
"strings"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -26,7 +28,9 @@ const (
// IPv6PayloadLenOffset is the offset of the PayloadLength field in
// IPv6 header.
IPv6PayloadLenOffset = 4
- nextHdr = 6
+ // IPv6NextHeaderOffset is the offset of the NextHeader field in
+ // IPv6 header.
+ IPv6NextHeaderOffset = 6
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
@@ -83,16 +87,58 @@ const (
// The address is ff02::1.
IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ // IPv6AllRoutersMulticastAddress is a link-local multicast group that
+ // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers on a link.
+ //
+ // The address is ff02::2.
+ IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
+ // IPv6Loopback is the IPv6 Loopback address.
+ IPv6Loopback tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
// known as the unspecified address.
IPv6Any tcpip.Address = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
+
+ // IIDSize is the size of an interface identifier (IID), in bytes, as
+ // defined by RFC 4291 section 2.5.1.
+ IIDSize = 8
+
+ // IIDOffsetInIPv6Address is the offset, in bytes, from the start
+ // of an IPv6 address to the beginning of the interface identifier
+ // (IID) for auto-generated addresses. That is, all bytes before
+ // the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all
+ // bytes including and after the IIDOffsetInIPv6Address-th byte are
+ // for the IID.
+ IIDOffsetInIPv6Address = 8
+
+ // OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
+ // for the secret key used to generate an opaque interface identifier as
+ // outlined by RFC 7217.
+ OpaqueIIDSecretKeyMinBytes = 16
+
+ // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field
+ // is located within a multicast IPv6 address, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeByteIdx = 1
+
+ // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
+ // within the byte holding the field, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeMask = 0xF
+
+ // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within
+ // a multicast IPv6 address that indicates the address has link-local scope,
+ // as per RFC 4291 section 2.7.
+ ipv6LinkLocalMulticastScope = 2
)
-// IPv6EmptySubnet is the empty IPv6 subnet.
+// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
+// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
+// be contained within this subnet.
var IPv6EmptySubnet = func() tcpip.Subnet {
subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any))
if err != nil {
@@ -123,7 +169,7 @@ func (b IPv6) HopLimit() uint8 {
// NextHeader returns the value of the "next header" field of the ipv6 header.
func (b IPv6) NextHeader() uint8 {
- return b[nextHdr]
+ return b[IPv6NextHeaderOffset]
}
// TransportProtocol implements Network.TransportProtocol.
@@ -183,7 +229,7 @@ func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
func (b IPv6) SetNextHeader(v uint8) {
- b[nextHdr] = v
+ b[IPv6NextHeaderOffset] = v
}
// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
@@ -195,7 +241,7 @@ func (IPv6) SetChecksum(uint16) {
func (b IPv6) Encode(i *IPv6Fields) {
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
- b[nextHdr] = i.NextHeader
+ b[IPv6NextHeaderOffset] = i.NextHeader
b[hopLimit] = i.HopLimit
b.SetSourceAddress(i.SrcAddr)
b.SetDestinationAddress(i.DstAddr)
@@ -264,27 +310,43 @@ func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
+// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
+// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
+//
+// buf MUST be at least 8 bytes.
+func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) {
+ buf[0] = linkAddr[0] ^ 2
+ buf[1] = linkAddr[1]
+ buf[2] = linkAddr[2]
+ buf[3] = 0xFF
+ buf[4] = 0xFE
+ buf[5] = linkAddr[3]
+ buf[6] = linkAddr[4]
+ buf[7] = linkAddr[5]
+}
+
+// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit
+// Ethernet/MAC address, as per RFC 4291 section 2.5.1.
+func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte {
+ var buf [IIDSize]byte
+ EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
+ return buf
+}
+
// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
// (MAC) address.
func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
- // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local
- // header, FE80::.
+ // Convert a 48-bit MAC to a modified EUI-64 and then prepend the
+ // link-local header, FE80::.
//
// The conversion is very nearly:
// aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
// Note the capital A. The conversion aa->Aa involves a bit flip.
- lladdrb := [16]byte{
- 0: 0xFE,
- 1: 0x80,
- 8: linkAddr[0] ^ 2,
- 9: linkAddr[1],
- 10: linkAddr[2],
- 11: 0xFF,
- 12: 0xFE,
- 13: linkAddr[3],
- 14: linkAddr[4],
- 15: linkAddr[5],
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
}
+ EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:])
return tcpip.Address(lladdrb[:])
}
@@ -296,3 +358,145 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
}
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+
+// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
+// link-local multicast address.
+func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
+ return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
+}
+
+// IsV6UniqueLocalAddress determines if the provided address is an IPv6
+// unique-local address (within the prefix FC00::/7).
+func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+ // According to RFC 4193 section 3.1, a unique local address has the prefix
+ // FC00::/7.
+ return (addr[0] & 0xfe) == 0xfc
+}
+
+// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
+// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
+//
+// The opaque IID is generated from the cryptographic hash of the concatenation
+// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
+// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
+// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
+// requirements for security.
+//
+// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
+// array for the buffer will not be allocated.
+func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
+ // As per RFC 7217 section 5, the opaque identifier can be generated as a
+ // cryptographic hash of the concatenation of each of the function parameters.
+ // Note, we omit the optional Network_ID field.
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write([]byte(prefix.ID()[:IIDOffsetInIPv6Address]))
+ h.Write([]byte(nicName))
+ h.Write([]byte{dadCounter})
+ h.Write(secretKey)
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ return append(buf, sum[:IIDSize]...)
+}
+
+// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
+// an opaque IID.
+func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
+ lladdrb := [IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ return tcpip.Address(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey))
+}
+
+// IPv6AddressScope is the scope of an IPv6 address.
+type IPv6AddressScope int
+
+const (
+ // LinkLocalScope indicates a link-local address.
+ LinkLocalScope IPv6AddressScope = iota
+
+ // UniqueLocalScope indicates a unique-local address.
+ UniqueLocalScope
+
+ // GlobalScope indicates a global address.
+ GlobalScope
+)
+
+// ScopeForIPv6Address returns the scope for an IPv6 address.
+func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
+ if len(addr) != IPv6AddressSize {
+ return GlobalScope, tcpip.ErrBadAddress
+ }
+
+ switch {
+ case IsV6LinkLocalMulticastAddress(addr):
+ return LinkLocalScope, nil
+
+ case IsV6LinkLocalAddress(addr):
+ return LinkLocalScope, nil
+
+ case IsV6UniqueLocalAddress(addr):
+ return UniqueLocalScope, nil
+
+ default:
+ return GlobalScope, nil
+ }
+}
+
+// InitialTempIID generates the initial temporary IID history value to generate
+// temporary SLAAC addresses with.
+//
+// Panics if initialTempIIDHistory is not at least IIDSize bytes.
+func InitialTempIID(initialTempIIDHistory []byte, seed []byte, nicID tcpip.NICID) {
+ h := sha256.New()
+ // h.Write never returns an error.
+ h.Write(seed)
+ var nicIDBuf [4]byte
+ binary.BigEndian.PutUint32(nicIDBuf[:], uint32(nicID))
+ h.Write(nicIDBuf[:])
+
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ if n := copy(initialTempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
+ }
+}
+
+// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an
+// associated stable/permanent SLAAC address.
+//
+// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be
+// used when generating a new temporary IID.
+//
+// Panics if tempIIDHistory is not at least IIDSize bytes.
+func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ addrBytes := []byte(stableAddr)
+ h := sha256.New()
+ h.Write(tempIIDHistory)
+ h.Write(addrBytes[IIDOffsetInIPv6Address:])
+ var sumBuf [sha256.Size]byte
+ sum := h.Sum(sumBuf[:0])
+
+ // The rightmost 64 bits of sum are saved for the next iteration.
+ if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
+ panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
+ }
+
+ // The leftmost 64 bits of sum is used as the IID.
+ if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize {
+ panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize))
+ }
+
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: IIDOffsetInIPv6Address * 8,
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go
new file mode 100644
index 000000000..3499d8399
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers.go
@@ -0,0 +1,551 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bufio"
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier.
+type IPv6ExtensionHeaderIdentifier uint8
+
+const (
+ // IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by
+ // Hop Options extension header, as per RFC 8200 section 4.3.
+ IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0
+
+ // IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension
+ // header, as per RFC 8200 section 4.4.
+ IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43
+
+ // IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment
+ // extension header, as per RFC 8200 section 4.5.
+ IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44
+
+ // IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a
+ // Destination Options extension header, as per RFC 8200 section 4.6.
+ IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60
+
+ // IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
+ // of an IPv6 payload, as per RFC 8200 section 4.7.
+ IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
+)
+
+const (
+ // ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when
+ // a node encounters an unrecognized option.
+ ipv6UnknownExtHdrOptionActionMask = 192
+
+ // ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard
+ // from the action value for an unrecognized option identifier.
+ ipv6UnknownExtHdrOptionActionShift = 6
+
+ // ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field
+ // within an IPv6RoutingExtHdr.
+ ipv6RoutingExtHdrSegmentsLeftIdx = 1
+
+ // IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in
+ // bytes.
+ IPv6FragmentExtHdrLength = 8
+
+ // ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
+ // Fragment Offset field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFragmentOffsetOffset = 0
+
+ // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to
+ // discard from the Fragment Offset.
+ ipv6FragmentExtHdrFragmentOffsetShift = 3
+
+ // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
+ // IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrFlagsIdx = 1
+
+ // ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the
+ // flags field of an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrMFlagMask = 1
+
+ // ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification
+ // field within an IPv6FragmentExtHdr.
+ ipv6FragmentExtHdrIdentificationOffset = 2
+
+ // ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length
+ // field. That is, given a Length field of 2, the extension header expects
+ // 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for
+ // details about the first 8 bytes' exclusion from the Length field).
+ ipv6ExtHdrLenBytesPerUnit = 8
+
+ // ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an
+ // extension header's Length field following the Length field.
+ //
+ // The Length field excludes the first 8 bytes, but the Next Header and Length
+ // field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes
+ // after the Length field.
+ //
+ // This ensures that every extension header is at least 8 bytes.
+ ipv6ExtHdrLenBytesExcluded = 6
+
+ // IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment
+ // extension header's Fragment Offset field. That is, given a Fragment Offset
+ // of 2, the extension header is indiciating that the fragment's payload
+ // starts at the 16th byte in the reassembled packet.
+ IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
+)
+
+// IPv6PayloadHeader is implemented by the various headers that can be found
+// in an IPv6 payload.
+//
+// These headers include IPv6 extension headers or upper layer data.
+type IPv6PayloadHeader interface {
+ isIPv6PayloadHeader()
+}
+
+// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
+// encounters a Next Header field it does not recognize as an IPv6 extension
+// header.
+type IPv6RawPayloadHeader struct {
+ Identifier IPv6ExtensionHeaderIdentifier
+ Buf buffer.VectorisedView
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
+
+// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
+type ipv6OptionsExtHdr []byte
+
+// Iter returns an iterator over the IPv6 extension header options held in b.
+func (b ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
+ it := IPv6OptionsExtHdrOptionsIterator{}
+ it.reader.Reset(b)
+ return it
+}
+
+// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
+// options.
+//
+// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
+// used, no changes to the underlying buffer may happen. Doing so may cause
+// undefined and unexpected behaviour. It is fine to obtain an
+// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
+// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
+// obtained before modification is no longer used.
+type IPv6OptionsExtHdrOptionsIterator struct {
+ reader bytes.Reader
+}
+
+// IPv6OptionUnknownAction is the action that must be taken if the processing
+// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
+type IPv6OptionUnknownAction int
+
+const (
+ // IPv6OptionUnknownActionSkip indicates that the unrecognized option must
+ // be skipped and the node should continue processing the header.
+ IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0
+
+ // IPv6OptionUnknownActionDiscard indicates that the packet must be silently
+ // discarded.
+ IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1
+
+ // IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be
+ // discarded and the node must send an ICMP Parameter Problem, Code 2, message
+ // to the packet's source, regardless of whether or not the packet's
+ // Destination was a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2
+
+ // IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the
+ // packet must be discarded and the node must send an ICMP Parameter Problem,
+ // Code 2, message to the packet's source only if the packet's Destination was
+ // not a multicast address.
+ IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3
+)
+
+// IPv6ExtHdrOption is implemented by the various IPv6 extension header options.
+type IPv6ExtHdrOption interface {
+ // UnknownAction returns the action to take in response to an unrecognized
+ // option.
+ UnknownAction() IPv6OptionUnknownAction
+
+ // isIPv6ExtHdrOption is used to "lock" this interface so it is not
+ // implemented by other packages.
+ isIPv6ExtHdrOption()
+}
+
+// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier.
+type IPv6ExtHdrOptionIndentifier uint8
+
+const (
+ // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides 1 byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0
+
+ // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that
+ // provides variable length byte padding, as outlined in RFC 8200 section 4.2.
+ ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1
+)
+
+// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
+// header option that is unknown by the parsing utilities.
+type IPv6UnknownExtHdrOption struct {
+ Identifier IPv6ExtHdrOptionIndentifier
+ Data []byte
+}
+
+// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
+func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
+ return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
+}
+
+// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
+func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
+
+// Next returns the next option in the options data.
+//
+// If the next item is not a known extension header option,
+// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
+//
+// The return is of the format (option, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the options data, or an error occured.
+func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
+ for {
+ temp, err := i.reader.ReadByte()
+ if err != nil {
+ // If we can't read the first byte of a new option, then we know the
+ // options buffer has been exhausted and we are done iterating.
+ return nil, true, nil
+ }
+ id := IPv6ExtHdrOptionIndentifier(temp)
+
+ // If the option identifier indicates the option is a Pad1 option, then we
+ // know the option does not have Length and Data fields. End processing of
+ // the Pad1 option and continue processing the buffer as a new option.
+ if id == ipv6Pad1ExtHdrOptionIdentifier {
+ continue
+ }
+
+ length, err := i.reader.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the reader to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
+ }
+
+ // Special-case the variable length padding option to avoid a copy.
+ if id == ipv6PadNExtHdrOptionIdentifier {
+ // Do we have enough bytes in the reader for the PadN option?
+ if n := i.reader.Len(); n < int(length) {
+ // Reset the reader to effectively consume the remaining buffer.
+ i.reader.Reset(nil)
+
+ // We return the same error as if we failed to read a non-padding option
+ // so consumers of this iterator don't need to differentiate between
+ // padding and non-padding options.
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
+ }
+
+ if _, err := i.reader.Seek(int64(length), io.SeekCurrent); err != nil {
+ panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err))
+ }
+
+ // End processing of the PadN option and continue processing the buffer as
+ // a new option.
+ continue
+ }
+
+ bytes := make([]byte, length)
+ if n, err := io.ReadFull(&i.reader, bytes); err != nil {
+ // io.ReadFull may return io.EOF if i.reader has been exhausted. We use
+ // io.ErrUnexpectedEOF instead as the io.EOF is unexpected given the
+ // Length field found in the option.
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+
+ return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
+ }
+
+ return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
+ }
+}
+
+// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
+// extension header.
+type IPv6HopByHopOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
+// extension header.
+type IPv6DestinationOptionsExtHdr struct {
+ ipv6OptionsExtHdr
+}
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
+
+// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
+// data as outlined in RFC 8200 section 4.4.
+type IPv6RoutingExtHdr []byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
+
+// SegmentsLeft returns the Segments Left field.
+func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
+ return b[ipv6RoutingExtHdrSegmentsLeftIdx]
+}
+
+// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
+// data as outlined in RFC 8200 section 4.5.
+//
+// Note, the buffer does not include the Next Header and Reserved fields.
+type IPv6FragmentExtHdr [6]byte
+
+// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
+func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {}
+
+// FragmentOffset returns the Fragment Offset field.
+//
+// This value indicates where the buffer following the Fragment extension header
+// starts in the target (reassembled) packet.
+func (b IPv6FragmentExtHdr) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift
+}
+
+// More returns the More (M) flag.
+//
+// This indicates whether any fragments are expected to succeed b.
+func (b IPv6FragmentExtHdr) More() bool {
+ return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0
+}
+
+// ID returns the Identification field.
+//
+// This value is used to uniquely identify the packet, between a
+// souce and destination.
+func (b IPv6FragmentExtHdr) ID() uint32 {
+ return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
+}
+
+// IsAtomic returns whether the fragment header indicates an atomic fragment. An
+// atomic fragment is a fragment that contains all the data required to
+// reassemble a full packet.
+func (b IPv6FragmentExtHdr) IsAtomic() bool {
+ return !b.More() && b.FragmentOffset() == 0
+}
+
+// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
+//
+// The IPv6 payload may contain IPv6 extension headers before any upper layer
+// data.
+//
+// Note, between when an IPv6PayloadIterator is obtained and last used, no
+// changes to the payload may happen. Doing so may cause undefined and
+// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
+// over the first few headers then modify the backing payload so long as the
+// IPv6PayloadIterator obtained before modification is no longer used.
+type IPv6PayloadIterator struct {
+ // The identifier of the next header to parse.
+ nextHdrIdentifier IPv6ExtensionHeaderIdentifier
+
+ // reader is an io.Reader over payload.
+ reader bufio.Reader
+ payload buffer.VectorisedView
+
+ // Indicates to the iterator that it should return the remaining payload as a
+ // raw payload on the next call to Next.
+ forceRaw bool
+}
+
+// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
+// extension headers, or a raw payload if the payload cannot be parsed.
+func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.VectorisedView) IPv6PayloadIterator {
+ readers := payload.Readers()
+ readerPs := make([]io.Reader, 0, len(readers))
+ for i := range readers {
+ readerPs = append(readerPs, &readers[i])
+ }
+
+ return IPv6PayloadIterator{
+ nextHdrIdentifier: nextHdrIdentifier,
+ payload: payload.Clone(nil),
+ // We need a buffer of size 1 for calls to bufio.Reader.ReadByte.
+ reader: *bufio.NewReaderSize(io.MultiReader(readerPs...), 1),
+ }
+}
+
+// AsRawHeader returns the remaining payload of i as a raw header and
+// optionally consumes the iterator.
+//
+// If consume is true, calls to Next after calling AsRawHeader on i will
+// indicate that the iterator is done.
+func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
+ identifier := i.nextHdrIdentifier
+
+ var buf buffer.VectorisedView
+ if consume {
+ // Since we consume the iterator, we return the payload as is.
+ buf = i.payload
+
+ // Mark i as done.
+ *i = IPv6PayloadIterator{
+ nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
+ }
+ } else {
+ buf = i.payload.Clone(nil)
+ }
+
+ return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
+}
+
+// Next returns the next item in the payload.
+//
+// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
+// will be returned with the remaining bytes and next header identifier.
+//
+// The return is of the format (header, done, error). done will be true when
+// Next is unable to return anything because the iterator has reached the end of
+// the payload, or an error occured.
+func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
+ // We could be forced to return i as a raw header when the previous header was
+ // a fragment extension header as the data following the fragment extension
+ // header may not be complete.
+ if i.forceRaw {
+ return i.AsRawHeader(true /* consume */), false, nil
+ }
+
+ // Is the header we are parsing a known extension header?
+ switch i.nextHdrIdentifier {
+ case IPv6HopByHopOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6RoutingExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6RoutingExtHdr(bytes), false, nil
+ case IPv6FragmentExtHdrIdentifier:
+ var data [6]byte
+ // We ignore the returned bytes becauase we know the fragment extension
+ // header specific data will fit in data.
+ nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
+ if err != nil {
+ return nil, true, err
+ }
+
+ fragmentExtHdr := IPv6FragmentExtHdr(data)
+
+ // If the packet is not the first fragment, do not attempt to parse anything
+ // after the fragment extension header as the payload following the fragment
+ // extension header should not contain any headers; the first fragment must
+ // hold all the headers up to and including any upper layer headers, as per
+ // RFC 8200 section 4.5.
+ if fragmentExtHdr.FragmentOffset() != 0 {
+ i.forceRaw = true
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return fragmentExtHdr, false, nil
+ case IPv6DestinationOptionsExtHdrIdentifier:
+ nextHdrIdentifier, bytes, err := i.nextHeaderData(false /* fragmentHdr */, nil)
+ if err != nil {
+ return nil, true, err
+ }
+
+ i.nextHdrIdentifier = nextHdrIdentifier
+ return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: bytes}, false, nil
+ case IPv6NoNextHeaderIdentifier:
+ // This indicates the end of the IPv6 payload.
+ return nil, true, nil
+
+ default:
+ // The header we are parsing is not a known extension header. Return the
+ // raw payload.
+ return i.AsRawHeader(true /* consume */), false, nil
+ }
+}
+
+// nextHeaderData returns the extension header's Next Header field and raw data.
+//
+// fragmentHdr indicates that the extension header being parsed is the Fragment
+// extension header so the Length field should be ignored as it is Reserved
+// for the Fragment extension header.
+//
+// If bytes is not nil, extension header specific data will be read into bytes
+// if it has enough capacity. If bytes is provided but does not have enough
+// capacity for the data, nextHeaderData will panic.
+func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, []byte, error) {
+ // We ignore the number of bytes read because we know we will only ever read
+ // at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
+ // would return io.EOF to indicate that io.Reader has reached the end of the
+ // payload.
+ nextHdrIdentifier, err := i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ var length uint8
+ length, err = i.reader.ReadByte()
+ i.payload.TrimFront(1)
+ if err != nil {
+ if fragmentHdr {
+ return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+
+ return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
+ }
+ if fragmentHdr {
+ length = 0
+ }
+
+ bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
+ if bytes == nil {
+ bytes = make([]byte, bytesLen)
+ } else if n := len(bytes); n < bytesLen {
+ panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
+ }
+
+ n, err := io.ReadFull(&i.reader, bytes)
+ i.payload.TrimFront(n)
+ if err != nil {
+ return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
+ }
+
+ return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil
+}
diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go
new file mode 100644
index 000000000..ab20c5f37
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_extension_headers_test.go
@@ -0,0 +1,992 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold the same Identifier value and
+// contain the same bytes in Buf, even if the bytes are split across views
+// differently.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6RawPayloadHeader) Equal(b IPv6RawPayloadHeader) bool {
+ return a.Identifier == b.Identifier && bytes.Equal(a.Buf.ToView(), b.Buf.ToView())
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6HopByHopOptionsExtHdr) Equal(b IPv6HopByHopOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+// Equal returns true of a and b are equivalent.
+//
+// Note, Equal will return true if a and b hold equivalent ipv6OptionsExtHdrs.
+//
+// Needed to use cmp.Equal on IPv6RawPayloadHeader as it contains unexported
+// fields.
+func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool {
+ return bytes.Equal(a.ipv6OptionsExtHdr, b.ipv6OptionsExtHdr)
+}
+
+func TestIPv6UnknownExtHdrOption(t *testing.T) {
+ tests := []struct {
+ name string
+ identifier IPv6ExtHdrOptionIndentifier
+ expectedUnknownAction IPv6OptionUnknownAction
+ }{
+ {
+ name: "Skip with zero LSBs",
+ identifier: 0,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with zero LSBs",
+ identifier: 64,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with zero LSBs",
+ identifier: 128,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with zero LSBs",
+ identifier: 192,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ {
+ name: "Skip with non-zero LSBs",
+ identifier: 63,
+ expectedUnknownAction: IPv6OptionUnknownActionSkip,
+ },
+ {
+ name: "Discard with non-zero LSBs",
+ identifier: 127,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscard,
+ },
+ {
+ name: "Discard and ICMP with non-zero LSBs",
+ identifier: 191,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMP,
+ },
+ {
+ name: "Discard and ICMP for non multicast destination with non-zero LSBs",
+ identifier: 255,
+ expectedUnknownAction: IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := &IPv6UnknownExtHdrOption{Identifier: test.identifier, Data: []byte{1, 2, 3, 4}}
+ if a := opt.UnknownAction(); a != test.expectedUnknownAction {
+ t.Fatalf("got UnknownAction() = %d, want = %d", a, test.expectedUnknownAction)
+ }
+ })
+ }
+
+}
+
+func TestIPv6OptionsExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ err error
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ },
+ {
+ name: "Two options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ },
+ },
+ {
+ name: "Three options",
+ bytes: []byte{
+ 255, 0,
+ 254, 1, 1,
+ 253, 4, 2, 3, 4, 5,
+ },
+ },
+ {
+ name: "Single unknown only identifier",
+ bytes: []byte{255},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 1",
+ bytes: []byte{255, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Single unknown too small with length = 2",
+ bytes: []byte{255, 2, 1},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown only identifier",
+ bytes: []byte{
+ 255, 0,
+ 254,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown missing data",
+ bytes: []byte{
+ 255, 0,
+ 254, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid first with second unknown too small",
+ bytes: []byte{
+ 255, 0,
+ 254, 2, 1,
+ },
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "One Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Multiple Pad1",
+ bytes: []byte{0, 0, 0},
+ },
+ {
+ name: "Multiple PadN",
+ bytes: []byte{
+ // Pad3
+ 1, 1, 1,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Pad5 too small middle of data buffer",
+ bytes: []byte{1, 3, 1, 2},
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Pad5 no data",
+ bytes: []byte{1, 3},
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) {
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as expectedErr.
+ if !errors.Is(err, expectedErr) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, expectedErr)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if expectedErr != nil {
+ t.Fatalf("expected error when iterating; want = %s", expectedErr)
+ }
+
+ return
+ }
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ check(t, extHdr.Iter(), test.err)
+ })
+ })
+ }
+}
+
+func TestIPv6OptionsExtHdrIter(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ expected []IPv6ExtHdrOption
+ }{
+ {
+ name: "Single unknown with zero length",
+ bytes: []byte{255, 0},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ },
+ },
+ {
+ name: "Single unknown with non-zero length",
+ bytes: []byte{255, 3, 1, 2, 3},
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{1, 2, 3}},
+ },
+ },
+ {
+ name: "Single Pad1",
+ bytes: []byte{0},
+ },
+ {
+ name: "Two Pad1",
+ bytes: []byte{0, 0},
+ },
+ {
+ name: "Single Pad3",
+ bytes: []byte{1, 1, 1},
+ },
+ {
+ name: "Single Pad5",
+ bytes: []byte{1, 3, 1, 2, 3},
+ },
+ {
+ name: "Multiple Pad",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Pad2
+ 1, 0,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Pad4
+ 1, 2, 1, 2,
+
+ // Pad5
+ 1, 3, 1, 2, 3,
+ },
+ },
+ {
+ name: "Multiple options",
+ bytes: []byte{
+ // Pad1
+ 0,
+
+ // Unknown
+ 255, 0,
+
+ // Pad2
+ 1, 0,
+
+ // Unknown
+ 254, 1, 1,
+
+ // Pad3
+ 1, 1, 1,
+
+ // Unknown
+ 253, 4, 2, 3, 4, 5,
+
+ // Pad4
+ 1, 2, 1, 2,
+ },
+ expected: []IPv6ExtHdrOption{
+ &IPv6UnknownExtHdrOption{Identifier: 255, Data: []byte{}},
+ &IPv6UnknownExtHdrOption{Identifier: 254, Data: []byte{1}},
+ &IPv6UnknownExtHdrOption{Identifier: 253, Data: []byte{2, 3, 4, 5}},
+ },
+ },
+ }
+
+ checkIter := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expected []IPv6ExtHdrOption) {
+ for i, e := range expected {
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, opt); diff != "" {
+ t.Errorf("(i=%d) got option mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ opt, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if opt != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", opt)
+ }
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Run("Hop By Hop", func(t *testing.T) {
+ extHdr := IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+
+ t.Run("Destination", func(t *testing.T) {
+ extHdr := IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: test.bytes}
+ checkIter(t, extHdr.Iter(), test.expected)
+ })
+ })
+ }
+}
+
+func TestIPv6RoutingExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes []byte
+ segmentsLeft uint8
+ }{
+ {
+ name: "Zeroes",
+ bytes: []byte{0, 0, 0, 0, 0, 0},
+ segmentsLeft: 0,
+ },
+ {
+ name: "Ones",
+ bytes: []byte{1, 1, 1, 1, 1, 1},
+ segmentsLeft: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: []byte{1, 2, 3, 4, 5, 6},
+ segmentsLeft: 2,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6RoutingExtHdr(test.bytes)
+ if got := extHdr.SegmentsLeft(); got != test.segmentsLeft {
+ t.Errorf("got SegmentsLeft() = %d, want = %d", got, test.segmentsLeft)
+ }
+ })
+ }
+}
+
+func TestIPv6FragmentExtHdr(t *testing.T) {
+ tests := []struct {
+ name string
+ bytes [6]byte
+ fragmentOffset uint16
+ more bool
+ id uint32
+ }{
+ {
+ name: "Zeroes",
+ bytes: [6]byte{0, 0, 0, 0, 0, 0},
+ fragmentOffset: 0,
+ more: false,
+ id: 0,
+ },
+ {
+ name: "Ones",
+ bytes: [6]byte{0, 9, 0, 0, 0, 1},
+ fragmentOffset: 1,
+ more: true,
+ id: 1,
+ },
+ {
+ name: "Mixed",
+ bytes: [6]byte{68, 9, 128, 4, 2, 1},
+ fragmentOffset: 2177,
+ more: true,
+ id: 2147746305,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ extHdr := IPv6FragmentExtHdr(test.bytes)
+ if got := extHdr.FragmentOffset(); got != test.fragmentOffset {
+ t.Errorf("got FragmentOffset() = %d, want = %d", got, test.fragmentOffset)
+ }
+ if got := extHdr.More(); got != test.more {
+ t.Errorf("got More() = %t, want = %t", got, test.more)
+ }
+ if got := extHdr.ID(); got != test.id {
+ t.Errorf("got ID() = %d, want = %d", got, test.id)
+ }
+ })
+ }
+}
+
+func makeVectorisedViewFromByteBuffers(bs ...[]byte) buffer.VectorisedView {
+ size := 0
+ var vs []buffer.View
+
+ for _, b := range bs {
+ vs = append(vs, buffer.View(b))
+ size += len(b)
+ }
+
+ return buffer.NewVectorisedView(size, vs)
+}
+
+func TestIPv6ExtHdrIterErr(t *testing.T) {
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ err error
+ }{
+ {
+ name: "Upper layer only without data",
+ firstNextHdr: 255,
+ },
+ {
+ name: "Upper layer only with data",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "No next header",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ },
+ {
+ name: "No next header with data",
+ firstNextHdr: IPv6NoNextHeaderIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{1, 2, 3, 4}),
+ },
+ {
+ name: "Valid single hop by hop",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Hop by hop too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single fragment",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2, 1}),
+ },
+ {
+ name: "Fragment too small",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 68, 9, 128, 4, 2}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single destination",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3, 4}),
+ },
+ {
+ name: "Destination too small",
+ firstNextHdr: IPv6DestinationOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 4, 1, 2, 3}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid single routing",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5, 6}),
+ },
+ {
+ name: "Valid single routing across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2}, []byte{3, 4, 5, 6}),
+ },
+ {
+ name: "Routing too small with zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 0, 1, 2, 3, 4, 5}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Valid routing with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Valid routing with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7, 8}),
+ },
+ {
+ name: "Routing too small with non-zero length field",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Routing too small with non-zero length field across views",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{255, 1, 1, 2, 3, 4, 5, 6}, []byte{1, 2, 3, 4, 5, 6, 7}),
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "Mixed",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3, 4,
+ }),
+ },
+ {
+ name: "Mixed without upper layer data but last ext hdr too small",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // (Atomic) Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 255, 4, 1, 2, 3,
+ }),
+ err: io.ErrUnexpectedEOF,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i := 0; ; i++ {
+ _, done, err := it.Next()
+ if err != nil {
+ // If we encountered a non-nil error while iterating, make sure it is
+ // is the same error as test.err.
+ if !errors.Is(err, test.err) {
+ t.Fatalf("got %d-th Next() = %v, want = %v", i, err, test.err)
+ }
+
+ return
+ }
+ if done {
+ // If we are done (without an error), make sure that we did not expect
+ // an error.
+ if test.err != nil {
+ t.Fatalf("expected error when iterating; want = %s", test.err)
+ }
+
+ return
+ }
+ }
+ })
+ }
+}
+
+func TestIPv6ExtHdrIter(t *testing.T) {
+ routingExtHdrWithUpperLayerData := buffer.View([]byte{255, 0, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4})
+ upperLayerData := buffer.View([]byte{1, 2, 3, 4})
+ tests := []struct {
+ name string
+ firstNextHdr IPv6ExtensionHeaderIdentifier
+ payload buffer.VectorisedView
+ expected []IPv6PayloadHeader
+ }{
+ // With a non-atomic fragment that is not the first fragment, the payload
+ // after the fragment will not be parsed because the payload is expected to
+ // only hold upper layer data.
+ {
+ name: "hopbyhop - fragment (not first) - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 2117, ID = 2147746305
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ //
+ // Even though we have a routing ext header here, it should be
+ // be interpretted as raw bytes as only the first fragment is expected
+ // to hold headers.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "hopbyhop - fragment (first) - routing - upper",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Fragment extension header.
+ //
+ // More = 1, Fragment Offset = 0, ID = 2147746305
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 0, 1, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2, 3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6FragmentExtHdr([6]byte{0, 1, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 68, 9, 128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{68, 9, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6RoutingExtHdrIdentifier,
+ Buf: routingExtHdrWithUpperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // If we have an atomic fragment, the payload following the fragment
+ // extension header should be parsed normally.
+ {
+ name: "atomic fragment - routing - destination - upper",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6, 128, 4, 2, 1,
+
+ // Routing extension header.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Destination Options extension header.
+ 255, 0, 1, 4, 1, 2, 3, 4,
+
+ // Upper layer data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - routing - upper (across views)",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6RoutingExtHdrIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Routing extension header.
+ 255, 0, 1, 2}, []byte{3, 4, 5, 6,
+
+ // Upper layer data.
+ 1, 2}, []byte{3, 4}),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ },
+ },
+ },
+ {
+ name: "atomic fragment - destination - no next header",
+ firstNextHdr: IPv6FragmentExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Fragment extension header.
+ //
+ // Res (Reserved) bits are 1 which should not affect anything.
+ uint8(IPv6DestinationOptionsExtHdrIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Destination Options extension header.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 0, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "routing - atomic fragment - no next header (across views)",
+ firstNextHdr: IPv6RoutingExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Reserved bits are 1 which should not affect anything.
+ uint8(IPv6NoNextHeaderIdentifier), 255, 0, 6}, []byte{128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{0, 6, 128, 4, 2, 1}),
+ },
+ },
+ {
+ name: "hopbyhop - routing - fragment - no next header",
+ firstNextHdr: IPv6HopByHopOptionsExtHdrIdentifier,
+ payload: makeVectorisedViewFromByteBuffers([]byte{
+ // Hop By Hop Options extension header.
+ uint8(IPv6RoutingExtHdrIdentifier), 0, 1, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ uint8(IPv6FragmentExtHdrIdentifier), 0, 1, 2, 3, 4, 5, 6,
+
+ // Fragment extension header.
+ //
+ // Fragment Offset = 32; Res = 6.
+ uint8(IPv6NoNextHeaderIdentifier), 0, 1, 6, 128, 4, 2, 1,
+
+ // Random data.
+ 1, 2, 3, 4,
+ }),
+ expected: []IPv6PayloadHeader{
+ IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr: []byte{1, 4, 1, 2, 3, 4}},
+ IPv6RoutingExtHdr([]byte{1, 2, 3, 4, 5, 6}),
+ IPv6FragmentExtHdr([6]byte{1, 6, 128, 4, 2, 1}),
+ IPv6RawPayloadHeader{
+ Identifier: IPv6NoNextHeaderIdentifier,
+ Buf: upperLayerData.ToVectorisedView(),
+ },
+ },
+ },
+
+ // Test the raw payload for common transport layer protocol numbers.
+ {
+ name: "TCP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(TCPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "UDP raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(UDPProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv4 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv4ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "ICMPv6 raw payload",
+ firstNextHdr: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: IPv6ExtensionHeaderIdentifier(ICMPv6ProtocolNumber),
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: upperLayerData.ToVectorisedView(),
+ }},
+ },
+ {
+ name: "Unknwon next header raw payload (across views)",
+ firstNextHdr: 255,
+ payload: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ expected: []IPv6PayloadHeader{IPv6RawPayloadHeader{
+ Identifier: 255,
+ Buf: makeVectorisedViewFromByteBuffers(upperLayerData[:2], upperLayerData[2:]),
+ }},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ it := MakeIPv6PayloadIterator(test.firstNextHdr, test.payload)
+
+ for i, e := range test.expected {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(i=%d) Next(): %s", i, err)
+ }
+ if done {
+ t.Errorf("(i=%d) unexpectedly done iterating", i)
+ }
+ if diff := cmp.Diff(e, extHdr); diff != "" {
+ t.Errorf("(i=%d) got ext hdr mismatch (-want +got):\n%s", i, diff)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+
+ extHdr, done, err := it.Next()
+ if err != nil {
+ t.Errorf("(last) Next(): %s", err)
+ }
+ if !done {
+ t.Errorf("(last) iterator unexpectedly not done")
+ }
+ if extHdr != nil {
+ t.Errorf("(last) got Next() = %T, want = nil", extHdr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
new file mode 100644
index 000000000..426a873b1
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -0,0 +1,417 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header_test
+
+import (
+ "bytes"
+ "crypto/sha256"
+ "fmt"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+)
+
+func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
+ expectedIID := [header.IIDSize]byte{0, 2, 3, 255, 254, 4, 5, 6}
+
+ if diff := cmp.Diff(expectedIID, header.EthernetAddressToModifiedEUI64(linkAddr)); diff != "" {
+ t.Errorf("EthernetAddressToModifiedEUI64(%s) mismatch (-want +got):\n%s", linkAddr, diff)
+ }
+
+ var buf [header.IIDSize]byte
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
+ if diff := cmp.Diff(expectedIID, buf); diff != "" {
+ t.Errorf("EthernetAddressToModifiedEUI64IntoBuf(%s, _) mismatch (-want +got):\n%s", linkAddr, diff)
+ }
+}
+
+func TestLinkLocalAddr(t *testing.T) {
+ if got, want := header.LinkLocalAddr(linkAddr), tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x02\x03\xff\xfe\x04\x05\x06"); got != want {
+ t.Errorf("got LinkLocalAddr(%s) = %s, want = %s", linkAddr, got, want)
+ }
+}
+
+func TestAppendOpaqueInterfaceIdentifier(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ prefix: header.IPv6LinkLocalPrefix.Subnet(),
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ prefix: func() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: "\x01\x02\x03\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ return addrWithPrefix.Subnet()
+ }(),
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ h := sha256.New()
+ h.Write([]byte(test.prefix.ID()[:header.IIDOffsetInIPv6Address]))
+ h.Write([]byte(test.nicName))
+ h.Write([]byte{test.dadCounter})
+ if k := test.secretKey; k != nil {
+ h.Write(k)
+ }
+ var hashSum [sha256.Size]byte
+ h.Sum(hashSum[:0])
+ want := hashSum[:header.IIDSize]
+
+ // Passing a nil buffer should result in a new buffer returned with the
+ // IID.
+ if got := header.AppendOpaqueInterfaceIdentifier(nil, test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(nil, %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+
+ // Passing a buffer with sufficient capacity for the IID should populate
+ // the buffer provided.
+ var iidBuf [header.IIDSize]byte
+ if got := header.AppendOpaqueInterfaceIdentifier(iidBuf[:0], test.prefix, test.nicName, test.dadCounter, test.secretKey); !bytes.Equal(got, want) {
+ t.Errorf("got AppendOpaqueInterfaceIdentifier(iidBuf[:0], %s, %s, %d, %x) = %x, want = %x", test.prefix, test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ if got := iidBuf[:]; !bytes.Equal(got, want) {
+ t.Errorf("got iidBuf = %x, want = %x", got, want)
+ }
+ })
+ }
+}
+
+func TestLinkLocalAddrWithOpaqueIID(t *testing.T) {
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes * 2]byte
+ if n, err := rand.Read(secretKeyBuf[:]); err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ } else if want := header.OpaqueIIDSecretKeyMinBytes * 2; n != want {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", want, n)
+ }
+
+ prefix := header.IPv6LinkLocalPrefix.Subnet()
+
+ tests := []struct {
+ name string
+ prefix tcpip.Subnet
+ nicName string
+ dadCounter uint8
+ secretKey []byte
+ }{
+ {
+ name: "SecretKey of minimum size",
+ nicName: "eth0",
+ dadCounter: 0,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes],
+ },
+ {
+ name: "SecretKey of less than minimum size",
+ nicName: "eth10",
+ dadCounter: 1,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes/2],
+ },
+ {
+ name: "SecretKey of more than minimum size",
+ nicName: "eth11",
+ dadCounter: 2,
+ secretKey: secretKeyBuf[:header.OpaqueIIDSecretKeyMinBytes*2],
+ },
+ {
+ name: "Nil SecretKey and empty nicName",
+ nicName: "",
+ dadCounter: 3,
+ secretKey: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ addrBytes := [header.IPv6AddressSize]byte{
+ 0: 0xFE,
+ 1: 0x80,
+ }
+
+ want := tcpip.Address(header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ test.nicName,
+ test.dadCounter,
+ test.secretKey,
+ ))
+
+ if got := header.LinkLocalAddrWithOpaqueIID(test.nicName, test.dadCounter, test.secretKey); got != want {
+ t.Errorf("got LinkLocalAddrWithOpaqueIID(%s, %d, %x) = %s, want = %s", test.nicName, test.dadCounter, test.secretKey, got, want)
+ }
+ })
+ }
+}
+
+func TestIsV6UniqueLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Unique 1",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Valid Unique 2",
+ addr: uniqueLocalAddr1,
+ expected: true,
+ },
+ {
+ name: "Link Local",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: true,
+ },
+ {
+ name: "Valid Link Local Multicast with flags",
+ addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ expected: true,
+ },
+ {
+ name: "Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Multicast",
+ addr: "\xe0\x00\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV6LinkLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: true,
+ },
+ {
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: false,
+ },
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Link Local",
+ addr: "\xa9\xfe\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestScopeForIPv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ scope header.IPv6AddressScope
+ err *tcpip.Error
+ }{
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ scope: header.UniqueLocalScope,
+ err: nil,
+ },
+ {
+ name: "Link Local Unicast",
+ addr: linkLocalAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ scope: header.GlobalScope,
+ err: nil,
+ },
+ {
+ name: "IPv4",
+ addr: "\x01\x02\x03\x04",
+ scope: header.GlobalScope,
+ err: tcpip.ErrBadAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ got, err := header.ScopeForIPv6Address(test.addr)
+ if err != test.err {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (_, %v), want = (_, %v)", test.addr, err, test.err)
+ }
+ if got != test.scope {
+ t.Errorf("got header.IsV6UniqueLocalAddress(%s) = (%d, _), want = (%d, _)", test.addr, got, test.scope)
+ }
+ })
+ }
+}
+
+func TestSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ addr tcpip.Address
+ want tcpip.Address
+ }{
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x0e\x0f\xa0",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x0e\x0f\xa0",
+ },
+ {
+ addr: "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\xdd\x01\x02\x03",
+ want: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff\x01\x02\x03",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
+ if got := header.SolicitedNodeAddr(test.addr); got != test.want {
+ t.Fatalf("got header.SolicitedNodeAddr(%s) = %s, want = %s", test.addr, got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ndp_options.go b/pkg/tcpip/header/ndp_options.go
index 98310ea23..5d3975c56 100644
--- a/pkg/tcpip/header/ndp_options.go
+++ b/pkg/tcpip/header/ndp_options.go
@@ -15,24 +15,46 @@
package header
import (
+ "bytes"
"encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "math"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
)
+// NDPOptionIdentifier is an NDP option type identifier.
+type NDPOptionIdentifier uint8
+
const (
- // NDPTargetLinkLayerAddressOptionType is the type of the Target
- // Link-Layer Address option, as per RFC 4861 section 4.6.1.
- NDPTargetLinkLayerAddressOptionType = 2
+ // NDPSourceLinkLayerAddressOptionType is the type of the Source Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
+ NDPSourceLinkLayerAddressOptionType NDPOptionIdentifier = 1
- // ndpTargetEthernetLinkLayerAddressSize is the size of a Target
- // Link Layer Option for an Ethernet address.
- ndpTargetEthernetLinkLayerAddressSize = 8
+ // NDPTargetLinkLayerAddressOptionType is the type of the Target Link Layer
+ // Address option, as per RFC 4861 section 4.6.1.
+ NDPTargetLinkLayerAddressOptionType NDPOptionIdentifier = 2
- // ndpPrefixInformationType is the type of the Prefix Information
+ // NDPPrefixInformationType is the type of the Prefix Information
// option, as per RFC 4861 section 4.6.2.
- ndpPrefixInformationType = 3
+ NDPPrefixInformationType NDPOptionIdentifier = 3
+
+ // NDPRecursiveDNSServerOptionType is the type of the Recursive DNS
+ // Server option, as per RFC 8106 section 5.1.
+ NDPRecursiveDNSServerOptionType NDPOptionIdentifier = 25
+
+ // NDPDNSSearchListOptionType is the type of the DNS Search List option,
+ // as per RFC 8106 section 5.2.
+ NDPDNSSearchListOptionType = 31
+)
+
+const (
+ // NDPLinkLayerAddressSize is the size of a Source or Target Link Layer
+ // Address option for an Ethernet address.
+ NDPLinkLayerAddressSize = 8
// ndpPrefixInformationLength is the expected length, in bytes, of the
// body of an NDP Prefix Information option, as per RFC 4861 section
@@ -84,10 +106,39 @@ const (
// within an NDPPrefixInformation.
ndpPrefixInformationPrefixOffset = 14
- // NDPPrefixInformationInfiniteLifetime is a value that represents
- // infinity for the Valid and Preferred Lifetime fields in a NDP Prefix
- // Information option. Its value is (2^32 - 1)s = 4294967295s
- NDPPrefixInformationInfiniteLifetime = time.Second * 4294967295
+ // ndpRecursiveDNSServerLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerLifetimeOffset = 2
+
+ // ndpRecursiveDNSServerAddressesOffset is the start of the addresses
+ // for IPv6 Recursive DNS Servers within an NDPRecursiveDNSServer.
+ ndpRecursiveDNSServerAddressesOffset = 6
+
+ // minNDPRecursiveDNSServerLength is the minimum NDP Recursive DNS Server
+ // option's body size when it contains at least one IPv6 address, as per
+ // RFC 8106 section 5.3.1.
+ minNDPRecursiveDNSServerBodySize = 22
+
+ // ndpDNSSearchListLifetimeOffset is the start of the 4-byte
+ // Lifetime field within an NDPDNSSearchList.
+ ndpDNSSearchListLifetimeOffset = 2
+
+ // ndpDNSSearchListDomainNamesOffset is the start of the DNS search list
+ // domain names within an NDPDNSSearchList.
+ ndpDNSSearchListDomainNamesOffset = 6
+
+ // minNDPDNSSearchListBodySize is the minimum NDP DNS Search List option's
+ // body size when it contains at least one domain name, as per RFC 8106
+ // section 5.3.1.
+ minNDPDNSSearchListBodySize = 14
+
+ // maxDomainNameLabelLength is the maximum length of a domain name
+ // label, as per RFC 1035 section 3.1.
+ maxDomainNameLabelLength = 63
+
+ // maxDomainNameLength is the maximum length of a domain name, including
+ // label AND label length octet, as per RFC 1035 section 3.1.
+ maxDomainNameLength = 255
// lengthByteUnits is the multiplier factor for the Length field of an
// NDP option. That is, the length field for NDP options is in units of
@@ -95,9 +146,158 @@ const (
lengthByteUnits = 8
)
+var (
+ // NDPInfiniteLifetime is a value that represents infinity for the
+ // 4-byte lifetime fields found in various NDP options. Its value is
+ // (2^32 - 1)s = 4294967295s.
+ //
+ // This is a variable instead of a constant so that tests can change
+ // this value to a smaller value. It should only be modified by tests.
+ NDPInfiniteLifetime = time.Second * math.MaxUint32
+)
+
+// NDPOptionIterator is an iterator of NDPOption.
+//
+// Note, between when an NDPOptionIterator is obtained and last used, no changes
+// to the NDPOptions may happen. Doing so may cause undefined and unexpected
+// behaviour. It is fine to obtain an NDPOptionIterator, iterate over the first
+// few NDPOption then modify the backing NDPOptions so long as the
+// NDPOptionIterator obtained before modification is no longer used.
+type NDPOptionIterator struct {
+ opts *bytes.Buffer
+}
+
+// Potential errors when iterating over an NDPOptions.
+var (
+ ErrNDPOptMalformedBody = errors.New("NDP option has a malformed body")
+ ErrNDPOptMalformedHeader = errors.New("NDP option has a malformed header")
+)
+
+// Next returns the next element in the backing NDPOptions, or true if we are
+// done, or false if an error occured.
+//
+// The return can be read as option, done, error. Note, option should only be
+// used if done is false and error is nil.
+func (i *NDPOptionIterator) Next() (NDPOption, bool, error) {
+ for {
+ // Do we still have elements to look at?
+ if i.opts.Len() == 0 {
+ return nil, true, nil
+ }
+
+ // Get the Type field.
+ temp, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading the option's Type field: %s", err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
+ // we start parsing an option; we expect the buffer to contain enough
+ // bytes for the whole option.
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Type field: %w", io.ErrUnexpectedEOF)
+ }
+ kind := NDPOptionIdentifier(temp)
+
+ // Get the Length field.
+ length, err := i.opts.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading the option's Length field for %s: %s", kind, err))
+ }
+
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Length field for %s: %w", kind, io.ErrUnexpectedEOF)
+ }
+
+ // This would indicate an erroneous NDP option as the Length field should
+ // never be 0.
+ if length == 0 {
+ return nil, true, fmt.Errorf("zero valued Length field for %s: %w", kind, ErrNDPOptMalformedHeader)
+ }
+
+ // Get the body.
+ numBytes := int(length) * lengthByteUnits
+ numBodyBytes := numBytes - 2
+ body := i.opts.Next(numBodyBytes)
+ if len(body) < numBodyBytes {
+ return nil, true, fmt.Errorf("unexpectedly exhausted buffer when reading the option's Body for %s: %w", kind, io.ErrUnexpectedEOF)
+ }
+
+ switch kind {
+ case NDPSourceLinkLayerAddressOptionType:
+ return NDPSourceLinkLayerAddressOption(body), false, nil
+
+ case NDPTargetLinkLayerAddressOptionType:
+ return NDPTargetLinkLayerAddressOption(body), false, nil
+
+ case NDPPrefixInformationType:
+ // Make sure the length of a Prefix Information option
+ // body is ndpPrefixInformationLength, as per RFC 4861
+ // section 4.6.2.
+ if numBodyBytes != ndpPrefixInformationLength {
+ return nil, true, fmt.Errorf("got %d bytes for NDP Prefix Information option's body, expected %d bytes: %w", numBodyBytes, ndpPrefixInformationLength, ErrNDPOptMalformedBody)
+ }
+
+ return NDPPrefixInformation(body), false, nil
+
+ case NDPRecursiveDNSServerOptionType:
+ opt := NDPRecursiveDNSServer(body)
+ if err := opt.checkAddresses(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
+ case NDPDNSSearchListOptionType:
+ opt := NDPDNSSearchList(body)
+ if err := opt.checkDomainNames(); err != nil {
+ return nil, true, err
+ }
+
+ return opt, false, nil
+
+ default:
+ // We do not yet recognize the option, just skip for
+ // now. This is okay because RFC 4861 allows us to
+ // skip/ignore any unrecognized options. However,
+ // we MUST recognized all the options in RFC 4861.
+ //
+ // TODO(b/141487990): Handle all NDP options as defined
+ // by RFC 4861.
+ }
+ }
+}
+
// NDPOptions is a buffer of NDP options as defined by RFC 4861 section 4.6.
type NDPOptions []byte
+// Iter returns an iterator of NDPOption.
+//
+// If check is true, Iter will do an integrity check on the options by iterating
+// over it and returning an error if detected.
+//
+// See NDPOptionIterator for more information.
+func (b NDPOptions) Iter(check bool) (NDPOptionIterator, error) {
+ it := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
+
+ if check {
+ it2 := NDPOptionIterator{
+ opts: bytes.NewBuffer(b),
+ }
+
+ for {
+ if _, done, err := it2.Next(); err != nil || done {
+ return it, err
+ }
+ }
+ }
+
+ return it, nil
+}
+
// Serialize serializes the provided list of NDP options into o.
//
// Note, b must be of sufficient size to hold all the options in s. See
@@ -116,7 +316,7 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
continue
}
- b[0] = o.Type()
+ b[0] = byte(o.Type())
// We know this safe because paddedLength would have returned
// 0 if o had an invalid length (> 255 * lengthByteUnits).
@@ -137,15 +337,17 @@ func (b NDPOptions) Serialize(s NDPOptionsSerializer) int {
return done
}
-// ndpOption is the set of functions to be implemented by all NDP option types.
-type ndpOption interface {
- // Type returns the type of this ndpOption.
- Type() uint8
+// NDPOption is the set of functions to be implemented by all NDP option types.
+type NDPOption interface {
+ fmt.Stringer
+
+ // Type returns the type of the receiver.
+ Type() NDPOptionIdentifier
- // Length returns the length of the body of this ndpOption, in bytes.
+ // Length returns the length of the body of the receiver, in bytes.
Length() int
- // serializeInto serializes this ndpOption into the provided byte
+ // serializeInto serializes the receiver into the provided byte
// buffer.
//
// Note, the caller MUST provide a byte buffer with size of at least
@@ -154,15 +356,15 @@ type ndpOption interface {
// buffer is not of sufficient size.
//
// serializeInto will return the number of bytes that was used to
- // serialize this ndpOption. Implementers must only use the number of
- // bytes required to serialize this ndpOption. Callers MAY provide a
+ // serialize the receiver. Implementers must only use the number of
+ // bytes required to serialize the receiver. Callers MAY provide a
// larger buffer than required to serialize into.
serializeInto([]byte) int
}
// paddedLength returns the length of o, in bytes, with any padding bytes, if
// required.
-func paddedLength(o ndpOption) int {
+func paddedLength(o NDPOption) int {
l := o.Length()
if l == 0 {
@@ -201,7 +403,7 @@ func paddedLength(o ndpOption) int {
}
// NDPOptionsSerializer is a serializer for NDP options.
-type NDPOptionsSerializer []ndpOption
+type NDPOptionsSerializer []NDPOption
// Length returns the total number of bytes required to serialize.
func (b NDPOptionsSerializer) Length() int {
@@ -214,6 +416,46 @@ func (b NDPOptionsSerializer) Length() int {
return l
}
+// NDPSourceLinkLayerAddressOption is the NDP Source Link Layer Option
+// as defined by RFC 4861 section 4.6.1.
+//
+// It is the first X bytes following the NDP option's Type and Length field
+// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
+type NDPSourceLinkLayerAddressOption tcpip.LinkAddress
+
+// Type implements NDPOption.Type.
+func (o NDPSourceLinkLayerAddressOption) Type() NDPOptionIdentifier {
+ return NDPSourceLinkLayerAddressOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPSourceLinkLayerAddressOption) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPSourceLinkLayerAddressOption) serializeInto(b []byte) int {
+ return copy(b, o)
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPSourceLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
+// EthernetAddress will return an ethernet (MAC) address if the
+// NDPSourceLinkLayerAddressOption's body has at minimum EthernetAddressSize
+// bytes. If the body has more than EthernetAddressSize bytes, only the first
+// EthernetAddressSize bytes are returned as that is all that is needed for an
+// Ethernet address.
+func (o NDPSourceLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
+ if len(o) >= EthernetAddressSize {
+ return tcpip.LinkAddress(o[:EthernetAddressSize])
+ }
+
+ return tcpip.LinkAddress([]byte(nil))
+}
+
// NDPTargetLinkLayerAddressOption is the NDP Target Link Layer Option
// as defined by RFC 4861 section 4.6.1.
//
@@ -221,21 +463,39 @@ func (b NDPOptionsSerializer) Length() int {
// where X is the value in Length multiplied by lengthByteUnits - 2 bytes.
type NDPTargetLinkLayerAddressOption tcpip.LinkAddress
-// Type implements ndpOption.Type.
-func (o NDPTargetLinkLayerAddressOption) Type() uint8 {
+// Type implements NDPOption.Type.
+func (o NDPTargetLinkLayerAddressOption) Type() NDPOptionIdentifier {
return NDPTargetLinkLayerAddressOptionType
}
-// Length implements ndpOption.Length.
+// Length implements NDPOption.Length.
func (o NDPTargetLinkLayerAddressOption) Length() int {
return len(o)
}
-// serializeInto implements ndpOption.serializeInto.
+// serializeInto implements NDPOption.serializeInto.
func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
return copy(b, o)
}
+// String implements fmt.Stringer.String.
+func (o NDPTargetLinkLayerAddressOption) String() string {
+ return fmt.Sprintf("%T(%s)", o, tcpip.LinkAddress(o))
+}
+
+// EthernetAddress will return an ethernet (MAC) address if the
+// NDPTargetLinkLayerAddressOption's body has at minimum EthernetAddressSize
+// bytes. If the body has more than EthernetAddressSize bytes, only the first
+// EthernetAddressSize bytes are returned as that is all that is needed for an
+// Ethernet address.
+func (o NDPTargetLinkLayerAddressOption) EthernetAddress() tcpip.LinkAddress {
+ if len(o) >= EthernetAddressSize {
+ return tcpip.LinkAddress(o[:EthernetAddressSize])
+ }
+
+ return tcpip.LinkAddress([]byte(nil))
+}
+
// NDPPrefixInformation is the NDP Prefix Information option as defined by
// RFC 4861 section 4.6.2.
//
@@ -243,17 +503,17 @@ func (o NDPTargetLinkLayerAddressOption) serializeInto(b []byte) int {
// ndpPrefixInformationLength bytes.
type NDPPrefixInformation []byte
-// Type implements ndpOption.Type.
-func (o NDPPrefixInformation) Type() uint8 {
- return ndpPrefixInformationType
+// Type implements NDPOption.Type.
+func (o NDPPrefixInformation) Type() NDPOptionIdentifier {
+ return NDPPrefixInformationType
}
-// Length implements ndpOption.Length.
+// Length implements NDPOption.Length.
func (o NDPPrefixInformation) Length() int {
return ndpPrefixInformationLength
}
-// serializeInto implements ndpOption.serializeInto.
+// serializeInto implements NDPOption.serializeInto.
func (o NDPPrefixInformation) serializeInto(b []byte) int {
used := copy(b, o)
@@ -269,6 +529,17 @@ func (o NDPPrefixInformation) serializeInto(b []byte) int {
return used
}
+// String implements fmt.Stringer.String.
+func (o NDPPrefixInformation) String() string {
+ return fmt.Sprintf("%T(O=%t, A=%t, PL=%s, VL=%s, Prefix=%s)",
+ o,
+ o.OnLinkFlag(),
+ o.AutonomousAddressConfigurationFlag(),
+ o.PreferredLifetime(),
+ o.ValidLifetime(),
+ o.Subnet())
+}
+
// PrefixLength returns the value in the number of leading bits in the Prefix
// that are valid.
//
@@ -302,7 +573,7 @@ func (o NDPPrefixInformation) AutonomousAddressConfigurationFlag() bool {
//
// Note, a value of 0 implies the prefix should not be considered as on-link,
// and a value of infinity/forever is represented by
-// NDPPrefixInformationInfiniteLifetime.
+// NDPInfiniteLifetime.
func (o NDPPrefixInformation) ValidLifetime() time.Duration {
// The field is the time in seconds, as per RFC 4861 section 4.6.2.
return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpPrefixInformationValidLifetimeOffset:]))
@@ -315,7 +586,7 @@ func (o NDPPrefixInformation) ValidLifetime() time.Duration {
//
// Note, a value of 0 implies that addresses generated from the prefix should
// no longer remain preferred, and a value of infinity is represented by
-// NDPPrefixInformationInfiniteLifetime.
+// NDPInfiniteLifetime.
//
// Also note that the value of this field MUST NOT exceed the Valid Lifetime
// field to avoid preferring addresses that are no longer valid, for the
@@ -334,3 +605,295 @@ func (o NDPPrefixInformation) PreferredLifetime() time.Duration {
func (o NDPPrefixInformation) Prefix() tcpip.Address {
return tcpip.Address(o[ndpPrefixInformationPrefixOffset:][:IPv6AddressSize])
}
+
+// Subnet returns the Prefix field and Prefix Length field represented in a
+// tcpip.Subnet.
+func (o NDPPrefixInformation) Subnet() tcpip.Subnet {
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: o.Prefix(),
+ PrefixLen: int(o.PrefixLength()),
+ }
+ return addrWithPrefix.Subnet()
+}
+
+// NDPRecursiveDNSServer is the NDP Recursive DNS Server option, as defined by
+// RFC 8106 section 5.1.
+//
+// To make sure that the option meets its minimum length and does not end in the
+// middle of a DNS server's IPv6 address, the length of a valid
+// NDPRecursiveDNSServer must meet the following constraint:
+// (Length - ndpRecursiveDNSServerAddressesOffset) % IPv6AddressSize == 0
+type NDPRecursiveDNSServer []byte
+
+// Type returns the type of an NDP Recursive DNS Server option.
+//
+// Type implements NDPOption.Type.
+func (NDPRecursiveDNSServer) Type() NDPOptionIdentifier {
+ return NDPRecursiveDNSServerOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPRecursiveDNSServer) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPRecursiveDNSServer) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpRecursiveDNSServerLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPRecursiveDNSServer) String() string {
+ lt := o.Lifetime()
+ addrs, err := o.Addresses()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, addrs, lt)
+}
+
+// Lifetime returns the length of time that the DNS server addresses
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the addresses should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+//
+// Lifetime may panic if o does not have enough bytes to hold the Lifetime
+// field.
+func (o NDPRecursiveDNSServer) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpRecursiveDNSServerLifetimeOffset:]))
+}
+
+// Addresses returns the recursive DNS server IPv6 addresses that may be
+// used for name resolution.
+//
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) Addresses() ([]tcpip.Address, error) {
+ var addrs []tcpip.Address
+ return addrs, o.iterAddresses(func(addr tcpip.Address) { addrs = append(addrs, addr) })
+}
+
+// checkAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and returns any error it encounters.
+func (o NDPRecursiveDNSServer) checkAddresses() error {
+ return o.iterAddresses(nil)
+}
+
+// iterAddresses iterates over the addresses in an NDP Recursive DNS Server
+// option and calls a function with each valid unicast IPv6 address.
+//
+// Note, the addresses MAY be link-local addresses.
+func (o NDPRecursiveDNSServer) iterAddresses(fn func(tcpip.Address)) error {
+ if l := len(o); l < minNDPRecursiveDNSServerBodySize {
+ return fmt.Errorf("got %d bytes for NDP Recursive DNS Server option's body, expected at least %d bytes: %w", l, minNDPRecursiveDNSServerBodySize, io.ErrUnexpectedEOF)
+ }
+
+ o = o[ndpRecursiveDNSServerAddressesOffset:]
+ l := len(o)
+ if l%IPv6AddressSize != 0 {
+ return fmt.Errorf("NDP Recursive DNS Server option's body ends in the middle of an IPv6 address (addresses body size = %d bytes): %w", l, ErrNDPOptMalformedBody)
+ }
+
+ for i := 0; len(o) != 0; i++ {
+ addr := tcpip.Address(o[:IPv6AddressSize])
+ if !IsV6UnicastAddress(addr) {
+ return fmt.Errorf("%d-th address (%s) in NDP Recursive DNS Server option is not a valid unicast IPv6 address: %w", i, addr, ErrNDPOptMalformedBody)
+ }
+
+ if fn != nil {
+ fn(addr)
+ }
+
+ o = o[IPv6AddressSize:]
+ }
+
+ return nil
+}
+
+// NDPDNSSearchList is the NDP DNS Search List option, as defined by
+// RFC 8106 section 5.2.
+type NDPDNSSearchList []byte
+
+// Type implements NDPOption.Type.
+func (o NDPDNSSearchList) Type() NDPOptionIdentifier {
+ return NDPDNSSearchListOptionType
+}
+
+// Length implements NDPOption.Length.
+func (o NDPDNSSearchList) Length() int {
+ return len(o)
+}
+
+// serializeInto implements NDPOption.serializeInto.
+func (o NDPDNSSearchList) serializeInto(b []byte) int {
+ used := copy(b, o)
+
+ // Zero out the reserved bytes that are before the Lifetime field.
+ for i := 0; i < ndpDNSSearchListLifetimeOffset; i++ {
+ b[i] = 0
+ }
+
+ return used
+}
+
+// String implements fmt.Stringer.String.
+func (o NDPDNSSearchList) String() string {
+ lt := o.Lifetime()
+ domainNames, err := o.DomainNames()
+ if err != nil {
+ return fmt.Sprintf("%T([] valid for %s; err = %s)", o, lt, err)
+ }
+ return fmt.Sprintf("%T(%s valid for %s)", o, domainNames, lt)
+}
+
+// Lifetime returns the length of time that the DNS search list of domain names
+// in this option may be used for name resolution.
+//
+// Note, a value of 0 implies the domain names should no longer be used,
+// and a value of infinity/forever is represented by NDPInfiniteLifetime.
+func (o NDPDNSSearchList) Lifetime() time.Duration {
+ // The field is the time in seconds, as per RFC 8106 section 5.1.
+ return time.Second * time.Duration(binary.BigEndian.Uint32(o[ndpDNSSearchListLifetimeOffset:]))
+}
+
+// DomainNames returns a DNS search list of domain names.
+//
+// DomainNames will parse the backing buffer as outlined by RFC 1035 section
+// 3.1 and return a list of strings, with all domain names in lower case.
+func (o NDPDNSSearchList) DomainNames() ([]string, error) {
+ var domainNames []string
+ return domainNames, o.iterDomainNames(func(domainName string) { domainNames = append(domainNames, domainName) })
+}
+
+// checkDomainNames iterates over the domain names in an NDP DNS Search List
+// option and returns any error it encounters.
+func (o NDPDNSSearchList) checkDomainNames() error {
+ return o.iterDomainNames(nil)
+}
+
+// iterDomainNames iterates over the domain names in an NDP DNS Search List
+// option and calls a function with each valid domain name.
+func (o NDPDNSSearchList) iterDomainNames(fn func(string)) error {
+ if l := len(o); l < minNDPDNSSearchListBodySize {
+ return fmt.Errorf("got %d bytes for NDP DNS Search List option's body, expected at least %d bytes: %w", l, minNDPDNSSearchListBodySize, io.ErrUnexpectedEOF)
+ }
+
+ var searchList bytes.Reader
+ searchList.Reset(o[ndpDNSSearchListDomainNamesOffset:])
+
+ var scratch [maxDomainNameLength]byte
+ domainName := bytes.NewBuffer(scratch[:])
+
+ // Parse the domain names, as per RFC 1035 section 3.1.
+ for searchList.Len() != 0 {
+ domainName.Reset()
+
+ // Parse a label within a domain name, as per RFC 1035 section 3.1.
+ for {
+ // The first byte is the label length.
+ labelLenByte, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ // ReadByte should only ever return nil or io.EOF.
+ panic(fmt.Sprintf("unexpected error when reading a label's length: %s", err))
+ }
+
+ // We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected
+ // once we start parsing a domain name; we expect the buffer to contain
+ // enough bytes for the whole domain name.
+ return fmt.Errorf("unexpected exhausted buffer while parsing a new label for a domain from NDP Search List option: %w", io.ErrUnexpectedEOF)
+ }
+ labelLen := int(labelLenByte)
+
+ // A zero-length label implies the end of a domain name.
+ if labelLen == 0 {
+ // If the domain name is empty or we have no callback function, do
+ // nothing further with the current domain name.
+ if domainName.Len() == 0 || fn == nil {
+ break
+ }
+
+ // Ignore the trailing period in the parsed domain name.
+ domainName.Truncate(domainName.Len() - 1)
+ fn(domainName.String())
+ break
+ }
+
+ // The label's length must not exceed the maximum length for a label.
+ if labelLen > maxDomainNameLabelLength {
+ return fmt.Errorf("label length of %d bytes is greater than the max label length of %d bytes for an NDP Search List option: %w", labelLen, maxDomainNameLabelLength, ErrNDPOptMalformedBody)
+ }
+
+ // The label (and trailing period) must not make the domain name too long.
+ if labelLen+1 > domainName.Cap()-domainName.Len() {
+ return fmt.Errorf("label would make an NDP Search List option's domain name longer than the max domain name length of %d bytes: %w", maxDomainNameLength, ErrNDPOptMalformedBody)
+ }
+
+ // Copy the label and add a trailing period.
+ for i := 0; i < labelLen; i++ {
+ b, err := searchList.ReadByte()
+ if err != nil {
+ if err != io.EOF {
+ panic(fmt.Sprintf("unexpected error when reading domain name's label: %s", err))
+ }
+
+ return fmt.Errorf("read %d out of %d bytes for a domain name's label from NDP Search List option: %w", i, labelLen, io.ErrUnexpectedEOF)
+ }
+
+ // As per RFC 1035 section 2.3.1:
+ // 1) the label must only contain ASCII include letters, digits and
+ // hyphens
+ // 2) the first character in a label must be a letter
+ // 3) the last letter in a label must be a letter or digit
+
+ if !isLetter(b) {
+ if i == 0 {
+ return fmt.Errorf("first character of a domain name's label in an NDP Search List option must be a letter, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+
+ if b == '-' {
+ if i == labelLen-1 {
+ return fmt.Errorf("last character of a domain name's label in an NDP Search List option must not be a hyphen (-): %w", ErrNDPOptMalformedBody)
+ }
+ } else if !isDigit(b) {
+ return fmt.Errorf("domain name's label in an NDP Search List option may only contain letters, digits and hyphens, got character code = %d: %w", b, ErrNDPOptMalformedBody)
+ }
+ }
+
+ // If b is an upper case character, make it lower case.
+ if isUpperLetter(b) {
+ b = b - 'A' + 'a'
+ }
+
+ if err := domainName.WriteByte(b); err != nil {
+ panic(fmt.Sprintf("unexpected error writing label to domain name buffer: %s", err))
+ }
+ }
+ if err := domainName.WriteByte('.'); err != nil {
+ panic(fmt.Sprintf("unexpected error writing trailing period to domain name buffer: %s", err))
+ }
+ }
+ }
+
+ return nil
+}
+
+func isLetter(b byte) bool {
+ return b >= 'a' && b <= 'z' || isUpperLetter(b)
+}
+
+func isUpperLetter(b byte) bool {
+ return b >= 'A' && b <= 'Z'
+}
+
+func isDigit(b byte) bool {
+ return b >= '0' && b <= '9'
+}
diff --git a/pkg/tcpip/header/ndp_router_solicit.go b/pkg/tcpip/header/ndp_router_solicit.go
new file mode 100644
index 000000000..9e67ba95d
--- /dev/null
+++ b/pkg/tcpip/header/ndp_router_solicit.go
@@ -0,0 +1,36 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package header
+
+// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
+// the body of an ICMPv6 packet.
+//
+// See RFC 4861 section 4.1 for more details.
+type NDPRouterSolicit []byte
+
+const (
+ // NDPRSMinimumSize is the minimum size of a valid NDP Router
+ // Solicitation message (body of an ICMPv6 packet).
+ NDPRSMinimumSize = 4
+
+ // ndpRSOptionsOffset is the start of the NDP options in an
+ // NDPRouterSolicit.
+ ndpRSOptionsOffset = 4
+)
+
+// Options returns an NDPOptions of the the options body.
+func (b NDPRouterSolicit) Options() NDPOptions {
+ return NDPOptions(b[ndpRSOptionsOffset:])
+}
diff --git a/pkg/tcpip/header/ndp_test.go b/pkg/tcpip/header/ndp_test.go
index 0bbf67a2b..dc4591253 100644
--- a/pkg/tcpip/header/ndp_test.go
+++ b/pkg/tcpip/header/ndp_test.go
@@ -16,9 +16,14 @@ package header
import (
"bytes"
+ "errors"
+ "fmt"
+ "io"
+ "regexp"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -36,18 +41,18 @@ func TestNDPNeighborSolicit(t *testing.T) {
ns := NDPNeighborSolicit(b)
addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
if got := ns.TargetAddress(); got != addr {
- t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr)
+ t.Errorf("got ns.TargetAddress = %s, want %s", got, addr)
}
// Test updating the Target Address.
addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
ns.SetTargetAddress(addr2)
if got := ns.TargetAddress(); got != addr2 {
- t.Fatalf("got ns.TargetAddress = %s, want %s", got, addr2)
+ t.Errorf("got ns.TargetAddress = %s, want %s", got, addr2)
}
// Make sure the address got updated in the backing buffer.
if got := tcpip.Address(b[ndpNSTargetAddessOffset:][:IPv6AddressSize]); got != addr2 {
- t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
}
}
@@ -65,56 +70,56 @@ func TestNDPNeighborAdvert(t *testing.T) {
na := NDPNeighborAdvert(b)
addr := tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10")
if got := na.TargetAddress(); got != addr {
- t.Fatalf("got TargetAddress = %s, want %s", got, addr)
+ t.Errorf("got TargetAddress = %s, want %s", got, addr)
}
// Test getting the Router Flag.
if got := na.RouterFlag(); !got {
- t.Fatalf("got RouterFlag = false, want = true")
+ t.Errorf("got RouterFlag = false, want = true")
}
// Test getting the Solicited Flag.
if got := na.SolicitedFlag(); got {
- t.Fatalf("got SolicitedFlag = true, want = false")
+ t.Errorf("got SolicitedFlag = true, want = false")
}
// Test getting the Override Flag.
if got := na.OverrideFlag(); !got {
- t.Fatalf("got OverrideFlag = false, want = true")
+ t.Errorf("got OverrideFlag = false, want = true")
}
// Test updating the Target Address.
addr2 := tcpip.Address("\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x11")
na.SetTargetAddress(addr2)
if got := na.TargetAddress(); got != addr2 {
- t.Fatalf("got TargetAddress = %s, want %s", got, addr2)
+ t.Errorf("got TargetAddress = %s, want %s", got, addr2)
}
// Make sure the address got updated in the backing buffer.
if got := tcpip.Address(b[ndpNATargetAddressOffset:][:IPv6AddressSize]); got != addr2 {
- t.Fatalf("got targetaddress buffer = %s, want %s", got, addr2)
+ t.Errorf("got targetaddress buffer = %s, want %s", got, addr2)
}
// Test updating the Router Flag.
na.SetRouterFlag(false)
if got := na.RouterFlag(); got {
- t.Fatalf("got RouterFlag = true, want = false")
+ t.Errorf("got RouterFlag = true, want = false")
}
// Test updating the Solicited Flag.
na.SetSolicitedFlag(true)
if got := na.SolicitedFlag(); !got {
- t.Fatalf("got SolicitedFlag = false, want = true")
+ t.Errorf("got SolicitedFlag = false, want = true")
}
// Test updating the Override Flag.
na.SetOverrideFlag(false)
if got := na.OverrideFlag(); got {
- t.Fatalf("got OverrideFlag = true, want = false")
+ t.Errorf("got OverrideFlag = true, want = false")
}
// Make sure flags got updated in the backing buffer.
if got := b[ndpNAFlagsOffset]; got != 64 {
- t.Fatalf("got flags byte = %d, want = 64")
+ t.Errorf("got flags byte = %d, want = 64", got)
}
}
@@ -128,27 +133,181 @@ func TestNDPRouterAdvert(t *testing.T) {
ra := NDPRouterAdvert(b)
if got := ra.CurrHopLimit(); got != 64 {
- t.Fatalf("got ra.CurrHopLimit = %d, want = 64", got)
+ t.Errorf("got ra.CurrHopLimit = %d, want = 64", got)
}
if got := ra.ManagedAddrConfFlag(); !got {
- t.Fatalf("got ManagedAddrConfFlag = false, want = true")
+ t.Errorf("got ManagedAddrConfFlag = false, want = true")
}
if got := ra.OtherConfFlag(); got {
- t.Fatalf("got OtherConfFlag = true, want = false")
+ t.Errorf("got OtherConfFlag = true, want = false")
}
if got, want := ra.RouterLifetime(), time.Second*258; got != want {
- t.Fatalf("got ra.RouterLifetime = %d, want = %d", got, want)
+ t.Errorf("got ra.RouterLifetime = %d, want = %d", got, want)
}
if got, want := ra.ReachableTime(), time.Millisecond*50595078; got != want {
- t.Fatalf("got ra.ReachableTime = %d, want = %d", got, want)
+ t.Errorf("got ra.ReachableTime = %d, want = %d", got, want)
}
if got, want := ra.RetransTimer(), time.Millisecond*117967114; got != want {
- t.Fatalf("got ra.RetransTimer = %d, want = %d", got, want)
+ t.Errorf("got ra.RetransTimer = %d, want = %d", got, want)
+ }
+}
+
+// TestNDPSourceLinkLayerAddressOptionEthernetAddress tests getting the
+// Ethernet address from an NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected tcpip.LinkAddress
+ }{
+ {
+ "ValidMAC",
+ []byte{1, 2, 3, 4, 5, 6},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ "SLLBodyTooShort",
+ []byte{1, 2, 3, 4, 5},
+ tcpip.LinkAddress([]byte(nil)),
+ },
+ {
+ "SLLBodyLargerThanNeeded",
+ []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ sll := NDPSourceLinkLayerAddressOption(test.buf)
+ if got := sll.EthernetAddress(); got != test.expected {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, test.expected)
+ }
+ })
+ }
+}
+
+// TestNDPSourceLinkLayerAddressOptionSerialize tests serializing a
+// NDPSourceLinkLayerAddressOption.
+func TestNDPSourceLinkLayerAddressOptionSerialize(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedBuf []byte
+ addr tcpip.LinkAddress
+ }{
+ {
+ "Ethernet",
+ make([]byte, 8),
+ []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ "\x01\x02\x03\x04\x05\x06",
+ },
+ {
+ "Padding",
+ []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
+ []byte{1, 2, 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0},
+ "\x01\x02\x03\x04\x05\x06\x07\x08",
+ },
+ {
+ "Empty",
+ nil,
+ nil,
+ "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ serializer := NDPOptionsSerializer{
+ NDPSourceLinkLayerAddressOption(test.addr),
+ }
+ if got, want := int(serializer.Length()), len(test.expectedBuf); got != want {
+ t.Fatalf("got Length = %d, want = %d", got, want)
+ }
+ opts.Serialize(serializer)
+ if !bytes.Equal(test.buf, test.expectedBuf) {
+ t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ if len(test.expectedBuf) > 0 {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+ sll := next.(NDPSourceLinkLayerAddressOption)
+ if got, want := []byte(sll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
+ t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+
+ if got, want := sll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
+ t.Errorf("got sll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+
+ // Iterator should not return anything else.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
+// TestNDPTargetLinkLayerAddressOptionEthernetAddress tests getting the
+// Ethernet address from an NDPTargetLinkLayerAddressOption.
+func TestNDPTargetLinkLayerAddressOptionEthernetAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expected tcpip.LinkAddress
+ }{
+ {
+ "ValidMAC",
+ []byte{1, 2, 3, 4, 5, 6},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ {
+ "TLLBodyTooShort",
+ []byte{1, 2, 3, 4, 5},
+ tcpip.LinkAddress([]byte(nil)),
+ },
+ {
+ "TLLBodyLargerThanNeeded",
+ []byte{1, 2, 3, 4, 5, 6, 7, 8},
+ tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ tll := NDPTargetLinkLayerAddressOption(test.buf)
+ if got := tll.EthernetAddress(); got != test.expected {
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, test.expected)
+ }
+ })
}
}
@@ -175,8 +334,8 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
},
{
"Empty",
- []byte{},
- []byte{},
+ nil,
+ nil,
"",
},
}
@@ -194,6 +353,44 @@ func TestNDPTargetLinkLayerAddressOptionSerialize(t *testing.T) {
if !bytes.Equal(test.buf, test.expectedBuf) {
t.Fatalf("got b = %d, want = %d", test.buf, test.expectedBuf)
}
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ if len(test.expectedBuf) > 0 {
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ }
+ tll := next.(NDPTargetLinkLayerAddressOption)
+ if got, want := []byte(tll), test.expectedBuf[2:]; !bytes.Equal(got, want) {
+ t.Fatalf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+
+ if got, want := tll.EthernetAddress(), tcpip.LinkAddress(test.expectedBuf[2:][:EthernetAddressSize]); got != want {
+ t.Errorf("got tll.EthernetAddress = %s, want = %s", got, want)
+ }
+ }
+
+ // Iterator should not return anything else.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
})
}
}
@@ -232,39 +429,1093 @@ func TestNDPPrefixInformationOption(t *testing.T) {
t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expectedBuf)
}
- // First two bytes are the Type and Length fields, which are not part of
- // the option body.
- pi := NDPPrefixInformation(targetBuf[2:])
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ }
+
+ pi := next.(NDPPrefixInformation)
if got := pi.Type(); got != 3 {
- t.Fatalf("got Type = %d, want = 3", got)
+ t.Errorf("got Type = %d, want = 3", got)
}
if got := pi.Length(); got != 30 {
- t.Fatalf("got Length = %d, want = 30", got)
+ t.Errorf("got Length = %d, want = 30", got)
}
if got := pi.PrefixLength(); got != 43 {
- t.Fatalf("got PrefixLength = %d, want = 43", got)
+ t.Errorf("got PrefixLength = %d, want = 43", got)
}
if pi.OnLinkFlag() {
- t.Fatalf("got OnLinkFlag = true, want = false")
+ t.Error("got OnLinkFlag = true, want = false")
}
if !pi.AutonomousAddressConfigurationFlag() {
- t.Fatalf("got AutonomousAddressConfigurationFlag = false, want = true")
+ t.Error("got AutonomousAddressConfigurationFlag = false, want = true")
}
if got, want := pi.ValidLifetime(), 16909060*time.Second; got != want {
- t.Fatalf("got ValidLifetime = %d, want = %d", got, want)
+ t.Errorf("got ValidLifetime = %d, want = %d", got, want)
}
if got, want := pi.PreferredLifetime(), 84281096*time.Second; got != want {
- t.Fatalf("got PreferredLifetime = %d, want = %d", got, want)
+ t.Errorf("got PreferredLifetime = %d, want = %d", got, want)
}
if got, want := pi.Prefix(), tcpip.Address("\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18"); got != want {
- t.Fatalf("got Prefix = %s, want = %s", got, want)
+ t.Errorf("got Prefix = %s, want = %s", got, want)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+func TestNDPRecursiveDNSServerOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 25, 3, 0, 0,
+ 1, 2, 4, 8,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPRecursiveDNSServer(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Type(); got != 25 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16909320*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ want := []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ }
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, want); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+func TestNDPRecursiveDNSServerOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ addrs []tcpip.Address
+ }{
+ {
+ "Valid1Addr",
+ []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ },
+ },
+ {
+ "Valid2Addr",
+ []byte{
+ 25, 5, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ },
+ },
+ {
+ "Valid3Addr",
+ []byte{
+ 25, 7, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16,
+ 17, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 17,
+ },
+ 0,
+ []tcpip.Address{
+ "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x10",
+ "\x11\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x11",
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ // Iterator should get our option.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPRecursiveDNSServerOptionType {
+ t.Fatalf("got Type = %d, want = %d", got, NDPRecursiveDNSServerOptionType)
+ }
+
+ opt, ok := next.(NDPRecursiveDNSServer)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPRecursiveDNSServer", next)
+ }
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ addrs, err := opt.Addresses()
+ if err != nil {
+ t.Errorf("opt.Addresses() = %s", err)
+ }
+ if diff := cmp.Diff(addrs, test.addrs); diff != "" {
+ t.Errorf("mismatched addresses (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+ })
+ }
+}
+
+// TestNDPDNSSearchListOption tests the getters of NDPDNSSearchList.
+func TestNDPDNSSearchListOption(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ lifetime time.Duration
+ domainNames []string
+ err error
+ }{
+ {
+ name: "Valid1Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "abc",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Label",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 5,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 0,
+ 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 5 * time.Second,
+ domainNames: []string{
+ "abc.abcd",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3Label",
+ buf: []byte{
+ 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ },
+ lifetime: 16777216 * time.Second,
+ domainNames: []string{
+ "abc.abcd.e",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid2Domains",
+ buf: []byte{
+ 0, 0,
+ 1, 2, 3, 4,
+ 3, 'a', 'b', 'c',
+ 0,
+ 2, 'd', 'e',
+ 3, 'x', 'y', 'z',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: 16909060 * time.Second,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid3DomainsMixedCase",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ 1, 'J',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ "j",
+ },
+ err: nil,
+ },
+ {
+ name: "ValidDomainAfterNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 'B', 'c',
+ 0, 0, 0, 0,
+ 2, 'd', 'E',
+ 3, 'X', 'y', 'z',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abc",
+ "de.xyz",
+ },
+ err: nil,
+ },
+ {
+ name: "Valid0Domains",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 0,
+ 0, 0, 0, 0, 0, 0, 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: nil,
+ },
+ {
+ name: "NoTrailingNull",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLength",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 8, 'a', 'b', 'c', 'd', 'e', 'f', 'g',
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "IncorrectLengthWithNULL",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 7, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "LabelOfLength63",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk",
+ },
+ err: nil,
+ },
+ {
+ name: "LabelOfLength64",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 64, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DomainNameOfLength255",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: []string{
+ "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijk.abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghij",
+ },
+ err: nil,
+ },
+ {
+ name: "DomainNameOfLength256",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ },
+ lifetime: 0,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '9', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "StartingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, '-', 'b', 'c',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingHyphenForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '-',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: nil,
+ err: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "EndingDigitForLabel",
+ buf: []byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 3, 'a', 'b', '9',
+ 0,
+ 0, 0, 0,
+ },
+ lifetime: time.Second,
+ domainNames: []string{
+ "ab9",
+ },
+ err: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opt := NDPDNSSearchList(test.buf)
+
+ if got := opt.Lifetime(); got != test.lifetime {
+ t.Errorf("got Lifetime = %d, want = %d", got, test.lifetime)
+ }
+ domainNames, err := opt.DomainNames()
+ if !errors.Is(err, test.err) {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, test.domainNames); diff != "" {
+ t.Errorf("mismatched domain names (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestNDPSearchListOptionDomainNameLabelInvalidSymbols(t *testing.T) {
+ for r := rune(0); r <= 255; r++ {
+ t.Run(fmt.Sprintf("RuneVal=%d", r), func(t *testing.T) {
+ buf := []byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 3, 'a', 0 /* will be replaced */, 'c',
+ 0,
+ 0, 0, 0,
+ }
+ buf[8] = uint8(r)
+ opt := NDPDNSSearchList(buf)
+
+ // As per RFC 1035 section 2.3.1, the label must only include ASCII
+ // letters, digits and hyphens (a-z, A-Z, 0-9, -).
+ var expectedErr error
+ re := regexp.MustCompile(`[a-zA-Z0-9-]`)
+ if !re.Match([]byte{byte(r)}) {
+ expectedErr = ErrNDPOptMalformedBody
+ }
+
+ if domainNames, err := opt.DomainNames(); !errors.Is(err, expectedErr) {
+ t.Errorf("got opt.DomainNames() = (%s, %v), want = (_, %v)", domainNames, err, ErrNDPOptMalformedBody)
+ }
+ })
+ }
+}
+
+func TestNDPDNSSearchListOptionSerialize(t *testing.T) {
+ b := []byte{
+ 9, 8,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ }
+ targetBuf := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
+ expected := []byte{
+ 31, 3, 0, 0,
+ 1, 0, 0, 0,
+ 3, 'a', 'b', 'c',
+ 4, 'a', 'b', 'c', 'd',
+ 1, 'e',
+ 0,
+ 0, 0, 0, 0,
+ }
+ opts := NDPOptions(targetBuf)
+ serializer := NDPOptionsSerializer{
+ NDPDNSSearchList(b),
+ }
+ if got, want := opts.Serialize(serializer), len(expected); got != want {
+ t.Errorf("got Serialize = %d, want = %d", got, want)
+ }
+ if !bytes.Equal(targetBuf, expected) {
+ t.Fatalf("got targetBuf = %x, want = %x", targetBuf, expected)
+ }
+
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got := next.Type(); got != NDPDNSSearchListOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPDNSSearchListOptionType)
+ }
+
+ opt, ok := next.(NDPDNSSearchList)
+ if !ok {
+ t.Fatalf("next (type = %T) cannot be casted to an NDPDNSSearchList", next)
+ }
+ if got := opt.Type(); got != 31 {
+ t.Errorf("got Type = %d, want = 31", got)
+ }
+ if got := opt.Length(); got != 22 {
+ t.Errorf("got Length = %d, want = 22", got)
+ }
+ if got, want := opt.Lifetime(), 16777216*time.Second; got != want {
+ t.Errorf("got Lifetime = %s, want = %s", got, want)
+ }
+ domainNames, err := opt.DomainNames()
+ if err != nil {
+ t.Errorf("opt.DomainNames() = %s", err)
+ }
+ if diff := cmp.Diff(domainNames, []string{"abc.abcd.e"}); diff != "" {
+ t.Errorf("domain names mismatch (-want +got):\n%s", diff)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
+ }
+}
+
+// TestNDPOptionsIterCheck tests that Iter will return false if the NDPOptions
+// the iterator was returned for is malformed.
+func TestNDPOptionsIterCheck(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expectedErr error
+ }{
+ {
+ name: "ZeroLengthField",
+ buf: []byte{0, 0, 0, 0, 0, 0, 0, 0},
+ expectedErr: ErrNDPOptMalformedHeader,
+ },
+ {
+ name: "ValidSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
+ },
+ {
+ name: "TooSmallSourceLinkLayerAddressOption",
+ buf: []byte{1, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "ValidTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5, 6},
+ expectedErr: nil,
+ },
+ {
+ name: "TooSmallTargetLinkLayerAddressOption",
+ buf: []byte{2, 1, 1, 2, 3, 4, 5},
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "ValidPrefixInformation",
+ buf: []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "TooSmallPrefixInformation",
+ buf: []byte{
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "InvalidPrefixInformationLength",
+ buf: []byte{
+ 3, 3, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformation",
+ buf: []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
+ // Target Link-Layer Address.
+ 2, 1, 7, 8, 9, 10, 11, 12,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "ValidSourceAndTargetLinkLayerAddressWithPrefixInformationWithUnrecognized",
+ buf: []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
+ // Target Link-Layer Address.
+ 2, 1, 7, 8, 9, 10, 11, 12,
+
+ // 255 is an unrecognized type. If 255 ends up
+ // being the type for some recognized type,
+ // update 255 to some other unrecognized value.
+ 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "InvalidRecursiveDNSServerCutsOffAddress",
+ buf: []byte{
+ 25, 4, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ 0, 1, 2, 3, 4, 5, 6, 7,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "InvalidRecursiveDNSServerInvalidLengthField",
+ buf: []byte{
+ 25, 2, 0, 0,
+ 0, 0, 0, 0,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "RecursiveDNSServerTooSmall",
+ buf: []byte{
+ 25, 1, 0, 0,
+ 0, 0, 0,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ {
+ name: "RecursiveDNSServerMulticast",
+ buf: []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "RecursiveDNSServerUnspecified",
+ buf: []byte{
+ 25, 3, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DNSSearchListLargeCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 62, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListNonCompliantRFC1035",
+ buf: []byte{
+ 31, 33, 0, 0,
+ 0, 0, 0, 0,
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 63, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
+ 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
+ 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
+ 'i', 'j', 'k',
+ 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ expectedErr: ErrNDPOptMalformedBody,
+ },
+ {
+ name: "DNSSearchListValidSmall",
+ buf: []byte{
+ 31, 2, 0, 0,
+ 0, 0, 0, 0,
+ 6, 'a', 'b', 'c', 'd', 'e', 'f',
+ 0,
+ },
+ expectedErr: nil,
+ },
+ {
+ name: "DNSSearchListTooSmall",
+ buf: []byte{
+ 31, 1, 0, 0,
+ 0, 0, 0,
+ },
+ expectedErr: io.ErrUnexpectedEOF,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := NDPOptions(test.buf)
+
+ if _, err := opts.Iter(true); !errors.Is(err, test.expectedErr) {
+ t.Fatalf("got Iter(true) = (_, %v), want = (_, %v)", err, test.expectedErr)
+ }
+
+ // test.buf may be malformed but we chose not to check
+ // the iterator so it must return true.
+ if _, err := opts.Iter(false); err != nil {
+ t.Fatalf("got Iter(false) = (_, %s), want = (_, nil)", err)
+ }
+ })
+ }
+}
+
+// TestNDPOptionsIter tests that we can iterator over a valid NDPOptions. Note,
+// this test does not actually check any of the option's getters, it simply
+// checks the option Type and Body. We have other tests that tests the option
+// field gettings given an option body and don't need to duplicate those tests
+// here.
+func TestNDPOptionsIter(t *testing.T) {
+ buf := []byte{
+ // Source Link-Layer Address.
+ 1, 1, 1, 2, 3, 4, 5, 6,
+
+ // Target Link-Layer Address.
+ 2, 1, 7, 8, 9, 10, 11, 12,
+
+ // 255 is an unrecognized type. If 255 ends up being the type
+ // for some recognized type, update 255 to some other
+ // unrecognized value. Note, this option should be skipped when
+ // iterating.
+ 255, 2, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 8,
+
+ // Prefix information.
+ 3, 4, 43, 64,
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+ 0, 0, 0, 0,
+ 9, 10, 11, 12,
+ 13, 14, 15, 16,
+ 17, 18, 19, 20,
+ 21, 22, 23, 24,
+ }
+
+ opts := NDPOptions(buf)
+ it, err := opts.Iter(true)
+ if err != nil {
+ t.Fatalf("got Iter = (_, %s), want = (_, nil)", err)
+ }
+
+ // Test the first (Source Link-Layer) option.
+ next, done, err := it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := []byte(next.(NDPSourceLinkLayerAddressOption)), buf[2:][:6]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPSourceLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPSourceLinkLayerAddressOptionType)
+ }
+
+ // Test the next (Target Link-Layer) option.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := []byte(next.(NDPTargetLinkLayerAddressOption)), buf[10:][:6]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPTargetLinkLayerAddressOptionType {
+ t.Errorf("got Type = %d, want = %d", got, NDPTargetLinkLayerAddressOptionType)
+ }
+
+ // Test the next (Prefix Information) option.
+ // Note, the unrecognized option should be skipped.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Fatalf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if done {
+ t.Fatal("got Next = (_, true, _), want = (_, false, _)")
+ }
+ if got, want := next.(NDPPrefixInformation), buf[34:][:30]; !bytes.Equal(got, want) {
+ t.Errorf("got Next = (%x, _, _), want = (%x, _, _)", got, want)
+ }
+ if got := next.Type(); got != NDPPrefixInformationType {
+ t.Errorf("got Type = %d, want = %d", got, NDPPrefixInformationType)
+ }
+
+ // Iterator should not return anything else.
+ next, done, err = it.Next()
+ if err != nil {
+ t.Errorf("got Next = (_, _, %s), want = (_, _, nil)", err)
+ }
+ if !done {
+ t.Error("got Next = (_, false, _), want = (_, true, _)")
+ }
+ if next != nil {
+ t.Errorf("got Next = (%x, _, _), want = (nil, _, _)", next)
}
}
diff --git a/pkg/tcpip/header/ndpoptionidentifier_string.go b/pkg/tcpip/header/ndpoptionidentifier_string.go
new file mode 100644
index 000000000..6fe9a336b
--- /dev/null
+++ b/pkg/tcpip/header/ndpoptionidentifier_string.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NDPOptionIdentifier ."; DO NOT EDIT.
+
+package header
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[NDPSourceLinkLayerAddressOptionType-1]
+ _ = x[NDPTargetLinkLayerAddressOptionType-2]
+ _ = x[NDPPrefixInformationType-3]
+ _ = x[NDPRecursiveDNSServerOptionType-25]
+}
+
+const (
+ _NDPOptionIdentifier_name_0 = "NDPSourceLinkLayerAddressOptionTypeNDPTargetLinkLayerAddressOptionTypeNDPPrefixInformationType"
+ _NDPOptionIdentifier_name_1 = "NDPRecursiveDNSServerOptionType"
+)
+
+var (
+ _NDPOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
+)
+
+func (i NDPOptionIdentifier) String() string {
+ switch {
+ case 1 <= i && i <= 3:
+ i -= 1
+ return _NDPOptionIdentifier_name_0[_NDPOptionIdentifier_index_0[i]:_NDPOptionIdentifier_index_0[i+1]]
+ case i == 25:
+ return _NDPOptionIdentifier_name_1
+ default:
+ return "NDPOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 82cfe785c..4c6f808e5 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -66,6 +66,14 @@ const (
TCPOptionSACK = 5
)
+// Option Lengths.
+const (
+ TCPOptionMSSLength = 4
+ TCPOptionTSLength = 10
+ TCPOptionWSLength = 3
+ TCPOptionSackPermittedLength = 2
+)
+
// TCPFields contains the fields of a TCP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type TCPFields struct {
@@ -81,7 +89,8 @@ type TCPFields struct {
// AckNum is the "acknowledgement number" field of a TCP packet.
AckNum uint32
- // DataOffset is the "data offset" field of a TCP packet.
+ // DataOffset is the "data offset" field of a TCP packet. It is the length of
+ // the TCP header in bytes.
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
@@ -213,7 +222,8 @@ func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
}
-// DataOffset returns the "data offset" field of the tcp header.
+// DataOffset returns the "data offset" field of the tcp header. The return
+// value is the length of the TCP header in bytes.
func (b TCP) DataOffset() uint8 {
return (b[TCPDataOffset] >> 4) * 4
}
@@ -238,6 +248,11 @@ func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
}
+// UrgentPointer returns the "urgent pointer" field of the tcp header.
+func (b TCP) UrgentPointer() uint16 {
+ return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
+}
+
// SetSourcePort sets the "source port" field of the tcp header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
@@ -253,6 +268,37 @@ func (b TCP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[TCPChecksumOffset:], checksum)
}
+// SetDataOffset sets the data offset field of the tcp header. headerLen should
+// be the length of the TCP header in bytes.
+func (b TCP) SetDataOffset(headerLen uint8) {
+ b[TCPDataOffset] = (headerLen / 4) << 4
+}
+
+// SetSequenceNumber sets the sequence number field of the tcp header.
+func (b TCP) SetSequenceNumber(seqNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
+}
+
+// SetAckNumber sets the ack number field of the tcp header.
+func (b TCP) SetAckNumber(ackNum uint32) {
+ binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
+}
+
+// SetFlags sets the flags field of the tcp header.
+func (b TCP) SetFlags(flags uint8) {
+ b[TCPFlagsOffset] = flags
+}
+
+// SetWindowSize sets the window size field of the tcp header.
+func (b TCP) SetWindowSize(rcvwnd uint16) {
+ binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
+}
+
+// SetUrgentPoiner sets the window size field of the tcp header.
+func (b TCP) SetUrgentPoiner(urgentPointer uint16) {
+ binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
+}
+
// CalculateChecksum calculates the checksum of the tcp segment.
// partialChecksum is the checksum of the network-layer pseudo-header
// and the checksum of the segment data.
@@ -456,14 +502,11 @@ func ParseTCPOptions(b []byte) TCPOptions {
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeMSSOption(mss uint32, b []byte) int {
- // mssOptionSize is the number of bytes in a valid MSS option.
- const mssOptionSize = 4
-
- if len(b) < mssOptionSize {
+ if len(b) < TCPOptionMSSLength {
return 0
}
- b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss)
- return mssOptionSize
+ b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss)
+ return TCPOptionMSSLength
}
// EncodeWSOption encodes the WS TCP option with the WS value in the
@@ -471,10 +514,10 @@ func EncodeMSSOption(mss uint32, b []byte) int {
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeWSOption(ws int, b []byte) int {
- if len(b) < 3 {
+ if len(b) < TCPOptionWSLength {
return 0
}
- b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws)
+ b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws)
return int(b[1])
}
@@ -483,10 +526,10 @@ func EncodeWSOption(ws int, b []byte) int {
// just returns without encoding anything. It returns the number of bytes
// written to the provided buffer.
func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
- if len(b) < 10 {
+ if len(b) < TCPOptionTSLength {
return 0
}
- b[0], b[1] = TCPOptionTS, 10
+ b[0], b[1] = TCPOptionTS, TCPOptionTSLength
binary.BigEndian.PutUint32(b[2:], tsVal)
binary.BigEndian.PutUint32(b[6:], tsEcr)
return int(b[1])
@@ -497,11 +540,11 @@ func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
// encoding anything. It returns the number of bytes written to the provided
// buffer.
func EncodeSACKPermittedOption(b []byte) int {
- if len(b) < 2 {
+ if len(b) < TCPOptionSackPermittedLength {
return 0
}
- b[0], b[1] = TCPOptionSACKPermitted, 2
+ b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength
return int(b[1])
}
@@ -556,3 +599,23 @@ func AddTCPOptionPadding(options []byte, offset int) int {
}
return paddingToAdd
}
+
+// Acceptable checks if a segment that starts at segSeq and has length segLen is
+// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
+// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
+func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
+ if rcvNxt == rcvAcc {
+ return segLen == 0 && segSeq == rcvNxt
+ }
+ if segLen == 0 {
+ // rcvWnd is incremented by 1 because that is Linux's behavior despite the
+ // RFC.
+ return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
+ }
+ // Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
+ // the payload, so we'll accept any payload that overlaps the receieve window.
+ // segSeq < rcvAcc is more correct according to RFC, however, Linux does it
+ // differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior
+ // as Linux.
+ return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc)
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index 74412c894..9339d637f 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -99,6 +99,11 @@ func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
+// SetLength sets the "length" field of the udp header.
+func (b UDP) SetLength(length uint16) {
+ binary.BigEndian.PutUint16(b[udpLength:], length)
+}
+
// CalculateChecksum calculates the checksum of the udp packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
deleted file mode 100644
index cc5f531e2..000000000
--- a/pkg/tcpip/iptables/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "iptables",
- srcs = [
- "iptables.go",
- "targets.go",
- "types.go",
- ],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
- visibility = ["//visibility:public"],
- deps = ["//pkg/tcpip/buffer"],
-)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
deleted file mode 100644
index 68c68d4aa..000000000
--- a/pkg/tcpip/iptables/iptables.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2019 The gVisor authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package iptables supports packet filtering and manipulation via the iptables
-// tool.
-package iptables
-
-const (
- tablenameNat = "nat"
- tablenameMangle = "mangle"
-)
-
-// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
-const (
- chainNamePrerouting = "PREROUTING"
- chainNameInput = "INPUT"
- chainNameForward = "FORWARD"
- chainNameOutput = "OUTPUT"
- chainNamePostrouting = "POSTROUTING"
-)
-
-// DefaultTables returns a default set of tables. Each chain is set to accept
-// all packets.
-func DefaultTables() IPTables {
- return IPTables{
- Tables: map[string]Table{
- tablenameNat: Table{
- BuiltinChains: map[Hook]Chain{
- Prerouting: unconditionalAcceptChain(chainNamePrerouting),
- Input: unconditionalAcceptChain(chainNameInput),
- Output: unconditionalAcceptChain(chainNameOutput),
- Postrouting: unconditionalAcceptChain(chainNamePostrouting),
- },
- DefaultTargets: map[Hook]Target{
- Prerouting: UnconditionalAcceptTarget{},
- Input: UnconditionalAcceptTarget{},
- Output: UnconditionalAcceptTarget{},
- Postrouting: UnconditionalAcceptTarget{},
- },
- UserChains: map[string]Chain{},
- },
- tablenameMangle: Table{
- BuiltinChains: map[Hook]Chain{
- Prerouting: unconditionalAcceptChain(chainNamePrerouting),
- Output: unconditionalAcceptChain(chainNameOutput),
- },
- DefaultTargets: map[Hook]Target{
- Prerouting: UnconditionalAcceptTarget{},
- Output: UnconditionalAcceptTarget{},
- },
- UserChains: map[string]Chain{},
- },
- },
- Priorities: map[Hook][]string{
- Prerouting: []string{tablenameMangle, tablenameNat},
- Output: []string{tablenameMangle, tablenameNat},
- },
- }
-}
-
-func unconditionalAcceptChain(name string) Chain {
- return Chain{
- Name: name,
- Rules: []Rule{
- Rule{
- Target: UnconditionalAcceptTarget{},
- },
- },
- }
-}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
deleted file mode 100644
index 19a7f77e3..000000000
--- a/pkg/tcpip/iptables/targets.go
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// This file contains various Targets.
-
-package iptables
-
-import "gvisor.dev/gvisor/pkg/tcpip/buffer"
-
-// UnconditionalAcceptTarget accepts all packets.
-type UnconditionalAcceptTarget struct{}
-
-// Action implements Target.Action.
-func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
- return Accept, ""
-}
-
-// UnconditionalDropTarget denies all packets.
-type UnconditionalDropTarget struct{}
-
-// Action implements Target.Action.
-func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
- return Drop, ""
-}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
deleted file mode 100644
index 42a79ef9f..000000000
--- a/pkg/tcpip/iptables/types.go
+++ /dev/null
@@ -1,196 +0,0 @@
-// Copyright 2019 The gVisor authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package iptables
-
-import (
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
-)
-
-// A Hook specifies one of the hooks built into the network stack.
-//
-// Userspace app Userspace app
-// ^ |
-// | v
-// [Input] [Output]
-// ^ |
-// | v
-// | routing
-// | |
-// | v
-// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
-type Hook uint
-
-// These values correspond to values in include/uapi/linux/netfilter.h.
-const (
- // Prerouting happens before a packet is routed to applications or to
- // be forwarded.
- Prerouting Hook = iota
-
- // Input happens before a packet reaches an application.
- Input
-
- // Forward happens once it's decided that a packet should be forwarded
- // to another host.
- Forward
-
- // Output happens after a packet is written by an application to be
- // sent out.
- Output
-
- // Postrouting happens just before a packet goes out on the wire.
- Postrouting
-
- // The total number of hooks.
- NumHooks
-)
-
-// A Verdict is returned by a rule's target to indicate how traversal of rules
-// should (or should not) continue.
-type Verdict int
-
-const (
- // Accept indicates the packet should continue traversing netstack as
- // normal.
- Accept Verdict = iota
-
- // Drop inicates the packet should be dropped, stopping traversing
- // netstack.
- Drop
-
- // Stolen indicates the packet was co-opted by the target and should
- // stop traversing netstack.
- Stolen
-
- // Queue indicates the packet should be queued for userspace processing.
- Queue
-
- // Repeat indicates the packet should re-traverse the chains for the
- // current hook.
- Repeat
-
- // None indicates no verdict was reached.
- None
-
- // Jump indicates a jump to another chain.
- Jump
-
- // Continue indicates that traversal should continue at the next rule.
- Continue
-
- // Return indicates that traversal should return to the calling chain.
- Return
-)
-
-// IPTables holds all the tables for a netstack.
-type IPTables struct {
- // Tables maps table names to tables. User tables have arbitrary names.
- Tables map[string]Table
-
- // Priorities maps each hook to a list of table names. The order of the
- // list is the order in which each table should be visited for that
- // hook.
- Priorities map[Hook][]string
-}
-
-// A Table defines a set of chains and hooks into the network stack. The
-// currently supported tables are:
-// * nat
-// * mangle
-type Table struct {
- // BuiltinChains holds the un-deletable chains built into netstack. If
- // a hook isn't present in the map, this table doesn't utilize that
- // hook.
- BuiltinChains map[Hook]Chain
-
- // DefaultTargets holds a target for each hook that will be executed if
- // chain traversal doesn't yield a verdict.
- DefaultTargets map[Hook]Target
-
- // UserChains holds user-defined chains for the keyed by name. Users
- // can give their chains arbitrary names.
- UserChains map[string]Chain
-
- // Chains maps names to chains for both builtin and user-defined chains.
- // Its entries point to Chains already either in BuiltinChains or
- // UserChains, and its purpose is to make looking up tables by name
- // fast.
- Chains map[string]*Chain
-
- // Metadata holds information about the Table that is useful to users
- // of IPTables, but not to the netstack IPTables code itself.
- metadata interface{}
-}
-
-// ValidHooks returns a bitmap of the builtin hooks for the given table.
-func (table *Table) ValidHooks() uint32 {
- hooks := uint32(0)
- for hook, _ := range table.BuiltinChains {
- hooks |= 1 << hook
- }
- return hooks
-}
-
-// Metadata returns the metadata object stored in table.
-func (table *Table) Metadata() interface{} {
- return table.metadata
-}
-
-// SetMetadata sets the metadata object stored in table.
-func (table *Table) SetMetadata(metadata interface{}) {
- table.metadata = metadata
-}
-
-// A Chain defines a list of rules for packet processing. When a packet
-// traverses a chain, it is checked against each rule until either a rule
-// returns a verdict or the chain ends.
-//
-// By convention, builtin chains end with a rule that matches everything and
-// returns either Accept or Drop. User-defined chains end with Return. These
-// aren't strictly necessary here, but the iptables tool writes tables this way.
-type Chain struct {
- // Name is the chain name.
- Name string
-
- // Rules is the list of rules to traverse.
- Rules []Rule
-}
-
-// A Rule is a packet processing rule. It consists of two pieces. First it
-// contains zero or more matchers, each of which is a specification of which
-// packets this rule applies to. If there are no matchers in the rule, it
-// applies to any packet.
-type Rule struct {
- // Matchers is the list of matchers for this rule.
- Matchers []Matcher
-
- // Target is the action to invoke if all the matchers match the packet.
- Target Target
-}
-
-// A Matcher is the interface for matching packets.
-type Matcher interface {
- // Match returns whether the packet matches and whether the packet
- // should be "hotdropped", i.e. dropped immediately. This is usually
- // used for suspicious packets.
- Match(hook Hook, packet buffer.VectorisedView, interfaceName string) (matches bool, hotdrop bool)
-}
-
-// A Target is the interface for taking an action for a packet.
-type Target interface {
- // Action takes an action on the packet and returns a verdict on how
- // traversal should (or should not) continue. If the return value is
- // Jump, it also returns the name of the chain to jump to.
- Action(packet buffer.VectorisedView) (Verdict, string)
-}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 97a794986..39ca774ef 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -1,15 +1,16 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "channel",
srcs = ["channel.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/channel",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 14f197a77..c95aef63c 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -18,61 +18,177 @@
package channel
import (
+ "context"
+
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
- Header buffer.View
- Payload buffer.View
- Proto tcpip.NetworkProtocolNumber
- GSO *stack.GSO
+ Pkt *stack.PacketBuffer
+ Proto tcpip.NetworkProtocolNumber
+ GSO *stack.GSO
+ Route stack.Route
+}
+
+// Notification is the interface for receiving notification from the packet
+// queue.
+type Notification interface {
+ // WriteNotify will be called when a write happens to the queue.
+ WriteNotify()
+}
+
+// NotificationHandle is an opaque handle to the registered notification target.
+// It can be used to unregister the notification when no longer interested.
+//
+// +stateify savable
+type NotificationHandle struct {
+ n Notification
+}
+
+type queue struct {
+ // c is the outbound packet channel.
+ c chan PacketInfo
+ // mu protects fields below.
+ mu sync.RWMutex
+ notify []*NotificationHandle
+}
+
+func (q *queue) Close() {
+ close(q.c)
+}
+
+func (q *queue) Read() (PacketInfo, bool) {
+ select {
+ case p := <-q.c:
+ return p, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ select {
+ case pkt := <-q.c:
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) Write(p PacketInfo) bool {
+ wrote := false
+ select {
+ case q.c <- p:
+ wrote = true
+ default:
+ }
+ q.mu.Lock()
+ notify := q.notify
+ q.mu.Unlock()
+
+ if wrote {
+ // Send notification outside of lock.
+ for _, h := range notify {
+ h.n.WriteNotify()
+ }
+ }
+ return wrote
+}
+
+func (q *queue) Num() int {
+ return len(q.c)
+}
+
+func (q *queue) AddNotify(notify Notification) *NotificationHandle {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ h := &NotificationHandle{n: notify}
+ q.notify = append(q.notify, h)
+ return h
+}
+
+func (q *queue) RemoveNotify(handle *NotificationHandle) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ // Make a copy, since we reads the array outside of lock when notifying.
+ notify := make([]*NotificationHandle, 0, len(q.notify))
+ for _, h := range q.notify {
+ if h != handle {
+ notify = append(notify, h)
+ }
+ }
+ q.notify = notify
}
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
- dispatcher stack.NetworkDispatcher
- mtu uint32
- linkAddr tcpip.LinkAddress
- GSO bool
+ dispatcher stack.NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+ LinkEPCapabilities stack.LinkEndpointCapabilities
- // C is where outbound packets are queued.
- C chan PacketInfo
+ // Outbound packet queue.
+ q *queue
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- C: make(chan PacketInfo, size),
+ q: &queue{
+ c: make(chan PacketInfo, size),
+ },
mtu: mtu,
linkAddr: linkAddr,
}
}
+// Close closes e. Further packet injections will panic. Reads continue to
+// succeed until all packets are read.
+func (e *Endpoint) Close() {
+ e.q.Close()
+}
+
+// Read does non-blocking read one packet from the outbound packet queue.
+func (e *Endpoint) Read() (PacketInfo, bool) {
+ return e.q.Read()
+}
+
+// ReadContext does blocking read for one packet from the outbound packet queue.
+// It can be cancelled by ctx, and in this case, it returns false.
+func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ return e.q.ReadContext(ctx)
+}
+
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
- select {
- case <-e.C:
- c++
- default:
+ if _, ok := e.Read(); !ok {
return c
}
+ c++
}
}
-// Inject injects an inbound packet.
-func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- e.InjectLinkAddr(protocol, "", vv)
+// NumQueued returns the number of packet queued for outbound.
+func (e *Endpoint) NumQueued() int {
+ return e.q.Num()
+}
+
+// InjectInbound injects an inbound packet.
+func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.InjectLinkAddr(protocol, "", pkt)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
-func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, remote, "" /* local */, protocol, vv.Clone(nil), nil /* linkHeader */)
+func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
}
// Attach saves the stack network-layer dispatcher for use later when packets
@@ -94,11 +210,7 @@ func (e *Endpoint) MTU() uint32 {
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
- caps := stack.LinkEndpointCapabilities(0)
- if e.GSO {
- caps |= stack.CapabilityHardwareGSO
- }
- return caps
+ return e.LinkEPCapabilities
}
// GSOMaxSize returns the maximum GSO packet size.
@@ -118,65 +230,81 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
}
// WritePacket stores outbound packets into the channel.
-func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
p := PacketInfo{
- Header: hdr.View(),
- Proto: protocol,
- Payload: payload.ToView(),
- GSO: gso,
+ Pkt: pkt,
+ Proto: protocol,
+ GSO: gso,
+ Route: route,
}
- select {
- case e.C <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// WritePackets stores outbound packets into the channel.
-func (e *Endpoint) WritePackets(_ *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- payloadView := payload.ToView()
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Clone r then release its resource so we only get the relevant fields from
+ // stack.Route without holding a reference to a NIC's endpoint.
+ route := r.Clone()
+ route.Release()
n := 0
-packetLoop:
- for i := range hdrs {
- hdr := &hdrs[i].Hdr
- off := hdrs[i].Off
- size := hdrs[i].Size
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
p := PacketInfo{
- Header: hdr.View(),
- Proto: protocol,
- Payload: buffer.NewViewFromBytes(payloadView[off : off+size]),
- GSO: gso,
+ Pkt: pkt,
+ Proto: protocol,
+ GSO: gso,
+ Route: route,
}
- select {
- case e.C <- p:
- n++
- default:
- break packetLoop
+ if !e.q.Write(p) {
+ break
}
+ n++
}
return n, nil
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
p := PacketInfo{
- Header: packet.ToView(),
- Proto: 0,
- Payload: buffer.View{},
- GSO: nil,
+ Pkt: stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }),
+ Proto: 0,
+ GSO: nil,
}
- select {
- case e.C <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
+
+// AddNotify adds a notification target for receiving event about outgoing
+// packets.
+func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
+ return e.q.AddNotify(notify)
+}
+
+// RemoveNotify removes handle from the list of notification targets.
+func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
+ e.q.RemoveNotify(handle)
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 8fa9e3984..10072eac1 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -13,11 +12,11 @@ go_library(
"mmap_unsafe.go",
"packet_dispatchers.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
+ "//pkg/binary",
+ "//pkg/iovec",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -31,12 +30,13 @@ go_test(
name = "fdbased_test",
size = "small",
srcs = ["endpoint_test.go"],
- embed = [":fdbased"],
+ library = ":fdbased",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/rawfile",
"//pkg/tcpip/stack",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index ae4858529..975309fc8 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,10 +41,12 @@ package fdbased
import (
"fmt"
- "sync"
"syscall"
"golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/iovec"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -91,7 +93,7 @@ func (p PacketDispatchMode) String() string {
case PacketMMap:
return "PacketMMap"
default:
- return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ return fmt.Sprintf("unknown packet dispatch mode '%d'", p)
}
}
@@ -384,37 +386,46 @@ const (
_VIRTIO_NET_HDR_GSO_TCPV6 = 4
)
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
- eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
// Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
}
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if e.hdrSize > 0 {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
+ }
+
+ var builder iovec.Builder
+ fd := e.fds[pkt.Hash%uint32(len(e.fds))]
if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
vnetHdr := virtioNetHdr{}
- vnetHdrBuf := vnetHdrToByteSlice(&vnetHdr)
if gso != nil {
- vnetHdr.hdrLen = uint16(hdr.UsedLength())
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
if gso.NeedsCsum {
vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
vnetHdr.csumStart = header.EthernetMinimumSize + gso.L3HdrLen
vnetHdr.csumOffset = gso.CsumOffset
}
- if gso.Type != stack.GSONone && uint16(payload.Size()) > gso.MSS {
+ if gso.Type != stack.GSONone && uint16(pkt.Data.Size()) > gso.MSS {
switch gso.Type {
case stack.GSOTCPv4:
vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
@@ -427,136 +438,129 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
}
}
- return rawfile.NonBlockingWrite3(e.fds[0], vnetHdrBuf, hdr.View(), payload.ToView())
+ vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
+ builder.Add(vnetHdrBuf)
}
- if payload.Size() == 0 {
- return rawfile.NonBlockingWrite(e.fds[0], hdr.View())
+ for _, v := range pkt.Views() {
+ builder.Add(v)
}
-
- return rawfile.NonBlockingWrite3(e.fds[0], hdr.View(), payload.ToView(), nil)
+ return rawfile.NonBlockingWriteIovec(fd, builder.Build())
}
-// WritePackets writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- var ethHdrBuf []byte
- // hdr + data
- iovLen := 2
- if e.hdrSize > 0 {
- // Add ethernet header if needed.
- ethHdrBuf = make([]byte, header.EthernetMinimumSize)
- eth := header.Ethernet(ethHdrBuf)
- ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
- Type: protocol,
+func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) {
+ // Send a batch of packets through batchFD.
+ mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch))
+ for _, pkt := range batch {
+ if e.hdrSize > 0 {
+ e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt)
}
- // Preserve the src address if it's set in the route.
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
- } else {
- ethHdr.SrcAddr = e.addr
- }
- eth.Encode(ethHdr)
- iovLen++
- }
-
- n := len(hdrs)
-
- views := payload.Views()
- /*
- * Each bondary in views can add one more iovec.
- *
- * payload | | | |
- * -----------------------------
- * packets | | | | | | |
- * -----------------------------
- * iovecs | | | | | | | | |
- */
- iovec := make([]syscall.Iovec, n*iovLen+len(views)-1)
- mmsgHdrs := make([]rawfile.MMsgHdr, n)
-
- iovecIdx := 0
- viewIdx := 0
- viewOff := 0
- off := 0
- nextOff := 0
- for i := range hdrs {
- prevIovecIdx := iovecIdx
- mmsgHdr := &mmsgHdrs[i]
- mmsgHdr.Msg.Iov = &iovec[iovecIdx]
- packetSize := hdrs[i].Size
- hdr := &hdrs[i].Hdr
-
- off = hdrs[i].Off
- if off != nextOff {
- // We stop in a different point last time.
- size := packetSize
- viewIdx = 0
- viewOff = 0
- for size > 0 {
- if size >= len(views[viewIdx]) {
- viewIdx++
- viewOff = 0
- size -= len(views[viewIdx])
- } else {
- viewOff = size
- size = 0
+ var vnetHdrBuf []byte
+ if e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
+ vnetHdr := virtioNetHdr{}
+ if pkt.GSOOptions != nil {
+ vnetHdr.hdrLen = uint16(pkt.HeaderSize())
+ if pkt.GSOOptions.NeedsCsum {
+ vnetHdr.flags = _VIRTIO_NET_HDR_F_NEEDS_CSUM
+ vnetHdr.csumStart = header.EthernetMinimumSize + pkt.GSOOptions.L3HdrLen
+ vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset
+ }
+ if pkt.GSOOptions.Type != stack.GSONone && uint16(pkt.Data.Size()) > pkt.GSOOptions.MSS {
+ switch pkt.GSOOptions.Type {
+ case stack.GSOTCPv4:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV4
+ case stack.GSOTCPv6:
+ vnetHdr.gsoType = _VIRTIO_NET_HDR_GSO_TCPV6
+ default:
+ panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type))
+ }
+ vnetHdr.gsoSize = pkt.GSOOptions.MSS
}
}
- }
- nextOff = off + packetSize
-
- if ethHdrBuf != nil {
- v := &iovec[iovecIdx]
- v.Base = &ethHdrBuf[0]
- v.Len = uint64(len(ethHdrBuf))
- iovecIdx++
+ vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr)
}
- v := &iovec[iovecIdx]
- hdrView := hdr.View()
- v.Base = &hdrView[0]
- v.Len = uint64(len(hdrView))
- iovecIdx++
-
- for packetSize > 0 {
- vec := &iovec[iovecIdx]
- iovecIdx++
-
- v := views[viewIdx]
- vec.Base = &v[viewOff]
- s := len(v) - viewOff
- if s <= packetSize {
- viewIdx++
- viewOff = 0
- } else {
- s = packetSize
- viewOff += s
- }
- vec.Len = uint64(s)
- packetSize -= s
+ var builder iovec.Builder
+ builder.Add(vnetHdrBuf)
+ for _, v := range pkt.Views() {
+ builder.Add(v)
}
+ iovecs := builder.Build()
- mmsgHdr.Msg.Iovlen = uint64(iovecIdx - prevIovecIdx)
+ var mmsgHdr rawfile.MMsgHdr
+ mmsgHdr.Msg.Iov = &iovecs[0]
+ mmsgHdr.Msg.Iovlen = uint64(len(iovecs))
+ mmsgHdrs = append(mmsgHdrs, mmsgHdr)
}
packets := 0
- for packets < n {
- sent, err := rawfile.NonBlockingSendMMsg(e.fds[0], mmsgHdrs)
+ for len(mmsgHdrs) > 0 {
+ sent, err := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
if err != nil {
return packets, err
}
packets += sent
mmsgHdrs = mmsgHdrs[sent:]
}
+
return packets, nil
}
+// WritePackets writes outbound packets to the underlying file descriptors. If
+// one is not currently writable, the packet is dropped.
+//
+// Being a batch API, each packet in pkts should have the following
+// fields populated:
+// - pkt.EgressRoute
+// - pkt.GSOOptions
+// - pkt.NetworkProtocolNumber
+func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Preallocate to avoid repeated reallocation as we append to batch.
+ // batchSz is 47 because when SWGSO is in use then a single 65KB TCP
+ // segment can get split into 46 segments of 1420 bytes and a single 216
+ // byte segment.
+ const batchSz = 47
+ batch := make([]*stack.PacketBuffer, 0, batchSz)
+ batchFD := -1
+ sentPackets := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if len(batch) == 0 {
+ batchFD = e.fds[pkt.Hash%uint32(len(e.fds))]
+ }
+ pktFD := e.fds[pkt.Hash%uint32(len(e.fds))]
+ if sendNow := pktFD != batchFD; !sendNow {
+ batch = append(batch, pkt)
+ continue
+ }
+ n, err := e.sendBatch(batchFD, batch)
+ sentPackets += n
+ if err != nil {
+ return sentPackets, err
+ }
+ batch = batch[:0]
+ batch = append(batch, pkt)
+ batchFD = pktFD
+ }
+
+ if len(batch) != 0 {
+ n, err := e.sendBatch(batchFD, batch)
+ sentPackets += n
+ if err != nil {
+ return sentPackets, err
+ }
+ }
+ return sentPackets, nil
+}
+
+// viewsEqual tests whether v1 and v2 refer to the same backing bytes.
+func viewsEqual(vs1, vs2 []buffer.View) bool {
+ return len(vs1) == len(vs2) && (len(vs1) == 0 || &vs1[0] == &vs2[0])
+}
+
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
- return rawfile.NonBlockingWrite(e.fds[0], packet.ToView())
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return rawfile.NonBlockingWrite(e.fds[0], vv.ToView())
}
// InjectOutobund implements stack.InjectableEndpoint.InjectOutbound.
@@ -583,6 +587,14 @@ func (e *endpoint) GSOMaxSize() uint32 {
return e.gsoMaxSize
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.hdrSize > 0 {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
// to the FD, but does not read from it. All reads come from injected packets.
type InjectableEndpoint struct {
@@ -598,8 +610,8 @@ func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
// InjectInbound injects an inbound packet.
-func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
+func (e *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
}
// NewInjectable creates a new fd-based InjectableEndpoint.
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 59378b96c..709f829c8 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -26,6 +26,7 @@ import (
"time"
"unsafe"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -43,43 +44,75 @@ const (
)
type packetInfo struct {
- raddr tcpip.LinkAddress
- proto tcpip.NetworkProtocolNumber
- contents buffer.View
- linkHeader buffer.View
+ Raddr tcpip.LinkAddress
+ Proto tcpip.NetworkProtocolNumber
+ Contents *stack.PacketBuffer
+}
+
+type packetContents struct {
+ LinkHeader buffer.View
+ NetworkHeader buffer.View
+ TransportHeader buffer.View
+ Data buffer.View
+}
+
+func checkPacketInfoEqual(t *testing.T, got, want packetInfo) {
+ t.Helper()
+ if diff := cmp.Diff(
+ want, got,
+ cmp.Transformer("ExtractPacketBuffer", func(pk *stack.PacketBuffer) *packetContents {
+ if pk == nil {
+ return nil
+ }
+ return &packetContents{
+ LinkHeader: pk.LinkHeader().View(),
+ NetworkHeader: pk.NetworkHeader().View(),
+ TransportHeader: pk.TransportHeader().View(),
+ Data: pk.Data.ToView(),
+ }
+ }),
+ ); diff != "" {
+ t.Errorf("unexpected packetInfo (-want +got):\n%s", diff)
+ }
}
type context struct {
- t *testing.T
- fds [2]int
- ep stack.LinkEndpoint
- ch chan packetInfo
- done chan struct{}
+ t *testing.T
+ readFDs []int
+ writeFDs []int
+ ep stack.LinkEndpoint
+ ch chan packetInfo
+ done chan struct{}
}
func newContext(t *testing.T, opt *Options) *context {
- fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ firstFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ t.Fatalf("Socketpair failed: %v", err)
+ }
+ secondFDPair, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
if err != nil {
t.Fatalf("Socketpair failed: %v", err)
}
- done := make(chan struct{}, 1)
+ done := make(chan struct{}, 2)
opt.ClosedFunc = func(*tcpip.Error) {
done <- struct{}{}
}
- opt.FDs = []int{fds[1]}
+ opt.FDs = []int{firstFDPair[1], secondFDPair[1]}
ep, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
}
c := &context{
- t: t,
- fds: fds,
- ep: ep,
- ch: make(chan packetInfo, 100),
- done: done,
+ t: t,
+ readFDs: []int{firstFDPair[0], secondFDPair[0]},
+ writeFDs: opt.FDs,
+ ep: ep,
+ ch: make(chan packetInfo, 100),
+ done: done,
}
ep.Attach(c)
@@ -88,13 +121,22 @@ func newContext(t *testing.T, opt *Options) *context {
}
func (c *context) cleanup() {
- syscall.Close(c.fds[0])
+ for _, fd := range c.readFDs {
+ syscall.Close(fd)
+ }
<-c.done
- syscall.Close(c.fds[1])
+ <-c.done
+ for _, fd := range c.writeFDs {
+ syscall.Close(fd)
+ }
}
-func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
- c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader}
+func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ c.ch <- packetInfo{remote, protocol, pkt}
+}
+
+func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
}
func TestNoEthernetProperties(t *testing.T) {
@@ -137,7 +179,7 @@ func TestAddress(t *testing.T) {
}
}
-func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
+func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32, hash uint32) {
c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth, GSOMaxSize: gsoMaxSize})
defer c.cleanup()
@@ -145,19 +187,28 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
RemoteLinkAddress: raddr,
}
- // Build header.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
- b := hdr.Prepend(100)
- for i := range b {
- b[i] = uint8(rand.Intn(256))
+ // Build payload.
+ payload := buffer.NewView(plen)
+ if _, err := rand.Read(payload); err != nil {
+ t.Fatalf("rand.Read(payload): %s", err)
}
- // Build payload and write.
- payload := make(buffer.View, plen)
- for i := range payload {
- payload[i] = uint8(rand.Intn(256))
+ // Build packet buffer.
+ const netHdrLen = 100
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()) + netHdrLen,
+ Data: payload.ToVectorisedView(),
+ })
+ pkt.Hash = hash
+
+ // Build header.
+ b := pkt.NetworkHeader().Push(netHdrLen)
+ if _, err := rand.Read(b); err != nil {
+ t.Fatalf("rand.Read(b): %s", err)
}
- want := append(hdr.View(), payload...)
+
+ // Write.
+ want := append(append(buffer.View(nil), b...), payload...)
var gso *stack.GSO
if gsoMaxSize != 0 {
gso = &stack.GSO{
@@ -169,13 +220,14 @@ func testWritePacket(t *testing.T, plen int, eth bool, gsoMaxSize uint32) {
L3HdrLen: header.IPv4MaximumHeaderSize,
}
}
- if err := c.ep.WritePacket(r, gso, hdr, payload.ToVectorisedView(), proto); err != nil {
+ if err := c.ep.WritePacket(r, gso, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
- // Read from fd, then compare with what we wrote.
+ // Read from the corresponding FD, then compare with what we wrote.
b = make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
+ fd := c.readFDs[hash%uint32(len(c.readFDs))]
+ n, err := syscall.Read(fd, b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
@@ -236,7 +288,7 @@ func TestWritePacket(t *testing.T) {
t.Run(
fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v", eth, plen, gso),
func(t *testing.T) {
- testWritePacket(t, plen, eth, gso)
+ testWritePacket(t, plen, eth, gso, 0)
},
)
}
@@ -244,6 +296,27 @@ func TestWritePacket(t *testing.T) {
}
}
+func TestHashedWritePacket(t *testing.T) {
+ lengths := []int{0, 100, 1000}
+ eths := []bool{true, false}
+ gsos := []uint32{0, 32768}
+ hashes := []uint32{0, 1}
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ for _, gso := range gsos {
+ for _, hash := range hashes {
+ t.Run(
+ fmt.Sprintf("Eth=%v,PayloadLen=%v,GSOMaxSize=%v,Hash=%d", eth, plen, gso, hash),
+ func(t *testing.T) {
+ testWritePacket(t, plen, eth, gso, hash)
+ },
+ )
+ }
+ }
+ }
+ }
+}
+
func TestPreserveSrcAddress(t *testing.T) {
baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
@@ -256,16 +329,20 @@ func TestPreserveSrcAddress(t *testing.T) {
LocalLinkAddress: baddr,
}
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- hdr := buffer.NewPrependable(header.EthernetMinimumSize)
- if err := c.ep.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ // TODO(b/153685824): Figure out if this should use c.ep.MaxHeaderLength().
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ Data: buffer.VectorisedView{},
+ })
+ if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
// Read from the FD, then compare with what we wrote.
b := make([]byte, mtu)
- n, err := syscall.Read(c.fds[0], b)
+ n, err := syscall.Read(c.readFDs[0], b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
@@ -288,28 +365,29 @@ func TestDeliverPacket(t *testing.T) {
defer c.cleanup()
// Build packet.
- b := make([]byte, plen)
- all := b
- for i := range b {
- b[i] = uint8(rand.Intn(256))
+ all := make([]byte, plen)
+ if _, err := rand.Read(all); err != nil {
+ t.Fatalf("rand.Read(all): %s", err)
}
-
- var hdr header.Ethernet
- if !eth {
- // So that it looks like an IPv4 packet.
- b[0] = 0x40
- } else {
- hdr = make(header.Ethernet, header.EthernetMinimumSize)
+ // Make it look like an IPv4 packet.
+ all[0] = 0x40
+
+ wantPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ Data: buffer.NewViewFromBytes(all).ToVectorisedView(),
+ })
+ if eth {
+ hdr := header.Ethernet(wantPkt.LinkHeader().Push(header.EthernetMinimumSize))
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
Type: proto,
})
- all = append(hdr, b...)
+ all = append(hdr, all...)
}
// Write packet via the file descriptor.
- if _, err := syscall.Write(c.fds[0], all); err != nil {
+ if _, err := syscall.Write(c.readFDs[0], all); err != nil {
t.Fatalf("Write failed: %v", err)
}
@@ -317,18 +395,15 @@ func TestDeliverPacket(t *testing.T) {
select {
case pi := <-c.ch:
want := packetInfo{
- raddr: raddr,
- proto: proto,
- contents: b,
- linkHeader: buffer.View(hdr),
+ Raddr: raddr,
+ Proto: proto,
+ Contents: wantPkt,
}
if !eth {
- want.proto = header.IPv4ProtocolNumber
- want.raddr = ""
- }
- if !reflect.DeepEqual(want, pi) {
- t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
+ want.Proto = header.IPv4ProtocolNumber
+ want.Raddr = ""
}
+ checkPacketInfoEqual(t, pi, want)
case <-time.After(10 * time.Second):
t.Fatalf("Timed out waiting for packet")
}
@@ -455,3 +530,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) {
}
}
+
+// fakeNetworkDispatcher delivers packets to pkts.
+type fakeNetworkDispatcher struct {
+ pkts []*stack.PacketBuffer
+}
+
+func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ d.pkts = append(d.pkts, pkt)
+}
+
+func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
+func TestDispatchPacketFormat(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ newDispatcher func(fd int, e *endpoint) (linkDispatcher, error)
+ }{
+ {
+ name: "readVDispatcher",
+ newDispatcher: newReadVDispatcher,
+ },
+ {
+ name: "recvMMsgDispatcher",
+ newDispatcher: newRecvMMsgDispatcher,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ // Create a socket pair to send/recv.
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer syscall.Close(fds[0])
+ defer syscall.Close(fds[1])
+
+ data := []byte{
+ // Ethernet header.
+ 1, 2, 3, 4, 5, 60,
+ 1, 2, 3, 4, 5, 61,
+ 8, 0,
+ // Mock network header.
+ 40, 41, 42, 43,
+ }
+ err = syscall.Sendmsg(fds[1], data, nil, nil, 0)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create and run dispatcher once.
+ sink := &fakeNetworkDispatcher{}
+ d, err := test.newDispatcher(fds[0], &endpoint{
+ hdrSize: header.EthernetMinimumSize,
+ dispatcher: sink,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if ok, err := d.dispatch(); !ok || err != nil {
+ t.Fatalf("d.dispatch() = %v, %v", ok, err)
+ }
+
+ // Verify packet.
+ if got, want := len(sink.pkts), 1; got != want {
+ t.Fatalf("len(sink.pkts) = %d, want %d", got, want)
+ }
+ pkt := sink.pkts[0]
+ if got, want := pkt.LinkHeader().View().Size(), header.EthernetMinimumSize; got != want {
+ t.Errorf("pkt.LinkHeader().View().Size() = %d, want %d", got, want)
+ }
+ if got, want := pkt.Data.Size(), 4; got != want {
+ t.Errorf("pkt.Data.Size() = %d, want %d", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint_unsafe.go b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
index 97a477b61..df14eaad1 100644
--- a/pkg/tcpip/link/fdbased/endpoint_unsafe.go
+++ b/pkg/tcpip/link/fdbased/endpoint_unsafe.go
@@ -17,16 +17,7 @@
package fdbased
import (
- "reflect"
"unsafe"
)
const virtioNetHdrSize = int(unsafe.Sizeof(virtioNetHdr{}))
-
-func vnetHdrToByteSlice(hdr *virtioNetHdr) (slice []byte) {
- sh := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
- sh.Data = uintptr(unsafe.Pointer(hdr))
- sh.Len = virtioNetHdrSize
- sh.Cap = virtioNetHdrSize
- return
-}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 554d45715..c475dda20 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -18,6 +18,7 @@ package fdbased
import (
"encoding/binary"
+ "fmt"
"syscall"
"golang.org/x/sys/unix"
@@ -25,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
@@ -169,10 +171,9 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(pkt)
+ eth := header.Ethernet(pkt)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -189,7 +190,14 @@ func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}), buffer.View(eth))
+ pbuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(pkt).ToVectorisedView(),
+ })
+ if d.e.hdrSize > 0 {
+ if _, ok := pbuf.LinkHeader().Consume(d.e.hdrSize); !ok {
+ panic(fmt.Sprintf("LinkHeader().Consume(%d) must succeed", d.e.hdrSize))
+ }
+ }
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pbuf)
return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 12168a1dc..8c3ca86d6 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -103,7 +103,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
d.allocateViews(BufConfig)
n, err := rawfile.BlockingReadv(d.fd, d.iovecs)
- if err != nil {
+ if n == 0 || err != nil {
return false, err
}
if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
@@ -111,17 +111,22 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
// isn't used and it isn't in a view.
n -= virtioNetHdrSize
}
- if n <= d.e.hdrSize {
- return false, nil
- }
+
+ used := d.capViews(n, BufConfig)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)),
+ })
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[0][:header.EthernetMinimumSize])
+ hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize)
+ if !ok {
+ return false, nil
+ }
+ eth := header.Ethernet(hdr)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -138,11 +143,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- used := d.capViews(n, BufConfig)
- vv := buffer.NewVectorisedView(n, d.views[:used])
- vv.TrimFront(d.e.hdrSize)
-
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
@@ -166,7 +167,7 @@ type recvMMsgDispatcher struct {
// iovecs is an array of array of iovec records where each iovec base
// pointer and length are initialzed to the corresponding view above,
- // except when GSO is neabled then the first iovec in each array of
+ // except when GSO is enabled then the first iovec in each array of
// iovecs points to a buffer for the vnet header which is stripped
// before the views are passed up the stack for further processing.
iovecs [][]syscall.Iovec
@@ -265,17 +266,22 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
if d.e.Capabilities()&stack.CapabilityHardwareGSO != 0 {
n -= virtioNetHdrSize
}
- if n <= d.e.hdrSize {
- return false, nil
- }
+
+ used := d.capViews(k, int(n), BufConfig)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)),
+ })
var (
p tcpip.NetworkProtocolNumber
remote, local tcpip.LinkAddress
- eth header.Ethernet
)
if d.e.hdrSize > 0 {
- eth = header.Ethernet(d.views[k][0])
+ hdr, ok := pkt.LinkHeader().Consume(d.e.hdrSize)
+ if !ok {
+ return false, nil
+ }
+ eth := header.Ethernet(hdr)
p = eth.Type()
remote = eth.SourceAddress()
local = eth.DestinationAddress()
@@ -292,10 +298,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
}
- used := d.capViews(k, int(n), BufConfig)
- vv := buffer.NewVectorisedView(int(n), d.views[k][:used])
- vv.TrimFront(d.e.hdrSize)
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
+ d.e.dispatcher.DeliverNetworkPacket(remote, local, p, pkt)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
index 23e4d1418..6bf3805b7 100644
--- a/pkg/tcpip/link/loopback/BUILD
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -1,12 +1,11 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "loopback",
srcs = ["loopback.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/loopback",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index a3b48fa73..38aa694e4 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -76,36 +76,47 @@ func (*endpoint) Wait() {}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
-func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
-
- // Because we're immediately turning around and writing the packet back to the
- // rx path, we intentionally don't preserve the remote and local link
- // addresses from the stack.Route we're passed.
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
+func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ // Construct data as the unparsed portion for the loopback packet.
+ data := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+
+ // Because we're immediately turning around and writing the packet back
+ // to the rx path, we intentionally don't preserve the remote and local
+ // link addresses from the stack.Route we're passed.
+ newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: data,
+ })
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, newPkt)
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
- // Reject the packet if it's shorter than an ethernet header.
- if packet.Size() < header.EthernetMinimumSize {
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ })
+ // There should be an ethernet header at the beginning of vv.
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ // Reject the packet if it's shorter than an ethernet header.
return tcpip.ErrBadAddress
}
-
- // There should be an ethernet header at the beginning of packet.
- linkHeader := header.Ethernet(packet.First()[:header.EthernetMinimumSize])
- packet.TrimFront(len(linkHeader))
- e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), packet, buffer.View(linkHeader))
+ linkHeader := header.Ethernet(hdr)
+ e.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, linkHeader.Type(), pkt)
return nil
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareLoopback
+}
+
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index 1bab380b0..e7493e5c5 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,18 +1,15 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "muxed",
srcs = ["injectable.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/muxed",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
@@ -21,7 +18,7 @@ go_test(
name = "muxed_test",
size = "small",
srcs = ["injectable_test.go"],
- embed = [":muxed"],
+ library = ":muxed",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 682b60291..56a611825 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -18,6 +18,7 @@ package muxed
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -80,33 +81,33 @@ func (m *InjectableEndpoint) IsAttached() bool {
}
// InjectInbound implements stack.InjectableLinkEndpoint.
-func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
- m.dispatcher.DeliverNetworkPacket(m, "" /* remote */, "" /* local */, protocol, vv, nil /* linkHeader */)
+func (m *InjectableEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ m.dispatcher.DeliverNetworkPacket("" /* remote */, "" /* local */, protocol, pkt)
}
// WritePackets writes outbound packets to the appropriate
// LinkInjectableEndpoint based on the RemoteAddress. HandleLocal only works if
// r.RemoteAddress has a route registered in this endpoint.
-func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (m *InjectableEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
endpoint, ok := m.routes[r.RemoteAddress]
if !ok {
return 0, tcpip.ErrNoRoute
}
- return endpoint.WritePackets(r, gso, hdrs, payload, protocol)
+ return endpoint.WritePackets(r, gso, pkts, protocol)
}
// WritePacket writes outbound packets to the appropriate LinkInjectableEndpoint
// based on the RemoteAddress. HandleLocal only works if r.RemoteAddress has a
// route registered in this endpoint.
-func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (m *InjectableEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if endpoint, ok := m.routes[r.RemoteAddress]; ok {
- return endpoint.WritePacket(r, gso, hdr, payload, protocol)
+ return endpoint.WritePacket(r, gso, protocol, pkt)
}
return tcpip.ErrNoRoute
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (m *InjectableEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+func (m *InjectableEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
// WriteRawPacket doesn't get a route or network address, so there's
// nowhere to write this.
return tcpip.ErrNoRoute
@@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() {
}
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unsupported operation")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 9cd300af8..3e4afcdad 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -46,12 +46,14 @@ func TestInjectableEndpointRawDispatch(t *testing.T) {
func TestInjectableEndpointDispatch(t *testing.T) {
endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- hdr := buffer.NewPrependable(1)
- hdr.Prepend(1)[0] = 0xFA
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: 1,
+ Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, hdr,
- buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), ipv4.ProtocolNumber)
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
@@ -65,11 +67,14 @@ func TestInjectableEndpointDispatch(t *testing.T) {
func TestInjectableEndpointDispatchHdrOnly(t *testing.T) {
endpoint, sock, dstIP := makeTestInjectableEndpoint(t)
- hdr := buffer.NewPrependable(1)
- hdr.Prepend(1)[0] = 0xFA
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: 1,
+ Data: buffer.NewView(0).ToVectorisedView(),
+ })
+ pkt.TransportHeader().Push(1)[0] = 0xFA
packetRoute := stack.Route{RemoteAddress: dstIP}
- endpoint.WritePacket(&packetRoute, nil /* gso */, hdr,
- buffer.NewView(0).ToVectorisedView(), ipv4.ProtocolNumber)
+ endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt)
buf := make([]byte, 6500)
bytesRead, err := sock.Read(buf)
if err != nil {
diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD
new file mode 100644
index 000000000..2cdb23475
--- /dev/null
+++ b/pkg/tcpip/link/nested/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "nested",
+ srcs = [
+ "nested.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "nested_test",
+ size = "small",
+ srcs = [
+ "nested_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go
new file mode 100644
index 000000000..d40de54df
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested.go
@@ -0,0 +1,152 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nested provides helpers to implement the pattern of nested
+// stack.LinkEndpoints.
+package nested
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher
+// that can be used to implement nesting safely by providing lifecycle
+// concurrency guards.
+//
+// See the tests in this package for example usage.
+type Endpoint struct {
+ child stack.LinkEndpoint
+ embedder stack.NetworkDispatcher
+
+ // mu protects dispatcher.
+ mu sync.RWMutex
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.GSOEndpoint = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+
+// Init initializes a nested.Endpoint that uses embedder as the dispatcher for
+// child on Attach.
+//
+// See the tests in this package for example usage.
+func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) {
+ e.child = child
+ e.embedder = embedder
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverNetworkPacket(remote, local, protocol, pkt)
+ }
+}
+
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.mu.RLock()
+ d := e.dispatcher
+ e.mu.RUnlock()
+ if d != nil {
+ d.DeliverOutboundPacket(remote, local, protocol, pkt)
+ }
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ e.dispatcher = dispatcher
+ e.mu.Unlock()
+ // If we're attaching to a valid dispatcher, pass embedder as the dispatcher
+ // to our child, otherwise detach the child by giving it a nil dispatcher.
+ var pass stack.NetworkDispatcher
+ if dispatcher != nil {
+ pass = e.embedder
+ }
+ e.child.Attach(pass)
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ e.mu.RLock()
+ isAttached := e.dispatcher != nil
+ e.mu.RUnlock()
+ return isAttached
+}
+
+// MTU implements stack.LinkEndpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.child.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.child.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.child.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.child.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ return e.child.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ return e.child.WritePackets(r, gso, pkts, protocol)
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ return e.child.WriteRawPacket(vv)
+}
+
+// Wait implements stack.LinkEndpoint.
+func (e *Endpoint) Wait() {
+ e.child.Wait()
+}
+
+// GSOMaxSize implements stack.GSOEndpoint.
+func (e *Endpoint) GSOMaxSize() uint32 {
+ if e, ok := e.child.(stack.GSOEndpoint); ok {
+ return e.GSOMaxSize()
+ }
+ return 0
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.child.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.child.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go
new file mode 100644
index 000000000..c1f9d308c
--- /dev/null
+++ b/pkg/tcpip/link/nested/nested_test.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nested_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type parentEndpoint struct {
+ nested.Endpoint
+}
+
+var _ stack.LinkEndpoint = (*parentEndpoint)(nil)
+var _ stack.NetworkDispatcher = (*parentEndpoint)(nil)
+
+type childEndpoint struct {
+ stack.LinkEndpoint
+ dispatcher stack.NetworkDispatcher
+}
+
+var _ stack.LinkEndpoint = (*childEndpoint)(nil)
+
+func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ c.dispatcher = dispatcher
+}
+
+func (c *childEndpoint) IsAttached() bool {
+ return c.dispatcher != nil
+}
+
+type counterDispatcher struct {
+ count int
+}
+
+var _ stack.NetworkDispatcher = (*counterDispatcher)(nil)
+
+func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ d.count++
+}
+
+func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
+func TestNestedLinkEndpoint(t *testing.T) {
+ const emptyAddress = tcpip.LinkAddress("")
+
+ var (
+ childEP childEndpoint
+ nestedEP parentEndpoint
+ disp counterDispatcher
+ )
+ nestedEP.Endpoint.Init(&childEP, &nestedEP)
+
+ if childEP.IsAttached() {
+ t.Error("On init, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("On init, nestedEP.IsAttached() = true, want = false")
+ }
+
+ nestedEP.Attach(&disp)
+ if disp.count != 0 {
+ t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count)
+ }
+ if !childEP.IsAttached() {
+ t.Error("After attach, childEP.IsAttached() = false, want = true")
+ }
+ if !nestedEP.IsAttached() {
+ t.Error("After attach, nestedEP.IsAttached() = false, want = true")
+ }
+
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
+ if disp.count != 1 {
+ t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count)
+ }
+
+ nestedEP.Attach(nil)
+ if childEP.IsAttached() {
+ t.Error("After detach, childEP.IsAttached() = true, want = false")
+ }
+ if nestedEP.IsAttached() {
+ t.Error("After detach, nestedEP.IsAttached() = true, want = false")
+ }
+
+ disp.count = 0
+ nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
+ if disp.count != 0 {
+ t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count)
+ }
+
+}
diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD
new file mode 100644
index 000000000..6fff160ce
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "packetsocket",
+ srcs = ["endpoint.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go
new file mode 100644
index 000000000..3922c2a04
--- /dev/null
+++ b/pkg/tcpip/link/packetsocket/endpoint.go
@@ -0,0 +1,50 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package packetsocket provides a link layer endpoint that provides the ability
+// to loop outbound packets to any AF_PACKET sockets that may be interested in
+// the outgoing packet.
+package packetsocket
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ nested.Endpoint
+}
+
+// New creates a new packetsocket LinkEndpoint.
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ e := &endpoint{}
+ e.Endpoint.Init(lower, e)
+ return e
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD
new file mode 100644
index 000000000..1d0079bd6
--- /dev/null
+++ b/pkg/tcpip/link/qdisc/fifo/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fifo",
+ srcs = [
+ "endpoint.go",
+ "packet_buffer_queue.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
new file mode 100644
index 000000000..fc1e34fc7
--- /dev/null
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -0,0 +1,227 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fifo provides the implementation of data-link layer endpoints that
+// wrap another endpoint and queues all outbound packets and asynchronously
+// dispatches them to the lower endpoint.
+package fifo
+
+import (
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// endpoint represents a LinkEndpoint which implements a FIFO queue for all
+// outgoing packets. endpoint can have 1 or more underlying queueDispatchers.
+// All outgoing packets are consistenly hashed to a single underlying queue
+// using the PacketBuffer.Hash if set, otherwise all packets are queued to the
+// first queue to avoid reordering in case of missing hash.
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ lower stack.LinkEndpoint
+ wg sync.WaitGroup
+ dispatchers []*queueDispatcher
+}
+
+// queueDispatcher is responsible for dispatching all outbound packets in its
+// queue. It will also smartly batch packets when possible and write them
+// through the lower LinkEndpoint.
+type queueDispatcher struct {
+ lower stack.LinkEndpoint
+ q *packetBufferQueue
+ newPacketWaker sleep.Waker
+ closeWaker sleep.Waker
+}
+
+// New creates a new fifo link endpoint with the n queues with maximum
+// capacity of queueLen.
+func New(lower stack.LinkEndpoint, n int, queueLen int) stack.LinkEndpoint {
+ e := &endpoint{
+ lower: lower,
+ }
+ // Create the required dispatchers
+ for i := 0; i < n; i++ {
+ qd := &queueDispatcher{
+ q: &packetBufferQueue{limit: queueLen},
+ lower: lower,
+ }
+ e.dispatchers = append(e.dispatchers, qd)
+ e.wg.Add(1)
+ go func() {
+ defer e.wg.Done()
+ qd.dispatchLoop()
+ }()
+ }
+ return e
+}
+
+func (q *queueDispatcher) dispatchLoop() {
+ const newPacketWakerID = 1
+ const closeWakerID = 2
+ s := sleep.Sleeper{}
+ s.AddWaker(&q.newPacketWaker, newPacketWakerID)
+ s.AddWaker(&q.closeWaker, closeWakerID)
+ defer s.Done()
+
+ const batchSize = 32
+ var batch stack.PacketBufferList
+ for {
+ id, ok := s.Fetch(true)
+ if ok && id == closeWakerID {
+ return
+ }
+ for pkt := q.q.dequeue(); pkt != nil; pkt = q.q.dequeue() {
+ batch.PushBack(pkt)
+ if batch.Len() < batchSize && !q.q.empty() {
+ continue
+ }
+ // We pass a protocol of zero here because each packet carries its
+ // NetworkProtocol.
+ q.lower.WritePackets(nil /* route */, nil /* gso */, batch, 0 /* protocol */)
+ for pkt := batch.Front(); pkt != nil; pkt = pkt.Next() {
+ pkt.EgressRoute.Release()
+ batch.Remove(pkt)
+ }
+ batch.Reset()
+ }
+ }
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
+}
+
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU.
+func (e *endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (e *endpoint) GSOMaxSize() uint32 {
+ if gso, ok := e.lower.(stack.GSOEndpoint); ok {
+ return gso.GSOMaxSize()
+ }
+ return 0
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket.
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ // WritePacket caller's do not set the following fields in PacketBuffer
+ // so we populate them here.
+ newRoute := r.Clone()
+ pkt.EgressRoute = &newRoute
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
+ d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
+ if !d.q.enqueue(pkt) {
+ return tcpip.ErrNoBufferSpace
+ }
+ d.newPacketWaker.Assert()
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.WritePackets.
+//
+// Being a batch API, each packet in pkts should have the following fields
+// populated:
+// - pkt.EgressRoute
+// - pkt.GSOOptions
+// - pkt.NetworkProtocolNumber
+func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ enqueued := 0
+ for pkt := pkts.Front(); pkt != nil; {
+ d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)]
+ nxt := pkt.Next()
+ // Since qdisc can hold onto a packet for long we should Clone
+ // the route here to ensure it doesn't get released while the
+ // packet is still in our queue.
+ newRoute := pkt.EgressRoute.Clone()
+ pkt.EgressRoute = &newRoute
+ if !d.q.enqueue(pkt) {
+ if enqueued > 0 {
+ d.newPacketWaker.Assert()
+ }
+ return enqueued, tcpip.ErrNoBufferSpace
+ }
+ pkt = nxt
+ enqueued++
+ d.newPacketWaker.Assert()
+ }
+ return enqueued, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3267): Queue these packets as well once
+ // WriteRawPacket takes PacketBuffer instead of VectorisedView.
+ return e.lower.WriteRawPacket(vv)
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *endpoint) Wait() {
+ e.lower.Wait()
+
+ // The linkEP is gone. Teardown the outbound dispatcher goroutines.
+ for i := range e.dispatchers {
+ e.dispatchers[i].closeWaker.Assert()
+ }
+
+ e.wg.Wait()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
new file mode 100644
index 000000000..eb5abb906
--- /dev/null
+++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fifo
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// packetBufferQueue is a bounded, thread-safe queue of PacketBuffers.
+//
+type packetBufferQueue struct {
+ mu sync.Mutex
+ list stack.PacketBufferList
+ limit int
+ used int
+}
+
+// emptyLocked determines if the queue is empty.
+// Preconditions: q.mu must be held.
+func (q *packetBufferQueue) emptyLocked() bool {
+ return q.used == 0
+}
+
+// empty determines if the queue is empty.
+func (q *packetBufferQueue) empty() bool {
+ q.mu.Lock()
+ r := q.emptyLocked()
+ q.mu.Unlock()
+
+ return r
+}
+
+// setLimit updates the limit. No PacketBuffers are immediately dropped in case
+// the queue becomes full due to the new limit.
+func (q *packetBufferQueue) setLimit(limit int) {
+ q.mu.Lock()
+ q.limit = limit
+ q.mu.Unlock()
+}
+
+// enqueue adds the given packet to the queue.
+//
+// Returns true when the PacketBuffer is successfully added to the queue, in
+// which case ownership of the reference is transferred to the queue. And
+// returns false if the queue is full, in which case ownership is retained by
+// the caller.
+func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool {
+ q.mu.Lock()
+ r := q.used < q.limit
+ if r {
+ q.list.PushBack(s)
+ q.used++
+ }
+ q.mu.Unlock()
+
+ return r
+}
+
+// dequeue removes and returns the next PacketBuffer from queue, if one exists.
+// Ownership is transferred to the caller.
+func (q *packetBufferQueue) dequeue() *stack.PacketBuffer {
+ q.mu.Lock()
+ s := q.list.Front()
+ if s != nil {
+ q.list.Remove(s)
+ q.used--
+ }
+ q.mu.Unlock()
+
+ return s
+}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 05c7b8024..14b527bc2 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -12,10 +12,7 @@ go_library(
"errors.go",
"rawfile_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/rawfile",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index dda3b10a6..99313ee25 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -14,7 +14,7 @@
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 44e25d475..f4c32c2da 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -66,39 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
return nil
}
-// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a
-// single syscall. It fails if partial data is written.
-func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
- // If the is no second buffer, issue a regular write.
- if len(b2) == 0 {
- return NonBlockingWrite(fd, b1)
- }
-
- // We have two buffers. Build the iovec that represents them and issue
- // a writev syscall.
- iovec := [3]syscall.Iovec{
- {
- Base: &b1[0],
- Len: uint64(len(b1)),
- },
- {
- Base: &b2[0],
- Len: uint64(len(b2)),
- },
- }
- iovecLen := uintptr(2)
-
- if len(b3) > 0 {
- iovecLen++
- iovec[2].Base = &b3[0]
- iovec[2].Len = uint64(len(b3))
- }
-
+// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall.
+// It fails if partial data is written.
+func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error {
+ iovecLen := uintptr(len(iovec))
_, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
if e != 0 {
return TranslateErrno(e)
}
-
return nil
}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index 0a5ea3dc4..13243ebbb 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,12 +10,10 @@ go_library(
"sharedmem_unsafe.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem",
- visibility = [
- "//:sandbox",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
@@ -31,8 +28,9 @@ go_test(
srcs = [
"sharedmem_test.go",
],
- embed = [":sharedmem"],
+ library = ":sharedmem",
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 330ed5e94..87020ec08 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -11,8 +10,7 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/pipe",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
)
go_test(
@@ -20,5 +18,6 @@ go_test(
srcs = [
"pipe_test.go",
],
- embed = [":pipe"],
+ library = ":pipe",
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
index 59ef69a8b..dc239a0d0 100644
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -18,8 +18,9 @@ import (
"math/rand"
"reflect"
"runtime"
- "sync"
"testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func TestSimpleReadWrite(t *testing.T) {
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index de1ce043d..3ba06af73 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,8 +8,7 @@ go_library(
"rx.go",
"tx.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/tcpip/link/sharedmem/pipe",
@@ -22,7 +20,7 @@ go_test(
srcs = [
"queue_test.go",
],
- embed = [":queue"],
+ library = ":queue",
deps = [
"//pkg/tcpip/link/sharedmem/pipe",
],
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 279e2b457..7fb8a6c49 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -23,11 +23,11 @@
package sharedmem
import (
- "sync"
"sync/atomic"
"syscall"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -183,26 +183,33 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
-// WritePacket writes outbound packets to the file descriptor. If it is not
-// currently writable, the packet is dropped.
-func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- // Add the ethernet header here.
- eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
ethHdr := &header.EthernetFields{
- DstAddr: r.RemoteLinkAddress,
+ DstAddr: remote,
Type: protocol,
}
- if r.LocalLinkAddress != "" {
- ethHdr.SrcAddr = r.LocalLinkAddress
+
+ // Preserve the src address if it's set in the route.
+ if local != "" {
+ ethHdr.SrcAddr = local
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt)
- v := payload.ToView()
+ views := pkt.Views()
// Transmit the packet.
e.mu.Lock()
- ok := e.tx.transmit(hdr.View(), v)
+ ok := e.tx.transmit(views...)
e.mu.Unlock()
if !ok {
@@ -213,16 +220,16 @@ func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependa
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
- v := packet.ToView()
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ views := vv.Views()
// Transmit the packet.
e.mu.Lock()
- ok := e.tx.transmit(v, buffer.View{})
+ ok := e.tx.transmit(views...)
e.mu.Unlock()
if !ok {
@@ -268,13 +275,18 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
rxb[i].Size = e.bufferSize
}
- if n < header.EthernetMinimumSize {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(b).ToVectorisedView(),
+ })
+
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
continue
}
+ eth := header.Ethernet(hdr)
// Send packet up the stack.
- eth := header.Ethernet(b)
- d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), buffer.View(b[header.EthernetMinimumSize:]).ToVectorisedView(), buffer.View(eth))
+ d.DeliverNetworkPacket(eth.SourceAddress(), eth.DestinationAddress(), eth.Type(), pkt)
}
// Clean state.
@@ -283,3 +295,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
e.completed.Done()
}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType
+func (*endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index f3e9705c9..22d5c97f1 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -22,11 +22,11 @@ import (
"math/rand"
"os"
"strings"
- "sync"
"syscall"
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -131,19 +131,22 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
return c
}
-func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
c.mu.Lock()
c.packets = append(c.packets, packetInfo{
- addr: remoteLinkAddr,
- proto: proto,
- vv: vv.Clone(nil),
- linkHeader: linkHeader,
+ addr: remoteLinkAddr,
+ proto: proto,
+ vv: pkt.Data.Clone(nil),
})
c.mu.Unlock()
c.packetCh <- struct{}{}
}
+func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (c *testContext) cleanup() {
c.ep.Close()
closeFDs(&c.txCfg)
@@ -263,18 +266,23 @@ func TestSimpleSend(t *testing.T) {
for iters := 1000; iters > 0; iters-- {
func() {
+ hdrLen, dataLen := rand.Intn(10000), rand.Intn(10000)
+
// Prepare and send packet.
- n := rand.Intn(10000)
- hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength()))
- hdrBuf := hdr.Prepend(n)
+ hdrBuf := buffer.NewView(hdrLen)
randomFill(hdrBuf)
- n = rand.Intn(10000)
- buf := buffer.NewView(n)
- randomFill(buf)
+ data := buffer.NewView(dataLen)
+ randomFill(data)
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: hdrLen + int(c.ep.MaxHeaderLength()),
+ Data: data.ToVectorisedView(),
+ })
+ copy(pkt.NetworkHeader().Push(hdrLen), hdrBuf)
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), proto); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -311,7 +319,7 @@ func TestSimpleSend(t *testing.T) {
// Compare contents skipping the ethernet header added by the
// endpoint.
- merged := append(hdrBuf, buf...)
+ merged := append(hdrBuf, data...)
if uint32(len(contents)) < pi.Size {
t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
}
@@ -338,12 +346,14 @@ func TestPreserveSrcAddressInSend(t *testing.T) {
LocalLinkAddress: newLocalLinkAddress,
}
- // WritePacket panics given a prependable with anything less than
- // the minimum size of the ethernet header.
- hdr := buffer.NewPrependable(header.EthernetMinimumSize)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // WritePacket panics given a prependable with anything less than
+ // the minimum size of the ethernet header.
+ ReserveHeaderBytes: header.EthernetMinimumSize,
+ })
proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buffer.VectorisedView{}, proto); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
@@ -395,9 +405,12 @@ func TestFillTxQueue(t *testing.T) {
// until the tx queue if full.
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -411,8 +424,11 @@ func TestFillTxQueue(t *testing.T) {
}
// Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -436,8 +452,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// Send two packets so that the id slice has at least two slots.
for i := 2; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
@@ -456,8 +475,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
// until the tx queue if full.
ids := make(map[uint64]struct{})
for i := queuePipeSize / 40; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -471,8 +493,11 @@ func TestFillTxQueueAfterBadCompletion(t *testing.T) {
}
// Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != want {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
@@ -494,8 +519,11 @@ func TestFillTxMemory(t *testing.T) {
// we fill the memory.
ids := make(map[uint64]struct{})
for i := queueDataSize / bufferSize; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -510,8 +538,11 @@ func TestFillTxMemory(t *testing.T) {
}
// Next attempt to write must fail.
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt)
if want := tcpip.ErrWouldBlock; err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
@@ -535,8 +566,11 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Each packet is uses up one buffer, so write as many as possible
// until there is only one buffer left.
for i := queueDataSize/bufferSize - 1; i > 0; i-- {
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
@@ -547,17 +581,22 @@ func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
// Attempt to write a two-buffer packet. It must fail.
{
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- uu := buffer.NewView(bufferSize).ToVectorisedView()
- if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, hdr, uu, header.IPv4ProtocolNumber); err != want {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buffer.NewView(bufferSize).ToVectorisedView(),
+ })
+ if want, err := tcpip.ErrWouldBlock, c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != want {
t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
}
}
// Attempt to write the one-buffer packet again. It must succeed.
{
- hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
- if err := c.ep.WritePacket(&r, nil /* gso */, hdr, buf.ToVectorisedView(), header.IPv4ProtocolNumber); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(c.ep.MaxHeaderLength()),
+ Data: buf.ToVectorisedView(),
+ })
+ if err := c.ep.WritePacket(&r, nil /* gso */, header.IPv4ProtocolNumber, pkt); err != nil {
t.Fatalf("WritePacket failed unexpectedly: %v", err)
}
}
@@ -640,7 +679,7 @@ func TestSimpleReceive(t *testing.T) {
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
c.mu.Lock()
- rcvd := []byte(c.packets[0].vv.First())
+ rcvd := []byte(c.packets[0].vv.ToView())
c.packets = c.packets[:0]
c.mu.Unlock()
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
index 6b8d7859d..44f421c2d 100644
--- a/pkg/tcpip/link/sharedmem/tx.go
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -18,6 +18,7 @@ import (
"math"
"syscall"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/link/sharedmem/queue"
)
@@ -76,9 +77,9 @@ func (t *tx) cleanup() {
syscall.Munmap(t.data)
}
-// transmit sends a packet made up of up to two buffers. Returns a boolean that
-// specifies whether the packet was successfully transmitted.
-func (t *tx) transmit(a, b []byte) bool {
+// transmit sends a packet made of bufs. Returns a boolean that specifies
+// whether the packet was successfully transmitted.
+func (t *tx) transmit(bufs ...buffer.View) bool {
// Pull completions from the tx queue and add their buffers back to the
// pool so that we can reuse them.
for {
@@ -93,7 +94,10 @@ func (t *tx) transmit(a, b []byte) bool {
}
bSize := t.bufs.entrySize
- total := uint32(len(a) + len(b))
+ total := uint32(0)
+ for _, data := range bufs {
+ total += uint32(len(data))
+ }
bufCount := (total + bSize - 1) / bSize
// Allocate enough buffers to hold all the data.
@@ -115,7 +119,7 @@ func (t *tx) transmit(a, b []byte) bool {
// Copy data into allocated buffers.
nBuf := buf
var dBuf []byte
- for _, data := range [][]byte{a, b} {
+ for _, data := range bufs {
for len(data) > 0 {
if len(dBuf) == 0 {
dBuf = t.data[nBuf.Offset:][:nBuf.Size]
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 1756114e6..7cbc305e7 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,15 +8,13 @@ go_library(
"pcap.go",
"sniffer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/sniffer",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 39757ea2a..4fb127978 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -21,11 +21,9 @@
package sniffer
import (
- "bytes"
"encoding/binary"
"fmt"
"io"
- "os"
"sync/atomic"
"time"
@@ -33,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -42,26 +41,29 @@ import (
// LogPackets must be accessed atomically.
var LogPackets uint32 = 1
-// LogPacketsToFile is a flag used to enable or disable logging packets to a
-// pcap file. Valid values are 0 or 1. A file must have been specified when the
+// LogPacketsToPCAP is a flag used to enable or disable logging packets to a
+// pcap writer. Valid values are 0 or 1. A writer must have been specified when the
// sniffer was created for this flag to have effect.
//
-// LogPacketsToFile must be accessed atomically.
-var LogPacketsToFile uint32 = 1
+// LogPacketsToPCAP must be accessed atomically.
+var LogPacketsToPCAP uint32 = 1
type endpoint struct {
- dispatcher stack.NetworkDispatcher
- lower stack.LinkEndpoint
- file *os.File
+ nested.Endpoint
+ writer io.Writer
maxPCAPLen uint32
}
+var _ stack.GSOEndpoint = (*endpoint)(nil)
+var _ stack.LinkEndpoint = (*endpoint)(nil)
+var _ stack.NetworkDispatcher = (*endpoint)(nil)
+
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
- return &endpoint{
- lower: lower,
- }
+ sniffer := &endpoint{}
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer
}
func zoneOffset() (int32, error) {
@@ -92,132 +94,72 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
})
}
-// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
-// another endpoint and logs packets and they traverse the endpoint.
+// NewWithWriter creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets as they traverse the endpoint.
//
-// Packets can be logged to file in the pcap format. A sniffer created
-// with this function will not emit packets using the standard log
-// package.
+// Packets are logged to writer in the pcap format. A sniffer created with this
+// function will not emit packets using the standard log package.
//
// snapLen is the maximum amount of a packet to be saved. Packets with a length
-// less than or equal too snapLen will be saved in their entirety. Longer
+// less than or equal to snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
-func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) {
- if err := writePCAPHeader(file, snapLen); err != nil {
+func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) (stack.LinkEndpoint, error) {
+ if err := writePCAPHeader(writer, snapLen); err != nil {
return nil, err
}
- return &endpoint{
- lower: lower,
- file: file,
+ sniffer := &endpoint{
+ writer: writer,
maxPCAPLen: snapLen,
- }, nil
+ }
+ sniffer.Endpoint.Init(lower, sniffer)
+ return sniffer, nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
-func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("recv", protocol, vv.First(), nil)
- }
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- vs := vv.Views()
- length := vv.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
- }
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
- panic(err)
- }
- for _, v := range vs {
- if length == 0 {
- break
- }
- if len(v) > length {
- v = v[:length]
- }
- if _, err := buf.Write([]byte(v)); err != nil {
- panic(err)
- }
- length -= len(v)
- }
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
- }
- }
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader)
-}
-
-// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
-// and registers with the lower endpoint as its dispatcher so that "e" is called
-// for inbound packets.
-func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
- e.lower.Attach(e)
+func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.dumpPacket("recv", nil, protocol, pkt)
+ e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt)
}
-// IsAttached implements stack.LinkEndpoint.IsAttached.
-func (e *endpoint) IsAttached() bool {
- return e.dispatcher != nil
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt)
}
-// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
-// lower endpoint.
-func (e *endpoint) MTU() uint32 {
- return e.lower.MTU()
-}
-
-// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
-// request to the lower endpoint.
-func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
- return e.lower.Capabilities()
-}
-
-// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
-// the request to the lower endpoint.
-func (e *endpoint) MaxHeaderLength() uint16 {
- return e.lower.MaxHeaderLength()
-}
-
-func (e *endpoint) LinkAddress() tcpip.LinkAddress {
- return e.lower.LinkAddress()
-}
-
-// GSOMaxSize returns the maximum GSO packet size.
-func (e *endpoint) GSOMaxSize() uint32 {
- if gso, ok := e.lower.(stack.GSOEndpoint); ok {
- return gso.GSOMaxSize()
- }
- return 0
-}
-
-func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", protocol, hdr.View(), gso)
+func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ writer := e.writer
+ if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
+ logPacket(prefix, protocol, pkt, gso)
}
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- hdrBuf := hdr.View()
- length := len(hdrBuf) + payload.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
+ if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
+ totalLength := pkt.Size()
+ length := totalLength
+ if max := int(e.maxPCAPLen); length > max {
+ length = max
}
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil {
+ if err := binary.Write(writer, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(totalLength))); err != nil {
panic(err)
}
- if len(hdrBuf) > length {
- hdrBuf = hdrBuf[:length]
- }
- if _, err := buf.Write(hdrBuf); err != nil {
- panic(err)
+ write := func(b []byte) {
+ if len(b) > length {
+ b = b[:length]
+ }
+ for len(b) != 0 {
+ n, err := writer.Write(b)
+ if err != nil {
+ panic(err)
+ }
+ b = b[n:]
+ length -= n
+ }
}
- length -= len(hdrBuf)
- logVectorisedView(payload, length, buf)
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
+ for _, v := range pkt.Views() {
+ if length == 0 {
+ break
+ }
+ write(v)
}
}
}
@@ -225,68 +167,30 @@ func (e *endpoint) dumpPacket(gso *stack.GSO, hdr buffer.Prependable, payload bu
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- e.dumpPacket(gso, hdr, payload, protocol)
- return e.lower.WritePacket(r, gso, hdr, payload, protocol)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.dumpPacket("send", gso, protocol, pkt)
+ return e.Endpoint.WritePacket(r, gso, protocol, pkt)
}
// WritePackets implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and
// forwards the request to the lower endpoint.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- view := payload.ToView()
- for _, d := range hdrs {
- e.dumpPacket(gso, d.Hdr, buffer.NewVectorisedView(d.Size, []buffer.View{view[d.Off:][:d.Size]}), protocol)
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.dumpPacket("send", gso, protocol, pkt)
}
- return e.lower.WritePackets(r, gso, hdrs, payload, protocol)
+ return e.Endpoint.WritePackets(r, gso, pkts, protocol)
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
- if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
- logPacket("send", 0, buffer.View("[raw packet, no header available]"), nil /* gso */)
- }
- if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
- length := packet.Size()
- if length > int(e.maxPCAPLen) {
- length = int(e.maxPCAPLen)
- }
-
- buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
- if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(packet.Size()))); err != nil {
- panic(err)
- }
- logVectorisedView(packet, length, buf)
- if _, err := e.file.Write(buf.Bytes()); err != nil {
- panic(err)
- }
- }
- return e.lower.WriteRawPacket(packet)
-}
-
-func logVectorisedView(vv buffer.VectorisedView, length int, buf *bytes.Buffer) {
- if length <= 0 {
- return
- }
- for _, v := range vv.Views() {
- if len(v) > length {
- v = v[:length]
- }
- n, err := buf.Write(v)
- if err != nil {
- panic(err)
- }
- length -= n
- if length == 0 {
- return
- }
- }
+func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ e.dumpPacket("send", nil, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ return e.Endpoint.WriteRawPacket(vv)
}
-// Wait implements stack.LinkEndpoint.Wait.
-func (*endpoint) Wait() {}
-
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -295,30 +199,47 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
size := uint16(0)
var fragmentOffset uint16
var moreFragments bool
+
+ // Examine the packet using a new VV. Backing storage must not be written.
+ vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
+
switch protocol {
case header.IPv4ProtocolNumber:
- ipv4 := header.IPv4(b)
+ hdr, ok := vv.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ ipv4 := header.IPv4(hdr)
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- b = b[ipv4.HeaderLength():]
+ vv.TrimFront(int(ipv4.HeaderLength()))
id = int(ipv4.ID())
case header.IPv6ProtocolNumber:
- ipv6 := header.IPv6(b)
+ hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ ipv6 := header.IPv6(hdr)
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
transProto = ipv6.NextHeader()
size = ipv6.PayloadLength()
- b = b[header.IPv6MinimumSize:]
+ vv.TrimFront(header.IPv6MinimumSize)
case header.ARPProtocolNumber:
- arp := header.ARP(b)
+ hdr, ok := vv.PullUp(header.ARPSize)
+ if !ok {
+ return
+ }
+ vv.TrimFront(header.ARPSize)
+ arp := header.ARP(hdr)
log.Infof(
- "%s arp %v (%v) -> %v (%v) valid:%v",
+ "%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
@@ -338,7 +259,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- icmp := header.ICMPv4(b)
+ hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv4(hdr)
icmpType := "unknown"
if fragmentOffset == 0 {
switch icmp.Type() {
@@ -366,12 +291,16 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
icmpType = "info reply"
}
}
- log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- icmp := header.ICMPv6(b)
+ hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv6(hdr)
icmpType := "unknown"
switch icmp.Type() {
case header.ICMPv6DstUnreachable:
@@ -397,13 +326,17 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
- log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ log.Infof("%s %s %s -> %s %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
transName = "udp"
- udp := header.UDP(b)
- if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
+ hdr, ok := vv.PullUp(header.UDPMinimumSize)
+ if !ok {
+ break
+ }
+ udp := header.UDP(hdr)
+ if fragmentOffset == 0 {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
@@ -412,15 +345,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.TCPProtocolNumber:
transName = "tcp"
- tcp := header.TCP(b)
- if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
+ hdr, ok := vv.PullUp(header.TCPMinimumSize)
+ if !ok {
+ break
+ }
+ tcp := header.TCP(hdr)
+ if fragmentOffset == 0 {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
break
}
- if offset > len(tcp) && !moreFragments {
- details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp))
+ if offset > vv.Size() && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size())
break
}
@@ -436,7 +373,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
flagsStr[i] = ' '
}
}
- details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
if flags&header.TCPFlagSyn != 0 {
details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
} else {
@@ -445,7 +382,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
}
default:
- log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
+ log.Infof("%s %s -> %s unknown transport protocol: %d", prefix, src, dst, transProto)
return
}
@@ -453,5 +390,5 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
details += fmt.Sprintf(" gso: %+v", gso)
}
- log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+ log.Infof("%s %s %s:%d -> %s:%d len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 92dce8fac..6c137f693 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -1,12 +1,26 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "tun",
- srcs = ["tun_unsafe.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/tun",
- visibility = [
- "//visibility:public",
+ srcs = [
+ "device.go",
+ "protocol.go",
+ "tun_unsafe.go",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
],
)
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
new file mode 100644
index 000000000..3b1510a33
--- /dev/null
+++ b/pkg/tcpip/link/tun/device.go
@@ -0,0 +1,383 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // drivers/net/tun.c:tun_net_init()
+ defaultDevMtu = 1500
+
+ // Queue length for outbound packet, arriving at fd side for read. Overflow
+ // causes packet drops. gVisor implementation-specific.
+ defaultDevOutQueueLen = 1024
+)
+
+var zeroMAC [6]byte
+
+// Device is an opened /dev/net/tun device.
+//
+// +stateify savable
+type Device struct {
+ waiter.Queue
+
+ mu sync.RWMutex `state:"nosave"`
+ endpoint *tunEndpoint
+ notifyHandle *channel.NotificationHandle
+ flags uint16
+}
+
+// beforeSave is invoked by stateify.
+func (d *Device) beforeSave() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ // TODO(b/110961832): Restore the device to stack. At this moment, the stack
+ // is not savable.
+ if d.endpoint != nil {
+ panic("/dev/net/tun does not support save/restore when a device is associated with it.")
+ }
+}
+
+// Release implements fs.FileOperations.Release.
+func (d *Device) Release(ctx context.Context) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Decrease refcount if there is an endpoint associated with this file.
+ if d.endpoint != nil {
+ d.endpoint.RemoveNotify(d.notifyHandle)
+ d.endpoint.DecRef(ctx)
+ d.endpoint = nil
+ }
+}
+
+// SetIff services TUNSETIFF ioctl(2) request.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.endpoint != nil {
+ return syserror.EINVAL
+ }
+
+ // Input validations.
+ isTun := flags&linux.IFF_TUN != 0
+ isTap := flags&linux.IFF_TAP != 0
+ supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
+ if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
+ return syserror.EINVAL
+ }
+
+ prefix := "tun"
+ if isTap {
+ prefix = "tap"
+ }
+
+ linkCaps := stack.CapabilityNone
+ if isTap {
+ linkCaps |= stack.CapabilityResolutionRequired
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ if err != nil {
+ return syserror.EINVAL
+ }
+
+ d.endpoint = endpoint
+ d.notifyHandle = d.endpoint.AddNotify(d)
+ d.flags = flags
+ return nil
+}
+
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
+ for {
+ // 1. Try to attach to an existing NIC.
+ if name != "" {
+ if nic, found := s.GetNICByName(name); found {
+ endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if !ok {
+ // Not a NIC created by tun device.
+ return nil, syserror.EOPNOTSUPP
+ }
+ if !endpoint.TryIncRef() {
+ // Race detected: NIC got deleted in between.
+ continue
+ }
+ return endpoint, nil
+ }
+ }
+
+ // 2. Creating a new NIC.
+ id := tcpip.NICID(s.UniqueID())
+ endpoint := &tunEndpoint{
+ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
+ stack: s,
+ nicID: id,
+ name: name,
+ isTap: prefix == "tap",
+ }
+ endpoint.Endpoint.LinkEPCapabilities = linkCaps
+ if endpoint.name == "" {
+ endpoint.name = fmt.Sprintf("%s%d", prefix, id)
+ }
+ err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
+ Name: endpoint.name,
+ })
+ switch err {
+ case nil:
+ return endpoint, nil
+ case tcpip.ErrDuplicateNICID:
+ // Race detected: A NIC has been created in between.
+ continue
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+}
+
+// Write inject one inbound packet to the network interface.
+func (d *Device) Write(data []byte) (int64, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return 0, syserror.EBADFD
+ }
+ if !endpoint.IsAttached() {
+ return 0, syserror.EIO
+ }
+
+ dataLen := int64(len(data))
+
+ // Packet information.
+ var pktInfoHdr PacketInfoHeader
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ if len(data) < PacketInfoHeaderSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
+ data = data[PacketInfoHeaderSize:]
+ }
+
+ // Ethernet header (TAP only).
+ var ethHdr header.Ethernet
+ if d.hasFlags(linux.IFF_TAP) {
+ if len(data) < header.EthernetMinimumSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
+ data = data[header.EthernetMinimumSize:]
+ }
+
+ // Try to determine network protocol number, default zero.
+ var protocol tcpip.NetworkProtocolNumber
+ switch {
+ case pktInfoHdr != nil:
+ protocol = pktInfoHdr.Protocol()
+ case ethHdr != nil:
+ protocol = ethHdr.Type()
+ }
+
+ // Try to determine remote link address, default zero.
+ var remote tcpip.LinkAddress
+ switch {
+ case ethHdr != nil:
+ remote = ethHdr.SourceAddress()
+ default:
+ remote = tcpip.LinkAddress(zeroMAC[:])
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: len(ethHdr),
+ Data: buffer.View(data).ToVectorisedView(),
+ })
+ copy(pkt.LinkHeader().Push(len(ethHdr)), ethHdr)
+ endpoint.InjectLinkAddr(protocol, remote, pkt)
+ return dataLen, nil
+}
+
+// Read reads one outgoing packet from the network interface.
+func (d *Device) Read() ([]byte, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return nil, syserror.EBADFD
+ }
+
+ for {
+ info, ok := endpoint.Read()
+ if !ok {
+ return nil, syserror.ErrWouldBlock
+ }
+
+ v, ok := d.encodePkt(&info)
+ if !ok {
+ // Ignore unsupported packet.
+ continue
+ }
+ return v, nil
+ }
+}
+
+// encodePkt encodes packet for fd side.
+func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
+ var vv buffer.VectorisedView
+
+ // Packet information.
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
+ hdr.Encode(&PacketInfoFields{
+ Protocol: info.Proto,
+ })
+ vv.AppendView(buffer.View(hdr))
+ }
+
+ // If the packet does not already have link layer header, and the route
+ // does not exist, we can't compute it. This is possibly a raw packet, tun
+ // device doesn't support this at the moment.
+ if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress == "" {
+ return nil, false
+ }
+
+ // Ethernet header (TAP only).
+ if d.hasFlags(linux.IFF_TAP) {
+ // Add ethernet header if not provided.
+ if info.Pkt.LinkHeader().View().IsEmpty() {
+ d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt)
+ }
+ vv.AppendView(info.Pkt.LinkHeader().View())
+ }
+
+ // Append upper headers.
+ vv.AppendView(info.Pkt.NetworkHeader().View())
+ vv.AppendView(info.Pkt.TransportHeader().View())
+ // Append data payload.
+ vv.Append(info.Pkt.Data)
+
+ return vv.ToView(), true
+}
+
+// Name returns the name of the attached network interface. Empty string if
+// unattached.
+func (d *Device) Name() string {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ if d.endpoint != nil {
+ return d.endpoint.name
+ }
+ return ""
+}
+
+// Flags returns the flags set for d. Zero value if unset.
+func (d *Device) Flags() uint16 {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.flags
+}
+
+func (d *Device) hasFlags(flags uint16) bool {
+ return d.flags&flags == flags
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint != nil && endpoint.NumQueued() == 0 {
+ mask &= ^waiter.EventIn
+ }
+ }
+ return mask & (waiter.EventIn | waiter.EventOut)
+}
+
+// WriteNotify implements channel.Notification.WriteNotify.
+func (d *Device) WriteNotify() {
+ d.Notify(waiter.EventIn)
+}
+
+// tunEndpoint is the link endpoint for the NIC created by the tun device.
+//
+// It is ref-counted as multiple opening files can attach to the same NIC.
+// The last owner is responsible for deleting the NIC.
+type tunEndpoint struct {
+ *channel.Endpoint
+
+ refs.AtomicRefCount
+
+ stack *stack.Stack
+ nicID tcpip.NICID
+ name string
+ isTap bool
+}
+
+// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+func (e *tunEndpoint) DecRef(ctx context.Context) {
+ e.DecRefWithDestructor(ctx, func(context.Context) {
+ e.stack.RemoveNIC(e.nicID)
+ })
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType {
+ if e.isTap {
+ return header.ARPHardwareEther
+ }
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.isTap {
+ return
+ }
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ hdr := &header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: protocol,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = e.LinkAddress()
+ }
+
+ eth.Encode(hdr)
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header.
+func (e *tunEndpoint) MaxHeaderLength() uint16 {
+ if e.isTap {
+ return header.EthernetMinimumSize
+ }
+ return 0
+}
diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go
new file mode 100644
index 000000000..89d9d91a9
--- /dev/null
+++ b/pkg/tcpip/link/tun/protocol.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // PacketInfoHeaderSize is the size of the packet information header.
+ PacketInfoHeaderSize = 4
+
+ offsetFlags = 0
+ offsetProtocol = 2
+)
+
+// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is
+// not set.
+type PacketInfoFields struct {
+ Flags uint16
+ Protocol tcpip.NetworkProtocolNumber
+}
+
+// PacketInfoHeader is the wire representation of the packet information sent if
+// IFF_NO_PI flag is not set.
+type PacketInfoHeader []byte
+
+// Encode encodes f into h.
+func (h PacketInfoHeader) Encode(f *PacketInfoFields) {
+ binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags)
+ binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol))
+}
+
+// Flags returns the flag field in h.
+func (h PacketInfoHeader) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[offsetFlags:])
+}
+
+// Protocol returns the protocol field in h.
+func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:]))
+}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 0746dc8ec..ee84c3d96 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -8,14 +7,12 @@ go_library(
srcs = [
"waitable.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/link/waitable",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/gate",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
@@ -25,10 +22,11 @@ go_test(
srcs = [
"waitable_test.go",
],
- embed = [":waitable"],
+ library = ":waitable",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
],
)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index a04fc1062..b152a0f26 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/gate"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -50,12 +51,21 @@ func New(lower stack.LinkEndpoint) *Endpoint {
// It is called by the link-layer endpoint being wrapped when a packet arrives,
// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
// been called.
-func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
if !e.dispatchGate.Enter() {
return
}
- e.dispatcher.DeliverNetworkPacket(e, remote, local, protocol, vv, linkHeader)
+ e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt)
+ e.dispatchGate.Leave()
+}
+
+// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket.
+func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+ e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt)
e.dispatchGate.Leave()
}
@@ -99,12 +109,12 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
- err := e.lower.WritePacket(r, gso, hdr, payload, protocol)
+ err := e.lower.WritePacket(r, gso, protocol, pkt)
e.writeGate.Leave()
return err
}
@@ -112,23 +122,23 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
// WritePackets implements stack.LinkEndpoint.WritePackets. It is called by
// higher-level protocols to write packets. It only forwards packets to the
// lower endpoint if Wait or WaitWrite haven't been called.
-func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
if !e.writeGate.Enter() {
- return len(hdrs), nil
+ return pkts.Len(), nil
}
- n, err := e.lower.WritePackets(r, gso, hdrs, payload, protocol)
+ n, err := e.lower.WritePackets(r, gso, pkts, protocol)
e.writeGate.Leave()
return n, err
}
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
-func (e *Endpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
if !e.writeGate.Enter() {
return nil
}
- err := e.lower.WriteRawPacket(packet)
+ err := e.lower.WriteRawPacket(vv)
e.writeGate.Leave()
return err
}
@@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() {
// Wait implements stack.LinkEndpoint.Wait.
func (e *Endpoint) Wait() {}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (e *Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return e.lower.ARPHardwareType()
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ e.lower.AddHeader(local, remote, protocol, pkt)
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 5f0f8fa2d..94827fc56 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -35,10 +36,14 @@ type countedEndpoint struct {
dispatcher stack.NetworkDispatcher
}
-func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
e.dispatchCount++
}
+func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.attachCount++
e.dispatcher = dispatcher
@@ -65,45 +70,55 @@ func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
-func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
e.writeCount++
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- e.writeCount += len(hdrs)
- return len(hdrs), nil
+func (e *countedEndpoint) WritePackets(r *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ e.writeCount += pkts.Len()
+ return pkts.Len(), nil
}
-func (e *countedEndpoint) WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error {
+func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
e.writeCount++
return nil
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("unimplemented")
+}
+
// Wait implements stack.LinkEndpoint.Wait.
func (*countedEndpoint) Wait() {}
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("unimplemented")
+}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
// Write and check that it goes through.
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 1; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on dispatches, then try to write. It must go through.
wep.WaitDispatch()
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
// Wait on writes, then try to write. It must not go through.
wep.WaitWrite()
- wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
+ wep.WritePacket(nil, nil /* gso */, 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.writeCount != want {
t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
}
@@ -120,21 +135,21 @@ func TestWaitDispatch(t *testing.T) {
}
// Dispatch and check that it goes through.
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 1; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on writes, then try to dispatch. It must go through.
wep.WaitWrite()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
// Wait on dispatches, then try to dispatch. It must not go through.
wep.WaitDispatch()
- ep.dispatcher.DeliverNetworkPacket(ep, "", "", 0, buffer.VectorisedView{}, buffer.View{})
+ ep.dispatcher.DeliverNetworkPacket("", "", 0, stack.NewPacketBuffer(stack.PacketBufferOptions{}))
if want := 2; ep.dispatchCount != want {
t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 9d16ff8c9..46083925c 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_test")
package(licenses = ["notice"])
@@ -12,6 +12,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index df0d3a8c0..eddf7b725 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,15 +1,11 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "arp",
srcs = ["arp.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/arp",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 46178459e..920872c3f 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -42,7 +42,8 @@ const (
// endpoint implements stack.NetworkEndpoint.
type endpoint struct {
- nicid tcpip.NICID
+ protocol *protocol
+ nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
}
@@ -58,43 +59,39 @@ func (e *endpoint) MTU() uint32 {
}
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
+ return e.nicID
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
}
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &stack.NetworkEndpointID{ProtocolAddress}
-}
-
-func (e *endpoint) PrefixLen() int {
- return 0
-}
-
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
func (e *endpoint) Close() {}
-func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderParams, *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, []stack.PacketDescriptor, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) (int, *tcpip.Error) {
+func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, stack.NetworkHeaderParams) (int, *tcpip.Error) {
return 0, tcpip.ErrNotSupported
}
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- v := vv.First()
- h := header.ARP(v)
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.ARP(pkt.NetworkHeader().View())
if !h.IsValid() {
return
}
@@ -102,23 +99,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
switch h.Op() {
case header.ARPRequest:
localAddr := tcpip.Address(h.ProtocolAddressTarget())
- if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
- hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
- pkt := header.ARP(hdr.Prepend(header.ARPSize))
- pkt.SetIPv4OverEthernet()
- pkt.SetOp(header.ARPReply)
- copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
- copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
- copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
- copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
- e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize,
+ })
+ packet := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
+ packet.SetIPv4OverEthernet()
+ packet.SetOp(header.ARPReply)
+ copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(packet.HardwareAddressTarget(), h.HardwareAddressSender())
+ copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
fallthrough // also fill the cache from requests
case header.ARPReply:
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
+ e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr)
}
}
@@ -135,76 +134,77 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
- if addrWithPrefix.Address != ProtocolAddress {
- return nil, tcpip.ErrBadLocalAddress
- }
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
- nicid: nicid,
+ protocol: p,
+ nicID: nicID,
linkEP: sender,
linkAddrCache: linkAddrCache,
- }, nil
+ }
}
-// LinkAddressProtocol implements stack.LinkAddressResolver.
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
- RemoteLinkAddress: broadcastMAC,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
- hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
- h := header.ARP(hdr.Prepend(header.ARPSize))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize,
+ })
+ h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize))
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
copy(h.HardwareAddressSender(), linkEP.LinkAddress())
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
- return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
}
-// ResolveStaticAddress implements stack.LinkAddressResolver.
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
- return broadcastMAC, true
+ return header.EthernetBroadcastAddress, true
}
if header.IsV4MulticastAddress(addr) {
- // RFC 1112 Host Extensions for IP Multicasting
- //
- // 6.4. Extensions to an Ethernet Local Network Module:
- //
- // An IP host group address is mapped to an Ethernet multicast
- // address by placing the low-order 23-bits of the IP address
- // into the low-order 23 bits of the Ethernet multicast address
- // 01-00-5E-00-00-00 (hex).
- return tcpip.LinkAddress([]byte{
- 0x01,
- 0x00,
- 0x5e,
- addr[header.IPv4AddressSize-3] & 0x7f,
- addr[header.IPv4AddressSize-2],
- addr[header.IPv4AddressSize-1],
- }), true
+ return header.EthernetAddressFromMulticastIPv4Address(addr), true
}
- return "", false
+ return tcpip.LinkAddress([]byte(nil)), false
}
-// SetOption implements NetworkProtocol.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements NetworkProtocol.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.NetworkProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ _, ok = pkt.NetworkHeader().Consume(header.ARPSize)
+ if !ok {
+ return 0, false, false
+ }
+ return 0, false, true
+}
// NewProtocol returns an ARP network protocol.
func NewProtocol() stack.NetworkProtocol {
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 88b57ec03..c2c3e6891 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -15,6 +15,7 @@
package arp_test
import (
+ "context"
"strconv"
"testing"
"time"
@@ -31,10 +32,14 @@ import (
)
const (
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+ stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
type testContext struct {
@@ -49,8 +54,7 @@ func newTestContext(t *testing.T) *testContext {
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
})
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
@@ -83,7 +87,7 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP.C)
+ c.linkEP.Close()
}
func TestDirectRequest(t *testing.T) {
@@ -102,21 +106,23 @@ func TestDirectRequest(t *testing.T) {
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
- c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView())
+ c.linkEP.InjectInbound(arp.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v.ToVectorisedView(),
+ }))
}
for i, address := range []tcpip.Address{stackAddr1, stackAddr2} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
inject(address)
- pkt := <-c.linkEP.C
- if pkt.Proto != arp.ProtocolNumber {
- t.Fatalf("expected ARP response, got network protocol number %d", pkt.Proto)
+ pi, _ := c.linkEP.ReadContext(context.Background())
+ if pi.Proto != arp.ProtocolNumber {
+ t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto)
}
- rep := header.ARP(pkt.Header)
+ rep := header.ARP(pi.Pkt.NetworkHeader().View())
if !rep.IsValid() {
- t.Fatalf("invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep)
}
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
@@ -132,12 +138,53 @@ func TestDirectRequest(t *testing.T) {
}
inject(stackAddrBad)
- select {
- case pkt := <-c.linkEP.C:
+ // Sleep tests are gross, but this will only potentially flake
+ // if there's a bug. If there is no bug this will reliably
+ // succeed.
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+ if pkt, ok := c.linkEP.ReadContext(ctx); ok {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
- case <-time.After(100 * time.Millisecond):
- // Sleep tests are gross, but this will only potentially flake
- // if there's a bug. If there is no bug this will reliably
- // succeed.
+ }
+}
+
+func TestLinkAddressRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: stackLinkAddr2,
+ expectLinkAddr: stackLinkAddr2,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: header.EthernetBroadcastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ p := arp.NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
+ if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 2cad0a0b6..d1c728ccf 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -24,10 +23,10 @@ go_library(
"reassembler.go",
"reassembler_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/log",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
],
@@ -41,14 +40,6 @@ go_test(
"fragmentation_test.go",
"reassembler_test.go",
],
- embed = [":fragmentation"],
+ library = ":fragmentation",
deps = ["//pkg/tcpip/buffer"],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "reassembler_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 6da5238ec..1827666c5 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -17,28 +17,58 @@
package fragmentation
import (
+ "errors"
"fmt"
"log"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
-// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
-const DefaultReassembleTimeout = 30 * time.Second
+const (
+ // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+ DefaultReassembleTimeout = 30 * time.Second
-// HighFragThreshold is the threshold at which we start trimming old
-// fragmented packets. Linux uses a default value of 4 MB. See
-// net.ipv4.ipfrag_high_thresh for more information.
-const HighFragThreshold = 4 << 20 // 4MB
+ // HighFragThreshold is the threshold at which we start trimming old
+ // fragmented packets. Linux uses a default value of 4 MB. See
+ // net.ipv4.ipfrag_high_thresh for more information.
+ HighFragThreshold = 4 << 20 // 4MB
-// LowFragThreshold is the threshold we reach to when we start dropping
-// older fragmented packets. It's important that we keep enough room for newer
-// packets to be re-assembled. Hence, this needs to be lower than
-// HighFragThreshold enough. Linux uses a default value of 3 MB. See
-// net.ipv4.ipfrag_low_thresh for more information.
-const LowFragThreshold = 3 << 20 // 3MB
+ // LowFragThreshold is the threshold we reach to when we start dropping
+ // older fragmented packets. It's important that we keep enough room for newer
+ // packets to be re-assembled. Hence, this needs to be lower than
+ // HighFragThreshold enough. Linux uses a default value of 3 MB. See
+ // net.ipv4.ipfrag_low_thresh for more information.
+ LowFragThreshold = 3 << 20 // 3MB
+
+ // minBlockSize is the minimum block size for fragments.
+ minBlockSize = 1
+)
+
+var (
+ // ErrInvalidArgs indicates to the caller that that an invalid argument was
+ // provided.
+ ErrInvalidArgs = errors.New("invalid args")
+)
+
+// FragmentID is the identifier for a fragment.
+type FragmentID struct {
+ // Source is the source address of the fragment.
+ Source tcpip.Address
+
+ // Destination is the destination address of the fragment.
+ Destination tcpip.Address
+
+ // ID is the identification value of the fragment.
+ //
+ // This is a uint32 because IPv6 uses a 32-bit identification value.
+ ID uint32
+
+ // The protocol for the packet.
+ Protocol uint8
+}
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
@@ -46,14 +76,17 @@ type Fragmentation struct {
mu sync.Mutex
highLimit int
lowLimit int
- reassemblers map[uint32]*reassembler
+ reassemblers map[FragmentID]*reassembler
rList reassemblerList
size int
timeout time.Duration
+ blockSize uint16
}
// NewFragmentation creates a new Fragmentation.
//
+// blockSize specifies the fragment block size, in bytes.
+//
// highMemoryLimit specifies the limit on the memory consumed
// by the fragments stored by Fragmentation (overhead of internal data-structures
// is not accounted). Fragments are dropped when the limit is reached.
@@ -64,7 +97,7 @@ type Fragmentation struct {
// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
-func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
@@ -73,17 +106,46 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
lowMemoryLimit = 0
}
+ if blockSize < minBlockSize {
+ blockSize = minBlockSize
+ }
+
return &Fragmentation{
- reassemblers: make(map[uint32]*reassembler),
+ reassemblers: make(map[FragmentID]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
+ blockSize: blockSize,
}
}
-// Process processes an incoming fragment belonging to an ID
-// and returns a complete packet when all the packets belonging to that ID have been received.
-func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+// Process processes an incoming fragment belonging to an ID and returns a
+// complete packet when all the packets belonging to that ID have been received.
+//
+// [first, last] is the range of the fragment bytes.
+//
+// first must be a multiple of the block size f is configured with. The size
+// of the fragment data must be a multiple of the block size, unless there are
+// no fragments following this fragment (more set to false).
+func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) {
+ if first > last {
+ return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs)
+ }
+
+ if first%f.blockSize != 0 {
+ return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs)
+ }
+
+ fragmentSize := last - first + 1
+ if more && fragmentSize%f.blockSize != 0 {
+ return buffer.VectorisedView{}, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs)
+ }
+
+ if l := vv.Size(); l < int(fragmentSize) {
+ return buffer.VectorisedView{}, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs)
+ }
+ vv.CapLength(int(fragmentSize))
+
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
@@ -115,10 +177,12 @@ func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buf
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
if f.size > f.highLimit {
- tail := f.rList.Back()
- for f.size > f.lowLimit && tail != nil {
+ for f.size > f.lowLimit {
+ tail := f.rList.Back()
+ if tail == nil {
+ break
+ }
f.release(tail)
- tail = tail.Prev()
}
}
f.mu.Unlock()
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
index 72c0f53be..9eedd33c4 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation_test.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -15,6 +15,7 @@
package fragmentation
import (
+ "errors"
"reflect"
"testing"
"time"
@@ -33,7 +34,7 @@ func vv(size int, pieces ...string) buffer.VectorisedView {
}
type processInput struct {
- id uint32
+ id FragmentID
first uint16
last uint16
more bool
@@ -53,8 +54,8 @@ var processTestCases = []struct {
{
comment: "One ID",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -64,10 +65,10 @@ var processTestCases = []struct {
{
comment: "Two IDs",
in: []processInput{
- {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
- {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
- {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
- {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: FragmentID{ID: 1}, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: FragmentID{ID: 1}, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
@@ -81,7 +82,7 @@ var processTestCases = []struct {
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
- f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout)
for i, in := range c.in {
vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv)
if err != nil {
@@ -110,14 +111,14 @@ func TestFragmentationProcess(t *testing.T) {
func TestReassemblingTimeout(t *testing.T) {
timeout := time.Millisecond
- f := NewFragmentation(1024, 512, timeout)
+ f := NewFragmentation(minBlockSize, 1024, 512, timeout)
// Send first fragment with id = 0, first = 0, last = 0, and more = true.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
// Sleep more than the timeout.
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
- _, done, err := f.Process(0, 1, 1, false, vv(1, "1"))
+ _, done, err := f.Process(FragmentID{}, 1, 1, false, vv(1, "1"))
if err != nil {
t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err)
}
@@ -127,35 +128,35 @@ func TestReassemblingTimeout(t *testing.T) {
}
func TestMemoryLimits(t *testing.T) {
- f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout)
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{ID: 0}, 0, 0, true, vv(1, "0"))
// Send first fragment with id = 1.
- f.Process(1, 0, 0, true, vv(1, "1"))
+ f.Process(FragmentID{ID: 1}, 0, 0, true, vv(1, "1"))
// Send first fragment with id = 2.
- f.Process(2, 0, 0, true, vv(1, "2"))
+ f.Process(FragmentID{ID: 2}, 0, 0, true, vv(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
- f.Process(3, 0, 0, true, vv(1, "3"))
+ f.Process(FragmentID{ID: 3}, 0, 0, true, vv(1, "3"))
- if _, ok := f.reassemblers[0]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
}
- if _, ok := f.reassemblers[1]; ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 1}]; ok {
t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
}
- if _, ok := f.reassemblers[3]; !ok {
+ if _, ok := f.reassemblers[FragmentID{ID: 3}]; !ok {
t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
}
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
- f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout)
// Send first fragment with id = 0.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
// Send the same packet again.
- f.Process(0, 0, 0, true, vv(1, "0"))
+ f.Process(FragmentID{}, 0, 0, true, vv(1, "0"))
got := f.size
want := 1
@@ -163,3 +164,97 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}
+
+func TestErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ blockSize uint16
+ first uint16
+ last uint16
+ more bool
+ data string
+ err error
+ }{
+ {
+ name: "exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: false,
+ data: "01",
+ },
+ {
+ name: "exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "01",
+ },
+ {
+ name: "exact block size with more and extra data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "012",
+ },
+ {
+ name: "exact block size with more and too little data",
+ blockSize: 2,
+ first: 2,
+ last: 3,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size with more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: true,
+ data: "0",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "not exact block size without more",
+ blockSize: 2,
+ first: 2,
+ last: 2,
+ more: false,
+ data: "0",
+ },
+ {
+ name: "first not a multiple of block size",
+ blockSize: 2,
+ first: 3,
+ last: 4,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ {
+ name: "first more than last",
+ blockSize: 2,
+ first: 4,
+ last: 3,
+ more: true,
+ data: "01",
+ err: ErrInvalidArgs,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout)
+ _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, vv(len(test.data), test.data))
+ if !errors.Is(err, test.err) {
+ t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, _, %v), want = (_, _, %v)", test.first, test.last, test.more, test.data, err, test.err)
+ }
+ if done {
+ t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, true, _), want = (_, false, _)", test.first, test.last, test.more, test.data)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
index 9e002e396..50d30bbf0 100644
--- a/pkg/tcpip/network/fragmentation/reassembler.go
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -18,9 +18,9 @@ import (
"container/heap"
"fmt"
"math"
- "sync"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
)
@@ -32,7 +32,7 @@ type hole struct {
type reassembler struct {
reassemblerEntry
- id uint32
+ id FragmentID
size int
mu sync.Mutex
holes []hole
@@ -42,7 +42,7 @@ type reassembler struct {
creationTime time.Time
}
-func newReassembler(id uint32) *reassembler {
+func newReassembler(id FragmentID) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
index 7eee0710d..dff7c9dcb 100644
--- a/pkg/tcpip/network/fragmentation/reassembler_test.go
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -94,7 +94,7 @@ var holesTestCases = []struct {
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
- r := newReassembler(0)
+ r := newReassembler(FragmentID{})
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
index e6db5c0b0..872165866 100644
--- a/pkg/tcpip/network/hash/BUILD
+++ b/pkg/tcpip/network/hash/BUILD
@@ -1,11 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "hash",
srcs = ["hash.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/hash",
visibility = ["//visibility:public"],
deps = [
"//pkg/rand",
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
index 6a215938b..8f65713c5 100644
--- a/pkg/tcpip/network/hash/hash.go
+++ b/pkg/tcpip/network/hash/hash.go
@@ -80,12 +80,12 @@ func IPv4FragmentHash(h header.IPv4) uint32 {
// RFC 2640 (sec 4.5) is not very sharp on this aspect.
// As a reference, also Linux ignores the protocol to compute
// the hash (inet6_hash_frag).
-func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
+func IPv6FragmentHash(h header.IPv6, id uint32) uint32 {
t := h.SourceAddress()
y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = h.DestinationAddress()
z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
- return Hash3Words(f.ID(), y, z, hashIV)
+ return Hash3Words(id, y, z, hashIV)
}
func rol32(v, shift uint32) uint32 {
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 666d8b92a..9007346fe 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -41,6 +42,7 @@ const (
ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
+ nicID = 1
)
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
@@ -96,16 +98,16 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
-func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
- t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) {
+ t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
-func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- t.checkValues(trans, vv, remote, local)
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ t.checkValues(trans, pkt.Data, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
}
@@ -150,50 +152,60 @@ func (*testObject) Wait() {}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
-func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
if t.v4 {
- h := header.IPv4(hdr.View())
+ h := header.IPv4(pkt.NetworkHeader().View())
prot = tcpip.TransportProtocolNumber(h.Protocol())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
} else {
- h := header.IPv6(hdr.View())
+ h := header.IPv6(pkt.NetworkHeader().View())
prot = tcpip.TransportProtocolNumber(h.NextHeader())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
}
- t.checkValues(prot, payload, srcAddr, dstAddr)
+ t.checkValues(prot, pkt.Data, srcAddr, dstAddr)
return nil
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, hdr []stack.PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
panic("not implemented")
}
-func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
+func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testObject) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ panic("not implemented")
+}
+
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv4.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv4EmptySubnet,
Gateway: ipv4Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv4.ProtocolNumber, false /* multicastLoop */)
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
@@ -201,24 +213,45 @@ func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
})
- s.CreateNIC(1, loopback.New())
- s.AddAddress(1, ipv6.ProtocolNumber, local)
+ s.CreateNIC(nicID, loopback.New())
+ s.AddAddress(nicID, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: header.IPv6EmptySubnet,
Gateway: ipv6Gateway,
NIC: 1,
}})
- return s.FindRoute(1, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
+ return s.FindRoute(nicID, local, remote, ipv6.ProtocolNumber, false /* multicastLoop */)
+}
+
+func buildDummyStack(t *testing.T) *stack.Stack {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, localIpv4Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, localIpv4Addr, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, localIpv6Addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, localIpv6Addr, err)
+ }
+
+ return s
}
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, nil, &o)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, nil, &o, buildDummyStack(t))
+ defer ep.Close()
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
@@ -226,8 +259,11 @@ func TestIPv4Send(t *testing.T) {
payload[i] = uint8(i)
}
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+ // Setup the packet buffer.
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ep.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ })
// Issue the write.
o.protocol = 123
@@ -239,7 +275,11 @@ func TestIPv4Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -247,10 +287,8 @@ func TestIPv4Send(t *testing.T) {
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ defer ep.Close()
totalLen := header.IPv4MinimumSize + 30
view := buffer.NewView(totalLen)
@@ -279,7 +317,13 @@ func TestIPv4Receive(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, view.ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -291,7 +335,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
name string
expectedCount int
fragmentOffset uint16
- code uint8
+ code header.ICMPv4Code
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
@@ -313,10 +357,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
@@ -366,8 +407,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
o.typ = c.expectedTyp
o.extra = c.expectedExtra
- vv := view[:len(view)-c.trunc].ToVectorisedView()
- ep.HandlePacket(&r, vv)
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
@@ -378,10 +418,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv4Addr, localIpv4PrefixLen}, nil, &o, nil)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ defer ep.Close()
totalLen := header.IPv4MinimumSize + 24
@@ -430,13 +468,25 @@ func TestIPv4FragmentationReceive(t *testing.T) {
}
// Send first segment.
- ep.HandlePacket(&r, frag1.ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: frag1.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
if o.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
}
// Send second segment.
- ep.HandlePacket(&r, frag2.ToVectorisedView())
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: frag2.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -445,10 +495,8 @@ func TestIPv4FragmentationReceive(t *testing.T) {
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, nil, &o)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t))
+ defer ep.Close()
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
@@ -456,8 +504,11 @@ func TestIPv6Send(t *testing.T) {
payload[i] = uint8(i)
}
- // Allocate the header buffer.
- hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+ // Setup the packet buffer.
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(ep.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ })
// Issue the write.
o.protocol = 123
@@ -469,7 +520,11 @@ func TestIPv6Send(t *testing.T) {
if err != nil {
t.Fatalf("could not find route: %v", err)
}
- if err := ep.WritePacket(&r, nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: 123, TTL: 123, TOS: stack.DefaultTOS}, stack.PacketOut); err != nil {
+ if err := ep.WritePacket(&r, nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: 123,
+ TTL: 123,
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
@@ -477,10 +532,8 @@ func TestIPv6Send(t *testing.T) {
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
+ defer ep.Close()
totalLen := header.IPv6MinimumSize + 30
view := buffer.NewView(totalLen)
@@ -509,7 +562,13 @@ func TestIPv6Receive(t *testing.T) {
t.Fatalf("could not find route: %v", err)
}
- ep.HandlePacket(&r, view.ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: view.ToVectorisedView(),
+ })
+ if _, _, ok := proto.Parse(pkt); !ok {
+ t.Fatalf("failed to parse packet: %x", pkt.Data.ToView())
+ }
+ ep.HandlePacket(&r, pkt)
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
@@ -525,7 +584,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
expectedCount int
fragmentOffset *uint16
typ header.ICMPv6Type
- code uint8
+ code header.ICMPv6Code
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
@@ -552,11 +611,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
- ep, err := proto.NewEndpoint(1, tcpip.AddressWithPrefix{localIpv6Addr, localIpv6PrefixLen}, nil, &o, nil)
- if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
- }
-
+ ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t))
defer ep.Close()
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
@@ -618,15 +673,26 @@ func TestIPv6ReceiveControl(t *testing.T) {
o.typ = c.expectedTyp
o.extra = c.expectedExtra
- vv := view[:len(view)-c.trunc].ToVectorisedView()
-
// Set ICMPv6 checksum.
icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIpv6Addr, buffer.VectorisedView{}))
- ep.HandlePacket(&r, vv)
+ ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize))
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
})
}
}
+
+// truncatedPacket returns a PacketBuffer based on a truncated view. If view,
+// after truncation, is large enough to hold a network header, it makes part of
+// view the packet's NetworkHeader and the rest its Data. Otherwise all of view
+// becomes Data.
+func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer {
+ v := view[:len(view)-trunc]
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: v.ToVectorisedView(),
+ })
+ _, _ = pkt.NetworkHeader().Consume(netHdrLen)
+ return pkt
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 58e537aad..d142b4ffa 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,10 +8,7 @@ go_library(
"icmp.go",
"ipv4.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv4",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -38,5 +34,6 @@ go_test(
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 50b363dc4..b5659a36b 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -24,8 +24,12 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- h := header.IPv4(vv.First())
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv4(h)
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
@@ -33,13 +37,14 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
// false.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
- // original source address doesn't match the endpoint's address.
- if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
return
}
- hlen := int(h.HeaderLength())
- if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ hlen := int(hdr.HeaderLength())
+ if pkt.Data.Size() < hlen || hdr.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
@@ -47,16 +52,19 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
}
// Skip the ip header, then deliver control message.
- vv.TrimFront(hlen)
- p := h.TransportProtocol()
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+ pkt.Data.TrimFront(hlen)
+ p := hdr.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) {
stats := r.Stats()
received := stats.ICMP.V4PacketsReceived
- v := vv.First()
- if len(v) < header.ICMPv4MinimumSize {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ v, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
@@ -73,29 +81,64 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
// checksum. We'll have to reset this before we hand the packet
// off.
h.SetChecksum(0)
- gotChecksum := ^header.ChecksumVV(vv, 0 /* initial */)
+ gotChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */)
if gotChecksum != wantChecksum {
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
received.Invalid.Increment()
return
}
+ // Make a copy of data before pkt gets sent to raw socket.
+ // DeliverTransportPacket will take ownership of pkt.
+ replyData := pkt.Data.Clone(nil)
+ replyData.TrimFront(header.ICMPv4MinimumSize)
+
// It's possible that a raw socket expects to receive this.
h.SetChecksum(wantChecksum)
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
-
- vv := vv.Clone(nil)
- vv.TrimFront(header.ICMPv4MinimumSize)
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- copy(pkt, h)
- pkt.SetType(header.ICMPv4EchoReply)
- pkt.SetChecksum(0)
- pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
+
+ remoteLinkAddr := r.RemoteLinkAddress
+
+ // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP
+ // source address MUST be one of its own IP addresses (but not a broadcast
+ // or multicast address).
+ localAddr := r.LocalAddress
+ if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) {
+ localAddr = ""
+ }
+
+ r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ // Use the remote link address from the incoming packet.
+ r.ResolveWith(remoteLinkAddr)
+
+ // Prepare a reply packet.
+ icmpHdr := make(header.ICMPv4, header.ICMPv4MinimumSize)
+ copy(icmpHdr, h)
+ icmpHdr.SetType(header.ICMPv4EchoReply)
+ icmpHdr.SetChecksum(0)
+ icmpHdr.SetChecksum(^header.Checksum(icmpHdr, header.ChecksumVV(replyData, 0)))
+ dataVV := buffer.View(icmpHdr).ToVectorisedView()
+ dataVV.Append(replyData)
+ replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: dataVV,
+ })
+
+ // Send out the reply packet.
sent := stats.ICMP.V4PacketsSent
- if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: header.ICMPv4ProtocolNumber,
+ TTL: r.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -104,19 +147,22 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv4EchoReply:
received.EchoReply.Increment()
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt)
case header.ICMPv4DstUnreachable:
received.DstUnreachable.Increment()
- vv.TrimFront(header.ICMPv4MinimumSize)
+ pkt.Data.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
+ case header.ICMPv4HostUnreachable:
+ e.handleControl(stack.ControlNoRoute, 0, pkt)
+
case header.ICMPv4PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
case header.ICMPv4FragmentationNeeded:
mtu := uint32(h.MTU())
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
}
case header.ICMPv4SrcQuench:
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 90f4406e5..79872ec9a 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -44,31 +44,29 @@ const (
// buckets is the number of identifier buckets.
buckets = 2048
+
+ // The size of a fragment block, in bytes, as per RFC 791 section 3.1,
+ // page 14.
+ fragmentblockSize = 8
)
type endpoint struct {
- nicid tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
- linkEP stack.LinkEndpoint
- dispatcher stack.TransportDispatcher
- fragmentation *fragmentation.Fragmentation
- protocol *protocol
+ nicID tcpip.NICID
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ protocol *protocol
+ stack *stack.Stack
}
// NewEndpoint creates a new ipv4 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
- e := &endpoint{
- nicid: nicid,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
- linkEP: linkEP,
- dispatcher: dispatcher,
- fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
- protocol: p,
- }
-
- return e, nil
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
+ return &endpoint{
+ nicID: nicID,
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ protocol: p,
+ stack: st,
+ }
}
// DefaultTTL is the default time-to-live value for this endpoint.
@@ -89,17 +87,7 @@ func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
-}
-
-// ID returns the ipv4 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
-}
-
-// PrefixLen returns the ipv4 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
+ return e.nicID
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
@@ -116,14 +104,18 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
// writePacketFragments calls e.linkEP.WritePacket with each packet fragment to
-// write. It assumes that the IP header is entirely in hdr but does not assume
-// that only the IP header is in hdr. It assumes that the input packet's stated
-// length matches the length of the hdr+payload. mtu includes the IP header and
-// options. This does not support the DontFragment IP flag.
-func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, mtu int) *tcpip.Error {
+// write. It assumes that the IP header is already present in pkt.NetworkHeader.
+// pkt.TransportHeader may be set. mtu includes the IP header and options. This
+// does not support the DontFragment IP flag.
+func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, pkt *stack.PacketBuffer) *tcpip.Error {
// This packet is too big, it needs to be fragmented.
- ip := header.IPv4(hdr.View())
+ ip := header.IPv4(pkt.NetworkHeader().View())
flags := ip.Flags()
// Update mtu to take into account the header, which will exist in all
@@ -137,76 +129,88 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff
outerMTU := innerMTU + int(ip.HeaderLength())
offset := ip.FragmentOffset()
- originalAvailableLength := hdr.AvailableLength()
+
+ // Keep the length reserved for link-layer, we need to create fragments with
+ // the same reserved length.
+ reservedForLink := pkt.AvailableHeaderBytes()
+
+ // Destroy the packet, pull all payloads out for fragmentation.
+ transHeader, data := pkt.TransportHeader().View(), pkt.Data
+
+ // Where possible, the first fragment that is sent has the same
+ // number of bytes reserved for header as the input packet. The link-layer
+ // endpoint may depend on this for looking at, eg, L4 headers.
+ transFitsFirst := len(transHeader) <= innerMTU
+
for i := 0; i < n; i++ {
- // Where possible, the first fragment that is sent has the same
- // hdr.UsedLength() as the input packet. The link-layer endpoint may depends
- // on this for looking at, eg, L4 headers.
- h := ip
- if i > 0 {
- hdr = buffer.NewPrependable(int(ip.HeaderLength()) + originalAvailableLength)
- h = header.IPv4(hdr.Prepend(int(ip.HeaderLength())))
- copy(h, ip[:ip.HeaderLength()])
+ reserve := reservedForLink + int(ip.HeaderLength())
+ if i == 0 && transFitsFirst {
+ // Reserve for transport header if it's going to be put in the first
+ // fragment.
+ reserve += len(transHeader)
+ }
+ fragPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: reserve,
+ })
+ fragPkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+
+ // Copy data for the fragment.
+ avail := innerMTU
+
+ if n := len(transHeader); n > 0 {
+ if n > avail {
+ n = avail
+ }
+ if i == 0 && transFitsFirst {
+ copy(fragPkt.TransportHeader().Push(n), transHeader)
+ } else {
+ fragPkt.Data.AppendView(transHeader[:n:n])
+ }
+ transHeader = transHeader[n:]
+ avail -= n
+ }
+
+ if avail > 0 {
+ n := data.Size()
+ if n > avail {
+ n = avail
+ }
+ data.ReadToVV(&fragPkt.Data, n)
+ avail -= n
}
+
+ copied := uint16(innerMTU - avail)
+
+ // Set lengths in header and calculate checksum.
+ h := header.IPv4(fragPkt.NetworkHeader().Push(len(ip)))
+ copy(h, ip)
if i != n-1 {
h.SetTotalLength(uint16(outerMTU))
h.SetFlagsFragmentOffset(flags|header.IPv4FlagMoreFragments, offset)
} else {
- h.SetTotalLength(uint16(h.HeaderLength()) + uint16(payload.Size()))
+ h.SetTotalLength(uint16(h.HeaderLength()) + copied)
h.SetFlagsFragmentOffset(flags, offset)
}
h.SetChecksum(0)
h.SetChecksum(^h.CalculateChecksum())
- offset += uint16(innerMTU)
- if i > 0 {
- newPayload := payload.Clone([]buffer.View{})
- newPayload.CapLength(innerMTU)
- if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- payload.TrimFront(newPayload.Size())
- continue
- }
- // Special handling for the first fragment because it comes from the hdr.
- if outerMTU >= hdr.UsedLength() {
- // This fragment can fit all of hdr and possibly some of payload, too.
- newPayload := payload.Clone([]buffer.View{})
- newPayloadLength := outerMTU - hdr.UsedLength()
- newPayload.CapLength(newPayloadLength)
- if err := e.linkEP.WritePacket(r, gso, hdr, newPayload, ProtocolNumber); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- payload.TrimFront(newPayloadLength)
- } else {
- // The fragment is too small to fit all of hdr.
- startOfHdr := hdr
- startOfHdr.TrimBack(hdr.UsedLength() - outerMTU)
- emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
- if err := e.linkEP.WritePacket(r, gso, startOfHdr, emptyVV, ProtocolNumber); err != nil {
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- // Add the unused bytes of hdr into the payload that remains to be sent.
- restOfHdr := hdr.View()[outerMTU:]
- tmp := buffer.NewVectorisedView(len(restOfHdr), []buffer.View{buffer.NewViewFromBytes(restOfHdr)})
- tmp.Append(payload)
- payload = tmp
+ offset += copied
+
+ // Send out the fragment.
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil {
+ return err
}
+ r.Stats().IP.PacketsSent.Increment()
}
return nil
}
-func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) {
- ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
- length := uint16(hdr.UsedLength() + payloadSize)
- id := uint32(0)
- if length > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
- }
+func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+ ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
+ length := uint16(pkt.Size())
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic
+ // datagrams. Since the DF bit is never being set here, all datagrams
+ // are non-atomic and need an ID.
+ id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
@@ -218,28 +222,49 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
DstAddr: r.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
- e.addIPHeader(r, &hdr, payload.Size(), params)
-
- if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.addIPHeader(r, pkt, params)
+
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
+ // iptables is telling us to drop the packet.
+ return nil
+ }
+
+ // If the packet is manipulated as per NAT Ouput rules, handle packet
+ // based on destination address and do not send the packet to link layer.
+ // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than
+ // only NATted packets, but removing this check short circuits broadcasts
+ // before they are sent out to other hosts.
+ if pkt.NatDone {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress())
+ if err == nil {
+ route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress())
+ ep.HandlePacket(&route, pkt)
+ return nil
+ }
+ }
+
+ if r.Loop&stack.PacketLoop != 0 {
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, vv)
+ e.HandlePacket(&loopedR, pkt)
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
- if hdr.UsedLength()+payload.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
- return e.writePacketFragments(r, gso, hdr, payload, int(e.linkEP.MTU()))
+ if pkt.Size() > int(e.linkEP.MTU()) && (gso == nil || gso.Type == stack.GSONone) {
+ return e.writePacketFragments(r, gso, int(e.linkEP.MTU()), pkt)
}
- if err := e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber); err != nil {
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
return err
}
r.Stats().IP.PacketsSent.Increment()
@@ -247,34 +272,76 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
}
// WritePackets implements stack.NetworkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("multiple packets in local loop")
}
- if loop&stack.PacketOut == 0 {
- return len(hdrs), nil
+ if r.Loop&stack.PacketOut == 0 {
+ return pkts.Len(), nil
+ }
+
+ for pkt := pkts.Front(); pkt != nil; {
+ e.addIPHeader(r, pkt, params)
+ pkt = pkt.Next()
+ }
+
+ nicName := e.stack.FindNICNameFromID(e.NICID())
+ // iptables filtering. All packets that reach here are locally
+ // generated.
+ ipt := e.stack.IPTables()
+ dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName)
+ if len(dropped) == 0 && len(natPkts) == 0 {
+ // Fast path: If no packets are to be dropped then we can just invoke the
+ // faster WritePackets API directly.
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
}
- for i := range hdrs {
- e.addIPHeader(r, &hdrs[i].Hdr, hdrs[i].Size, params)
+ // Slow Path as we are dropping some packets in the batch degrade to
+ // emitting one packet at a time.
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if _, ok := dropped[pkt]; ok {
+ continue
+ }
+ if _, ok := natPkts[pkt]; ok {
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()); err == nil {
+ src := netHeader.SourceAddress()
+ dst := netHeader.DestinationAddress()
+ route := r.ReverseRoute(src, dst)
+ ep.HandlePacket(&route, pkt)
+ n++
+ continue
+ }
+ }
+ if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil {
+ r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
+ return n, err
+ }
+ n++
}
- n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
- return n, err
+ return n, nil
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
- ip := header.IPv4(payload.First())
- if !ip.IsValid(payload.Size()) {
+ h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
+ ip := header.IPv4(h)
+ if !ip.IsValid(pkt.Data.Size()) {
return tcpip.ErrInvalidOptionValue
}
// Always set the total length.
- ip.SetTotalLength(uint16(payload.Size()))
+ ip.SetTotalLength(uint16(pkt.Data.Size()))
// Set the source address when zero.
if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
@@ -287,51 +354,49 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
// Set the packet ID when zero.
if ip.ID() == 0 {
- id := uint32(0)
- if payload.Size() > header.IPv4MaximumHeaderSize+8 {
- // Packets of 68 bytes or less are required by RFC 791 to not be
- // fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
+ // RFC 6864 section 4.3 mandates uniqueness of ID values for
+ // non-atomic datagrams, so assign an ID to all such datagrams
+ // according to the definition given in RFC 6864 section 4.
+ if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 {
+ ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)))
}
- ip.SetID(uint16(id))
}
// Always set the checksum.
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, payload)
+ if r.Loop&stack.PacketLoop != 0 {
+ e.HandlePacket(r, pkt.Clone())
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
- // If we want to send the packet to a link-layer,
- // we have to reserve space for an Ethernet header.
- hdr := buffer.NewPrependableFromView(payload.ToView(), int(e.linkEP.MaxHeaderLength()))
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+
+ return e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- headerView := vv.First()
- h := header.IPv4(headerView)
- if !h.IsValid(vv.Size()) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv4(pkt.NetworkHeader().View())
+ if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- hlen := int(h.HeaderLength())
- tlen := int(h.TotalLength())
- vv.TrimFront(hlen)
- vv.CapLength(tlen - hlen)
+ // iptables filtering. All packets that reach here are intended for
+ // this machine and will not be forwarded.
+ ipt := e.stack.IPTables()
+ if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
- more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
- if more || h.FragmentOffset() != 0 {
- if vv.Size() == 0 {
+ if h.More() || h.FragmentOffset() != 0 {
+ if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 {
// Drop the packet as it's marked as a fragment but has
// no payload.
r.Stats().IP.MalformedPacketsReceived.Increment()
@@ -339,10 +404,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
return
}
// The packet is a fragment, let's try to reassemble it.
- last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1
// Drop the packet if the fragmentOffset is incorrect. i.e the
- // combination of fragmentOffset and vv.size() causes a wrap
- // around resulting in last being less than the offset.
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
if last < h.FragmentOffset() {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -350,7 +415,20 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
}
var ready bool
var err error
- vv, ready, err = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ pkt.Data, ready, err = e.protocol.fragmentation.Process(
+ // As per RFC 791 section 2.3, the identification value is unique
+ // for a source-destination pair and protocol.
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: uint32(h.ID()),
+ Protocol: h.Protocol(),
+ },
+ h.FragmentOffset(),
+ last,
+ h.More(),
+ pkt.Data,
+ )
if err != nil {
r.Stats().IP.MalformedPacketsReceived.Increment()
r.Stats().IP.MalformedFragmentsReceived.Increment()
@@ -362,12 +440,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
}
p := h.TransportProtocol()
if p == header.ICMPv4ProtocolNumber {
- headerView.CapLength(hlen)
- e.handleICMP(r, headerView, vv)
+ e.handleICMP(r, pkt)
return
}
r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
}
// Close cleans up resources associated with the endpoint.
@@ -381,6 +458,8 @@ type protocol struct {
// uint8 portion of it is meaningful and it must be accessed
// atomically.
defaultTTL uint32
+
+ fragmentation *fragmentation.Fragmentation
}
// Number returns the ipv4 protocol number.
@@ -436,6 +515,45 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv4(hdr)
+
+ // Header may have options, determine the true header length.
+ headerLen := int(ipHdr.HeaderLength())
+ if headerLen < header.IPv4MinimumSize {
+ // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in
+ // order for the packet to be valid. Figure out if we want to reject this
+ // case.
+ headerLen = header.IPv4MinimumSize
+ }
+ hdr, ok = pkt.NetworkHeader().Consume(headerLen)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr = header.IPv4(hdr)
+
+ // If this is a fragment, don't bother parsing the transport header.
+ parseTransportHeader := true
+ if ipHdr.More() || ipHdr.FragmentOffset() != 0 {
+ parseTransportHeader = false
+ }
+
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
+ return ipHdr.TransportProtocol(), parseTransportHeader, true
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
@@ -467,5 +585,10 @@ func NewProtocol() stack.NetworkProtocol {
}
hashIV := r[buckets]
- return &protocol{ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL}
+ return &protocol{
+ ids: ids,
+ hashIV: hashIV,
+ defaultTTL: DefaultTTL,
+ fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 99f84acd7..197e3bc51 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -17,9 +17,11 @@ package ipv4_test
import (
"bytes"
"encoding/hex"
+ "fmt"
"math/rand"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -90,15 +92,11 @@ func TestExcludeBroadcast(t *testing.T) {
})
}
-// makeHdrAndPayload generates a randomize packet. hdrLength indicates how much
+// makeRandPkt generates a randomize packet. hdrLength indicates how much
// data should already be in the header before WritePacket. extraLength
// indicates how much extra space should be in the header. The payload is made
// from many Views of the sizes listed in viewSizes.
-func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) {
- hdr := buffer.NewPrependable(hdrLength + extraLength)
- hdr.Prepend(hdrLength)
- rand.Read(hdr.View())
-
+func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer {
var views []buffer.View
totalLength := 0
for _, s := range viewSizes {
@@ -107,18 +105,26 @@ func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.
views = append(views, newView)
totalLength += s
}
- payload := buffer.NewVectorisedView(totalLength, views)
- return hdr, payload
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: hdrLength + extraLength,
+ Data: buffer.NewVectorisedView(totalLength, views),
+ })
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
+ if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil {
+ panic(fmt.Sprintf("rand.Read: %s", err))
+ }
+ return pkt
}
// comparePayloads compared the contents of all the packets against the contents
// of the source packet.
-func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) {
+func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) {
t.Helper()
// Make a complete array of the sourcePacketInfo packet.
- source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize])
- source = append(source, sourcePacketInfo.Header.View()...)
- source = append(source, sourcePacketInfo.Payload.ToView()...)
+ source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize])
+ vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views())
+ source = append(source, vv.ToView()...)
// Make a copy of the IP header, which will be modified in some fields to make
// an expected header.
@@ -131,8 +137,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
var reassembledPayload []byte
for i, packet := range packets {
// Confirm that the packet is valid.
- allBytes := packet.Header.View().ToVectorisedView()
- allBytes.Append(packet.Payload)
+ allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views())
ip := header.IPv4(allBytes.ToView())
if !ip.IsValid(len(ip)) {
t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip))
@@ -143,12 +148,22 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
if got, want := len(ip), int(mtu); got > want {
t.Errorf("fragment is too large, got %d want %d", got, want)
}
- if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want {
- t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ if i == 0 {
+ got := packet.NetworkHeader().View().Size() + packet.TransportHeader().View().Size()
+ // sourcePacketInfo does not have NetworkHeader added, simulate one.
+ want := header.IPv4MinimumSize + sourcePacketInfo.TransportHeader().View().Size()
+ // Check that it kept the transport header in packet.TransportHeader if
+ // it fits in the first fragment.
+ if want < int(mtu) && got != want {
+ t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want)
+ }
}
- if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
+ if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
+ if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want {
+ t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want)
+ }
if i < len(packets)-1 {
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
} else {
@@ -173,7 +188,7 @@ func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packe
type errorChannel struct {
*channel.Endpoint
- Ch chan packetInfo
+ Ch chan *stack.PacketBuffer
packetCollectorErrors []*tcpip.Error
}
@@ -183,17 +198,11 @@ type errorChannel struct {
func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
return &errorChannel{
Endpoint: channel.New(size, mtu, linkAddr),
- Ch: make(chan packetInfo, size),
+ Ch: make(chan *stack.PacketBuffer, size),
packetCollectorErrors: packetCollectorErrors,
}
}
-// packetInfo holds all the information about an outbound packet.
-type packetInfo struct {
- Header buffer.Prependable
- Payload buffer.VectorisedView
-}
-
// Drain removes all outbound packets from the channel and counts them.
func (e *errorChannel) Drain() int {
c := 0
@@ -208,14 +217,9 @@ func (e *errorChannel) Drain() int {
}
// WritePacket stores outbound packets into the channel.
-func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
- p := packetInfo{
- Header: hdr,
- Payload: payload,
- }
-
+func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
select {
- case e.Ch <- p:
+ case e.Ch <- pkt:
default:
}
@@ -291,19 +295,19 @@ func TestFragmentation(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
- source := packetInfo{
- Header: hdr,
- // Save the source payload because WritePacket will modify it.
- Payload: payload.Clone([]buffer.View{}),
- }
+ pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes)
+ source := pkt.Clone()
c := buildContext(t, nil, ft.mtu)
- err := c.Route.WritePacket(ft.gso, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
+ err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, pkt)
if err != nil {
t.Errorf("err got %v, want %v", err, nil)
}
- var results []packetInfo
+ var results []*stack.PacketBuffer
L:
for {
select {
@@ -343,9 +347,13 @@ func TestFragmentationErrors(t *testing.T) {
for _, ft := range fragTests {
t.Run(ft.description, func(t *testing.T) {
- hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
+ pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes)
c := buildContext(t, ft.packetCollectorErrors, ft.mtu)
- err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, stack.NetworkHeaderParams{Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS})
+ err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{
+ Protocol: tcp.ProtocolNumber,
+ TTL: 42,
+ TOS: stack.DefaultTOS,
+ }, pkt)
for i := 0; i < len(ft.packetCollectorErrors)-1; i++ {
if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want {
t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want)
@@ -451,7 +459,7 @@ func TestInvalidFragments(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
- const nicid tcpip.NICID = 42
+ const nicID tcpip.NICID = 42
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{
ipv4.NewProtocol(),
@@ -461,10 +469,12 @@ func TestInvalidFragments(t *testing.T) {
var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30})
var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31})
ep := channel.New(10, 1500, linkAddr)
- s.CreateNIC(nicid, sniffer.New(ep))
+ s.CreateNIC(nicID, sniffer.New(ep))
for _, pkt := range tc.packets {
- ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}))
+ ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}),
+ }))
}
if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want {
@@ -476,3 +486,423 @@ func TestInvalidFragments(t *testing.T) {
})
}
}
+
+// TestReceiveFragments feeds fragments in through the incoming packet path to
+// test reassembly
+func TestReceiveFragments(t *testing.T) {
+ const (
+ nicID = 1
+
+ addr1 = "\x0c\xa8\x00\x01" // 192.168.0.1
+ addr2 = "\x0c\xa8\x00\x02" // 192.168.0.2
+ addr3 = "\x0c\xa8\x00\x03" // 192.168.0.3
+ )
+
+ // Build and return a UDP header containing payload.
+ udpGen := func(payloadLen int, multiplier uint8, src, dst tcpip.Address) buffer.View {
+ payload := buffer.NewView(payloadLen)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i) * multiplier
+ }
+
+ udpLength := header.UDPMinimumSize + len(payload)
+
+ hdr := buffer.NewPrependable(udpLength)
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), payload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
+ sum = header.Checksum(payload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ return hdr.View()
+ }
+
+ // UDP header plus a payload of 0..256
+ ipv4Payload1Addr1ToAddr2 := udpGen(256, 1, addr1, addr2)
+ udpPayload1Addr1ToAddr2 := ipv4Payload1Addr1ToAddr2[header.UDPMinimumSize:]
+ ipv4Payload1Addr3ToAddr2 := udpGen(256, 1, addr3, addr2)
+ udpPayload1Addr3ToAddr2 := ipv4Payload1Addr3ToAddr2[header.UDPMinimumSize:]
+ // UDP header plus a payload of 0..256 in increments of 2.
+ ipv4Payload2Addr1ToAddr2 := udpGen(128, 2, addr1, addr2)
+ udpPayload2Addr1ToAddr2 := ipv4Payload2Addr1ToAddr2[header.UDPMinimumSize:]
+ // UDP header plus a payload of 0..256 in increments of 3.
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 791 section 3.1 page 14).
+ ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2)
+ udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:]
+
+ type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ id uint16
+ flags uint8
+ fragmentOffset uint16
+ payload buffer.View
+ }
+
+ tests := []struct {
+ name string
+ fragments []fragmentData
+ expectedPayloads [][]byte
+ }{
+ {
+ name: "No fragmentation",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2,
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "No fragmentation with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2,
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "More fragments without payload",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2,
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Non-zero fragment offset without payload",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 8,
+ payload: ipv4Payload1Addr1ToAddr2,
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments out of order",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload3Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload3Addr1ToAddr2[:63],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 63,
+ payload: ipv4Payload3Addr1ToAddr2[63:],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Second fragment has MoreFlags set",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with different IDs",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 2,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two interleaved fragmented packets",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 2,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload2Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 2,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload2Addr1ToAddr2[64:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr3ToAddr2[:32],
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 64,
+ payload: ipv4Payload1Addr1ToAddr2[64:],
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ id: 1,
+ flags: 0,
+ fragmentOffset: 32,
+ payload: ipv4Payload1Addr3ToAddr2[32:],
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
+ },
+ {
+ name: "Fragment without followup",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ id: 1,
+ flags: header.IPv4FlagMoreFragments,
+ fragmentOffset: 0,
+ payload: ipv4Payload1Addr1ToAddr2[:64],
+ },
+ },
+ expectedPayloads: nil,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ // Setup a stack and endpoint.
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, tcpip.LinkAddress("\xf0\x00"))
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, header.IPv4ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv4ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ // Prepare and send the fragments.
+ for _, frag := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv4MinimumSize)
+
+ // Serialize IPv4 fixed header.
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: header.IPv4MinimumSize + uint16(len(frag.payload)),
+ ID: frag.id,
+ Flags: frag.flags,
+ FragmentOffset: frag.fragmentOffset,
+ TTL: 64,
+ Protocol: uint8(header.UDPProtocolNumber),
+ SrcAddr: frag.srcAddr,
+ DstAddr: frag.dstAddr,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.AppendView(frag.payload)
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
+ t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
+ }
+
+ for i, expectedPayload := range test.expectedPayloads {
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ }
+ if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" {
+ t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index f06622a8b..bcc64994e 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,14 +8,12 @@ go_library(
"icmp.go",
"ipv6.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/network/ipv6",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
"//pkg/tcpip/stack",
],
)
@@ -29,10 +26,11 @@ go_test(
"ipv6_test.go",
"ndp_test.go",
],
- embed = [":ipv6"],
+ library = ":ipv6",
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
@@ -40,5 +38,6 @@ go_test(
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index c3f1dd488..66d3a953a 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,6 +15,8 @@
package ipv6
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -25,26 +27,35 @@ import (
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
-func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
- h := header.IPv6(vv.First())
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ hdr := header.IPv6(h)
// We don't use IsValid() here because ICMP only requires that up to
// 1280 bytes of the original packet be included. So it's likely that it
// is truncated, which would cause IsValid to return false.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
- // original source address doesn't match the endpoint's address.
- if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ // original source address doesn't match an address we own.
+ src := hdr.SourceAddress()
+ if e.stack.CheckLocalAddress(e.NICID(), ProtocolNumber, src) == 0 {
return
}
// Skip the IP header, then handle the fragmentation header if there
// is one.
- vv.TrimFront(header.IPv6MinimumSize)
- p := h.TransportProtocol()
+ pkt.Data.TrimFront(header.IPv6MinimumSize)
+ p := hdr.TransportProtocol()
if p == header.IPv6FragmentHeader {
- f := header.IPv6Fragment(vv.First())
- if !f.IsValid() || f.FragmentOffset() != 0 {
+ f, ok := pkt.Data.PullUp(header.IPv6FragmentHeaderSize)
+ if !ok {
+ return
+ }
+ fragHdr := header.IPv6Fragment(f)
+ if !fragHdr.IsValid() || fragHdr.FragmentOffset() != 0 {
// We can't handle fragments that aren't at offset 0
// because they don't have the transport headers.
return
@@ -52,145 +63,183 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.
// Skip fragmentation header and find out the actual protocol
// number.
- vv.TrimFront(header.IPv6FragmentHeaderSize)
- p = f.TransportProtocol()
+ pkt.Data.TrimFront(header.IPv6FragmentHeaderSize)
+ p = fragHdr.TransportProtocol()
}
// Deliver the control packet to the transport endpoint.
- e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+ e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt)
}
-func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) {
stats := r.Stats().ICMP
sent := stats.V6PacketsSent
received := stats.V6PacketsReceived
- v := vv.First()
- if len(v) < header.ICMPv6MinimumSize {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ v, ok := pkt.Data.PullUp(header.ICMPv6HeaderSize)
+ if !ok {
received.Invalid.Increment()
return
}
h := header.ICMPv6(v)
- iph := header.IPv6(netHeader)
+ iph := header.IPv6(pkt.NetworkHeader().View())
// Validate ICMPv6 checksum before processing the packet.
//
- // Only the first view in vv is accounted for by h. To account for the
- // rest of vv, a shallow copy is made and the first view is removed.
// This copy is used as extra payload during the checksum calculation.
- payload := vv
- payload.RemoveFirst()
+ payload := pkt.Data.Clone(nil)
+ payload.TrimFront(len(h))
if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want {
received.Invalid.Increment()
return
}
- // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
- // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
- // in the IPv6 header is not set to 255.
- switch h.Type() {
- case header.ICMPv6NeighborSolicit,
- header.ICMPv6NeighborAdvert,
- header.ICMPv6RouterSolicit,
- header.ICMPv6RouterAdvert,
- header.ICMPv6RedirectMsg:
- if iph.HopLimit() != header.NDPHopLimit {
- received.Invalid.Increment()
- return
- }
+ isNDPValid := func() bool {
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255, or the ICMPv6 Code field is not
+ // set to 0.
+ //
+ // As per RFC 6980 section 5, nodes MUST silently drop NDP messages if the
+ // packet includes a fragmentation header.
+ return !hasFragmentHeader && iph.HopLimit() == header.NDPHopLimit && h.Code() == 0
}
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv6PacketTooBig:
received.PacketTooBig.Increment()
- if len(v) < header.ICMPv6PacketTooBigMinimumSize {
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6PacketTooBigMinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := h.MTU()
- e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := header.ICMPv6(hdr).MTU()
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt)
case header.ICMPv6DstUnreachable:
received.DstUnreachable.Increment()
- if len(v) < header.ICMPv6DstUnreachableMinimumSize {
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6DstUnreachableMinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
- switch h.Code() {
+ pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch header.ICMPv6(hdr).Code() {
+ case header.ICMPv6NetworkUnreachable:
+ e.handleControl(stack.ControlNetworkUnreachable, 0, pkt)
case header.ICMPv6PortUnreachable:
- e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ e.handleControl(stack.ControlPortUnreachable, 0, pkt)
}
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
+ if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the neighbor solicitation, so
+ // payload.ToView() always returns the solicitation. Per RFC 6980 section 5,
+ // NDP messages cannot be fragmented. Also note that in the common case NDP
+ // datagrams are very small and ToView() will not incur allocations.
+ ns := header.NDPNeighborSolicit(payload.ToView())
+ it, err := ns.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NS option, drop the packet.
received.Invalid.Increment()
return
}
- ns := header.NDPNeighborSolicit(h.NDPPayload())
targetAddr := ns.TargetAddress()
s := r.Stack()
- rxNICID := r.NICID()
-
- isTentative, err := s.IsAddrTentative(rxNICID, targetAddr)
- if err != nil {
- // We will only get an error if rxNICID is unrecognized,
- // which should not happen. For now short-circuit this
- // packet.
+ if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now, drop this packet.
//
// TODO(b/141002840): Handle this better?
return
- }
-
- if isTentative {
- // If the target address is tentative and the source
- // of the packet is a unicast (specified) address, then
- // the source of the packet is attempting to perform
- // address resolution on the target. In this case, the
- // solicitation is silently ignored, as per RFC 4862
- // section 5.4.3.
+ } else if isTentative {
+ // If the target address is tentative and the source of the packet is a
+ // unicast (specified) address, then the source of the packet is
+ // attempting to perform address resolution on the target. In this case,
+ // the solicitation is silently ignored, as per RFC 4862 section 5.4.3.
//
- // If the target address is tentative and the source of
- // the packet is the unspecified address (::), then we
- // know another node is also performing DAD for the
- // same address (since targetAddr is tentative for us,
- // we know we are also performing DAD on it). In this
- // case we let the stack know so it can handle such a
- // scenario and do nothing further with the NDP NS.
- if iph.SourceAddress() == header.IPv6Any {
- s.DupTentativeAddrDetected(rxNICID, targetAddr)
+ // If the target address is tentative and the source of the packet is the
+ // unspecified address (::), then we know another node is also performing
+ // DAD for the same address (since the target address is tentative for us,
+ // we know we are also performing DAD on it). In this case we let the
+ // stack know so it can handle such a scenario and do nothing further with
+ // the NS.
+ if r.RemoteAddress == header.IPv6Any {
+ s.DupTentativeAddrDetected(e.nicID, targetAddr)
}
- // Do not handle neighbor solicitations targeted
- // to an address that is tentative on the received
- // NIC any further.
+ // Do not handle neighbor solicitations targeted to an address that is
+ // tentative on the NIC any further.
return
}
- // At this point we know that targetAddr is not tentative on
- // rxNICID so the packet is processed as defined in RFC 4861,
- // as per RFC 4862 section 5.4.3.
+ // At this point we know that the target address is not tentative on the NIC
+ // so the packet is processed as defined in RFC 4861, as per RFC 4862
+ // section 5.4.3.
- if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
- // We don't have a useful answer; the best we can do is ignore the request.
+ // Is the NS targetting us?
+ if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 {
return
}
- optsSerializer := header.NDPOptionsSerializer{
- header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress[:]),
+ // If the NS message contains the Source Link-Layer Address option, update
+ // the link address cache with the value of the option.
+ //
+ // TODO(b/148429853): Properly process the NS message and do Neighbor
+ // Unreachability Detection.
+ var sourceLinkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
+ }
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPSourceLinkLayerAddressOption:
+ // No RFCs define what to do when an NS message has multiple Source
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(sourceLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ sourceLinkAddr = opt.EthernetAddress()
+ }
+ }
+
+ unspecifiedSource := r.RemoteAddress == header.IPv6Any
+
+ // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST
+ // NOT be included when the source IP address is the unspecified address.
+ // Otherwise, on link layers that have addresses this option MUST be
+ // included in multicast solicitations and SHOULD be included in unicast
+ // solicitations.
+ if len(sourceLinkAddr) == 0 {
+ if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ }
+ } else if unspecifiedSource {
+ received.Invalid.Increment()
+ return
+ } else {
+ e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr)
}
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()))
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- pkt.SetType(header.ICMPv6NeighborAdvert)
- na := header.NDPNeighborAdvert(pkt.NDPPayload())
- na.SetSolicitedFlag(true)
- na.SetOverrideFlag(true)
- na.SetTargetAddress(targetAddr)
- opts := na.Options()
- opts.Serialize(optsSerializer)
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
@@ -203,16 +252,43 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- // TODO(tamird/ghanan): there exists an explicit NDP option that is
- // used to update the neighbor table with link addresses for a
- // neighbor from an NS (see the Source Link Layer option RFC
- // 4861 section 4.6.1 and section 7.2.3).
+ // As per RFC 4861 section 7.2.4, if the the source of the solicitation is
+ // the unspecified address, the node MUST set the Solicited flag to zero and
+ // multicast the advertisement to the all-nodes address.
+ solicited := true
+ if unspecifiedSource {
+ solicited = false
+ r.RemoteAddress = header.IPv6AllNodesMulticastAddress
+ }
+
+ // If the NS has a source link-layer option, use the link address it
+ // specifies as the remote link address for the response instead of the
+ // source link address of the packet.
//
- // Furthermore, the entirety of NDP handling here seems to be
- // contradicted by RFC 4861.
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link
+ // address cache for the right destination link address instead of manually
+ // patching the route with the remote link address if one is specified in a
+ // Source Link-Layer Address option.
+ if len(sourceLinkAddr) != 0 {
+ r.RemoteLinkAddress = sourceLinkAddr
+ }
+
+ optsSerializer := header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6NeighborAdvertMinimumSize + int(optsSerializer.Length()),
+ })
+ packet := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ packet.SetType(header.ICMPv6NeighborAdvert)
+ na := header.NDPNeighborAdvert(packet.NDPPayload())
+ na.SetSolicitedFlag(solicited)
+ na.SetOverrideFlag(true)
+ na.SetTargetAddress(targetAddr)
+ opts := na.Options()
+ opts.Serialize(optsSerializer)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
// RFC 4861 Neighbor Discovery for IP version 6 (IPv6)
//
@@ -220,7 +296,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
//
// The IP Hop Limit field has a value of 255, i.e., the packet
// could not possibly have been forwarded by a router.
- if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: stack.DefaultTOS}, pkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -228,64 +304,121 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborAdvert:
received.NeighborAdvert.Increment()
- if len(v) < header.ICMPv6NeighborAdvertSize {
+ if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the neighbor advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ na := header.NDPNeighborAdvert(payload.ToView())
+ it, err := na.Options().Iter(true)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
received.Invalid.Increment()
return
}
- na := header.NDPNeighborAdvert(h.NDPPayload())
targetAddr := na.TargetAddress()
stack := r.Stack()
- rxNICID := r.NICID()
- isTentative, err := stack.IsAddrTentative(rxNICID, targetAddr)
- if err != nil {
- // We will only get an error if rxNICID is unrecognized,
- // which should not happen. For now short-circuit this
- // packet.
+ if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil {
+ // We will only get an error if the NIC is unrecognized, which should not
+ // happen. For now short-circuit this packet.
//
// TODO(b/141002840): Handle this better?
return
- }
-
- if isTentative {
- // We just got an NA from a node that owns an address we
- // are performing DAD on, implying the address is not
- // unique. In this case we let the stack know so it can
- // handle such a scenario and do nothing furthur with
+ } else if isTentative {
+ // We just got an NA from a node that owns an address we are performing
+ // DAD on, implying the address is not unique. In this case we let the
+ // stack know so it can handle such a scenario and do nothing furthur with
// the NDP NA.
- stack.DupTentativeAddrDetected(rxNICID, targetAddr)
+ stack.DupTentativeAddrDetected(e.nicID, targetAddr)
return
}
- // At this point we know that the targetAddress is not tentative
- // on rxNICID. However, targetAddr may still be assigned to
- // rxNICID but not tentative (it could be permanent). Such a
- // scenario is beyond the scope of RFC 4862. As such, we simply
- // ignore such a scenario for now and proceed as normal.
+ // At this point we know that the target address is not tentative on the
+ // NIC. However, the target address may still be assigned to the NIC but not
+ // tentative (it could be permanent). Such a scenario is beyond the scope of
+ // RFC 4862. As such, we simply ignore such a scenario for now and proceed
+ // as normal.
+ //
+ // TODO(b/143147598): Handle the scenario described above. Also inform the
+ // netstack integration that a duplicate address was detected outside of
+ // DAD.
+
+ // If the NA message has the target link layer option, update the link
+ // address cache with the link address for the target of the message.
//
- // TODO(b/143147598): Handle the scenario described above. Also
- // inform the netstack integration that a duplicate address was
- // detected outside of DAD.
+ // TODO(b/148429853): Properly process the NA message and do Neighbor
+ // Unreachability Detection.
+ var targetLinkAddr tcpip.LinkAddress
+ for {
+ opt, done, err := it.Next()
+ if err != nil {
+ // This should never happen as Iter(true) above did not return an error.
+ panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err))
+ }
+ if done {
+ break
+ }
+
+ switch opt := opt.(type) {
+ case header.NDPTargetLinkLayerAddressOption:
+ // No RFCs define what to do when an NA message has multiple Target
+ // Link-Layer Address options. Since no interface can have multiple
+ // link-layer addresses, we consider such messages invalid.
+ if len(targetLinkAddr) != 0 {
+ received.Invalid.Increment()
+ return
+ }
+
+ targetLinkAddr = opt.EthernetAddress()
+ }
+ }
- e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
- if targetAddr != r.RemoteAddress {
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
+ if len(targetLinkAddr) != 0 {
+ e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr)
}
case header.ICMPv6EchoRequest:
received.EchoRequest.Increment()
- if len(v) < header.ICMPv6EchoMinimumSize {
+ icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize)
+ if !ok {
received.Invalid.Increment()
return
}
- vv.TrimFront(header.ICMPv6EchoMinimumSize)
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
- copy(pkt, h)
- pkt.SetType(header.ICMPv6EchoReply)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
- if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+
+ remoteLinkAddr := r.RemoteLinkAddress
+
+ // As per RFC 4291 section 2.7, multicast addresses must not be used as
+ // source addresses in IPv6 packets.
+ localAddr := r.LocalAddress
+ if header.IsV6MulticastAddress(r.LocalAddress) {
+ localAddr = ""
+ }
+
+ r, err := r.Stack().FindRoute(e.NICID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ // If we cannot find a route to the destination, silently drop the packet.
+ return
+ }
+ defer r.Release()
+
+ // Use the link address from the source of the original packet.
+ r.ResolveWith(remoteLinkAddr)
+
+ replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize,
+ Data: pkt.Data,
+ })
+ packet := header.ICMPv6(replyPkt.TransportHeader().Push(header.ICMPv6EchoMinimumSize))
+ copy(packet, icmpHdr)
+ packet.SetType(header.ICMPv6EchoReply)
+ packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data))
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil {
sent.Dropped.Increment()
return
}
@@ -293,11 +426,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6EchoReply:
received.EchoReply.Increment()
- if len(v) < header.ICMPv6EchoMinimumSize {
+ if pkt.Data.Size() < header.ICMPv6EchoMinimumSize {
received.Invalid.Increment()
return
}
- e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, netHeader, vv)
+ e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt)
case header.ICMPv6TimeExceeded:
received.TimeExceeded.Increment()
@@ -307,12 +440,64 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6RouterSolicit:
received.RouterSolicit.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
case header.ICMPv6RouterAdvert:
received.RouterAdvert.Increment()
+ // Is the NDP payload of sufficient size to hold a Router
+ // Advertisement?
+ if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
+
+ routerAddr := iph.SourceAddress()
+
+ //
+ // Validate the RA as per RFC 4861 section 6.1.2.
+ //
+
+ // Is the IP Source Address a link-local address?
+ if !header.IsV6LinkLocalAddress(routerAddr) {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ // The remainder of payload must be only the router advertisement, so
+ // payload.ToView() always returns the advertisement. Per RFC 6980 section
+ // 5, NDP messages cannot be fragmented. Also note that in the common case
+ // NDP datagrams are very small and ToView() will not incur allocations.
+ ra := header.NDPRouterAdvert(payload.ToView())
+ opts := ra.Options()
+
+ // Are options valid as per the wire format?
+ if _, err := opts.Iter(true); err != nil {
+ // ...No, silently drop the packet.
+ received.Invalid.Increment()
+ return
+ }
+
+ //
+ // At this point, we have a valid Router Advertisement, as far
+ // as RFC 4861 section 6.1.2 is concerned.
+ //
+
+ // Tell the NIC to handle the RA.
+ stack := r.Stack()
+ rxNICID := r.NICID()
+ stack.HandleNDPRA(rxNICID, routerAddr, ra)
+
case header.ICMPv6RedirectMsg:
received.RedirectMsg.Increment()
+ if !isNDPValid() {
+ received.Invalid.Increment()
+ return
+ }
default:
received.Invalid.Increment()
@@ -331,8 +516,6 @@ const (
icmpV6LengthOffset = 25
)
-var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
@@ -341,24 +524,34 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
+
+ // TODO(b/148672031): Use stack.FindRoute instead of manually creating the
+ // route here. Note, we would need the nicID to do this properly so the right
+ // NIC (associated to linkEP) is used to send the NDP NS message.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: broadcastMAC,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
}
- hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- copy(pkt[icmpV6OptOffset-len(addr):], addr)
- pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
- pkt[icmpV6LengthOffset] = 1
- copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
-
- length := uint16(hdr.UsedLength())
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize,
+ })
+ icmpHdr := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6NeighborAdvertSize))
+ icmpHdr.SetType(header.ICMPv6NeighborSolicit)
+ copy(icmpHdr[icmpV6OptOffset-len(addr):], addr)
+ icmpHdr[icmpV6OptOffset] = ndpOptSrcLinkAddr
+ icmpHdr[icmpV6LengthOffset] = 1
+ copy(icmpHdr[icmpV6LengthOffset+1:], linkEP.LinkAddress())
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ length := uint16(pkt.Size())
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
@@ -368,29 +561,13 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
// TODO(stijlist): count this in ICMP stats.
- return linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+ return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if header.IsV6MulticastAddress(addr) {
- // RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
- //
- // 7. Address Mapping -- Multicast
- //
- // An IPv6 packet with a multicast destination address DST,
- // consisting of the sixteen octets DST[1] through DST[16], is
- // transmitted to the Ethernet multicast address whose first
- // two octets are the value 3333 hexadecimal and whose last
- // four octets are the last four octets of DST.
- return tcpip.LinkAddress([]byte{
- 0x33,
- 0x33,
- addr[header.IPv6AddressSize-4],
- addr[header.IPv6AddressSize-3],
- addr[header.IPv6AddressSize-2],
- addr[header.IPv6AddressSize-1],
- }), true
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
}
- return "", false
+ return tcpip.LinkAddress([]byte(nil)), false
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index b112303b6..9e4eeea77 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "context"
"reflect"
"strings"
"testing"
@@ -31,7 +32,11 @@ import (
const (
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+ linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
+ linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
var (
@@ -55,7 +60,7 @@ func (*stubLinkEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
-func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.NetworkProtocolNumber) *tcpip.Error {
+func (*stubLinkEndpoint) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error {
return nil
}
@@ -65,7 +70,7 @@ type stubDispatcher struct {
stack.TransportDispatcher
}
-func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, buffer.View, buffer.VectorisedView) {
+func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) {
}
type stubLinkAddressCache struct {
@@ -109,10 +114,8 @@ func TestICMPCounts(t *testing.T) {
if netProto == nil {
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{lladdr1, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
- }
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
+ defer ep.Close()
r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
if err != nil {
@@ -120,48 +123,90 @@ func TestICMPCounts(t *testing.T) {
}
defer r.Release()
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
- typ header.ICMPv6Type
- size int
+ typ header.ICMPv6Type
+ size int
+ extraData []byte
}{
- {header.ICMPv6DstUnreachable, header.ICMPv6DstUnreachableMinimumSize},
- {header.ICMPv6PacketTooBig, header.ICMPv6PacketTooBigMinimumSize},
- {header.ICMPv6TimeExceeded, header.ICMPv6MinimumSize},
- {header.ICMPv6ParamProblem, header.ICMPv6MinimumSize},
- {header.ICMPv6EchoRequest, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6EchoReply, header.ICMPv6EchoMinimumSize},
- {header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize},
- {header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize},
- {header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize},
- {header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize},
- {header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize},
- }
-
- handleIPv6Payload := func(hdr buffer.Prependable) {
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ {
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ },
+ {
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ },
+ {
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ },
+ {
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ },
+ {
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ },
+ {
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ },
+ }
+
+ handleIPv6Payload := func(icmp header.ICMPv6) {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize,
+ Data: buffer.View(icmp).ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
+ PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(&r, hdr.View().ToVectorisedView())
+ ep.HandlePacket(&r, pkt)
}
for _, typ := range types {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
-
- handleIPv6Payload(hdr)
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
+ handleIPv6Payload(icmp)
}
// Construct an empty ICMP packet so that
// Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented.
- handleIPv6Payload(buffer.NewPrependable(header.IPv6MinimumSize))
+ handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize)))
icmpv6Stats := s.Stats().ICMP.V6PacketsReceived
visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) {
@@ -214,8 +259,7 @@ func newTestContext(t *testing.T) *testContext {
}),
}
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
@@ -228,7 +272,7 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -262,32 +306,40 @@ func newTestContext(t *testing.T) *testContext {
}
func (c *testContext) cleanup() {
- close(c.linkEP0.C)
- close(c.linkEP1.C)
+ c.linkEP0.Close()
+ c.linkEP1.Close()
}
type routeArgs struct {
- src, dst *channel.Endpoint
- typ header.ICMPv6Type
+ src, dst *channel.Endpoint
+ typ header.ICMPv6Type
+ remoteLinkAddr tcpip.LinkAddress
}
func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header.ICMPv6)) {
t.Helper()
- pkt := <-args.src.C
+ pi, _ := args.src.ReadContext(context.Background())
{
- views := []buffer.View{pkt.Header, pkt.Payload}
- size := len(pkt.Header) + len(pkt.Payload)
- vv := buffer.NewVectorisedView(size, views)
- args.dst.InjectLinkAddr(pkt.Proto, args.dst.LinkAddress(), vv)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pi.Pkt.Size(), pi.Pkt.Views()),
+ })
+ args.dst.InjectLinkAddr(pi.Proto, args.dst.LinkAddress(), pkt)
}
- if pkt.Proto != ProtocolNumber {
- t.Errorf("unexpected protocol number %d", pkt.Proto)
+ if pi.Proto != ProtocolNumber {
+ t.Errorf("unexpected protocol number %d", pi.Proto)
return
}
- ipv6 := header.IPv6(pkt.Header)
+
+ if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress {
+ t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr)
+ }
+
+ // Pull the full payload since network header. Needed for header.IPv6 to
+ // extract its payload.
+ ipv6 := header.IPv6(stack.PayloadSince(pi.Pkt.NetworkHeader()))
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
t.Errorf("unexpected transport protocol number %d", transProto)
@@ -334,7 +386,7 @@ func TestLinkResolution(t *testing.T) {
t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err)
}
for _, args := range []routeArgs{
- {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit},
+ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
{src: c.linkEP1, dst: c.linkEP0, typ: header.ICMPv6NeighborAdvert},
} {
routeICMPv6Packet(t, args, func(t *testing.T, icmpv6 header.ICMPv6) {
@@ -361,97 +413,104 @@ func TestLinkResolution(t *testing.T) {
}
func TestICMPChecksumValidationSimple(t *testing.T) {
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
{
- "DstUnreachable",
- header.ICMPv6DstUnreachable,
- header.ICMPv6DstUnreachableMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "DstUnreachable",
+ typ: header.ICMPv6DstUnreachable,
+ size: header.ICMPv6DstUnreachableMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.DstUnreachable
},
},
{
- "PacketTooBig",
- header.ICMPv6PacketTooBig,
- header.ICMPv6PacketTooBigMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "PacketTooBig",
+ typ: header.ICMPv6PacketTooBig,
+ size: header.ICMPv6PacketTooBigMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.PacketTooBig
},
},
{
- "TimeExceeded",
- header.ICMPv6TimeExceeded,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "TimeExceeded",
+ typ: header.ICMPv6TimeExceeded,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.TimeExceeded
},
},
{
- "ParamProblem",
- header.ICMPv6ParamProblem,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "ParamProblem",
+ typ: header.ICMPv6ParamProblem,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.ParamProblem
},
},
{
- "EchoRequest",
- header.ICMPv6EchoRequest,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoRequest",
+ typ: header.ICMPv6EchoRequest,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoRequest
},
},
{
- "EchoReply",
- header.ICMPv6EchoReply,
- header.ICMPv6EchoMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "EchoReply",
+ typ: header.ICMPv6EchoReply,
+ size: header.ICMPv6EchoMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.EchoReply
},
},
{
- "RouterSolicit",
- header.ICMPv6RouterSolicit,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterSolicit
},
},
{
- "RouterAdvert",
- header.ICMPv6RouterAdvert,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RouterAdvert
},
},
{
- "NeighborSolicit",
- header.ICMPv6NeighborSolicit,
- header.ICMPv6NeighborSolicitMinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborSolicit
},
},
{
- "NeighborAdvert",
- header.ICMPv6NeighborAdvert,
- header.ICMPv6NeighborAdvertSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.NeighborAdvert
},
},
{
- "RedirectMsg",
- header.ICMPv6RedirectMsg,
- header.ICMPv6MinimumSize,
- func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return stats.RedirectMsg
},
},
@@ -483,22 +542,25 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
)
}
- handleIPv6Payload := func(typ header.ICMPv6Type, size int, checksum bool) {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
+ handleIPv6Payload := func(checksum bool) {
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView()))
}
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(size),
+ PayloadLength: uint16(len(icmp)),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: header.NDPHopLimit,
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}),
+ })
+ e.InjectInbound(ProtocolNumber, pkt)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -515,7 +577,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
// Without setting checksum, the incoming packet should
// be invalid.
- handleIPv6Payload(typ.typ, typ.size, false)
+ handleIPv6Payload(false)
if got := invalid.Value(); got != 1 {
t.Fatalf("got invalid = %d, want = 1", got)
}
@@ -525,7 +587,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) {
}
// When checksum is set, it should be received.
- handleIPv6Payload(typ.typ, typ.size, true)
+ handleIPv6Payload(true)
if got := typStat.Value(); got != 1 {
t.Fatalf("got %s = %d, want = 1", typ.name, got)
}
@@ -657,12 +719,12 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
icmpSize := size + payloadSize
hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
- pkt := header.ICMPv6(hdr.Prepend(icmpSize))
- pkt.SetType(typ)
- payloadFn(pkt.Payload())
+ icmpHdr := header.ICMPv6(hdr.Prepend(icmpSize))
+ icmpHdr.SetType(typ)
+ payloadFn(icmpHdr.Payload())
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, buffer.VectorisedView{}))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -673,7 +735,10 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(ProtocolNumber, pkt)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -831,14 +896,14 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
handleIPv6Payload := func(typ header.ICMPv6Type, size, payloadSize int, payloadFn func(buffer.View), checksum bool) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + size)
- pkt := header.ICMPv6(hdr.Prepend(size))
- pkt.SetType(typ)
+ icmpHdr := header.ICMPv6(hdr.Prepend(size))
+ icmpHdr.SetType(typ)
payload := buffer.NewView(payloadSize)
payloadFn(payload)
if checksum {
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, payload.ToVectorisedView()))
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, lladdr1, lladdr0, payload.ToVectorisedView()))
}
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -849,9 +914,10 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
SrcAddr: lladdr1,
DstAddr: lladdr0,
})
- e.Inject(ProtocolNumber,
- buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize,
- []buffer.View{hdr.View(), payload}))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}),
+ })
+ e.InjectInbound(ProtocolNumber, pkt)
}
stats := s.Stats().ICMP.V6PacketsReceived
@@ -889,3 +955,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ snaddr := header.SolicitedNodeAddr(lladdr0)
+ mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
+
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: mcaddr,
+ },
+ }
+
+ for _, test := range tests {
+ p := NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 5898f8f9e..0eafe9790 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -21,11 +21,13 @@
package ipv6
import (
+ "fmt"
"sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -43,13 +45,12 @@ const (
)
type endpoint struct {
- nicid tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
+ nicID tcpip.NICID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
protocol *protocol
+ stack *stack.Stack
}
// DefaultTTL is the default hop limit for this endpoint.
@@ -65,17 +66,7 @@ func (e *endpoint) MTU() uint32 {
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
- return e.nicid
-}
-
-// ID returns the ipv6 endpoint ID.
-func (e *endpoint) ID() *stack.NetworkEndpointID {
- return &e.id
-}
-
-// PrefixLen returns the ipv6 endpoint subnet prefix length in bits.
-func (e *endpoint) PrefixLen() int {
- return e.prefixLen
+ return e.nicID
}
// Capabilities implements stack.NetworkEndpoint.Capabilities.
@@ -97,9 +88,9 @@ func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
-func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) {
- length := uint16(hdr.UsedLength() + payloadSize)
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) {
+ length := uint16(pkt.Size())
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(params.Protocol),
@@ -108,86 +99,336 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
}
// WritePacket writes a packet to the given destination address and protocol.
-func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
- e.addIPHeader(r, &hdr, payload.Size(), params)
-
- if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
+func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.addIPHeader(r, pkt, params)
+
+ if r.Loop&stack.PacketLoop != 0 {
loopedR := r.MakeLoopedRoute()
- e.HandlePacket(&loopedR, vv)
+
+ e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ // The inbound path expects an unparsed packet.
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+
loopedR.Release()
}
- if loop&stack.PacketOut == 0 {
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
r.Stats().IP.PacketsSent.Increment()
- return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
+ return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
- if loop&stack.PacketLoop != 0 {
+func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
+ if r.Loop&stack.PacketLoop != 0 {
panic("not implemented")
}
- if loop&stack.PacketOut == 0 {
- return len(hdrs), nil
+ if r.Loop&stack.PacketOut == 0 {
+ return pkts.Len(), nil
}
- for i := range hdrs {
- hdr := &hdrs[i].Hdr
- size := hdrs[i].Size
- e.addIPHeader(r, hdr, size, params)
+ for pb := pkts.Front(); pb != nil; pb = pb.Next() {
+ e.addIPHeader(r, pb, params)
}
- n, err := e.linkEP.WritePackets(r, gso, hdrs, payload, ProtocolNumber)
+ n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
r.Stats().IP.PacketsSent.IncrementBy(uint64(n))
return n, err
}
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
- // TODO(b/119580726): Support IPv6 header-included packets.
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // TODO(b/146666412): Support IPv6 header-included packets.
return tcpip.ErrNotSupported
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
- headerView := vv.First()
- h := header.IPv6(headerView)
- if !h.IsValid(vv.Size()) {
+func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
+ h := header.IPv6(pkt.NetworkHeader().View())
+ if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
return
}
- vv.TrimFront(header.IPv6MinimumSize)
- vv.CapLength(int(h.PayloadLength()))
-
- p := h.TransportProtocol()
- if p == header.ICMPv6ProtocolNumber {
- e.handleICMP(r, headerView, vv)
- return
+ // vv consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ vv := pkt.NetworkHeader().View()[header.IPv6MinimumSize:].ToVectorisedView()
+ vv.AppendView(pkt.TransportHeader().View())
+ vv.Append(pkt.Data)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv)
+ hasFragmentHeader := false
+
+ for firstHeader := true; ; firstHeader = false {
+ extHdr, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6HopByHopOptionsExtHdr:
+ // As per RFC 8200 section 4.1, the Hop By Hop extension header is
+ // restricted to appear immediately after an IPv6 fixed header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1
+ // (unrecognized next header) error in response to an extension header's
+ // Next Header field with the Hop By Hop extension header identifier.
+ if !firstHeader {
+ return
+ }
+
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Hop By Hop extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Hop By Hop extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RoutingExtHdr:
+ // As per RFC 8200 section 4.4, if a node encounters a routing header with
+ // an unrecognized routing type value, with a non-zero Segments Left
+ // value, the node must discard the packet and send an ICMP Parameter
+ // Problem, Code 0. If the Segments Left is 0, the node must ignore the
+ // Routing extension header and process the next header in the packet.
+ //
+ // Note, the stack does not yet handle any type of routing extension
+ // header, so we just make sure Segments Left is zero before processing
+ // the next extension header.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 0 for
+ // unrecognized routing types with a non-zero Segments Left value.
+ if extHdr.SegmentsLeft() != 0 {
+ return
+ }
+
+ case header.IPv6FragmentExtHdr:
+ hasFragmentHeader = true
+
+ if extHdr.IsAtomic() {
+ // This fragment extension header indicates that this packet is an
+ // atomic fragment. An atomic fragment is a fragment that contains
+ // all the data required to reassemble a full packet. As per RFC 6946,
+ // atomic fragments must not interfere with "normal" fragmented traffic
+ // so we skip processing the fragment instead of feeding it through the
+ // reassembly process below.
+ continue
+ }
+
+ // Don't consume the iterator if we have the first fragment because we
+ // will use it to validate that the first fragment holds the upper layer
+ // header.
+ rawPayload := it.AsRawHeader(extHdr.FragmentOffset() != 0 /* consume */)
+
+ if extHdr.FragmentOffset() == 0 {
+ // Check that the iterator ends with a raw payload as the first fragment
+ // should include all headers up to and including any upper layer
+ // headers, as per RFC 8200 section 4.5; only upper layer data
+ // (non-headers) should follow the fragment extension header.
+ var lastHdr header.IPv6PayloadHeader
+
+ for {
+ it, done, err := it.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ lastHdr = it
+ }
+
+ // If the last header is a raw header, then the last portion of the IPv6
+ // payload is not a known IPv6 extension header. Note, this does not
+ // mean that the last portion is an upper layer header or not an
+ // extension header because:
+ // 1) we do not yet support all extension headers
+ // 2) we do not validate the upper layer header before reassembling.
+ //
+ // This check makes sure that a known IPv6 extension header is not
+ // present after the Fragment extension header in a non-initial
+ // fragment.
+ //
+ // TODO(#2196): Support IPv6 Authentication and Encapsulated
+ // Security Payload extension headers.
+ // TODO(#2333): Validate that the upper layer header is valid.
+ switch lastHdr.(type) {
+ case header.IPv6RawPayloadHeader:
+ default:
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+ }
+
+ fragmentPayloadLen := rawPayload.Buf.Size()
+ if fragmentPayloadLen == 0 {
+ // Drop the packet as it's marked as a fragment but has no payload.
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ // The packet is a fragment, let's try to reassemble it.
+ start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit
+ last := start + uint16(fragmentPayloadLen) - 1
+
+ // Drop the packet if the fragmentOffset is incorrect. i.e the
+ // combination of fragmentOffset and pkt.Data.size() causes a
+ // wrap around resulting in last being less than the offset.
+ if last < start {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ var ready bool
+ // Note that pkt doesn't have its transport header set after reassembly,
+ // and won't until DeliverNetworkPacket sets it.
+ pkt.Data, ready, err = e.protocol.fragmentation.Process(
+ // IPv6 ignores the Protocol field since the ID only needs to be unique
+ // across source-destination pairs, as per RFC 8200 section 4.5.
+ fragmentation.FragmentID{
+ Source: h.SourceAddress(),
+ Destination: h.DestinationAddress(),
+ ID: extHdr.ID(),
+ },
+ start,
+ last,
+ extHdr.More(),
+ rawPayload.Buf,
+ )
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ r.Stats().IP.MalformedFragmentsReceived.Increment()
+ return
+ }
+
+ if ready {
+ // We create a new iterator with the reassembled packet because we could
+ // have more extension headers in the reassembled payload, as per RFC
+ // 8200 section 4.5.
+ it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data)
+ }
+
+ case header.IPv6DestinationOptionsExtHdr:
+ optsIt := extHdr.Iter()
+
+ for {
+ opt, done, err := optsIt.Next()
+ if err != nil {
+ r.Stats().IP.MalformedPacketsReceived.Increment()
+ return
+ }
+ if done {
+ break
+ }
+
+ // We currently do not support any IPv6 Destination extension header
+ // options.
+ switch opt.UnknownAction() {
+ case header.IPv6OptionUnknownActionSkip:
+ case header.IPv6OptionUnknownActionDiscard:
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMP:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest:
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem Code 2 for
+ // unrecognized IPv6 extension header options.
+ return
+ default:
+ panic(fmt.Sprintf("unrecognized action for an unrecognized Destination extension header option = %d", opt))
+ }
+ }
+
+ case header.IPv6RawPayloadHeader:
+ // If the last header in the payload isn't a known IPv6 extension header,
+ // handle it as if it is transport layer data.
+
+ // For unfragmented packets, extHdr still contains the transport header.
+ // Get rid of it.
+ //
+ // For reassembled fragments, pkt.TransportHeader is unset, so this is a
+ // no-op and pkt.Data begins with the transport header.
+ extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size())
+ pkt.Data = extHdr.Buf
+
+ if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, pkt, hasFragmentHeader)
+ } else {
+ r.Stats().IP.PacketsDelivered.Increment()
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ e.dispatcher.DeliverTransportPacket(r, p, pkt)
+ }
+
+ default:
+ // If we receive a packet for an extension header we do not yet handle,
+ // drop the packet for now.
+ //
+ // TODO(b/152019344): Send an ICMPv6 Parameter Problem, Code 1 error
+ // in response to unrecognized next header values.
+ r.Stats().UnknownProtocolRcvdPackets.Increment()
+ return
+ }
}
-
- r.Stats().IP.PacketsDelivered.Increment()
- e.dispatcher.DeliverTransportPacket(r, p, headerView, vv)
}
// Close cleans up resources associated with the endpoint.
func (*endpoint) Close() {}
+// NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber.
+func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return e.protocol.Number()
+}
+
type protocol struct {
// defaultTTL is the current default TTL for the protocol. Only the
// uint8 portion of it is meaningful and it must be accessed
// atomically.
- defaultTTL uint32
+ defaultTTL uint32
+ fragmentation *fragmentation.Fragmentation
}
// Number returns the ipv6 protocol number.
@@ -212,16 +453,15 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
}
// NewEndpoint creates a new ipv6 endpoint.
-func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint {
return &endpoint{
- nicid: nicid,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
+ nicID: nicID,
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
protocol: p,
- }, nil
+ stack: st,
+ }
}
// SetOption implements NetworkProtocol.SetOption.
@@ -256,6 +496,83 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) {
+ hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return 0, false, false
+ }
+ ipHdr := header.IPv6(hdr)
+
+ // dataClone consists of:
+ // - Any IPv6 header bytes after the first 40 (i.e. extensions).
+ // - The transport header, if present.
+ // - Any other payload data.
+ views := [8]buffer.View{}
+ dataClone := pkt.Data.Clone(views[:])
+ dataClone.TrimFront(header.IPv6MinimumSize)
+ it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone)
+
+ // Iterate over the IPv6 extensions to find their length.
+ //
+ // Parsing occurs again in HandlePacket because we don't track the
+ // extensions in PacketBuffer. Unfortunately, that means HandlePacket
+ // has to do the parsing work again.
+ var nextHdr tcpip.TransportProtocolNumber
+ foundNext := true
+ extensionsSize := 0
+traverseExtensions:
+ for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() {
+ if err != nil {
+ break
+ }
+ // If we exhaust the extension list, the entire packet is the IPv6 header
+ // and (possibly) extensions.
+ if done {
+ extensionsSize = dataClone.Size()
+ foundNext = false
+ break
+ }
+
+ switch extHdr := extHdr.(type) {
+ case header.IPv6FragmentExtHdr:
+ // If this is an atomic fragment, we don't have to treat it specially.
+ if !extHdr.More() && extHdr.FragmentOffset() == 0 {
+ continue
+ }
+ // This is a non-atomic fragment and has to be re-assembled before we can
+ // examine the payload for a transport header.
+ foundNext = false
+
+ case header.IPv6RawPayloadHeader:
+ // We've found the payload after any extensions.
+ extensionsSize = dataClone.Size() - extHdr.Buf.Size()
+ nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier)
+ break traverseExtensions
+
+ default:
+ // Any other extension is a no-op, keep looping until we find the payload.
+ }
+ }
+
+ // Put the IPv6 header with extensions in pkt.NetworkHeader().
+ hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize)
+ if !ok {
+ panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size()))
+ }
+ ipHdr = header.IPv6(hdr)
+ pkt.Data.CapLength(int(ipHdr.PayloadLength()))
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
+
+ return nextHdr, foundNext, true
+}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
@@ -268,5 +585,8 @@ func calculateMTU(mtu uint32) uint32 {
// NewProtocol returns an IPv6 network protocol.
func NewProtocol() stack.NetworkProtocol {
- return &protocol{defaultTTL: DefaultTTL}
+ return &protocol{
+ defaultTTL: DefaultTTL,
+ fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index deaa9b7f3..0a183bfde 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -17,6 +17,7 @@ package ipv6
import (
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -33,6 +34,15 @@ const (
// The least significant 3 bytes are the same as addr2 so both addr2 and
// addr3 will have the same solicited-node address.
addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+ addr4 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x03"
+
+ // Tests use the extension header identifier values as uint8 instead of
+ // header.IPv6ExtensionHeaderIdentifier.
+ hopByHopExtHdrID = uint8(header.IPv6HopByHopOptionsExtHdrIdentifier)
+ routingExtHdrID = uint8(header.IPv6RoutingExtHdrIdentifier)
+ fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier)
+ destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier)
+ noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier)
)
// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
@@ -55,7 +65,9 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
stats := s.Stats().ICMP.V6PacketsReceived
@@ -111,7 +123,9 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
DstAddr: dst,
})
- e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
stat := s.Stats().UDP.PacketsReceived
@@ -154,6 +168,8 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// packets destined to the IPv6 solicited-node address of an assigned IPv6
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
protocolFactory stack.TransportProtocol
@@ -171,50 +187,61 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
})
- e := channel.New(10, 1280, linkAddr1)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ e := channel.New(1, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as we haven't added
- // those addresses.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as we haven't added those addresses.
test.rxf(t, s, e, addr1, snmc, 0)
- if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
}
- // Should receive a packet destined to the solicited
- // node address of addr2/addr3 now that we have added
- // added addr2.
+ // Should receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added added addr2.
test.rxf(t, s, e, addr1, snmc, 1)
- if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ if err := s.AddAddress(nicID, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr3, err)
}
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have added addr3.
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have added addr3.
test.rxf(t, s, e, addr1, snmc, 2)
- if err := s.RemoveAddress(1, addr2); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ if err := s.RemoveAddress(nicID, addr2); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr2, err)
}
- // Should still receive a packet destined to the
- // solicited node address of addr2/addr3 now that we
- // have removed addr2.
+ // Should still receive a packet destined to the solicited node address of
+ // addr2/addr3 now that we have removed addr2.
test.rxf(t, s, e, addr1, snmc, 3)
- if err := s.RemoveAddress(1, addr3); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ // Make sure addr3's endpoint does not get removed from the NIC by
+ // incrementing its reference count with a route.
+ r, err := s.FindRoute(nicID, addr3, addr4, ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr3, addr4, ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ if err := s.RemoveAddress(nicID, addr3); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr3, err)
}
- // Should not receive a packet destined to the solicited
- // node address of addr2/addr3 yet as both of them got
- // removed.
+ // Should not receive a packet destined to the solicited node address of
+ // addr2/addr3 yet as both of them got removed, even though a route using
+ // addr3 exists.
test.rxf(t, s, e, addr1, snmc, 3)
})
}
@@ -264,3 +291,1244 @@ func TestAddIpv6Address(t *testing.T) {
})
}
}
+
+func TestReceiveIPv6ExtHdrs(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ extHdr func(nextHdr uint8) ([]byte, uint8)
+ shouldAccept bool
+ }{
+ {
+ name: "None",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, nextHdr },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing with zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "routing with non-zero segments left",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 1, 2, 3, 4, 5}, routingExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "atomic fragment with zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 0, 0, 0, 0}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment with non-zero ID",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 0, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: true,
+ },
+ {
+ name: "fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{nextHdr, 0, 1, 0, 1, 2, 3, 4}, fragmentExtHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option skippable action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Skippable unknown.
+ 62, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "destination with unknown option discard action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard unknown.
+ 127, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP if option is unknown.
+ 191, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "destination with unknown option discard and send icmp action unless multicast dest",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ nextHdr, 1,
+
+ // Skippable unknown.
+ 63, 4, 1, 2, 3, 4,
+
+ // Discard & send ICMP unless packet is for multicast destination if
+ // option is unknown.
+ 255, 6, 1, 2, 3, 4, 5, 6,
+ }, destinationExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "routing - atomic fragment",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ nextHdr, 0, 0, 0, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "atomic fragment - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Fragment extension header.
+ routingExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, fragmentExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hop by hop (with skippable unknown) - routing",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ nextHdr, 0, 1, 0, 2, 3, 4, 5,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "routing - hop by hop (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Routing extension header.
+ hopByHopExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Hop By Hop extension header with skippable unknown option.
+ nextHdr, 0, 62, 4, 1, 2, 3, 4,
+ }, routingExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "No next header",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) { return []byte{}, noNextHdrID },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: true,
+ },
+ {
+ name: "hopbyhop (with discard unknown) - routing - atomic fragment - destination (with skippable unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with discard action for unknown option.
+ routingExtHdrID, 0, 65, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with skippable unknown option.
+ nextHdr, 0, 63, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ {
+ name: "hopbyhop (with skippable unknown) - routing - atomic fragment - destination (with discard unknown)",
+ extHdr: func(nextHdr uint8) ([]byte, uint8) {
+ return []byte{
+ // Hop By Hop extension header with skippable unknown option.
+ routingExtHdrID, 0, 62, 4, 1, 2, 3, 4,
+
+ // Routing extension header.
+ fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5,
+
+ // Fragment extension header.
+ destinationExtHdrID, 0, 0, 0, 1, 2, 3, 4,
+
+ // Destination extension header with discard action for unknown
+ // option.
+ nextHdr, 0, 65, 4, 1, 2, 3, 4,
+ }, hopByHopExtHdrID
+ },
+ shouldAccept: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ udpPayload := []byte{1, 2, 3, 4, 5, 6, 7, 8}
+ udpLength := header.UDPMinimumSize + len(udpPayload)
+ extHdrBytes, ipv6NextHdr := test.extHdr(uint8(header.UDPProtocolNumber))
+ extHdrLen := len(extHdrBytes)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + extHdrLen + udpLength)
+
+ // Serialize UDP message.
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), udpPayload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, addr1, addr2, uint16(udpLength))
+ sum = header.Checksum(udpPayload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ // Copy extension header bytes between the UDP message and the IPv6
+ // fixed header.
+ copy(hdr.Prepend(extHdrLen), extHdrBytes)
+
+ // Serialize IPv6 fixed header.
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: ipv6NextHdr,
+ HopLimit: 255,
+ SrcAddr: addr1,
+ DstAddr: addr2,
+ })
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+
+ stats := s.Stats().UDP.PacketsReceived
+
+ if !test.shouldAccept {
+ if got := stats.Value(); got != 0 {
+ t.Errorf("got UDP Rx Packets = %d, want = 0", got)
+ }
+
+ return
+ }
+
+ // Expect a UDP packet.
+ if got := stats.Value(); got != 1 {
+ t.Errorf("got UDP Rx Packets = %d, want = 1", got)
+ }
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+
+ // Should not have any more UDP packets.
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
+
+// fragmentData holds the IPv6 payload for a fragmented IPv6 packet.
+type fragmentData struct {
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ nextHdr uint8
+ data buffer.VectorisedView
+}
+
+func TestReceiveIPv6Fragments(t *testing.T) {
+ const (
+ nicID = 1
+ udpPayload1Length = 256
+ udpPayload2Length = 128
+ // Used to test cases where the fragment blocks are not a multiple of
+ // the fragment block size of 8 (RFC 8200 section 4.5).
+ udpPayload3Length = 127
+ fragmentExtHdrLen = 8
+ // Note, not all routing extension headers will be 8 bytes but this test
+ // uses 8 byte routing extension headers for most sub tests.
+ routingExtHdrLen = 8
+ )
+
+ udpGen := func(payload []byte, multiplier uint8, src, dst tcpip.Address) buffer.View {
+ payloadLen := len(payload)
+ for i := 0; i < payloadLen; i++ {
+ payload[i] = uint8(i) * multiplier
+ }
+
+ udpLength := header.UDPMinimumSize + payloadLen
+
+ hdr := buffer.NewPrependable(udpLength)
+ u := header.UDP(hdr.Prepend(udpLength))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: uint16(udpLength),
+ })
+ copy(u.Payload(), payload)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, uint16(udpLength))
+ sum = header.Checksum(payload, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+ return hdr.View()
+ }
+
+ var udpPayload1Addr1ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr1ToAddr2 := udpPayload1Addr1ToAddr2Buf[:]
+ ipv6Payload1Addr1ToAddr2 := udpGen(udpPayload1Addr1ToAddr2, 1, addr1, addr2)
+
+ var udpPayload1Addr3ToAddr2Buf [udpPayload1Length]byte
+ udpPayload1Addr3ToAddr2 := udpPayload1Addr3ToAddr2Buf[:]
+ ipv6Payload1Addr3ToAddr2 := udpGen(udpPayload1Addr3ToAddr2, 4, addr3, addr2)
+
+ var udpPayload2Addr1ToAddr2Buf [udpPayload2Length]byte
+ udpPayload2Addr1ToAddr2 := udpPayload2Addr1ToAddr2Buf[:]
+ ipv6Payload2Addr1ToAddr2 := udpGen(udpPayload2Addr1ToAddr2, 2, addr1, addr2)
+
+ var udpPayload3Addr1ToAddr2Buf [udpPayload3Length]byte
+ udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:]
+ ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2)
+
+ tests := []struct {
+ name string
+ expectedPayload []byte
+ fragments []fragmentData
+ expectedPayloads [][]byte
+ }{
+ {
+ name: "No fragmentation",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: uint8(header.UDPProtocolNumber),
+ data: ipv6Payload1Addr1ToAddr2.ToVectorisedView(),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Atomic fragment",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload1Addr1ToAddr2,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Atomic fragment with size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}),
+
+ ipv6Payload3Addr1ToAddr2,
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments out of order",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with last fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload3Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with first fragment size not a multiple of fragment block size",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[:63],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload3Addr1ToAddr2)-63,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload3Addr1ToAddr2[63:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with different IDs",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 2}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with per-fragment routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 0, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with per-fragment routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: routingExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{fragmentExtHdrID, 0, 1, 1, 2, 3, 4, 5}),
+
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 0, 2, 3, 4, 5}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ routingExtHdrLen+fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header.
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 1, 1, 2, 3, 4, 5}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 9, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 72, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 0.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 0, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1Addr1ToAddr2,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ {
+ name: "Two fragments with routing header with non-zero segments left across fragments",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is fragmentExtHdrLen+8 because the
+ // first 8 bytes of the 16 byte routing extension header is in
+ // this fragment.
+ fragmentExtHdrLen+8,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 1, 0, 0, 0, 1}),
+
+ // Routing extension header (part 1)
+ //
+ // Segments left = 1.
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 1, 1, 1, 2, 3, 4, 5}),
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ // The length of this payload is
+ // fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2) because the last 8 bytes of
+ // the 16 byte routing extension header is in this fagment.
+ fragmentExtHdrLen+8+len(ipv6Payload1Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 1, More = false, ID = 1
+ buffer.View([]byte{routingExtHdrID, 0, 0, 8, 0, 0, 0, 1}),
+
+ // Routing extension header (part 2)
+ buffer.View([]byte{6, 7, 8, 9, 10, 11, 12, 13}),
+
+ ipv6Payload1Addr1ToAddr2,
+ },
+ ),
+ },
+ },
+ expectedPayloads: nil,
+ },
+ // As per RFC 6946, IPv6 atomic fragments MUST NOT interfere with "normal"
+ // fragmented traffic.
+ {
+ name: "Two fragments with atomic",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ // This fragment has the same ID as the other fragments but is an atomic
+ // fragment. It should not interfere with the other fragments.
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2),
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 1}),
+
+ ipv6Payload2Addr1ToAddr2,
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload2Addr1ToAddr2, udpPayload1Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 2}),
+
+ ipv6Payload2Addr1ToAddr2[:32],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload2Addr1ToAddr2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 2
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 2}),
+
+ ipv6Payload2Addr1ToAddr2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload2Addr1ToAddr2},
+ },
+ {
+ name: "Two interleaved fragmented packets from different sources but with same ID",
+ fragments: []fragmentData{
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[:64],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 0, More = true, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[:32],
+ },
+ ),
+ },
+ {
+ srcAddr: addr1,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 8, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr1ToAddr2[64:],
+ },
+ ),
+ },
+ {
+ srcAddr: addr3,
+ dstAddr: addr2,
+ nextHdr: fragmentExtHdrID,
+ data: buffer.NewVectorisedView(
+ fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-32,
+ []buffer.View{
+ // Fragment extension header.
+ //
+ // Fragment offset = 4, More = false, ID = 1
+ buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 32, 0, 0, 0, 1}),
+
+ ipv6Payload1Addr3ToAddr2[32:],
+ },
+ ),
+ },
+ },
+ expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2, udpPayload1Addr3ToAddr2},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err)
+ }
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ProtocolNumber, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: addr2, Port: 80}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ for _, f := range test.fragments {
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize)
+
+ // Serialize IPv6 fixed header.
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(f.data.Size()),
+ NextHeader: f.nextHdr,
+ HopLimit: 255,
+ SrcAddr: f.srcAddr,
+ DstAddr: f.dstAddr,
+ })
+
+ vv := hdr.View().ToVectorisedView()
+ vv.Append(f.data)
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: vv,
+ }))
+ }
+
+ if got, want := s.Stats().UDP.PacketsReceived.Value(), uint64(len(test.expectedPayloads)); got != want {
+ t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
+ }
+
+ for i, p := range test.expectedPayloads {
+ gotPayload, _, err := ep.Read(nil)
+ if err != nil {
+ t.Fatalf("(i=%d) Read(nil): %s", i, err)
+ }
+ if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" {
+ t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff)
+ }
+ }
+
+ if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index c32716f2e..af71a7d6b 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -20,7 +20,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
)
@@ -61,17 +63,476 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
}
- ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
- if err != nil {
- t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
- }
+ ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s)
return s, ep
}
-// TestHopLimitValidation is a test that makes sure that NDP packets are only
-// received if their IP header's hop limit is set to 255.
-func TestHopLimitValidation(t *testing.T) {
+// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
+// valid NDP NS message with the Source Link Layer Address option results in a
+// new entry in the link address cache for the sender of the message.
+func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{1, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr0)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
+ }
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNeighorSolicitationResponse(t *testing.T) {
+ const nicID = 1
+ nicAddr := lladdr0
+ remoteAddr := lladdr1
+ nicAddrSNMC := header.SolicitedNodeAddr(nicAddr)
+ nicLinkAddr := linkAddr0
+ remoteLinkAddr0 := linkAddr1
+ remoteLinkAddr1 := linkAddr2
+
+ tests := []struct {
+ name string
+ nsOpts header.NDPOptionsSerializer
+ nsSrcLinkAddr tcpip.LinkAddress
+ nsSrc tcpip.Address
+ nsDst tcpip.Address
+ nsInvalid bool
+ naDstLinkAddr tcpip.LinkAddress
+ naSolicited bool
+ naSrc tcpip.Address
+ naDst tcpip.Address
+ }{
+ {
+ name: "Unspecified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Unspecified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: false,
+ naSrc: nicAddr,
+ naDst: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "Unspecified source with source ll option to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: header.IPv6Any,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source with 1 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source to multicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+ {
+ name: "Specified source with 2 source ll to multicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddrSNMC,
+ nsInvalid: true,
+ },
+
+ {
+ name: "Specified source to unicast destination",
+ nsOpts: nil,
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr0,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 1 source ll different from route to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: false,
+ naDstLinkAddr: remoteLinkAddr1,
+ naSolicited: true,
+ naSrc: nicAddr,
+ naDst: remoteAddr,
+ },
+ {
+ name: "Specified source with 2 source ll to unicast destination",
+ nsOpts: header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr0[:]),
+ header.NDPSourceLinkLayerAddressOption(remoteLinkAddr1[:]),
+ },
+ nsSrcLinkAddr: remoteLinkAddr0,
+ nsSrc: remoteAddr,
+ nsDst: nicAddr,
+ nsInvalid: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(1, 1280, nicLinkAddr)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err)
+ }
+
+ ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length()
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNSSize))
+ pkt.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ ns.SetTargetAddress(nicAddr)
+ opts := ns.Options()
+ opts.Serialize(test.nsOpts)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: test.nsSrc,
+ DstAddr: test.nsDst,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+
+ if test.nsInvalid {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ if p, got := e.Read(); got {
+ t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt)
+ }
+
+ // If we expected the NS to be invalid, we have nothing else to check.
+ return
+ }
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ p, got := e.Read()
+ if !got {
+ t.Fatal("expected an NDP NA response")
+ }
+
+ if p.Route.RemoteLinkAddress != test.naDstLinkAddr {
+ t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.naSrc),
+ checker.DstAddr(test.naDst),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPNA(
+ checker.NDPNASolicitedFlag(test.naSolicited),
+ checker.NDPNATargetAddress(nicAddr),
+ checker.NDPNAOptions([]header.NDPOption{
+ header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]),
+ }),
+ ))
+ })
+ }
+}
+
+// TestNeighorAdvertisementWithTargetLinkLayerOption tests that receiving a
+// valid NDP NA message with the Target Link Layer Address option results in a
+// new entry in the link address cache for the target of the message.
+func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ optsBuf []byte
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Valid",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7},
+ expectedLinkAddr: "\x02\x03\x04\x05\x06\x07",
+ },
+ {
+ name: "Too Small",
+ optsBuf: []byte{2, 1, 2, 3, 4, 5, 6},
+ },
+ {
+ name: "Invalid Length",
+ optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7},
+ },
+ {
+ name: "Multiple",
+ optsBuf: []byte{
+ 2, 1, 2, 3, 4, 5, 6, 7,
+ 2, 1, 2, 3, 4, 5, 6, 8,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ e := channel.New(0, 1280, linkAddr0)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err)
+ }
+
+ ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize)
+ pkt := header.ICMPv6(hdr.Prepend(ndpNASize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ ns := header.NDPNeighborAdvert(pkt.NDPPayload())
+ ns.SetTargetAddress(lladdr1)
+ opts := ns.Options()
+ copy(opts, test.optsBuf)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: lladdr1,
+ DstAddr: lladdr0,
+ })
+
+ invalid := s.Stats().ICMP.V6PacketsReceived.Invalid
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+
+ linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
+ if linkAddr != test.expectedLinkAddr {
+ t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
+ }
+
+ if test.expectedLinkAddr != "" {
+ if err != nil {
+ t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
+ }
+ if c != nil {
+ t.Errorf("got unexpected channel")
+ }
+
+ // Invalid count should not have increased.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
+ }
+ if c == nil {
+ t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
+ }
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Errorf("got invalid = %d, want = 1", got)
+ }
+ }
+ })
+ }
+}
+
+func TestNDPValidation(t *testing.T) {
setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
t.Helper()
@@ -87,94 +548,357 @@ func TestHopLimitValidation(t *testing.T) {
return s, ep, r
}
- handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) {
- payloadLength := hdr.UsedLength()
- ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) {
+ nextHdr := uint8(header.ICMPv6ProtocolNumber)
+ var extensions buffer.View
+ if atomicFragment {
+ extensions = buffer.NewView(header.IPv6FragmentExtHdrLength)
+ extensions[0] = nextHdr
+ nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier)
+ }
+
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions),
+ Data: payload.ToVectorisedView(),
+ })
+ ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions)))
ip.Encode(&header.IPv6Fields{
- PayloadLength: uint16(payloadLength),
- NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ PayloadLength: uint16(len(payload) + len(extensions)),
+ NextHeader: nextHdr,
HopLimit: hopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
- ep.HandlePacket(r, hdr.View().ToVectorisedView())
+ if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) {
+ t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n)
+ }
+ ep.HandlePacket(r, pkt)
}
+ var tllData [header.NDPLinkLayerAddressSize]byte
+ header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
+
types := []struct {
name string
typ header.ICMPv6Type
size int
+ extraData []byte
statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
}{
- {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterSolicit
- }},
- {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RouterAdvert
- }},
- {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborSolicit
- }},
- {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.NeighborAdvert
- }},
- {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
- return stats.RedirectMsg
- }},
+ {
+ name: "RouterSolicit",
+ typ: header.ICMPv6RouterSolicit,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ },
+ },
+ {
+ name: "RouterAdvert",
+ typ: header.ICMPv6RouterAdvert,
+ size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ },
+ },
+ {
+ name: "NeighborSolicit",
+ typ: header.ICMPv6NeighborSolicit,
+ size: header.ICMPv6NeighborSolicitMinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ },
+ },
+ {
+ name: "NeighborAdvert",
+ typ: header.ICMPv6NeighborAdvert,
+ size: header.ICMPv6NeighborAdvertMinimumSize,
+ extraData: tllData[:],
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ },
+ },
+ {
+ name: "RedirectMsg",
+ typ: header.ICMPv6RedirectMsg,
+ size: header.ICMPv6MinimumSize,
+ statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ },
+ },
+ }
+
+ subTests := []struct {
+ name string
+ atomicFragment bool
+ hopLimit uint8
+ code header.ICMPv6Code
+ valid bool
+ }{
+ {
+ name: "Valid",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: true,
+ },
+ {
+ name: "Fragmented",
+ atomicFragment: true,
+ hopLimit: header.NDPHopLimit,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid hop limit",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit - 1,
+ code: 0,
+ valid: false,
+ },
+ {
+ name: "Invalid ICMPv6 code",
+ atomicFragment: false,
+ hopLimit: header.NDPHopLimit,
+ code: 1,
+ valid: false,
+ },
}
for _, typ := range types {
t.Run(typ.name, func(t *testing.T) {
- s, ep, r := setup(t)
- defer r.Release()
+ for _, test := range subTests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
- stats := s.Stats().ICMP.V6PacketsReceived
- invalid := stats.Invalid
- typStat := typ.statCounter(stats)
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
- pkt := header.ICMPv6(hdr.Prepend(typ.size))
- pkt.SetType(typ.typ)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData)))
+ copy(icmp[typ.size:], typ.extraData)
+ icmp.SetType(typ.typ)
+ icmp.SetCode(test.code)
+ icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView()))
- // Invalid count should initially be 0.
- if got := invalid.Value(); got != 0 {
- t.Fatalf("got invalid = %d, want = 0", got)
- }
+ // Rx count of the NDP message should initially be 0.
+ if got := typStat.Value(); got != 0 {
+ t.Errorf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Errorf("got invalid = %d, want = 0", got)
+ }
- // Should not have received any ICMPv6 packets with
- // type = typ.typ.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r)
+
+ // Rx count of the NDP packet should have increased.
+ if got := typStat.Value(); got != 1 {
+ t.Errorf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ want := uint64(0)
+ if !test.valid {
+ // Invalid count should have increased.
+ want = 1
+ }
+ if got := invalid.Value(); got != want {
+ t.Errorf("got invalid = %d, want = %d", got, want)
+ }
+ })
}
+ })
+ }
+}
+
+// TestRouterAdvertValidation tests that when the NIC is configured to handle
+// NDP Router Advertisement packets, it validates the Router Advertisement
+// properly before handling them.
+func TestRouterAdvertValidation(t *testing.T) {
+ tests := []struct {
+ name string
+ src tcpip.Address
+ hopLimit uint8
+ code header.ICMPv6Code
+ ndpPayload []byte
+ expectedSuccess bool
+ }{
+ {
+ "OK",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "NonLinkLocalSourceAddr",
+ addr1,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "HopLimitNot255",
+ lladdr0,
+ 254,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NonZeroCode",
+ lladdr0,
+ 255,
+ 1,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "NDPPayloadTooSmall",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0,
+ },
+ false,
+ },
+ {
+ "OKWithOptions",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ 2, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ true,
+ },
+ {
+ "OptionWithZeroLength",
+ lladdr0,
+ 255,
+ 0,
+ []byte{
+ // RA payload
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+ 0, 0, 0, 0,
+
+ // Option #1 (TargetLinkLayerAddress)
+ // Invalid as it has 0 length.
+ 2, 0, 0, 0, 0, 0, 0, 0,
+
+ // Option #2 (unrecognized)
+ 255, 1, 0, 0, 0, 0, 0, 0,
+
+ // Option #3 (PrefixInformation)
+ 3, 4, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0,
+ },
+ false,
+ },
+ }
- // Receive the NDP packet with an invalid hop limit
- // value.
- handleIPv6Payload(hdr, header.NDPHopLimit-1, ep, &r)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
}
- // Rx count of NDP packet of type typ.typ should not
- // have increased.
- if got := typStat.Value(); got != 0 {
- t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(test.code)
+ copy(pkt.NDPPayload(), test.ndpPayload)
+ payloadLength := hdr.UsedLength()
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: test.hopLimit,
+ SrcAddr: test.src,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ rxRA := stats.RouterAdvert
+
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ if got := rxRA.Value(); got != 0 {
+ t.Fatalf("got rxRA = %d, want = 0", got)
}
- // Receive the NDP packet with a valid hop limit value.
- handleIPv6Payload(hdr, header.NDPHopLimit, ep, &r)
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
- // Rx count of NDP packet of type typ.typ should have
- // increased.
- if got := typStat.Value(); got != 1 {
- t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ if got := rxRA.Value(); got != 1 {
+ t.Fatalf("got rxRA = %d, want = 1", got)
}
- // Invalid count should not have increased again.
- if got := invalid.Value(); got != 1 {
- t.Fatalf("got invalid = %d, want = 1", got)
+ if test.expectedSuccess {
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+ } else {
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
}
})
}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 11efb4e44..2bad05a2e 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,14 +1,13 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "ports",
srcs = ["ports.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/ports",
- visibility = ["//:sandbox"],
+ visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
],
)
@@ -16,7 +15,7 @@ go_library(
go_test(
name = "ports_test",
srcs = ["ports_test.go"],
- embed = [":ports"],
+ library = ":ports",
deps = [
"//pkg/tcpip",
],
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 30cea8996..f6d592eb5 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -18,9 +18,9 @@ package ports
import (
"math"
"math/rand"
- "sync"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -41,6 +41,46 @@ type portDescriptor struct {
port uint16
}
+// Flags represents the type of port reservation.
+//
+// +stateify savable
+type Flags struct {
+ // MostRecent represents UDP SO_REUSEADDR.
+ MostRecent bool
+
+ // LoadBalanced indicates SO_REUSEPORT.
+ //
+ // LoadBalanced takes precidence over MostRecent.
+ LoadBalanced bool
+
+ // TupleOnly represents TCP SO_REUSEADDR.
+ TupleOnly bool
+}
+
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+ var rf BitFlags
+ if f.MostRecent {
+ rf |= MostRecentFlag
+ }
+ if f.LoadBalanced {
+ rf |= LoadBalancedFlag
+ }
+ if f.TupleOnly {
+ rf |= TupleOnlyFlag
+ }
+ return rf
+}
+
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+ e := f
+ if e.LoadBalanced && e.MostRecent {
+ e.MostRecent = false
+ }
+ return e
+}
+
// PortManager manages allocating, reserving and releasing ports.
type PortManager struct {
mu sync.RWMutex
@@ -54,9 +94,144 @@ type PortManager struct {
hint uint32
}
-type portNode struct {
- reuse bool
- refs int
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
+
+const (
+ // MostRecentFlag represents Flags.MostRecent.
+ MostRecentFlag BitFlags = 1 << iota
+
+ // LoadBalancedFlag represents Flags.LoadBalanced.
+ LoadBalancedFlag
+
+ // TupleOnlyFlag represents Flags.TupleOnly.
+ TupleOnlyFlag
+
+ // nextFlag is the value that the next added flag will have.
+ //
+ // It is used to calculate FlagMask below. It is also the number of
+ // valid flag states.
+ nextFlag
+
+ // FlagMask is a bit mask for BitFlags.
+ FlagMask = nextFlag - 1
+
+ // MultiBindFlagMask contains the flags that allow binding the same
+ // tuple multiple times.
+ MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
+)
+
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+ return Flags{
+ MostRecent: f&MostRecentFlag != 0,
+ LoadBalanced: f&LoadBalancedFlag != 0,
+ TupleOnly: f&TupleOnlyFlag != 0,
+ }
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+ // refs stores the count for each possible flag combination, (0 though
+ // FlagMask).
+ refs [nextFlag]int
+}
+
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+ c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+ c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
+ var total int
+ for _, r := range c.refs {
+ total += r
+ }
+ return total
+}
+
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
+ var total int
+ for i, r := range c.refs {
+ if BitFlags(i)&flags == flags {
+ total += r
+ }
+ }
+ return total
+}
+
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags != flags && r > 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// IntersectionRefs returns the set of flags shared by all references.
+func (c FlagCounter) IntersectionRefs() BitFlags {
+ intersection := FlagMask
+ for i, r := range c.refs {
+ if r > 0 {
+ intersection &= BitFlags(i)
+ }
+ }
+ return intersection
+}
+
+type destination struct {
+ addr tcpip.Address
+ port uint16
+}
+
+func makeDestination(a tcpip.FullAddress) destination {
+ return destination{
+ a.Addr,
+ a.Port,
+ }
+}
+
+// portNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type portNode map[destination]FlagCounter
+
+// intersectionRefs calculates the intersection of flag bit values which affect
+// the specified destination.
+//
+// If no destinations are present, all flag values are returned as there are no
+// entries to limit possible flag values of a new entry.
+//
+// In addition to the intersection, the number of intersecting refs is
+// returned.
+func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
+ intersection := FlagMask
+ var count int
+
+ for d, f := range p {
+ if d == dst {
+ intersection &= f.IntersectionRefs()
+ count++
+ continue
+ }
+ // Wildcard destinations affect all destinations for TupleOnly.
+ if d.addr == anyIPAddress || dst.addr == anyIPAddress {
+ // Only bitwise and the TupleOnlyFlag.
+ intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs())
+ count++
+ }
+ }
+
+ return intersection, count
}
// deviceNode is never empty. When it has no elements, it is removed from the
@@ -64,32 +239,45 @@ type portNode struct {
type deviceNode map[tcpip.NICID]portNode
// isAvailable checks whether binding is possible by device. If not binding to a
-// device, check against all portNodes. If binding to a specific device, check
+// device, check against all FlagCounters. If binding to a specific device, check
// against the unspecified device and the provided device.
-func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+//
+// If either of the port reuse flags is enabled on any of the nodes, all nodes
+// sharing a port must share at least one reuse flag. This matches Linux's
+// behavior.
+func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ flagBits := flags.Bits()
if bindToDevice == 0 {
- // Trying to binding all devices.
- if !reuse {
- // Can't bind because the (addr,port) is already bound.
- return false
- }
+ intersection := FlagMask
for _, p := range d {
- if !p.reuse {
- // Can't bind because the (addr,port) was previously bound without reuse.
+ i, c := p.intersectionRefs(dst)
+ if c == 0 {
+ continue
+ }
+ intersection &= i
+ if intersection&flagBits == 0 {
+ // Can't bind because the (addr,port) was
+ // previously bound without reuse.
return false
}
}
return true
}
+ intersection := FlagMask
+
if p, ok := d[0]; ok {
- if !reuse || !p.reuse {
+ var c int
+ intersection, c = p.intersectionRefs(dst)
+ if c > 0 && intersection&flagBits == 0 {
return false
}
}
if p, ok := d[bindToDevice]; ok {
- if !reuse || !p.reuse {
+ i, c := p.intersectionRefs(dst)
+ intersection &= i
+ if c > 0 && intersection&flagBits == 0 {
return false
}
}
@@ -103,12 +291,12 @@ type bindAddresses map[tcpip.Address]deviceNode
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
if addr == anyIPAddress {
// If binding to the "any" address then check that there are no conflicts
// with all addresses.
for _, d := range b {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice, dst) {
return false
}
}
@@ -117,14 +305,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice
// Check that there is no conflict with the "any" address.
if d, ok := b[anyIPAddress]; ok {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice, dst) {
return false
}
}
// Check that this is no conflict with the provided address.
if d, ok := b[addr]; ok {
- if !d.isAvailable(reuse, bindToDevice) {
+ if !d.isAvailable(flags, bindToDevice, dst) {
return false
}
}
@@ -190,17 +378,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
+ return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest))
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse, bindToDevice) {
+ if !addrs.isAvailable(addr, flags, bindToDevice, dst) {
return false
}
}
@@ -212,14 +400,16 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
+ dst := makeDestination(dest)
+
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -227,16 +417,18 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) {
return false
}
+ flagBits := flags.Bits()
+
// Reserve port on all network protocols.
for _, network := range networks {
desc := portDescriptor{network, transport, port}
@@ -250,12 +442,65 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
d = make(deviceNode)
m[addr] = d
}
- if n, ok := d[bindToDevice]; ok {
- n.refs++
- d[bindToDevice] = n
- } else {
- d[bindToDevice] = portNode{reuse: reuse, refs: 1}
+ p := d[bindToDevice]
+ if p == nil {
+ p = make(portNode)
}
+ n := p[dst]
+ n.AddRef(flagBits)
+ p[dst] = n
+ d[bindToDevice] = p
+ }
+
+ return true
+}
+
+// ReserveTuple adds a port reservation for the tuple on all given protocol.
+func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
+ flagBits := flags.Bits()
+ dst := makeDestination(dest)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // It is easier to undo the entire reservation, so if we find that the
+ // tuple can't be fully added, finish and undo the whole thing.
+ undo := false
+
+ // Reserve port on all network protocols.
+ for _, network := range networks {
+ desc := portDescriptor{network, transport, port}
+ m, ok := s.allocatedPorts[desc]
+ if !ok {
+ m = make(bindAddresses)
+ s.allocatedPorts[desc] = m
+ }
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ p := d[bindToDevice]
+ if p == nil {
+ p = make(portNode)
+ }
+
+ n := p[dst]
+ if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 {
+ // Tuple already exists.
+ undo = true
+ }
+ n.AddRef(flagBits)
+ p[dst] = n
+ d[bindToDevice] = p
+ }
+
+ if undo {
+ // releasePortLocked decrements the counts (rather than setting
+ // them to zero), so it will undo the incorrect incrementing
+ // above.
+ s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst)
+ return false
}
return true
@@ -263,10 +508,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) {
s.mu.Lock()
defer s.mu.Unlock()
+ s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest))
+}
+
+func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
@@ -274,21 +523,32 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
if !ok {
continue
}
- n, ok := d[bindToDevice]
+ p, ok := d[bindToDevice]
if !ok {
continue
}
- n.refs--
- d[bindToDevice] = n
- if n.refs == 0 {
- delete(d, bindToDevice)
+ n, ok := p[dst]
+ if !ok {
+ continue
}
- if len(d) == 0 {
- delete(m, addr)
+ n.DropRef(flags)
+ if n.TotalRefs() > 0 {
+ p[dst] = n
+ continue
+ }
+ delete(p, dst)
+ if len(p) > 0 {
+ continue
}
- if len(m) == 0 {
- delete(s.allocatedPorts, desc)
+ delete(d, bindToDevice)
+ if len(d) > 0 {
+ continue
+ }
+ delete(m, addr)
+ if len(m) > 0 {
+ continue
}
+ delete(s.allocatedPorts, desc)
}
}
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 19f4833fc..58db5868c 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -33,9 +33,10 @@ type portReserveTestAction struct {
port uint16
ip tcpip.Address
want *tcpip.Error
- reuse bool
+ flags Flags
release bool
device tcpip.NICID
+ dest tcpip.FullAddress
}
func TestPortReservation(t *testing.T) {
@@ -50,7 +51,7 @@ func TestPortReservation(t *testing.T) {
{port: 80, ip: fakeIPAddress1, want: nil},
/* N.B. Order of tests matters! */
{port: 80, ip: anyIPAddress, want: tcpip.ErrPortInUse},
- {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ {port: 80, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
},
},
{
@@ -61,7 +62,7 @@ func TestPortReservation(t *testing.T) {
/* release fakeIPAddress, but anyIPAddress is still inuse */
{port: 22, ip: fakeIPAddress, release: true},
{port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, reuse: true},
+ {port: 22, ip: fakeIPAddress, want: tcpip.ErrPortInUse, flags: Flags{LoadBalanced: true}},
/* Release port 22 from any IP address, then try to reserve fake IP address on 22 */
{port: 22, ip: anyIPAddress, want: nil, release: true},
{port: 22, ip: fakeIPAddress, want: nil},
@@ -71,36 +72,36 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 00, ip: fakeIPAddress, want: nil},
{port: 00, ip: fakeIPAddress, want: nil},
- {port: 00, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 00, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "bind to ip with reuseport",
actions: []portReserveTestAction{
- {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
- {port: 25, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 25, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 25, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 25, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 25, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 25, ip: anyIPAddress, reuse: true, want: nil},
+ {port: 25, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
tname: "bind to inaddr any with reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: anyIPAddress, reuse: true, want: nil},
- {port: 24, ip: anyIPAddress, reuse: true, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
- {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, release: true, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, release: true, want: nil},
- {port: 24, ip: anyIPAddress, release: true},
- {port: 24, ip: anyIPAddress, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: tcpip.ErrPortInUse},
- {port: 24, ip: anyIPAddress, release: true},
- {port: 24, ip: anyIPAddress, reuse: false, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{LoadBalanced: true}, release: true},
+ {port: 24, ip: anyIPAddress, flags: Flags{}, want: nil},
},
}, {
tname: "bind twice with device fails",
@@ -125,88 +126,200 @@ func TestPortReservation(t *testing.T) {
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
{port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "bind with device",
actions: []portReserveTestAction{
{port: 24, ip: fakeIPAddress, device: 123, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 789, want: nil},
{port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
- tname: "bind with reuse",
+ tname: "bind with reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
},
}, {
- tname: "binding with reuse and device",
+ tname: "binding with reuseport and device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
},
}, {
- tname: "mixing reuse and not reuse by binding to device",
+ tname: "mixing reuseport and not reuseport by binding to device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 999, want: nil},
},
}, {
- tname: "can't bind to 0 after mixing reuse and not reuse",
+ tname: "can't bind to 0 after mixing reuseport and not reuseport",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
{port: 24, ip: fakeIPAddress, device: 456, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "bind and release",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
- {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, flags: Flags{LoadBalanced: true}, want: nil},
// Release the bind to device 0 and try again.
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil},
},
}, {
- tname: "bind twice with reuse once",
+ tname: "bind twice with reuseport once",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
- {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
},
}, {
tname: "release an unreserved device",
actions: []portReserveTestAction{
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil},
// The below don't exist.
- {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
- {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, flags: Flags{}, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
// Release all.
- {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
- {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, flags: Flags{}, want: nil, release: true},
+ },
+ }, {
+ tname: "bind with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuseaddr once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, flags: Flags{}, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr and reuseport twice, and then reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ },
+ }, {
+ tname: "bind with reuseaddr, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuseport, and then reuseaddr and reuseport",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ },
+ }, {
+ tname: "bind tuple with reuseaddr, and then wildcard",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind two tuples with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
+ },
+ }, {
+ tname: "bind two tuples",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil},
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil},
+ },
+ }, {
+ tname: "bind wildcard, and then tuple with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil},
+ {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind wildcard twice with reuseaddr",
+ actions: []portReserveTestAction{
+ {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
+ {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil},
},
},
} {
@@ -216,19 +329,18 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
+ t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
}
}
})
-
}
}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index a57752a7c..cf0a5fefe 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -1,10 +1,11 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
go_binary(
name = "tun_tcp_connect",
srcs = ["main.go"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 2239c1e66..0ab089208 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -164,7 +164,7 @@ func main() {
// Create TCP endpoint.
var wq waiter.Queue
ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
- if err != nil {
+ if e != nil {
log.Fatal(e)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
index dad8ef399..43264b76d 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD
@@ -1,10 +1,11 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
go_binary(
name = "tun_tcp_echo",
srcs = ["main.go"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/link/fdbased",
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index bca73cbb1..9e37cab18 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -168,7 +168,7 @@ func main() {
// Create TCP endpoint, bind it, then start listening.
var wq waiter.Queue
ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
- if err != nil {
+ if e != nil {
log.Fatal(e)
}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
index 29b7d761c..45f503845 100644
--- a/pkg/tcpip/seqnum/BUILD
+++ b/pkg/tcpip/seqnum/BUILD
@@ -1,12 +1,9 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "seqnum",
srcs = ["seqnum.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/seqnum",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
)
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
index b40a3c212..d3bea7de4 100644
--- a/pkg/tcpip/seqnum/seqnum.go
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -46,11 +46,6 @@ func (v Value) InWindow(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
-// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
-func Overlap(a Value, b Size, x Value, y Size) bool {
- return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
-}
-
// Add calculates the sequence number following the [v, v+s) window.
func (v Value) Add(s Size) Value {
return v + Value(s)
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 460db3cf8..900938dd1 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -16,35 +15,88 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "neighbor_entry_list",
+ out = "neighbor_entry_list.go",
+ package = "stack",
+ prefix = "neighborEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*neighborEntry",
+ "Linker": "*neighborEntry",
+ },
+)
+
+go_template_instance(
+ name = "packet_buffer_list",
+ out = "packet_buffer_list.go",
+ package = "stack",
+ prefix = "PacketBuffer",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*PacketBuffer",
+ "Linker": "*PacketBuffer",
+ },
+)
+
+go_template_instance(
+ name = "tuple_list",
+ out = "tuple_list.go",
+ package = "stack",
+ prefix = "tuple",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*tuple",
+ "Linker": "*tuple",
+ },
+)
+
go_library(
name = "stack",
srcs = [
+ "conntrack.go",
+ "dhcpv6configurationfromndpra_string.go",
+ "forwarder.go",
+ "headertype_string.go",
"icmp_rate_limit.go",
+ "iptables.go",
+ "iptables_state.go",
+ "iptables_targets.go",
+ "iptables_types.go",
"linkaddrcache.go",
"linkaddrentry_list.go",
"ndp.go",
+ "neighbor_cache.go",
+ "neighbor_entry.go",
+ "neighbor_entry_list.go",
+ "neighborstate_string.go",
"nic.go",
+ "nud.go",
+ "packet_buffer.go",
+ "packet_buffer_list.go",
+ "rand.go",
"registration.go",
"route.go",
"stack.go",
"stack_global_state.go",
+ "stack_options.go",
"transport_demuxer.go",
+ "tuple_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/stack",
- visibility = [
- "//visibility:public",
- ],
+ visibility = ["//visibility:public"],
deps = [
"//pkg/ilist",
+ "//pkg/log",
"//pkg/rand",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
+ "//pkg/tcpip/transport/tcpconntrack",
"//pkg/waiter",
"@org_golang_x_time//rate:go_default_library",
],
@@ -52,46 +104,57 @@ go_library(
go_test(
name = "stack_x_test",
- size = "small",
+ size = "medium",
srcs = [
"ndp_test.go",
+ "nud_test.go",
"stack_test.go",
"transport_demuxer_test.go",
"transport_test.go",
],
+ shard_count = 20,
deps = [
":stack",
+ "//pkg/rand",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/transport/icmp",
"//pkg/tcpip/transport/udp",
"//pkg/waiter",
- "@com_github_google_go-cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
go_test(
name = "stack_test",
size = "small",
- srcs = ["linkaddrcache_test.go"],
- embed = [":stack"],
+ srcs = [
+ "fake_time_test.go",
+ "forwarder_test.go",
+ "linkaddrcache_test.go",
+ "neighbor_cache_test.go",
+ "neighbor_entry_test.go",
+ "nic_test.go",
+ "packet_buffer_test.go",
+ ],
+ library = ":stack",
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "@com_github_dpjacques_clockwork//:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "linkaddrentry_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
new file mode 100644
index 000000000..7dd344b4f
--- /dev/null
+++ b/pkg/tcpip/stack/conntrack.go
@@ -0,0 +1,631 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "encoding/binary"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
+)
+
+// Connection tracking is used to track and manipulate packets for NAT rules.
+// The connection is created for a packet if it does not exist. Every
+// connection contains two tuples (original and reply). The tuples are
+// manipulated if there is a matching NAT rule. The packet is modified by
+// looking at the tuples in the Prerouting and Output hooks.
+//
+// Currently, only TCP tracking is supported.
+
+// Our hash table has 16K buckets.
+// TODO(gvisor.dev/issue/170): These should be tunable.
+const numBuckets = 1 << 14
+
+// Direction of the tuple.
+type direction int
+
+const (
+ dirOriginal direction = iota
+ dirReply
+)
+
+// Manipulation type for the connection.
+type manipType int
+
+const (
+ manipNone manipType = iota
+ manipDstPrerouting
+ manipDstOutput
+)
+
+// tuple holds a connection's identifying and manipulating data in one
+// direction. It is immutable.
+//
+// +stateify savable
+type tuple struct {
+ // tupleEntry is used to build an intrusive list of tuples.
+ tupleEntry
+
+ tupleID
+
+ // conn is the connection tracking entry this tuple belongs to.
+ conn *conn
+
+ // direction is the direction of the tuple.
+ direction direction
+}
+
+// tupleID uniquely identifies a connection in one direction. It currently
+// contains enough information to distinguish between any TCP or UDP
+// connection, and will need to be extended to support other protocols.
+//
+// +stateify savable
+type tupleID struct {
+ srcAddr tcpip.Address
+ srcPort uint16
+ dstAddr tcpip.Address
+ dstPort uint16
+ transProto tcpip.TransportProtocolNumber
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// reply creates the reply tupleID.
+func (ti tupleID) reply() tupleID {
+ return tupleID{
+ srcAddr: ti.dstAddr,
+ srcPort: ti.dstPort,
+ dstAddr: ti.srcAddr,
+ dstPort: ti.srcPort,
+ transProto: ti.transProto,
+ netProto: ti.netProto,
+ }
+}
+
+// conn is a tracked connection.
+//
+// +stateify savable
+type conn struct {
+ // original is the tuple in original direction. It is immutable.
+ original tuple
+
+ // reply is the tuple in reply direction. It is immutable.
+ reply tuple
+
+ // manip indicates if the packet should be manipulated. It is immutable.
+ manip manipType
+
+ // tcbHook indicates if the packet is inbound or outbound to
+ // update the state of tcb. It is immutable.
+ tcbHook Hook
+
+ // mu protects all mutable state.
+ mu sync.Mutex `state:"nosave"`
+ // tcb is TCB control block. It is used to keep track of states
+ // of tcp connection and is protected by mu.
+ tcb tcpconntrack.TCB
+ // lastUsed is the last time the connection saw a relevant packet, and
+ // is updated by each packet on the connection. It is protected by mu.
+ lastUsed time.Time `state:".(unixTime)"`
+}
+
+// timedOut returns whether the connection timed out based on its state.
+func (cn *conn) timedOut(now time.Time) bool {
+ const establishedTimeout = 5 * 24 * time.Hour
+ const defaultTimeout = 120 * time.Second
+ cn.mu.Lock()
+ defer cn.mu.Unlock()
+ if cn.tcb.State() == tcpconntrack.ResultAlive {
+ // Use the same default as Linux, which doesn't delete
+ // established connections for 5(!) days.
+ return now.Sub(cn.lastUsed) > establishedTimeout
+ }
+ // Use the same default as Linux, which lets connections in most states
+ // other than established remain for <= 120 seconds.
+ return now.Sub(cn.lastUsed) > defaultTimeout
+}
+
+// update the connection tracking state.
+//
+// Precondition: ct.mu must be held.
+func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) {
+ // Update the state of tcb. tcb assumes it's always initialized on the
+ // client. However, we only need to know whether the connection is
+ // established or not, so the client/server distinction isn't important.
+ // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle
+ // other tcp states.
+ if ct.tcb.IsEmpty() {
+ ct.tcb.Init(tcpHeader)
+ } else if hook == ct.tcbHook {
+ ct.tcb.UpdateStateOutbound(tcpHeader)
+ } else {
+ ct.tcb.UpdateStateInbound(tcpHeader)
+ }
+}
+
+// ConnTrack tracks all connections created for NAT rules. Most users are
+// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
+//
+// ConnTrack keeps all connections in a slice of buckets, each of which holds a
+// linked list of tuples. This gives us some desirable properties:
+// - Each bucket has its own lock, lessening lock contention.
+// - The slice is large enough that lists stay short (<10 elements on average).
+// Thus traversal is fast.
+// - During linked list traversal we reap expired connections. This amortizes
+// the cost of reaping them and makes reapUnused faster.
+//
+// Locks are ordered by their location in the buckets slice. That is, a
+// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
+//
+// +stateify savable
+type ConnTrack struct {
+ // seed is a one-time random value initialized at stack startup
+ // and is used in the calculation of hash keys for the list of buckets.
+ // It is immutable.
+ seed uint32
+
+ // mu protects the buckets slice, but not buckets' contents. Only take
+ // the write lock if you are modifying the slice or saving for S/R.
+ mu sync.RWMutex `state:"nosave"`
+
+ // buckets is protected by mu.
+ buckets []bucket
+}
+
+// +stateify savable
+type bucket struct {
+ // mu protects tuples.
+ mu sync.Mutex `state:"nosave"`
+ tuples tupleList
+}
+
+// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid
+// TCP header.
+func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) {
+ // TODO(gvisor.dev/issue/170): Need to support for other
+ // protocols as well.
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if len(netHeader) < header.IPv4MinimumSize || netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return tupleID{}, tcpip.ErrUnknownProtocol
+ }
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
+ return tupleID{}, tcpip.ErrUnknownProtocol
+ }
+
+ return tupleID{
+ srcAddr: netHeader.SourceAddress(),
+ srcPort: tcpHeader.SourcePort(),
+ dstAddr: netHeader.DestinationAddress(),
+ dstPort: tcpHeader.DestinationPort(),
+ transProto: netHeader.TransportProtocol(),
+ netProto: header.IPv4ProtocolNumber,
+ }, nil
+}
+
+// newConn creates new connection.
+func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
+ conn := conn{
+ manip: manip,
+ tcbHook: hook,
+ lastUsed: time.Now(),
+ }
+ conn.original = tuple{conn: &conn, tupleID: orig}
+ conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
+ return &conn
+}
+
+// connFor gets the conn for pkt if it exists, or returns nil
+// if it does not. It returns an error when pkt does not contain a valid TCP
+// header.
+// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support
+// other transport protocols.
+func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil, dirOriginal
+ }
+ return ct.connForTID(tid)
+}
+
+func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) {
+ bucket := ct.bucket(tid)
+ now := time.Now()
+
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ ct.buckets[bucket].mu.Lock()
+ defer ct.buckets[bucket].mu.Unlock()
+
+ // Iterate over the tuples in a bucket, cleaning up any unused
+ // connections we find.
+ for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() {
+ // Clean up any timed-out connections we happen to find.
+ if ct.reapTupleLocked(other, bucket, now) {
+ // The tuple expired.
+ continue
+ }
+ if tid == other.tupleID {
+ return other.conn, other.direction
+ }
+ }
+
+ return nil, dirOriginal
+}
+
+func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn {
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return nil
+ }
+ if hook != Prerouting && hook != Output {
+ return nil
+ }
+
+ // Create a new connection and change the port as per the iptables
+ // rule. This tuple will be used to manipulate the packet in
+ // handlePacket.
+ replyTID := tid.reply()
+ replyTID.srcAddr = rt.MinIP
+ replyTID.srcPort = rt.MinPort
+ var manip manipType
+ switch hook {
+ case Prerouting:
+ manip = manipDstPrerouting
+ case Output:
+ manip = manipDstOutput
+ }
+ conn := newConn(tid, replyTID, manip, hook)
+ ct.insertConn(conn)
+ return conn
+}
+
+// insertConn inserts conn into the appropriate table bucket.
+func (ct *ConnTrack) insertConn(conn *conn) {
+ // Lock the buckets in the correct order.
+ tupleBucket := ct.bucket(conn.original.tupleID)
+ replyBucket := ct.bucket(conn.reply.tupleID)
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ if tupleBucket < replyBucket {
+ ct.buckets[tupleBucket].mu.Lock()
+ ct.buckets[replyBucket].mu.Lock()
+ } else if tupleBucket > replyBucket {
+ ct.buckets[replyBucket].mu.Lock()
+ ct.buckets[tupleBucket].mu.Lock()
+ } else {
+ // Both tuples are in the same bucket.
+ ct.buckets[tupleBucket].mu.Lock()
+ }
+
+ // Now that we hold the locks, ensure the tuple hasn't been inserted by
+ // another thread.
+ alreadyInserted := false
+ for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
+ if other.tupleID == conn.original.tupleID {
+ alreadyInserted = true
+ break
+ }
+ }
+
+ if !alreadyInserted {
+ // Add the tuple to the map.
+ ct.buckets[tupleBucket].tuples.PushFront(&conn.original)
+ ct.buckets[replyBucket].tuples.PushFront(&conn.reply)
+ }
+
+ // Unlocking can happen in any order.
+ ct.buckets[tupleBucket].mu.Unlock()
+ if tupleBucket != replyBucket {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
+}
+
+// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
+// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
+func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
+ // For prerouting redirection, packets going in the original direction
+ // have their destinations modified and replies have their sources
+ // modified.
+ switch dir {
+ case dirOriginal:
+ port := conn.reply.srcPort
+ tcpHeader.SetDestinationPort(port)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ case dirReply:
+ port := conn.original.dstPort
+ tcpHeader.SetSourcePort(port)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+
+ // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
+ // on inbound packets, so we don't recalculate them. However, we should
+ // support cases when they are validated, e.g. when we can't offload
+ // receive checksumming.
+
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+}
+
+// handlePacketOutput manipulates ports for packets in Output hook.
+func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
+ // If this is a noop entry, don't do anything.
+ if conn.manip == manipNone {
+ return
+ }
+
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+
+ // For output redirection, packets going in the original direction
+ // have their destinations modified and replies have their sources
+ // modified. For prerouting redirection, we only reach this point
+ // when replying, so packet sources are modified.
+ if conn.manip == manipDstOutput && dir == dirOriginal {
+ port := conn.reply.srcPort
+ tcpHeader.SetDestinationPort(port)
+ netHeader.SetDestinationAddress(conn.reply.srcAddr)
+ } else {
+ port := conn.original.dstPort
+ tcpHeader.SetSourcePort(port)
+ netHeader.SetSourceAddress(conn.original.dstAddr)
+ }
+
+ // Calculate the TCP checksum and set it.
+ tcpHeader.SetChecksum(0)
+ length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+ xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length)
+ if gso != nil && gso.NeedsCsum {
+ tcpHeader.SetChecksum(xsum)
+ } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size())
+ tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
+ }
+
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+}
+
+// handlePacket will manipulate the port and address of the packet if the
+// connection exists. Returns whether, after the packet traverses the tables,
+// it should create a new entry in the table.
+func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool {
+ if pkt.NatDone {
+ return false
+ }
+
+ if hook != Prerouting && hook != Output {
+ return false
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other transport protocols.
+ if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ return false
+ }
+
+ conn, dir := ct.connFor(pkt)
+ // Connection or Rule not found for the packet.
+ if conn == nil {
+ return true
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader().View())
+ if len(tcpHeader) < header.TCPMinimumSize {
+ return false
+ }
+
+ switch hook {
+ case Prerouting:
+ handlePacketPrerouting(pkt, conn, dir)
+ case Output:
+ handlePacketOutput(pkt, conn, gso, r, dir)
+ }
+ pkt.NatDone = true
+
+ // Update the state of tcb.
+ // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
+ // other tcp states.
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ // Mark the connection as having been used recently so it isn't reaped.
+ conn.lastUsed = time.Now()
+ // Update connection state.
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+
+ return false
+}
+
+// maybeInsertNoop tries to insert a no-op connection entry to keep connections
+// from getting clobbered when replies arrive. It only inserts if there isn't
+// already a connection for pkt.
+//
+// This should be called after traversing iptables rules only, to ensure that
+// pkt.NatDone is set correctly.
+func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) {
+ // If there were a rule applying to this packet, it would be marked
+ // with NatDone.
+ if pkt.NatDone {
+ return
+ }
+
+ // We only track TCP connections.
+ if nh := pkt.NetworkHeader().View(); nh.IsEmpty() || header.IPv4(nh).TransportProtocol() != header.TCPProtocolNumber {
+ return
+ }
+
+ // This is the first packet we're seeing for the TCP connection. Insert
+ // the noop entry (an identity mapping) so that the response doesn't
+ // get NATed, breaking the connection.
+ tid, err := packetToTupleID(pkt)
+ if err != nil {
+ return
+ }
+ conn := newConn(tid, tid.reply(), manipNone, hook)
+ conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook)
+ ct.insertConn(conn)
+}
+
+// bucket gets the conntrack bucket for a tupleID.
+func (ct *ConnTrack) bucket(id tupleID) int {
+ h := jenkins.Sum32(ct.seed)
+ h.Write([]byte(id.srcAddr))
+ h.Write([]byte(id.dstAddr))
+ shortBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(shortBuf, id.srcPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, id.dstPort)
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
+ h.Write([]byte(shortBuf))
+ binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
+ h.Write([]byte(shortBuf))
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ return int(h.Sum32()) % len(ct.buckets)
+}
+
+// reapUnused deletes timed out entries from the conntrack map. The rules for
+// reaping are:
+// - Most reaping occurs in connFor, which is called on each packet. connFor
+// cleans up the bucket the packet's connection maps to. Thus calls to
+// reapUnused should be fast.
+// - Each call to reapUnused traverses a fraction of the conntrack table.
+// Specifically, it traverses len(ct.buckets)/fractionPerReaping.
+// - After reaping, reapUnused decides when it should next run based on the
+// ratio of expired connections to examined connections. If the ratio is
+// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
+// slightly increases the interval between runs.
+// - maxFullTraversal caps the time it takes to traverse the entire table.
+//
+// reapUnused returns the next bucket that should be checked and the time after
+// which it should be called again.
+func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
+ // TODO(gvisor.dev/issue/170): This can be more finely controlled, as
+ // it is in Linux via sysctl.
+ const fractionPerReaping = 128
+ const maxExpiredPct = 50
+ const maxFullTraversal = 60 * time.Second
+ const minInterval = 10 * time.Millisecond
+ const maxInterval = maxFullTraversal / fractionPerReaping
+
+ now := time.Now()
+ checked := 0
+ expired := 0
+ var idx int
+ ct.mu.RLock()
+ defer ct.mu.RUnlock()
+ for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
+ idx = (i + start) % len(ct.buckets)
+ ct.buckets[idx].mu.Lock()
+ for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() {
+ checked++
+ if ct.reapTupleLocked(tuple, idx, now) {
+ expired++
+ }
+ }
+ ct.buckets[idx].mu.Unlock()
+ }
+ // We already checked buckets[idx].
+ idx++
+
+ // If half or more of the connections are expired, the table has gotten
+ // stale. Reschedule quickly.
+ expiredPct := 0
+ if checked != 0 {
+ expiredPct = expired * 100 / checked
+ }
+ if expiredPct > maxExpiredPct {
+ return idx, minInterval
+ }
+ if interval := prevInterval + minInterval; interval <= maxInterval {
+ // Increment the interval between runs.
+ return idx, interval
+ }
+ // We've hit the maximum interval.
+ return idx, maxInterval
+}
+
+// reapTupleLocked tries to remove tuple and its reply from the table. It
+// returns whether the tuple's connection has timed out.
+//
+// Preconditions: ct.mu is locked for reading and bucket is locked.
+func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool {
+ if !tuple.conn.timedOut(now) {
+ return false
+ }
+
+ // To maintain lock order, we can only reap these tuples if the reply
+ // appears later in the table.
+ replyBucket := ct.bucket(tuple.reply())
+ if bucket > replyBucket {
+ return true
+ }
+
+ // Don't re-lock if both tuples are in the same bucket.
+ differentBuckets := bucket != replyBucket
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Lock()
+ }
+
+ // We have the buckets locked and can remove both tuples.
+ if tuple.direction == dirOriginal {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply)
+ } else {
+ ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original)
+ }
+ ct.buckets[bucket].tuples.Remove(tuple)
+
+ // Don't re-unlock if both tuples are in the same bucket.
+ if differentBuckets {
+ ct.buckets[replyBucket].mu.Unlock()
+ }
+
+ return true
+}
+
+func (ct *ConnTrack) originalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ // Lookup the connection. The reply's original destination
+ // describes the original address.
+ tid := tupleID{
+ srcAddr: epID.LocalAddress,
+ srcPort: epID.LocalPort,
+ dstAddr: epID.RemoteAddress,
+ dstPort: epID.RemotePort,
+ transProto: header.TCPProtocolNumber,
+ netProto: header.IPv4ProtocolNumber,
+ }
+ conn, _ := ct.connForTID(tid)
+ if conn == nil {
+ // Not a tracked connection.
+ return "", 0, tcpip.ErrNotConnected
+ } else if conn.manip == manipNone {
+ // Unmanipulated connection.
+ return "", 0, tcpip.ErrInvalidOptionValue
+ }
+
+ return conn.original.dstAddr, conn.original.dstPort, nil
+}
diff --git a/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
new file mode 100644
index 000000000..d199ded6a
--- /dev/null
+++ b/pkg/tcpip/stack/dhcpv6configurationfromndpra_string.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type DHCPv6ConfigurationFromNDPRA"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[DHCPv6NoConfiguration-1]
+ _ = x[DHCPv6ManagedAddress-2]
+ _ = x[DHCPv6OtherConfigurations-3]
+}
+
+const _DHCPv6ConfigurationFromNDPRA_name = "DHCPv6NoConfigurationDHCPv6ManagedAddressDHCPv6OtherConfigurations"
+
+var _DHCPv6ConfigurationFromNDPRA_index = [...]uint8{0, 21, 41, 66}
+
+func (i DHCPv6ConfigurationFromNDPRA) String() string {
+ i -= 1
+ if i < 0 || i >= DHCPv6ConfigurationFromNDPRA(len(_DHCPv6ConfigurationFromNDPRA_index)-1) {
+ return "DHCPv6ConfigurationFromNDPRA(" + strconv.FormatInt(int64(i+1), 10) + ")"
+ }
+ return _DHCPv6ConfigurationFromNDPRA_name[_DHCPv6ConfigurationFromNDPRA_index[i]:_DHCPv6ConfigurationFromNDPRA_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go
new file mode 100644
index 000000000..92c8cb534
--- /dev/null
+++ b/pkg/tcpip/stack/fake_time_test.go
@@ -0,0 +1,209 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "container/heap"
+ "sync"
+ "time"
+
+ "github.com/dpjacques/clockwork"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+type fakeClock struct {
+ clock clockwork.FakeClock
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ // times is min-heap of times. A heap is used for quick retrieval of the next
+ // upcoming time of scheduled work.
+ times *timeHeap
+
+ // waitGroups stores one WaitGroup for all work scheduled to execute at the
+ // same time via AfterFunc. This allows parallel execution of all functions
+ // passed to AfterFunc scheduled for the same time.
+ waitGroups map[time.Time]*sync.WaitGroup
+}
+
+func newFakeClock() *fakeClock {
+ return &fakeClock{
+ clock: clockwork.NewFakeClock(),
+ times: &timeHeap{},
+ waitGroups: make(map[time.Time]*sync.WaitGroup),
+ }
+}
+
+var _ tcpip.Clock = (*fakeClock)(nil)
+
+// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
+func (fc *fakeClock) NowNanoseconds() int64 {
+ return fc.clock.Now().UnixNano()
+}
+
+// NowMonotonic implements tcpip.Clock.NowMonotonic.
+func (fc *fakeClock) NowMonotonic() int64 {
+ return fc.NowNanoseconds()
+}
+
+// AfterFunc implements tcpip.Clock.AfterFunc.
+func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer {
+ until := fc.clock.Now().Add(d)
+ wg := fc.addWait(until)
+ return &fakeTimer{
+ clock: fc,
+ until: until,
+ timer: fc.clock.AfterFunc(d, func() {
+ defer wg.Done()
+ f()
+ }),
+ }
+}
+
+// addWait adds an additional wait to the WaitGroup for parallel execution of
+// all work scheduled for t. Returns a reference to the WaitGroup modified.
+func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup {
+ fc.mu.RLock()
+ wg, ok := fc.waitGroups[t]
+ fc.mu.RUnlock()
+
+ if ok {
+ wg.Add(1)
+ return wg
+ }
+
+ fc.mu.Lock()
+ heap.Push(fc.times, t)
+ fc.mu.Unlock()
+
+ wg = &sync.WaitGroup{}
+ wg.Add(1)
+
+ fc.mu.Lock()
+ fc.waitGroups[t] = wg
+ fc.mu.Unlock()
+
+ return wg
+}
+
+// removeWait removes a wait from the WaitGroup for parallel execution of all
+// work scheduled for t.
+func (fc *fakeClock) removeWait(t time.Time) {
+ fc.mu.RLock()
+ defer fc.mu.RUnlock()
+
+ wg := fc.waitGroups[t]
+ wg.Done()
+}
+
+// advance executes all work that have been scheduled to execute within d from
+// the current fake time. Blocks until all work has completed execution.
+func (fc *fakeClock) advance(d time.Duration) {
+ // Block until all the work is done
+ until := fc.clock.Now().Add(d)
+ for {
+ fc.mu.Lock()
+ if fc.times.Len() == 0 {
+ fc.mu.Unlock()
+ return
+ }
+
+ t := heap.Pop(fc.times).(time.Time)
+ if t.After(until) {
+ // No work to do
+ heap.Push(fc.times, t)
+ fc.mu.Unlock()
+ return
+ }
+ fc.mu.Unlock()
+
+ diff := t.Sub(fc.clock.Now())
+ fc.clock.Advance(diff)
+
+ fc.mu.RLock()
+ wg := fc.waitGroups[t]
+ fc.mu.RUnlock()
+
+ wg.Wait()
+
+ fc.mu.Lock()
+ delete(fc.waitGroups, t)
+ fc.mu.Unlock()
+ }
+}
+
+type fakeTimer struct {
+ clock *fakeClock
+ timer clockwork.Timer
+
+ mu sync.RWMutex
+ until time.Time
+}
+
+var _ tcpip.Timer = (*fakeTimer)(nil)
+
+// Reset implements tcpip.Timer.Reset.
+func (ft *fakeTimer) Reset(d time.Duration) {
+ if !ft.timer.Reset(d) {
+ return
+ }
+
+ ft.mu.Lock()
+ defer ft.mu.Unlock()
+
+ ft.clock.removeWait(ft.until)
+ ft.until = ft.clock.clock.Now().Add(d)
+ ft.clock.addWait(ft.until)
+}
+
+// Stop implements tcpip.Timer.Stop.
+func (ft *fakeTimer) Stop() bool {
+ if !ft.timer.Stop() {
+ return false
+ }
+
+ ft.mu.RLock()
+ defer ft.mu.RUnlock()
+
+ ft.clock.removeWait(ft.until)
+ return true
+}
+
+type timeHeap []time.Time
+
+var _ heap.Interface = (*timeHeap)(nil)
+
+func (h timeHeap) Len() int {
+ return len(h)
+}
+
+func (h timeHeap) Less(i, j int) bool {
+ return h[i].Before(h[j])
+}
+
+func (h timeHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+func (h *timeHeap) Push(x interface{}) {
+ *h = append(*h, x.(time.Time))
+}
+
+func (h *timeHeap) Pop() interface{} {
+ last := (*h)[len(*h)-1]
+ *h = (*h)[:len(*h)-1]
+ return last
+}
diff --git a/pkg/tcpip/stack/forwarder.go b/pkg/tcpip/stack/forwarder.go
new file mode 100644
index 000000000..3eff141e6
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // maxPendingResolutions is the maximum number of pending link-address
+ // resolutions.
+ maxPendingResolutions = 64
+ maxPendingPacketsPerResolution = 256
+)
+
+type pendingPacket struct {
+ nic *NIC
+ route *Route
+ proto tcpip.NetworkProtocolNumber
+ pkt *PacketBuffer
+}
+
+type forwardQueue struct {
+ sync.Mutex
+
+ // The packets to send once the resolver completes.
+ packets map[<-chan struct{}][]*pendingPacket
+
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ cancelChans []chan struct{}
+}
+
+func newForwardQueue() *forwardQueue {
+ return &forwardQueue{packets: make(map[<-chan struct{}][]*pendingPacket)}
+}
+
+func (f *forwardQueue) enqueue(ch <-chan struct{}, n *NIC, r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ shouldWait := false
+
+ f.Lock()
+ packets, ok := f.packets[ch]
+ if !ok {
+ shouldWait = true
+ }
+ for len(packets) == maxPendingPacketsPerResolution {
+ p := packets[0]
+ packets = packets[1:]
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ p.route.Release()
+ }
+ if l := len(packets); l >= maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+ }
+ f.packets[ch] = append(packets, &pendingPacket{
+ nic: n,
+ route: r,
+ proto: protocol,
+ pkt: pkt,
+ })
+ f.Unlock()
+
+ if !shouldWait {
+ return
+ }
+
+ // Wait for the link-address resolution to complete.
+ // Start a goroutine with a forwarding-cancel channel so that we can
+ // limit the maximum number of goroutines running concurrently.
+ cancel := f.newCancelChannel()
+ go func() {
+ cancelled := false
+ select {
+ case <-ch:
+ case <-cancel:
+ cancelled = true
+ }
+
+ f.Lock()
+ packets := f.packets[ch]
+ delete(f.packets, ch)
+ f.Unlock()
+
+ for _, p := range packets {
+ if cancelled {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else if _, err := p.route.Resolve(nil); err != nil {
+ p.nic.stack.stats.IP.OutgoingPacketErrors.Increment()
+ } else {
+ p.nic.forwardPacket(p.route, p.proto, p.pkt)
+ }
+ p.route.Release()
+ }
+ }()
+}
+
+// newCancelChannel creates a channel that can cancel a pending forwarding
+// activity. The oldest channel is closed if the number of open channels would
+// exceed maxPendingResolutions.
+func (f *forwardQueue) newCancelChannel() chan struct{} {
+ f.Lock()
+ defer f.Unlock()
+
+ if len(f.cancelChans) == maxPendingResolutions {
+ ch := f.cancelChans[0]
+ f.cancelChans = f.cancelChans[1:]
+ close(ch)
+ }
+ if l := len(f.cancelChans); l >= maxPendingResolutions {
+ panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
+ }
+
+ ch := make(chan struct{})
+ f.cancelChans = append(f.cancelChans, ch)
+ return ch
+}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
new file mode 100644
index 000000000..9dff23623
--- /dev/null
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -0,0 +1,648 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "encoding/binary"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ fwdTestNetHeaderLen = 12
+ fwdTestNetDefaultPrefixLen = 8
+
+ // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match
+ // the MTU of loopback interfaces on linux systems.
+ fwdTestNetDefaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
+)
+
+// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
+// Headers of this protocol are fwdTestNetHeaderLen bytes, but we currently only
+// use the first three: destination address, source address, and transport
+// protocol. They're all one byte fields to simplify parsing.
+type fwdTestNetworkEndpoint struct {
+ nicID tcpip.NICID
+ proto *fwdTestNetworkProtocol
+ dispatcher TransportDispatcher
+ ep LinkEndpoint
+}
+
+func (f *fwdTestNetworkEndpoint) MTU() uint32 {
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
+}
+
+func (f *fwdTestNetworkEndpoint) NICID() tcpip.NICID {
+ return f.nicID
+}
+
+func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 {
+ return 123
+}
+
+func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
+}
+
+func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
+ return f.ep.MaxHeaderLength() + fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+ return 0
+}
+
+func (f *fwdTestNetworkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return f.ep.Capabilities()
+}
+
+func (f *fwdTestNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return f.proto.Number()
+}
+
+func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
+ // Add the protocol's header to the packet and send it to the link
+ // endpoint.
+ b := pkt.NetworkHeader().Push(fwdTestNetHeaderLen)
+ b[dstAddrOffset] = r.RemoteAddress[0]
+ b[srcAddrOffset] = r.LocalAddress[0]
+ b[protocolNumberOffset] = byte(params.Protocol)
+
+ return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (f *fwdTestNetworkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+func (*fwdTestNetworkEndpoint) WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (*fwdTestNetworkEndpoint) Close() {}
+
+// fwdTestNetworkProtocol is a network-layer protocol that implements Address
+// resolution.
+type fwdTestNetworkProtocol struct {
+ addrCache *linkAddrCache
+ addrResolveDelay time.Duration
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
+ onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
+}
+
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+
+func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+ return fakeNetNumber
+}
+
+func (f *fwdTestNetworkProtocol) MinimumPacketSize() int {
+ return fwdTestNetHeaderLen
+}
+
+func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
+ return fwdTestNetDefaultPrefixLen
+}
+
+func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
+}
+
+func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ netHeader, ok := pkt.NetworkHeader().Consume(fwdTestNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true
+}
+
+func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint {
+ return &fwdTestNetworkEndpoint{
+ nicID: nicID,
+ proto: f,
+ dispatcher: dispatcher,
+ ep: ep,
+ }
+}
+
+func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (f *fwdTestNetworkProtocol) Close() {}
+
+func (f *fwdTestNetworkProtocol) Wait() {}
+
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ if f.addrCache != nil && f.onLinkAddressResolved != nil {
+ time.AfterFunc(f.addrResolveDelay, func() {
+ f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
+ })
+ }
+ return nil
+}
+
+func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if f.onResolveStaticAddress != nil {
+ return f.onResolveStaticAddress(addr)
+ }
+ return "", false
+}
+
+func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return fakeNetNumber
+}
+
+// fwdTestPacketInfo holds all the information about an outbound packet.
+type fwdTestPacketInfo struct {
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalLinkAddress tcpip.LinkAddress
+ Pkt *PacketBuffer
+}
+
+type fwdTestLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+
+ // C is where outbound packets are queued.
+ C chan fwdTestPacketInfo
+}
+
+// InjectInbound injects an inbound packet.
+func (e *fwdTestLinkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ e.InjectLinkAddr(protocol, "", pkt)
+}
+
+// InjectLinkAddr injects an inbound packet with a remote link address.
+func (e *fwdTestLinkEndpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remote tcpip.LinkAddress, pkt *PacketBuffer) {
+ e.dispatcher.DeliverNetworkPacket(remote, "" /* local */, protocol, pkt)
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *fwdTestLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.IsAttached.
+func (e *fwdTestLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *fwdTestLinkEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e fwdTestLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ caps := LinkEndpointCapabilities(0)
+ return caps | CapabilityResolutionRequired
+}
+
+// GSOMaxSize returns the maximum GSO packet size.
+func (*fwdTestLinkEndpoint) GSOMaxSize() uint32 {
+ return 1 << 15
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header. Given it
+// doesn't have a header, it just returns 0.
+func (*fwdTestLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *fwdTestLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e fwdTestLinkEndpoint) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ RemoteLinkAddress: r.RemoteLinkAddress,
+ LocalLinkAddress: r.LocalLinkAddress,
+ Pkt: pkt,
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// WritePackets stores outbound packets into the channel.
+func (e *fwdTestLinkEndpoint) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ n := 0
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.WritePacket(r, gso, protocol, pkt)
+ n++
+ }
+
+ return n, nil
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
+func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
+ p := fwdTestPacketInfo{
+ Pkt: NewPacketBuffer(PacketBufferOptions{Data: vv}),
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*fwdTestLinkEndpoint) Wait() {}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
+func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) {
+ // Create a stack with the network protocol and two NICs.
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{proto},
+ })
+
+ proto.addrCache = s.linkAddrCache
+
+ // Enable forwarding.
+ s.SetForwarding(proto.Number(), true)
+
+ // NIC 1 has the link address "a", and added the network address 1.
+ ep1 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "a",
+ }
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC #1 failed:", err)
+ }
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatal("AddAddress #1 failed:", err)
+ }
+
+ // NIC 2 has the link address "b", and added the network address 2.
+ ep2 = &fwdTestLinkEndpoint{
+ C: make(chan fwdTestPacketInfo, 300),
+ mtu: fwdTestNetDefaultMTU,
+ linkAddr: "b",
+ }
+ if err := s.CreateNIC(2, ep2); err != nil {
+ t.Fatal("CreateNIC #2 failed:", err)
+ }
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatal("AddAddress #2 failed:", err)
+ }
+
+ // Route all packets to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: 2}})
+ }
+
+ return ep1, ep2
+}
+
+func TestForwardingWithStaticResolver(t *testing.T) {
+ // Create a network protocol with a static resolver.
+ proto := &fwdTestNetworkProtocol{
+ onResolveStaticAddress:
+ // The network address 3 is resolved to the link address "c".
+ func(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\x03" {
+ return "c", true
+ }
+ return "", false
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ default:
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the static address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolver(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any address will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithNoResolver(t *testing.T) {
+ // Create a network protocol without a resolver.
+ proto := &fwdTestNetworkProtocol{}
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ select {
+ case <-ep2.C:
+ t.Fatal("Packet should not be forwarded")
+ case <-time.After(time.Second):
+ }
+}
+
+func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Only packets to address 3 will be resolved to the
+ // link address "c".
+ if addr == "\x03" {
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ }
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject an inbound packet to address 4 on NIC 1. This packet should
+ // not be forwarded.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 4
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ // Inject an inbound packet to address 3 on NIC 1, and see if it is
+ // forwarded to NIC 2.
+ buf = buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+}
+
+func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ // Inject two inbound packets to address 3 on NIC 1.
+ for i := 0; i < 2; i++ {
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
+
+ for i := 0; i < 2; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
+ // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = 3
+ // Set the packet sequence number.
+ binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
+
+ for i := 0; i < maxPendingPacketsPerResolution; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ b := PayloadSince(p.Pkt.NetworkHeader())
+ if b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ if len(b) < fwdTestNetHeaderLen+2 {
+ t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b)
+ }
+ seqNumBuf := b[fwdTestNetHeaderLen:]
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
+ want := uint16(i + 5)
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
+ t.Fatalf("got the packet #%d, want = #%d", n, want)
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
+
+func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
+ // Create a network protocol with a fake resolver.
+ proto := &fwdTestNetworkProtocol{
+ addrResolveDelay: 500 * time.Millisecond,
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
+ // Any packets will be resolved to the link address "c".
+ cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
+ },
+ }
+
+ ep1, ep2 := fwdTestNetFactory(t, proto)
+
+ for i := 0; i < maxPendingResolutions+5; i++ {
+ // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1.
+ // Each packet has a different destination address (3 to
+ // maxPendingResolutions + 7).
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = byte(3 + i)
+ ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+ }
+
+ for i := 0; i < maxPendingResolutions; i++ {
+ var p fwdTestPacketInfo
+
+ select {
+ case p = <-ep2.C:
+ case <-time.After(time.Second):
+ t.Fatal("packet not forwarded")
+ }
+
+ // The first 5 packets (address 3 to 7) should not be forwarded
+ // because their address resolutions are interrupted.
+ if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset])
+ }
+
+ // Test that the address resolution happened correctly.
+ if p.RemoteLinkAddress != "c" {
+ t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress)
+ }
+ if p.LocalLinkAddress != "b" {
+ t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/headertype_string.go b/pkg/tcpip/stack/headertype_string.go
new file mode 100644
index 000000000..5efddfaaf
--- /dev/null
+++ b/pkg/tcpip/stack/headertype_string.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type headerType ."; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[linkHeader-0]
+ _ = x[networkHeader-1]
+ _ = x[transportHeader-2]
+ _ = x[numHeaderType-3]
+}
+
+const _headerType_name = "linkHeadernetworkHeadertransportHeadernumHeaderType"
+
+var _headerType_index = [...]uint8{0, 10, 23, 38, 51}
+
+func (i headerType) String() string {
+ if i < 0 || i >= headerType(len(_headerType_index)-1) {
+ return "headerType(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _headerType_name[_headerType_index[i]:_headerType_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
new file mode 100644
index 000000000..c37da814f
--- /dev/null
+++ b/pkg/tcpip/stack/iptables.go
@@ -0,0 +1,423 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// tableID is an index into IPTables.tables.
+type tableID int
+
+const (
+ natID tableID = iota
+ mangleID
+ filterID
+ numTables
+)
+
+// Table names.
+const (
+ NATTable = "nat"
+ MangleTable = "mangle"
+ FilterTable = "filter"
+)
+
+// nameToID is immutable.
+var nameToID = map[string]tableID{
+ NATTable: natID,
+ MangleTable: mangleID,
+ FilterTable: filterID,
+}
+
+// HookUnset indicates that there is no hook set for an entrypoint or
+// underflow.
+const HookUnset = -1
+
+// reaperDelay is how long to wait before starting to reap connections.
+const reaperDelay = 5 * time.Second
+
+// DefaultTables returns a default set of tables. Each chain is set to accept
+// all packets.
+func DefaultTables() *IPTables {
+ return &IPTables{
+ tables: [numTables]Table{
+ natID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: 1,
+ Forward: HookUnset,
+ Output: 2,
+ Postrouting: 3,
+ },
+ },
+ mangleID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: 0,
+ Output: 1,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: 0,
+ Input: HookUnset,
+ Forward: HookUnset,
+ Output: 1,
+ Postrouting: HookUnset,
+ },
+ },
+ filterID: Table{
+ Rules: []Rule{
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: AcceptTarget{}},
+ Rule{Target: ErrorTarget{}},
+ },
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Input: 0,
+ Forward: 1,
+ Output: 2,
+ Postrouting: HookUnset,
+ },
+ },
+ },
+ priorities: [NumHooks][]tableID{
+ Prerouting: []tableID{mangleID, natID},
+ Input: []tableID{natID, filterID},
+ Output: []tableID{mangleID, natID, filterID},
+ },
+ connections: ConnTrack{
+ seed: generateRandUint32(),
+ },
+ reaperDone: make(chan struct{}, 1),
+ }
+}
+
+// EmptyFilterTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyFilterTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
+ },
+ Underflows: [NumHooks]int{
+ Prerouting: HookUnset,
+ Postrouting: HookUnset,
+ },
+ }
+}
+
+// EmptyNATTable returns a Table with no rules and the filter table chains
+// mapped to HookUnset.
+func EmptyNATTable() Table {
+ return Table{
+ Rules: []Rule{},
+ BuiltinChains: [NumHooks]int{
+ Forward: HookUnset,
+ },
+ Underflows: [NumHooks]int{
+ Forward: HookUnset,
+ },
+ }
+}
+
+// GetTable returns a table by name.
+func (it *IPTables) GetTable(name string) (Table, bool) {
+ id, ok := nameToID[name]
+ if !ok {
+ return Table{}, false
+ }
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ return it.tables[id], true
+}
+
+// ReplaceTable replaces or inserts table by name.
+func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+ id, ok := nameToID[name]
+ if !ok {
+ return tcpip.ErrInvalidOptionValue
+ }
+ it.mu.Lock()
+ defer it.mu.Unlock()
+ // If iptables is being enabled, initialize the conntrack table and
+ // reaper.
+ if !it.modified {
+ it.connections.buckets = make([]bucket, numBuckets)
+ it.startReaper(reaperDelay)
+ }
+ it.modified = true
+ it.tables[id] = table
+ return nil
+}
+
+// A chainVerdict is what a table decides should be done with a packet.
+type chainVerdict int
+
+const (
+ // chainAccept indicates the packet should continue through netstack.
+ chainAccept chainVerdict = iota
+
+ // chainAccept indicates the packet should be dropped.
+ chainDrop
+
+ // chainReturn indicates the packet should return to the calling chain
+ // or the underflow rule of a builtin chain.
+ chainReturn
+)
+
+// Check runs pkt through the rules for hook. It returns true when the packet
+// should continue traversing the network stack and false when it should be
+// dropped.
+//
+// Precondition: pkt.NetworkHeader is set.
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool {
+ // Many users never configure iptables. Spare them the cost of rule
+ // traversal if rules have never been set.
+ it.mu.RLock()
+ defer it.mu.RUnlock()
+ if !it.modified {
+ return true
+ }
+
+ // Packets are manipulated only if connection and matching
+ // NAT rule exists.
+ shouldTrack := it.connections.handlePacket(pkt, hook, gso, r)
+
+ // Go through each table containing the hook.
+ priorities := it.priorities[hook]
+ for _, tableID := range priorities {
+ // If handlePacket already NATed the packet, we don't need to
+ // check the NAT table.
+ if tableID == natID && pkt.NatDone {
+ continue
+ }
+ table := it.tables[tableID]
+ ruleIdx := table.BuiltinChains[hook]
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ // If the table returns Accept, move on to the next table.
+ case chainAccept:
+ continue
+ // The Drop verdict is final.
+ case chainDrop:
+ return false
+ case chainReturn:
+ // Any Return from a built-in chain means we have to
+ // call the underflow.
+ underflow := table.Rules[table.Underflows[hook]]
+ switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v {
+ case RuleAccept:
+ continue
+ case RuleDrop:
+ return false
+ case RuleJump, RuleReturn:
+ panic("Underflows should only return RuleAccept or RuleDrop.")
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", v))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown verdict %v.", verdict))
+ }
+ }
+
+ // If this connection should be tracked, try to add an entry for it. If
+ // traversing the nat table didn't end in adding an entry,
+ // maybeInsertNoop will add a no-op entry for the connection. This is
+ // needeed when establishing connections so that the SYN/ACK reply to an
+ // outgoing SYN is delivered to the correct endpoint rather than being
+ // redirected by a prerouting rule.
+ //
+ // From the iptables documentation: "If there is no rule, a `null'
+ // binding is created: this usually does not map the packet, but exists
+ // to ensure we don't map another stream over an existing one."
+ if shouldTrack {
+ it.connections.maybeInsertNoop(pkt, hook)
+ }
+
+ // Every table returned Accept.
+ return true
+}
+
+// beforeSave is invoked by stateify.
+func (it *IPTables) beforeSave() {
+ // Ensure the reaper exits cleanly.
+ it.reaperDone <- struct{}{}
+ // Prevent others from modifying the connection table.
+ it.connections.mu.Lock()
+}
+
+// afterLoad is invoked by stateify.
+func (it *IPTables) afterLoad() {
+ it.startReaper(reaperDelay)
+}
+
+// startReaper starts a goroutine that wakes up periodically to reap timed out
+// connections.
+func (it *IPTables) startReaper(interval time.Duration) {
+ go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved.
+ bucket := 0
+ for {
+ select {
+ case <-it.reaperDone:
+ return
+ case <-time.After(interval):
+ bucket, interval = it.connections.reapUnused(bucket, interval)
+ }
+ }
+ }()
+}
+
+// CheckPackets runs pkts through the rules for hook and returns a map of packets that
+// should not go forward.
+//
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
+//
+// NOTE: unlike the Check API the returned map contains packets that should be
+// dropped.
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ if !pkt.NatDone {
+ if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok {
+ if drop == nil {
+ drop = make(map[*PacketBuffer]struct{})
+ }
+ drop[pkt] = struct{}{}
+ }
+ if pkt.NatDone {
+ if natPkts == nil {
+ natPkts = make(map[*PacketBuffer]struct{})
+ }
+ natPkts[pkt] = struct{}{}
+ }
+ }
+ }
+ return drop, natPkts
+}
+
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict {
+ // Start from ruleIdx and walk the list of rules until a rule gives us
+ // a verdict.
+ for ruleIdx < len(table.Rules) {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict {
+ case RuleAccept:
+ return chainAccept
+
+ case RuleDrop:
+ return chainDrop
+
+ case RuleReturn:
+ return chainReturn
+
+ case RuleJump:
+ // "Jumping" to the next rule just means we're
+ // continuing on down the list.
+ if jumpTo == ruleIdx+1 {
+ ruleIdx++
+ continue
+ }
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict {
+ case chainAccept:
+ return chainAccept
+ case chainDrop:
+ return chainDrop
+ case chainReturn:
+ ruleIdx++
+ continue
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
+ }
+
+ }
+
+ // We got through the entire table without a decision. Default to DROP
+ // for safety.
+ return chainDrop
+}
+
+// Preconditions:
+// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
+// - pkt.NetworkHeader is not nil.
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) {
+ rule := table.Rules[ruleIdx]
+
+ // Check whether the packet matches the IP header filter.
+ if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) {
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
+ }
+
+ // Go through each rule matcher. If they all match, run
+ // the rule target.
+ for _, matcher := range rule.Matchers {
+ matches, hotdrop := matcher.Match(hook, pkt, "")
+ if hotdrop {
+ return RuleDrop, 0
+ }
+ if !matches {
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
+ }
+ }
+
+ // All the matchers matched, so run the target.
+ return rule.Target.Action(pkt, &it.connections, hook, gso, r, address)
+}
+
+// OriginalDst returns the original destination of redirected connections. It
+// returns an error if the connection doesn't exist or isn't redirected.
+func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) {
+ return it.connections.originalDst(epID)
+}
diff --git a/pkg/sentry/fsimpl/proc/filesystems.go b/pkg/tcpip/stack/iptables_state.go
index c36c4aff5..529e02a07 100644
--- a/pkg/sentry/fsimpl/proc/filesystems.go
+++ b/pkg/tcpip/stack/iptables_state.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,14 +12,29 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package proc
+package stack
+
+import (
+ "time"
+)
-// filesystemsData implements vfs.DynamicBytesSource for /proc/filesystems.
-//
// +stateify savable
-type filesystemsData struct{}
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastUsed is invoked by stateify.
+func (cn *conn) saveLastUsed() unixTime {
+ return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()}
+}
+
+// loadLastUsed is invoked by stateify.
+func (cn *conn) loadLastUsed(unix unixTime) {
+ cn.lastUsed = time.Unix(unix.second, unix.nano)
+}
-// TODO(b/138862512): Implement vfs.DynamicBytesSource.Generate for
-// filesystemsData. We would need to retrive filesystem names from
-// vfs.VirtualFilesystem. Also needs vfs replacement for
-// fs.Filesystem.AllowUserList() and fs.FilesystemRequiresDev.
+// beforeSave is invoked by stateify.
+func (ct *ConnTrack) beforeSave() {
+ ct.mu.Lock()
+}
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
new file mode 100644
index 000000000..5f1b2af64
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -0,0 +1,163 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// AcceptTarget accepts packets.
+type AcceptTarget struct{}
+
+// Action implements Target.Action.
+func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ return RuleAccept, 0
+}
+
+// DropTarget drops packets.
+type DropTarget struct{}
+
+// Action implements Target.Action.
+func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ return RuleDrop, 0
+}
+
+// ErrorTarget logs an error and drops the packet. It represents a target that
+// should be unreachable.
+type ErrorTarget struct{}
+
+// Action implements Target.Action.
+func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ log.Debugf("ErrorTarget triggered.")
+ return RuleDrop, 0
+}
+
+// UserChainTarget marks a rule as the beginning of a user chain.
+type UserChainTarget struct {
+ Name string
+}
+
+// Action implements Target.Action.
+func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ panic("UserChainTarget should never be called.")
+}
+
+// ReturnTarget returns from the current chain. If the chain is a built-in, the
+// hook's underflow should be called.
+type ReturnTarget struct{}
+
+// Action implements Target.Action.
+func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) {
+ return RuleReturn, 0
+}
+
+// RedirectTarget redirects the packet by modifying the destination port/IP.
+// Min and Max values for IP and Ports in the struct indicate the range of
+// values which can be used to redirect.
+type RedirectTarget struct {
+ // TODO(gvisor.dev/issue/170): Other flags need to be added after
+ // we support them.
+ // RangeProtoSpecified flag indicates single port is specified to
+ // redirect.
+ RangeProtoSpecified bool
+
+ // MinIP indicates address used to redirect.
+ MinIP tcpip.Address
+
+ // MaxIP indicates address used to redirect.
+ MaxIP tcpip.Address
+
+ // MinPort indicates port used to redirect.
+ MinPort uint16
+
+ // MaxPort indicates port used to redirect.
+ MaxPort uint16
+}
+
+// Action implements Target.Action.
+// TODO(gvisor.dev/issue/170): Parse headers without copying. The current
+// implementation only works for PREROUTING and calls pkt.Clone(), neither
+// of which should be the case.
+func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
+ // Packet is already manipulated.
+ if pkt.NatDone {
+ return RuleAccept, 0
+ }
+
+ // Drop the packet if network and transport header are not set.
+ if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
+ return RuleDrop, 0
+ }
+
+ // Change the address to localhost (127.0.0.1) in Output and
+ // to primary address of the incoming interface in Prerouting.
+ switch hook {
+ case Output:
+ rt.MinIP = tcpip.Address([]byte{127, 0, 0, 1})
+ rt.MaxIP = tcpip.Address([]byte{127, 0, 0, 1})
+ case Prerouting:
+ rt.MinIP = address
+ rt.MaxIP = address
+ default:
+ panic("redirect target is supported only on output and prerouting hooks")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
+ // we need to change dest address (for OUTPUT chain) or ports.
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ switch protocol := netHeader.TransportProtocol(); protocol {
+ case header.UDPProtocolNumber:
+ udpHeader := header.UDP(pkt.TransportHeader().View())
+ udpHeader.SetDestinationPort(rt.MinPort)
+
+ // Calculate UDP checksum and set it.
+ if hook == Output {
+ udpHeader.SetChecksum(0)
+ length := uint16(pkt.Size()) - uint16(netHeader.HeaderLength())
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&CapabilityTXChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(protocol, length)
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udpHeader.SetChecksum(0)
+ udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
+ }
+ }
+ // Change destination address.
+ netHeader.SetDestinationAddress(rt.MinIP)
+ netHeader.SetChecksum(0)
+ netHeader.SetChecksum(^netHeader.CalculateChecksum())
+ pkt.NatDone = true
+ case header.TCPProtocolNumber:
+ if ct == nil {
+ return RuleAccept, 0
+ }
+
+ // Set up conection for matching NAT rule. Only the first
+ // packet of the connection comes here. Other packets will be
+ // manipulated in connection tracking.
+ if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil {
+ ct.handlePacket(pkt, hook, gso, r)
+ }
+ default:
+ return RuleDrop, 0
+ }
+
+ return RuleAccept, 0
+}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
new file mode 100644
index 000000000..73274ada9
--- /dev/null
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -0,0 +1,262 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "strings"
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// A Hook specifies one of the hooks built into the network stack.
+//
+// Userspace app Userspace app
+// ^ |
+// | v
+// [Input] [Output]
+// ^ |
+// | v
+// | routing
+// | |
+// | v
+// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
+type Hook uint
+
+// These values correspond to values in include/uapi/linux/netfilter.h.
+const (
+ // Prerouting happens before a packet is routed to applications or to
+ // be forwarded.
+ Prerouting Hook = iota
+
+ // Input happens before a packet reaches an application.
+ Input
+
+ // Forward happens once it's decided that a packet should be forwarded
+ // to another host.
+ Forward
+
+ // Output happens after a packet is written by an application to be
+ // sent out.
+ Output
+
+ // Postrouting happens just before a packet goes out on the wire.
+ Postrouting
+
+ // The total number of hooks.
+ NumHooks
+)
+
+// A RuleVerdict is what a rule decides should be done with a packet.
+type RuleVerdict int
+
+const (
+ // RuleAccept indicates the packet should continue through netstack.
+ RuleAccept RuleVerdict = iota
+
+ // RuleDrop indicates the packet should be dropped.
+ RuleDrop
+
+ // RuleJump indicates the packet should jump to another chain.
+ RuleJump
+
+ // RuleReturn indicates the packet should return to the previous chain.
+ RuleReturn
+)
+
+// IPTables holds all the tables for a netstack.
+//
+// +stateify savable
+type IPTables struct {
+ // mu protects tables, priorities, and modified.
+ mu sync.RWMutex
+
+ // tables maps tableIDs to tables. Holds builtin tables only, not user
+ // tables. mu must be locked for accessing.
+ tables [numTables]Table
+
+ // priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook. mu needs to be locked for accessing.
+ priorities [NumHooks][]tableID
+
+ // modified is whether tables have been modified at least once. It is
+ // used to elide the iptables performance overhead for workloads that
+ // don't utilize iptables.
+ modified bool
+
+ connections ConnTrack
+
+ // reaperDone can be signalled to stop the reaper goroutine.
+ reaperDone chan struct{}
+}
+
+// A Table defines a set of chains and hooks into the network stack. It is
+// really just a list of rules.
+//
+// +stateify savable
+type Table struct {
+ // Rules holds the rules that make up the table.
+ Rules []Rule
+
+ // BuiltinChains maps builtin chains to their entrypoint rule in Rules.
+ BuiltinChains [NumHooks]int
+
+ // Underflows maps builtin chains to their underflow rule in Rules
+ // (i.e. the rule to execute if the chain returns without a verdict).
+ Underflows [NumHooks]int
+}
+
+// ValidHooks returns a bitmap of the builtin hooks for the given table.
+func (table *Table) ValidHooks() uint32 {
+ hooks := uint32(0)
+ for hook, ruleIdx := range table.BuiltinChains {
+ if ruleIdx != HookUnset {
+ hooks |= 1 << hook
+ }
+ }
+ return hooks
+}
+
+// A Rule is a packet processing rule. It consists of two pieces. First it
+// contains zero or more matchers, each of which is a specification of which
+// packets this rule applies to. If there are no matchers in the rule, it
+// applies to any packet.
+//
+// +stateify savable
+type Rule struct {
+ // Filter holds basic IP filtering fields common to every rule.
+ Filter IPHeaderFilter
+
+ // Matchers is the list of matchers for this rule.
+ Matchers []Matcher
+
+ // Target is the action to invoke if all the matchers match the packet.
+ Target Target
+}
+
+// IPHeaderFilter holds basic IP filtering data common to every rule.
+//
+// +stateify savable
+type IPHeaderFilter struct {
+ // Protocol matches the transport protocol.
+ Protocol tcpip.TransportProtocolNumber
+
+ // Dst matches the destination IP address.
+ Dst tcpip.Address
+
+ // DstMask masks bits of the destination IP address when comparing with
+ // Dst.
+ DstMask tcpip.Address
+
+ // DstInvert inverts the meaning of the destination IP check, i.e. when
+ // true the filter will match packets that fail the destination
+ // comparison.
+ DstInvert bool
+
+ // Src matches the source IP address.
+ Src tcpip.Address
+
+ // SrcMask masks bits of the source IP address when comparing with Src.
+ SrcMask tcpip.Address
+
+ // SrcInvert inverts the meaning of the source IP check, i.e. when true the
+ // filter will match packets that fail the source comparison.
+ SrcInvert bool
+
+ // OutputInterface matches the name of the outgoing interface for the
+ // packet.
+ OutputInterface string
+
+ // OutputInterfaceMask masks the characters of the interface name when
+ // comparing with OutputInterface.
+ OutputInterfaceMask string
+
+ // OutputInterfaceInvert inverts the meaning of outgoing interface check,
+ // i.e. when true the filter will match packets that fail the outgoing
+ // interface comparison.
+ OutputInterfaceInvert bool
+}
+
+// match returns whether hdr matches the filter.
+func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool {
+ // TODO(gvisor.dev/issue/170): Support other fields of the filter.
+ // Check the transport protocol.
+ if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() {
+ return false
+ }
+
+ // Check the source and destination IPs.
+ if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) {
+ return false
+ }
+
+ // Check the output interface.
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ if hook == Output {
+ n := len(fl.OutputInterface)
+ if n == 0 {
+ return true
+ }
+
+ // If the interface name ends with '+', any interface which begins
+ // with the name should be matched.
+ ifName := fl.OutputInterface
+ matches := true
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return fl.OutputInterfaceInvert != matches
+ }
+
+ return true
+}
+
+// filterAddress returns whether addr matches the filter.
+func filterAddress(addr, mask, filterAddr tcpip.Address, invert bool) bool {
+ matches := true
+ for i := range filterAddr {
+ if addr[i]&mask[i] != filterAddr[i] {
+ matches = false
+ break
+ }
+ }
+ return matches != invert
+}
+
+// A Matcher is the interface for matching packets.
+type Matcher interface {
+ // Name returns the name of the Matcher.
+ Name() string
+
+ // Match returns whether the packet matches and whether the packet
+ // should be "hotdropped", i.e. dropped immediately. This is usually
+ // used for suspicious packets.
+ //
+ // Precondition: packet.NetworkHeader is set.
+ Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+}
+
+// A Target is the interface for taking an action for a packet.
+type Target interface {
+ // Action takes an action on the packet and returns a verdict on how
+ // traversal should (or should not) continue. If the return value is
+ // Jump, it also returns the index of the rule to jump to.
+ Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int)
+}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 267df60d1..6f73a0ce4 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -16,10 +16,10 @@ package stack
import (
"fmt"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 9946b8fe8..b15b8d1cb 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -16,12 +16,12 @@ package stack
import (
"fmt"
- "sync"
"sync/atomic"
"testing"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
if f := r.onLinkAddressRequest; f != nil {
f()
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 03ddebdbd..97ca00d16 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"log"
+ "math/rand"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -32,39 +33,277 @@ const (
// Default = 1 (from RFC 4862 section 5.1)
defaultDupAddrDetectTransmits = 1
- // defaultRetransmitTimer is the default amount of time to wait between
- // sending NDP Neighbor solicitation messages.
+ // defaultMaxRtrSolicitations is the default number of Router
+ // Solicitation messages to send when a NIC becomes enabled.
//
- // Default = 1s (from RFC 4861 section 10).
- defaultRetransmitTimer = time.Second
+ // Default = 3 (from RFC 4861 section 10).
+ defaultMaxRtrSolicitations = 3
- // minimumRetransmitTimer is the minimum amount of time to wait between
- // sending NDP Neighbor solicitation messages. Note, RFC 4861 does
- // not impose a minimum Retransmit Timer, but we do here to make sure
- // the messages are not sent all at once. We also come to this value
- // because in the RetransmitTimer field of a Router Advertisement, a
- // value of 0 means unspecified, so the smallest valid value is 1.
- // Note, the unit of the RetransmitTimer field in the Router
- // Advertisement is milliseconds.
+ // defaultRtrSolicitationInterval is the default amount of time between
+ // sending Router Solicitation messages.
//
- // Min = 1ms.
- minimumRetransmitTimer = time.Millisecond
+ // Default = 4s (from 4861 section 10).
+ defaultRtrSolicitationInterval = 4 * time.Second
+
+ // defaultMaxRtrSolicitationDelay is the default maximum amount of time
+ // to wait before sending the first Router Solicitation message.
+ //
+ // Default = 1s (from 4861 section 10).
+ defaultMaxRtrSolicitationDelay = time.Second
+
+ // defaultHandleRAs is the default configuration for whether or not to
+ // handle incoming Router Advertisements as a host.
+ defaultHandleRAs = true
+
+ // defaultDiscoverDefaultRouters is the default configuration for
+ // whether or not to discover default routers from incoming Router
+ // Advertisements, as a host.
+ defaultDiscoverDefaultRouters = true
+
+ // defaultDiscoverOnLinkPrefixes is the default configuration for
+ // whether or not to discover on-link prefixes from incoming Router
+ // Advertisements' Prefix Information option, as a host.
+ defaultDiscoverOnLinkPrefixes = true
+
+ // defaultAutoGenGlobalAddresses is the default configuration for
+ // whether or not to generate global IPv6 addresses in response to
+ // receiving a new Prefix Information option with its Autonomous
+ // Address AutoConfiguration flag set, as a host.
+ //
+ // Default = true.
+ defaultAutoGenGlobalAddresses = true
+
+ // minimumRtrSolicitationInterval is the minimum amount of time to wait
+ // between sending Router Solicitation messages. This limit is imposed
+ // to make sure that Router Solicitation messages are not sent all at
+ // once, defeating the purpose of sending the initial few messages.
+ minimumRtrSolicitationInterval = 500 * time.Millisecond
+
+ // minimumMaxRtrSolicitationDelay is the minimum amount of time to wait
+ // before sending the first Router Solicitation message. It is 0 because
+ // we cannot have a negative delay.
+ minimumMaxRtrSolicitationDelay = 0
+
+ // MaxDiscoveredDefaultRouters is the maximum number of discovered
+ // default routers. The stack should stop discovering new routers after
+ // discovering MaxDiscoveredDefaultRouters routers.
+ //
+ // This value MUST be at minimum 2 as per RFC 4861 section 6.3.4, and
+ // SHOULD be more.
+ MaxDiscoveredDefaultRouters = 10
+
+ // MaxDiscoveredOnLinkPrefixes is the maximum number of discovered
+ // on-link prefixes. The stack should stop discovering new on-link
+ // prefixes after discovering MaxDiscoveredOnLinkPrefixes on-link
+ // prefixes.
+ MaxDiscoveredOnLinkPrefixes = 10
+
+ // validPrefixLenForAutoGen is the expected prefix length that an
+ // address can be generated for. Must be 64 bits as the interface
+ // identifier (IID) is 64 bits and an IPv6 address is 128 bits, so
+ // 128 - 64 = 64.
+ validPrefixLenForAutoGen = 64
+
+ // defaultAutoGenTempGlobalAddresses is the default configuration for whether
+ // or not to generate temporary SLAAC addresses.
+ defaultAutoGenTempGlobalAddresses = true
+
+ // defaultMaxTempAddrValidLifetime is the default maximum valid lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 7 days (from RFC 4941 section 5).
+ defaultMaxTempAddrValidLifetime = 7 * 24 * time.Hour
+
+ // defaultMaxTempAddrPreferredLifetime is the default preferred lifetime
+ // for temporary SLAAC addresses generated as part of RFC 4941.
+ //
+ // Default = 1 day (from RFC 4941 section 5).
+ defaultMaxTempAddrPreferredLifetime = 24 * time.Hour
+
+ // defaultRegenAdvanceDuration is the default duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ //
+ // Default = 5s (from RFC 4941 section 5).
+ defaultRegenAdvanceDuration = 5 * time.Second
+
+ // minRegenAdvanceDuration is the minimum duration before the deprecation
+ // of a temporary address when a new address will be generated.
+ minRegenAdvanceDuration = time.Duration(0)
+
+ // maxSLAACAddrLocalRegenAttempts is the maximum number of times to attempt
+ // SLAAC address regenerations in response to a NIC-local conflict.
+ maxSLAACAddrLocalRegenAttempts = 10
+)
+
+var (
+ // MinPrefixInformationValidLifetimeForUpdate is the minimum Valid
+ // Lifetime to update the valid lifetime of a generated address by
+ // SLAAC.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Min = 2hrs.
+ MinPrefixInformationValidLifetimeForUpdate = 2 * time.Hour
+
+ // MaxDesyncFactor is the upper bound for the preferred lifetime's desync
+ // factor for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // Must be greater than 0.
+ //
+ // Max = 10m (from RFC 4941 section 5).
+ MaxDesyncFactor = 10 * time.Minute
+
+ // MinMaxTempAddrPreferredLifetime is the minimum value allowed for the
+ // maximum preferred lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address will be preferred for at
+ // least 1hr if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrPreferredLifetime = defaultRegenAdvanceDuration + MaxDesyncFactor + time.Hour
+
+ // MinMaxTempAddrValidLifetime is the minimum value allowed for the
+ // maximum valid lifetime for temporary SLAAC addresses.
+ //
+ // This is exported as a variable (instead of a constant) so tests
+ // can update it to a smaller value.
+ //
+ // This value guarantees that a temporary address will be valid for at least
+ // 2hrs if the SLAAC prefix is valid for at least that time.
+ MinMaxTempAddrValidLifetime = 2 * time.Hour
+)
+
+// DHCPv6ConfigurationFromNDPRA is a configuration available via DHCPv6 that an
+// NDP Router Advertisement informed the Stack about.
+type DHCPv6ConfigurationFromNDPRA int
+
+const (
+ _ DHCPv6ConfigurationFromNDPRA = iota
+
+ // DHCPv6NoConfiguration indicates that no configurations are available via
+ // DHCPv6.
+ DHCPv6NoConfiguration
+
+ // DHCPv6ManagedAddress indicates that addresses are available via DHCPv6.
+ //
+ // DHCPv6ManagedAddress also implies DHCPv6OtherConfigurations because DHCPv6
+ // will return all available configuration information.
+ DHCPv6ManagedAddress
+
+ // DHCPv6OtherConfigurations indicates that other configuration information is
+ // available via DHCPv6.
+ //
+ // Other configurations are configurations other than addresses. Examples of
+ // other configurations are recursive DNS server list, DNS search lists and
+ // default gateway.
+ DHCPv6OtherConfigurations
)
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
// OnDuplicateAddressDetectionStatus will be called when the DAD process
- // for an address (addr) on a NIC (with ID nicid) completes. resolved
+ // for an address (addr) on a NIC (with ID nicID) completes. resolved
// will be set to true if DAD completed successfully (no duplicate addr
// detected); false otherwise (addr was detected to be a duplicate on
// the link the NIC is a part of, or it was stopped for some other
// reason, such as the address being removed). If an error occured
// during DAD, err will be set and resolved must be ignored.
//
- // This function is permitted to block indefinitely without interfering
- // with the stack's operation.
- OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error)
+
+ // OnDefaultRouterDiscovered will be called when a new default router is
+ // discovered. Implementations must return true if the newly discovered
+ // router should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
+
+ // OnDefaultRouterInvalidated will be called when a discovered default
+ // router that was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
+
+ // OnOnLinkPrefixDiscovered will be called when a new on-link prefix is
+ // discovered. Implementations must return true if the newly discovered
+ // on-link prefix should be remembered.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
+
+ // OnOnLinkPrefixInvalidated will be called when a discovered on-link
+ // prefix that was remembered is invalidated.
+ //
+ // This function is not permitted to block indefinitely. This function
+ // is also not permitted to call into the stack.
+ OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
+
+ // OnAutoGenAddress will be called when a new prefix with its
+ // autonomous address-configuration flag set has been received and SLAAC
+ // has been performed. Implementations may prevent the stack from
+ // assigning the address to the NIC by returning false.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddress(tcpip.NICID, tcpip.AddressWithPrefix) bool
+
+ // OnAutoGenAddressDeprecated will be called when an auto-generated
+ // address (as part of SLAAC) has been deprecated, but is still
+ // considered valid. Note, if an address is invalidated at the same
+ // time it is deprecated, the deprecation event MAY be omitted.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressDeprecated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnAutoGenAddressInvalidated will be called when an auto-generated
+ // address (as part of SLAAC) has been invalidated.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnAutoGenAddressInvalidated(tcpip.NICID, tcpip.AddressWithPrefix)
+
+ // OnRecursiveDNSServerOption will be called when an NDP option with
+ // recursive DNS servers has been received. Note, addrs may contain
+ // link-local addresses.
+ //
+ // It is up to the caller to use the DNS Servers only for their valid
+ // lifetime. OnRecursiveDNSServerOption may be called for new or
+ // already known DNS servers. If called with known DNS servers, their
+ // valid lifetimes must be refreshed to lifetime (it may be increased,
+ // decreased, or completely invalidated when lifetime = 0).
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+
+ // OnDNSSearchListOption will be called when an NDP option with a DNS
+ // search list has been received.
+ //
+ // It is up to the caller to use the domain names in the search list
+ // for only their valid lifetime. OnDNSSearchListOption may be called
+ // with new or already known domain names. If called with known domain
+ // names, their valid lifetimes must be refreshed to lifetime (it may
+ // be increased, decreased or completely invalidated when lifetime = 0.
+ OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+
+ // OnDHCPv6Configuration will be called with an updated configuration that is
+ // available via DHCPv6 for a specified NIC.
+ //
+ // This function is not permitted to block indefinitely. It must not
+ // call functions on the stack itself.
+ OnDHCPv6Configuration(tcpip.NICID, DHCPv6ConfigurationFromNDPRA)
}
// NDPConfigurations is the NDP configurations for the netstack.
@@ -78,28 +317,124 @@ type NDPConfigurations struct {
// The amount of time to wait between sending Neighbor solicitation
// messages.
//
- // Must be greater than 0.5s.
+ // Must be greater than or equal to 1ms.
RetransmitTimer time.Duration
+
+ // The number of Router Solicitation messages to send when the NIC
+ // becomes enabled.
+ MaxRtrSolicitations uint8
+
+ // The amount of time between transmitting Router Solicitation messages.
+ //
+ // Must be greater than or equal to 0.5s.
+ RtrSolicitationInterval time.Duration
+
+ // The maximum amount of time before transmitting the first Router
+ // Solicitation message.
+ //
+ // Must be greater than or equal to 0s.
+ MaxRtrSolicitationDelay time.Duration
+
+ // HandleRAs determines whether or not Router Advertisements will be
+ // processed.
+ HandleRAs bool
+
+ // DiscoverDefaultRouters determines whether or not default routers will
+ // be discovered from Router Advertisements. This configuration is
+ // ignored if HandleRAs is false.
+ DiscoverDefaultRouters bool
+
+ // DiscoverOnLinkPrefixes determines whether or not on-link prefixes
+ // will be discovered from Router Advertisements' Prefix Information
+ // option. This configuration is ignored if HandleRAs is false.
+ DiscoverOnLinkPrefixes bool
+
+ // AutoGenGlobalAddresses determines whether or not global IPv6
+ // addresses will be generated for a NIC in response to receiving a new
+ // Prefix Information option with its Autonomous Address
+ // AutoConfiguration flag set, as a host, as per RFC 4862 (SLAAC).
+ //
+ // Note, if an address was already generated for some unique prefix, as
+ // part of SLAAC, this option does not affect whether or not the
+ // lifetime(s) of the generated address changes; this option only
+ // affects the generation of new addresses as part of SLAAC.
+ AutoGenGlobalAddresses bool
+
+ // AutoGenAddressConflictRetries determines how many times to attempt to retry
+ // generation of a permanent auto-generated address in response to DAD
+ // conflicts.
+ //
+ // If the method used to generate the address does not support creating
+ // alternative addresses (e.g. IIDs based on the modified EUI64 of a NIC's
+ // MAC address), then no attempt will be made to resolve the conflict.
+ AutoGenAddressConflictRetries uint8
+
+ // AutoGenTempGlobalAddresses determines whether or not temporary SLAAC
+ // addresses will be generated for a NIC as part of SLAAC privacy extensions,
+ // RFC 4941.
+ //
+ // Ignored if AutoGenGlobalAddresses is false.
+ AutoGenTempGlobalAddresses bool
+
+ // MaxTempAddrValidLifetime is the maximum valid lifetime for temporary
+ // SLAAC addresses.
+ MaxTempAddrValidLifetime time.Duration
+
+ // MaxTempAddrPreferredLifetime is the maximum preferred lifetime for
+ // temporary SLAAC addresses.
+ MaxTempAddrPreferredLifetime time.Duration
+
+ // RegenAdvanceDuration is the duration before the deprecation of a temporary
+ // address when a new address will be generated.
+ RegenAdvanceDuration time.Duration
}
// DefaultNDPConfigurations returns an NDPConfigurations populated with
// default values.
func DefaultNDPConfigurations() NDPConfigurations {
return NDPConfigurations{
- DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
- RetransmitTimer: defaultRetransmitTimer,
+ DupAddrDetectTransmits: defaultDupAddrDetectTransmits,
+ RetransmitTimer: defaultRetransmitTimer,
+ MaxRtrSolicitations: defaultMaxRtrSolicitations,
+ RtrSolicitationInterval: defaultRtrSolicitationInterval,
+ MaxRtrSolicitationDelay: defaultMaxRtrSolicitationDelay,
+ HandleRAs: defaultHandleRAs,
+ DiscoverDefaultRouters: defaultDiscoverDefaultRouters,
+ DiscoverOnLinkPrefixes: defaultDiscoverOnLinkPrefixes,
+ AutoGenGlobalAddresses: defaultAutoGenGlobalAddresses,
+ AutoGenTempGlobalAddresses: defaultAutoGenTempGlobalAddresses,
+ MaxTempAddrValidLifetime: defaultMaxTempAddrValidLifetime,
+ MaxTempAddrPreferredLifetime: defaultMaxTempAddrPreferredLifetime,
+ RegenAdvanceDuration: defaultRegenAdvanceDuration,
}
}
// validate modifies an NDPConfigurations with valid values. If invalid values
// are present in c, the corresponding default values will be used instead.
-//
-// If RetransmitTimer is less than minimumRetransmitTimer, then a value of
-// defaultRetransmitTimer will be used.
func (c *NDPConfigurations) validate() {
if c.RetransmitTimer < minimumRetransmitTimer {
c.RetransmitTimer = defaultRetransmitTimer
}
+
+ if c.RtrSolicitationInterval < minimumRtrSolicitationInterval {
+ c.RtrSolicitationInterval = defaultRtrSolicitationInterval
+ }
+
+ if c.MaxRtrSolicitationDelay < minimumMaxRtrSolicitationDelay {
+ c.MaxRtrSolicitationDelay = defaultMaxRtrSolicitationDelay
+ }
+
+ if c.MaxTempAddrValidLifetime < MinMaxTempAddrValidLifetime {
+ c.MaxTempAddrValidLifetime = MinMaxTempAddrValidLifetime
+ }
+
+ if c.MaxTempAddrPreferredLifetime < MinMaxTempAddrPreferredLifetime || c.MaxTempAddrPreferredLifetime > c.MaxTempAddrValidLifetime {
+ c.MaxTempAddrPreferredLifetime = MinMaxTempAddrPreferredLifetime
+ }
+
+ if c.RegenAdvanceDuration < minRegenAdvanceDuration {
+ c.RegenAdvanceDuration = minRegenAdvanceDuration
+ }
}
// ndpState is the per-interface NDP state.
@@ -112,13 +447,47 @@ type ndpState struct {
// The DAD state to send the next NS message, or resolve the address.
dad map[tcpip.Address]dadState
+
+ // The default routers discovered through Router Advertisements.
+ defaultRouters map[tcpip.Address]defaultRouterState
+
+ rtrSolicit struct {
+ // The timer used to send the next router solicitation message.
+ timer tcpip.Timer
+
+ // Used to let the Router Solicitation timer know that it has been stopped.
+ //
+ // Must only be read from or written to while protected by the lock of
+ // the NIC this ndpState is associated with. MUST be set when the timer is
+ // set.
+ done *bool
+ }
+
+ // The on-link prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ onLinkPrefixes map[tcpip.Subnet]onLinkPrefixState
+
+ // The SLAAC prefixes discovered through Router Advertisements' Prefix
+ // Information option.
+ slaacPrefixes map[tcpip.Subnet]slaacPrefixState
+
+ // The last learned DHCPv6 configuration from an NDP RA.
+ dhcpv6Configuration DHCPv6ConfigurationFromNDPRA
+
+ // temporaryIIDHistory is the history value used to generate a new temporary
+ // IID.
+ temporaryIIDHistory [header.IIDSize]byte
+
+ // temporaryAddressDesyncFactor is the preferred lifetime's desync factor for
+ // temporary SLAAC addresses.
+ temporaryAddressDesyncFactor time.Duration
}
// dadState holds the Duplicate Address Detection timer and channel to signal
// to the DAD goroutine that DAD should stop.
type dadState struct {
// The DAD timer to send the next NS message, or resolve the address.
- timer *time.Timer
+ timer tcpip.Timer
// Used to let the DAD timer know that it has been stopped.
//
@@ -127,6 +496,102 @@ type dadState struct {
done *bool
}
+// defaultRouterState holds data associated with a default router discovered by
+// a Router Advertisement (RA).
+type defaultRouterState struct {
+ // Job to invalidate the default router.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+}
+
+// onLinkPrefixState holds data associated with an on-link prefix discovered by
+// a Router Advertisement's Prefix Information option (PI) when the NDP
+// configurations was configured to do so.
+type onLinkPrefixState struct {
+ // Job to invalidate the on-link prefix.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+}
+
+// tempSLAACAddrState holds state associated with a temporary SLAAC address.
+type tempSLAACAddrState struct {
+ // Job to deprecate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ deprecationJob *tcpip.Job
+
+ // Job to invalidate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+
+ // Job to regenerate the temporary SLAAC address.
+ //
+ // Must not be nil.
+ regenJob *tcpip.Job
+
+ createdAt time.Time
+
+ // The address's endpoint.
+ //
+ // Must not be nil.
+ ref *referencedNetworkEndpoint
+
+ // Has a new temporary SLAAC address already been regenerated?
+ regenerated bool
+}
+
+// slaacPrefixState holds state associated with a SLAAC prefix.
+type slaacPrefixState struct {
+ // Job to deprecate the prefix.
+ //
+ // Must not be nil.
+ deprecationJob *tcpip.Job
+
+ // Job to invalidate the prefix.
+ //
+ // Must not be nil.
+ invalidationJob *tcpip.Job
+
+ // Nonzero only when the address is not valid forever.
+ validUntil time.Time
+
+ // Nonzero only when the address is not preferred forever.
+ preferredUntil time.Time
+
+ // State associated with the stable address generated for the prefix.
+ stableAddr struct {
+ // The address's endpoint.
+ //
+ // May only be nil when the address is being (re-)generated. Otherwise,
+ // must not be nil as all SLAAC prefixes must have a stable address.
+ ref *referencedNetworkEndpoint
+
+ // The number of times an address has been generated locally where the NIC
+ // already had the generated address.
+ localGenerationFailures uint8
+ }
+
+ // The temporary (short-lived) addresses generated for the SLAAC prefix.
+ tempAddrs map[tcpip.Address]tempSLAACAddrState
+
+ // The next two fields are used by both stable and temporary addresses
+ // generated for a SLAAC prefix. This is safe as only 1 address will be
+ // in the generation and DAD process at any time. That is, no two addresses
+ // will be generated at the same time for a given SLAAC prefix.
+
+ // The number of times an address has been generated and added to the NIC.
+ //
+ // Addresses may be regenerated in reseponse to a DAD conflicts.
+ generationAttempts uint8
+
+ // The maximum number of times to attempt regeneration of a SLAAC address
+ // in response to DAD conflicts.
+ maxGenerationAttempts uint8
+}
+
// startDuplicateAddressDetection performs Duplicate Address Detection.
//
// This function must only be called by IPv6 addresses that are currently
@@ -139,87 +604,110 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
return tcpip.ErrAddressFamilyNotSupported
}
- // Should not attempt to perform DAD on an address that is currently in
- // the DAD process.
+ if ref.getKind() != permanentTentative {
+ // The endpoint should be marked as tentative since we are starting DAD.
+ panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
+ }
+
+ // Should not attempt to perform DAD on an address that is currently in the
+ // DAD process.
if _, ok := ndp.dad[addr]; ok {
- // Should never happen because we should only ever call this
- // function for newly created addresses. If we attemped to
- // "add" an address that already existed, we would returned an
- // error since we attempted to add a duplicate address, or its
- // reference count would have been increased without doing the
- // work that would have been done for an address that was brand
- // new. See NIC.addPermanentAddressLocked.
+ // Should never happen because we should only ever call this function for
+ // newly created addresses. If we attemped to "add" an address that already
+ // existed, we would get an error since we attempted to add a duplicate
+ // address, or its reference count would have been increased without doing
+ // the work that would have been done for an address that was brand new.
+ // See NIC.addAddressLocked.
panic(fmt.Sprintf("ndpdad: already performing DAD for addr %s on NIC(%d)", addr, ndp.nic.ID()))
}
remaining := ndp.configs.DupAddrDetectTransmits
+ if remaining == 0 {
+ ref.setKind(permanent)
- {
- done, err := ndp.doDuplicateAddressDetection(addr, remaining, ref)
- if err != nil {
- return err
- }
- if done {
- return nil
+ // Consider DAD to have resolved even if no DAD messages were actually
+ // transmitted.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, true, nil)
}
- }
- remaining--
+ return nil
+ }
var done bool
- var timer *time.Timer
- timer = time.AfterFunc(ndp.configs.RetransmitTimer, func() {
- var d bool
- var err *tcpip.Error
+ var timer tcpip.Timer
+ // We initially start a timer to fire immediately because some of the DAD work
+ // cannot be done while holding the NIC's lock. This is effectively the same
+ // as starting a goroutine but we use a timer that fires immediately so we can
+ // reset it for the next DAD iteration.
+ timer = ndp.nic.stack.Clock().AfterFunc(0, func() {
+ ndp.nic.mu.Lock()
+ defer ndp.nic.mu.Unlock()
- // doDadIteration does a single iteration of the DAD loop.
- //
- // Returns true if the integrator needs to be informed of DAD
- // completing.
- doDadIteration := func() bool {
- ndp.nic.mu.Lock()
- defer ndp.nic.mu.Unlock()
-
- if done {
- // If we reach this point, it means that the DAD
- // timer fired after another goroutine already
- // obtained the NIC lock and stopped DAD before
- // this function obtained the NIC lock. Simply
- // return here and do nothing further.
- return false
- }
+ if done {
+ // If we reach this point, it means that the DAD timer fired after
+ // another goroutine already obtained the NIC lock and stopped DAD
+ // before this function obtained the NIC lock. Simply return here and do
+ // nothing further.
+ return
+ }
- ref, ok := ndp.nic.endpoints[NetworkEndpointID{addr}]
- if !ok {
- // This should never happen.
- // We should have an endpoint for addr since we
- // are still performing DAD on it. If the
- // endpoint does not exist, but we are doing DAD
- // on it, then we started DAD at some point, but
- // forgot to stop it when the endpoint was
- // deleted.
- panic(fmt.Sprintf("ndpdad: unrecognized addr %s for NIC(%d)", addr, ndp.nic.ID()))
- }
+ if ref.getKind() != permanentTentative {
+ // The endpoint should still be marked as tentative since we are still
+ // performing DAD on it.
+ panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.nic.ID()))
+ }
- d, err = ndp.doDuplicateAddressDetection(addr, remaining, ref)
- if err != nil || d {
- delete(ndp.dad, addr)
+ dadDone := remaining == 0
- if err != nil {
- log.Printf("ndpdad: Error occured during DAD iteration for addr (%s) on NIC(%d); err = %s", addr, ndp.nic.ID(), err)
- }
+ var err *tcpip.Error
+ if !dadDone {
+ // Use the unspecified address as the source address when performing DAD.
+ ref := ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
- // Let the integrator know DAD has completed.
- return true
- }
+ // Do not hold the lock when sending packets which may be a long running
+ // task or may block link address resolution. We know this is safe
+ // because immediately after obtaining the lock again, we check if DAD
+ // has been stopped before doing any work with the NIC. Note, DAD would be
+ // stopped if the NIC was disabled or removed, or if the address was
+ // removed.
+ ndp.nic.mu.Unlock()
+ err = ndp.sendDADPacket(addr, ref)
+ ndp.nic.mu.Lock()
+ }
+ if done {
+ // If we reach this point, it means that DAD was stopped after we released
+ // the NIC's read lock and before we obtained the write lock.
+ return
+ }
+
+ if dadDone {
+ // DAD has resolved.
+ ref.setKind(permanent)
+ } else if err == nil {
+ // DAD is not done and we had no errors when sending the last NDP NS,
+ // schedule the next DAD timer.
remaining--
timer.Reset(ndp.nic.stack.ndpConfigs.RetransmitTimer)
- return false
+ return
+ }
+
+ // At this point we know that either DAD is done or we hit an error sending
+ // the last NDP NS. Either way, clean up addr's DAD state and let the
+ // integrator know DAD has completed.
+ delete(ndp.dad, addr)
+
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, dadDone, err)
}
- if doDadIteration() && ndp.nic.stack.ndpDisp != nil {
- ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, d, err)
+ // If DAD resolved for a stable SLAAC address, attempt generation of a
+ // temporary SLAAC address.
+ if dadDone && ref.configType == slaac {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.regenerateTempSLAACAddr(ref.addrWithPrefix().Subnet(), true /* resetGenAttempts */)
}
})
@@ -231,60 +719,58 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref
return nil
}
-// doDuplicateAddressDetection is called on every iteration of the timer, and
-// when DAD starts.
-//
-// It handles resolving the address (if there are no more NS to send), or
-// sending the next NS if there are more NS to send.
-//
-// This function must only be called by IPv6 addresses that are currently
-// tentative.
+// sendDADPacket sends a NS message to see if any nodes on ndp's NIC's link owns
+// addr.
//
-// The NIC that ndp belongs to (n) MUST be locked.
+// addr must be a tentative IPv6 address on ndp's NIC.
//
-// Returns true if DAD has resolved; false if DAD is still ongoing.
-func (ndp *ndpState) doDuplicateAddressDetection(addr tcpip.Address, remaining uint8, ref *referencedNetworkEndpoint) (bool, *tcpip.Error) {
- if ref.getKind() != permanentTentative {
- // The endpoint should still be marked as tentative
- // since we are still performing DAD on it.
- panic(fmt.Sprintf("ndpdad: addr %s is not tentative on NIC(%d)", addr, ndp.nic.ID()))
- }
+// The NIC ndp belongs to MUST NOT be locked.
+func (ndp *ndpState) sendDADPacket(addr tcpip.Address, ref *referencedNetworkEndpoint) *tcpip.Error {
+ snmc := header.SolicitedNodeAddr(addr)
- if remaining == 0 {
- // DAD has resolved.
- ref.setKind(permanent)
- return true, nil
- }
+ r := makeRoute(header.IPv6ProtocolNumber, ref.address(), snmc, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
- // Send a new NS.
- snmc := header.SolicitedNodeAddr(addr)
- snmcRef, ok := ndp.nic.endpoints[NetworkEndpointID{snmc}]
- if !ok {
- // This should never happen as if we have the
- // address, we should have the solicited-node
- // address.
- panic(fmt.Sprintf("ndpdad: NIC(%d) is not in the solicited-node multicast group (%s) but it has addr %s", ndp.nic.ID(), snmc, addr))
- }
+ // Route should resolve immediately since snmc is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err != tcpip.ErrInvalidEndpointState {
+ return err
+ }
- // Use the unspecified address as the source address when performing
- // DAD.
- r := makeRoute(header.IPv6ProtocolNumber, header.IPv6Any, snmc, ndp.nic.linkEP.LinkAddress(), snmcRef, false, false)
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP NS for DAD (%s -> %s on NIC(%d)): %s", header.IPv6Any, snmc, ndp.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP NS for DAD (%s -> %s on NIC(%d))", header.IPv6Any, snmc, ndp.nic.ID()))
+ }
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv6NeighborSolicitMinimumSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborSolicitMinimumSize))
- pkt.SetType(header.ICMPv6NeighborSolicit)
- ns := header.NDPNeighborSolicit(pkt.NDPPayload())
+ icmpData := header.ICMPv6(buffer.NewView(header.ICMPv6NeighborSolicitMinimumSize))
+ icmpData.SetType(header.ICMPv6NeighborSolicit)
+ ns := header.NDPNeighborSolicit(icmpData.NDPPayload())
ns.SetTargetAddress(addr)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
sent := r.Stats().ICMP.V6PacketsSent
- if err := r.WritePacket(nil, hdr, buffer.VectorisedView{}, NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, TOS: DefaultTOS}); err != nil {
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, pkt,
+ ); err != nil {
sent.Dropped.Increment()
- return false, err
+ return err
}
sent.NeighborSolicit.Increment()
- return false, nil
+ return nil
}
// stopDuplicateAddressDetection ends a running Duplicate Address Detection
@@ -315,7 +801,1173 @@ func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address) {
delete(ndp.dad, addr)
// Let the integrator know DAD did not resolve.
- if ndp.nic.stack.ndpDisp != nil {
- go ndp.nic.stack.ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDuplicateAddressDetectionStatus(ndp.nic.ID(), addr, false, nil)
+ }
+}
+
+// handleRA handles a Router Advertisement message that arrived on the NIC
+// this ndp is for. Does nothing if the NIC is configured to not handle RAs.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ // Is the NIC configured to handle RAs at all?
+ //
+ // Currently, the stack does not determine router interface status on a
+ // per-interface basis; it is a stack-wide configuration, so we check
+ // stack's forwarding flag to determine if the NIC is a routing
+ // interface.
+ if !ndp.configs.HandleRAs || ndp.nic.stack.Forwarding(header.IPv6ProtocolNumber) {
+ return
+ }
+
+ // Only worry about the DHCPv6 configuration if we have an NDPDispatcher as we
+ // only inform the dispatcher on configuration changes. We do nothing else
+ // with the information.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ var configuration DHCPv6ConfigurationFromNDPRA
+ switch {
+ case ra.ManagedAddrConfFlag():
+ configuration = DHCPv6ManagedAddress
+
+ case ra.OtherConfFlag():
+ configuration = DHCPv6OtherConfigurations
+
+ default:
+ configuration = DHCPv6NoConfiguration
+ }
+
+ if ndp.dhcpv6Configuration != configuration {
+ ndp.dhcpv6Configuration = configuration
+ ndpDisp.OnDHCPv6Configuration(ndp.nic.ID(), configuration)
+ }
+ }
+
+ // Is the NIC configured to discover default routers?
+ if ndp.configs.DiscoverDefaultRouters {
+ rtr, ok := ndp.defaultRouters[ip]
+ rl := ra.RouterLifetime()
+ switch {
+ case !ok && rl != 0:
+ // This is a new default router we are discovering.
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredDefaultRouters routers.
+ if len(ndp.defaultRouters) < MaxDiscoveredDefaultRouters {
+ ndp.rememberDefaultRouter(ip, rl)
+ }
+
+ case ok && rl != 0:
+ // This is an already discovered default router. Update
+ // the invalidation job.
+ rtr.invalidationJob.Cancel()
+ rtr.invalidationJob.Schedule(rl)
+ ndp.defaultRouters[ip] = rtr
+
+ case ok && rl == 0:
+ // We know about the router but it is no longer to be
+ // used as a default router so invalidate it.
+ ndp.invalidateDefaultRouter(ip)
+ }
+ }
+
+ // TODO(b/141556115): Do (RetransTimer, ReachableTime)) Parameter
+ // Discovery.
+
+ // We know the options is valid as far as wire format is concerned since
+ // we got the Router Advertisement, as documented by this fn. Given this
+ // we do not check the iterator for errors on calls to Next.
+ it, _ := ra.Options().Iter(false)
+ for opt, done, _ := it.Next(); !done; opt, done, _ = it.Next() {
+ switch opt := opt.(type) {
+ case header.NDPRecursiveDNSServer:
+ if ndp.nic.stack.ndpDisp == nil {
+ continue
+ }
+
+ addrs, _ := opt.Addresses()
+ ndp.nic.stack.ndpDisp.OnRecursiveDNSServerOption(ndp.nic.ID(), addrs, opt.Lifetime())
+
+ case header.NDPDNSSearchList:
+ if ndp.nic.stack.ndpDisp == nil {
+ continue
+ }
+
+ domainNames, _ := opt.DomainNames()
+ ndp.nic.stack.ndpDisp.OnDNSSearchListOption(ndp.nic.ID(), domainNames, opt.Lifetime())
+
+ case header.NDPPrefixInformation:
+ prefix := opt.Subnet()
+
+ // Is the prefix a link-local?
+ if header.IsV6LinkLocalAddress(prefix.ID()) {
+ // ...Yes, skip as per RFC 4861 section 6.3.4,
+ // and RFC 4862 section 5.5.3.b (for SLAAC).
+ continue
+ }
+
+ // Is the Prefix Length 0?
+ if prefix.Prefix() == 0 {
+ // ...Yes, skip as this is an invalid prefix
+ // as all IPv6 addresses cannot be on-link.
+ continue
+ }
+
+ if opt.OnLinkFlag() {
+ ndp.handleOnLinkPrefixInformation(opt)
+ }
+
+ if opt.AutonomousAddressConfigurationFlag() {
+ ndp.handleAutonomousPrefixInformation(opt)
+ }
+ }
+
+ // TODO(b/141556115): Do (MTU) Parameter Discovery.
+ }
+}
+
+// invalidateDefaultRouter invalidates a discovered default router.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) {
+ rtr, ok := ndp.defaultRouters[ip]
+
+ // Is the router still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ rtr.invalidationJob.Cancel()
+ delete(ndp.defaultRouters, ip)
+
+ // Let the integrator know a discovered default router is invalidated.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnDefaultRouterInvalidated(ndp.nic.ID(), ip)
+ }
+}
+
+// rememberDefaultRouter remembers a newly discovered default router with IPv6
+// link-local address ip with lifetime rl.
+//
+// The router identified by ip MUST NOT already be known by the NIC.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered a default router.
+ if !ndpDisp.OnDefaultRouterDiscovered(ndp.nic.ID(), ip) {
+ // Informed by the integrator to not remember the router, do
+ // nothing further.
+ return
+ }
+
+ state := defaultRouterState{
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ ndp.invalidateDefaultRouter(ip)
+ }),
+ }
+
+ state.invalidationJob.Schedule(rl)
+
+ ndp.defaultRouters[ip] = state
+}
+
+// rememberOnLinkPrefix remembers a newly discovered on-link prefix with IPv6
+// address with prefix prefix with lifetime l.
+//
+// The prefix identified by prefix MUST NOT already be known.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) {
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return
+ }
+
+ // Inform the integrator when we discovered an on-link prefix.
+ if !ndpDisp.OnOnLinkPrefixDiscovered(ndp.nic.ID(), prefix) {
+ // Informed by the integrator to not remember the prefix, do
+ // nothing further.
+ return
+ }
+
+ state := onLinkPrefixState{
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }),
+ }
+
+ if l < header.NDPInfiniteLifetime {
+ state.invalidationJob.Schedule(l)
+ }
+
+ ndp.onLinkPrefixes[prefix] = state
+}
+
+// invalidateOnLinkPrefix invalidates a discovered on-link prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) {
+ s, ok := ndp.onLinkPrefixes[prefix]
+
+ // Is the on-link prefix still discovered?
+ if !ok {
+ // ...Nope, do nothing further.
+ return
+ }
+
+ s.invalidationJob.Cancel()
+ delete(ndp.onLinkPrefixes, prefix)
+
+ // Let the integrator know a discovered on-link prefix is invalidated.
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnOnLinkPrefixInvalidated(ndp.nic.ID(), prefix)
+ }
+}
+
+// handleOnLinkPrefixInformation handles a Prefix Information option with
+// its on-link flag set, as per RFC 4861 section 6.3.4.
+//
+// handleOnLinkPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the on-link flag is set.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformation) {
+ prefix := pi.Subnet()
+ prefixState, ok := ndp.onLinkPrefixes[prefix]
+ vl := pi.ValidLifetime()
+
+ if !ok && vl == 0 {
+ // Don't know about this prefix but it has a zero valid
+ // lifetime, so just ignore.
+ return
+ }
+
+ if !ok && vl != 0 {
+ // This is a new on-link prefix we are discovering
+ //
+ // Only remember it if we currently know about less than
+ // MaxDiscoveredOnLinkPrefixes on-link prefixes.
+ if ndp.configs.DiscoverOnLinkPrefixes && len(ndp.onLinkPrefixes) < MaxDiscoveredOnLinkPrefixes {
+ ndp.rememberOnLinkPrefix(prefix, vl)
+ }
+ return
+ }
+
+ if ok && vl == 0 {
+ // We know about the on-link prefix, but it is
+ // no longer to be considered on-link, so
+ // invalidate it.
+ ndp.invalidateOnLinkPrefix(prefix)
+ return
+ }
+
+ // This is an already discovered on-link prefix with a
+ // new non-zero valid lifetime.
+ //
+ // Update the invalidation job.
+
+ prefixState.invalidationJob.Cancel()
+
+ if vl < header.NDPInfiniteLifetime {
+ // Prefix is valid for a finite lifetime, schedule the job to execute after
+ // the new valid lifetime.
+ prefixState.invalidationJob.Schedule(vl)
+ }
+
+ ndp.onLinkPrefixes[prefix] = prefixState
+}
+
+// handleAutonomousPrefixInformation handles a Prefix Information option with
+// its autonomous flag set, as per RFC 4862 section 5.5.3.
+//
+// handleAutonomousPrefixInformation assumes that the prefix this pi is for is
+// not the link-local prefix and the autonomous flag is set.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) handleAutonomousPrefixInformation(pi header.NDPPrefixInformation) {
+ vl := pi.ValidLifetime()
+ pl := pi.PreferredLifetime()
+
+ // If the preferred lifetime is greater than the valid lifetime,
+ // silently ignore the Prefix Information option, as per RFC 4862
+ // section 5.5.3.c.
+ if pl > vl {
+ return
+ }
+
+ prefix := pi.Subnet()
+
+ // Check if we already maintain SLAAC state for prefix.
+ if state, ok := ndp.slaacPrefixes[prefix]; ok {
+ // As per RFC 4862 section 5.5.3.e, refresh prefix's SLAAC lifetimes.
+ ndp.refreshSLAACPrefixLifetimes(prefix, &state, pl, vl)
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // prefix is a new SLAAC prefix. Do the work as outlined by RFC 4862 section
+ // 5.5.3.d if ndp is configured to auto-generate new addresses via SLAAC.
+ if !ndp.configs.AutoGenGlobalAddresses {
+ return
+ }
+
+ ndp.doSLAAC(prefix, pl, vl)
+}
+
+// doSLAAC generates a new SLAAC address with the provided lifetimes
+// for prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) {
+ // If we do not already have an address for this prefix and the valid
+ // lifetime is 0, no need to do anything further, as per RFC 4862
+ // section 5.5.3.d.
+ if vl == 0 {
+ return
+ }
+
+ // Make sure the prefix is valid (as far as its length is concerned) to
+ // generate a valid IPv6 address from an interface identifier (IID), as
+ // per RFC 4862 sectiion 5.5.3.d.
+ if prefix.Prefix() != validPrefixLenForAutoGen {
+ return
+ }
+
+ state := slaacPrefixState{
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix))
+ }
+
+ ndp.deprecateSLAACAddress(state.stableAddr.ref)
+ }),
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix))
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }),
+ tempAddrs: make(map[tcpip.Address]tempSLAACAddrState),
+ maxGenerationAttempts: ndp.configs.AutoGenAddressConflictRetries + 1,
+ }
+
+ now := time.Now()
+
+ // The time an address is preferred until is needed to properly generate the
+ // address.
+ if pl < header.NDPInfiniteLifetime {
+ state.preferredUntil = now.Add(pl)
+ }
+
+ if !ndp.generateSLAACAddr(prefix, &state) {
+ // We were unable to generate an address for the prefix, we do not nothing
+ // further as there is no reason to maintain state or jobs for a prefix we
+ // do not have an address for.
+ return
+ }
+
+ // Setup the initial jobs to deprecate and invalidate prefix.
+
+ if pl < header.NDPInfiniteLifetime && pl != 0 {
+ state.deprecationJob.Schedule(pl)
+ }
+
+ if vl < header.NDPInfiniteLifetime {
+ state.invalidationJob.Schedule(vl)
+ state.validUntil = now.Add(vl)
+ }
+
+ // If the address is assigned (DAD resolved), generate a temporary address.
+ if state.stableAddr.ref.getKind() == permanent {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ ndp.generateTempSLAACAddr(prefix, &state, true /* resetGenAttempts */)
+ }
+
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// addSLAACAddr adds a SLAAC address to the NIC.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) addSLAACAddr(addr tcpip.AddressWithPrefix, configType networkEndpointConfigType, deprecated bool) *referencedNetworkEndpoint {
+ // Inform the integrator that we have a new SLAAC address.
+ ndpDisp := ndp.nic.stack.ndpDisp
+ if ndpDisp == nil {
+ return nil
+ }
+
+ if !ndpDisp.OnAutoGenAddress(ndp.nic.ID(), addr) {
+ // Informed by the integrator not to add the address.
+ return nil
+ }
+
+ protocolAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr,
+ }
+
+ ref, err := ndp.nic.addAddressLocked(protocolAddr, FirstPrimaryEndpoint, permanent, configType, deprecated)
+ if err != nil {
+ panic(fmt.Sprintf("ndp: error when adding SLAAC address %+v: %s", protocolAddr, err))
+ }
+
+ return ref
+}
+
+// generateSLAACAddr generates a SLAAC address for prefix.
+//
+// Returns true if an address was successfully generated.
+//
+// Panics if the prefix is not a SLAAC prefix or it already has an address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) generateSLAACAddr(prefix tcpip.Subnet, state *slaacPrefixState) bool {
+ if r := state.stableAddr.ref; r != nil {
+ panic(fmt.Sprintf("ndp: SLAAC prefix %s already has a permenant address %s", prefix, r.addrWithPrefix()))
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if state.generationAttempts == state.maxGenerationAttempts {
+ return false
+ }
+
+ var generatedAddr tcpip.AddressWithPrefix
+ addrBytes := []byte(prefix.ID())
+
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ dadCounter := state.generationAttempts + state.stableAddr.localGenerationFailures
+ if oIID := ndp.nic.stack.opaqueIIDOpts; oIID.NICNameFromID != nil {
+ addrBytes = header.AppendOpaqueInterfaceIdentifier(
+ addrBytes[:header.IIDOffsetInIPv6Address],
+ prefix,
+ oIID.NICNameFromID(ndp.nic.ID(), ndp.nic.name),
+ dadCounter,
+ oIID.SecretKey,
+ )
+ } else if dadCounter == 0 {
+ // Modified-EUI64 based IIDs have no way to resolve DAD conflicts, so if
+ // the DAD counter is non-zero, we cannot use this method.
+ //
+ // Only attempt to generate an interface-specific IID if we have a valid
+ // link address.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ linkAddr := ndp.nic.linkEP.LinkAddress()
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return false
+ }
+
+ // Generate an address within prefix from the modified EUI-64 of ndp's
+ // NIC's Ethernet MAC address.
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ } else {
+ // We have no way to regenerate an address in response to an address
+ // conflict when addresses are not generated with opaque IIDs.
+ return false
+ }
+
+ generatedAddr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: validPrefixLenForAutoGen,
+ }
+
+ if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ break
+ }
+
+ state.stableAddr.localGenerationFailures++
+ }
+
+ if ref := ndp.addSLAACAddr(generatedAddr, slaac, time.Since(state.preferredUntil) >= 0 /* deprecated */); ref != nil {
+ state.stableAddr.ref = ref
+ state.generationAttempts++
+ return true
+ }
+
+ return false
+}
+
+// regenerateSLAACAddr regenerates an address for a SLAAC prefix.
+//
+// If generating a new address for the prefix fails, the prefix will be
+// invalidated.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateSLAACAddr(prefix tcpip.Subnet) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate address for %s", prefix))
+ }
+
+ if ndp.generateSLAACAddr(prefix, &state) {
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ // We were unable to generate a permanent address for the SLAAC prefix so
+ // invalidate the prefix as there is no reason to maintain state for a
+ // SLAAC prefix we do not have an address for.
+ ndp.invalidateSLAACPrefix(prefix, state)
+}
+
+// generateTempSLAACAddr generates a new temporary SLAAC address.
+//
+// If resetGenAttempts is true, the prefix's generation counter will be reset.
+//
+// Returns true if a new address was generated.
+func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *slaacPrefixState, resetGenAttempts bool) bool {
+ // Are we configured to auto-generate new temporary global addresses for the
+ // prefix?
+ if !ndp.configs.AutoGenTempGlobalAddresses || prefix == header.IPv6LinkLocalPrefix.Subnet() {
+ return false
+ }
+
+ if resetGenAttempts {
+ prefixState.generationAttempts = 0
+ prefixState.maxGenerationAttempts = ndp.configs.AutoGenAddressConflictRetries + 1
+ }
+
+ // If we have already reached the maximum address generation attempts for the
+ // prefix, do not generate another address.
+ if prefixState.generationAttempts == prefixState.maxGenerationAttempts {
+ return false
+ }
+
+ stableAddr := prefixState.stableAddr.ref.address()
+ now := time.Now()
+
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime.
+ vl := ndp.configs.MaxTempAddrValidLifetime
+ if prefixState.validUntil != (time.Time{}) {
+ if prefixVL := prefixState.validUntil.Sub(now); vl > prefixVL {
+ vl = prefixVL
+ }
+ }
+
+ if vl <= 0 {
+ // Cannot create an address without a valid lifetime.
+ return false
+ }
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or the
+ // maximum temporary address preferred lifetime - the temporary address desync
+ // factor.
+ pl := ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor
+ if prefixState.preferredUntil != (time.Time{}) {
+ if prefixPL := prefixState.preferredUntil.Sub(now); pl > prefixPL {
+ // Respect the preferred lifetime of the prefix, as per RFC 4941 section
+ // 3.3 step 4.
+ pl = prefixPL
+ }
+ }
+
+ // As per RFC 4941 section 3.3 step 5, a temporary address is created only if
+ // the calculated preferred lifetime is greater than the advance regeneration
+ // duration. In particular, we MUST NOT create a temporary address with a zero
+ // Preferred Lifetime.
+ if pl <= ndp.configs.RegenAdvanceDuration {
+ return false
+ }
+
+ // Attempt to generate a new address that is not already assigned to the NIC.
+ var generatedAddr tcpip.AddressWithPrefix
+ for i := 0; ; i++ {
+ // If we were unable to generate an address after the maximum SLAAC address
+ // local regeneration attempts, do nothing further.
+ if i == maxSLAACAddrLocalRegenAttempts {
+ return false
+ }
+
+ generatedAddr = header.GenerateTempIPv6SLAACAddr(ndp.temporaryIIDHistory[:], stableAddr)
+ if !ndp.nic.hasPermanentAddrLocked(generatedAddr.Address) {
+ break
+ }
+ }
+
+ // As per RFC RFC 4941 section 3.3 step 5, we MUST NOT create a temporary
+ // address with a zero preferred lifetime. The checks above ensure this
+ // so we know the address is not deprecated.
+ ref := ndp.addSLAACAddr(generatedAddr, slaacTemp, false /* deprecated */)
+ if ref == nil {
+ return false
+ }
+
+ state := tempSLAACAddrState{
+ deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to deprecate temporary address %s", generatedAddr))
+ }
+
+ ndp.deprecateSLAACAddress(tempAddrState.ref)
+ }),
+ invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to invalidate temporary address %s", generatedAddr))
+ }
+
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState)
+ }),
+ regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() {
+ prefixState, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr))
+ }
+
+ tempAddrState, ok := prefixState.tempAddrs[generatedAddr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to regenerate temporary address after %s", generatedAddr))
+ }
+
+ // If an address has already been regenerated for this address, don't
+ // regenerate another address.
+ if tempAddrState.regenerated {
+ return
+ }
+
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ tempAddrState.regenerated = ndp.generateTempSLAACAddr(prefix, &prefixState, true /* resetGenAttempts */)
+ prefixState.tempAddrs[generatedAddr.Address] = tempAddrState
+ ndp.slaacPrefixes[prefix] = prefixState
+ }),
+ createdAt: now,
+ ref: ref,
+ }
+
+ state.deprecationJob.Schedule(pl)
+ state.invalidationJob.Schedule(vl)
+ state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration)
+
+ prefixState.generationAttempts++
+ prefixState.tempAddrs[generatedAddr.Address] = state
+
+ return true
+}
+
+// regenerateTempSLAACAddr regenerates a temporary address for a SLAAC prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) regenerateTempSLAACAddr(prefix tcpip.Subnet, resetGenAttempts bool) {
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: SLAAC prefix state not found to regenerate temporary address for %s", prefix))
+ }
+
+ ndp.generateTempSLAACAddr(prefix, &state, resetGenAttempts)
+ ndp.slaacPrefixes[prefix] = state
+}
+
+// refreshSLAACPrefixLifetimes refreshes the lifetimes of a SLAAC prefix.
+//
+// pl is the new preferred lifetime. vl is the new valid lifetime.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixState *slaacPrefixState, pl, vl time.Duration) {
+ // If the preferred lifetime is zero, then the prefix should be deprecated.
+ deprecated := pl == 0
+ if deprecated {
+ ndp.deprecateSLAACAddress(prefixState.stableAddr.ref)
+ } else {
+ prefixState.stableAddr.ref.deprecated = false
+ }
+
+ // If prefix was preferred for some finite lifetime before, cancel the
+ // deprecation job so it can be reset.
+ prefixState.deprecationJob.Cancel()
+
+ now := time.Now()
+
+ // Schedule the deprecation job if prefix has a finite preferred lifetime.
+ if pl < header.NDPInfiniteLifetime {
+ if !deprecated {
+ prefixState.deprecationJob.Schedule(pl)
+ }
+ prefixState.preferredUntil = now.Add(pl)
+ } else {
+ prefixState.preferredUntil = time.Time{}
+ }
+
+ // As per RFC 4862 section 5.5.3.e, update the valid lifetime for prefix:
+ //
+ // 1) If the received Valid Lifetime is greater than 2 hours or greater than
+ // RemainingLifetime, set the valid lifetime of the prefix to the
+ // advertised Valid Lifetime.
+ //
+ // 2) If RemainingLifetime is less than or equal to 2 hours, ignore the
+ // advertised Valid Lifetime.
+ //
+ // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours.
+
+ if vl >= header.NDPInfiniteLifetime {
+ // Handle the infinite valid lifetime separately as we do not schedule a
+ // job in this case.
+ prefixState.invalidationJob.Cancel()
+ prefixState.validUntil = time.Time{}
+ } else {
+ var effectiveVl time.Duration
+ var rl time.Duration
+
+ // If the prefix was originally set to be valid forever, assume the
+ // remaining time to be the maximum possible value.
+ if prefixState.validUntil == (time.Time{}) {
+ rl = header.NDPInfiniteLifetime
+ } else {
+ rl = time.Until(prefixState.validUntil)
+ }
+
+ if vl > MinPrefixInformationValidLifetimeForUpdate || vl > rl {
+ effectiveVl = vl
+ } else if rl > MinPrefixInformationValidLifetimeForUpdate {
+ effectiveVl = MinPrefixInformationValidLifetimeForUpdate
+ }
+
+ if effectiveVl != 0 {
+ prefixState.invalidationJob.Cancel()
+ prefixState.invalidationJob.Schedule(effectiveVl)
+ prefixState.validUntil = now.Add(effectiveVl)
+ }
+ }
+
+ // If DAD is not yet complete on the stable address, there is no need to do
+ // work with temporary addresses.
+ if prefixState.stableAddr.ref.getKind() != permanent {
+ return
+ }
+
+ // Note, we do not need to update the entries in the temporary address map
+ // after updating the jobs because the jobs are held as pointers.
+ var regenForAddr tcpip.Address
+ allAddressesRegenerated := true
+ for tempAddr, tempAddrState := range prefixState.tempAddrs {
+ // As per RFC 4941 section 3.3 step 4, the valid lifetime of a temporary
+ // address is the lower of the valid lifetime of the stable address or the
+ // maximum temporary address valid lifetime. Note, the valid lifetime of a
+ // temporary address is relative to the address's creation time.
+ validUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrValidLifetime)
+ if prefixState.validUntil != (time.Time{}) && validUntil.Sub(prefixState.validUntil) > 0 {
+ validUntil = prefixState.validUntil
+ }
+
+ // If the address is no longer valid, invalidate it immediately. Otherwise,
+ // reset the invalidation job.
+ newValidLifetime := validUntil.Sub(now)
+ if newValidLifetime <= 0 {
+ ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState)
+ continue
+ }
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.invalidationJob.Schedule(newValidLifetime)
+
+ // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary
+ // address is the lower of the preferred lifetime of the stable address or
+ // the maximum temporary address preferred lifetime - the temporary address
+ // desync factor. Note, the preferred lifetime of a temporary address is
+ // relative to the address's creation time.
+ preferredUntil := tempAddrState.createdAt.Add(ndp.configs.MaxTempAddrPreferredLifetime - ndp.temporaryAddressDesyncFactor)
+ if prefixState.preferredUntil != (time.Time{}) && preferredUntil.Sub(prefixState.preferredUntil) > 0 {
+ preferredUntil = prefixState.preferredUntil
+ }
+
+ // If the address is no longer preferred, deprecate it immediately.
+ // Otherwise, schedule the deprecation job again.
+ newPreferredLifetime := preferredUntil.Sub(now)
+ tempAddrState.deprecationJob.Cancel()
+ if newPreferredLifetime <= 0 {
+ ndp.deprecateSLAACAddress(tempAddrState.ref)
+ } else {
+ tempAddrState.ref.deprecated = false
+ tempAddrState.deprecationJob.Schedule(newPreferredLifetime)
+ }
+
+ tempAddrState.regenJob.Cancel()
+ if tempAddrState.regenerated {
+ } else {
+ allAddressesRegenerated = false
+
+ if newPreferredLifetime <= ndp.configs.RegenAdvanceDuration {
+ // The new preferred lifetime is less than the advance regeneration
+ // duration so regenerate an address for this temporary address
+ // immediately after we finish iterating over the temporary addresses.
+ regenForAddr = tempAddr
+ } else {
+ tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration)
+ }
+ }
+ }
+
+ // Generate a new temporary address if all of the existing temporary addresses
+ // have been regenerated, or we need to immediately regenerate an address
+ // due to an update in preferred lifetime.
+ //
+ // If each temporay address has already been regenerated, no new temporary
+ // address will be generated. To ensure continuation of temporary SLAAC
+ // addresses, we manually try to regenerate an address here.
+ if len(regenForAddr) != 0 || allAddressesRegenerated {
+ // Reset the generation attempts counter as we are starting the generation
+ // of a new address for the SLAAC prefix.
+ if state, ok := prefixState.tempAddrs[regenForAddr]; ndp.generateTempSLAACAddr(prefix, prefixState, true /* resetGenAttempts */) && ok {
+ state.regenerated = true
+ prefixState.tempAddrs[regenForAddr] = state
+ }
+ }
+}
+
+// deprecateSLAACAddress marks ref as deprecated and notifies the stack's NDP
+// dispatcher that ref has been deprecated.
+//
+// deprecateSLAACAddress does nothing if ref is already deprecated.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) deprecateSLAACAddress(ref *referencedNetworkEndpoint) {
+ if ref.deprecated {
+ return
+ }
+
+ ref.deprecated = true
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressDeprecated(ndp.nic.ID(), ref.addrWithPrefix())
+ }
+}
+
+// invalidateSLAACPrefix invalidates a SLAAC prefix.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefixState) {
+ if r := state.stableAddr.ref; r != nil {
+ // Since we are already invalidating the prefix, do not invalidate the
+ // prefix when removing the address.
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(r, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", r.addrWithPrefix(), err))
+ }
+ }
+
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
+
+// cleanupSLAACAddrResourcesAndNotify cleans up an invalidated SLAAC address's
+// resources.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidatePrefix bool) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok || state.stableAddr.ref == nil || addr.Address != state.stableAddr.ref.address() {
+ return
+ }
+
+ if !invalidatePrefix {
+ // If the prefix is not being invalidated, disassociate the address from the
+ // prefix and do nothing further.
+ state.stableAddr.ref = nil
+ ndp.slaacPrefixes[prefix] = state
+ return
+ }
+
+ ndp.cleanupSLAACPrefixResources(prefix, state)
+}
+
+// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry.
+//
+// Panics if the SLAAC prefix is not known.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaacPrefixState) {
+ // Invalidate all temporary addresses.
+ for tempAddr, tempAddrState := range state.tempAddrs {
+ ndp.invalidateTempSLAACAddr(state.tempAddrs, tempAddr, tempAddrState)
+ }
+
+ state.stableAddr.ref = nil
+ state.deprecationJob.Cancel()
+ state.invalidationJob.Cancel()
+ delete(ndp.slaacPrefixes, prefix)
+}
+
+// invalidateTempSLAACAddr invalidates a temporary SLAAC address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ // Since we are already invalidating the address, do not invalidate the
+ // address when removing the address.
+ if err := ndp.nic.removePermanentIPv6EndpointLocked(tempAddrState.ref, false /* allowSLAACInvalidation */); err != nil {
+ panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.ref.addrWithPrefix(), err))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(tempAddrs, tempAddr, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up an invalidated temporary
+// SLAAC address's resources from ndp.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPrefix, invalidateAddr bool) {
+ if ndpDisp := ndp.nic.stack.ndpDisp; ndpDisp != nil {
+ ndpDisp.OnAutoGenAddressInvalidated(ndp.nic.ID(), addr)
+ }
+
+ if !invalidateAddr {
+ return
+ }
+
+ prefix := addr.Subnet()
+ state, ok := ndp.slaacPrefixes[prefix]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry to clean up temp addr %s resources", addr))
+ }
+
+ tempAddrState, ok := state.tempAddrs[addr.Address]
+ if !ok {
+ panic(fmt.Sprintf("ndp: must have a tempAddr entry to clean up temp addr %s resources", addr))
+ }
+
+ ndp.cleanupTempSLAACAddrResources(state.tempAddrs, addr.Address, tempAddrState)
+}
+
+// cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's
+// jobs and entry.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
+ tempAddrState.deprecationJob.Cancel()
+ tempAddrState.invalidationJob.Cancel()
+ tempAddrState.regenJob.Cancel()
+ delete(tempAddrs, tempAddr)
+}
+
+// cleanupState cleans up ndp's state.
+//
+// If hostOnly is true, then only host-specific state will be cleaned up.
+//
+// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
+// transitioning from a host to a router. This function will invalidate all
+// discovered on-link prefixes, discovered routers, and auto-generated
+// addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address will not be
+// invalidated as routers are also expected to generate a link-local address.
+//
+// The NIC that ndp belongs to MUST be locked.
+func (ndp *ndpState) cleanupState(hostOnly bool) {
+ linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
+ linkLocalPrefixes := 0
+ for prefix, state := range ndp.slaacPrefixes {
+ // RFC 4862 section 5 states that routers are also expected to generate a
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if hostOnly && prefix == linkLocalSubnet {
+ linkLocalPrefixes++
+ continue
+ }
+
+ ndp.invalidateSLAACPrefix(prefix, state)
+ }
+
+ if got := len(ndp.slaacPrefixes); got != linkLocalPrefixes {
+ panic(fmt.Sprintf("ndp: still have non-linklocal SLAAC prefixes after cleaning up; found = %d prefixes, of which %d are link-local", got, linkLocalPrefixes))
+ }
+
+ for prefix := range ndp.onLinkPrefixes {
+ ndp.invalidateOnLinkPrefix(prefix)
+ }
+
+ if got := len(ndp.onLinkPrefixes); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered on-link prefixes after cleaning up; found = %d", got))
+ }
+
+ for router := range ndp.defaultRouters {
+ ndp.invalidateDefaultRouter(router)
+ }
+
+ if got := len(ndp.defaultRouters); got != 0 {
+ panic(fmt.Sprintf("ndp: still have discovered default routers after cleaning up; found = %d", got))
+ }
+
+ ndp.dhcpv6Configuration = 0
+}
+
+// startSolicitingRouters starts soliciting routers, as per RFC 4861 section
+// 6.3.7. If routers are already being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) startSolicitingRouters() {
+ if ndp.rtrSolicit.timer != nil {
+ // We are already soliciting routers.
+ return
+ }
+
+ remaining := ndp.configs.MaxRtrSolicitations
+ if remaining == 0 {
+ return
+ }
+
+ // Calculate the random delay before sending our first RS, as per RFC
+ // 4861 section 6.3.7.
+ var delay time.Duration
+ if ndp.configs.MaxRtrSolicitationDelay > 0 {
+ delay = time.Duration(rand.Int63n(int64(ndp.configs.MaxRtrSolicitationDelay)))
+ }
+
+ var done bool
+ ndp.rtrSolicit.done = &done
+ ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() {
+ ndp.nic.mu.Lock()
+ if done {
+ // If we reach this point, it means that the RS timer fired after another
+ // goroutine already obtained the NIC lock and stopped solicitations.
+ // Simply return here and do nothing further.
+ ndp.nic.mu.Unlock()
+ return
+ }
+
+ // As per RFC 4861 section 4.1, the source of the RS is an address assigned
+ // to the sending interface, or the unspecified address if no address is
+ // assigned to the sending interface.
+ ref := ndp.nic.primaryIPv6EndpointRLocked(header.IPv6AllRoutersMulticastAddress)
+ if ref == nil {
+ ref = ndp.nic.getRefOrCreateTempLocked(header.IPv6ProtocolNumber, header.IPv6Any, NeverPrimaryEndpoint)
+ }
+ ndp.nic.mu.Unlock()
+
+ localAddr := ref.address()
+ r := makeRoute(header.IPv6ProtocolNumber, localAddr, header.IPv6AllRoutersMulticastAddress, ndp.nic.linkEP.LinkAddress(), ref, false, false)
+ defer r.Release()
+
+ // Route should resolve immediately since
+ // header.IPv6AllRoutersMulticastAddress is a multicast address so a
+ // remote link address can be calculated without a resolution process.
+ if c, err := r.Resolve(nil); err != nil {
+ // Do not consider the NIC being unknown or disabled as a fatal error.
+ // Since this method is required to be called when the NIC is not locked,
+ // the NIC could have been disabled or removed by another goroutine.
+ if err == tcpip.ErrUnknownNICID || err == tcpip.ErrInvalidEndpointState {
+ return
+ }
+
+ panic(fmt.Sprintf("ndp: error when resolving route to send NDP RS (%s -> %s on NIC(%d)): %s", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID(), err))
+ } else if c != nil {
+ panic(fmt.Sprintf("ndp: route resolution not immediate for route to send NDP RS (%s -> %s on NIC(%d))", header.IPv6Any, header.IPv6AllRoutersMulticastAddress, ndp.nic.ID()))
+ }
+
+ // As per RFC 4861 section 4.1, an NDP RS SHOULD include the source
+ // link-layer address option if the source address of the NDP RS is
+ // specified. This option MUST NOT be included if the source address is
+ // unspecified.
+ //
+ // TODO(b/141011931): Validate a LinkEndpoint's link address (provided by
+ // LinkEndpoint.LinkAddress) before reaching this point.
+ var optsSerializer header.NDPOptionsSerializer
+ if localAddr != header.IPv6Any && header.IsValidUnicastEthernetAddress(r.LocalLinkAddress) {
+ optsSerializer = header.NDPOptionsSerializer{
+ header.NDPSourceLinkLayerAddressOption(r.LocalLinkAddress),
+ }
+ }
+ payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize + int(optsSerializer.Length())
+ icmpData := header.ICMPv6(buffer.NewView(payloadSize))
+ icmpData.SetType(header.ICMPv6RouterSolicit)
+ rs := header.NDPRouterSolicit(icmpData.NDPPayload())
+ rs.Options().Serialize(optsSerializer)
+ icmpData.SetChecksum(header.ICMPv6Checksum(icmpData, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ pkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: buffer.View(icmpData).ToVectorisedView(),
+ })
+
+ sent := r.Stats().ICMP.V6PacketsSent
+ if err := r.WritePacket(nil,
+ NetworkHeaderParams{
+ Protocol: header.ICMPv6ProtocolNumber,
+ TTL: header.NDPHopLimit,
+ TOS: DefaultTOS,
+ }, pkt,
+ ); err != nil {
+ sent.Dropped.Increment()
+ log.Printf("startSolicitingRouters: error writing NDP router solicit message on NIC(%d); err = %s", ndp.nic.ID(), err)
+ // Don't send any more messages if we had an error.
+ remaining = 0
+ } else {
+ sent.RouterSolicit.Increment()
+ remaining--
+ }
+
+ ndp.nic.mu.Lock()
+ if done || remaining == 0 {
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+ } else if ndp.rtrSolicit.timer != nil {
+ // Note, we need to explicitly check to make sure that
+ // the timer field is not nil because if it was nil but
+ // we still reached this point, then we know the NIC
+ // was requested to stop soliciting routers so we don't
+ // need to send the next Router Solicitation message.
+ ndp.rtrSolicit.timer.Reset(ndp.configs.RtrSolicitationInterval)
+ }
+ ndp.nic.mu.Unlock()
+ })
+
+}
+
+// stopSolicitingRouters stops soliciting routers. If routers are not currently
+// being solicited, this function does nothing.
+//
+// The NIC ndp belongs to MUST be locked.
+func (ndp *ndpState) stopSolicitingRouters() {
+ if ndp.rtrSolicit.timer == nil {
+ // Nothing to do.
+ return
+ }
+
+ *ndp.rtrSolicit.done = true
+ ndp.rtrSolicit.timer.Stop()
+ ndp.rtrSolicit.timer = nil
+ ndp.rtrSolicit.done = nil
+}
+
+// initializeTempAddrState initializes state related to temporary SLAAC
+// addresses.
+func (ndp *ndpState) initializeTempAddrState() {
+ header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.nic.stack.tempIIDSeed, ndp.nic.ID())
+
+ if MaxDesyncFactor != 0 {
+ ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor)))
}
}
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 525a25218..1a6724c31 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -15,9 +15,14 @@
package stack_test
import (
+ "context"
+ "encoding/binary"
+ "fmt"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -26,74 +31,330 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const (
- addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
- linkAddr1 = "\x02\x02\x03\x04\x05\x06"
+ addr1 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ addr2 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ addr3 = tcpip.Address("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03")
+ linkAddr1 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkAddr2 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x07")
+ linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
+ linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+
+ // Extra time to use when waiting for an async event to occur.
+ defaultAsyncPositiveEventTimeout = 10 * time.Second
+
+ // Extra time to use when waiting for an async event to not occur.
+ //
+ // Since a negative check is used to make sure an event did not happen, it is
+ // okay to use a smaller timeout compared to the positive case since execution
+ // stall in regards to the monotonic clock will not affect the expected
+ // outcome.
+ defaultAsyncNegativeEventTimeout = time.Second
)
-// TestDADDisabled tests that an address successfully resolves immediately
-// when DAD is not enabled (the default for an empty stack.Options).
-func TestDADDisabled(t *testing.T) {
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+var (
+ llAddr1 = header.LinkLocalAddr(linkAddr1)
+ llAddr2 = header.LinkLocalAddr(linkAddr2)
+ llAddr3 = header.LinkLocalAddr(linkAddr3)
+ llAddr4 = header.LinkLocalAddr(linkAddr4)
+ dstAddr = tcpip.FullAddress{
+ Addr: "\x0a\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ Port: 25,
}
+)
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+func addrForSubnet(subnet tcpip.Subnet, linkAddr tcpip.LinkAddress) tcpip.AddressWithPrefix {
+ if !header.IsValidUnicastEthernetAddress(linkAddr) {
+ return tcpip.AddressWithPrefix{}
}
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ addrBytes := []byte(subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, addrBytes[header.IIDOffsetInIPv6Address:])
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
}
+}
- // Should get the address immediately since we should not have performed
- // DAD on it.
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+// prefixSubnetAddr returns a prefix (Address + Length), the prefix's equivalent
+// tcpip.Subnet, and an address where the lower half of the address is composed
+// of the EUI-64 of linkAddr if it is a valid unicast ethernet address.
+func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWithPrefix, tcpip.Subnet, tcpip.AddressWithPrefix) {
+ prefixBytes := []byte{1, 2, 3, 4, 5, 6, 7, 8 + offset, 0, 0, 0, 0, 0, 0, 0, 0}
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(prefixBytes),
+ PrefixLen: 64,
}
- // We should not have sent any NDP NS messages.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
- t.Fatalf("got NeighborSolicit = %d, want = 0", got)
- }
+ subnet := prefix.Subnet()
+
+ return prefix, subnet, addrForSubnet(subnet, linkAddr)
}
// ndpDADEvent is a set of parameters that was passed to
// ndpDispatcher.OnDuplicateAddressDetectionStatus.
type ndpDADEvent struct {
- nicid tcpip.NICID
+ nicID tcpip.NICID
addr tcpip.Address
resolved bool
err *tcpip.Error
}
+type ndpRouterEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.Address
+ // true if router was discovered, false if invalidated.
+ discovered bool
+}
+
+type ndpPrefixEvent struct {
+ nicID tcpip.NICID
+ prefix tcpip.Subnet
+ // true if prefix was discovered, false if invalidated.
+ discovered bool
+}
+
+type ndpAutoGenAddrEventType int
+
+const (
+ newAddr ndpAutoGenAddrEventType = iota
+ deprecatedAddr
+ invalidatedAddr
+)
+
+type ndpAutoGenAddrEvent struct {
+ nicID tcpip.NICID
+ addr tcpip.AddressWithPrefix
+ eventType ndpAutoGenAddrEventType
+}
+
+type ndpRDNSS struct {
+ addrs []tcpip.Address
+ lifetime time.Duration
+}
+
+type ndpRDNSSEvent struct {
+ nicID tcpip.NICID
+ rdnss ndpRDNSS
+}
+
+type ndpDNSSLEvent struct {
+ nicID tcpip.NICID
+ domainNames []string
+ lifetime time.Duration
+}
+
+type ndpDHCPv6Event struct {
+ nicID tcpip.NICID
+ configuration stack.DHCPv6ConfigurationFromNDPRA
+}
+
var _ stack.NDPDispatcher = (*ndpDispatcher)(nil)
// ndpDispatcher implements NDPDispatcher so tests can know when various NDP
// related events happen for test purposes.
type ndpDispatcher struct {
- dadC chan ndpDADEvent
+ dadC chan ndpDADEvent
+ routerC chan ndpRouterEvent
+ rememberRouter bool
+ prefixC chan ndpPrefixEvent
+ rememberPrefix bool
+ autoGenAddrC chan ndpAutoGenAddrEvent
+ rdnssC chan ndpRDNSSEvent
+ dnsslC chan ndpDNSSLEvent
+ routeTable []tcpip.Route
+ dhcpv6ConfigurationC chan ndpDHCPv6Event
}
// Implements stack.NDPDispatcher.OnDuplicateAddressDetectionStatus.
-//
-// If the DAD event matches what we are expecting, send signal on n.dadC.
-func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
- n.dadC <- ndpDADEvent{
- nicid,
- addr,
- resolved,
- err,
+func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) {
+ if n.dadC != nil {
+ n.dadC <- ndpDADEvent{
+ nicID,
+ addr,
+ resolved,
+ err,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterDiscovered.
+func (n *ndpDispatcher) OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
+ nicID,
+ addr,
+ true,
+ }
+ }
+
+ return n.rememberRouter
+}
+
+// Implements stack.NDPDispatcher.OnDefaultRouterInvalidated.
+func (n *ndpDispatcher) OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) {
+ if c := n.routerC; c != nil {
+ c <- ndpRouterEvent{
+ nicID,
+ addr,
+ false,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnOnLinkPrefixDiscovered.
+func (n *ndpDispatcher) OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
+ nicID,
+ prefix,
+ true,
+ }
+ }
+
+ return n.rememberPrefix
+}
+
+// Implements stack.NDPDispatcher.OnOnLinkPrefixInvalidated.
+func (n *ndpDispatcher) OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) {
+ if c := n.prefixC; c != nil {
+ c <- ndpPrefixEvent{
+ nicID,
+ prefix,
+ false,
+ }
+ }
+}
+
+func (n *ndpDispatcher) OnAutoGenAddress(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) bool {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ newAddr,
+ }
+ }
+ return true
+}
+
+func (n *ndpDispatcher) OnAutoGenAddressDeprecated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ deprecatedAddr,
+ }
+ }
+}
+
+func (n *ndpDispatcher) OnAutoGenAddressInvalidated(nicID tcpip.NICID, addr tcpip.AddressWithPrefix) {
+ if c := n.autoGenAddrC; c != nil {
+ c <- ndpAutoGenAddrEvent{
+ nicID,
+ addr,
+ invalidatedAddr,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnRecursiveDNSServerOption.
+func (n *ndpDispatcher) OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) {
+ if c := n.rdnssC; c != nil {
+ c <- ndpRDNSSEvent{
+ nicID,
+ ndpRDNSS{
+ addrs,
+ lifetime,
+ },
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDNSSearchListOption.
+func (n *ndpDispatcher) OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) {
+ if n.dnsslC != nil {
+ n.dnsslC <- ndpDNSSLEvent{
+ nicID,
+ domainNames,
+ lifetime,
+ }
+ }
+}
+
+// Implements stack.NDPDispatcher.OnDHCPv6Configuration.
+func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ if c := n.dhcpv6ConfigurationC; c != nil {
+ c <- ndpDHCPv6Event{
+ nicID,
+ configuration,
+ }
+ }
+}
+
+// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
+// header length.
+type channelLinkWithHeaderLength struct {
+ *channel.Endpoint
+ headerLength uint16
+}
+
+func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
+ return l.headerLength
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// resolved flag set to resolved with the specified err.
+func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
+ return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
+}
+
+// TestDADDisabled tests that an address successfully resolves immediately
+// when DAD is not enabled (the default for an empty stack.Options).
+func TestDADDisabled(t *testing.T) {
+ const nicID = 1
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+
+ // Should get the address immediately since we should not have performed
+ // DAD on it.
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ }
+
+ // We should not have sent any NDP NS messages.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != 0 {
+ t.Fatalf("got NeighborSolicit = %d, want = 0", got)
}
}
@@ -101,23 +362,54 @@ func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicid tcpip.NICID, add
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+// This tests also validates the NDP NS packet that is transmitted.
func TestDADResolve(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
+ linkHeaderLen uint16
dupAddrDetectTransmits uint8
retransTimer time.Duration
expectedRetransmitTimer time.Duration
}{
- {"1:1s:1s", 1, time.Second, time.Second},
- {"2:1s:1s", 2, time.Second, time.Second},
- {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ {
+ name: "1:1s:1s",
+ dupAddrDetectTransmits: 1,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "2:1s:1s",
+ linkHeaderLen: 1,
+ dupAddrDetectTransmits: 2,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "1:2s:2s",
+ linkHeaderLen: 2,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 2 * time.Second,
+ expectedRetransmitTimer: 2 * time.Second,
+ },
// 0s is an invalid RetransmitTimer timer and will be fixed to
// the default RetransmitTimer value of 1s.
- {"1:0s:1s", 1, 0, time.Second},
+ {
+ name: "1:0s:1s",
+ linkHeaderLen: 3,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 0,
+ expectedRetransmitTimer: time.Second,
+ },
}
for _, test := range tests {
+ test := test
+
t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
@@ -128,101 +420,142 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
}
-
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- stat := s.Stats().ICMP.V6PacketsSent.NeighborSolicit
+ // We add a default route so the call to FindRoute below will succeed
+ // once we have an assigned address.
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: addr3,
+ NIC: nicID,
+ }})
- // Should have sent an NDP NS immediately.
- if got := stat.Value(); got != 1 {
- t.Fatalf("got NeighborSolicit = %d, want = 1", got)
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
- // Address should not be considered bound to the NIC yet
- // (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ // Make sure the address does not resolve before the resolution time has
+ // passed.
+ time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ // Should not get a route even if we specify the local address as the
+ // tentative address.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
}
-
- // Wait for the remaining time - some delta (500ms), to
- // make sure the address is still not resolved.
- const delta = 500 * time.Millisecond
- time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - delta)
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != tcpip.ErrNoRoute {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false) = (%+v, %v), want = (_, %s)", nicID, addr1, addr2, header.IPv6ProtocolNumber, r, err, tcpip.ErrNoRoute)
+ }
+ r.Release()
}
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+
+ if t.Failed() {
+ t.FailNow()
}
// Wait for DAD to resolve.
select {
- case <-time.After(2 * delta):
- // We should get a resolution event after 500ms
- // (delta) since we wait for 500ms less than the
- // expected resolution time above to make sure
- // that the address did not yet resolve. Waiting
- // for 1s (2x delta) without a resolution event
- // means something is wrong.
+ case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
+ if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- if e.nicid != 1 {
- t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
- }
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
- }
- if !e.resolved {
- t.Fatal("got DAD event w/ resolved = false, want = true")
+ }
+ if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ } else if addr.Address != addr1 {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ }
+ // Should get a route using the address now that it is resolved.
+ {
+ r, err := s.FindRoute(nicID, "", addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, '', %s, %d, false): %s", nicID, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
}
+ r.Release()
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ {
+ r, err := s.FindRoute(nicID, addr1, addr2, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Errorf("got FindRoute(%d, %s, %s, %d, false): %s", nicID, addr1, addr2, header.IPv6ProtocolNumber, err)
+ } else if r.LocalAddress != addr1 {
+ t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, addr1)
+ }
+ r.Release()
}
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, addr1)
+
+ if t.Failed() {
+ t.FailNow()
}
// Should not have sent any more NS messages.
- if got := stat.Value(); got != uint64(test.dupAddrDetectTransmits) {
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got != uint64(test.dupAddrDetectTransmits) {
t.Fatalf("got NeighborSolicit = %d, want = %d", got, test.dupAddrDetectTransmits)
}
// Validate the sent Neighbor Solicitation messages.
for i := uint8(0); i < test.dupAddrDetectTransmits; i++ {
- p := <-e.C
+ p, _ := e.ReadContext(context.Background())
// Make sure its an IPv6 packet.
if p.Proto != header.IPv6ProtocolNumber {
t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
}
- // Check NDP packet.
- checker.IPv6(t, p.Header.ToVectorisedView().First(),
+ // Make sure the right remote link address is used.
+ snmc := header.SolicitedNodeAddr(addr1)
+ if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ // Check NDP NS packet.
+ //
+ // As per RFC 4861 section 4.3, a possible option is the Source Link
+ // Layer option, but this option MUST NOT be included when the source
+ // address of the packet is the unspecified address.
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(snmc),
checker.TTL(header.NDPHopLimit),
checker.NDPNS(
- checker.NDPNSTargetAddress(addr1)))
+ checker.NDPNSTargetAddress(addr1),
+ checker.NDPNSOptions(nil),
+ ))
+
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
+ }
}
})
}
-
}
// TestDADFail tests to make sure that the DAD process fails if another node is
@@ -230,6 +563,8 @@ func TestDADResolve(t *testing.T) {
// a node doing DAD for the same address), or if another node is detected to own
// the address already (receive an NA message for the tentative address).
func TestDADFail(t *testing.T) {
+ const nicID = 1
+
tests := []struct {
name string
makeBuf func(tgt tcpip.Address) buffer.Prependable
@@ -265,13 +600,17 @@ func TestDADFail(t *testing.T) {
{
"RxAdvert",
func(tgt tcpip.Address) buffer.Prependable {
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + naSize)
+ pkt := header.ICMPv6(hdr.Prepend(naSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
na := header.NDPNeighborAdvert(pkt.NDPPayload())
na.SetSolicitedFlag(true)
na.SetOverrideFlag(true)
na.SetTargetAddress(tgt)
+ na.Options().Serialize(header.NDPOptionsSerializer{
+ header.NDPTargetLinkLayerAddressOption(linkAddr1),
+ })
pkt.SetChecksum(header.ICMPv6Checksum(pkt, tgt, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
payloadLength := hdr.UsedLength()
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -295,7 +634,7 @@ func TestDADFail(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
ndpConfigs := stack.DefaultNDPConfigurations()
opts := stack.Options{
@@ -305,30 +644,33 @@ func TestDADFail(t *testing.T) {
}
opts.NDPConfigs.RetransmitTimer = time.Second * 2
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
// Address should not be considered bound to the NIC yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
// Receive a packet to simulate multiple nodes owning or
// attempting to own the same address.
hdr := test.makeBuf(addr1)
- e.Inject(header.IPv6ProtocolNumber, hdr.View().ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+ e.InjectInbound(header.IPv6ProtocolNumber, pkt)
stat := test.getStat(s.Stats().ICMP.V6PacketsReceived)
if got := stat.Value(); got != 1 {
@@ -344,102 +686,132 @@ func TestDADFail(t *testing.T) {
// something is wrong.
t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicid != 1 {
- t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
- }
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
- }
- if e.resolved {
- t.Fatal("got DAD event w/ resolved = true, want = false")
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ // Attempting to add the address again should not fail if the address's
+ // state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
}
})
}
}
-// TestDADStop tests to make sure that the DAD process stops when an address is
-// removed.
func TestDADStop(t *testing.T) {
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
- }
- ndpConfigs := stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
+ const nicID = 1
- e := channel.New(10, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
+ tests := []struct {
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ skipFinalAddrCheck bool
+ }{
+ // Tests to make sure that DAD stops when an address is removed.
+ {
+ name: "Remove address",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
+ }
+ },
+ },
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
- }
+ // Tests to make sure that DAD stops when the NIC is disabled.
+ {
+ name: "Disable NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ // Tests to make sure that DAD stops when the NIC is removed.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ // The NIC is removed so we can't check its addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
}
- // Remove the address. This should stop DAD.
- if err := s.RemoveAddress(1, addr1); err != nil {
- t.Fatalf("RemoveAddress(_, %s) = %s", addr1, err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
- // Wait for DAD to fail (since the address was removed during DAD).
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicid != 1 {
- t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
- }
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
- }
- if e.resolved {
- t.Fatal("got DAD event w/ resolved = true, want = false")
- }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- }
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
- }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ test.stopFn(t, s)
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+
+ if !test.skipFinalAddrCheck {
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+ }
+
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Errorf("got NeighborSolicit = %d, want <= 1", got)
+ }
+ })
}
}
@@ -460,6 +832,10 @@ func TestSetNDPConfigurationFailsForBadNICID(t *testing.T) {
// configurations without affecting the default NDP configurations or other
// interfaces' configurations.
func TestSetNDPConfigurations(t *testing.T) {
+ const nicID1 = 1
+ const nicID2 = 2
+ const nicID3 = 3
+
tests := []struct {
name string
dupAddrDetectTransmits uint8
@@ -483,25 +859,36 @@ func TestSetNDPConfigurations(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent),
+ dadC: make(chan ndpDADEvent, 1),
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(0, 1280, linkAddr1)
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
NDPDisp: &ndpDisp,
})
+ expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatalf("expected DAD event for %s", addr)
+ }
+ }
+
// This NIC(1)'s NDP configurations will be updated to
// be different from the default.
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
}
// Created before updating NIC(1)'s NDP configurations
// but updating NIC(1)'s NDP configurations should not
// affect other existing NICs.
- if err := s.CreateNIC(2, e); err != nil {
- t.Fatalf("CreateNIC(2) = %s", err)
+ if err := s.CreateNIC(nicID2, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
}
// Update the NDP configurations on NIC(1) to use DAD.
@@ -509,36 +896,38 @@ func TestSetNDPConfigurations(t *testing.T) {
DupAddrDetectTransmits: test.dupAddrDetectTransmits,
RetransmitTimer: test.retransmitTimer,
}
- if err := s.SetNDPConfigurations(1, configs); err != nil {
- t.Fatalf("got SetNDPConfigurations(1, _) = %s", err)
+ if err := s.SetNDPConfigurations(nicID1, configs); err != nil {
+ t.Fatalf("got SetNDPConfigurations(%d, _) = %s", nicID1, err)
}
// Created after updating NIC(1)'s NDP configurations
// but the stack's default NDP configurations should not
// have been updated.
- if err := s.CreateNIC(3, e); err != nil {
- t.Fatalf("CreateNIC(3) = %s", err)
+ if err := s.CreateNIC(nicID3, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID3, err)
}
// Add addresses for each NIC.
- if err := s.AddAddress(1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(1, %d, %s) = %s", header.IPv6ProtocolNumber, addr1, err)
+ if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err)
}
- if err := s.AddAddress(2, header.IPv6ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(2, %d, %s) = %s", header.IPv6ProtocolNumber, addr2, err)
+ if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err)
}
- if err := s.AddAddress(3, header.IPv6ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(3, %d, %s) = %s", header.IPv6ProtocolNumber, addr3, err)
+ expectDADEvent(nicID2, addr2)
+ if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err)
}
+ expectDADEvent(nicID3, addr3)
// Address should not be considered bound to NIC(1) yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
}
// Should get the address on NIC(2) and NIC(3)
@@ -546,31 +935,31 @@ func TestSetNDPConfigurations(t *testing.T) {
// it as the stack was configured to not do DAD by
// default and we only updated the NDP configurations on
// NIC(1).
- addr, err = s.GetMainNICAddress(2, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(2, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err)
}
if addr.Address != addr2 {
- t.Fatalf("got stack.GetMainNICAddress(2, _) = %s, want = %s", addr, addr2)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2)
}
- addr, err = s.GetMainNICAddress(3, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(3, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err)
}
if addr.Address != addr3 {
- t.Fatalf("got stack.GetMainNICAddress(3, _) = %s, want = %s", addr, addr3)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3)
}
// Sleep until right (500ms before) before resolution to
// make sure the address didn't resolve on NIC(1) yet.
const delta = 500 * time.Millisecond
time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
}
// Wait for DAD to resolve.
@@ -584,25 +973,4386 @@ func TestSetNDPConfigurations(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
+ if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
- if e.nicid != 1 {
- t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
+ }
+ addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
+ }
+ if addr.Address != addr1 {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1)
+ }
+ })
+ }
+}
+
+// raBufWithOptsAndDHCPv6 returns a valid NDP Router Advertisement with options
+// and DHCPv6 configurations specified.
+func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherConfigurations bool, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+ icmpSize := header.ICMPv6HeaderSize + header.NDPRAMinimumSize + int(optSer.Length())
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize)
+ pkt := header.ICMPv6(hdr.Prepend(icmpSize))
+ pkt.SetType(header.ICMPv6RouterAdvert)
+ pkt.SetCode(0)
+ raPayload := pkt.NDPPayload()
+ ra := header.NDPRouterAdvert(raPayload)
+ // Populate the Router Lifetime.
+ binary.BigEndian.PutUint16(raPayload[2:], rl)
+ // Populate the Managed Address flag field.
+ if managedAddress {
+ // The Managed Addresses flag field is the 7th bit of byte #1 (0-indexing)
+ // of the RA payload.
+ raPayload[1] |= (1 << 7)
+ }
+ // Populate the Other Configurations flag field.
+ if otherConfigurations {
+ // The Other Configurations flag field is the 6th bit of byte #1
+ // (0-indexing) of the RA payload.
+ raPayload[1] |= (1 << 6)
+ }
+ opts := ra.Options()
+ opts.Serialize(optSer)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, ip, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ iph.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: header.NDPHopLimit,
+ SrcAddr: ip,
+ DstAddr: header.IPv6AllNodesMulticastAddress,
+ })
+
+ return stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ })
+}
+
+// raBufWithOpts returns a valid NDP Router Advertisement with options.
+//
+// Note, raBufWithOpts does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithOpts(ip tcpip.Address, rl uint16, optSer header.NDPOptionsSerializer) *stack.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, rl, false, false, optSer)
+}
+
+// raBufWithDHCPv6 returns a valid NDP Router Advertisement with DHCPv6 related
+// fields set.
+//
+// Note, raBufWithDHCPv6 does not populate any of the RA fields other than the
+// DHCPv6 related ones.
+func raBufWithDHCPv6(ip tcpip.Address, managedAddresses, otherConfiguratiosns bool) *stack.PacketBuffer {
+ return raBufWithOptsAndDHCPv6(ip, 0, managedAddresses, otherConfiguratiosns, header.NDPOptionsSerializer{})
+}
+
+// raBuf returns a valid NDP Router Advertisement.
+//
+// Note, raBuf does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBuf(ip tcpip.Address, rl uint16) *stack.PacketBuffer {
+ return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{})
+}
+
+// raBufWithPI returns a valid NDP Router Advertisement with a single Prefix
+// Information option.
+//
+// Note, raBufWithPI does not populate any of the RA fields other than the
+// Router Lifetime.
+func raBufWithPI(ip tcpip.Address, rl uint16, prefix tcpip.AddressWithPrefix, onLink, auto bool, vl, pl uint32) *stack.PacketBuffer {
+ flags := uint8(0)
+ if onLink {
+ // The OnLink flag is the 7th bit in the flags byte.
+ flags |= 1 << 7
+ }
+ if auto {
+ // The Address Auto-Configuration flag is the 6th bit in the
+ // flags byte.
+ flags |= 1 << 6
+ }
+
+ // A valid header.NDPPrefixInformation must be 30 bytes.
+ buf := [30]byte{}
+ // The first byte in a header.NDPPrefixInformation is the Prefix Length
+ // field.
+ buf[0] = uint8(prefix.PrefixLen)
+ // The 2nd byte within a header.NDPPrefixInformation is the Flags field.
+ buf[1] = flags
+ // The Valid Lifetime field starts after the 2nd byte within a
+ // header.NDPPrefixInformation.
+ binary.BigEndian.PutUint32(buf[2:], vl)
+ // The Preferred Lifetime field starts after the 6th byte within a
+ // header.NDPPrefixInformation.
+ binary.BigEndian.PutUint32(buf[6:], pl)
+ // The Prefix Address field starts after the 14th byte within a
+ // header.NDPPrefixInformation.
+ copy(buf[14:], prefix.Address)
+ return raBufWithOpts(ip, rl, header.NDPOptionsSerializer{
+ header.NDPPrefixInformation(buf[:]),
+ })
+}
+
+// TestNoRouterDiscovery tests that router discovery will not be performed if
+// configured not to.
+func TestNoRouterDiscovery(t *testing.T) {
+ // Being configured to discover routers means handle and
+ // discover are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, discover =
+ // true and forwarding = false (the required configuration to do
+ // router discovery) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ discover := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverDefaultRouters(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverDefaultRouters: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkRouterEvent(e ndpRouterEvent, addr tcpip.Address, discovered bool) string {
+ return cmp.Diff(ndpRouterEvent{nicID: 1, addr: addr, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
+// TestRouterDiscoveryDispatcherNoRemember tests that the stack does not
+// remember a discovered router when the dispatcher asks it not to.
+func TestRouterDiscoveryDispatcherNoRemember(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA for a router we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, lifetimeSeconds))
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr2, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the router in the first place.
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have received any router events")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+}
+
+func TestRouterDiscovery(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ expectRouterEvent := func(addr tcpip.Address, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, discovered); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected router discovery event")
+ }
+ }
+
+ expectAsyncRouterInvalidationEvent := func(addr tcpip.Address, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, addr, false); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for router discovery event")
+ }
+ }
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA from lladdr2 with zero lifetime. It should not be
+ // remembered.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("unexpectedly discovered a router with 0 lifetime")
+ default:
+ }
+
+ // Rx an RA from lladdr2 with a huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from another router (lladdr3) with non-zero lifetime.
+ const l3LifetimeSeconds = 6
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr3, l3LifetimeSeconds))
+ expectRouterEvent(llAddr3, true)
+
+ // Rx an RA from lladdr2 with lesser lifetime.
+ const l2LifetimeSeconds = 2
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, l2LifetimeSeconds))
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("Should not receive a router event when updating lifetimes for known routers")
+ default:
+ }
+
+ // Wait for lladdr2's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr2, l2LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+
+ // Rx an RA from lladdr2 with huge lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 1000))
+ expectRouterEvent(llAddr2, true)
+
+ // Rx an RA from lladdr2 with zero lifetime. It should be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0))
+ expectRouterEvent(llAddr2, false)
+
+ // Wait for lladdr3's router invalidation job to execute. The lifetime
+ // of the router should have been updated to the most recent (smaller)
+ // lifetime.
+ //
+ // Wait for the normal lifetime plus an extra bit for the
+ // router to get invalidated. If we don't get an invalidation
+ // event after this time, then something is wrong.
+ expectAsyncRouterInvalidationEvent(llAddr3, l3LifetimeSeconds*time.Second+defaultAsyncPositiveEventTimeout)
+}
+
+// TestRouterDiscoveryMaxRouters tests that only
+// stack.MaxDiscoveredDefaultRouters discovered routers are remembered.
+func TestRouterDiscoveryMaxRouters(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA from 2 more than the max number of discovered routers.
+ for i := 1; i <= stack.MaxDiscoveredDefaultRouters+2; i++ {
+ linkAddr := []byte{2, 2, 3, 4, 5, 0}
+ linkAddr[5] = byte(i)
+ llAddr := header.LinkLocalAddr(tcpip.LinkAddress(linkAddr))
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr, 5))
+
+ if i <= stack.MaxDiscoveredDefaultRouters {
+ select {
+ case e := <-ndpDisp.routerC:
+ if diff := checkRouterEvent(e, llAddr, true); diff != "" {
+ t.Errorf("router event mismatch (-want +got):\n%s", diff)
}
- if e.addr != addr1 {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, addr1)
+ default:
+ t.Fatal("expected router discovery event")
+ }
+
+ } else {
+ select {
+ case <-ndpDisp.routerC:
+ t.Fatal("should not have discovered a new router after we already discovered the max number of routers")
+ default:
+ }
+ }
+ }
+}
+
+// TestNoPrefixDiscovery tests that prefix discovery will not be performed if
+// configured not to.
+func TestNoPrefixDiscovery(t *testing.T) {
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ PrefixLen: 64,
+ }
+
+ // Being configured to discover prefixes means handle and
+ // discover are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, discover =
+ // true and forwarding = false (the required configuration to do
+ // prefix discovery) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ discover := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), DiscoverOnLinkPrefixes(%t), Forwarding(%t)", handle, discover, forwarding), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ DiscoverOnLinkPrefixes: discover,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with prefix with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 10, 0))
+
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for prefix on nic with ID 1, and the
+// discovered flag set to discovered.
+func checkPrefixEvent(e ndpPrefixEvent, prefix tcpip.Subnet, discovered bool) string {
+ return cmp.Diff(ndpPrefixEvent{nicID: 1, prefix: prefix, discovered: discovered}, e, cmp.AllowUnexported(e))
+}
+
+// TestPrefixDiscoveryDispatcherNoRemember tests that the stack does not
+// remember a discovered on-link prefix when the dispatcher asks it not to.
+func TestPrefixDiscoveryDispatcherNoRemember(t *testing.T) {
+ prefix, subnet, _ := prefixSubnetAddr(0, "")
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix that we should not remember.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, lifetimeSeconds, 0))
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+
+ // Wait for the invalidation time plus some buffer to make sure we do
+ // not actually receive any invalidation events as we should not have
+ // remembered the prefix in the first place.
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("should not have received any prefix events")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+}
+
+func TestPrefixDiscovery(t *testing.T) {
+ prefix1, subnet1, _ := prefixSubnetAddr(0, "")
+ prefix2, subnet2, _ := prefixSubnetAddr(1, "")
+ prefix3, subnet3, _ := prefixSubnetAddr(2, "")
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly discovered a prefix with 0 lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 100, 0))
+ expectPrefixEvent(subnet1, true)
+
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, 100, 0))
+ expectPrefixEvent(subnet2, true)
+
+ // Receive an RA with prefix3 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 100, 0))
+ expectPrefixEvent(subnet3, true)
+
+ // Receive an RA with prefix1 in a PI with lifetime = 0.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, false, 0, 0))
+ expectPrefixEvent(subnet1, false)
+
+ // Receive an RA with prefix2 in a PI with lesser lifetime.
+ lifetime := uint32(2)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, false, lifetime, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly received prefix event when updating lifetime")
+ default:
+ }
+
+ // Wait for prefix2's most recent invalidation job plus some buffer to
+ // expire.
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet2, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(lifetime)*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive RA to invalidate prefix3.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix3, true, false, 0, 0))
+ expectPrefixEvent(subnet3, false)
+}
+
+func TestPrefixDiscoveryWithInfiniteLifetime(t *testing.T) {
+ // Update the infinite lifetime value to a smaller value so we can test
+ // that when we receive a PI with such a lifetime value, we do not
+ // invalidate the prefix.
+ const testInfiniteLifetimeSeconds = 2
+ const testInfiniteLifetime = testInfiniteLifetimeSeconds * time.Second
+ saved := header.NDPInfiniteLifetime
+ header.NDPInfiniteLifetime = testInfiniteLifetime
+ defer func() {
+ header.NDPInfiniteLifetime = saved
+ }()
+
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x00"),
+ PrefixLen: 64,
+ }
+ subnet := prefix.Subnet()
+
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, 1),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectPrefixEvent := func(prefix tcpip.Subnet, discovered bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefix, discovered); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ }
+
+ // Receive an RA with prefix in an NDP Prefix Information option (PI)
+ // with infinite valid lifetime which should not get invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ expectPrefixEvent(subnet, true)
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Receive an RA with finite lifetime.
+ // The prefix should get invalidated after 1s.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, subnet, false); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(testInfiniteLifetime):
+ t.Fatal("timed out waiting for prefix discovery event")
+ }
+
+ // Receive an RA with finite lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds-1, 0))
+ expectPrefixEvent(subnet, true)
+
+ // Receive an RA with prefix with an infinite lifetime.
+ // The prefix should not be invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After(testInfiniteLifetime + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Receive an RA with a prefix with a lifetime value greater than the
+ // set infinite lifetime value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, testInfiniteLifetimeSeconds+1, 0))
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("unexpectedly invalidated a prefix with infinite lifetime")
+ case <-time.After((testInfiniteLifetimeSeconds+1)*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Receive an RA with 0 lifetime.
+ // The prefix should get invalidated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, false, 0, 0))
+ expectPrefixEvent(subnet, false)
+}
+
+// TestPrefixDiscoveryMaxRouters tests that only
+// stack.MaxDiscoveredOnLinkPrefixes discovered on-link prefixes are remembered.
+func TestPrefixDiscoveryMaxOnLinkPrefixes(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ prefixC: make(chan ndpPrefixEvent, stack.MaxDiscoveredOnLinkPrefixes+3),
+ rememberPrefix: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: false,
+ DiscoverOnLinkPrefixes: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ optSer := make(header.NDPOptionsSerializer, stack.MaxDiscoveredOnLinkPrefixes+2)
+ prefixes := [stack.MaxDiscoveredOnLinkPrefixes + 2]tcpip.Subnet{}
+
+ // Receive an RA with 2 more than the max number of discovered on-link
+ // prefixes.
+ for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ prefixAddr := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0}
+ prefixAddr[7] = byte(i)
+ prefix := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(prefixAddr[:]),
+ PrefixLen: 64,
+ }
+ prefixes[i] = prefix.Subnet()
+ buf := [30]byte{}
+ buf[0] = uint8(prefix.PrefixLen)
+ buf[1] = 128
+ binary.BigEndian.PutUint32(buf[2:], 10)
+ copy(buf[14:], prefix.Address)
+
+ optSer[i] = header.NDPPrefixInformation(buf[:])
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
+ for i := 0; i < stack.MaxDiscoveredOnLinkPrefixes+2; i++ {
+ if i < stack.MaxDiscoveredOnLinkPrefixes {
+ select {
+ case e := <-ndpDisp.prefixC:
+ if diff := checkPrefixEvent(e, prefixes[i], true); diff != "" {
+ t.Errorf("prefix event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected prefix discovery event")
+ }
+ } else {
+ select {
+ case <-ndpDisp.prefixC:
+ t.Fatal("should not have discovered a new prefix after we already discovered the max number of prefixes")
+ default:
+ }
+ }
+ }
+}
+
+// Checks to see if list contains an IPv6 address, item.
+func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) bool {
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: item,
+ }
+
+ return containsAddr(list, protocolAddress)
+}
+
+// TestNoAutoGenAddr tests that SLAAC is not performed when configured not to.
+func TestNoAutoGenAddr(t *testing.T) {
+ prefix, _, _ := prefixSubnetAddr(0, "")
+
+ // Being configured to auto-generate addresses means handle and
+ // autogen are set to true and forwarding is set to false.
+ // This tests all possible combinations of the configurations,
+ // except for the configuration where handle = true, autogen =
+ // true and forwarding = false (the required configuration to do
+ // SLAAC) - that will done in other tests.
+ for i := 0; i < 7; i++ {
+ handle := i&1 != 0
+ autogen := i&2 != 0
+ forwarding := i&4 == 0
+
+ t.Run(fmt.Sprintf("HandleRAs(%t), AutoGenAddr(%t), Forwarding(%t)", handle, autogen, forwarding), func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: handle,
+ AutoGenGlobalAddresses: autogen,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ s.SetForwarding(ipv6.ProtocolNumber, forwarding)
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Rx an RA with prefix with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, false, true, 10, 0))
+
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when configured not to")
+ default:
+ }
+ })
+ }
+}
+
+// Check e to make sure that the event is for addr on nic with ID 1, and the
+// event type is set to eventType.
+func checkAutoGenAddrEvent(e ndpAutoGenAddrEvent, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) string {
+ return cmp.Diff(ndpAutoGenAddrEvent{nicID: 1, addr: addr, eventType: eventType}, e, cmp.AllowUnexported(e))
+}
+
+// TestAutoGenAddr tests that an address is properly generated and invalidated
+// when configured to do so.
+func TestAutoGenAddr(t *testing.T) {
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with 0 lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address with preferred lifetime > valid lifetime")
+ default:
+ }
+
+ // Receive an RA with prefix2 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+
+ // Refresh valid lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address when we already have an address for a prefix")
+ default:
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr1) {
+ t.Fatalf("Should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr2) {
+ t.Fatalf("Should have %s in the list of addresses", addr2)
+ }
+}
+
+func addressCheck(addrs []tcpip.ProtocolAddress, containList, notContainList []tcpip.AddressWithPrefix) string {
+ ret := ""
+ for _, c := range containList {
+ if !containsV6Addr(addrs, c) {
+ ret += fmt.Sprintf("should have %s in the list of addresses\n", c)
+ }
+ }
+ for _, c := range notContainList {
+ if containsV6Addr(addrs, c) {
+ ret += fmt.Sprintf("should not have %s in the list of addresses\n", c)
+ }
+ }
+ return ret
+}
+
+// TestAutoGenTempAddr tests that temporary SLAAC addresses are generated when
+// configured to do so as part of IPv6 Privacy Extensions.
+func TestAutoGenTempAddr(t *testing.T) {
+ const (
+ nicID = 1
+ newMinVL = 5
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMinPrefixInformationValidLifetimeForUpdate := stack.MinPrefixInformationValidLifetimeForUpdate
+ savedMaxDesync := stack.MaxDesyncFactor
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinPrefixInformationValidLifetimeForUpdate
+ stack.MaxDesyncFactor = savedMaxDesync
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ tests := []struct {
+ name string
+ dupAddrTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD disabled",
+ },
+ {
+ name: "DAD enabled",
+ dupAddrTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for i, test := range tests {
+ i := i
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ seed := []byte{uint8(i)}
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], seed, nicID)
+ newTempAddr := func(stableAddr tcpip.Address) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableAddr)
+ }
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 2),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: test.dupAddrTransmits,
+ RetransmitTimer: test.retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ TempIIDSeed: seed,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 0, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with 0 lifetime; event = %+v", e)
+ default:
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectDADEventAsync(addr1.Address)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ tempAddr1 := newTempAddr(addr1.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ expectDADEventAsync(tempAddr1.Address)
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in an NDP Prefix Information option (PI)
+ // with preferred lifetime > valid lifetime
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 5, 6))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly auto-generated an address with preferred lifetime > valid lifetime; event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Receive an RA with prefix2 in a PI w/ non-zero valid and preferred
+ // lifetimes.
+ tempAddr2 := newTempAddr(addr2.Address)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ expectDADEventAsync(addr2.Address)
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr)
+ expectDADEventAsync(tempAddr2.Address)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Deprecate prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
}
- if !e.resolved {
- t.Fatal("got DAD event w/ resolved = false, want = true")
+
+ // Refresh lifetimes for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Reduce valid lifetime and deprecate addresses of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr1, tempAddr1, addr2, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for addrs of prefix1 to be invalidated. They should be
+ // invalidated at the same time.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ var nextAddr tcpip.AddressWithPrefix
+ if e.addr == addr1 {
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = tempAddr1
+ } else {
+ if diff := checkAutoGenAddrEvent(e, tempAddr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ nextAddr = addr1
+ }
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, nextAddr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
}
+
+ // Receive an RA with prefix2 in a PI w/ 0 lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 0, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr2, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr2, tempAddr2}, []tcpip.AddressWithPrefix{addr1, tempAddr1}); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ })
+ }
+ })
+}
+
+// TestNoAutoGenTempAddrForLinkLocal test that temporary SLAAC addresses are not
+// generated for auto generated link-local addresses.
+func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
+ const nicID = 1
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ }()
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ tests := []struct {
+ name string
+ dupAddrTransmits uint8
+ retransmitTimer time.Duration
+ }{
+ {
+ name: "DAD disabled",
+ },
+ {
+ name: "DAD enabled",
+ dupAddrTransmits: 1,
+ retransmitTimer: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ AutoGenIPv6LinkLocal: true,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // The stable link-local address should auto-generate and resolve DAD.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tcpip.AddressWithPrefix{Address: llAddr1, PrefixLen: header.IIDOffsetInIPv6Address * 8}, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ // No new addresses should be generated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Errorf("got unxpected auto gen addr event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ })
+ }
+ })
+}
+
+// TestNoAutoGenTempAddrWithoutStableAddr tests that a temporary SLAAC address
+// will not be generated until after DAD completes, even if a new Router
+// Advertisement is received to refresh lifetimes.
+func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
+ const (
+ nicID = 1
+ dadTransmits = 1
+ retransmitTimer = 2 * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ }()
+ stack.MaxDesyncFactor = 0
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Receive an RA to trigger SLAAC for prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // DAD on the stable address for prefix has not yet completed. Receiving a new
+ // RA that would refresh lifetimes should not generate a temporary SLAAC
+ // address for the prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ default:
+ }
+
+ // Wait for DAD to complete for the stable address then expect the temporary
+ // address to be generated.
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, tempAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+}
+
+// TestAutoGenTempAddrRegen tests that temporary SLAAC addresses are
+// regenerated.
+func TestAutoGenTempAddrRegen(t *testing.T) {
+ const (
+ nicID = 1
+ regenAfter = 2 * time.Second
+ newMinVL = 10
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ }()
+ stack.MaxDesyncFactor = 0
+ stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(1, _) err = %s", err)
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
}
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(1, _) = %s, want = %s", addr, addr1)
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr, newAddr)
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for regeneration
+ expectAutoGenAddrEventAsync(tempAddr2, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Wait for regeneration
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1, tempAddr2, tempAddr3}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Stop generating temporary addresses
+ ndpConfigs.AutoGenTempGlobalAddresses = false
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+
+ // Wait for all the temporary addresses to get invalidated.
+ tempAddrs := []tcpip.AddressWithPrefix{tempAddr1, tempAddr2, tempAddr3}
+ invalidateAfter := newMinVLDuration - 2*regenAfter
+ for _, addr := range tempAddrs {
+ // Wait for a deprecation then invalidation event, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation jobs could execute in any
+ // order.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we shouldn't get a deprecation
+ // event after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event = %+v", e)
+ }
+ case <-time.After(invalidateAfter + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+
+ invalidateAfter = regenAfter
+ }
+ if mismatch := addressCheck(s.NICInfo()[1].ProtocolAddresses, []tcpip.AddressWithPrefix{addr}, tempAddrs); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+}
+
+// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's
+// regeneration job gets updated when refreshing the address's lifetimes.
+func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) {
+ const (
+ nicID = 1
+ regenAfter = 2 * time.Second
+ newMinVL = 10
+ newMinVLDuration = newMinVL * time.Second
+ )
+
+ savedMaxDesyncFactor := stack.MaxDesyncFactor
+ savedMinMaxTempAddrPreferredLifetime := stack.MinMaxTempAddrPreferredLifetime
+ savedMinMaxTempAddrValidLifetime := stack.MinMaxTempAddrValidLifetime
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesyncFactor
+ stack.MinMaxTempAddrPreferredLifetime = savedMinMaxTempAddrPreferredLifetime
+ stack.MinMaxTempAddrValidLifetime = savedMinMaxTempAddrValidLifetime
+ }()
+ stack.MaxDesyncFactor = 0
+ stack.MinMaxTempAddrPreferredLifetime = newMinVLDuration
+ stack.MinMaxTempAddrValidLifetime = newMinVLDuration
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+ tempAddr3 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], addr.Address)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ RegenAdvanceDuration: newMinVLDuration - regenAfter,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in an NDP Prefix Information option (PI)
+ // with non-zero valid & preferred lifetimes.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr, newAddr)
+ expectAutoGenAddrEvent(tempAddr1, newAddr)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, []tcpip.AddressWithPrefix{addr, tempAddr1}, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Deprecate the prefix.
+ //
+ // A new temporary address should be generated after the regeneration
+ // time has passed since the prefix is deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr, deprecatedAddr)
+ expectAutoGenAddrEvent(tempAddr1, deprecatedAddr)
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Prefer the prefix again.
+ //
+ // A new temporary address should immediately be generated since the
+ // regeneration time has already passed since the last address was generated
+ // - this regeneration does not depend on a job.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(tempAddr2, newAddr)
+
+ // Increase the maximum lifetimes for temporary addresses to large values
+ // then refresh the lifetimes of the prefix.
+ //
+ // A new address should not be generated after the regeneration time that was
+ // expected for the previous check. This is because the preferred lifetime for
+ // the temporary addresses has increased, so it will take more time to
+ // regenerate a new temporary address. Note, new addresses are only
+ // regenerated after the preferred lifetime - the regenerate advance duration
+ // as paased.
+ ndpConfigs.MaxTempAddrValidLifetime = 100 * time.Second
+ ndpConfigs.MaxTempAddrPreferredLifetime = 100 * time.Second
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ case <-time.After(regenAfter + defaultAsyncNegativeEventTimeout):
+ }
+
+ // Set the maximum lifetimes for temporary addresses such that on the next
+ // RA, the regeneration job gets scheduled again.
+ //
+ // The maximum lifetime is the sum of the minimum lifetimes for temporary
+ // addresses + the time that has already passed since the last address was
+ // generated so that the regeneration job is needed to generate the next
+ // address.
+ newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout
+ ndpConfigs.MaxTempAddrValidLifetime = newLifetimes
+ ndpConfigs.MaxTempAddrPreferredLifetime = newLifetimes
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEventAsync(tempAddr3, newAddr, regenAfter+defaultAsyncPositiveEventTimeout)
+}
+
+// TestMixedSLAACAddrConflictRegen tests SLAAC address regeneration in response
+// to a mix of DAD conflicts and NIC-local conflicts.
+func TestMixedSLAACAddrConflictRegen(t *testing.T) {
+ const (
+ nicID = 1
+ nicName = "nic"
+ lifetimeSeconds = 9999
+ // From stack.maxSLAACAddrLocalRegenAttempts
+ maxSLAACAddrLocalRegenAttempts = 10
+ // We use 2 more addreses than the maximum local regeneration attempts
+ // because we want to also trigger regeneration in response to a DAD
+ // conflicts for this test.
+ maxAddrs = maxSLAACAddrLocalRegenAttempts + 2
+ dupAddrTransmits = 1
+ retransmitTimer = time.Second
+ )
+
+ var tempIIDHistoryWithModifiedEUI64 [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistoryWithModifiedEUI64[:], nil, nicID)
+
+ var tempIIDHistoryWithOpaqueIID [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistoryWithOpaqueIID[:], nil, nicID)
+
+ prefix, subnet, stableAddrWithModifiedEUI64 := prefixSubnetAddr(0, linkAddr1)
+ var stableAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix
+ var tempAddrsWithOpaqueIID [maxAddrs]tcpip.AddressWithPrefix
+ var tempAddrsWithModifiedEUI64 [maxAddrs]tcpip.AddressWithPrefix
+ addrBytes := []byte(subnet.ID())
+ for i := 0; i < maxAddrs; i++ {
+ stableAddrsWithOpaqueIID[i] = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, uint8(i), nil)),
+ PrefixLen: header.IIDOffsetInIPv6Address * 8,
+ }
+ // When generating temporary addresses, the resolved stable address for the
+ // SLAAC prefix will be the first address stable address generated for the
+ // prefix as we will not simulate address conflicts for the stable addresses
+ // in tests involving temporary addresses. Address conflicts for stable
+ // addresses will be done in their own tests.
+ tempAddrsWithOpaqueIID[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithOpaqueIID[:], stableAddrsWithOpaqueIID[0].Address)
+ tempAddrsWithModifiedEUI64[i] = header.GenerateTempIPv6SLAACAddr(tempIIDHistoryWithModifiedEUI64[:], stableAddrWithModifiedEUI64.Address)
+ }
+
+ tests := []struct {
+ name string
+ addrs []tcpip.AddressWithPrefix
+ tempAddrs bool
+ initialExpect tcpip.AddressWithPrefix
+ nicNameFromID func(tcpip.NICID, string) string
+ }{
+ {
+ name: "Stable addresses with opaque IIDs",
+ addrs: stableAddrsWithOpaqueIID[:],
+ nicNameFromID: func(tcpip.NICID, string) string {
+ return nicName
+ },
+ },
+ {
+ name: "Temporary addresses with opaque IIDs",
+ addrs: tempAddrsWithOpaqueIID[:],
+ tempAddrs: true,
+ initialExpect: stableAddrsWithOpaqueIID[0],
+ nicNameFromID: func(tcpip.NICID, string) string {
+ return nicName
+ },
+ },
+ {
+ name: "Temporary addresses with modified EUI64",
+ addrs: tempAddrsWithModifiedEUI64[:],
+ tempAddrs: true,
+ initialExpect: stableAddrWithModifiedEUI64,
+ },
+ }
+
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: test.tempAddrs,
+ AutoGenAddressConflictRetries: 1,
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: test.nicNameFromID,
+ },
+ })
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr2,
+ NIC: nicID,
+ }})
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ for j := 0; j < len(test.addrs)-1; j++ {
+ // The NIC will not attempt to generate an address in response to a
+ // NIC-local conflict after some maximum number of attempts. We skip
+ // creating a conflict for the address that would be generated as part
+ // of the last attempt so we can simulate a DAD conflict for this
+ // address and restart the NIC-local generation process.
+ if j == maxSLAACAddrLocalRegenAttempts-1 {
+ continue
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, test.addrs[j].Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, test.addrs[j].Address, err)
+ }
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrAsyncEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectDADEventAsync := func(addr tcpip.Address) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dupAddrTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
+
+ // Enable DAD.
+ ndpDisp.dadC = make(chan ndpDADEvent, 2)
+ ndpConfigs.DupAddrDetectTransmits = dupAddrTransmits
+ ndpConfigs.RetransmitTimer = retransmitTimer
+ if err := s.SetNDPConfigurations(nicID, ndpConfigs); err != nil {
+ t.Fatalf("s.SetNDPConfigurations(%d, _): %s", nicID, err)
+ }
+
+ // Do SLAAC for prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+ if test.initialExpect != (tcpip.AddressWithPrefix{}) {
+ expectAutoGenAddrEvent(test.initialExpect, newAddr)
+ expectDADEventAsync(test.initialExpect.Address)
+ }
+
+ // The last local generation attempt should succeed, but we introduce a
+ // DAD failure to restart the local generation process.
+ addr := test.addrs[maxSLAACAddrLocalRegenAttempts-1]
+ expectAutoGenAddrAsyncEvent(addr, newAddr)
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+
+ // The last address generated should resolve DAD.
+ addr = test.addrs[len(test.addrs)-1]
+ expectAutoGenAddrAsyncEvent(addr, newAddr)
+ expectDADEventAsync(addr.Address)
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpected auto gen addr event = %+v", e)
+ default:
+ }
+ })
+ }
+}
+
+// stackAndNdpDispatcherWithDefaultRoute returns an ndpDispatcher,
+// channel.Endpoint and stack.Stack.
+//
+// stack.Stack will have a default route through the router (llAddr3) installed
+// and a static link-address (linkAddr3) added to the link address cache for the
+// router.
+func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) {
+ t.Helper()
+ ndpDisp := &ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+ return ndpDisp, e, s
+}
+
+// addrForNewConnectionTo returns the local address used when creating a new
+// connection to addr.
+func addrForNewConnectionTo(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Connect(addr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", addr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// addrForNewConnection returns the local address used when creating a new
+// connection.
+func addrForNewConnection(t *testing.T, s *stack.Stack) tcpip.Address {
+ t.Helper()
+
+ return addrForNewConnectionTo(t, s, dstAddr)
+}
+
+// addrForNewConnectionWithAddr returns the local address used when creating a
+// new connection with a specific local address.
+func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullAddress) tcpip.Address {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+ if err := ep.Bind(addr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", addr, err)
+ }
+ if err := ep.Connect(dstAddr); err != nil {
+ t.Fatalf("ep.Connect(%+v): %s", dstAddr, err)
+ }
+ got, err := ep.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("ep.GetLocalAddress(): %s", err)
+ }
+ return got.Addr
+}
+
+// TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when
+// receiving a PI with 0 preferred lifetime.
+func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Deprecate addr for prefix1 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr1, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ // addr should still be the primary endpoint as there are no other addresses.
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetimes of addr generated from prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Deprecate addr for prefix2 immedaitely.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, deprecatedAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr1 should be the primary endpoint now since addr2 is deprecated but
+ // addr1 is not.
+ expectPrimaryAddr(addr1)
+ // addr2 is deprecated but if explicitly requested, it should be used.
+ fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
+
+ // Another PI w/ 0 preferred lifetime should not result in a deprecation
+ // event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address)
+ }
+
+ // Refresh lifetimes of addr generated from prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+}
+
+// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated
+// when its preferred lifetime expires.
+func TestAutoGenAddrJobDeprecation(t *testing.T) {
+ const nicID = 1
+ const newMinVL = 2
+ newMinVLDuration := newMinVL * time.Second
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(timeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive PI for prefix2.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Receive a PI for prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr1)
+
+ // Refresh lifetime for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since addr1 is deprecated but
+ // addr2 is not.
+ expectPrimaryAddr(addr2)
+ // addr1 is deprecated but if explicitly requested, it should be used.
+ fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID}
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
+
+ // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make
+ // sure we do not get a deprecation event again.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
+
+ // Refresh lifetimes for addr of prefix1.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+ // addr1 is the primary endpoint again since it is non-deprecated now.
+ expectPrimaryAddr(addr1)
+
+ // Wait for addr of prefix1 to be deprecated.
+ expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ // addr2 should be the primary endpoint now since it is not deprecated.
+ expectPrimaryAddr(addr2)
+ if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address {
+ t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout)
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+ expectPrimaryAddr(addr2)
+
+ // Refresh both lifetimes for addr of prefix2 to the same value.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ default:
+ }
+
+ // Wait for a deprecation then invalidation events, or just an invalidation
+ // event. We need to cover both cases but cannot deterministically hit both
+ // cases because the deprecation and invalidation handlers could be handled in
+ // either deprecation then invalidation, or invalidation then deprecation
+ // (which should be cancelled by the invalidation handler).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" {
+ // If we get a deprecation event first, we should get an invalidation
+ // event almost immediately after.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" {
+ // If we get an invalidation event first, we should not get a deprecation
+ // event after.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto-generated event")
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ } else {
+ t.Fatalf("got unexpected auto-generated event")
+ }
+ case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should not have %s in the list of addresses", addr2)
+ }
+ // Should not have any primary endpoints.
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+ ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err)
+ }
+ defer ep.Close()
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err)
+ }
+
+ if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute {
+ t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute)
+ }
+}
+
+// Tests transitioning a SLAAC address's valid lifetime between finite and
+// infinite values.
+func TestAutoGenAddrFiniteToInfiniteToFiniteVL(t *testing.T) {
+ const infiniteVLSeconds = 2
+ const minVLSeconds = 1
+ savedIL := header.NDPInfiniteLifetime
+ savedMinVL := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = savedMinVL
+ header.NDPInfiniteLifetime = savedIL
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = minVLSeconds * time.Second
+ header.NDPInfiniteLifetime = infiniteVLSeconds * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ infiniteVL uint32
+ }{
+ {
+ name: "EqualToInfiniteVL",
+ infiniteVL: infiniteVLSeconds,
+ },
+ // Our implementation supports changing header.NDPInfiniteLifetime for tests
+ // such that a packet can be received where the lifetime field has a value
+ // greater than header.NDPInfiniteLifetime. Because of this, we test to make
+ // sure that receiving a value greater than header.NDPInfiniteLifetime is
+ // handled the same as when receiving a value equal to
+ // header.NDPInfiniteLifetime.
+ {
+ name: "MoreThanInfiniteVL",
+ infiniteVL: infiniteVLSeconds + 1,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with finite prefix.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with infinite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.infiniteVL, 0))
+
+ // Receive a new RA with prefix with finite VL.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, minVLSeconds, 0))
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ case <-time.After(minVLSeconds*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
+// TestAutoGenAddrValidLifetimeUpdates tests that the valid lifetime of an
+// auto-generated address only gets updated when required to, as specified in
+// RFC 4862 section 5.5.3.e.
+func TestAutoGenAddrValidLifetimeUpdates(t *testing.T) {
+ const infiniteVL = 4294967295
+ const newMinVL = 4
+ saved := stack.MinPrefixInformationValidLifetimeForUpdate
+ defer func() {
+ stack.MinPrefixInformationValidLifetimeForUpdate = saved
+ }()
+ stack.MinPrefixInformationValidLifetimeForUpdate = newMinVL * time.Second
+
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ tests := []struct {
+ name string
+ ovl uint32
+ nvl uint32
+ evl uint32
+ }{
+ // Should update the VL to the minimum VL for updating if the
+ // new VL is less than newMinVL but was originally greater than
+ // it.
+ {
+ "LargeVLToVLLessThanMinVLForUpdate",
+ 9999,
+ 1,
+ newMinVL,
+ },
+ {
+ "LargeVLTo0",
+ 9999,
+ 0,
+ newMinVL,
+ },
+ {
+ "InfiniteVLToVLLessThanMinVLForUpdate",
+ infiniteVL,
+ 1,
+ newMinVL,
+ },
+ {
+ "InfiniteVLTo0",
+ infiniteVL,
+ 0,
+ newMinVL,
+ },
+
+ // Should not update VL if original VL was less than newMinVL
+ // and the new VL is also less than newMinVL.
+ {
+ "ShouldNotUpdateWhenBothOldAndNewAreLessThanMinVLForUpdate",
+ newMinVL - 1,
+ newMinVL - 3,
+ newMinVL - 1,
+ },
+
+ // Should take the new VL if the new VL is greater than the
+ // remaining time or is greater than newMinVL.
+ {
+ "MorethanMinVLToLesserButStillMoreThanMinVLForUpdate",
+ newMinVL + 5,
+ newMinVL + 3,
+ newMinVL + 3,
+ },
+ {
+ "SmallVLToGreaterVLButStillLessThanMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL - 1,
+ newMinVL - 1,
+ },
+ {
+ "SmallVLToGreaterVLThatIsMoreThaMinVLForUpdate",
+ newMinVL - 3,
+ newMinVL + 1,
+ newMinVL + 1,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 10),
+ }
+ e := channel.New(10, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Receive an RA with prefix with initial VL,
+ // test.ovl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.ovl, 0))
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+
+ // Receive an new RA with prefix with new VL,
+ // test.nvl.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, test.nvl, 0))
+
+ //
+ // Validate that the VL for the address got set
+ // to test.evl.
+ //
+
+ // The address should not be invalidated until the effective valid
+ // lifetime has passed.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(time.Duration(test.evl)*time.Second - defaultAsyncNegativeEventTimeout):
+ }
+
+ // Wait for the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timeout waiting for addr auto gen event")
+ }
+ })
+ }
+ })
+}
+
+// TestAutoGenAddrRemoval tests that when auto-generated addresses are removed
+// by the user, its resources will be cleaned up and an invalidation event will
+// be sent to the integrator.
+func TestAutoGenAddrRemoval(t *testing.T) {
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive a PI to auto-generate an address.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Removing the address should result in an invalidation event
+ // immediately.
+ if err := s.RemoveAddress(1, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+
+ // Wait for the original valid lifetime to make sure the original job got
+ // cancelled/cleaned up.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+}
+
+// TestAutoGenAddrAfterRemoval tests adding a SLAAC address that was previously
+// assigned to the NIC but is in the permanentExpired state.
+func TestAutoGenAddrAfterRemoval(t *testing.T) {
+ const nicID = 1
+
+ prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1)
+ ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID)
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
+ t.Helper()
+
+ if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
+ t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
+ } else if got != addr {
+ t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ }
+
+ if got := addrForNewConnection(t, s); got != addr.Address {
+ t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address)
+ }
+ }
+
+ // Receive a PI to auto-generate addr1 with a large valid and preferred
+ // lifetime.
+ const largeLifetimeSeconds = 999
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ expectPrimaryAddr(addr1)
+
+ // Add addr2 as a static address.
+ protoAddr2 := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: addr2,
+ }
+ if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil {
+ t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err)
+ }
+ // addr2 should be more preferred now since it is at the front of the primary
+ // list.
+ expectPrimaryAddr(addr2)
+
+ // Get a route using addr2 to increment its reference count then remove it
+ // to leave it in the permanentExpired state.
+ r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err)
+ }
+ defer r.Release()
+ if err := s.RemoveAddress(nicID, addr2.Address); err != nil {
+ t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err)
+ }
+ // addr1 should be preferred again since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and preferred.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr2 should be more preferred now that it is closer to the front of the
+ // primary list and not deprecated.
+ expectPrimaryAddr(addr2)
+
+ // Removing the address should result in an invalidation event immediately.
+ // It should still be in the permanentExpired state because r is still held.
+ //
+ // We remove addr2 here to make sure addr2 was marked as a SLAAC address
+ // (it was previously marked as a static address).
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ // addr1 should be more preferred since addr2 is in the expired state.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to auto-generate addr2 as valid and deprecated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ // addr1 should still be more preferred since addr2 is deprecated, even though
+ // it is closer to the front of the primary list.
+ expectPrimaryAddr(addr1)
+
+ // Receive a PI to refresh addr2's preferred lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly got an auto gen addr event")
+ default:
+ }
+ // addr2 should be more preferred now that it is not deprecated.
+ expectPrimaryAddr(addr2)
+
+ if err := s.RemoveAddress(1, addr2.Address); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err)
+ }
+ expectAutoGenAddrEvent(addr2, invalidatedAddr)
+ expectPrimaryAddr(addr1)
+}
+
+// TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that
+// is already assigned to the NIC, the static address remains.
+func TestAutoGenAddrStaticConflict(t *testing.T) {
+ prefix, _, addr := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ // Add the address as a static address before SLAAC tries to add it.
+ if err := s.AddProtocolAddress(1, tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: addr}); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Receive a PI where the generated address will be the same as the one
+ // that we already have assigned statically.
+ const lifetimeSeconds = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, 0))
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event for an address we already have statically")
+ default:
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+
+ // Should not get an invalidation event after the PI's invalidation
+ // time.
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly received an auto gen addr event")
+ case <-time.After(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout):
+ }
+ if !containsV6Addr(s.NICInfo()[1].ProtocolAddresses, addr) {
+ t.Fatalf("Should have %s in the list of addresses", addr1)
+ }
+}
+
+// TestAutoGenAddrWithOpaqueIID tests that SLAAC generated addresses will use
+// opaque interface identifiers when configured to do so.
+func TestAutoGenAddrWithOpaqueIID(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic1"
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix1, subnet1, _ := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, _ := prefixSubnetAddr(1, linkAddr1)
+ // addr1 and addr2 are the addresses that are expected to be generated when
+ // stack.Stack is configured to generate opaque interface identifiers as
+ // defined by RFC 7217.
+ addrBytes := []byte(subnet1.ID())
+ addr1 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet1, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ addrBytes = []byte(subnet2.ID())
+ addr2 := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet2, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v, _) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix1 in a PI.
+ const validLifetimeSecondPrefix1 = 1
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, validLifetimeSecondPrefix1, 0))
+ expectAutoGenAddrEvent(addr1, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+
+ // Receive an RA with prefix2 in a PI with a large valid lifetime.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0))
+ expectAutoGenAddrEvent(addr2, newAddr)
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+
+ // Wait for addr of prefix1 to be invalidated.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr1, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(validLifetimeSecondPrefix1*time.Second + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) {
+ t.Fatalf("should not have %s in the list of addresses", addr1)
+ }
+ if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) {
+ t.Fatalf("should have %s in the list of addresses", addr2)
+ }
+}
+
+func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxMaxRetries = 3
+ const lifetimeSeconds = 10
+
+ // Needed for the temporary address sub test.
+ savedMaxDesync := stack.MaxDesyncFactor
+ defer func() {
+ stack.MaxDesyncFactor = savedMaxDesync
+ }()
+ stack.MaxDesyncFactor = time.Nanosecond
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ addrForSubnet := func(subnet tcpip.Subnet, dadCounter uint8) tcpip.AddressWithPrefix {
+ addrBytes := []byte(subnet.ID())
+ return tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, dadCounter, secretKey)),
+ PrefixLen: 64,
+ }
+ }
+
+ expectAutoGenAddrEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ expectAutoGenAddrEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for addr auto gen event")
+ }
+ }
+
+ expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+ }
+
+ expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+ }
+
+ stableAddrForTempAddrTest := addrForSubnet(subnet, 0)
+
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ prepareFn func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix
+ addrGenFn func(dadCounter uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ },
+ prepareFn: func(_ *testing.T, _ *ndpDispatcher, e *channel.Endpoint, _ []byte) []tcpip.AddressWithPrefix {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+ return nil
+
+ },
+ addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
+ return addrForSubnet(subnet, dadCounter)
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ autoGenLinkLocal: true,
+ prepareFn: func(*testing.T, *ndpDispatcher, *channel.Endpoint, []byte) []tcpip.AddressWithPrefix {
+ return nil
+ },
+ addrGenFn: func(dadCounter uint8, _ []byte) tcpip.AddressWithPrefix {
+ return addrForSubnet(header.IPv6LinkLocalPrefix.Subnet(), dadCounter)
+ },
+ },
+ {
+ name: "Temporary address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ prepareFn: func(t *testing.T, ndpDisp *ndpDispatcher, e *channel.Endpoint, tempIIDHistory []byte) []tcpip.AddressWithPrefix {
+ header.InitialTempIID(tempIIDHistory, nil, nicID)
+
+ // Generate a stable SLAAC address so temporary addresses will be
+ // generated.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
+ expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
+ expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true)
+
+ // The stable address will be assigned throughout the test.
+ return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
+ },
+ addrGenFn: func(_ uint8, tempIIDHistory []byte) tcpip.AddressWithPrefix {
+ return header.GenerateTempIPv6SLAACAddr(tempIIDHistory, stableAddrForTempAddrTest.Address)
+ },
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the parallel
+ // tests complete and limit the number of parallel tests running at the same
+ // time to reduce flakes.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run(addrType.name, func(t *testing.T) {
+ for maxRetries := uint8(0); maxRetries <= maxMaxRetries; maxRetries++ {
+ for numFailures := uint8(0); numFailures <= maxRetries+1; numFailures++ {
+ maxRetries := maxRetries
+ numFailures := numFailures
+ addrType := addrType
+
+ t.Run(fmt.Sprintf("%d max retries and %d failures", maxRetries, numFailures), func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ ndpConfigs := addrType.ndpConfigs
+ ndpConfigs.AutoGenAddressConflictRetries = maxRetries
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: ndpConfigs,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ var tempIIDHistory [header.IIDSize]byte
+ stableAddrs := addrType.prepareFn(t, &ndpDisp, e, tempIIDHistory[:])
+
+ // Simulate DAD conflicts so the address is regenerated.
+ for i := uint8(0); i < numFailures; i++ {
+ addr := addrType.addrGenFn(i, tempIIDHistory[:])
+ expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+
+ // Should not have any new addresses assigned to the NIC.
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
+ expectDADEvent(t, &ndpDisp, addr.Address, false)
+
+ // Attempting to add the address manually should not fail if the
+ // address's state was cleaned up when DAD failed.
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr.Address); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr.Address, err)
+ }
+ if err := s.RemoveAddress(nicID, addr.Address); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
+ }
+ expectDADEvent(t, &ndpDisp, addr.Address, false)
+ }
+
+ // Should not have any new addresses assigned to the NIC.
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, stableAddrs, nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+
+ // If we had less failures than generation attempts, we should have
+ // an address after DAD resolves.
+ if maxRetries+1 > numFailures {
+ addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
+ expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
+ expectDADEventAsync(t, &ndpDisp, addr.Address, true)
+ if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
+ t.Fatal(mismatch)
+ }
+ }
+
+ // Should not attempt address generation again.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ })
+ }
+ }
+ })
+ }
+}
+
+// TestAutoGenAddrWithEUI64IIDNoDADRetries tests that a regeneration attempt is
+// not made for SLAAC addresses generated with an IID based on the NIC's link
+// address.
+func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
+ const nicID = 1
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const maxRetries = 3
+ const lifetimeSeconds = 10
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ addrTypes := []struct {
+ name string
+ ndpConfigs stack.NDPConfigurations
+ autoGenLinkLocal bool
+ subnet tcpip.Subnet
+ triggerSLAACFn func(e *channel.Endpoint)
+ }{
+ {
+ name: "Global address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ subnet: subnet,
+ triggerSLAACFn: func(e *channel.Endpoint) {
+ // Receive an RA with prefix1 in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ },
+ },
+ {
+ name: "LinkLocal address",
+ ndpConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ autoGenLinkLocal: true,
+ subnet: header.IPv6LinkLocalPrefix.Subnet(),
+ triggerSLAACFn: func(e *channel.Endpoint) {},
+ },
+ }
+
+ for _, addrType := range addrTypes {
+ addrType := addrType
+
+ t.Run(addrType.name, func(t *testing.T) {
+ t.Parallel()
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: addrType.autoGenLinkLocal,
+ NDPConfigs: addrType.ndpConfigs,
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ addrType.triggerSLAACFn(e)
+
+ addrBytes := []byte(addrType.subnet.ID())
+ header.EthernetAdddressToModifiedEUI64IntoBuf(linkAddr1, addrBytes[header.IIDOffsetInIPv6Address:])
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(addrBytes),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict.
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Should not attempt address regeneration.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ t.Fatalf("unexpectedly got an auto-generated address event = %+v", e)
+ case <-time.After(defaultAsyncNegativeEventTimeout):
+ }
+ })
+ }
+}
+
+// TestAutoGenAddrContinuesLifetimesAfterRetry tests that retrying address
+// generation in response to DAD conflicts does not refresh the lifetimes.
+func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
+ const nicID = 1
+ const nicName = "nic"
+ const dadTransmits = 1
+ const retransmitTimer = 2 * time.Second
+ const failureTimer = time.Second
+ const maxRetries = 1
+ const lifetimeSeconds = 5
+
+ var secretKeyBuf [header.OpaqueIIDSecretKeyMinBytes]byte
+ secretKey := secretKeyBuf[:]
+ n, err := rand.Read(secretKey)
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
+ }
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("got rand.Read(_) = (%d, _), want = (%d, _)", n, header.OpaqueIIDSecretKeyMinBytes)
+ }
+
+ prefix, subnet, _ := prefixSubnetAddr(0, linkAddr1)
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 2),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenAddressConflictRetries: maxRetries,
+ },
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ SecretKey: secretKey,
+ },
+ })
+ opts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
+ }
+
+ expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) {
+ t.Helper()
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
+ }
+ }
+
+ // Receive an RA with prefix in a PI.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, lifetimeSeconds, lifetimeSeconds))
+
+ addrBytes := []byte(subnet.ID())
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 0, secretKey)),
+ PrefixLen: 64,
+ }
+ expectAutoGenAddrEvent(addr, newAddr)
+
+ // Simulate a DAD conflict after some time has passed.
+ time.Sleep(failureTimer)
+ if err := s.DupTentativeAddrDetected(nicID, addr.Address); err != nil {
+ t.Fatalf("s.DupTentativeAddrDetected(%d, %s): %s", nicID, addr.Address, err)
+ }
+ expectAutoGenAddrEvent(addr, invalidatedAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DAD event")
+ }
+
+ // Let the next address resolve.
+ addr.Address = tcpip.Address(header.AppendOpaqueInterfaceIdentifier(addrBytes[:header.IIDOffsetInIPv6Address], subnet, nicName, 1, secretKey))
+ expectAutoGenAddrEvent(addr, newAddr)
+ select {
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD event")
+ }
+
+ // Address should be deprecated/invalidated after the lifetime expires.
+ //
+ // Note, the remaining lifetime is calculated from when the PI was first
+ // processed. Since we wait for some time before simulating a DAD conflict
+ // and more time for the new address to resolve, the new address is only
+ // expected to be valid for the remaining time. The DAD conflict should
+ // not have reset the lifetimes.
+ //
+ // We expect either just the invalidation event or the deprecation event
+ // followed by the invalidation event.
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if e.eventType == deprecatedAddr {
+ if diff := checkAutoGenAddrEvent(e, addr, deprecatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ case <-time.After(defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for invalidated auto gen addr event after deprecation")
+ }
+ } else {
+ if diff := checkAutoGenAddrEvent(e, addr, invalidatedAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ case <-time.After(lifetimeSeconds*time.Second - failureTimer - dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for auto gen addr event")
+ }
+}
+
+// TestNDPRecursiveDNSServerDispatch tests that we properly dispatch an event
+// to the integrator when an RA is received with the NDP Recursive DNS Server
+// option with at least one valid address.
+func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
+ tests := []struct {
+ name string
+ opt header.NDPRecursiveDNSServer
+ expected *ndpRDNSS
+ }{
+ {
+ "Unspecified",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ }),
+ nil,
+ },
+ {
+ "Multicast",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ nil,
+ },
+ {
+ "OptionTooSmall",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8,
+ }),
+ nil,
+ },
+ {
+ "0Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ }),
+ nil,
+ },
+ {
+ "Valid1Address",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ },
+ 2 * time.Second,
+ },
+ },
+ {
+ "Valid2Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ },
+ time.Second,
+ },
+ },
+ {
+ "Valid3Addresses",
+ header.NDPRecursiveDNSServer([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 1,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 2,
+ 1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 3,
+ }),
+ &ndpRDNSS{
+ []tcpip.Address{
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x01",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x02",
+ "\x01\x02\x03\x04\x05\x06\x07\x08\x00\x00\x00\x00\x00\x00\x00\x03",
+ },
+ 0,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ // We do not expect more than a single RDNSS
+ // event at any time for this test.
+ rdnssC: make(chan ndpRDNSSEvent, 1),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(1) = %s", err)
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, header.NDPOptionsSerializer{test.opt}))
+
+ if test.expected != nil {
+ select {
+ case e := <-ndpDisp.rdnssC:
+ if e.nicID != 1 {
+ t.Errorf("got rdnss nicID = %d, want = 1", e.nicID)
+ }
+ if diff := cmp.Diff(e.rdnss.addrs, test.expected.addrs); diff != "" {
+ t.Errorf("rdnss addrs mismatch (-want +got):\n%s", diff)
+ }
+ if e.rdnss.lifetime != test.expected.lifetime {
+ t.Errorf("got rdnss lifetime = %s, want = %s", e.rdnss.lifetime, test.expected.lifetime)
+ }
+ default:
+ t.Fatal("expected an RDNSS option event")
+ }
+ }
+
+ // Should have no more RDNSS options.
+ select {
+ case e := <-ndpDisp.rdnssC:
+ t.Fatalf("unexpectedly got a new RDNSS option event: %+v", e)
+ default:
+ }
+ })
+ }
+}
+
+// TestNDPDNSSearchListDispatch tests that the integrator is informed when an
+// NDP DNS Search List option is received with at least one domain name in the
+// search list.
+func TestNDPDNSSearchListDispatch(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dnsslC: make(chan ndpDNSSLEvent, 3),
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ optSer := header.NDPOptionsSerializer{
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 0,
+ 2, 'h', 'i',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 0, 1,
+ 1, 'i',
+ 0,
+ 2, 'a', 'm',
+ 2, 'm', 'e',
+ 0,
+ }),
+ header.NDPDNSSearchList([]byte{
+ 0, 0,
+ 0, 0, 1, 0,
+ 3, 'x', 'y', 'z',
+ 0,
+ 5, 'h', 'e', 'l', 'l', 'o',
+ 5, 'w', 'o', 'r', 'l', 'd',
+ 0,
+ 4, 't', 'h', 'i', 's',
+ 2, 'i', 's',
+ 1, 'a',
+ 4, 't', 'e', 's', 't',
+ 0,
+ }),
+ }
+ expected := []struct {
+ domainNames []string
+ lifetime time.Duration
+ }{
+ {
+ domainNames: []string{
+ "hi",
+ },
+ lifetime: 0,
+ },
+ {
+ domainNames: []string{
+ "i",
+ "am.me",
+ },
+ lifetime: time.Second,
+ },
+ {
+ domainNames: []string{
+ "xyz",
+ "hello.world",
+ "this.is.a.test",
+ },
+ lifetime: 256 * time.Second,
+ },
+ }
+
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithOpts(llAddr1, 0, optSer))
+
+ for i, expected := range expected {
+ select {
+ case dnssl := <-ndpDisp.dnsslC:
+ if dnssl.nicID != nicID {
+ t.Errorf("got %d-th dnssl nicID = %d, want = %d", i, dnssl.nicID, nicID)
+ }
+ if diff := cmp.Diff(dnssl.domainNames, expected.domainNames); diff != "" {
+ t.Errorf("%d-th dnssl domain names mismatch (-want +got):\n%s", i, diff)
+ }
+ if dnssl.lifetime != expected.lifetime {
+ t.Errorf("got %d-th dnssl lifetime = %s, want = %s", i, dnssl.lifetime, expected.lifetime)
+ }
+ default:
+ t.Fatal("expected a DNSSL event")
+ }
+ }
+
+ // Should have no more DNSSL options.
+ select {
+ case <-ndpDisp.dnsslC:
+ t.Fatal("unexpectedly got a DNSSL event")
+ default:
+ }
+}
+
+// TestCleanupNDPState tests that all discovered routers and prefixes, and
+// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestCleanupNDPState(t *testing.T) {
+ const (
+ lifetimeSeconds = 5
+ maxRouterAndPrefixEvents = 4
+ nicID1 = 1
+ nicID2 = 2
+ )
+
+ prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, subnet2, e1Addr2 := prefixSubnetAddr(1, linkAddr1)
+ e2Addr1 := addrForSubnet(subnet1, linkAddr2)
+ e2Addr2 := addrForSubnet(subnet2, linkAddr2)
+ llAddrWithPrefix1 := tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 64,
+ }
+ llAddrWithPrefix2 := tcpip.AddressWithPrefix{
+ Address: llAddr2,
+ PrefixLen: 64,
+ }
+
+ tests := []struct {
+ name string
+ cleanupFn func(t *testing.T, s *stack.Stack)
+ keepAutoGenLinkLocal bool
+ maxAutoGenAddrEvents int
+ skipFinalAddrCheck bool
+ }{
+ // A NIC should still keep its auto-generated link-local address when
+ // becoming a router.
+ {
+ name: "Enable forwarding",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(ipv6.ProtocolNumber, true)
+ },
+ keepAutoGenLinkLocal: true,
+ maxAutoGenAddrEvents: 4,
+ },
+
+ // A NIC should cleanup all NDP state when it is disabled.
+ {
+ name: "Disable NIC",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ },
+
+ // A NIC should cleanup all NDP state when it is removed.
+ {
+ name: "Remove NIC",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.RemoveNIC(nicID1); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID1, err)
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ // The NICs are removed so we can't check their addresses after calling
+ // stopFn.
+ skipFinalAddrCheck: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
+
+ return false, ndpRouterEvent{}
+ }
+
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
+
+ return false, ndpPrefixEvent{}
+ }
+
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
+
+ return false, ndpAutoGenAddrEvent{}
+ }
+
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ expectAutoGenAddrEvent()
+
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+ expectAutoGenAddrEvent()
+
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
+ // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
+ // llAddr4) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
+
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
+
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
+
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
+
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
+
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ test.cleanupFn(t, s)
+
+ // Collect invalidation events after having NDP state cleaned up.
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < test.maxAutoGenAddrEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
+
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+
+ if !test.keepAutoGenLinkLocal {
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
+ }
+
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ if !test.skipFinalAddrCheck {
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ }
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when the NDP state was cleaned up).
+ time.Sleep(lifetimeSeconds*time.Second + defaultAsyncNegativeEventTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+ })
+ }
+}
+
+// TestDHCPv6ConfigurationFromNDPDA tests that the NDPDispatcher is properly
+// informed when new information about what configurations are available via
+// DHCPv6 is learned.
+func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dhcpv6ConfigurationC: make(chan ndpDHCPv6Event, 1),
+ rememberRouter: true,
+ }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
+
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ expectDHCPv6Event := func(configuration stack.DHCPv6ConfigurationFromNDPRA) {
+ t.Helper()
+ select {
+ case e := <-ndpDisp.dhcpv6ConfigurationC:
+ if diff := cmp.Diff(ndpDHCPv6Event{nicID: nicID, configuration: configuration}, e, cmp.AllowUnexported(e)); diff != "" {
+ t.Errorf("dhcpv6 event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected DHCPv6 configuration event")
+ }
+ }
+
+ expectNoDHCPv6Event := func() {
+ t.Helper()
+ select {
+ case <-ndpDisp.dhcpv6ConfigurationC:
+ t.Fatal("unexpected DHCPv6 configuration event")
+ default:
+ }
+ }
+
+ // Even if the first RA reports no DHCPv6 configurations are available, the
+ // dispatcher should get an event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ // Receiving the same update again should not result in an event to the
+ // dispatcher.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to none.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectDHCPv6Event(stack.DHCPv6NoConfiguration)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, false))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Managed Address.
+ //
+ // Note, when the M flag is set, the O flag is redundant.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectDHCPv6Event(stack.DHCPv6ManagedAddress)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+ // Even though the DHCPv6 flags are different, the effective configuration is
+ // the same so we should not receive a new event.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, false))
+ expectNoDHCPv6Event()
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, true, true))
+ expectNoDHCPv6Event()
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+
+ // Cycling the NIC should cause the last DHCPv6 configuration to be cleared.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ // Receive an RA that updates the DHCPv6 configuration to Other
+ // Configurations.
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectDHCPv6Event(stack.DHCPv6OtherConfigurations)
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithDHCPv6(llAddr2, false, true))
+ expectNoDHCPv6Event()
+}
+
+// TestRouterSolicitation tests the initial Router Solicitations that are sent
+// when a NIC newly becomes enabled.
+func TestRouterSolicitation(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ linkHeaderLen uint16
+ linkAddr tcpip.LinkAddress
+ nicAddr tcpip.Address
+ expectedSrcAddr tcpip.Address
+ expectedNDPOpts []header.NDPOption
+ maxRtrSolicit uint8
+ rtrSolicitInt time.Duration
+ effectiveRtrSolicitInt time.Duration
+ maxRtrSolicitDelay time.Duration
+ effectiveMaxRtrSolicitDelay time.Duration
+ }{
+ {
+ name: "Single RS with 2s delay and interval",
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 2 * time.Second,
+ effectiveMaxRtrSolicitDelay: 2 * time.Second,
+ },
+ {
+ name: "Single RS with 4s delay and interval",
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 4 * time.Second,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 4 * time.Second,
+ effectiveMaxRtrSolicitDelay: 4 * time.Second,
+ },
+ {
+ name: "Two RS with delay",
+ linkHeaderLen: 1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
+ maxRtrSolicit: 2,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 500 * time.Millisecond,
+ effectiveMaxRtrSolicitDelay: 500 * time.Millisecond,
+ },
+ {
+ name: "Single RS without delay",
+ linkHeaderLen: 2,
+ linkAddr: linkAddr1,
+ nicAddr: llAddr1,
+ expectedSrcAddr: llAddr1,
+ expectedNDPOpts: []header.NDPOption{
+ header.NDPSourceLinkLayerAddressOption(linkAddr1),
+ },
+ maxRtrSolicit: 1,
+ rtrSolicitInt: 2 * time.Second,
+ effectiveRtrSolicitInt: 2 * time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS without delay and invalid zero interval",
+ linkHeaderLen: 3,
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 2,
+ rtrSolicitInt: 0,
+ effectiveRtrSolicitInt: 4 * time.Second,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Three RS without delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 3,
+ rtrSolicitInt: 500 * time.Millisecond,
+ effectiveRtrSolicitInt: 500 * time.Millisecond,
+ maxRtrSolicitDelay: 0,
+ effectiveMaxRtrSolicitDelay: 0,
+ },
+ {
+ name: "Two RS with invalid negative delay",
+ linkAddr: linkAddr1,
+ expectedSrcAddr: header.IPv6Any,
+ maxRtrSolicit: 2,
+ rtrSolicitInt: time.Second,
+ effectiveRtrSolicitInt: time.Second,
+ maxRtrSolicitDelay: -3 * time.Second,
+ effectiveMaxRtrSolicitDelay: time.Second,
+ },
+ }
+
+ // This Run will not return until the parallel tests finish.
+ //
+ // We need this because we need to do some teardown work after the
+ // parallel tests complete.
+ //
+ // See https://godoc.org/testing#hdr-Subtests_and_Sub_benchmarks for
+ // more details.
+ t.Run("group", func(t *testing.T) {
+ for _, test := range tests {
+ test := test
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, test.linkAddr),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+
+ // Make sure the right remote link address is used.
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(test.expectedSrcAddr),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
+ )
+
+ if l, want := p.Pkt.AvailableHeaderBytes(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.AvailableHeaderBytes() = %d; want = %d", l, want)
+ }
+ }
+ waitForNothing := func(timeout time.Duration) {
+ t.Helper()
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet")
+ }
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: test.maxRtrSolicit,
+ RtrSolicitationInterval: test.rtrSolicitInt,
+ MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
+ },
+ })
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ if addr := test.nicAddr; addr != "" {
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr, err)
+ }
+ }
+
+ // Make sure each RS is sent at the right time.
+ remaining := test.maxRtrSolicit
+ if remaining > 0 {
+ waitForPkt(test.effectiveMaxRtrSolicitDelay + defaultAsyncPositiveEventTimeout)
+ remaining--
+ }
+
+ for ; remaining > 0; remaining-- {
+ if test.effectiveRtrSolicitInt > defaultAsyncPositiveEventTimeout {
+ waitForNothing(test.effectiveRtrSolicitInt - defaultAsyncNegativeEventTimeout)
+ waitForPkt(defaultAsyncPositiveEventTimeout)
+ } else {
+ waitForPkt(test.effectiveRtrSolicitInt + defaultAsyncPositiveEventTimeout)
+ }
+ }
+
+ // Make sure no more RS.
+ if test.effectiveRtrSolicitInt > test.effectiveMaxRtrSolicitDelay {
+ waitForNothing(test.effectiveRtrSolicitInt + defaultAsyncNegativeEventTimeout)
+ } else {
+ waitForNothing(test.effectiveMaxRtrSolicitDelay + defaultAsyncNegativeEventTimeout)
+ }
+
+ // Make sure the counter got properly
+ // incremented.
+ if got, want := s.Stats().ICMP.V6PacketsSent.RouterSolicit.Value(), uint64(test.maxRtrSolicit); got != want {
+ t.Fatalf("got sent RouterSolicit = %d, want = %d", got, want)
+ }
+ })
+ }
+ })
+}
+
+func TestStopStartSolicitingRouters(t *testing.T) {
+ const nicID = 1
+ const delay = 0
+ const interval = 500 * time.Millisecond
+ const maxRtrSolicitations = 3
+
+ tests := []struct {
+ name string
+ startFn func(t *testing.T, s *stack.Stack)
+ // first is used to tell stopFn that it is being called for the first time
+ // after router solicitations were last enabled.
+ stopFn func(t *testing.T, s *stack.Stack, first bool)
+ }{
+ // Tests that when forwarding is enabled or disabled, router solicitations
+ // are stopped or started, respectively.
+ {
+ name: "Enable and disable forwarding",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(ipv6.ProtocolNumber, false)
+ },
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
+ t.Helper()
+ s.SetForwarding(ipv6.ProtocolNumber, true)
+ },
+ },
+
+ // Tests that when a NIC is enabled or disabled, router solicitations
+ // are started or stopped, respectively.
+ {
+ name: "Enable and disable NIC",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ },
+ stopFn: func(t *testing.T, s *stack.Stack, _ bool) {
+ t.Helper()
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
+
+ // Tests that when a NIC is removed, router solicitations are stopped. We
+ // cannot start router solications on a removed NIC.
+ {
+ name: "Remove NIC",
+ stopFn: func(t *testing.T, s *stack.Stack, first bool) {
+ t.Helper()
+
+ // Only try to remove the NIC the first time stopFn is called since it's
+ // impossible to remove an already removed NIC.
+ if !first {
+ return
+ }
+
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Stop soliciting routers.
+ test.stopFn(t, s, true /* first */)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before solicitations were stopped.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("should not have sent more than one RS message")
+ }
+ }
+
+ // Stopping router solicitations after it has already been stopped should
+ // do nothing.
+ test.stopFn(t, s, false /* first */)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
+ }
+
+ // If test.startFn is nil, there is no way to restart router solications.
+ if test.startFn == nil {
+ return
+ }
+
+ // Start soliciting routers.
+ test.startFn(t, s)
+ waitForPkt(delay + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ waitForPkt(interval + defaultAsyncPositiveEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Starting router solicitations after it has already completed should do
+ // nothing.
+ test.startFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultAsyncNegativeEventTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after finishing router solicitations")
}
})
}
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
new file mode 100644
index 000000000..27e1feec0
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -0,0 +1,333 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const neighborCacheSize = 512 // max entries per interface
+
+// neighborCache maps IP addresses to link addresses. It uses the Least
+// Recently Used (LRU) eviction strategy to implement a bounded cache for
+// dynmically acquired entries. It contains the state machine and configuration
+// for running Neighbor Unreachability Detection (NUD).
+//
+// There are two types of entries in the neighbor cache:
+// 1. Dynamic entries are discovered automatically by neighbor discovery
+// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm
+// reachability with the device once the entry's state becomes Stale.
+// 2. Static entries are explicitly added by a user and have no expiration.
+// Their state is always Static. The amount of static entries stored in the
+// cache is unbounded.
+//
+// neighborCache implements NUDHandler.
+type neighborCache struct {
+ nic *NIC
+ state *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ cache map[tcpip.Address]*neighborEntry
+ dynamic struct {
+ lru neighborEntryList
+
+ // count tracks the amount of dynamic entries in the cache. This is
+ // needed since static entries do not count towards the LRU cache
+ // eviction strategy.
+ count uint16
+ }
+}
+
+var _ NUDHandler = (*neighborCache)(nil)
+
+// getOrCreateEntry retrieves a cache entry associated with addr. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[remoteAddr]; ok {
+ entry.mu.RLock()
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.lru.PushFront(entry)
+ }
+ entry.mu.RUnlock()
+ return entry
+ }
+
+ // The entry that needs to be created must be dynamic since all static
+ // entries are directly added to the cache via addStaticEntry.
+ entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes)
+ if n.dynamic.count == neighborCacheSize {
+ e := n.dynamic.lru.Back()
+ e.mu.Lock()
+
+ delete(n.cache, e.neigh.Addr)
+ n.dynamic.lru.Remove(e)
+ n.dynamic.count--
+
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Unknown)
+ e.notifyWakersLocked()
+ e.mu.Unlock()
+ }
+ n.cache[remoteAddr] = entry
+ n.dynamic.lru.PushFront(entry)
+ n.dynamic.count++
+ return entry
+}
+
+// entry looks up the neighbor cache for translating address to link address
+// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there
+// is a LinkAddressResolver registered with the network protocol, the cache
+// attempts to resolve the address and returns ErrWouldBlock. If a Waker is
+// provided, it will be notified when address resolution is complete (success
+// or not).
+//
+// If address resolution is required, ErrNoLinkAddress and a notification
+// channel is returned for the top level caller to block. Channel is closed
+// once address resolution is complete (success or not).
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+ if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok {
+ e := NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ }
+ return e, nil, nil
+ }
+
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ switch s := entry.neigh.State; s {
+ case Reachable, Static:
+ return entry.neigh, nil, nil
+
+ case Unknown, Incomplete, Stale, Delay, Probe:
+ entry.addWakerLocked(w)
+
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+ }
+ entry.done = make(chan struct{})
+ }
+
+ entry.handlePacketQueuedLocked()
+ return entry.neigh, entry.done, tcpip.ErrWouldBlock
+
+ case Failed:
+ return entry.neigh, nil, tcpip.ErrNoLinkAddress
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", s))
+ }
+}
+
+// removeWaker removes a waker that has been added when link resolution for
+// addr was requested.
+func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) {
+ n.mu.Lock()
+ if entry, ok := n.cache[addr]; ok {
+ delete(entry.wakers, waker)
+ }
+ n.mu.Unlock()
+}
+
+// entries returns all entries in the neighbor cache.
+func (n *neighborCache) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(n.cache))
+ n.mu.RLock()
+ for _, entry := range n.cache {
+ entry.mu.RLock()
+ entries = append(entries, entry.neigh)
+ entry.mu.RUnlock()
+ }
+ n.mu.RUnlock()
+ return entries
+}
+
+// addStaticEntry adds a static entry to the neighbor cache, mapping an IP
+// address to a link address. If a dynamic entry exists in the neighbor cache
+// with the same address, it will be replaced with this static entry. If a
+// static entry exists with the same address but different link address, it
+// will be updated with the new link address. If a static entry exists with the
+// same address and link address, nothing will happen.
+func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if entry, ok := n.cache[addr]; ok {
+ entry.mu.Lock()
+ if entry.neigh.State != Static {
+ // Dynamic entry found with the same address.
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ } else if entry.neigh.LinkAddr == linkAddr {
+ // Static entry found with the same address and link address.
+ entry.mu.Unlock()
+ return
+ } else {
+ // Static entry found with the same address but different link address.
+ entry.neigh.LinkAddr = linkAddr
+ entry.dispatchChangeEventLocked(entry.neigh.State)
+ entry.mu.Unlock()
+ return
+ }
+
+ // Notify that resolution has been interrupted, just in case the entry was
+ // in the Incomplete or Probe state.
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state)
+ n.cache[addr] = entry
+}
+
+// removeEntryLocked removes the specified entry from the neighbor cache.
+func (n *neighborCache) removeEntryLocked(entry *neighborEntry) {
+ if entry.neigh.State != Static {
+ n.dynamic.lru.Remove(entry)
+ n.dynamic.count--
+ }
+ if entry.neigh.State != Failed {
+ entry.dispatchRemoveEventLocked()
+ }
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+
+ delete(n.cache, entry.neigh.Addr)
+}
+
+// removeEntry removes a dynamic or static entry by address from the neighbor
+// cache. Returns true if the entry was found and deleted.
+func (n *neighborCache) removeEntry(addr tcpip.Address) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ entry, ok := n.cache[addr]
+ if !ok {
+ return false
+ }
+
+ entry.mu.Lock()
+ defer entry.mu.Unlock()
+
+ n.removeEntryLocked(entry)
+ return true
+}
+
+// clear removes all dynamic and static entries from the neighbor cache.
+func (n *neighborCache) clear() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ for _, entry := range n.cache {
+ entry.mu.Lock()
+ entry.dispatchRemoveEventLocked()
+ entry.setStateLocked(Unknown)
+ entry.notifyWakersLocked()
+ entry.mu.Unlock()
+ }
+
+ n.dynamic.lru = neighborEntryList{}
+ n.cache = make(map[tcpip.Address]*neighborEntry)
+ n.dynamic.count = 0
+}
+
+// config returns the NUD configuration.
+func (n *neighborCache) config() NUDConfigurations {
+ return n.state.Config()
+}
+
+// setConfig changes the NUD configuration.
+//
+// If config contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *neighborCache) setConfig(config NUDConfigurations) {
+ config.resetInvalidFields()
+ n.state.SetConfig(config)
+}
+
+// HandleProbe implements NUDHandler.HandleProbe by following the logic defined
+// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled
+// by the caller.
+func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) {
+ entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes)
+ entry.mu.Lock()
+ entry.handleProbeLocked(remoteLinkAddr)
+ entry.mu.Unlock()
+}
+
+// HandleConfirmation implements NUDHandler.HandleConfirmation by following the
+// logic defined in RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce cryptographically
+// generated addresses, as defined in RFC 3972, Cryptographically Generated
+// Addresses (CGA). This ensures that the claimed source of an NDP message is
+// the owner of the claimed address.
+func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleConfirmationLocked(linkAddr, flags)
+ entry.mu.Unlock()
+ }
+ // The confirmation SHOULD be silently discarded if the recipient did not
+ // initiate any communication with the target. This is indicated if there is
+ // no matching entry for the remote address.
+}
+
+// HandleUpperLevelConfirmation implements
+// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in
+// RFC 4861 section 7.3.1.
+func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) {
+ n.mu.RLock()
+ entry, ok := n.cache[addr]
+ n.mu.RUnlock()
+ if ok {
+ entry.mu.Lock()
+ entry.handleUpperLevelConfirmationLocked()
+ entry.mu.Unlock()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
new file mode 100644
index 000000000..b4fa69e3e
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -0,0 +1,1726 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // entryStoreSize is the default number of entries that will be generated and
+ // added to the entry store. This number needs to be larger than the size of
+ // the neighbor cache to give ample opportunity for verifying behavior during
+ // cache overflows. Four times the size of the neighbor cache allows for
+ // three complete cache overflows.
+ entryStoreSize = 4 * neighborCacheSize
+
+ // typicalLatency is the typical latency for an ARP or NDP packet to travel
+ // to a router and back.
+ typicalLatency = time.Millisecond
+
+ // testEntryBroadcastAddr is a special address that indicates a packet should
+ // be sent to all nodes.
+ testEntryBroadcastAddr = tcpip.Address("broadcast")
+
+ // testEntryLocalAddr is the source address of neighbor probes.
+ testEntryLocalAddr = tcpip.Address("local_addr")
+
+ // testEntryBroadcastLinkAddr is a special link address sent back to
+ // multicast neighbor probes.
+ testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast")
+
+ // infiniteDuration indicates that a task will not occur in our lifetime.
+ infiniteDuration = time.Duration(math.MaxInt64)
+)
+
+// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor
+// entries. The UpdatedAt field is ignored due to a lack of a deterministic
+// method to predict the time that an event will be dispatched.
+func entryDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ }
+}
+
+// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to
+// sort slices of entries for cases where ordering must be ignored.
+func entryDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b NeighborEntry) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache {
+ config.resetInvalidFields()
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ return &neighborCache{
+ nic: &NIC{
+ stack: &Stack{
+ clock: clock,
+ nudDisp: nudDisp,
+ },
+ id: 1,
+ },
+ state: NewNUDState(config, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+}
+
+// testEntryStore contains a set of IP to NeighborEntry mappings.
+type testEntryStore struct {
+ mu sync.RWMutex
+ entriesMap map[tcpip.Address]NeighborEntry
+}
+
+func toAddress(i int) tcpip.Address {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint16(i))
+ return tcpip.Address(buf.String())
+}
+
+func toLinkAddress(i int) tcpip.LinkAddress {
+ buf := new(bytes.Buffer)
+ binary.Write(buf, binary.BigEndian, uint8(1))
+ binary.Write(buf, binary.BigEndian, uint8(0))
+ binary.Write(buf, binary.BigEndian, uint32(i))
+ return tcpip.LinkAddress(buf.String())
+}
+
+// newTestEntryStore returns a testEntryStore pre-populated with entries.
+func newTestEntryStore() *testEntryStore {
+ store := &testEntryStore{
+ entriesMap: make(map[tcpip.Address]NeighborEntry),
+ }
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ linkAddr := toLinkAddress(i)
+
+ store.entriesMap[addr] = NeighborEntry{
+ Addr: addr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: linkAddr,
+ }
+ }
+ return store
+}
+
+// size returns the number of entries in the store.
+func (s *testEntryStore) size() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return len(s.entriesMap)
+}
+
+// entry returns the entry at index i. Returns an empty entry and false if i is
+// out of bounds.
+func (s *testEntryStore) entry(i int) (NeighborEntry, bool) {
+ return s.entryByAddr(toAddress(i))
+}
+
+// entryByAddr returns the entry matching addr for situations when the index is
+// not available. Returns an empty entry and false if no entries match addr.
+func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ entry, ok := s.entriesMap[addr]
+ return entry, ok
+}
+
+// entries returns all entries in the store.
+func (s *testEntryStore) entries() []NeighborEntry {
+ entries := make([]NeighborEntry, 0, len(s.entriesMap))
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for i := 0; i < entryStoreSize; i++ {
+ addr := toAddress(i)
+ if entry, ok := s.entriesMap[addr]; ok {
+ entries = append(entries, entry)
+ }
+ }
+ return entries
+}
+
+// set modifies the link addresses of an entry.
+func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) {
+ addr := toAddress(i)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if entry, ok := s.entriesMap[addr]; ok {
+ entry.LinkAddr = linkAddr
+ s.entriesMap[addr] = entry
+ }
+}
+
+// testNeighborResolver implements LinkAddressResolver to emulate sending a
+// neighbor probe.
+type testNeighborResolver struct {
+ clock tcpip.Clock
+ neigh *neighborCache
+ entries *testEntryStore
+ delay time.Duration
+ onLinkAddressRequest func()
+}
+
+var _ LinkAddressResolver = (*testNeighborResolver)(nil)
+
+func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ // Delay handling the request to emulate network latency.
+ r.clock.AfterFunc(r.delay, func() {
+ r.fakeRequest(addr)
+ })
+
+ // Execute post address resolution action, if available.
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
+ return nil
+}
+
+// fakeRequest emulates handling a response for a link address request.
+func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) {
+ if entry, ok := r.entries.entryByAddr(addr); ok {
+ r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ }
+}
+
+func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == testEntryBroadcastAddr {
+ return testEntryBroadcastLinkAddr, true
+ }
+ return "", false
+}
+
+func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 0
+}
+
+type entryEvent struct {
+ nicID tcpip.NICID
+ address tcpip.Address
+ linkAddr tcpip.LinkAddress
+ state NeighborState
+}
+
+func TestNeighborCacheGetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheSetConfig(t *testing.T) {
+ nudDisp := testNUDDispatcher{}
+ c := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+ neigh.setConfig(c)
+
+ if got, want := neigh.config(), c; got != want {
+ t.Errorf("got neigh.config() = %+v, want = %+v", got, want)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheEntry(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveEntry(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+}
+
+type testContext struct {
+ clock *fakeClock
+ neigh *neighborCache
+ store *testEntryStore
+ linkRes *testNeighborResolver
+ nudDisp *testNUDDispatcher
+}
+
+func newTestContext(c NUDConfigurations) testContext {
+ nudDisp := &testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nudDisp, c, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ return testContext{
+ clock: clock,
+ neigh: neigh,
+ store: store,
+ linkRes: linkRes,
+ nudDisp: nudDisp,
+ }
+}
+
+type overflowOptions struct {
+ startAtEntryIndex int
+ wantStaticEntries []NeighborEntry
+}
+
+func (c *testContext) overflowCache(opts overflowOptions) error {
+ // Fill the neighbor cache to capacity to verify the LRU eviction strategy is
+ // working properly after the entry removal.
+ for i := opts.startAtEntryIndex; i < c.store.size(); i++ {
+ // Add a new entry
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock {
+ return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(c.neigh.config().RetransmitTimer)
+
+ var wantEvents []testEntryEventInfo
+
+ // When beyond the full capacity, the cache will evict an entry as per the
+ // LRU eviction strategy. Note that the number of static entries should not
+ // affect the total number of dynamic entries that can be added.
+ if i >= neighborCacheSize+opts.startAtEntryIndex {
+ removedEntry, ok := c.store.entry(i - neighborCacheSize)
+ if !ok {
+ return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize)
+ }
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ })
+ }
+
+ wantEvents = append(wantEvents, testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ }, testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ })
+
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the most recent entries. The order of entries reported
+ // by entries() is undeterministic, so entries have to be sorted before
+ // comparison.
+ wantUnsortedEntries := opts.wantStaticEntries
+ for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ {
+ entry, ok := c.store.entry(i)
+ if !ok {
+ return fmt.Errorf("c.store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ return nil
+}
+
+// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy
+// respects the dynamic entry count.
+func TestNeighborCacheOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when an entry is removed.
+func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(c.neigh.config().RetransmitTimer)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the entry
+ c.neigh.removeEntry(entry.Addr)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that
+// adding a duplicate static entry with the same link address does not dispatch
+// any events.
+func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+
+ // No more events should have been dispatched.
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that
+// adding a duplicate static entry with a different link address dispatches a
+// change event.
+func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a duplicate entry with a different link address
+ staticLinkAddr += "duplicate"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ defer c.nudDisp.mu.Unlock()
+ if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache
+// eviction strategy respects the dynamic entry count when a static entry is
+// added then removed. In this case, the dynamic entry count shouldn't have
+// been touched.
+func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a static entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Remove the static entry that was just added
+ c.neigh.removeEntry(entry.Addr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU
+// cache eviction strategy keeps count of the dynamic entry count when an entry
+// is overwritten by a static entry. Static entries should not count towards
+// the size of the LRU cache.
+func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Override the entry with a static one using the same address
+ staticLinkAddr := entry.LinkAddr + "static"
+ c.neigh.addStaticEntry(entry.Addr, staticLinkAddr)
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: staticLinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheNotifiesWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ clock.advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ id, ok := s.Fetch(false /* block */)
+ if !ok {
+ t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+ if id != wakerID {
+ t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheRemoveWaker(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ const wakerID = 1
+ s.AddWaker(&w, wakerID)
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh == nil {
+ t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Remove the waker before the neighbor cache has the opportunity to send a
+ // notification.
+ neigh.removeWaker(entry.Addr, &w)
+ clock.advance(typicalLatency)
+
+ select {
+ case <-doneCh:
+ default:
+ t.Fatal("expected notification from done channel")
+ }
+
+ if id, ok := s.Fetch(false /* block */); ok {
+ t.Errorf("unexpected notification from waker with id %d", id)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr)
+ e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 1,
+ wantStaticEntries: []NeighborEntry{
+ {
+ Addr: entry.Addr,
+ LocalAddr: "", // static entries don't need a local address
+ LinkAddr: entry.LinkAddr,
+ State: Static,
+ },
+ },
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheClear(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add a dynamic entry.
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Add a static entry.
+ neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1)
+
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Clear shoud remove both dynamic and static entries.
+ neigh.clear()
+
+ // Remove events dispatched from clear() have no deterministic order so they
+ // need to be sorted beforehand.
+ wantUnsortedEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Static,
+ },
+ }
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction
+// strategy keeps count of the dynamic entry count when all entries are
+// cleared.
+func TestNeighborCacheClearThenOverflow(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ c := newTestContext(config)
+
+ // Add a dynamic entry
+ entry, ok := c.store.entry(0)
+ if !ok {
+ t.Fatalf("c.store.entry(0) not found")
+ }
+ _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ c.clock.advance(typicalLatency)
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+
+ // Clear the cache.
+ c.neigh.clear()
+ {
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ c.nudDisp.mu.Lock()
+ diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...)
+ c.nudDisp.events = nil
+ c.nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ opts := overflowOptions{
+ startAtEntryIndex: 0,
+ }
+ if err := c.overflowCache(opts); err != nil {
+ t.Errorf("c.overflowCache(%+v): %s", opts, err)
+ }
+}
+
+func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ // Stay in Reachable so the cache can overflow
+ config.BaseReachableTime = infiniteDuration
+ config.MinRandomFactor = 1
+ config.MaxRandomFactor = 1
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ frequentlyUsedEntry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+
+ // The following logic is very similar to overflowCache, but
+ // periodically refreshes the frequently used entry.
+
+ // Fill the neighbor cache to capacity
+ for i := 0; i < neighborCacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Keep adding more entries
+ for i := neighborCacheSize; i < store.size(); i++ {
+ // Periodically refresh the frequently used entry
+ if i%(neighborCacheSize/2) == 0 {
+ _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err)
+ }
+ }
+
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // An entry should have been removed, as per the LRU eviction strategy
+ removedEntry, ok := store.entry(i - neighborCacheSize + 1)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1)
+ }
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestRemoved,
+ NICID: 1,
+ Addr: removedEntry.Addr,
+ LinkAddr: removedEntry.LinkAddr,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestAdded,
+ NICID: 1,
+ Addr: entry.Addr,
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: 1,
+ Addr: entry.Addr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.events = nil
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ // Expect to find only the frequently used entry and the most recent entries.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ wantUnsortedEntries := []NeighborEntry{
+ {
+ Addr: frequentlyUsedEntry.Addr,
+ LocalAddr: frequentlyUsedEntry.LocalAddr,
+ LinkAddr: frequentlyUsedEntry.LinkAddr,
+ State: Reachable,
+ },
+ }
+
+ for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Fatalf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No more events should have been dispatched.
+ nudDisp.mu.Lock()
+ defer nudDisp.mu.Unlock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheConcurrent(t *testing.T) {
+ const concurrentProcesses = 16
+
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ storeEntries := store.entries()
+ for _, entry := range storeEntries {
+ var wg sync.WaitGroup
+ for r := 0; r < concurrentProcesses; r++ {
+ wg.Add(1)
+ go func(entry NeighborEntry) {
+ defer wg.Done()
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock)
+ }
+ }(entry)
+ }
+
+ // Wait for all gorountines to send a request
+ wg.Wait()
+
+ // Process all the requests for a single entry concurrently
+ clock.advance(typicalLatency)
+ }
+
+ // All goroutines add in the same order and add more values than can fit in
+ // the cache. Our eviction strategy requires that the last entries are
+ // present, up to the size of the neighbor cache, and the rest are missing.
+ // The order of entries reported by entries() is undeterministic, so entries
+ // have to be sorted before comparison.
+ var wantUnsortedEntries []NeighborEntry
+ for i := store.size() - neighborCacheSize; i < store.size(); i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ t.Errorf("store.entry(%d) not found", i)
+ }
+ wantEntry := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ wantUnsortedEntries = append(wantUnsortedEntries, wantEntry)
+ }
+
+ if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" {
+ t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff)
+ }
+}
+
+func TestNeighborCacheReplace(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ // Add an entry
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+
+ // Verify the entry exists
+ e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ if doneCh != nil {
+ t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff)
+ }
+
+ // Notify of a link address change
+ var updatedLinkAddr tcpip.LinkAddress
+ {
+ entry, ok := store.entry(1)
+ if !ok {
+ t.Fatalf("store.entry(1) not found")
+ }
+ updatedLinkAddr = entry.LinkAddr
+ }
+ store.set(0, updatedLinkAddr)
+ neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+
+ // Requesting the entry again should start address resolution
+ {
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(config.DelayFirstProbeTime + typicalLatency)
+ select {
+ case <-doneCh:
+ default:
+ t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr)
+ }
+ }
+
+ // Verify the entry's new link address
+ {
+ e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ clock.advance(typicalLatency)
+ if err != nil {
+ t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want = NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: updatedLinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+ }
+}
+
+func TestNeighborCacheResolutionFailed(t *testing.T) {
+ config := DefaultNUDConfigurations()
+
+ nudDisp := testNUDDispatcher{}
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(&nudDisp, config, clock)
+ store := newTestEntryStore()
+
+ var requestCount uint32
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ onLinkAddressRequest: func() {
+ atomic.AddUint32(&requestCount, 1)
+ },
+ }
+
+ // First, sanity check that resolution is working
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ clock.advance(typicalLatency)
+ got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: entry.Addr,
+ LocalAddr: entry.LocalAddr,
+ LinkAddr: entry.LinkAddr,
+ State: Reachable,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff)
+ }
+
+ // Verify that address resolution for an unknown address returns ErrNoLinkAddress
+ before := atomic.LoadUint32(&requestCount)
+
+ entry.Addr += "2"
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes)
+ clock.advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+
+ maxAttempts := neigh.config().MaxUnicastProbes
+ if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
+}
+
+// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes
+// probes and not retrieving a confirmation before the duration defined by
+// MaxMulticastProbes * RetransmitTimer.
+func TestNeighborCacheResolutionTimeout(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ config.RetransmitTimer = time.Millisecond // small enough to cause timeout
+
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: time.Minute, // large enough to cause timeout
+ }
+
+ entry, ok := store.entry(0)
+ if !ok {
+ t.Fatalf("store.entry(0) not found")
+ }
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes)
+ clock.advance(waitFor)
+ if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress {
+ t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress)
+ }
+}
+
+// TestNeighborCacheStaticResolution checks that static link addresses are
+// resolved immediately and don't send resolution requests.
+func TestNeighborCacheStaticResolution(t *testing.T) {
+ config := DefaultNUDConfigurations()
+ clock := newFakeClock()
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: typicalLatency,
+ }
+
+ got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil)
+ if err != nil {
+ t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err)
+ }
+ want := NeighborEntry{
+ Addr: testEntryBroadcastAddr,
+ LocalAddr: testEntryLocalAddr,
+ LinkAddr: testEntryBroadcastLinkAddr,
+ State: Static,
+ }
+ if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" {
+ t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff)
+ }
+}
+
+func BenchmarkCacheClear(b *testing.B) {
+ b.StopTimer()
+ config := DefaultNUDConfigurations()
+ clock := &tcpip.StdClock{}
+ neigh := newTestNeighborCache(nil, config, clock)
+ store := newTestEntryStore()
+ linkRes := &testNeighborResolver{
+ clock: clock,
+ neigh: neigh,
+ entries: store,
+ delay: 0,
+ }
+
+ // Clear for every possible size of the cache
+ for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ {
+ // Fill the neighbor cache to capacity.
+ for i := 0; i < cacheSize; i++ {
+ entry, ok := store.entry(i)
+ if !ok {
+ b.Fatalf("store.entry(%d) not found", i)
+ }
+ _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil)
+ if err != tcpip.ErrWouldBlock {
+ b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock)
+ }
+ if doneCh != nil {
+ <-doneCh
+ }
+ }
+
+ b.StartTimer()
+ neigh.clear()
+ b.StopTimer()
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
new file mode 100644
index 000000000..0068cacb8
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -0,0 +1,482 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// NeighborEntry describes a neighboring device in the local network.
+type NeighborEntry struct {
+ Addr tcpip.Address
+ LocalAddr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+// NeighborState defines the state of a NeighborEntry within the Neighbor
+// Unreachability Detection state machine, as per RFC 4861 section 7.3.2.
+type NeighborState uint8
+
+const (
+ // Unknown means reachability has not been verified yet. This is the initial
+ // state of entries that have been created automatically by the Neighbor
+ // Unreachability Detection state machine.
+ Unknown NeighborState = iota
+ // Incomplete means that there is an outstanding request to resolve the
+ // address.
+ Incomplete
+ // Reachable means the path to the neighbor is functioning properly for both
+ // receive and transmit paths.
+ Reachable
+ // Stale means reachability to the neighbor is unknown, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Stale
+ // Delay means reachability to the neighbor is unknown and pending
+ // confirmation from an upper-level protocol like TCP, but packets are still
+ // able to be transmitted to the possibly stale link address.
+ Delay
+ // Probe means a reachability confirmation is actively being sought by
+ // periodically retransmitting reachability probes until a reachability
+ // confirmation is received, or until the max amount of probes has been sent.
+ Probe
+ // Static describes entries that have been explicitly added by the user. They
+ // do not expire and are not deleted until explicitly removed.
+ Static
+ // Failed means traffic should not be sent to this neighbor since attempts of
+ // reachability have returned inconclusive.
+ Failed
+)
+
+// neighborEntry implements a neighbor entry's individual node behavior, as per
+// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in
+// parallel with the sending of packets to a neighbor, necessitating the
+// entry's lock to be acquired for all operations.
+type neighborEntry struct {
+ neighborEntryEntry
+
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkRes provides the functionality to send reachability probes, used in
+ // Neighbor Unreachability Detection.
+ linkRes LinkAddressResolver
+
+ // nudState points to the Neighbor Unreachability Detection configuration.
+ nudState *NUDState
+
+ // mu protects the fields below.
+ mu sync.RWMutex
+
+ neigh NeighborEntry
+
+ // wakers is a set of waiters for address resolution result. Anytime state
+ // transitions out of incomplete these waiters are notified. It is nil iff
+ // address resolution is ongoing and no clients are waiting for the result.
+ wakers map[*sleep.Waker]struct{}
+
+ // done is used to allow callers to wait on address resolution. It is nil
+ // iff nudState is not Reachable and address resolution is not yet in
+ // progress.
+ done chan struct{}
+
+ isRouter bool
+ job *tcpip.Job
+}
+
+// newNeighborEntry creates a neighbor cache entry starting at the default
+// state, Unknown. Transition out of Unknown by calling either
+// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created
+// neighborEntry.
+func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry {
+ return &neighborEntry{
+ nic: nic,
+ linkRes: linkRes,
+ nudState: nudState,
+ neigh: NeighborEntry{
+ Addr: remoteAddr,
+ LocalAddr: localAddr,
+ State: Unknown,
+ },
+ }
+}
+
+// newStaticNeighborEntry creates a neighbor cache entry starting at the Static
+// state. The entry can only transition out of Static by directly calling
+// `setStateLocked`.
+func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry {
+ if nic.stack.nudDisp != nil {
+ nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now())
+ }
+ return &neighborEntry{
+ nic: nic,
+ nudState: state,
+ neigh: NeighborEntry{
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: Static,
+ UpdatedAt: time.Now(),
+ },
+ }
+}
+
+// addWaker adds w to the list of wakers waiting for address resolution.
+// Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) addWakerLocked(w *sleep.Waker) {
+ if w == nil {
+ return
+ }
+ if e.wakers == nil {
+ e.wakers = make(map[*sleep.Waker]struct{})
+ }
+ e.wakers[w] = struct{}{}
+}
+
+// notifyWakersLocked notifies those waiting for address resolution, whether it
+// succeeded or failed. Assumes the entry has already been appropriately locked.
+func (e *neighborEntry) notifyWakersLocked() {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ if ch := e.done; ch != nil {
+ close(ch)
+ e.done = nil
+ }
+}
+
+// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has
+// been added.
+func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry
+// has changed state or link-layer address.
+func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now())
+ }
+}
+
+// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry
+// has been removed.
+func (e *neighborEntry) dispatchRemoveEventLocked() {
+ if nudDisp := e.nic.stack.nudDisp; nudDisp != nil {
+ nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now())
+ }
+}
+
+// setStateLocked transitions the entry to the specified state immediately.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+//
+// e.mu MUST be locked.
+func (e *neighborEntry) setStateLocked(next NeighborState) {
+ // Cancel the previously scheduled action, if there is one. Entries in
+ // Unknown, Stale, or Static state do not have scheduled actions.
+ if timer := e.job; timer != nil {
+ timer.Cancel()
+ }
+
+ prev := e.neigh.State
+ e.neigh.State = next
+ e.neigh.UpdatedAt = time.Now()
+ config := e.nudState.Config()
+
+ switch next {
+ case Incomplete:
+ var retryCounter uint32
+ var sendMulticastProbe func()
+
+ sendMulticastProbe = func() {
+ if retryCounter == config.MaxMulticastProbes {
+ // "If no Neighbor Advertisement is received after
+ // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed.
+ // The sender MUST return ICMP destination unreachable indications with
+ // code 3 (Address Unreachable) for each packet queued awaiting address
+ // resolution." - RFC 4861 section 7.2.2
+ //
+ // There is no need to send an ICMP destination unreachable indication
+ // since the failure to resolve the address is expected to only occur
+ // on this node. Thus, redirecting traffic is currently not supported.
+ //
+ // "If the error occurs on a node other than the node originating the
+ // packet, an ICMP error message is generated. If the error occurs on
+ // the originating node, an implementation is not required to actually
+ // create and send an ICMP error packet to the source, as long as the
+ // upper-layer sender is notified through an appropriate mechanism
+ // (e.g. return value from a procedure call). Note, however, that an
+ // implementation may find it convenient in some cases to return errors
+ // to the sender by taking the offending packet, generating an ICMP
+ // error message, and then delivering it (locally) through the generic
+ // error-handling routines.' - RFC 4861 section 2.1
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil {
+ // There is no need to log the error here; the NUD implementation may
+ // assume a working link. A valid link should be the responsibility of
+ // the NIC/stack.LinkEndpoint.
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendMulticastProbe()
+
+ case Reachable:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ })
+ e.job.Schedule(e.nudState.ReachableTime())
+
+ case Delay:
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.dispatchChangeEventLocked(Probe)
+ e.setStateLocked(Probe)
+ })
+ e.job.Schedule(config.DelayFirstProbeTime)
+
+ case Probe:
+ var retryCounter uint32
+ var sendUnicastProbe func()
+
+ sendUnicastProbe = func() {
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ retryCounter++
+ if retryCounter == config.MaxUnicastProbes {
+ e.dispatchRemoveEventLocked()
+ e.setStateLocked(Failed)
+ return
+ }
+
+ e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe)
+ e.job.Schedule(config.RetransmitTimer)
+ }
+
+ sendUnicastProbe()
+
+ case Failed:
+ e.notifyWakersLocked()
+ e.job = e.nic.stack.newJob(&e.mu, func() {
+ e.nic.neigh.removeEntryLocked(e)
+ })
+ e.job.Schedule(config.UnreachableTime)
+
+ case Unknown, Stale, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next))
+ }
+}
+
+// handlePacketQueuedLocked advances the state machine according to a packet
+// being queued for outgoing transmission.
+//
+// Follows the logic defined in RFC 4861 section 7.3.3.
+func (e *neighborEntry) handlePacketQueuedLocked() {
+ switch e.neigh.State {
+ case Unknown:
+ e.dispatchAddEventLocked(Incomplete)
+ e.setStateLocked(Incomplete)
+
+ case Stale:
+ e.dispatchChangeEventLocked(Delay)
+ e.setStateLocked(Delay)
+
+ case Incomplete, Reachable, Delay, Probe, Static, Failed:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or
+// Neighbor Solicitation for ARP or NDP, respectively).
+//
+// Follows the logic defined in RFC 4861 section 7.2.3.
+func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) {
+ // Probes MUST be silently discarded if the target address is tentative, does
+ // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These
+ // checks MUST be done by the NetworkEndpoint.
+
+ switch e.neigh.State {
+ case Unknown, Incomplete, Failed:
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchAddEventLocked(Stale)
+ e.setStateLocked(Stale)
+ e.notifyWakersLocked()
+
+ case Reachable, Delay, Probe:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+
+ case Stale:
+ if e.neigh.LinkAddr != remoteLinkAddr {
+ e.neigh.LinkAddr = remoteLinkAddr
+ e.dispatchChangeEventLocked(Stale)
+ }
+
+ case Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleConfirmationLocked processes an incoming neighbor confirmation
+// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively).
+//
+// Follows the state machine defined by RFC 4861 section 7.2.5.
+//
+// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other
+// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol
+// should be deployed where preventing access to the broadcast segment might
+// not be possible. SEND uses RSA key pairs to produce Cryptographically
+// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the
+// claimed source of an NDP message is the owner of the claimed address.
+func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) {
+ switch e.neigh.State {
+ case Incomplete:
+ if len(linkAddr) == 0 {
+ // "If the link layer has addresses and no Target Link-Layer Address
+ // option is included, the receiving node SHOULD silently discard the
+ // received advertisement." - RFC 4861 section 7.2.5
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+ if flags.Solicited {
+ e.dispatchChangeEventLocked(Reachable)
+ e.setStateLocked(Reachable)
+ } else {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ e.isRouter = flags.IsRouter
+ e.notifyWakersLocked()
+
+ // "Note that the Override flag is ignored if the entry is in the
+ // INCOMPLETE state." - RFC 4861 section 7.2.5
+
+ case Reachable, Stale, Delay, Probe:
+ sameLinkAddr := e.neigh.LinkAddr == linkAddr
+
+ if !sameLinkAddr {
+ if !flags.Override {
+ if e.neigh.State == Reachable {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ }
+ break
+ }
+
+ e.neigh.LinkAddr = linkAddr
+
+ if !flags.Solicited {
+ if e.neigh.State != Stale {
+ e.dispatchChangeEventLocked(Stale)
+ e.setStateLocked(Stale)
+ } else {
+ // Notify the LinkAddr change, even though NUD state hasn't changed.
+ e.dispatchChangeEventLocked(e.neigh.State)
+ }
+ break
+ }
+ }
+
+ if flags.Solicited && (flags.Override || sameLinkAddr) {
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ }
+ // Set state to Reachable again to refresh timers.
+ e.setStateLocked(Reachable)
+ e.notifyWakersLocked()
+ }
+
+ if e.isRouter && !flags.IsRouter {
+ // "In those cases where the IsRouter flag changes from TRUE to FALSE as
+ // a result of this update, the node MUST remove that router from the
+ // Default Router List and update the Destination Cache entries for all
+ // destinations using that neighbor as a router as specified in Section
+ // 7.3.3. This is needed to detect when a node that is used as a router
+ // stops forwarding packets due to being configured as a host."
+ // - RFC 4861 section 7.2.5
+ e.nic.mu.Lock()
+ e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr)
+ e.nic.mu.Unlock()
+ }
+ e.isRouter = flags.IsRouter
+
+ case Unknown, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
+
+// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol
+// (e.g. TCP acknowledgements) reachability confirmation.
+func (e *neighborEntry) handleUpperLevelConfirmationLocked() {
+ switch e.neigh.State {
+ case Reachable, Stale, Delay, Probe:
+ if e.neigh.State != Reachable {
+ e.dispatchChangeEventLocked(Reachable)
+ // Set state to Reachable again to refresh timers.
+ }
+ e.setStateLocked(Reachable)
+
+ case Unknown, Incomplete, Failed, Static:
+ // Do nothing
+
+ default:
+ panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State))
+ }
+}
diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go
new file mode 100644
index 000000000..b769fb2fa
--- /dev/null
+++ b/pkg/tcpip/stack/neighbor_entry_test.go
@@ -0,0 +1,2870 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+
+ entryTestNICID tcpip.NICID = 1
+ entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+
+ entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01")
+ entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02")
+
+ // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests,
+ // except where another value is explicitly used. It is chosen to match the
+ // MTU of loopback interfaces on Linux systems.
+ entryTestNetDefaultMTU = 65536
+)
+
+// eventDiffOpts are the options passed to cmp.Diff to compare entry events.
+// The UpdatedAt field is ignored due to a lack of a deterministic method to
+// predict the time that an event will be dispatched.
+func eventDiffOpts() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ }
+}
+
+// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to
+// sort slices of events for cases where ordering must be ignored.
+func eventDiffOptsWithSort() []cmp.Option {
+ return []cmp.Option{
+ cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"),
+ cmpopts.SortSlices(func(a, b testEntryEventInfo) bool {
+ return strings.Compare(string(a.Addr), string(b.Addr)) < 0
+ }),
+ }
+}
+
+// The following unit tests exercise every state transition and verify its
+// behavior with RFC 4681.
+//
+// | From | To | Cause | Action | Event |
+// | ========== | ========== | ========================================== | =============== | ======= |
+// | Unknown | Unknown | Confirmation w/ unknown address | | Added |
+// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added |
+// | Unknown | Stale | Probe w/ unknown address | | Added |
+// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed |
+// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed |
+// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed |
+// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | |
+// | Reachable | Stale | Reachable timer expired | | Changed |
+// | Reachable | Stale | Probe or confirmation w/ different address | | Changed |
+// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Override confirmation | Update LinkAddr | Changed |
+// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed |
+// | Stale | Delay | Packet sent | | Changed |
+// | Delay | Reachable | Upper-layer confirmation | | Changed |
+// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Delay | Stale | Probe or confirmation w/ different address | | Changed |
+// | Delay | Probe | Delay timer expired | Send probe | Changed |
+// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed |
+// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed |
+// | Probe | Stale | Probe or confirmation w/ different address | | Changed |
+// | Probe | Probe | Retransmit timer expired | Send probe | Changed |
+// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed |
+// | Failed | | Unreachability timer expired | Delete entry | |
+
+type testEntryEventType uint8
+
+const (
+ entryTestAdded testEntryEventType = iota
+ entryTestChanged
+ entryTestRemoved
+)
+
+func (t testEntryEventType) String() string {
+ switch t {
+ case entryTestAdded:
+ return "add"
+ case entryTestChanged:
+ return "change"
+ case entryTestRemoved:
+ return "remove"
+ default:
+ return fmt.Sprintf("unknown (%d)", t)
+ }
+}
+
+// Fields are exported for use with cmp.Diff.
+type testEntryEventInfo struct {
+ EventType testEntryEventType
+ NICID tcpip.NICID
+ Addr tcpip.Address
+ LinkAddr tcpip.LinkAddress
+ State NeighborState
+ UpdatedAt time.Time
+}
+
+func (e testEntryEventInfo) String() string {
+ return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State)
+}
+
+// testNUDDispatcher implements NUDDispatcher to validate the dispatching of
+// events upon certain NUD state machine events.
+type testNUDDispatcher struct {
+ mu sync.Mutex
+ events []testEntryEventInfo
+}
+
+var _ NUDDispatcher = (*testNUDDispatcher)(nil)
+
+func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.events = append(d.events, e)
+}
+
+func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestAdded,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestChanged,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) {
+ d.queueEvent(testEntryEventInfo{
+ EventType: entryTestRemoved,
+ NICID: nicID,
+ Addr: addr,
+ LinkAddr: linkAddr,
+ State: state,
+ UpdatedAt: updatedAt,
+ })
+}
+
+type entryTestLinkResolver struct {
+ mu sync.Mutex
+ probes []entryTestProbeInfo
+}
+
+var _ LinkAddressResolver = (*entryTestLinkResolver)(nil)
+
+type entryTestProbeInfo struct {
+ RemoteAddress tcpip.Address
+ RemoteLinkAddress tcpip.LinkAddress
+ LocalAddress tcpip.Address
+}
+
+func (p entryTestProbeInfo) String() string {
+ return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress)
+}
+
+// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+// to the local network if linkAddr is the zero value.
+func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
+ p := entryTestProbeInfo{
+ RemoteAddress: addr,
+ RemoteLinkAddress: linkAddr,
+ LocalAddress: localAddr,
+ }
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.probes = append(r.probes, p)
+ return nil
+}
+
+// ResolveStaticAddress attempts to resolve address without sending requests.
+// It either resolves the name immediately or returns the empty LinkAddress.
+func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ return "", false
+}
+
+// LinkAddressProtocol returns the network protocol of the addresses this
+// resolver can resolve.
+func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return entryTestNetNumber
+}
+
+func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) {
+ clock := newFakeClock()
+ disp := testNUDDispatcher{}
+ nic := NIC{
+ id: entryTestNICID,
+ linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint
+ stack: &Stack{
+ clock: clock,
+ nudDisp: &disp,
+ },
+ }
+
+ rng := rand.New(rand.NewSource(time.Now().UnixNano()))
+ nudState := NewNUDState(c, rng)
+ linkRes := entryTestLinkResolver{}
+ entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes)
+
+ // Stub out ndpState to verify modification of default routers.
+ nic.mu.ndp = ndpState{
+ nic: &nic,
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ }
+
+ // Stub out the neighbor cache to verify deletion from the cache.
+ nic.neigh = &neighborCache{
+ nic: &nic,
+ state: nudState,
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
+ nic.neigh.cache[entryTestAddr1] = entry
+
+ return entry, &disp, &linkRes, clock
+}
+
+// TestEntryInitiallyUnknown verifies that the state of a newly created
+// neighborEntry is Unknown.
+func TestEntryInitiallyUnknown(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Unknown; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(time.Hour)
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ // No events should have been dispatched.
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryUnknownToIncomplete(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ {
+ nudDisp.mu.Lock()
+ diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...)
+ nudDisp.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ }
+}
+
+func TestEntryUnknownToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ // No probes should have been sent.
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil))
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ updatedAt := e.neigh.UpdatedAt
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // UpdatedAt should remain the same during address resolution.
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.probes = nil
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.UpdatedAt, updatedAt; got != want {
+ t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.RetransmitTimer)
+
+ // UpdatedAt should change after failing address resolution. Timing out after
+ // sending the last probe transitions the entry to Failed.
+ {
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ }
+
+ clock.advance(c.RetransmitTimer)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant {
+ t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachable(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+// TestEntryAddsAndClearsWakers verifies that wakers are added when
+// addWakerLocked is called and cleared when address resolution finishes. In
+// this case, address resolution will finish when transitioning from Incomplete
+// to Reachable.
+func TestEntryAddsAndClearsWakers(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got := e.wakers; got != nil {
+ t.Errorf("got e.wakers = %v, want = nil", got)
+ }
+ e.addWakerLocked(&w)
+ if got, want := w.IsAsserted(), false; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ if e.wakers == nil {
+ t.Error("expected e.wakers to be non-nil")
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if e.wakers != nil {
+ t.Errorf("got e.wakers = %v, want = nil", e.wakers)
+ }
+ if got, want := w.IsAsserted(), true; got != want {
+ t.Errorf("waker.IsAsserted() = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+ linkRes.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToStale(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryIncompleteToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Incomplete; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes)
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The Incomplete-to-Incomplete state transition is tested here by
+ // verifying that 3 reachability probes were sent.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+type testLocker struct{}
+
+var _ sync.Locker = (*testLocker)(nil)
+
+func (*testLocker) Lock() {}
+func (*testLocker) Unlock() {}
+
+func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: true,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.isRouter, true; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{
+ invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}),
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.isRouter, false; got != want {
+ t.Errorf("got e.isRouter = %t, want = %t", got, want)
+ }
+ if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok {
+ t.Errorf("unexpected defaultRouter for %s", entryTestAddr1)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenTimeout(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr1)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaleToDelay(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleUpperLevelConfirmationLocked()
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 1
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, _ := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantProbes := []entryTestProbeInfo{
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryDelayToProbe(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ if got, want := e.neigh.State, Delay; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleProbeLocked(entryTestLinkAddr2)
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Stale; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario:
+// 1. Probe is received
+// 2. Entry is created in Stale
+// 3. Packet is queued on the entry
+// 4. Entry transitions to Delay then Probe
+// 5. Probe is sent
+func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Probe to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handleProbeLocked(entryTestLinkAddr1)
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // Probe caused by the Delay-to-Probe transition
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: true,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want {
+ t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr2,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ // Eliminate random factors from ReachableTime computation so the transition
+ // from Stale to Reachable will only take BaseReachableTime duration.
+ c.MinRandomFactor = 1
+ c.MaxRandomFactor = 1
+
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ clock.advance(c.DelayFirstProbeTime)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The second probe is caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Probe; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: true,
+ Override: false,
+ IsRouter: false,
+ })
+ if got, want := e.neigh.State, Reachable; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+
+ clock.advance(c.BaseReachableTime)
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Reachable,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+}
+
+func TestEntryProbeToFailed(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes)
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ e.mu.Lock()
+ if got, want := e.neigh.State, Failed; got != want {
+ t.Errorf("got e.neigh.State = %q, want = %q", got, want)
+ }
+ e.mu.Unlock()
+}
+
+func TestEntryFailedGetsDeleted(t *testing.T) {
+ c := DefaultNUDConfigurations()
+ c.MaxMulticastProbes = 3
+ c.MaxUnicastProbes = 3
+ e, nudDisp, linkRes, clock := entryTestSetup(c)
+
+ // Verify the cache contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok {
+ t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1)
+ }
+
+ e.mu.Lock()
+ e.handlePacketQueuedLocked()
+ e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{
+ Solicited: false,
+ Override: false,
+ IsRouter: false,
+ })
+ e.handlePacketQueuedLocked()
+ e.mu.Unlock()
+
+ waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime
+ clock.advance(waitFor)
+
+ wantProbes := []entryTestProbeInfo{
+ // The first probe is caused by the Unknown-to-Incomplete transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: tcpip.LinkAddress(""),
+ LocalAddress: entryTestAddr2,
+ },
+ // The next three probe are caused by the Delay-to-Probe transition.
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ {
+ RemoteAddress: entryTestAddr1,
+ RemoteLinkAddress: entryTestLinkAddr1,
+ LocalAddress: entryTestAddr2,
+ },
+ }
+ linkRes.mu.Lock()
+ diff := cmp.Diff(linkRes.probes, wantProbes)
+ linkRes.mu.Unlock()
+ if diff != "" {
+ t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff)
+ }
+
+ wantEvents := []testEntryEventInfo{
+ {
+ EventType: entryTestAdded,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: tcpip.LinkAddress(""),
+ State: Incomplete,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Stale,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Delay,
+ },
+ {
+ EventType: entryTestChanged,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ {
+ EventType: entryTestRemoved,
+ NICID: entryTestNICID,
+ Addr: entryTestAddr1,
+ LinkAddr: entryTestLinkAddr1,
+ State: Probe,
+ },
+ }
+ nudDisp.mu.Lock()
+ if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" {
+ t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff)
+ }
+ nudDisp.mu.Unlock()
+
+ // Verify the cache no longer contains the entry.
+ if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok {
+ t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1)
+ }
+}
diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go
new file mode 100644
index 000000000..aa7311ec6
--- /dev/null
+++ b/pkg/tcpip/stack/neighborstate_string.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated by "stringer -type NeighborState"; DO NOT EDIT.
+
+package stack
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[Unknown-0]
+ _ = x[Incomplete-1]
+ _ = x[Reachable-2]
+ _ = x[Stale-3]
+ _ = x[Delay-4]
+ _ = x[Probe-5]
+ _ = x[Static-6]
+ _ = x[Failed-7]
+}
+
+const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed"
+
+var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53}
+
+func (i NeighborState) String() string {
+ if i >= NeighborState(len(_NeighborState_index)-1) {
+ return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]]
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index ab6798aa6..e74d2562a 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -15,48 +15,66 @@
package stack
import (
- "strings"
- "sync"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "sort"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+var ipv4BroadcastAddr = tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 8 * header.IPv4AddressSize,
+ },
+}
+
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
- stack *Stack
- id tcpip.NICID
- name string
- linkEP LinkEndpoint
- loopback bool
-
- mu sync.RWMutex
- spoofing bool
- promiscuous bool
- primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- addressRanges []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
- // packetEPs is protected by mu, but the contained PacketEndpoint
- // values are not.
- packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
-
- stats NICStats
-
- // ndp is the NDP related state for NIC.
- //
- // Note, read and write operations on ndp require that the NIC is
- // appropriately locked.
- ndp ndpState
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+ context NICContext
+
+ stats NICStats
+ neigh *neighborCache
+ networkEndpoints map[tcpip.NetworkProtocolNumber]NetworkEndpoint
+
+ mu struct {
+ sync.RWMutex
+ enabled bool
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ mcastJoins map[NetworkEndpointID]uint32
+ // packetEPs is protected by mu, but the contained PacketEndpoint
+ // values are not.
+ packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint
+ ndp ndpState
+ }
}
// NICStats includes transmitted and received stats.
type NICStats struct {
Tx DirectionStats
Rx DirectionStats
+
+ DisabledRx DirectionStats
+}
+
+func makeNICStats() NICStats {
+ var s NICStats
+ tcpip.InitStatCounters(reflect.ValueOf(&s).Elem())
+ return s
}
// DirectionStats includes packet and byte counts.
@@ -85,59 +103,170 @@ const (
)
// newNIC returns a new NIC using the default NDP configurations from stack.
-func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback bool) *NIC {
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC {
// TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For
// example, make sure that the link address it provides is a valid
// unicast ethernet address.
+
+ // TODO(b/143357959): RFC 8200 section 5 requires that IPv6 endpoints
+ // observe an MTU of at least 1280 bytes. Ensure that this requirement
+ // of IPv6 is supported on this endpoint's LinkEndpoint.
+
nic := &NIC{
- stack: stack,
- id: id,
- name: name,
- linkEP: ep,
- loopback: loopback,
- primary: make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint),
- endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
- mcastJoins: make(map[NetworkEndpointID]int32),
- packetEPs: make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint),
- stats: NICStats{
- Tx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- Rx: DirectionStats{
- Packets: &tcpip.StatCounter{},
- Bytes: &tcpip.StatCounter{},
- },
- },
- ndp: ndpState{
- configs: stack.ndpConfigs,
- dad: make(map[tcpip.Address]dadState),
- },
- }
- nic.ndp.nic = nic
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ context: ctx,
+ stats: makeNICStats(),
+ networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
+ }
+ nic.mu.primary = make(map[tcpip.NetworkProtocolNumber][]*referencedNetworkEndpoint)
+ nic.mu.endpoints = make(map[NetworkEndpointID]*referencedNetworkEndpoint)
+ nic.mu.mcastJoins = make(map[NetworkEndpointID]uint32)
+ nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint)
+ nic.mu.ndp = ndpState{
+ nic: nic,
+ configs: stack.ndpConfigs,
+ dad: make(map[tcpip.Address]dadState),
+ defaultRouters: make(map[tcpip.Address]defaultRouterState),
+ onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState),
+ slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState),
+ }
+ nic.mu.ndp.initializeTempAddrState()
// Register supported packet endpoint protocols.
for _, netProto := range header.Ethertypes {
- nic.packetEPs[netProto] = []PacketEndpoint{}
+ nic.mu.packetEPs[netProto] = []PacketEndpoint{}
}
for _, netProto := range stack.networkProtocols {
- nic.packetEPs[netProto.Number()] = []PacketEndpoint{}
+ netNum := netProto.Number()
+ nic.mu.packetEPs[netNum] = nil
+ nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic, ep, stack)
+ }
+
+ // Check for Neighbor Unreachability Detection support.
+ if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 {
+ rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds()))
+ nic.neigh = &neighborCache{
+ nic: nic,
+ state: NewNUDState(stack.nudConfigs, rng),
+ cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize),
+ }
}
+ nic.linkEP.Attach(nic)
+
return nic
}
-// enable enables the NIC. enable will attach the link to its LinkEndpoint and
-// join the IPv6 All-Nodes Multicast address (ff02::1).
+// enabled returns true if n is enabled.
+func (n *NIC) enabled() bool {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ return enabled
+}
+
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if !enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ err := n.disableLocked()
+ n.mu.Unlock()
+ return err
+}
+
+// disableLocked disables n.
+//
+// It undoes the work done by enable.
+//
+// n MUST be locked.
+func (n *NIC) disableLocked() *tcpip.Error {
+ if !n.mu.enabled {
+ return nil
+ }
+
+ // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be
+ // invalidated? Currently, Routes will continue to work when a NIC is enabled
+ // again, and applications may not know that the underlying NIC was ever
+ // disabled.
+
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ n.mu.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupState(false /* hostOnly */)
+
+ // Stop DAD for all the unicast IPv6 endpoints that are in the
+ // permanentTentative state.
+ for _, r := range n.mu.endpoints {
+ if addr := r.address(); r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+ }
+
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+
+ // The address may have already been removed.
+ if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ n.mu.enabled = false
+ return nil
+}
+
+// enable enables n.
+//
+// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast
+// address (ff02::1), start DAD for permanent addresses, and start soliciting
+// routers if the stack is not operating as a router. If the stack is also
+// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
- n.attachLinkEndpoint()
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if n.mu.enabled {
+ return nil
+ }
+
+ n.mu.enabled = true
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if err := n.AddAddress(tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint); err != nil {
+ if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ return err
+ }
+
+ // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts
+ // multicast group. Note, the IANA calls the all-hosts multicast group the
+ // all-systems multicast group.
+ if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil {
return err
}
}
@@ -159,77 +288,298 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
- n.mu.Lock()
- defer n.mu.Unlock()
-
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
return err
}
- if !n.stack.autoGenIPv6LinkLocal {
- return nil
+ // Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
+ // state.
+ //
+ // Addresses may have aleady completed DAD but in the time since the NIC was
+ // last enabled, other devices may have acquired the same addresses.
+ for _, r := range n.mu.endpoints {
+ addr := r.address()
+ if k := r.getKind(); (k != permanent && k != permanentTentative) || !header.IsV6UnicastAddress(addr) {
+ continue
+ }
+
+ r.setKind(permanentTentative)
+ if err := n.mu.ndp.startDuplicateAddressDetection(addr, r); err != nil {
+ return err
+ }
}
- l2addr := n.linkEP.LinkAddress()
+ // Do not auto-generate an IPv6 link-local address for loopback devices.
+ if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
+ // The valid and preferred lifetime is infinite for the auto-generated
+ // link-local address.
+ n.mu.ndp.doSLAAC(header.IPv6LinkLocalPrefix.Subnet(), header.NDPInfiniteLifetime, header.NDPInfiniteLifetime)
+ }
- // Only attempt to generate the link-local address if we have a
- // valid MAC address.
+ // If we are operating as a router, then do not solicit routers since we
+ // won't process the RAs anyways.
//
- // TODO(b/141011931): Validate a LinkEndpoint's link address
- // (provided by LinkEndpoint.LinkAddress) before reaching this
- // point.
- if !header.IsValidUnicastEthernetAddress(l2addr) {
- return nil
+ // Routers do not process Router Advertisements (RA) the same way a host
+ // does. That is, routers do not learn from RAs (e.g. on-link prefixes
+ // and default routers). Therefore, soliciting RAs from other routers on
+ // a link is unnecessary for routers.
+ if !n.stack.Forwarding(header.IPv6ProtocolNumber) {
+ n.mu.ndp.startSolicitingRouters()
}
- addr := header.LinkLocalAddr(l2addr)
+ return nil
+}
- _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv6ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: addr,
- PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
- },
- }, CanBePrimaryEndpoint)
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.disableLocked()
+
+ // TODO(b/151378115): come up with a better way to pick an error than the
+ // first one.
+ var err *tcpip.Error
+
+ // Forcefully leave multicast groups.
+ for nid := range n.mu.mcastJoins {
+ if tempErr := n.leaveGroupLocked(nid.LocalAddress, true /* force */); tempErr != nil && err == nil {
+ err = tempErr
+ }
+ }
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if tempErr := n.removePermanentAddressLocked(nid.LocalAddress); tempErr != nil && err == nil {
+ err = tempErr
+ }
+ }
+ }
+
+ // Release any resources the network endpoint may hold.
+ for _, ep := range n.networkEndpoints {
+ ep.Close()
+ }
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
return err
}
-// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
-// to start delivering packets.
-func (n *NIC) attachLinkEndpoint() {
- n.linkEP.Attach(n)
+// becomeIPv6Router transitions n into an IPv6 router.
+//
+// When transitioning into an IPv6 router, host-only state (NDP discovered
+// routers, discovered on-link prefixes, and auto-generated addresses) will
+// be cleaned up/invalidated and NDP router solicitations will be stopped.
+func (n *NIC) becomeIPv6Router() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.mu.ndp.cleanupState(true /* hostOnly */)
+ n.mu.ndp.stopSolicitingRouters()
+}
+
+// becomeIPv6Host transitions n into an IPv6 host.
+//
+// When transitioning into an IPv6 host, NDP router solicitations will be
+// started.
+func (n *NIC) becomeIPv6Host() {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.mu.ndp.startSolicitingRouters()
}
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
- n.promiscuous = enable
+ n.mu.promiscuous = enable
n.mu.Unlock()
}
func (n *NIC) isPromiscuousMode() bool {
n.mu.RLock()
- rv := n.promiscuous
+ rv := n.mu.promiscuous
n.mu.RUnlock()
return rv
}
+func (n *NIC) isLoopback() bool {
+ return n.linkEP.Capabilities()&CapabilityLoopback != 0
+}
+
// setSpoofing enables or disables address spoofing.
func (n *NIC) setSpoofing(enable bool) {
n.mu.Lock()
- n.spoofing = enable
+ n.mu.spoofing = enable
n.mu.Unlock()
}
-// primaryEndpoint returns the primary endpoint of n for the given network
-// protocol.
-func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+// primaryEndpoint will return the first non-deprecated endpoint if such an
+// endpoint exists for the given protocol and remoteAddr. If no non-deprecated
+// endpoint exists, the first deprecated endpoint will be returned.
+//
+// If an IPv6 primary endpoint is requested, Source Address Selection (as
+// defined by RFC 6724 section 5) will be performed.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ if protocol == header.IPv6ProtocolNumber && remoteAddr != "" {
+ return n.primaryIPv6Endpoint(remoteAddr)
+ }
+
n.mu.RLock()
defer n.mu.RUnlock()
- for _, r := range n.primary[protocol] {
- if r.isValidForOutgoing() && r.tryIncRef() {
+ var deprecatedEndpoint *referencedNetworkEndpoint
+ for _, r := range n.mu.primary[protocol] {
+ if !r.isValidForOutgoingRLocked() {
+ continue
+ }
+
+ if !r.deprecated {
+ if r.tryIncRef() {
+ // r is not deprecated, so return it immediately.
+ //
+ // If we kept track of a deprecated endpoint, decrement its reference
+ // count since it was incremented when we decided to keep track of it.
+ if deprecatedEndpoint != nil {
+ deprecatedEndpoint.decRefLocked()
+ deprecatedEndpoint = nil
+ }
+
+ return r
+ }
+ } else if deprecatedEndpoint == nil && r.tryIncRef() {
+ // We prefer an endpoint that is not deprecated, but we keep track of r in
+ // case n doesn't have any non-deprecated endpoints.
+ //
+ // If we end up finding a more preferred endpoint, r's reference count
+ // will be decremented when such an endpoint is found.
+ deprecatedEndpoint = r
+ }
+ }
+
+ // n doesn't have any valid non-deprecated endpoints, so return
+ // deprecatedEndpoint (which may be nil if n doesn't have any valid deprecated
+ // endpoints either).
+ return deprecatedEndpoint
+}
+
+// ipv6AddrCandidate is an IPv6 candidate for Source Address Selection (RFC
+// 6724 section 5).
+type ipv6AddrCandidate struct {
+ ref *referencedNetworkEndpoint
+ scope header.IPv6AddressScope
+}
+
+// primaryIPv6Endpoint returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+func (n *NIC) primaryIPv6Endpoint(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ ref := n.primaryIPv6EndpointRLocked(remoteAddr)
+ n.mu.RUnlock()
+ return ref
+}
+
+// primaryIPv6EndpointLocked returns an IPv6 endpoint following Source Address
+// Selection (RFC 6724 section 5).
+//
+// Note, only rules 1-3 and 7 are followed.
+//
+// remoteAddr must be a valid IPv6 address.
+//
+// n.mu MUST be read locked.
+func (n *NIC) primaryIPv6EndpointRLocked(remoteAddr tcpip.Address) *referencedNetworkEndpoint {
+ primaryAddrs := n.mu.primary[header.IPv6ProtocolNumber]
+
+ if len(primaryAddrs) == 0 {
+ return nil
+ }
+
+ // Create a candidate set of available addresses we can potentially use as a
+ // source address.
+ cs := make([]ipv6AddrCandidate, 0, len(primaryAddrs))
+ for _, r := range primaryAddrs {
+ // If r is not valid for outgoing connections, it is not a valid endpoint.
+ if !r.isValidForOutgoingRLocked() {
+ continue
+ }
+
+ addr := r.address()
+ scope, err := header.ScopeForIPv6Address(addr)
+ if err != nil {
+ // Should never happen as we got r from the primary IPv6 endpoint list and
+ // ScopeForIPv6Address only returns an error if addr is not an IPv6
+ // address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", addr, err))
+ }
+
+ cs = append(cs, ipv6AddrCandidate{
+ ref: r,
+ scope: scope,
+ })
+ }
+
+ remoteScope, err := header.ScopeForIPv6Address(remoteAddr)
+ if err != nil {
+ // primaryIPv6Endpoint should never be called with an invalid IPv6 address.
+ panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err))
+ }
+
+ // Sort the addresses as per RFC 6724 section 5 rules 1-3.
+ //
+ // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5.
+ sort.Slice(cs, func(i, j int) bool {
+ sa := cs[i]
+ sb := cs[j]
+
+ // Prefer same address as per RFC 6724 section 5 rule 1.
+ if sa.ref.address() == remoteAddr {
+ return true
+ }
+ if sb.ref.address() == remoteAddr {
+ return false
+ }
+
+ // Prefer appropriate scope as per RFC 6724 section 5 rule 2.
+ if sa.scope < sb.scope {
+ return sa.scope >= remoteScope
+ } else if sb.scope < sa.scope {
+ return sb.scope < remoteScope
+ }
+
+ // Avoid deprecated addresses as per RFC 6724 section 5 rule 3.
+ if saDep, sbDep := sa.ref.deprecated, sb.ref.deprecated; saDep != sbDep {
+ // If sa is not deprecated, it is preferred over sb.
+ return sbDep
+ }
+
+ // Prefer temporary addresses as per RFC 6724 section 5 rule 7.
+ if saTemp, sbTemp := sa.ref.configType == slaacTemp, sb.ref.configType == slaacTemp; saTemp != sbTemp {
+ return saTemp
+ }
+
+ // sa and sb are equal, return the endpoint that is closest to the front of
+ // the primary endpoint list.
+ return i < j
+ })
+
+ // Return the most preferred address that can have its reference count
+ // incremented.
+ for _, c := range cs {
+ if r := c.ref; r.tryIncRef() {
return r
}
}
@@ -237,62 +587,87 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
return nil
}
+// hasPermanentAddrLocked returns true if n has a permanent (including currently
+// tentative) address, addr.
+func (n *NIC) hasPermanentAddrLocked(addr tcpip.Address) bool {
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
+
+ if !ok {
+ return false
+ }
+
+ kind := ref.getKind()
+
+ return kind == permanent || kind == permanentTentative
+}
+
+type getRefBehaviour int
+
+const (
+ // spoofing indicates that the NIC's spoofing flag should be observed when
+ // getting a NIC's referenced network endpoint.
+ spoofing getRefBehaviour = iota
+
+ // promiscuous indicates that the NIC's promiscuous flag should be observed
+ // when getting a NIC's referenced network endpoint.
+ promiscuous
+)
+
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous)
+ return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous)
}
// findEndpoint finds the endpoint, if any, with the given address.
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
- return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing)
+ return n.getRefOrCreateTemp(protocol, address, peb, spoofing)
}
// getRefEpOrCreateTemp returns the referenced network endpoint for the given
-// protocol and address. If none exists a temporary one may be created if
-// we are in promiscuous mode or spoofing.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
- id := NetworkEndpointID{address}
-
+// protocol and address.
+//
+// If none exists a temporary one may be created if we are in promiscuous mode
+// or spoofing. Promiscuous mode will only be checked if promiscuous is true.
+// Similarly, spoofing will only be checked if spoofing is true.
+//
+// If the address is the IPv4 broadcast address for an endpoint's network, that
+// endpoint will be returned.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint {
n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok {
+ var spoofingOrPromiscuous bool
+ switch tempRef {
+ case spoofing:
+ spoofingOrPromiscuous = n.mu.spoofing
+ case promiscuous:
+ spoofingOrPromiscuous = n.mu.promiscuous
+ }
+
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// An endpoint with this id exists, check if it can be used and return it.
- switch ref.getKind() {
- case permanentExpired:
- if !spoofingOrPromiscuous {
- n.mu.RUnlock()
- return nil
- }
- fallthrough
- case temporary, permanent:
- if ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
- }
+ if !ref.isAssignedRLocked(spoofingOrPromiscuous) {
+ n.mu.RUnlock()
+ return nil
+ }
+
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
}
}
- // A usable reference was not found, create a temporary one if requested by
- // the caller or if the address is found in the NIC's subnets.
- createTempEP := spoofingOrPromiscuous
- if !createTempEP {
- for _, sn := range n.addressRanges {
- // Skip the subnet address.
- if address == sn.ID() {
- continue
- }
- // For now just skip the broadcast address, until we support it.
- // FIXME(b/137608825): Add support for sending/receiving directed
- // (subnet) broadcast.
- if address == sn.Broadcast() {
- continue
- }
- if sn.Contains(address) {
- createTempEP = true
- break
- }
+ // Check if address is a broadcast address for the endpoint's network.
+ //
+ // Only IPv4 has a notion of broadcast addresses.
+ if protocol == header.IPv4ProtocolNumber {
+ if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ n.mu.RUnlock()
+ return ref
}
}
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
n.mu.RUnlock()
if !createTempEP {
@@ -303,11 +678,44 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok {
+ ref := n.getRefOrCreateTempLocked(protocol, address, peb)
+ n.mu.Unlock()
+ return ref
+}
+
+// getRefForBroadcastLocked returns an endpoint where address is the IPv4
+// broadcast address for the endpoint's network.
+//
+// n.mu MUST be read locked.
+func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint {
+ for _, ref := range n.mu.endpoints {
+ // Only IPv4 has a notion of broadcast addresses.
+ if ref.protocol != header.IPv4ProtocolNumber {
+ continue
+ }
+
+ addr := ref.addrWithPrefix()
+ subnet := addr.Subnet()
+ if subnet.IsBroadcast(address) && ref.tryIncRef() {
+ return ref
+ }
+ }
+
+ return nil
+}
+
+/// getRefOrCreateTempLocked returns an existing endpoint for address or creates
+/// and returns a temporary endpoint.
+//
+// If the address is the IPv4 broadcast address for an endpoint's network, that
+// endpoint will be returned.
+//
+// n.mu must be write locked.
+func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
+ if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok {
// No need to check the type as we are ok with expired endpoints at this
// point.
if ref.tryIncRef() {
- n.mu.Unlock()
return ref
}
// tryIncRef failing means the endpoint is scheduled to be removed once the
@@ -316,10 +724,18 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
n.removeEndpointLocked(ref)
}
+ // Check if address is a broadcast address for an endpoint's network.
+ //
+ // Only IPv4 has a notion of broadcast addresses.
+ if protocol == header.IPv4ProtocolNumber {
+ if ref := n.getRefForBroadcastRLocked(address); ref != nil {
+ return ref
+ }
+ }
+
// Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
- n.mu.Unlock()
return nil
}
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
@@ -328,26 +744,40 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, temporary)
-
- n.mu.Unlock()
+ }, peb, temporary, static, false)
return ref
}
-func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
- id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
- if ref, ok := n.endpoints[id]; ok {
+// addAddressLocked adds a new protocolAddress to n.
+//
+// If n already has the address in a non-permanent state, and the kind given is
+// permanent, that address will be promoted in place and its properties set to
+// the properties provided. Otherwise, it returns tcpip.ErrDuplicateAddress.
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind, configType networkEndpointConfigType, deprecated bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // TODO(b/141022673): Validate IP addresses before adding them.
+
+ // Sanity check.
+ id := NetworkEndpointID{LocalAddress: protocolAddress.AddressWithPrefix.Address}
+ if ref, ok := n.mu.endpoints[id]; ok {
+ // Endpoint already exists.
+ if kind != permanent {
+ return nil, tcpip.ErrDuplicateAddress
+ }
switch ref.getKind() {
case permanentTentative, permanent:
// The NIC already have a permanent endpoint with that address.
return nil, tcpip.ErrDuplicateAddress
case permanentExpired, temporary:
- // Promote the endpoint to become permanent and respect
- // the new peb.
+ // Promote the endpoint to become permanent and respect the new peb,
+ // configType and deprecated status.
if ref.tryIncRef() {
+ // TODO(b/147748385): Perform Duplicate Address Detection when promoting
+ // an IPv6 endpoint to permanent.
ref.setKind(permanent)
+ ref.deprecated = deprecated
+ ref.configType = configType
- refs := n.primary[ref.protocol]
+ refs := n.mu.primary[ref.protocol]
for i, r := range refs {
if r == ref {
switch peb {
@@ -357,9 +787,9 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
if i == 0 {
return ref, nil
}
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
case NeverPrimaryEndpoint:
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
return ref, nil
}
}
@@ -377,44 +807,30 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
}
- return n.addAddressLocked(protocolAddress, peb, permanent)
-}
-
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
- // TODO(b/141022673): Validate IP address before adding them.
-
- // Sanity check.
- id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
- if _, ok := n.endpoints[id]; ok {
- // Endpoint already exists.
- return nil, tcpip.ErrDuplicateAddress
- }
-
- netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
+ ep, ok := n.networkEndpoints[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
}
- // Create the new network endpoint.
- ep, err := netProto.NewEndpoint(n.id, protocolAddress.AddressWithPrefix, n.stack, n, n.linkEP)
- if err != nil {
- return nil, err
- }
-
isIPv6Unicast := protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address)
// If the address is an IPv6 address and it is a permanent address,
- // mark it as tentative so it goes through the DAD process.
+ // mark it as tentative so it goes through the DAD process if the NIC is
+ // enabled. If the NIC is not enabled, DAD will be started when the NIC is
+ // enabled.
if isIPv6Unicast && kind == permanent {
kind = permanentTentative
}
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- kind: kind,
+ refs: 1,
+ addr: protocolAddress.AddressWithPrefix,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
+ configType: configType,
+ deprecated: deprecated,
}
// Set up cache if link address resolution exists for this protocol.
@@ -433,13 +849,13 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
- n.endpoints[id] = ref
+ n.mu.endpoints[id] = ref
n.insertPrimaryEndpointLocked(ref, peb)
- // If we are adding a tentative IPv6 address, start DAD.
- if isIPv6Unicast && kind == permanentTentative {
- if err := n.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
+ // If we are adding a tentative IPv6 address, start DAD if the NIC is enabled.
+ if isIPv6Unicast && kind == permanentTentative && n.mu.enabled {
+ if err := n.mu.ndp.startDuplicateAddressDetection(protocolAddress.AddressWithPrefix.Address, ref); err != nil {
return nil, err
}
}
@@ -452,7 +868,7 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addPermanentAddressLocked(protocolAddress, peb)
+ _, err := n.addAddressLocked(protocolAddress, peb, permanent, static, false /* deprecated */)
n.mu.Unlock()
return err
@@ -464,22 +880,18 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
- addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
- for nid, ref := range n.endpoints {
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.mu.endpoints))
+ for _, ref := range n.mu.endpoints {
// Don't include tentative, expired or temporary endpoints to
// avoid confusion and prevent the caller from using those.
switch ref.getKind() {
- case permanentTentative, permanentExpired, temporary:
- // TODO(b/140898488): Should tentative addresses be
- // returned?
+ case permanentExpired, temporary:
continue
}
+
addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: ref.protocol,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: nid.LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- },
+ Protocol: ref.protocol,
+ AddressWithPrefix: ref.addrWithPrefix(),
})
}
return addrs
@@ -491,7 +903,7 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
defer n.mu.RUnlock()
var addrs []tcpip.ProtocolAddress
- for proto, list := range n.primary {
+ for proto, list := range n.mu.primary {
for _, ref := range list {
// Don't include tentative, expired or tempory endpoints
// to avoid confusion and prevent the caller from using
@@ -502,59 +914,51 @@ func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
}
addrs = append(addrs, tcpip.ProtocolAddress{
- Protocol: proto,
- AddressWithPrefix: tcpip.AddressWithPrefix{
- Address: ref.ep.ID().LocalAddress,
- PrefixLen: ref.ep.PrefixLen(),
- },
+ Protocol: proto,
+ AddressWithPrefix: ref.addrWithPrefix(),
})
}
}
return addrs
}
-// AddAddressRange adds a range of addresses to n, so that it starts accepting
-// packets targeted at the given addresses and network protocol. The range is
-// given by a subnet address, and all addresses contained in the subnet are
-// used except for the subnet address itself and the subnet's broadcast
-// address.
-func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
- n.mu.Lock()
- n.addressRanges = append(n.addressRanges, subnet)
- n.mu.Unlock()
-}
+// primaryAddress returns the primary address associated with this NIC.
+//
+// primaryAddress will return the first non-deprecated address if such an
+// address exists. If no non-deprecated address exists, the first deprecated
+// address will be returned.
+func (n *NIC) primaryAddress(proto tcpip.NetworkProtocolNumber) tcpip.AddressWithPrefix {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
-// RemoveAddressRange removes the given address range from n.
-func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
- n.mu.Lock()
+ list, ok := n.mu.primary[proto]
+ if !ok {
+ return tcpip.AddressWithPrefix{}
+ }
- // Use the same underlying array.
- tmp := n.addressRanges[:0]
- for _, sub := range n.addressRanges {
- if sub != subnet {
- tmp = append(tmp, sub)
+ var deprecatedEndpoint *referencedNetworkEndpoint
+ for _, ref := range list {
+ // Don't include tentative, expired or tempory endpoints to avoid confusion
+ // and prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentTentative, permanentExpired, temporary:
+ continue
}
- }
- n.addressRanges = tmp
- n.mu.Unlock()
-}
+ if !ref.deprecated {
+ return ref.addrWithPrefix()
+ }
-// Subnets returns the Subnets associated with this NIC.
-func (n *NIC) AddressRanges() []tcpip.Subnet {
- n.mu.RLock()
- defer n.mu.RUnlock()
- sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints))
- for nid := range n.endpoints {
- sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
- if err != nil {
- // This should never happen as the mask has been carefully crafted to
- // match the address.
- panic("Invalid endpoint subnet: " + err.Error())
+ if deprecatedEndpoint == nil {
+ deprecatedEndpoint = ref
}
- sns = append(sns, sn)
}
- return append(sns, n.addressRanges...)
+
+ if deprecatedEndpoint != nil {
+ return deprecatedEndpoint.addrWithPrefix()
+ }
+
+ return tcpip.AddressWithPrefix{}
}
// insertPrimaryEndpointLocked adds r to n's primary endpoint list as required
@@ -564,21 +968,21 @@ func (n *NIC) AddressRanges() []tcpip.Subnet {
func (n *NIC) insertPrimaryEndpointLocked(r *referencedNetworkEndpoint, peb PrimaryEndpointBehavior) {
switch peb {
case CanBePrimaryEndpoint:
- n.primary[r.protocol] = append(n.primary[r.protocol], r)
+ n.mu.primary[r.protocol] = append(n.mu.primary[r.protocol], r)
case FirstPrimaryEndpoint:
- n.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.primary[r.protocol]...)
+ n.mu.primary[r.protocol] = append([]*referencedNetworkEndpoint{r}, n.mu.primary[r.protocol]...)
}
}
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
- id := *r.ep.ID()
+ id := NetworkEndpointID{LocalAddress: r.address()}
// Nothing to do if the reference has already been replaced with a different
// one. This happens in the case where 1) this endpoint's ref count hit zero
// and was waiting (on the lock) to be removed and 2) the same address was
// re-added in the meantime by removing this endpoint from the list and
// adding a new one.
- if n.endpoints[id] != r {
+ if n.mu.endpoints[id] != r {
return
}
@@ -586,16 +990,15 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
panic("Reference count dropped to zero before being removed")
}
- delete(n.endpoints, id)
- refs := n.primary[r.protocol]
+ delete(n.mu.endpoints, id)
+ refs := n.mu.primary[r.protocol]
for i, ref := range refs {
if ref == r {
- n.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ n.mu.primary[r.protocol] = append(refs[:i], refs[i+1:]...)
+ refs[len(refs)-1] = nil
break
}
}
-
- r.ep.Close()
}
func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
@@ -605,7 +1008,7 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
}
func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
- r, ok := n.endpoints[NetworkEndpointID{addr}]
+ r, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return tcpip.ErrBadLocalAddress
}
@@ -615,26 +1018,45 @@ func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- isIPv6Unicast := r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr)
-
- // If we are removing a tentative IPv6 unicast address, stop DAD.
- if isIPv6Unicast && kind == permanentTentative {
- n.ndp.stopDuplicateAddressDetection(addr)
+ switch r.protocol {
+ case header.IPv6ProtocolNumber:
+ return n.removePermanentIPv6EndpointLocked(r, true /* allowSLAACInvalidation */)
+ default:
+ r.expireLocked()
+ return nil
}
+}
- r.setKind(permanentExpired)
- if !r.decRefLocked() {
- // The endpoint still has references to it.
- return nil
+func (n *NIC) removePermanentIPv6EndpointLocked(r *referencedNetworkEndpoint, allowSLAACInvalidation bool) *tcpip.Error {
+ addr := r.addrWithPrefix()
+
+ isIPv6Unicast := header.IsV6UnicastAddress(addr.Address)
+
+ if isIPv6Unicast {
+ n.mu.ndp.stopDuplicateAddressDetection(addr.Address)
+
+ // If we are removing an address generated via SLAAC, cleanup
+ // its SLAAC resources and notify the integrator.
+ switch r.configType {
+ case slaac:
+ n.mu.ndp.cleanupSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ case slaacTemp:
+ n.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr, allowSLAACInvalidation)
+ }
}
+ r.expireLocked()
+
// At this point the endpoint is deleted.
// If we are removing an IPv6 unicast address, leave the solicited-node
// multicast address.
+ //
+ // We ignore the tcpip.ErrBadLocalAddress error because the solicited-node
+ // multicast group may be left by user action.
if isIPv6Unicast {
- snmc := header.SolicitedNodeAddr(addr)
- if err := n.leaveGroupLocked(snmc); err != nil {
+ snmc := header.SolicitedNodeAddr(addr.Address)
+ if err := n.leaveGroupLocked(snmc, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress {
return err
}
}
@@ -668,23 +1090,23 @@ func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.A
// outlined in RFC 3810 section 5.
id := NetworkEndpointID{addr}
- joins := n.mcastJoins[id]
+ joins := n.mu.mcastJoins[id]
if joins == 0 {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
return tcpip.ErrUnknownProtocol
}
- if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
+ if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint); err != nil {
+ }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
- n.mcastJoins[id] = joins + 1
+ n.mu.mcastJoins[id] = joins + 1
return nil
}
@@ -694,48 +1116,75 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.leaveGroupLocked(addr)
+ return n.leaveGroupLocked(addr, false /* force */)
}
// leaveGroupLocked decrements the count for the given multicast address, and
// when it reaches zero removes the endpoint for this address. n MUST be locked
// before leaveGroupLocked is called.
-func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
+//
+// If force is true, then the count for the multicast addres is ignored and the
+// endpoint will be removed immediately.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address, force bool) *tcpip.Error {
id := NetworkEndpointID{addr}
- joins := n.mcastJoins[id]
- switch joins {
- case 0:
+ joins, ok := n.mu.mcastJoins[id]
+ if !ok {
// There are no joins with this address on this NIC.
return tcpip.ErrBadLocalAddress
- case 1:
- // This is the last one, clean up.
- if err := n.removePermanentAddressLocked(addr); err != nil {
- return err
- }
}
- n.mcastJoins[id] = joins - 1
+
+ joins--
+ if force || joins == 0 {
+ // There are no outstanding joins or we are forced to leave, clean up.
+ delete(n.mu.mcastJoins, id)
+ return n.removePermanentAddressLocked(addr)
+ }
+
+ n.mu.mcastJoins[id] = joins
return nil
}
-func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+// isInGroup returns true if n has joined the multicast group addr.
+func (n *NIC) isInGroup(addr tcpip.Address) bool {
+ n.mu.RLock()
+ joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
+ n.mu.RUnlock()
+
+ return joins != 0
+}
+
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt *PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
- ref.ep.HandlePacket(&r, vv)
+
+ ref.ep.HandlePacket(&r, pkt)
ref.decRef()
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
-// the NIC receives a packet from the physical interface.
+// the NIC receives a packet from the link endpoint.
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
-func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
+func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ // If the NIC is not yet enabled, don't receive any packets.
+ if !enabled {
+ n.mu.RUnlock()
+
+ n.stats.DisabledRx.Packets.Increment()
+ n.stats.DisabledRx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
+ return
+ }
+
n.stats.Rx.Packets.Increment()
- n.stats.Rx.Bytes.IncrementBy(uint64(vv.Size()))
+ n.stats.Rx.Bytes.IncrementBy(uint64(pkt.Data.Size()))
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
+ n.mu.RUnlock()
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
@@ -747,32 +1196,59 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
// Are any packet sockets listening for this network protocol?
- n.mu.RLock()
- packetEPs := n.packetEPs[protocol]
- // Check whether there are packet sockets listening for every protocol.
- // If we received a packet with protocol EthernetProtocolAll, then the
- // previous for loop will have handled it.
- if protocol != header.EthernetProtocolAll {
- packetEPs = append(packetEPs, n.packetEPs[header.EthernetProtocolAll]...)
- }
+ packetEPs := n.mu.packetEPs[protocol]
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...)
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, vv, linkHeader)
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketHost
+ ep.HandlePacket(n.id, local, protocol, p)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
- if len(vv.First()) < netProto.MinimumPacketSize() {
+ // Parse headers.
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
+ if !ok {
+ // The packet is too small to contain a network header.
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
+ if hasTransportHdr {
+ // Parse the transport header if present.
+ if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
+ state.proto.Parse(pkt)
+ }
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View())
+
+ if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
+ // The source address is one of our own, so we never should have gotten a
+ // packet like this unless handleLocal is false. Loopback also calls this
+ // function even though the packets didn't come from the physical interface
+ // so don't drop those.
+ n.stack.stats.IP.InvalidSourceAddressesReceived.Increment()
+ return
+ }
- src, dst := netProto.ParseAddresses(vv.First())
+ // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet.
+ // Loopback traffic skips the prerouting chain.
+ if protocol == header.IPv4ProtocolNumber && !n.isLoopback() {
+ // iptables filtering.
+ ipt := n.stack.IPTables()
+ address := n.primaryAddress(protocol)
+ if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok {
+ // iptables is telling us to drop the packet.
+ return
+ }
+ }
if ref := n.getRef(protocol, dst); ref != nil {
- handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
+ handlePacket(protocol, dst, src, n.linkEP.LinkAddress(), remote, ref, pkt)
return
}
@@ -783,54 +1259,98 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
if n.stack.Forwarding(protocol) {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
return
}
- defer r.Release()
-
- r.LocalLinkAddress = n.linkEP.LinkAddress()
- r.RemoteLinkAddress = remote
// Found a NIC.
n := r.ref.nic
n.mu.RLock()
- ref, ok := n.endpoints[NetworkEndpointID{dst}]
- ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
+ ref, ok := n.mu.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoingRLocked() && ref.tryIncRef()
n.mu.RUnlock()
if ok {
+ r.LocalLinkAddress = n.linkEP.LinkAddress()
+ r.RemoteLinkAddress = remote
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
- ref.ep.HandlePacket(&r, vv)
+ ref.ep.HandlePacket(&r, pkt)
ref.decRef()
- } else {
- // n doesn't have a destination endpoint.
- // Send the packet out of n.
- // If we want to send the packet to a link-layer,
- // we have to reserve space for an Ethernet header.
- hdr := buffer.NewPrependableFromView(vv.First(), int(n.linkEP.MaxHeaderLength()))
- vv.RemoveFirst()
-
- // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
- // TODO(b/128629022): use route.WritePacket.
- if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- } else {
- n.stats.Tx.Packets.Increment()
- n.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + vv.Size()))
+ r.Release()
+ return
+ }
+
+ // n doesn't have a destination endpoint.
+ // Send the packet out of n.
+ // TODO(b/128629022): move this logic to route.WritePacket.
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
+ if ch, err := r.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ n.stack.forwarder.enqueue(ch, n, &r, protocol, pkt)
+ // forwarder will release route.
+ return
}
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ r.Release()
+ return
}
+
+ // The link-address resolution finished immediately.
+ n.forwardPacket(&r, protocol, pkt)
+ r.Release()
return
}
// If a packet socket handled the packet, don't treat it as invalid.
if len(packetEPs) == 0 {
- n.stack.stats.IP.InvalidAddressesReceived.Increment()
+ n.stack.stats.IP.InvalidDestinationAddressesReceived.Increment()
+ }
+}
+
+// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket.
+func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ n.mu.RLock()
+ // We do not deliver to protocol specific packet endpoints as on Linux
+ // only ETH_P_ALL endpoints get outbound packets.
+ // Add any other packet sockets that maybe listening for all protocols.
+ packetEPs := n.mu.packetEPs[header.EthernetProtocolAll]
+ n.mu.RUnlock()
+ for _, ep := range packetEPs {
+ p := pkt.Clone()
+ p.PktType = tcpip.PacketOutgoing
+ // Add the link layer header as outgoing packets are intercepted
+ // before the link layer header is created.
+ n.linkEP.AddHeader(local, remote, protocol, p)
+ ep.HandlePacket(n.id, local, protocol, p)
}
}
+func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ // TODO(b/143425874) Decrease the TTL field in forwarded packets.
+
+ // pkt may have set its header and may not have enough headroom for link-layer
+ // header for the other link to prepend. Here we create a new packet to
+ // forward.
+ fwdPkt := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: int(n.linkEP.MaxHeaderLength()),
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ })
+
+ // WritePacket takes ownership of fwdPkt, calculate numBytes first.
+ numBytes := fwdPkt.Size()
+
+ if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, fwdPkt); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return
+ }
+
+ n.stats.Tx.Packets.Increment()
+ n.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
+}
+
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
-func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
@@ -842,41 +1362,60 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
+ n.stack.demux.deliverRawPacket(r, protocol, pkt)
+
+ // TransportHeader is empty only when pkt is an ICMP packet or was reassembled
+ // from fragments.
+ if pkt.TransportHeader().View().IsEmpty() {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader
+ // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ // ICMP packets may be longer, but until icmp.Parse is implemented, here
+ // we parse it using the minimum size.
+ if _, ok := pkt.TransportHeader().Consume(transProto.MinimumPacketSize()); !ok {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+ } else {
+ // This is either a bad packet or was re-assembled from fragments.
+ transProto.Parse(pkt)
+ }
+ }
- if len(vv.First()) < transProto.MinimumPacketSize() {
+ if pkt.TransportHeader().View().Size() < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader().View())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
+ if n.stack.demux.deliverPacket(r, protocol, pkt, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
- if state.defaultHandler(r, id, netHeader, vv) {
+ if state.defaultHandler(r, id, pkt) {
return
}
}
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) {
+ if !transProto.HandleUnknownDestinationPacket(r, id, pkt) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
-func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
@@ -887,17 +1426,18 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
- if len(vv.First()) < 8 {
+ transHeader, ok := pkt.Data.PullUp(8)
+ if !ok {
return
}
- srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ srcPort, dstPort, err := transProto.ParsePorts(transHeader)
if err != nil {
return
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, pkt, id) {
return
}
}
@@ -907,18 +1447,31 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+// Name returns the name of n.
+func (n *NIC) Name() string {
+ return n.name
+}
+
// Stack returns the instance of the Stack that owns this NIC.
func (n *NIC) Stack() *Stack {
return n.stack
}
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// isAddrTentative returns true if addr is tentative on n.
//
// Note that if addr is not associated with n, then this function will return
// false. It will only return true if the address is associated with the NIC
// AND it is tentative.
func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return false
}
@@ -926,15 +1479,17 @@ func (n *NIC) isAddrTentative(addr tcpip.Address) bool {
return ref.getKind() == permanentTentative
}
-// dupTentativeAddrDetected attempts to inform n that a tentative addr
-// is a duplicate on a link.
+// dupTentativeAddrDetected attempts to inform n that a tentative addr is a
+// duplicate on a link.
//
-// dupTentativeAddrDetected will delete the tentative address if it exists.
+// dupTentativeAddrDetected will remove the tentative address if it exists. If
+// the address was generated via SLAAC, an attempt will be made to generate a
+// new address.
func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- ref, ok := n.endpoints[NetworkEndpointID{addr}]
+ ref, ok := n.mu.endpoints[NetworkEndpointID{addr}]
if !ok {
return tcpip.ErrBadAddress
}
@@ -943,7 +1498,24 @@ func (n *NIC) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- return n.removePermanentAddressLocked(addr)
+ // If the address is a SLAAC address, do not invalidate its SLAAC prefix as a
+ // new address will be generated for it.
+ if err := n.removePermanentIPv6EndpointLocked(ref, false /* allowSLAACInvalidation */); err != nil {
+ return err
+ }
+
+ prefix := ref.addrWithPrefix().Subnet()
+
+ switch ref.configType {
+ case slaac:
+ n.mu.ndp.regenerateSLAACAddr(prefix)
+ case slaacTemp:
+ // Do not reset the generation attempts counter for the prefix as the
+ // temporary address is being regenerated in response to a DAD conflict.
+ n.mu.ndp.regenerateTempSLAACAddr(prefix, false /* resetGenAttempts */)
+ }
+
+ return nil
}
// setNDPConfigs sets the NDP configurations for n.
@@ -954,10 +1526,39 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) {
c.validate()
n.mu.Lock()
- n.ndp.configs = c
+ n.mu.ndp.configs = c
n.mu.Unlock()
}
+// NUDConfigs gets the NUD configurations for n.
+func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) {
+ if n.neigh == nil {
+ return NUDConfigurations{}, tcpip.ErrNotSupported
+ }
+ return n.neigh.config(), nil
+}
+
+// setNUDConfigs sets the NUD configurations for n.
+//
+// Note, if c contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error {
+ if n.neigh == nil {
+ return tcpip.ErrNotSupported
+ }
+ c.resetInvalidFields()
+ n.neigh.setConfig(c)
+ return nil
+}
+
+// handleNDPRA handles an NDP Router Advertisement message that arrived on n.
+func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ n.mu.ndp.handleRA(ip, ra)
+}
+
type networkEndpointKind int32
const (
@@ -977,7 +1578,7 @@ const (
// removing the permanent address from the NIC.
permanent
- // An expired permanent endoint is a permanent endoint that had its address
+ // An expired permanent endpoint is a permanent endpoint that had its address
// removed from the NIC, and it is waiting to be removed once no more routes
// hold a reference to it. This is achieved by decreasing its reference count
// by 1. If its address is re-added before the endpoint is removed, its type
@@ -997,11 +1598,11 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa
n.mu.Lock()
defer n.mu.Unlock()
- eps, ok := n.packetEPs[netProto]
+ eps, ok := n.mu.packetEPs[netProto]
if !ok {
return tcpip.ErrNotSupported
}
- n.packetEPs[netProto] = append(eps, ep)
+ n.mu.packetEPs[netProto] = append(eps, ep)
return nil
}
@@ -1010,21 +1611,40 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
n.mu.Lock()
defer n.mu.Unlock()
- eps, ok := n.packetEPs[netProto]
+ eps, ok := n.mu.packetEPs[netProto]
if !ok {
return
}
for i, epOther := range eps {
if epOther == ep {
- n.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
+ n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...)
return
}
}
}
+type networkEndpointConfigType int32
+
+const (
+ // A statically configured endpoint is an address that was added by
+ // some user-specified action (adding an explicit address, joining a
+ // multicast group).
+ static networkEndpointConfigType = iota
+
+ // A SLAAC configured endpoint is an IPv6 endpoint that was added by
+ // SLAAC as per RFC 4862 section 5.5.3.
+ slaac
+
+ // A temporary SLAAC configured endpoint is an IPv6 endpoint that was added by
+ // SLAAC as per RFC 4941. Temporary SLAAC addresses are short-lived and are
+ // not expected to be valid (or preferred) forever; hence the term temporary.
+ slaacTemp
+)
+
type referencedNetworkEndpoint struct {
ep NetworkEndpoint
+ addr tcpip.AddressWithPrefix
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -1038,6 +1658,24 @@ type referencedNetworkEndpoint struct {
// networkEndpointKind must only be accessed using {get,set}Kind().
kind networkEndpointKind
+
+ // configType is the method that was used to configure this endpoint.
+ // This must never change except during endpoint creation and promotion to
+ // permanent.
+ configType networkEndpointConfigType
+
+ // deprecated indicates whether or not the endpoint should be considered
+ // deprecated. That is, when deprecated is true, other endpoints that are not
+ // deprecated should be preferred.
+ deprecated bool
+}
+
+func (r *referencedNetworkEndpoint) address() tcpip.Address {
+ return r.addr.Address
+}
+
+func (r *referencedNetworkEndpoint) addrWithPrefix() tcpip.AddressWithPrefix {
+ return r.addr
}
func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
@@ -1049,17 +1687,44 @@ func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address
-// has been removed), or the NIC to be in spoofing mode.
+// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// has been removed) unless the NIC is in spoofing mode, or temporary.
func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
- return r.getKind() != permanentExpired || r.nic.spoofing
+ r.nic.mu.RLock()
+ defer r.nic.mu.RUnlock()
+
+ return r.isValidForOutgoingRLocked()
}
-// isValidForIncoming returns true if the endpoint can accept an incoming
-// packet. It requires the endpoint to not be marked expired (i.e., its address
-// has been removed), or the NIC to be in promiscuous mode.
-func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
- return r.getKind() != permanentExpired || r.nic.promiscuous
+// isValidForOutgoingRLocked is the same as isValidForOutgoing but requires
+// r.nic.mu to be read locked.
+func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
+ if !r.nic.mu.enabled {
+ return false
+ }
+
+ return r.isAssignedRLocked(r.nic.mu.spoofing)
+}
+
+// isAssignedRLocked returns true if r is considered to be assigned to the NIC.
+//
+// r.nic.mu must be read locked.
+func (r *referencedNetworkEndpoint) isAssignedRLocked(spoofingOrPromiscuous bool) bool {
+ switch r.getKind() {
+ case permanentTentative:
+ return false
+ case permanentExpired:
+ return spoofingOrPromiscuous
+ default:
+ return true
+ }
+}
+
+// expireLocked decrements the reference count and marks the permanent endpoint
+// as expired.
+func (r *referencedNetworkEndpoint) expireLocked() {
+ r.setKind(permanentExpired)
+ r.decRefLocked()
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
@@ -1071,14 +1736,11 @@ func (r *referencedNetworkEndpoint) decRef() {
}
// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked. Returns true if the endpoint was removed.
-func (r *referencedNetworkEndpoint) decRefLocked() bool {
+// locked.
+func (r *referencedNetworkEndpoint) decRefLocked() {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpointLocked(r)
- return true
}
-
- return false
}
// incRef increments the ref count. It must only be called when the caller is
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
new file mode 100644
index 000000000..d312a79eb
--- /dev/null
+++ b/pkg/tcpip/stack/nic_test.go
@@ -0,0 +1,316 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+var _ LinkEndpoint = (*testLinkEndpoint)(nil)
+
+// A LinkEndpoint that throws away outgoing packets.
+//
+// We use this instead of the channel endpoint as the channel package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testLinkEndpoint struct {
+ dispatcher NetworkDispatcher
+}
+
+// Attach implements LinkEndpoint.Attach.
+func (e *testLinkEndpoint) Attach(dispatcher NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements LinkEndpoint.IsAttached.
+func (e *testLinkEndpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// MTU implements LinkEndpoint.MTU.
+func (*testLinkEndpoint) MTU() uint32 {
+ return math.MaxUint16
+}
+
+// Capabilities implements LinkEndpoint.Capabilities.
+func (*testLinkEndpoint) Capabilities() LinkEndpointCapabilities {
+ return CapabilityResolutionRequired
+}
+
+// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength.
+func (*testLinkEndpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testLinkEndpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// Wait implements LinkEndpoint.Wait.
+func (*testLinkEndpoint) Wait() {}
+
+// WritePacket implements LinkEndpoint.WritePacket.
+func (e *testLinkEndpoint) WritePacket(*Route, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements LinkEndpoint.WritePackets.
+func (e *testLinkEndpoint) WritePackets(*Route, *GSO, PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteRawPacket implements LinkEndpoint.WriteRawPacket.
+func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
+func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType {
+ panic("not implemented")
+}
+
+// AddHeader implements stack.LinkEndpoint.AddHeader.
+func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
+ panic("not implemented")
+}
+
+var _ NetworkEndpoint = (*testIPv6Endpoint)(nil)
+
+// An IPv6 NetworkEndpoint that throws away outgoing packets.
+//
+// We use this instead of ipv6.endpoint because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Endpoint struct {
+ nicID tcpip.NICID
+ linkEP LinkEndpoint
+ protocol *testIPv6Protocol
+}
+
+// DefaultTTL implements NetworkEndpoint.DefaultTTL.
+func (*testIPv6Endpoint) DefaultTTL() uint8 {
+ return 0
+}
+
+// MTU implements NetworkEndpoint.MTU.
+func (e *testIPv6Endpoint) MTU() uint32 {
+ return e.linkEP.MTU() - header.IPv6MinimumSize
+}
+
+// Capabilities implements NetworkEndpoint.Capabilities.
+func (e *testIPv6Endpoint) Capabilities() LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength implements NetworkEndpoint.MaxHeaderLength.
+func (e *testIPv6Endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket implements NetworkEndpoint.WritePacket.
+func (*testIPv6Endpoint) WritePacket(*Route, *GSO, NetworkHeaderParams, *PacketBuffer) *tcpip.Error {
+ return nil
+}
+
+// WritePackets implements NetworkEndpoint.WritePackets.
+func (*testIPv6Endpoint) WritePackets(*Route, *GSO, PacketBufferList, NetworkHeaderParams) (int, *tcpip.Error) {
+ // Our tests don't use this so we don't support it.
+ return 0, tcpip.ErrNotSupported
+}
+
+// WriteHeaderIncludedPacket implements
+// NetworkEndpoint.WriteHeaderIncludedPacket.
+func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip.Error {
+ // Our tests don't use this so we don't support it.
+ return tcpip.ErrNotSupported
+}
+
+// NICID implements NetworkEndpoint.NICID.
+func (e *testIPv6Endpoint) NICID() tcpip.NICID {
+ return e.nicID
+}
+
+// HandlePacket implements NetworkEndpoint.HandlePacket.
+func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) {
+}
+
+// Close implements NetworkEndpoint.Close.
+func (*testIPv6Endpoint) Close() {}
+
+// NetworkProtocolNumber implements NetworkEndpoint.NetworkProtocolNumber.
+func (*testIPv6Endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+var _ NetworkProtocol = (*testIPv6Protocol)(nil)
+
+// An IPv6 NetworkProtocol that supports the bare minimum to make a stack
+// believe it supports IPv6.
+//
+// We use this instead of ipv6.protocol because the ipv6 package depends on
+// the stack package which this test lives in, causing a cyclic dependency.
+type testIPv6Protocol struct{}
+
+// Number implements NetworkProtocol.Number.
+func (*testIPv6Protocol) Number() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// MinimumPacketSize implements NetworkProtocol.MinimumPacketSize.
+func (*testIPv6Protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// DefaultPrefixLen implements NetworkProtocol.DefaultPrefixLen.
+func (*testIPv6Protocol) DefaultPrefixLen() int {
+ return header.IPv6AddressSize * 8
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint implements NetworkProtocol.NewEndpoint.
+func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint {
+ return &testIPv6Endpoint{
+ nicID: nicID,
+ linkEP: linkEP,
+ protocol: p,
+ }
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Option implements NetworkProtocol.Option.
+func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+ return nil
+}
+
+// Close implements NetworkProtocol.Close.
+func (*testIPv6Protocol) Close() {}
+
+// Wait implements NetworkProtocol.Wait.
+func (*testIPv6Protocol) Wait() {}
+
+// Parse implements NetworkProtocol.Parse.
+func (*testIPv6Protocol) Parse(*PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ return 0, false, false
+}
+
+var _ LinkAddressResolver = (*testIPv6Protocol)(nil)
+
+// LinkAddressProtocol implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv6ProtocolNumber
+}
+
+// LinkAddressRequest implements LinkAddressResolver.
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
+ return nil
+}
+
+// ResolveStaticAddress implements LinkAddressResolver.
+func (*testIPv6Protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if header.IsV6MulticastAddress(addr) {
+ return header.EthernetAddressFromMulticastIPv6Address(addr), true
+ }
+ return "", false
+}
+
+// Test the race condition where a NIC is removed and an RS timer fires at the
+// same time.
+func TestRemoveNICWhileHandlingRSTimer(t *testing.T) {
+ const (
+ nicID = 1
+
+ maxRtrSolicitations = 5
+ )
+
+ e := testLinkEndpoint{}
+ s := New(Options{
+ NetworkProtocols: []NetworkProtocol{&testIPv6Protocol{}},
+ NDPConfigs: NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: minimumRtrSolicitationInterval,
+ },
+ })
+
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ s.mu.Lock()
+ // Wait for the router solicitation timer to fire and block trying to obtain
+ // the stack lock when doing link address resolution.
+ time.Sleep(minimumRtrSolicitationInterval * 2)
+ if err := s.removeNICLocked(nicID); err != nil {
+ t.Fatalf("s.removeNICLocked(%d) = %s", nicID, err)
+ }
+ s.mu.Unlock()
+}
+
+func TestDisabledRxStatsWhenNICDisabled(t *testing.T) {
+ // When the NIC is disabled, the only field that matters is the stats field.
+ // This test is limited to stats counter checks.
+ nic := NIC{
+ stats: makeNICStats(),
+ }
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 0 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ nic.DeliverNetworkPacket("", "", 0, NewPacketBuffer(PacketBufferOptions{
+ Data: buffer.View([]byte{1, 2, 3, 4}).ToVectorisedView(),
+ }))
+
+ if got := nic.stats.DisabledRx.Packets.Value(); got != 1 {
+ t.Errorf("got DisabledRx.Packets = %d, want = 1", got)
+ }
+ if got := nic.stats.DisabledRx.Bytes.Value(); got != 4 {
+ t.Errorf("got DisabledRx.Bytes = %d, want = 4", got)
+ }
+ if got := nic.stats.Rx.Packets.Value(); got != 0 {
+ t.Errorf("got Rx.Packets = %d, want = 0", got)
+ }
+ if got := nic.stats.Rx.Bytes.Value(); got != 0 {
+ t.Errorf("got Rx.Bytes = %d, want = 0", got)
+ }
+}
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go
new file mode 100644
index 000000000..e1ec15487
--- /dev/null
+++ b/pkg/tcpip/stack/nud.go
@@ -0,0 +1,466 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // defaultBaseReachableTime is the default base duration for computing the
+ // random reachable time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of a uniformly distributed random value between the minimum and
+ // maximum random factors, multiplied by the base reachable time. Using a
+ // random component eliminates the possibility that Neighbor Unreachability
+ // Detection messages will synchronize with each other.
+ //
+ // Default taken from REACHABLE_TIME of RFC 4861 section 10.
+ defaultBaseReachableTime = 30 * time.Second
+
+ // minimumBaseReachableTime is the minimum base duration for computing the
+ // random reachable time.
+ //
+ // Minimum = 1ms
+ minimumBaseReachableTime = time.Millisecond
+
+ // defaultMinRandomFactor is the default minimum value of the random factor
+ // used for computing reachable time.
+ //
+ // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10.
+ defaultMinRandomFactor = 0.5
+
+ // defaultMaxRandomFactor is the default maximum value of the random factor
+ // used for computing reachable time.
+ //
+ // The default value depends on the value of MinRandomFactor.
+ // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10,
+ // the value from the RFC will be used; otherwise, the default is
+ // MinRandomFactor multiplied by three.
+ defaultMaxRandomFactor = 1.5
+
+ // defaultRetransmitTimer is the default amount of time to wait between
+ // sending reachability probes.
+ //
+ // Default taken from RETRANS_TIMER of RFC 4861 section 10.
+ defaultRetransmitTimer = time.Second
+
+ // minimumRetransmitTimer is the minimum amount of time to wait between
+ // sending reachability probes.
+ //
+ // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here
+ // to make sure the messages are not sent all at once. We also come to this
+ // value because in the RetransmitTimer field of a Router Advertisement, a
+ // value of 0 means unspecified, so the smallest valid value is 1. Note, the
+ // unit of the RetransmitTimer field in the Router Advertisement is
+ // milliseconds.
+ minimumRetransmitTimer = time.Millisecond
+
+ // defaultDelayFirstProbeTime is the default duration to wait for a
+ // non-Neighbor-Discovery related protocol to reconfirm reachability after
+ // entering the DELAY state. After this time, a reachability probe will be
+ // sent and the entry will transition to the PROBE state.
+ //
+ // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10.
+ defaultDelayFirstProbeTime = 5 * time.Second
+
+ // defaultMaxMulticastProbes is the default number of reachabililty probes
+ // to send before concluding negative reachability and deleting the neighbor
+ // entry from the INCOMPLETE state.
+ //
+ // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10.
+ defaultMaxMulticastProbes = 3
+
+ // defaultMaxUnicastProbes is the default number of reachability probes to
+ // send before concluding retransmission from within the PROBE state should
+ // cease and the entry SHOULD be deleted.
+ //
+ // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10.
+ defaultMaxUnicastProbes = 3
+
+ // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD
+ // delay sending a response for a random time between 0 and this time, if the
+ // target address is an anycast address.
+ //
+ // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10.
+ defaultMaxAnycastDelayTime = time.Second
+
+ // defaultMaxReachbilityConfirmations is the default amount of unsolicited
+ // reachability confirmation messages a node MAY send to all-node multicast
+ // address when it determines its link-layer address has changed.
+ //
+ // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10.
+ defaultMaxReachbilityConfirmations = 3
+
+ // defaultUnreachableTime is the default duration for how long an entry will
+ // remain in the FAILED state before being removed from the neighbor cache.
+ //
+ // Note, there is no equivalent protocol constant defined in RFC 4861. It
+ // leaves the specifics of any garbage collection mechanism up to the
+ // implementation.
+ defaultUnreachableTime = 5 * time.Second
+)
+
+// NUDDispatcher is the interface integrators of netstack must implement to
+// receive and handle NUD related events.
+type NUDDispatcher interface {
+ // OnNeighborAdded will be called when a new entry is added to a NIC's (with
+ // ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID)
+ // neighbor table changes state and/or link address.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+
+ // OnNeighborRemoved will be called when an entry is removed from a NIC's
+ // (with ID nicID) neighbor table.
+ //
+ // This function is permitted to block indefinitely without interfering with
+ // the stack's operation.
+ //
+ // May be called concurrently.
+ OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time)
+}
+
+// ReachabilityConfirmationFlags describes the flags used within a reachability
+// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP,
+// respectively).
+type ReachabilityConfirmationFlags struct {
+ // Solicited indicates that the advertisement was sent in response to a
+ // reachability probe.
+ Solicited bool
+
+ // Override indicates that the reachability confirmation should override an
+ // existing neighbor cache entry and update the cached link-layer address.
+ // When Override is not set the confirmation will not update a cached
+ // link-layer address, but will update an existing neighbor cache entry for
+ // which no link-layer address is known.
+ Override bool
+
+ // IsRouter indicates that the sender is a router.
+ IsRouter bool
+}
+
+// NUDHandler communicates external events to the Neighbor Unreachability
+// Detection state machine, which is implemented per-interface. This is used by
+// network endpoints to inform the Neighbor Cache of probes and confirmations.
+type NUDHandler interface {
+ // HandleProbe processes an incoming neighbor probe (e.g. ARP request or
+ // Neighbor Solicitation for ARP or NDP, respectively). Validation of the
+ // probe needs to be performed before calling this function since the
+ // Neighbor Cache doesn't have access to view the NIC's assigned addresses.
+ HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver)
+
+ // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP
+ // reply or Neighbor Advertisement for ARP or NDP, respectively).
+ HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags)
+
+ // HandleUpperLevelConfirmation processes an incoming upper-level protocol
+ // (e.g. TCP acknowledgements) reachability confirmation.
+ HandleUpperLevelConfirmation(addr tcpip.Address)
+}
+
+// NUDConfigurations is the NUD configurations for the netstack. This is used
+// by the neighbor cache to operate the NUD state machine on each device in the
+// local network.
+type NUDConfigurations struct {
+ // BaseReachableTime is the base duration for computing the random reachable
+ // time.
+ //
+ // Reachable time is the duration for which a neighbor is considered
+ // reachable after a positive reachability confirmation is received. It is a
+ // function of uniformly distributed random value between minRandomFactor and
+ // maxRandomFactor multiplied by baseReachableTime. Using a random component
+ // eliminates the possibility that Neighbor Unreachability Detection messages
+ // will synchronize with each other.
+ //
+ // After this time, a neighbor entry will transition from REACHABLE to STALE
+ // state.
+ //
+ // Must be greater than 0.
+ BaseReachableTime time.Duration
+
+ // LearnBaseReachableTime enables learning BaseReachableTime during runtime
+ // from the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option.
+ LearnBaseReachableTime bool
+
+ // MinRandomFactor is the minimum value of the random factor used for
+ // computing reachable time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be greater than 0.
+ MinRandomFactor float32
+
+ // MaxRandomFactor is the maximum value of the random factor used for
+ // computing reachabile time.
+ //
+ // See BaseReachbleTime for more information on computing the reachable time.
+ //
+ // Must be great than or equal to MinRandomFactor.
+ MaxRandomFactor float32
+
+ // RetransmitTimer is the duration between retransmission of reachability
+ // probes in the PROBE state.
+ RetransmitTimer time.Duration
+
+ // LearnRetransmitTimer enables learning RetransmitTimer during runtime from
+ // the neighbor discovery protocol, if supported.
+ //
+ // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option.
+ LearnRetransmitTimer bool
+
+ // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery
+ // related protocol to reconfirm reachability after entering the DELAY state.
+ // After this time, a reachability probe will be sent and the entry will
+ // transition to the PROBE state.
+ //
+ // Must be greater than 0.
+ DelayFirstProbeTime time.Duration
+
+ // MaxMulticastProbes is the number of reachability probes to send before
+ // concluding negative reachability and deleting the neighbor entry from the
+ // INCOMPLETE state.
+ //
+ // Must be greater than 0.
+ MaxMulticastProbes uint32
+
+ // MaxUnicastProbes is the number of reachability probes to send before
+ // concluding retransmission from within the PROBE state should cease and
+ // entry SHOULD be deleted.
+ //
+ // Must be greater than 0.
+ MaxUnicastProbes uint32
+
+ // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a
+ // response for a random time between 0 and this time, if the target address
+ // is an anycast address.
+ //
+ // TODO(gvisor.dev/issue/2242): Use this option when sending solicited
+ // neighbor confirmations to anycast addresses and proxying neighbor
+ // confirmations.
+ MaxAnycastDelayTime time.Duration
+
+ // MaxReachabilityConfirmations is the number of unsolicited reachability
+ // confirmation messages a node MAY send to all-node multicast address when
+ // it determines its link-layer address has changed.
+ //
+ // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD
+ // configuration option is necessary.
+ MaxReachabilityConfirmations uint32
+
+ // UnreachableTime describes how long an entry will remain in the FAILED
+ // state before being removed from the neighbor cache.
+ UnreachableTime time.Duration
+}
+
+// DefaultNUDConfigurations returns a NUDConfigurations populated with default
+// values defined by RFC 4861 section 10.
+func DefaultNUDConfigurations() NUDConfigurations {
+ return NUDConfigurations{
+ BaseReachableTime: defaultBaseReachableTime,
+ LearnBaseReachableTime: true,
+ MinRandomFactor: defaultMinRandomFactor,
+ MaxRandomFactor: defaultMaxRandomFactor,
+ RetransmitTimer: defaultRetransmitTimer,
+ LearnRetransmitTimer: true,
+ DelayFirstProbeTime: defaultDelayFirstProbeTime,
+ MaxMulticastProbes: defaultMaxMulticastProbes,
+ MaxUnicastProbes: defaultMaxUnicastProbes,
+ MaxAnycastDelayTime: defaultMaxAnycastDelayTime,
+ MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations,
+ UnreachableTime: defaultUnreachableTime,
+ }
+}
+
+// resetInvalidFields modifies an invalid NDPConfigurations with valid values.
+// If invalid values are present in c, the corresponding default values will be
+// used instead. This is needed to check, and conditionally fix, user-specified
+// NUDConfigurations.
+func (c *NUDConfigurations) resetInvalidFields() {
+ if c.BaseReachableTime < minimumBaseReachableTime {
+ c.BaseReachableTime = defaultBaseReachableTime
+ }
+ if c.MinRandomFactor <= 0 {
+ c.MinRandomFactor = defaultMinRandomFactor
+ }
+ if c.MaxRandomFactor < c.MinRandomFactor {
+ c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor)
+ }
+ if c.RetransmitTimer < minimumRetransmitTimer {
+ c.RetransmitTimer = defaultRetransmitTimer
+ }
+ if c.DelayFirstProbeTime == 0 {
+ c.DelayFirstProbeTime = defaultDelayFirstProbeTime
+ }
+ if c.MaxMulticastProbes == 0 {
+ c.MaxMulticastProbes = defaultMaxMulticastProbes
+ }
+ if c.MaxUnicastProbes == 0 {
+ c.MaxUnicastProbes = defaultMaxUnicastProbes
+ }
+ if c.UnreachableTime == 0 {
+ c.UnreachableTime = defaultUnreachableTime
+ }
+}
+
+// calcMaxRandomFactor calculates the maximum value of the random factor used
+// for computing reachable time. This function is necessary for when the
+// default specified in RFC 4861 section 10 is less than the current
+// MinRandomFactor.
+//
+// Assumes minRandomFactor is positive since validation of the minimum value
+// should come before the validation of the maximum.
+func calcMaxRandomFactor(minRandomFactor float32) float32 {
+ if minRandomFactor > defaultMaxRandomFactor {
+ return minRandomFactor * 3
+ }
+ return defaultMaxRandomFactor
+}
+
+// A Rand is a source of random numbers.
+type Rand interface {
+ // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0).
+ Float32() float32
+}
+
+// NUDState stores states needed for calculating reachable time.
+type NUDState struct {
+ rng Rand
+
+ // mu protects the fields below.
+ //
+ // It is necessary for NUDState to handle its own locking since neighbor
+ // entries may access the NUD state from within the goroutine spawned by
+ // time.AfterFunc(). This goroutine may run concurrently with the main
+ // process for controlling the neighbor cache and would otherwise introduce
+ // race conditions if NUDState was not locked properly.
+ mu sync.RWMutex
+
+ config NUDConfigurations
+
+ // reachableTime is the duration to wait for a REACHABLE entry to
+ // transition into STALE after inactivity. This value is calculated with
+ // the algorithm defined in RFC 4861 section 6.3.2.
+ reachableTime time.Duration
+
+ expiration time.Time
+ prevBaseReachableTime time.Duration
+ prevMinRandomFactor float32
+ prevMaxRandomFactor float32
+}
+
+// NewNUDState returns new NUDState using c as configuration and the specified
+// random number generator for use in recomputing ReachableTime.
+func NewNUDState(c NUDConfigurations, rng Rand) *NUDState {
+ s := &NUDState{
+ rng: rng,
+ }
+ s.config = c
+ return s
+}
+
+// Config returns the NUD configuration.
+func (s *NUDState) Config() NUDConfigurations {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.config
+}
+
+// SetConfig replaces the existing NUD configurations with c.
+func (s *NUDState) SetConfig(c NUDConfigurations) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.config = c
+}
+
+// ReachableTime returns the duration to wait for a REACHABLE entry to
+// transition into STALE after inactivity. This value is recalculated for new
+// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the
+// algorithm defined in RFC 4861 section 6.3.2.
+func (s *NUDState) ReachableTime() time.Duration {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if time.Now().After(s.expiration) ||
+ s.config.BaseReachableTime != s.prevBaseReachableTime ||
+ s.config.MinRandomFactor != s.prevMinRandomFactor ||
+ s.config.MaxRandomFactor != s.prevMaxRandomFactor {
+ return s.recomputeReachableTimeLocked()
+ }
+ return s.reachableTime
+}
+
+// recomputeReachableTimeLocked forces a recalculation of ReachableTime using
+// the algorithm defined in RFC 4861 section 6.3.2.
+//
+// This SHOULD automatically be invoked during certain situations, as per
+// RFC 4861 section 6.3.4:
+//
+// If the received Reachable Time value is non-zero, the host SHOULD set its
+// BaseReachableTime variable to the received value. If the new value
+// differs from the previous value, the host SHOULD re-compute a new random
+// ReachableTime value. ReachableTime is computed as a uniformly
+// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR
+// times the BaseReachableTime. Using a random component eliminates the
+// possibility that Neighbor Unreachability Detection messages will
+// synchronize with each other.
+//
+// In most cases, the advertised Reachable Time value will be the same in
+// consecutive Router Advertisements, and a host's BaseReachableTime rarely
+// changes. In such cases, an implementation SHOULD ensure that a new
+// random value gets re-computed at least once every few hours.
+//
+// s.mu MUST be locked for writing.
+func (s *NUDState) recomputeReachableTimeLocked() time.Duration {
+ s.prevBaseReachableTime = s.config.BaseReachableTime
+ s.prevMinRandomFactor = s.config.MinRandomFactor
+ s.prevMaxRandomFactor = s.config.MaxRandomFactor
+
+ randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor)
+
+ // Check for overflow, given that minRandomFactor and maxRandomFactor are
+ // guaranteed to be positive numbers.
+ if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) {
+ s.reachableTime = time.Duration(math.MaxInt64)
+ } else if randomFactor == 1 {
+ // Avoid loss of precision when a large base reachable time is used.
+ s.reachableTime = s.config.BaseReachableTime
+ } else {
+ reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor)
+ s.reachableTime = time.Duration(reachableTime)
+ }
+
+ s.expiration = time.Now().Add(2 * time.Hour)
+ return s.reachableTime
+}
diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go
new file mode 100644
index 000000000..2494ee610
--- /dev/null
+++ b/pkg/tcpip/stack/nud_test.go
@@ -0,0 +1,795 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ defaultBaseReachableTime = 30 * time.Second
+ minimumBaseReachableTime = time.Millisecond
+ defaultMinRandomFactor = 0.5
+ defaultMaxRandomFactor = 1.5
+ defaultRetransmitTimer = time.Second
+ minimumRetransmitTimer = time.Millisecond
+ defaultDelayFirstProbeTime = 5 * time.Second
+ defaultMaxMulticastProbes = 3
+ defaultMaxUnicastProbes = 3
+ defaultMaxAnycastDelayTime = time.Second
+ defaultMaxReachbilityConfirmations = 3
+ defaultUnreachableTime = 5 * time.Second
+
+ defaultFakeRandomNum = 0.5
+)
+
+// fakeRand is a deterministic random number generator.
+type fakeRand struct {
+ num float32
+}
+
+var _ stack.Rand = (*fakeRand)(nil)
+
+func (f *fakeRand) Float32() float32 {
+ return f.num
+}
+
+// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if
+// we attempt to update NUD configurations using an invalid NICID.
+func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) {
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+
+ // No NIC with ID 1 yet.
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to retrieve NUD configurations when the
+// stack doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestNUDConfigurationFailsForNotSupported tests to make sure we get a
+// NotSupported error if we attempt to set NUD configurations when the stack
+// doesn't support NUD.
+//
+// The stack will report to not support NUD if a neighbor cache for a given NIC
+// is not allocated. The networking stack will only allocate neighbor caches if
+// a protocol providing link address resolution is specified (e.g. ARP, IPv6).
+func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ config := stack.NUDConfigurations{}
+ if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported {
+ t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported)
+ }
+}
+
+// TestDefaultNUDConfigurationIsValid verifies that calling
+// resetInvalidFields() on the result of DefaultNUDConfigurations() does not
+// change anything. DefaultNUDConfigurations() should return a valid
+// NUDConfigurations.
+func TestDefaultNUDConfigurations(t *testing.T) {
+ const nicID = 1
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The networking
+ // stack will only allocate neighbor caches if a protocol providing link
+ // address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: stack.DefaultNUDConfigurations(),
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ c, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got, want := c, stack.DefaultNUDConfigurations(); got != want {
+ t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want)
+ }
+}
+
+func TestNUDConfigurationsBaseReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ baseReachableTime: 0,
+ want: defaultBaseReachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ baseReachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ {
+ name: "MoreThanDefaultBaseReachableTime",
+ baseReachableTime: 2 * defaultBaseReachableTime,
+ want: 2 * defaultBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = test.baseReachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.BaseReachableTime; got != test.want {
+ t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMinRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: -1,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: 0,
+ want: defaultMinRandomFactor,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ minRandomFactor: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MinRandomFactor; got != test.want {
+ t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxRandomFactor(t *testing.T) {
+ tests := []struct {
+ name string
+ minRandomFactor float32
+ maxRandomFactor float32
+ want float32
+ }{
+ // Invalid cases
+ {
+ name: "LessThanZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: -1,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "EqualToZero",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: 0,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "LessThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 0.99,
+ want: defaultMaxRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault",
+ minRandomFactor: defaultMaxRandomFactor * 2,
+ maxRandomFactor: defaultMaxRandomFactor,
+ want: defaultMaxRandomFactor * 6,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor,
+ want: defaultMinRandomFactor,
+ },
+ {
+ name: "MoreThanMinRandomFactor",
+ minRandomFactor: defaultMinRandomFactor,
+ maxRandomFactor: defaultMinRandomFactor * 1.1,
+ want: defaultMinRandomFactor * 1.1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxRandomFactor; got != test.want {
+ t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsRetransmitTimer(t *testing.T) {
+ tests := []struct {
+ name string
+ retransmitTimer time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ retransmitTimer: 0,
+ want: defaultRetransmitTimer,
+ },
+ {
+ name: "LessThanMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer - time.Nanosecond,
+ want: defaultRetransmitTimer,
+ },
+ // Valid cases
+ {
+ name: "EqualToMinimumRetransmitTimer",
+ retransmitTimer: minimumRetransmitTimer,
+ want: minimumBaseReachableTime,
+ },
+ {
+ name: "LargetThanMinimumRetransmitTimer",
+ retransmitTimer: 2 * minimumBaseReachableTime,
+ want: 2 * minimumBaseReachableTime,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.RetransmitTimer = test.retransmitTimer
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.RetransmitTimer; got != test.want {
+ t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) {
+ tests := []struct {
+ name string
+ delayFirstProbeTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ delayFirstProbeTime: 0,
+ want: defaultDelayFirstProbeTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ delayFirstProbeTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.DelayFirstProbeTime = test.delayFirstProbeTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.DelayFirstProbeTime; got != test.want {
+ t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxMulticastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxMulticastProbes: 0,
+ want: defaultMaxMulticastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxMulticastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxMulticastProbes = test.maxMulticastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxMulticastProbes; got != test.want {
+ t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) {
+ tests := []struct {
+ name string
+ maxUnicastProbes uint32
+ want uint32
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ maxUnicastProbes: 0,
+ want: defaultMaxUnicastProbes,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ maxUnicastProbes: 1,
+ want: 1,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.MaxUnicastProbes = test.maxUnicastProbes
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.MaxUnicastProbes; got != test.want {
+ t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+func TestNUDConfigurationsUnreachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ unreachableTime time.Duration
+ want time.Duration
+ }{
+ // Invalid cases
+ {
+ name: "EqualToZero",
+ unreachableTime: 0,
+ want: defaultUnreachableTime,
+ },
+ // Valid cases
+ {
+ name: "MoreThanZero",
+ unreachableTime: time.Millisecond,
+ want: time.Millisecond,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ const nicID = 1
+
+ c := stack.DefaultNUDConfigurations()
+ c.UnreachableTime = test.unreachableTime
+
+ e := channel.New(0, 1280, linkAddr1)
+ e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+
+ s := stack.New(stack.Options{
+ // A neighbor cache is required to store NUDConfigurations. The
+ // networking stack will only allocate neighbor caches if a protocol
+ // providing link address resolution is specified (e.g. ARP or IPv6).
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NUDConfigs: c,
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ sc, err := s.NUDConfigurations(nicID)
+ if err != nil {
+ t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err)
+ }
+ if got := sc.UnreachableTime; got != test.want {
+ t.Errorf("got UnreachableTime = %q, want = %q", got, test.want)
+ }
+ })
+ }
+}
+
+// TestNUDStateReachableTime verifies the correctness of the ReachableTime
+// computation.
+func TestNUDStateReachableTime(t *testing.T) {
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "AllZeros",
+ baseReachableTime: 0,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMaxRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 0,
+ want: 0,
+ },
+ {
+ name: "ZeroMinRandomFactor",
+ baseReachableTime: time.Second,
+ minRandomFactor: 0,
+ maxRandomFactor: 1,
+ want: time.Duration(defaultFakeRandomNum * float32(time.Second)),
+ },
+ {
+ name: "FractionalRandomFactor",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 0.001,
+ maxRandomFactor: 0.002,
+ want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)),
+ },
+ {
+ name: "MinAndMaxRandomFactorsEqual",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Second,
+ },
+ {
+ name: "MinAndMaxRandomFactorsDifferent",
+ baseReachableTime: time.Second,
+ minRandomFactor: 1,
+ maxRandomFactor: 2,
+ want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)),
+ },
+ {
+ name: "MaxInt64",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1,
+ maxRandomFactor: 1,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "Overflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 1.5,
+ maxRandomFactor: 1.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ {
+ name: "DoubleOverflow",
+ baseReachableTime: time.Duration(math.MaxInt64),
+ minRandomFactor: 2.5,
+ maxRandomFactor: 2.5,
+ want: time.Duration(math.MaxInt64),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.NUDConfigurations{
+ BaseReachableTime: test.baseReachableTime,
+ MinRandomFactor: test.minRandomFactor,
+ MaxRandomFactor: test.maxRandomFactor,
+ }
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
+
+// TestNUDStateRecomputeReachableTime exercises the ReachableTime function
+// twice to verify recomputation of reachable time when the min random factor,
+// max random factor, or base reachable time changes.
+func TestNUDStateRecomputeReachableTime(t *testing.T) {
+ const defaultBase = time.Second
+ const defaultMin = 2.0 * defaultMaxRandomFactor
+ const defaultMax = 3.0 * defaultMaxRandomFactor
+
+ tests := []struct {
+ name string
+ baseReachableTime time.Duration
+ minRandomFactor float32
+ maxRandomFactor float32
+ want time.Duration
+ }{
+ {
+ name: "BaseReachableTime",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMax,
+ want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ {
+ name: "MinRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMax,
+ maxRandomFactor: defaultMax,
+ want: time.Duration(defaultMax * float32(defaultBase)),
+ },
+ {
+ name: "MaxRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: defaultMin,
+ maxRandomFactor: defaultMin,
+ want: time.Duration(defaultMin * float32(defaultBase)),
+ },
+ {
+ name: "BothRandomFactor",
+ baseReachableTime: defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)),
+ },
+ {
+ name: "BaseReachableTimeAndBothRandomFactors",
+ baseReachableTime: 2 * defaultBase,
+ minRandomFactor: 2 * defaultMin,
+ maxRandomFactor: 2 * defaultMax,
+ want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := stack.DefaultNUDConfigurations()
+ c.BaseReachableTime = defaultBase
+ c.MinRandomFactor = defaultMin
+ c.MaxRandomFactor = defaultMax
+
+ // A fake random number generator is used to ensure deterministic
+ // results.
+ rng := fakeRand{
+ num: defaultFakeRandomNum,
+ }
+ s := stack.NewNUDState(c, &rng)
+ old := s.ReachableTime()
+
+ if got, want := s.ReachableTime(), old; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Check for recomputation when changing the min random factor, the max
+ // random factor, the base reachability time, or any permutation of those
+ // three options.
+ c.BaseReachableTime = test.baseReachableTime
+ c.MinRandomFactor = test.minRandomFactor
+ c.MaxRandomFactor = test.maxRandomFactor
+ s.SetConfig(c)
+
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+
+ // Verify that ReachableTime isn't recomputed when none of the
+ // configuration options change. The random factor is changed so that if
+ // a recompution were to occur, ReachableTime would change.
+ rng.num = defaultFakeRandomNum / 2.0
+ if got, want := s.ReachableTime(), test.want; got != want {
+ t.Errorf("got ReachableTime = %q, want = %q", got, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
new file mode 100644
index 000000000..17b8beebb
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -0,0 +1,299 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+type headerType int
+
+const (
+ linkHeader headerType = iota
+ networkHeader
+ transportHeader
+ numHeaderType
+)
+
+// PacketBufferOptions specifies options for PacketBuffer creation.
+type PacketBufferOptions struct {
+ // ReserveHeaderBytes is the number of bytes to reserve for headers. Total
+ // number of bytes pushed onto the headers must not exceed this value.
+ ReserveHeaderBytes int
+
+ // Data is the initial unparsed data for the new packet. If set, it will be
+ // owned by the new packet.
+ Data buffer.VectorisedView
+}
+
+// A PacketBuffer contains all the data of a network packet.
+//
+// As a PacketBuffer traverses up the stack, it may be necessary to pass it to
+// multiple endpoints.
+//
+// The whole packet is expected to be a series of bytes in the following order:
+// LinkHeader, NetworkHeader, TransportHeader, and Data. Any of them can be
+// empty. Use of PacketBuffer in any other order is unsupported.
+//
+// PacketBuffer must be created with NewPacketBuffer.
+type PacketBuffer struct {
+ _ sync.NoCopy
+
+ // PacketBufferEntry is used to build an intrusive list of
+ // PacketBuffers.
+ PacketBufferEntry
+
+ // Data holds the payload of the packet.
+ //
+ // For inbound packets, Data is initially the whole packet. Then gets moved to
+ // headers via PacketHeader.Consume, when the packet is being parsed.
+ //
+ // For outbound packets, Data is the innermost layer, defined by the protocol.
+ // Headers are pushed in front of it via PacketHeader.Push.
+ //
+ // The bytes backing Data are immutable, a.k.a. users shouldn't write to its
+ // backing storage.
+ Data buffer.VectorisedView
+
+ // headers stores metadata about each header.
+ headers [numHeaderType]headerInfo
+
+ // header is the internal storage for outbound packets. Headers will be pushed
+ // (prepended) on this storage as the packet is being constructed.
+ //
+ // TODO(gvisor.dev/issue/2404): Switch to an implementation that header and
+ // data are held in the same underlying buffer storage.
+ header buffer.Prependable
+
+ // NetworkProtocol is only valid when NetworkHeader is set.
+ // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
+ // numbers in registration APIs that take a PacketBuffer.
+ NetworkProtocolNumber tcpip.NetworkProtocolNumber
+
+ // Hash is the transport layer hash of this packet. A value of zero
+ // indicates no valid hash has been set.
+ Hash uint32
+
+ // Owner is implemented by task to get the uid and gid.
+ // Only set for locally generated packets.
+ Owner tcpip.PacketOwner
+
+ // The following fields are only set by the qdisc layer when the packet
+ // is added to a queue.
+ EgressRoute *Route
+ GSOOptions *GSO
+
+ // NatDone indicates if the packet has been manipulated as per NAT
+ // iptables rule.
+ NatDone bool
+
+ // PktType indicates the SockAddrLink.PacketType of the packet as defined in
+ // https://www.man7.org/linux/man-pages/man7/packet.7.html.
+ PktType tcpip.PacketType
+}
+
+// NewPacketBuffer creates a new PacketBuffer with opts.
+func NewPacketBuffer(opts PacketBufferOptions) *PacketBuffer {
+ pk := &PacketBuffer{
+ Data: opts.Data,
+ }
+ if opts.ReserveHeaderBytes != 0 {
+ pk.header = buffer.NewPrependable(opts.ReserveHeaderBytes)
+ }
+ return pk
+}
+
+// ReservedHeaderBytes returns the number of bytes initially reserved for
+// headers.
+func (pk *PacketBuffer) ReservedHeaderBytes() int {
+ return pk.header.UsedLength() + pk.header.AvailableLength()
+}
+
+// AvailableHeaderBytes returns the number of bytes currently available for
+// headers. This is relevant to PacketHeader.Push method only.
+func (pk *PacketBuffer) AvailableHeaderBytes() int {
+ return pk.header.AvailableLength()
+}
+
+// LinkHeader returns the handle to link-layer header.
+func (pk *PacketBuffer) LinkHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: linkHeader,
+ }
+}
+
+// NetworkHeader returns the handle to network-layer header.
+func (pk *PacketBuffer) NetworkHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: networkHeader,
+ }
+}
+
+// TransportHeader returns the handle to transport-layer header.
+func (pk *PacketBuffer) TransportHeader() PacketHeader {
+ return PacketHeader{
+ pk: pk,
+ typ: transportHeader,
+ }
+}
+
+// HeaderSize returns the total size of all headers in bytes.
+func (pk *PacketBuffer) HeaderSize() int {
+ // Note for inbound packets (Consume called), headers are not stored in
+ // pk.header. Thus, calculation of size of each header is needed.
+ var size int
+ for i := range pk.headers {
+ size += len(pk.headers[i].buf)
+ }
+ return size
+}
+
+// Size returns the size of packet in bytes.
+func (pk *PacketBuffer) Size() int {
+ return pk.HeaderSize() + pk.Data.Size()
+}
+
+// Views returns the underlying storage of the whole packet.
+func (pk *PacketBuffer) Views() []buffer.View {
+ // Optimization for outbound packets that headers are in pk.header.
+ useHeader := true
+ for i := range pk.headers {
+ if !canUseHeader(&pk.headers[i]) {
+ useHeader = false
+ break
+ }
+ }
+
+ dataViews := pk.Data.Views()
+
+ var vs []buffer.View
+ if useHeader {
+ vs = make([]buffer.View, 0, 1+len(dataViews))
+ vs = append(vs, pk.header.View())
+ } else {
+ vs = make([]buffer.View, 0, len(pk.headers)+len(dataViews))
+ for i := range pk.headers {
+ if v := pk.headers[i].buf; len(v) > 0 {
+ vs = append(vs, v)
+ }
+ }
+ }
+ return append(vs, dataViews...)
+}
+
+func canUseHeader(h *headerInfo) bool {
+ // h.offset will be negative if the header was pushed in to prependable
+ // portion, or doesn't matter when it's empty.
+ return len(h.buf) == 0 || h.offset < 0
+}
+
+func (pk *PacketBuffer) push(typ headerType, size int) buffer.View {
+ h := &pk.headers[typ]
+ if h.buf != nil {
+ panic(fmt.Sprintf("push must not be called twice: type %s", typ))
+ }
+ h.buf = buffer.View(pk.header.Prepend(size))
+ h.offset = -pk.header.UsedLength()
+ return h.buf
+}
+
+func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consumed bool) {
+ h := &pk.headers[typ]
+ if h.buf != nil {
+ panic(fmt.Sprintf("consume must not be called twice: type %s", typ))
+ }
+ v, ok := pk.Data.PullUp(size)
+ if !ok {
+ return
+ }
+ pk.Data.TrimFront(size)
+ h.buf = v
+ return h.buf, true
+}
+
+// Clone makes a shallow copy of pk.
+//
+// Clone should be called in such cases so that no modifications is done to
+// underlying packet payload.
+func (pk *PacketBuffer) Clone() *PacketBuffer {
+ newPk := &PacketBuffer{
+ PacketBufferEntry: pk.PacketBufferEntry,
+ Data: pk.Data.Clone(nil),
+ headers: pk.headers,
+ header: pk.header,
+ Hash: pk.Hash,
+ Owner: pk.Owner,
+ EgressRoute: pk.EgressRoute,
+ GSOOptions: pk.GSOOptions,
+ NetworkProtocolNumber: pk.NetworkProtocolNumber,
+ NatDone: pk.NatDone,
+ }
+ return newPk
+}
+
+// headerInfo stores metadata about a header in a packet.
+type headerInfo struct {
+ // buf is the memorized slice for both prepended and consumed header.
+ // When header is prepended, buf serves as memorized value, which is a slice
+ // of pk.header. When header is consumed, buf is the slice pulled out from
+ // pk.Data, which is the only place to hold this header.
+ buf buffer.View
+
+ // offset will be a negative number denoting the offset where this header is
+ // from the end of pk.header, if it is prepended. Otherwise, zero.
+ offset int
+}
+
+// PacketHeader is a handle object to a header in the underlying packet.
+type PacketHeader struct {
+ pk *PacketBuffer
+ typ headerType
+}
+
+// View returns the underlying storage of h.
+func (h PacketHeader) View() buffer.View {
+ return h.pk.headers[h.typ].buf
+}
+
+// Push pushes size bytes in the front of its residing packet, and returns the
+// backing storage. Callers may only call one of Push or Consume once on each
+// header in the lifetime of the underlying packet.
+func (h PacketHeader) Push(size int) buffer.View {
+ return h.pk.push(h.typ, size)
+}
+
+// Consume moves the first size bytes of the unparsed data portion in the packet
+// to h, and returns the backing storage. In the case of data is shorter than
+// size, consumed will be false, and the state of h will not be affected.
+// Callers may only call one of Push or Consume once on each header in the
+// lifetime of the underlying packet.
+func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) {
+ return h.pk.consume(h.typ, size)
+}
+
+// PayloadSince returns packet payload starting from and including a particular
+// header. This method isn't optimized and should be used in test only.
+func PayloadSince(h PacketHeader) buffer.View {
+ var v buffer.View
+ for _, hinfo := range h.pk.headers[h.typ:] {
+ v = append(v, hinfo.buf...)
+ }
+ return append(v, h.pk.Data.ToView()...)
+}
diff --git a/pkg/tcpip/stack/packet_buffer_test.go b/pkg/tcpip/stack/packet_buffer_test.go
new file mode 100644
index 000000000..c6fa8da5f
--- /dev/null
+++ b/pkg/tcpip/stack/packet_buffer_test.go
@@ -0,0 +1,397 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "bytes"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+func TestPacketHeaderPush(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reserved int
+ link []byte
+ network []byte
+ transport []byte
+ data []byte
+ }{
+ {
+ name: "construct empty packet",
+ },
+ {
+ name: "construct link header only packet",
+ reserved: 60,
+ link: makeView(10),
+ },
+ {
+ name: "construct link and network header only packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ },
+ {
+ name: "construct header only packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ transport: makeView(30),
+ },
+ {
+ name: "construct data only packet",
+ data: makeView(40),
+ },
+ {
+ name: "construct L3 packet",
+ reserved: 60,
+ network: makeView(20),
+ transport: makeView(30),
+ data: makeView(40),
+ },
+ {
+ name: "construct L2 packet",
+ reserved: 60,
+ link: makeView(10),
+ network: makeView(20),
+ transport: makeView(30),
+ data: makeView(40),
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: test.reserved,
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
+ })
+
+ allHdrSize := len(test.link) + len(test.network) + len(test.transport)
+
+ // Check the initial values for packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ ReserveHeaderBytes: test.reserved,
+ Data: buffer.View(test.data).ToVectorisedView(),
+ })
+
+ // Push headers.
+ if v := test.transport; len(v) > 0 {
+ copy(pk.TransportHeader().Push(len(v)), v)
+ }
+ if v := test.network; len(v) > 0 {
+ copy(pk.NetworkHeader().Push(len(v)), v)
+ }
+ if v := test.link; len(v) > 0 {
+ copy(pk.LinkHeader().Push(len(v)), v)
+ }
+
+ // Check the after values for packet.
+ if got, want := pk.ReservedHeaderBytes(), test.reserved; got != want {
+ t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), test.reserved-allHdrSize; got != want {
+ t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), allHdrSize; got != want {
+ t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
+ }
+ if got, want := pk.Size(), allHdrSize+len(test.data); got != want {
+ t.Errorf("After pk.Size() = %d, want %d", got, want)
+ }
+ checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), test.data)
+ checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...),
+ concatViews(test.link, test.network, test.transport, test.data))
+ // Check the after values for each header.
+ checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), test.link)
+ checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), test.network)
+ checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), test.transport)
+ // Check the after values for PayloadSince.
+ checkViewEqual(t, "After PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(test.link, test.network, test.transport, test.data))
+ checkViewEqual(t, "After PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(test.network, test.transport, test.data))
+ checkViewEqual(t, "After PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(test.transport, test.data))
+ })
+ }
+}
+
+func TestPacketHeaderConsume(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ data []byte
+ link int
+ network int
+ transport int
+ }{
+ {
+ name: "parse L2 packet",
+ data: concatViews(makeView(10), makeView(20), makeView(30), makeView(40)),
+ link: 10,
+ network: 20,
+ transport: 30,
+ },
+ {
+ name: "parse L3 packet",
+ data: concatViews(makeView(20), makeView(30), makeView(40)),
+ network: 20,
+ transport: 30,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(test.data).ToVectorisedView(),
+ })
+
+ // Check the initial values for packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ Data: buffer.View(test.data).ToVectorisedView(),
+ })
+
+ // Consume headers.
+ if size := test.link; size > 0 {
+ if _, ok := pk.LinkHeader().Consume(size); !ok {
+ t.Fatalf("pk.LinkHeader().Consume() = false, want true")
+ }
+ }
+ if size := test.network; size > 0 {
+ if _, ok := pk.NetworkHeader().Consume(size); !ok {
+ t.Fatalf("pk.NetworkHeader().Consume() = false, want true")
+ }
+ }
+ if size := test.transport; size > 0 {
+ if _, ok := pk.TransportHeader().Consume(size); !ok {
+ t.Fatalf("pk.TransportHeader().Consume() = false, want true")
+ }
+ }
+
+ allHdrSize := test.link + test.network + test.transport
+
+ // Check the after values for packet.
+ if got, want := pk.ReservedHeaderBytes(), 0; got != want {
+ t.Errorf("After pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), 0; got != want {
+ t.Errorf("After pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), allHdrSize; got != want {
+ t.Errorf("After pk.HeaderSize() = %d, want %d", got, want)
+ }
+ if got, want := pk.Size(), len(test.data); got != want {
+ t.Errorf("After pk.Size() = %d, want %d", got, want)
+ }
+ // After state of pk.
+ var (
+ link = test.data[:test.link]
+ network = test.data[test.link:][:test.network]
+ transport = test.data[test.link+test.network:][:test.transport]
+ payload = test.data[allHdrSize:]
+ )
+ checkViewEqual(t, "After pk.Data.Views()", concatViews(pk.Data.Views()...), payload)
+ checkViewEqual(t, "After pk.Views()", concatViews(pk.Views()...), test.data)
+ // Check the after values for each header.
+ checkPacketHeader(t, "After pk.LinkHeader", pk.LinkHeader(), link)
+ checkPacketHeader(t, "After pk.NetworkHeader", pk.NetworkHeader(), network)
+ checkPacketHeader(t, "After pk.TransportHeader", pk.TransportHeader(), transport)
+ // Check the after values for PayloadSince.
+ checkViewEqual(t, "After PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()),
+ concatViews(link, network, transport, payload))
+ checkViewEqual(t, "After PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()),
+ concatViews(network, transport, payload))
+ checkViewEqual(t, "After PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()),
+ concatViews(transport, payload))
+ })
+ }
+}
+
+func TestPacketHeaderConsumeDataTooShort(t *testing.T) {
+ data := makeView(10)
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ // Make a copy of data to make sure our truth data won't be taint by
+ // PacketBuffer.
+ Data: buffer.NewViewFromBytes(data).ToVectorisedView(),
+ })
+
+ // Consume should fail if pkt.Data is too short.
+ if _, ok := pk.LinkHeader().Consume(11); ok {
+ t.Fatalf("pk.LinkHeader().Consume() = _, true; want _, false")
+ }
+ if _, ok := pk.NetworkHeader().Consume(11); ok {
+ t.Fatalf("pk.NetworkHeader().Consume() = _, true; want _, false")
+ }
+ if _, ok := pk.TransportHeader().Consume(11); ok {
+ t.Fatalf("pk.TransportHeader().Consume() = _, true; want _, false")
+ }
+
+ // Check packet should look the same as initial packet.
+ checkInitialPacketBuffer(t, pk, PacketBufferOptions{
+ Data: buffer.View(data).ToVectorisedView(),
+ })
+}
+
+func TestPacketHeaderPushCalledAtMostOnce(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: headerSize * int(numHeaderType),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.TransportHeader(),
+ pk.NetworkHeader(),
+ pk.LinkHeader(),
+ } {
+ t.Run("PushedTwice/"+h.typ.String(), func(t *testing.T) {
+ h.Push(headerSize)
+
+ defer func() { recover() }()
+ h.Push(headerSize)
+ t.Fatal("Second push should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderConsumeCalledAtMostOnce(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.LinkHeader(),
+ pk.NetworkHeader(),
+ pk.TransportHeader(),
+ } {
+ t.Run("ConsumedTwice/"+h.typ.String(), func(t *testing.T) {
+ if _, ok := h.Consume(headerSize); !ok {
+ t.Fatal("First consume should succeed")
+ }
+
+ defer func() { recover() }()
+ h.Consume(headerSize)
+ t.Fatal("Second consume should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderPushThenConsumePanics(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ ReserveHeaderBytes: headerSize * int(numHeaderType),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.TransportHeader(),
+ pk.NetworkHeader(),
+ pk.LinkHeader(),
+ } {
+ t.Run(h.typ.String(), func(t *testing.T) {
+ h.Push(headerSize)
+
+ defer func() { recover() }()
+ h.Consume(headerSize)
+ t.Fatal("Consume should have panicked")
+ })
+ }
+}
+
+func TestPacketHeaderConsumeThenPushPanics(t *testing.T) {
+ const headerSize = 10
+
+ pk := NewPacketBuffer(PacketBufferOptions{
+ Data: makeView(headerSize * int(numHeaderType)).ToVectorisedView(),
+ })
+
+ for _, h := range []PacketHeader{
+ pk.LinkHeader(),
+ pk.NetworkHeader(),
+ pk.TransportHeader(),
+ } {
+ t.Run(h.typ.String(), func(t *testing.T) {
+ h.Consume(headerSize)
+
+ defer func() { recover() }()
+ h.Push(headerSize)
+ t.Fatal("Push should have panicked")
+ })
+ }
+}
+
+func checkInitialPacketBuffer(t *testing.T, pk *PacketBuffer, opts PacketBufferOptions) {
+ t.Helper()
+ reserved := opts.ReserveHeaderBytes
+ if got, want := pk.ReservedHeaderBytes(), reserved; got != want {
+ t.Errorf("Initial pk.ReservedHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.AvailableHeaderBytes(), reserved; got != want {
+ t.Errorf("Initial pk.AvailableHeaderBytes() = %d, want %d", got, want)
+ }
+ if got, want := pk.HeaderSize(), 0; got != want {
+ t.Errorf("Initial pk.HeaderSize() = %d, want %d", got, want)
+ }
+ data := opts.Data.ToView()
+ if got, want := pk.Size(), len(data); got != want {
+ t.Errorf("Initial pk.Size() = %d, want %d", got, want)
+ }
+ checkViewEqual(t, "Initial pk.Data.Views()", concatViews(pk.Data.Views()...), data)
+ checkViewEqual(t, "Initial pk.Views()", concatViews(pk.Views()...), data)
+ // Check the initial values for each header.
+ checkPacketHeader(t, "Initial pk.LinkHeader", pk.LinkHeader(), nil)
+ checkPacketHeader(t, "Initial pk.NetworkHeader", pk.NetworkHeader(), nil)
+ checkPacketHeader(t, "Initial pk.TransportHeader", pk.TransportHeader(), nil)
+ // Check the initial valies for PayloadSince.
+ checkViewEqual(t, "Initial PayloadSince(LinkHeader)",
+ PayloadSince(pk.LinkHeader()), data)
+ checkViewEqual(t, "Initial PayloadSince(NetworkHeader)",
+ PayloadSince(pk.NetworkHeader()), data)
+ checkViewEqual(t, "Initial PayloadSince(TransportHeader)",
+ PayloadSince(pk.TransportHeader()), data)
+}
+
+func checkPacketHeader(t *testing.T, name string, h PacketHeader, want []byte) {
+ t.Helper()
+ checkViewEqual(t, name+".View()", h.View(), want)
+}
+
+func checkViewEqual(t *testing.T, what string, got, want buffer.View) {
+ t.Helper()
+ if !bytes.Equal(got, want) {
+ t.Errorf("%s = %x, want %x", what, got, want)
+ }
+}
+
+func makeView(size int) buffer.View {
+ b := byte(size)
+ return bytes.Repeat([]byte{b}, size)
+}
+
+func concatViews(views ...buffer.View) buffer.View {
+ var all buffer.View
+ for _, v := range views {
+ all = append(all, v...)
+ }
+ return all
+}
diff --git a/pkg/tcpip/stack/rand.go b/pkg/tcpip/stack/rand.go
new file mode 100644
index 000000000..421fb5c15
--- /dev/null
+++ b/pkg/tcpip/stack/rand.go
@@ -0,0 +1,40 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ mathrand "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// lockedRandomSource provides a threadsafe rand.Source.
+type lockedRandomSource struct {
+ mu sync.Mutex
+ src mathrand.Source
+}
+
+func (r *lockedRandomSource) Int63() (n int64) {
+ r.mu.Lock()
+ n = r.src.Int63()
+ r.mu.Unlock()
+ return n
+}
+
+func (r *lockedRandomSource) Seed(seed int64) {
+ r.mu.Lock()
+ r.src.Seed(seed)
+ r.mu.Unlock()
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0869fb084..aca2f77f8 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -51,8 +52,11 @@ type TransportEndpointID struct {
type ControlType int
// The following are the allowed values for ControlType values.
+// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages.
const (
- ControlPacketTooBig ControlType = iota
+ ControlNetworkUnreachable ControlType = iota
+ ControlNoRoute
+ ControlPacketTooBig
ControlPortUnreachable
ControlUnknown
)
@@ -60,13 +64,34 @@ const (
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
+ // UniqueID returns an unique ID for this transport endpoint.
+ UniqueID() uint64
+
// HandlePacket is called by the stack when new packets arrive to
- // this transport endpoint.
- HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
+ // this transport endpoint. It sets pkt.TransportHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer)
- // HandleControlPacket is called by the stack when new control (e.g.,
+ // HandleControlPacket is called by the stack when new control (e.g.
// ICMP) packets arrive to this transport endpoint.
- HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
+ // HandleControlPacket takes ownership of pkt.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer)
+
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // An endpoint can be requested to stop its worker goroutines by calling
+ // its Close method.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// RawTransportEndpoint is the interface that needs to be implemented by raw
@@ -77,7 +102,9 @@ type RawTransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint. The packet contains all data from the link
// layer up.
- HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt *PacketBuffer)
}
// PacketEndpoint is the interface that needs to be implemented by packet
@@ -93,7 +120,9 @@ type PacketEndpoint interface {
//
// linkHeader may have a length of 0, in which case the PacketEndpoint
// should construct its own ethernet header for applications.
- HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View)
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(nicID tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// TransportProtocol is the interface that needs to be implemented by transport
@@ -123,7 +152,9 @@ type TransportProtocol interface {
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+ //
+ // HandleUnknownDestinationPacket takes ownership of pkt.
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -134,6 +165,18 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
+
+ // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
+ // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
+ // MinimumPacketSize()
+ Parse(pkt *PacketBuffer) (ok bool)
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -141,13 +184,21 @@ type TransportProtocol interface {
// the network layer.
type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
- // transport protocol endpoint. It also returns the network layer
- // header for the enpoint to inspect or pass up the stack.
- DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
+ // transport protocol endpoint.
+ //
+ // pkt.NetworkHeader must be set before calling DeliverTransportPacket.
+ //
+ // DeliverTransportPacket takes ownership of pkt.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
- DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
+ //
+ // pkt.NetworkHeader must be set before calling
+ // DeliverTransportControlPacket.
+ //
+ // DeliverTransportControlPacket takes ownership of pkt.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer)
}
// PacketLooping specifies where an outbound packet should be sent.
@@ -198,32 +249,34 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol.
- WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error
+ // protocol. It takes ownership of pkt. pkt.TransportHeader must have
+ // already been set.
+ WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
- // protocol.
- WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) (int, *tcpip.Error)
+ // protocol. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error)
// WriteHeaderIncludedPacket writes a packet that includes a network
- // header to the given destination address.
- WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error
-
- // ID returns the network protocol endpoint ID.
- ID() *NetworkEndpointID
-
- // PrefixLen returns the network endpoint's subnet prefix length in bits.
- PrefixLen() int
+ // header to the given destination address. It takes ownership of pkt.
+ WriteHeaderIncludedPacket(r *Route, pkt *PacketBuffer) *tcpip.Error
// NICID returns the id of the NIC this endpoint belongs to.
NICID() tcpip.NICID
// HandlePacket is called by the link layer when new packets arrive to
- // this network endpoint.
- HandlePacket(r *Route, vv buffer.VectorisedView)
+ // this network endpoint. It sets pkt.NetworkHeader.
+ //
+ // HandlePacket takes ownership of pkt.
+ HandlePacket(r *Route, pkt *PacketBuffer)
// Close is called when the endpoint is reomved from a stack.
Close()
+
+ // NetworkProtocolNumber returns the tcpip.NetworkProtocolNumber for
+ // this endpoint.
+ NetworkProtocolNumber() tcpip.NetworkProtocolNumber
}
// NetworkProtocol is the interface that needs to be implemented by network
@@ -240,12 +293,12 @@ type NetworkProtocol interface {
// DefaultPrefixLen returns the protocol's default prefix length.
DefaultPrefixLen() int
- // ParsePorts returns the source and destination addresses stored in a
+ // ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
- NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+ NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -256,16 +309,45 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
+
+ // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It
+ // returns:
+ // - The encapsulated protocol, if present.
+ // - Whether there is an encapsulated transport protocol payload (e.g. ARP
+ // does not encapsulate anything).
+ // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
+ Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
}
// NetworkDispatcher contains the methods used by the network stack to deliver
-// packets to the appropriate network endpoint after it has been handled by
-// the data link layer.
+// inbound/outbound packets to the appropriate network/packet(if any) endpoints.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
- // and hands the packet over for further processing. linkHeader may have
- // length 0 when the caller does not have ethernet data.
- DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View)
+ // and hands the packet over for further processing.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverNetworkPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverNetworkPacket takes ownership of pkt.
+ DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
+
+ // DeliverOutboundPacket is called by link layer when a packet is being
+ // sent out.
+ //
+ // pkt.LinkHeader may or may not be set before calling
+ // DeliverOutboundPacket. Some packets do not have link headers (e.g.
+ // packets sent via loopback), and won't have the field set.
+ //
+ // DeliverOutboundPacket takes ownership of pkt.
+ DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// LinkEndpointCapabilities is the type associated with the capabilities
@@ -296,7 +378,9 @@ const (
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
-// out through the implementer's data link endpoint.
+// out through the implementer's data link endpoint. When a link header exists,
+// it sets each PacketBuffer's LinkHeader field before passing it up the
+// stack.
type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
@@ -318,28 +402,33 @@ type LinkEndpoint interface {
// link endpoint.
LinkAddress() tcpip.LinkAddress
- // WritePacket writes a packet with the given protocol through the given
- // route.
+ // WritePacket writes a packet with the given protocol through the
+ // given route. It takes ownership of pkt. pkt.NetworkHeader and
+ // pkt.TransportHeader must have already been set.
//
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
- WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+ WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets with the given protocol through the
- // given route.
+ // given route. pkts must not be zero length. It takes ownership of pkts and
+ // underlying packets.
//
// Right now, WritePackets is used only when the software segmentation
// offload is enabled. If it will be used for something else, it may
// require to change syscall filters.
- WritePackets(r *Route, gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
+ WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error)
// WriteRawPacket writes a packet directly to the link. The packet
- // should already have an ethernet header.
- WriteRawPacket(packet buffer.VectorisedView) *tcpip.Error
+ // should already have an ethernet header. It takes ownership of vv.
+ WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
+ //
+ // Attach will be called with a nil dispatcher if the receiver's associated
+ // NIC is being removed.
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
@@ -354,6 +443,15 @@ type LinkEndpoint interface {
// Wait will not block if the endpoint hasn't started any goroutines
// yet, even if it might later.
Wait()
+
+ // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint.
+ //
+ // See:
+ // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30
+ ARPHardwareType() header.ARPHardwareType
+
+ // AddHeader adds a link layer header to pkt if required.
+ AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -362,7 +460,7 @@ type InjectableLinkEndpoint interface {
LinkEndpoint
// InjectInbound injects an inbound packet.
- InjectInbound(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
+ InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer)
// InjectOutbound writes a fully formed outbound packet directly to the
// link.
@@ -374,12 +472,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr.
- // The request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+ // the request on the local network if remoteLinkAddr is the zero value. The
+ // request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
@@ -397,10 +496,10 @@ type LinkAddressResolver interface {
type LinkAddressCache interface {
// CheckLocalAddress determines if the given local address exists, and if it
// does not exist.
- CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+ CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
// AddLinkAddress adds a link address to the cache.
- AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+ AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
// GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
// If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
@@ -411,10 +510,10 @@ type LinkAddressCache interface {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
- GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
+ GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
// RemoveWaker removes a waker that has been added in GetLinkAddress().
- RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+ RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
// RawFactory produces endpoints for writing various types of raw packets.
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1a0a51b57..e267bebb0 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -17,7 +17,6 @@ package stack
import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -49,6 +48,10 @@ type Route struct {
// Loop controls where WritePacket should send packets.
Loop PacketLooping
+
+ // directedBroadcast indicates whether this route is sending a directed
+ // broadcast packet.
+ directedBroadcast bool
}
// makeRoute initializes a new route. It takes ownership of the provided
@@ -107,6 +110,12 @@ func (r *Route) GSOMaxSize() uint32 {
return 0
}
+// ResolveWith immediately resolves a route with the specified remote link
+// address.
+func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
+ r.RemoteLinkAddress = addr
+}
+
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
@@ -114,6 +123,8 @@ func (r *Route) GSOMaxSize() uint32 {
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
+//
+// The NIC r uses must not be locked.
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if !r.IsResolutionRequired() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
@@ -149,77 +160,72 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
+//
+// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
-func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error {
+func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.Loop)
+ // WritePacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Size()
+
+ err := r.ref.ep.WritePacket(r, gso, params, pkt)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
} else {
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdr.UsedLength() + payload.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
}
return err
}
-// PacketDescriptor is a packet descriptor which contains a packet header and
-// offset and size of packet data in a payload view.
-type PacketDescriptor struct {
- Hdr buffer.Prependable
- Off int
- Size int
-}
-
-// NewPacketDescriptors allocates a set of packet descriptors.
-func NewPacketDescriptors(n int, hdrSize int) []PacketDescriptor {
- buf := make([]byte, n*hdrSize)
- hdrs := make([]PacketDescriptor, n)
- for i := range hdrs {
- hdrs[i].Hdr = buffer.NewEmptyPrependableFromView(buf[i*hdrSize:][:hdrSize])
- }
- return hdrs
-}
-
-// WritePackets writes the set of packets through the given route.
-func (r *Route) WritePackets(gso *GSO, hdrs []PacketDescriptor, payload buffer.VectorisedView, params NetworkHeaderParams) (int, *tcpip.Error) {
+// WritePackets writes a list of n packets through the given route and returns
+// the number of packets written.
+func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) {
if !r.ref.isValidForOutgoing() {
return 0, tcpip.ErrInvalidEndpointState
}
- n, err := r.ref.ep.WritePackets(r, gso, hdrs, payload, params, r.Loop)
+ // WritePackets takes ownership of pkt, calculate length first.
+ numPkts := pkts.Len()
+
+ n, err := r.ref.ep.WritePackets(r, gso, pkts, params)
if err != nil {
- r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(len(hdrs) - n))
+ r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(numPkts - n))
}
r.ref.nic.stats.Tx.Packets.IncrementBy(uint64(n))
- payloadSize := 0
- for i := 0; i < n; i++ {
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(hdrs[i].Hdr.UsedLength()))
- payloadSize += hdrs[i].Size
+
+ writtenBytes := 0
+ for i, pb := 0, pkts.Front(); i < n && pb != nil; i, pb = i+1, pb.Next() {
+ writtenBytes += pb.Size()
}
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payloadSize))
+
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(writtenBytes))
return n, err
}
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
-func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error {
if !r.ref.isValidForOutgoing() {
return tcpip.ErrInvalidEndpointState
}
- if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.Loop); err != nil {
+ // WriteHeaderIncludedPacket takes ownership of pkt, calculate numBytes first.
+ numBytes := pkt.Data.Size()
+
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
}
r.ref.nic.stats.Tx.Packets.Increment()
- r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size()))
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(numBytes))
return nil
}
@@ -233,6 +239,12 @@ func (r *Route) MTU() uint32 {
return r.ref.ep.MTU()
}
+// NetworkProtocolNumber returns the NetworkProtocolNumber of the underlying
+// network endpoint.
+func (r *Route) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return r.ref.ep.NetworkProtocolNumber()
+}
+
// Release frees all resources associated with the route.
func (r *Route) Release() {
if r.ref != nil {
@@ -244,7 +256,9 @@ func (r *Route) Release() {
// Clone Clone a route such that the original one can be released and the new
// one will remain valid.
func (r *Route) Clone() Route {
- r.ref.incRef()
+ if r.ref != nil {
+ r.ref.incRef()
+ }
return *r
}
@@ -269,3 +283,36 @@ func (r *Route) MakeLoopedRoute() Route {
func (r *Route) Stack() *Stack {
return r.ref.stack()
}
+
+// IsOutboundBroadcast returns true if the route is for an outbound broadcast
+// packet.
+func (r *Route) IsOutboundBroadcast() bool {
+ // Only IPv4 has a notion of broadcast.
+ return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast
+}
+
+// IsInboundBroadcast returns true if the route is for an inbound broadcast
+// packet.
+func (r *Route) IsInboundBroadcast() bool {
+ // Only IPv4 has a notion of broadcast.
+ if r.LocalAddress == header.IPv4Broadcast {
+ return true
+ }
+
+ addr := r.ref.addrWithPrefix()
+ subnet := addr.Subnet()
+ return subnet.IsBroadcast(r.LocalAddress)
+}
+
+// ReverseRoute returns new route with given source and destination address.
+func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route {
+ return Route{
+ NetProto: r.NetProto,
+ LocalAddress: dst,
+ LocalLinkAddress: r.RemoteLinkAddress,
+ RemoteAddress: src,
+ RemoteLinkAddress: r.LocalLinkAddress,
+ ref: r.ref,
+ Loop: r.Loop,
+ }
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 71e0618f4..814b3e94a 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -20,19 +20,20 @@
package stack
import (
+ "bytes"
"encoding/binary"
"math"
- "sync"
+ mathrand "math/rand"
"sync/atomic"
"time"
"golang.org/x/time/rate"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/waiter"
@@ -54,7 +55,7 @@ const (
// fakeNetNumber is used as a protocol number in tests.
//
// This constant should match fakeNetNumber in stack_test.go.
- fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
)
type forwardingFlag uint32
@@ -77,8 +78,7 @@ func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag {
case header.IPv6ProtocolNumber:
flag = forwardingIPv6
case fakeNetNumber:
- // This network protocol number is used in stack_test to test
- // packet forwarding.
+ // This network protocol number is used to test packet forwarding.
flag = forwardingFake
default:
// We only support forwarding for IPv4 and IPv6.
@@ -88,7 +88,7 @@ func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag {
type transportProtocolState struct {
proto TransportProtocol
- defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
+ defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
@@ -109,6 +109,16 @@ type TCPCubicState struct {
WEst float64
}
+// TCPRACKState is used to hold a copy of the internal RACK state when the
+// TCPProbeFunc is invoked.
+type TCPRACKState struct {
+ XmitTime time.Time
+ EndSequence seqnum.Value
+ FACK seqnum.Value
+ RTT time.Duration
+ Reord bool
+}
+
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
type TCPEndpointID struct {
// LocalPort is the local port associated with the endpoint.
@@ -248,6 +258,9 @@ type TCPSenderState struct {
// Cubic holds the state related to CUBIC congestion control.
Cubic TCPCubicState
+
+ // RACKState holds the state related to RACK loss detection algorithm.
+ RACKState TCPRACKState
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
@@ -271,11 +284,11 @@ type RcvBufAutoTuneParams struct {
// was started.
MeasureTime time.Time
- // CopiedBytes is the number of bytes copied to user space since
+ // CopiedBytes is the number of bytes copied to userspace since
// this measure began.
CopiedBytes int
- // PrevCopiedBytes is the number of bytes copied to user space in
+ // PrevCopiedBytes is the number of bytes copied to userspace in
// the previous RTT period.
PrevCopiedBytes int
@@ -382,6 +395,45 @@ type ResumableEndpoint interface {
Resume(*Stack)
}
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+ return atomic.AddUint64((*uint64)(u), 1)
+}
+
+// NICNameFromID is a function that returns a stable name for the specified NIC,
+// even if different NIC IDs are used to refer to the same NIC in different
+// program runs. It is used when generating opaque interface identifiers (IIDs).
+// If the NIC was created with a name, it will be passed to NICNameFromID.
+//
+// NICNameFromID SHOULD return unique NIC names so unique opaque IIDs are
+// generated for the same prefix on differnt NICs.
+type NICNameFromID func(tcpip.NICID, string) string
+
+// OpaqueInterfaceIdentifierOptions holds the options related to the generation
+// of opaque interface indentifiers (IIDs) as defined by RFC 7217.
+type OpaqueInterfaceIdentifierOptions struct {
+ // NICNameFromID is a function that returns a stable name for a specified NIC,
+ // even if the NIC ID changes over time.
+ //
+ // Must be specified to generate the opaque IID.
+ NICNameFromID NICNameFromID
+
+ // SecretKey is a pseudo-random number used as the secret key when generating
+ // opaque IIDs as defined by RFC 7217. The key SHOULD be at least
+ // header.OpaqueIIDSecretKeyMinBytes bytes and MUST follow minimum randomness
+ // requirements for security as outlined by RFC 4086. SecretKey MUST NOT
+ // change between program runs, unless explicitly changed.
+ //
+ // OpaqueInterfaceIdentifierOptions takes ownership of SecretKey. SecretKey
+ // MUST NOT be modified after Stack is created.
+ //
+ // May be nil, but a nil value is highly discouraged to maintain
+ // some level of randomness between nodes.
+ SecretKey []byte
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -399,12 +451,17 @@ type Stack struct {
linkAddrCache *linkAddrCache
- mu sync.RWMutex
- nics map[tcpip.NICID]*NIC
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
// forwarding contains the enable bits for packet forwarding for different
// network protocols.
- forwarding uint32
+ forwarding struct {
+ sync.RWMutex
+ flag forwardingFlag
+ }
+
+ cleanupEndpoints map[TransportEndpoint]struct{}
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -424,7 +481,8 @@ type Stack struct {
handleLocal bool
// tables are the iptables packet filtering and manipulation rules.
- tables iptables.IPTables
+ // TODO(gvisor.dev/issue/170): S/R this field.
+ tables *IPTables
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
@@ -434,23 +492,62 @@ type Stack struct {
// by the stack.
icmpRateLimiter *ICMPRateLimiter
- // portSeed is a one-time random value initialized at stack startup
+ // seed is a one-time random value initialized at stack startup
// and is used to seed the TCP port picking on active connections
//
// TODO(gvisor.dev/issue/940): S/R this field.
- portSeed uint32
+ seed uint32
// ndpConfigs is the default NDP configurations used by interfaces.
ndpConfigs NDPConfigurations
+ // nudConfigs is the default NUD configurations used by interfaces.
+ nudConfigs NUDConfigurations
+
// autoGenIPv6LinkLocal determines whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
- // See the AutoGenIPv6LinkLocal field of Options for more details.
+ // to auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs. See the AutoGenIPv6LinkLocal field of Options for more details.
autoGenIPv6LinkLocal bool
// ndpDisp is the NDP event dispatcher that is used to send the netstack
// integrator NDP related events.
ndpDisp NDPDispatcher
+
+ // nudDisp is the NUD event dispatcher that is used to send the netstack
+ // integrator NUD related events.
+ nudDisp NUDDispatcher
+
+ // uniqueIDGenerator is a generator of unique identifiers.
+ uniqueIDGenerator UniqueID
+
+ // opaqueIIDOpts hold the options for generating opaque interface identifiers
+ // (IIDs) as outlined by RFC 7217.
+ opaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // tempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ tempIIDSeed []byte
+
+ // forwarder holds the packets that wait for their link-address resolutions
+ // to complete, and forwards them when each resolution is done.
+ forwarder *forwardQueue
+
+ // randomGenerator is an injectable pseudo random generator that can be
+ // used when a random number is required.
+ randomGenerator *mathrand.Rand
+
+ // sendBufferSize holds the min/default/max send buffer sizes for
+ // endpoints other than TCP.
+ sendBufferSize SendBufferSizeOption
+
+ // receiveBufferSize holds the min/default/max receive buffer sizes for
+ // endpoints other than TCP.
+ receiveBufferSize ReceiveBufferSizeOption
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+ UniqueID() uint64
}
// Options contains optional Stack configuration.
@@ -474,6 +571,9 @@ type Options struct {
// stack (false).
HandleLocal bool
+ // UniqueID is an optional generator of unique identifiers.
+ UniqueID UniqueID
+
// NDPConfigs is the default NDP configurations used by interfaces.
//
// By default, NDPConfigs will have a zero value for its
@@ -481,13 +581,18 @@ type Options struct {
// before assigning an address to a NIC.
NDPConfigs NDPConfigurations
- // AutoGenIPv6LinkLocal determins whether or not the stack will attempt
- // to auto-generate an IPv6 link-local address for newly enabled NICs.
+ // NUDConfigs is the default NUD configurations used by interfaces.
+ NUDConfigs NUDConfigurations
+
+ // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to
+ // auto-generate an IPv6 link-local address for newly enabled non-loopback
+ // NICs.
+ //
// Note, setting this to true does not mean that a link-local address
- // will be assigned right away, or at all. If Duplicate Address
- // Detection is enabled, an address will only be assigned if it
- // successfully resolves. If it fails, no further attempt will be made
- // to auto-generate an IPv6 link-local address.
+ // will be assigned right away, or at all. If Duplicate Address Detection
+ // is enabled, an address will only be assigned if it successfully resolves.
+ // If it fails, no further attempt will be made to auto-generate an IPv6
+ // link-local address.
//
// The generated link-local address will follow RFC 4291 Appendix A
// guidelines.
@@ -497,9 +602,39 @@ type Options struct {
// receive NDP related events.
NDPDisp NDPDispatcher
+ // NUDDisp is the NUD event dispatcher that an integrator can provide to
+ // receive NUD related events.
+ NUDDisp NUDDispatcher
+
// RawFactory produces raw endpoints. Raw endpoints are enabled only if
// this is non-nil.
RawFactory RawFactory
+
+ // OpaqueIIDOpts hold the options for generating opaque interface
+ // identifiers (IIDs) as outlined by RFC 7217.
+ OpaqueIIDOpts OpaqueInterfaceIdentifierOptions
+
+ // RandSource is an optional source to use to generate random
+ // numbers. If omitted it defaults to a Source seeded by the data
+ // returned by rand.Read().
+ //
+ // RandSource must be thread-safe.
+ RandSource mathrand.Source
+
+ // TempIIDSeed is used to seed the initial temporary interface identifier
+ // history value used to generate IIDs for temporary SLAAC addresses.
+ //
+ // Temporary SLAAC adresses are short-lived addresses which are unpredictable
+ // and random from the perspective of other nodes on the network. It is
+ // recommended that the seed be a random byte buffer of at least
+ // header.IIDSize bytes to make sure that temporary SLAAC addresses are
+ // sufficiently random. It should follow minimum randomness requirements for
+ // security as outlined by RFC 4086.
+ //
+ // Note: using a nil value, the same seed across netstack program runs, or a
+ // seed that is too small would reduce randomness and increase predictability,
+ // defeating the purpose of temporary SLAAC addresses.
+ TempIIDSeed []byte
}
// TransportEndpointInfo holds useful information about a transport endpoint
@@ -526,6 +661,51 @@ type TransportEndpointInfo struct {
RegisterNICID tcpip.NICID
}
+// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6
+// address and returns the network protocol number to be used to communicate
+// with the specified address. It returns an error if the passed address is
+// incompatible with the receiver.
+//
+// Preconditon: the parent endpoint mu must be held while calling this method.
+func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.NetProto
+ switch len(addr.Addr) {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ if header.IsV4MappedAddress(addr.Addr) {
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == header.IPv4Any {
+ addr.Addr = ""
+ }
+ }
+ }
+
+ switch len(e.ID.LocalAddress) {
+ case header.IPv4AddressSize:
+ if len(addr.Addr) == header.IPv6AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+ case header.IPv6AddressSize:
+ if len(addr.Addr) == header.IPv4AddressSize {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNetworkUnreachable
+ }
+ }
+
+ switch {
+ case netProto == e.NetProto:
+ case netProto == header.IPv4ProtocolNumber && e.NetProto == header.IPv6ProtocolNumber:
+ if v6only {
+ return tcpip.FullAddress{}, 0, tcpip.ErrNoRoute
+ }
+ default:
+ return tcpip.FullAddress{}, 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return addr, netProto, nil
+}
+
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
// marker interface.
func (*TransportEndpointInfo) IsEndpointInfo() {}
@@ -546,24 +726,56 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ if opts.UniqueID == nil {
+ opts.UniqueID = new(uniqueIDGenerator)
+ }
+
+ randSrc := opts.RandSource
+ if randSrc == nil {
+ // Source provided by mathrand.NewSource is not thread-safe so
+ // we wrap it in a simple thread-safe version.
+ randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())}
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
+ opts.NUDConfigs.resetInvalidFields()
+
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
+ cleanupEndpoints: make(map[TransportEndpoint]struct{}),
linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
+ tables: DefaultTables(),
icmpRateLimiter: NewICMPRateLimiter(),
- portSeed: generateRandUint32(),
+ seed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
+ nudConfigs: opts.NUDConfigs,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
+ nudDisp: opts.NUDDisp,
+ opaqueIIDOpts: opts.OpaqueIIDOpts,
+ tempIIDSeed: opts.TempIIDSeed,
+ forwarder: newForwardQueue(),
+ randomGenerator: mathrand.New(randSrc),
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
+ receiveBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultBufferSize,
+ Max: DefaultMaxBufferSize,
+ },
}
// Add specified network protocols.
@@ -590,6 +802,16 @@ func New(opts Options) *Stack {
return s
}
+// newJob returns a tcpip.Job using the Stack clock.
+func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job {
+ return tcpip.NewJob(s.clock, l, f)
+}
+
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+ return s.uniqueIDGenerator.UniqueID()
+}
+
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
@@ -651,16 +873,17 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber,
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
-func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.View, buffer.VectorisedView) bool) {
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
}
}
-// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
-func (s *Stack) NowNanoseconds() int64 {
- return s.clock.NowNanoseconds()
+// Clock returns the Stack's clock for retrieving the current time and
+// scheduling work.
+func (s *Stack) Clock() tcpip.Clock {
+ return s.clock
}
// Stats returns a mutable copy of the current stats.
@@ -673,30 +896,55 @@ func (s *Stack) Stats() tcpip.Stats {
// SetForwarding enables or disables packet forwarding between NICs.
func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) {
+ s.forwarding.Lock()
+ defer s.forwarding.Unlock()
+
+ // If this stack does not support the protocol, do nothing.
+ if _, ok := s.networkProtocols[protocol]; !ok {
+ return
+ }
+
flag := getForwardingFlag(protocol)
- for {
- forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
- var newValue forwardingFlag
+
+ // If the forwarding value for this protocol hasn't changed then do
+ // nothing.
+ if s.forwarding.flag&getForwardingFlag(protocol) != 0 == enable {
+ return
+ }
+
+ var newValue forwardingFlag
+ if enable {
+ newValue = s.forwarding.flag | flag
+ } else {
+ newValue = s.forwarding.flag & ^flag
+ }
+ s.forwarding.flag = newValue
+
+ // Enable or disable NDP for IPv6.
+ if protocol == header.IPv6ProtocolNumber {
if enable {
- newValue = forwarding | flag
+ for _, nic := range s.nics {
+ nic.becomeIPv6Router()
+ }
} else {
- newValue = forwarding & ^flag
- }
- if atomic.CompareAndSwapUint32(&s.forwarding, uint32(forwarding), uint32(newValue)) {
- break
+ for _, nic := range s.nics {
+ nic.becomeIPv6Host()
+ }
}
}
}
// Forwarding returns if packet forwarding between NICs is enabled.
func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
- flag := getForwardingFlag(protocol)
- forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
- return forwarding & flag != 0
+ s.forwarding.RLock()
+ defer s.forwarding.RUnlock()
+ return s.forwarding.flag&getForwardingFlag(protocol) != 0
}
// SetRouteTable assigns the route table to be used by this stack. It
// specifies which NIC to use for given destination address ranges.
+//
+// This method takes ownership of the table.
func (s *Stack) SetRouteTable(table []tcpip.Route) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -711,6 +959,13 @@ func (s *Stack) GetRouteTable() []tcpip.Route {
return append([]tcpip.Route(nil), s.routeTable...)
}
+// AddRoute appends a route to the route table.
+func (s *Stack) AddRoute(route tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.routeTable = append(s.routeTable, route)
+}
+
// NewEndpoint creates a new transport layer endpoint of the given protocol.
func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
@@ -751,9 +1006,32 @@ func (s *Stack) NewPacketEndpoint(cooked bool, netProto tcpip.NetworkProtocolNum
return s.rawFactory.NewPacketEndpoint(s, cooked, netProto, waiterQueue)
}
-// createNIC creates a NIC with the provided id and link-layer endpoint, and
-// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
+// NICContext is an opaque pointer used to store client-supplied NIC metadata.
+type NICContext interface{}
+
+// NICOptions specifies the configuration of a NIC as it is being created.
+// The zero value creates an enabled, unnamed NIC.
+type NICOptions struct {
+ // Name specifies the name of the NIC.
+ Name string
+
+ // Disabled specifies whether to avoid calling Attach on the passed
+ // LinkEndpoint.
+ Disabled bool
+
+ // Context specifies user-defined data that will be returned in stack.NICInfo
+ // for the NIC. Clients of this library can use it to add metadata that
+ // should be tracked alongside a NIC, to avoid having to keep a
+ // map[tcpip.NICID]metadata mirroring stack.Stack's nic map.
+ Context NICContext
+}
+
+// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
+// NICOptions. See the documentation on type NICOptions for details on how
+// NICs can be configured.
+//
+// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
+func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -762,44 +1040,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled,
return tcpip.ErrDuplicateNICID
}
- n := newNIC(s, id, name, ep, loopback)
+ // Make sure name is unique, unless unnamed.
+ if opts.Name != "" {
+ for _, n := range s.nics {
+ if n.Name() == opts.Name {
+ return tcpip.ErrDuplicateNICID
+ }
+ }
+ }
+ n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
- if enabled {
+ if !opts.Disabled {
return n.enable()
}
return nil
}
-// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
+// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, true, false)
+ return s.CreateNICWithOptions(id, ep, NICOptions{})
}
-// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
-// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, false)
-}
-
-// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
-// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, true, true)
-}
-
-// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
-// but leave it disable. Stack.EnableNIC must be called before the link-layer
-// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, "", ep, false, false)
-}
-
-// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
-// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
- return s.createNIC(id, name, ep, false, false)
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -808,36 +1082,72 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
return nic.enable()
}
-// CheckNIC checks if a NIC is usable.
-func (s *Stack) CheckNIC(id tcpip.NICID) bool {
+// DisableNIC disables the given NIC.
+func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
+ defer s.mu.RUnlock()
+
nic, ok := s.nics[id]
- s.mu.RUnlock()
- if ok {
- return nic.linkEP.IsAttached()
+ if !ok {
+ return tcpip.ErrUnknownNICID
}
- return false
+
+ return nic.disable()
}
-// NICSubnets returns a map of NICIDs to their associated subnets.
-func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
+// CheckNIC checks if a NIC is usable.
+func (s *Stack) CheckNIC(id tcpip.NICID) bool {
s.mu.RLock()
defer s.mu.RUnlock()
- nics := map[tcpip.NICID][]tcpip.Subnet{}
+ nic, ok := s.nics[id]
+ if !ok {
+ return false
+ }
- for id, nic := range s.nics {
- nics[id] = append(nics[id], nic.AddressRanges()...)
+ return nic.enabled()
+}
+
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return s.removeNICLocked(id)
+}
+
+// removeNICLocked removes NIC and all related routes from the network stack.
+//
+// s.mu must be locked.
+func (s *Stack) removeNICLocked(id tcpip.NICID) *tcpip.Error {
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
}
- return nics
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ s.routeTable[i] = tcpip.Route{}
+ if r.NIC != id {
+ // Keep this route.
+ s.routeTable[n] = r
+ n++
+ }
+ }
+
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
}
// NICInfo captures the name and addresses assigned to a NIC.
@@ -853,6 +1163,23 @@ type NICInfo struct {
MTU uint32
Stats NICStats
+
+ // Context is user-supplied data optionally supplied in CreateNICWithOptions.
+ // See type NICOptions for more details.
+ Context NICContext
+
+ // ARPHardwareType holds the ARP Hardware type of the NIC. This is the
+ // value sent in haType field of an ARP Request sent by this NIC and the
+ // value expected in the haType field of an ARP response.
+ ARPHardwareType header.ARPHardwareType
+}
+
+// HasNIC returns true if the NICID is defined in the stack.
+func (s *Stack) HasNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ _, ok := s.nics[id]
+ s.mu.RUnlock()
+ return ok
}
// NICInfo returns a map of NICIDs to their associated information.
@@ -864,9 +1191,9 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.linkEP.IsAttached(),
+ Running: nic.enabled(),
Promiscuous: nic.isPromiscuousMode(),
- Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
+ Loopback: nic.isLoopback(),
}
nics[id] = NICInfo{
Name: nic.name,
@@ -875,6 +1202,8 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
+ Context: nic.context,
+ ARPHardwareType: nic.linkEP.ARPHardwareType(),
}
}
return nics
@@ -936,35 +1265,6 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return nic.AddAddress(protocolAddress, peb)
}
-// AddAddressRange adds a range of addresses to the specified NIC. The range is
-// given by a subnet address, and all addresses contained in the subnet are
-// used except for the subnet address itself and the subnet's broadcast
-// address.
-func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic, ok := s.nics[id]; ok {
- nic.AddAddressRange(protocol, subnet)
- return nil
- }
-
- return tcpip.ErrUnknownNICID
-}
-
-// RemoveAddressRange removes the range of addresses from the specified NIC.
-func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if nic, ok := s.nics[id]; ok {
- nic.RemoveAddressRange(subnet)
- return nil
- }
-
- return tcpip.ErrUnknownNICID
-}
-
// RemoveAddress removes an existing network-layer address from the specified
// NIC.
func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
@@ -991,9 +1291,11 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
return nics
}
-// GetMainNICAddress returns the first primary address and prefix for the given
-// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
-// value if the NIC doesn't have a primary address for the given protocol.
+// GetMainNICAddress returns the first non-deprecated primary address and prefix
+// for the given NIC and protocol. If no non-deprecated primary address exists,
+// a deprecated primary address and prefix will be returned. Returns an error if
+// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
+// address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1003,17 +1305,12 @@ func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocol
return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- for _, a := range nic.PrimaryAddresses() {
- if a.Protocol == protocol {
- return a.AddressWithPrefix, nil
- }
- }
- return tcpip.AddressWithPrefix{}, nil
+ return nic.primaryAddress(protocol), nil
}
-func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
+func (s *Stack) getRefEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
if len(localAddr) == 0 {
- return nic.primaryEndpoint(netProto)
+ return nic.primaryEndpoint(netProto, remoteAddr)
}
return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
}
@@ -1024,13 +1321,13 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isBroadcast := remoteAddr == header.IPv4Broadcast
+ isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
- needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
+ needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok {
- if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
- return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback), nil
+ if nic, ok := s.nics[id]; ok && nic.enabled() {
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
+ return makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
}
} else {
@@ -1038,18 +1335,25 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok {
- if ref := s.getRefEP(nic, localAddr, netProto); ref != nil {
+ if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
+ if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
- remoteAddr = ref.ep.ID().LocalAddress
+ remoteAddr = ref.address()
}
- r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.loopback, multicastLoop && !nic.loopback)
- if needRoute {
- r.NextHop = route.Gateway
+ r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback())
+ r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr)
+
+ if len(route.Gateway) > 0 {
+ if needRoute {
+ r.NextHop = route.Gateway
+ }
+ } else if r.directedBroadcast {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
+
return r, nil
}
}
@@ -1073,13 +1377,13 @@ func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool
// CheckLocalAddress determines if the given local address exists, and if it
// does, returns the id of the NIC it's bound to. Returns 0 if the address
// does not exist.
-func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
s.mu.RLock()
defer s.mu.RUnlock()
// If a NIC is specified, we try to find the address there only.
- if nicid != 0 {
- nic := s.nics[nicid]
+ if nicID != 0 {
+ nic := s.nics[nicID]
if nic == nil {
return 0
}
@@ -1138,35 +1442,35 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
}
// AddLinkAddress adds a link address to the stack link cache.
-func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
s.linkAddrCache.add(fullAddr, linkAddr)
// TODO: provide a way for a transport endpoint to receive a signal
// that AddLinkAddress for a particular address has been called.
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
-func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
- nic := s.nics[nicid]
+ nic := s.nics[nicID]
if nic == nil {
s.mu.RUnlock()
return "", nil, tcpip.ErrUnknownNICID
}
s.mu.RUnlock()
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
linkRes := s.linkAddrResolvers[protocol]
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
// RemoveWaker implements LinkAddressCache.RemoveWaker.
-func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic := s.nics[nicid]; nic == nil {
- fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ if nic := s.nics[nicID]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr}
s.linkAddrCache.removeWaker(fullAddr, waker)
}
}
@@ -1175,14 +1479,45 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
+// the stack transport dispatcher.
+func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// StartTransportEndpointCleanup removes the endpoint with the given id from
+// the stack transport dispatcher. It also transitions it to the cleanup stage.
+func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.cleanupEndpoints[ep] = struct{}{}
+
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
+}
+
+// CompleteTransportEndpointCleanup removes the endpoint from the cleanup
+// stage.
+func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) {
+ s.mu.Lock()
+ delete(s.cleanupEndpoints, ep)
+ s.mu.Unlock()
+}
+
+// FindTransportEndpoint finds an endpoint that most closely matches the provided
+// id. If no endpoint is found it returns nil.
+func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ return s.demux.findTransportEndpoint(netProto, transProto, id, r)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
@@ -1206,6 +1541,81 @@ func (s *Stack) RegisterRestoredEndpoint(e ResumableEndpoint) {
s.mu.Unlock()
}
+// RegisteredEndpoints returns all endpoints which are currently registered.
+func (s *Stack) RegisteredEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var es []TransportEndpoint
+ for _, e := range s.demux.protocol {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// CleanupEndpoints returns endpoints currently in the cleanup state.
+func (s *Stack) CleanupEndpoints() []TransportEndpoint {
+ s.mu.Lock()
+ es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints))
+ for e := range s.cleanupEndpoints {
+ es = append(es, e)
+ }
+ s.mu.Unlock()
+ return es
+}
+
+// RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful
+// for restoring a stack after a save.
+func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
+ s.mu.Lock()
+ for _, e := range es {
+ s.cleanupEndpoints[e] = struct{}{}
+ }
+ s.mu.Unlock()
+}
+
+// Close closes all currently registered transport endpoints.
+//
+// Endpoints created or modified during this call may not get closed.
+func (s *Stack) Close() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
+ }
+}
+
+// Wait waits for all transport and link endpoints to halt their worker
+// goroutines.
+//
+// Endpoints created or modified during this call may not get waited on.
+//
+// Note that link endpoints must be stopped via an implementation specific
+// mechanism.
+func (s *Stack) Wait() {
+ for _, e := range s.RegisteredEndpoints() {
+ e.Wait()
+ }
+ for _, e := range s.CleanupEndpoints() {
+ e.Wait()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, n := range s.nics {
+ n.linkEP.Wait()
+ }
+}
+
// Resume restarts the stack after a restore. This must be called after the
// entire system has been restored.
func (s *Stack) Resume() {
@@ -1280,9 +1690,9 @@ func (s *Stack) unregisterPacketEndpointLocked(nicID tcpip.NICID, netProto tcpip
// WritePacket writes data directly to the specified NIC. It adds an ethernet
// header based on the arguments.
-func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
+func (s *Stack) WritePacket(nicID tcpip.NICID, dst tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
- nic, ok := s.nics[nicid]
+ nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
@@ -1296,10 +1706,10 @@ func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto t
}
fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
fakeHeader.Encode(&ethFields)
- ethHeader := buffer.View(fakeHeader).ToVectorisedView()
- ethHeader.Append(payload)
+ vv := buffer.View(fakeHeader).ToVectorisedView()
+ vv.Append(payload)
- if err := nic.linkEP.WriteRawPacket(ethHeader); err != nil {
+ if err := nic.linkEP.WriteRawPacket(vv); err != nil {
return err
}
@@ -1308,9 +1718,9 @@ func (s *Stack) WritePacket(nicid tcpip.NICID, dst tcpip.LinkAddress, netProto t
// WriteRawPacket writes data directly to the specified NIC without adding any
// headers.
-func (s *Stack) WriteRawPacket(nicid tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
+func (s *Stack) WriteRawPacket(nicID tcpip.NICID, payload buffer.VectorisedView) *tcpip.Error {
s.mu.Lock()
- nic, ok := s.nics[nicid]
+ nic, ok := s.nics[nicID]
s.mu.Unlock()
if !ok {
return tcpip.ErrUnknownDevice
@@ -1403,14 +1813,21 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
return tcpip.ErrUnknownNICID
}
-// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() iptables.IPTables {
- return s.tables
+// IsInGroup returns true if the NIC with ID nicID has joined the multicast
+// group multicastAddr.
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.isInGroup(multicastAddr), nil
+ }
+ return false, tcpip.ErrUnknownNICID
}
-// SetIPTables sets the stack's iptables.
-func (s *Stack) SetIPTables(ipt iptables.IPTables) {
- s.tables = ipt
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() *IPTables {
+ return s.tables
}
// ICMPLimit returns the maximum number of ICMP messages that can be sent
@@ -1489,16 +1906,66 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip
}
nic.setNDPConfigs(c)
+ return nil
+}
+
+// NUDConfigurations gets the per-interface NUD configurations.
+func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+
+ if !ok {
+ return NUDConfigurations{}, tcpip.ErrUnknownNICID
+ }
+
+ return nic.NUDConfigs()
+}
+
+// SetNUDConfigurations sets the per-interface NUD configurations.
+//
+// Note, if c contains invalid NUD configuration values, it will be fixed to
+// use default values for the erroneous values.
+func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error {
+ s.mu.RLock()
+ nic, ok := s.nics[id]
+ s.mu.RUnlock()
+
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.setNUDConfigs(c)
+}
+
+// HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement
+// message that it needs to handle.
+func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.handleNDPRA(ip, ra)
return nil
}
-// PortSeed returns a 32 bit value that can be used as a seed value for port
-// picking.
+// Seed returns a 32 bit value that can be used as a seed value for port
+// picking, ISN generation etc.
//
// NOTE: The seed is generated once during stack initialization only.
-func (s *Stack) PortSeed() uint32 {
- return s.portSeed
+func (s *Stack) Seed() uint32 {
+ return s.seed
+}
+
+// Rand returns a reference to a pseudo random generator that can be used
+// to generate random numbers as required.
+func (s *Stack) Rand() *mathrand.Rand {
+ return s.randomGenerator
}
func generateRandUint32() uint32 {
@@ -1508,3 +1975,49 @@ func generateRandUint32() uint32 {
}
return binary.LittleEndian.Uint32(b)
}
+
+func generateRandInt64() int64 {
+ b := make([]byte, 8)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ buf := bytes.NewReader(b)
+ var v int64
+ if err := binary.Read(buf, binary.LittleEndian, &v); err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// FindNetworkEndpoint returns the network endpoint for the given address.
+func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ for _, nic := range s.nics {
+ id := NetworkEndpointID{address}
+
+ if ref, ok := nic.mu.endpoints[id]; ok {
+ nic.mu.RLock()
+ defer nic.mu.RUnlock()
+
+ // An endpoint with this id exists, check if it can be
+ // used and return it.
+ return ref.ep, nil
+ }
+ }
+ return nil, tcpip.ErrBadAddress
+}
+
+// FindNICNameFromID returns the name of the nic for the given NICID.
+func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return ""
+ }
+
+ return nic.Name()
+}
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
new file mode 100644
index 000000000..0b093e6c5
--- /dev/null
+++ b/pkg/tcpip/stack/stack_options.go
@@ -0,0 +1,106 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import "gvisor.dev/gvisor/pkg/tcpip"
+
+const (
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4 KiB
+
+ // DefaultBufferSize is the default size of the send/recv buffer for a
+ // transport endpoint.
+ DefaultBufferSize = 212 << 10 // 212 KiB
+
+ // DefaultMaxBufferSize is the default maximum permitted size of a
+ // send/receive buffer.
+ DefaultMaxBufferSize = 4 << 20 // 4 MiB
+)
+
+// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max send buffer sizes.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
+// get/set the default, min and max receive buffer sizes.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// SetOption allows setting stack wide options.
+func (s *Stack) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SendBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.sendBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ // Make sure we don't allow lowering the buffer below minimum
+ // required for stack to work.
+ if v.Min < MinBufferSize {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ if v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ s.mu.Lock()
+ s.receiveBufferSize = v
+ s.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option allows retrieving stack wide options.
+func (s *Stack) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SendBufferSizeOption:
+ s.mu.RLock()
+ *v = s.sendBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ s.mu.RLock()
+ *v = s.receiveBufferSize
+ s.mu.RUnlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index ef3d1beb0..f168be402 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -21,18 +21,24 @@ import (
"bytes"
"fmt"
"math"
+ "net"
"sort"
- "strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
@@ -51,6 +57,10 @@ const (
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
@@ -61,9 +71,7 @@ const (
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
- nicid tcpip.NICID
- id stack.NetworkEndpointID
- prefixLen int
+ nicID tcpip.NICID
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
ep stack.LinkEndpoint
@@ -74,43 +82,35 @@ func (f *fakeNetworkEndpoint) MTU() uint32 {
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
- return f.nicid
-}
-
-func (f *fakeNetworkEndpoint) PrefixLen() int {
- return f.prefixLen
+ return f.nicID
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
-func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
- return &f.id
-}
-
-func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) {
// Increment the received packet count in the protocol descriptor.
- f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
-
- // Consume the network header.
- b := vv.First()
- vv.TrimFront(fakeNetHeaderLen)
+ f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++
// Handle control packets.
- if b[2] == uint8(fakeControlProtocol) {
- nb := vv.First()
- if len(nb) < fakeNetHeaderLen {
+ if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) {
+ nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
return
}
-
- vv.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ f.dispatcher.DeliverTransportControlPacket(
+ tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ fakeNetNumber,
+ tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), buffer.View([]byte{}), vv)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -125,37 +125,37 @@ func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.ep.Capabilities()
}
-func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error {
+func (f *fakeNetworkEndpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber {
+ return f.proto.Number()
+}
+
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := hdr.Prepend(fakeNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
-
- if loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(payload.Views()))
- views[0] = hdr.View()
- views = append(views, payload.Views()...)
- vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
- f.HandlePacket(r, vv)
- }
- if loop&stack.PacketOut == 0 {
+ hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen)
+ hdr[dstAddrOffset] = r.RemoteAddress[0]
+ hdr[srcAddrOffset] = r.LocalAddress[0]
+ hdr[protocolNumberOffset] = byte(params.Protocol)
+
+ if r.Loop&stack.PacketLoop != 0 {
+ f.HandlePacket(r, pkt)
+ }
+ if r.Loop&stack.PacketOut == 0 {
return nil
}
- return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
+ return f.ep.WritePacket(r, gso, fakeNetNumber, pkt)
}
// WritePackets implements stack.LinkEndpoint.WritePackets.
-func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, hdrs []stack.PacketDescriptor, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) (int, *tcpip.Error) {
+func (f *fakeNetworkEndpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, params stack.NetworkHeaderParams) (int, *tcpip.Error) {
panic("not implemented")
}
-func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
return tcpip.ErrNotSupported
}
@@ -197,18 +197,16 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
}
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint {
return &fakeNetworkEndpoint{
- nicid: nicid,
- id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
- prefixLen: addrWithPrefix.PrefixLen,
+ nicID: nicID,
proto: f,
dispatcher: dispatcher,
ep: ep,
- }, nil
+ }
}
func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
@@ -233,10 +231,53 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
+// Parse implements TransportProtocol.Parse.
+func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ hdr, ok := pkt.NetworkHeader().Consume(fakeNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
+}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
+// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
+// that LinkEndpoint.Attach was called.
+type linkEPWithMockedAttach struct {
+ stack.LinkEndpoint
+ attached bool
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
+ l.LinkEndpoint.Attach(d)
+ l.attached = d != nil
+}
+
+func (l *linkEPWithMockedAttach) isAttached() bool {
+ return l.attached
+}
+
+// Checks to see if list contains an address.
+func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool {
+ for _, i := range list {
+ if i == item {
+ return true
+ }
+ }
+
+ return false
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -261,8 +302,10 @@ func TestNetworkReceive(t *testing.T) {
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
- buf[0] = 3
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ buf[dstAddrOffset] = 3
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -271,8 +314,10 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to first endpoint.
- buf[0] = 1
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ buf[dstAddrOffset] = 1
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -281,8 +326,10 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to second endpoint.
- buf[0] = 2
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ buf[dstAddrOffset] = 2
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -291,7 +338,9 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- ep.Inject(fakeNetNumber-1, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber-1, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -301,7 +350,9 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -320,8 +371,10 @@ func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Erro
}
func send(r stack.Route, payload buffer.View) *tcpip.Error {
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(r.MaxHeaderLength()),
+ Data: payload.ToVectorisedView(),
+ }))
}
func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
@@ -376,7 +429,9 @@ func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte b
func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if got := fakeNet.PacketCount(localAddrByte); got != want {
t.Errorf("receive packet count: got = %d, want %d", got, want)
}
@@ -493,6 +548,340 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr
}
}
+// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
+// a NetworkDispatcher when the NIC is created.
+func TestAttachToLinkEndpointImmediately(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ nicOpts stack.NICOptions
+ }{
+ {
+ name: "Create enabled NIC",
+ nicOpts: stack.NICOptions{Disabled: false},
+ },
+ {
+ name: "Create disabled NIC",
+ nicOpts: stack.NICOptions{Disabled: true},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+
+ if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
+ }
+ })
+ }
+}
+
+func TestDisableUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := loopback.New()
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ checkNIC := func(enabled bool) {
+ t.Helper()
+
+ allNICInfo := s.NICInfo()
+ nicInfo, ok := allNICInfo[nicID]
+ if !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ } else if nicInfo.Flags.Running != enabled {
+ t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
+ }
+
+ if got := s.CheckNIC(nicID); got != enabled {
+ t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
+ }
+ }
+
+ // NIC should initially report itself as disabled.
+ checkNIC(false)
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(true)
+
+ // If the NIC is not reporting a correct enabled status, we cannot trust the
+ // next check so end the test here.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(false)
+}
+
+func TestRemoveUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.RemoveNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.RemoveNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestRemoveNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
+ }
+
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if err := s.RemoveNIC(nicID); err != nil {
+ t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
+ }
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
+ }
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ }
+}
+
+func TestRouteWithDownNIC(t *testing.T) {
+ tests := []struct {
+ name string
+ downFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ upFn func(s *stack.Stack, nicID tcpip.NICID) *tcpip.Error
+ }{
+ {
+ name: "Disabled NIC",
+ downFn: (*stack.Stack).DisableNIC,
+ upFn: (*stack.Stack).EnableNIC,
+ },
+
+ // Once a NIC is removed, it cannot be brought up.
+ {
+ name: "Removed NIC",
+ downFn: (*stack.Stack).RemoveNIC,
+ },
+ }
+
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+ const addr1 = tcpip.Address("\x01")
+ const addr2 = tcpip.Address("\x02")
+ const nic1Dst = tcpip.Address("\x05")
+ const nic2Dst = tcpip.Address("\x06")
+
+ setup := func(t *testing.T) (*stack.Stack, *channel.Endpoint, *channel.Endpoint) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ return s, ep1, ep2
+ }
+
+ // Tests that routes through a down NIC are not used when looking up a route
+ // for a destination.
+ t.Run("Find", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, _, _ := setup(t)
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Bringing NIC1 down should result in no routes to odd addresses. Routes to
+ // even addresses should continue to be available as NIC2 is still up.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Bringing NIC2 down should result in no routes to even addresses. No
+ // route should be available to any address as routes to odd addresses
+ // were made unavailable by bringing NIC1 down above.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ if upFn := test.upFn; upFn != nil {
+ // Bringing NIC1 up should make routes to odd addresses available
+ // again. Routes to even addresses should continue to be unavailable
+ // as NIC2 is still down.
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+ }
+ })
+ }
+ })
+
+ // Tests that writing a packet using a Route through a down NIC fails.
+ t.Run("WritePacket", func(t *testing.T) {
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s, ep1, ep2 := setup(t)
+
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
+
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
+
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use NIC1 after being brought down should fail.
+ if err := test.downFn(s, nicID1); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use NIC2 after being brought down should fail.
+ if err := test.downFn(s, nicID2); err != nil {
+ t.Fatalf("test.downFn(_, %d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ if upFn := test.upFn; upFn != nil {
+ // Writes with Routes that use NIC1 after being brought up should
+ // succeed.
+ //
+ // TODO(gvisor.dev/issue/1491): Should we instead completely
+ // invalidate all Routes that were bound to a NIC that was brought
+ // down at some point?
+ if err := upFn(s, nicID1); err != nil {
+ t.Fatalf("test.upFn(_, %d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+ }
+ })
+ }
+ })
+}
+
func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
@@ -602,7 +991,7 @@ func TestAddressRemoval(t *testing.T) {
buf := buffer.NewView(30)
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -652,7 +1041,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -671,11 +1060,11 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
}
-func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+func verifyAddress(t *testing.T, s *stack.Stack, nicID tcpip.NICID, addr tcpip.Address) {
t.Helper()
- info, ok := s.NICInfo()[nicid]
+ info, ok := s.NICInfo()[nicID]
if !ok {
- t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ t.Fatalf("NICInfo() failed to find nicID=%d", nicID)
}
if len(addr) == 0 {
// No address given, verify that there is no address assigned to the NIC.
@@ -708,7 +1097,7 @@ func TestEndpointExpiration(t *testing.T) {
localAddrByte byte = 0x01
remoteAddr tcpip.Address = "\x03"
noAddr tcpip.Address = ""
- nicid tcpip.NICID = 1
+ nicID tcpip.NICID = 1
)
localAddr := tcpip.Address([]byte{localAddrByte})
@@ -720,7 +1109,7 @@ func TestEndpointExpiration(t *testing.T) {
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -734,16 +1123,16 @@ func TestEndpointExpiration(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
if promiscuous {
- if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ if err := s.SetPromiscuousMode(nicID, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
}
if spoofing {
- if err := s.SetSpoofing(nicid, true); err != nil {
+ if err := s.SetSpoofing(nicID, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
}
@@ -751,7 +1140,7 @@ func TestEndpointExpiration(t *testing.T) {
// 1. No Address yet, send should only work for spoofing, receive for
// promiscuous mode.
//-----------------------
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -766,20 +1155,20 @@ func TestEndpointExpiration(t *testing.T) {
// 2. Add Address, everything should work.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
// 3. Remove the address, send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -794,10 +1183,10 @@ func TestEndpointExpiration(t *testing.T) {
// 4. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -815,10 +1204,10 @@ func TestEndpointExpiration(t *testing.T) {
// 6. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -834,10 +1223,10 @@ func TestEndpointExpiration(t *testing.T) {
// 7. Add Address back, everything should work again.
//-----------------------
- if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
testSend(t, r, ep, nil)
@@ -845,17 +1234,17 @@ func TestEndpointExpiration(t *testing.T) {
// 8. Remove the route, sendTo/recv should still work.
//-----------------------
r.Release()
- verifyAddress(t, s, nicid, localAddr)
+ verifyAddress(t, s, nicID, localAddr)
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
// 9. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
//-----------------------
- if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ if err := s.RemoveAddress(nicID, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- verifyAddress(t, s, nicid, noAddr)
+ verifyAddress(t, s, nicID, noAddr)
if promiscuous {
testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
@@ -897,7 +1286,7 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
@@ -1071,19 +1460,19 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) {
protoAddr := tcpip.ProtocolAddress{Protocol: fakeNetNumber, AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Any, 0}}
if err := s.AddProtocolAddress(1, protoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", protoAddr, err)
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", protoAddr, err)
}
r, err := s.FindRoute(1, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
if err := verifyRoute(r, stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err)
}
// If the NIC doesn't exist, it won't work.
if _, err := s.FindRoute(2, header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */); err != tcpip.ErrNetworkUnreachable {
- t.Fatalf("got FindRoute(2, %s, %s, %d) = %s want = %s", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
+ t.Fatalf("got FindRoute(2, %v, %v, %d) = %v want = %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err, tcpip.ErrNetworkUnreachable)
}
}
@@ -1109,12 +1498,12 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
}
nic1ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic1Addr}
if err := s.AddProtocolAddress(1, nic1ProtoAddr); err != nil {
- t.Fatalf("AddProtocolAddress(1, %s) failed: %s", nic1ProtoAddr, err)
+ t.Fatalf("AddProtocolAddress(1, %v) failed: %v", nic1ProtoAddr, err)
}
nic2ProtoAddr := tcpip.ProtocolAddress{fakeNetNumber, nic2Addr}
if err := s.AddProtocolAddress(2, nic2ProtoAddr); err != nil {
- t.Fatalf("AddAddress(2, %s) failed: %s", nic2ProtoAddr, err)
+ t.Fatalf("AddAddress(2, %v) failed: %v", nic2ProtoAddr, err)
}
// Set the initial route table.
@@ -1129,10 +1518,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) {
// When an interface is given, the route for a broadcast goes through it.
r, err := s.FindRoute(1, nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatalf("FindRoute(1, %s, %s, %d) failed: %s", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
if err := verifyRoute(r, stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil {
- t.Errorf("FindRoute(1, %s, %s, %d) returned unexpected Route: %s)", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
+ t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err)
}
// When an interface is not given, it consults the route table.
@@ -1254,149 +1643,6 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
-// Add a range of addresses, then check that a packet is delivered.
-func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- const localAddrByte byte = 0x01
- buf[0] = localAddrByte
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
- if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- testRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
-func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
- t.Helper()
-
- // Loop over all addresses and check them.
- numOfAddresses := 1 << uint(8-subnet.Prefix())
- if numOfAddresses < 1 || numOfAddresses > 255 {
- t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
- }
-
- addrBytes := []byte(subnet.ID())
- for i := 0; i < numOfAddresses; i++ {
- addr := tcpip.Address(addrBytes)
- wantNicID := nicID
- // The subnet and broadcast addresses are skipped.
- if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() {
- wantNicID = 0
- }
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID)
- }
- addrBytes[0]++
- }
-
- // Trying the next address should always fail since it is outside the range.
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
- }
-}
-
-// Set a range of addresses, then remove it again, and check at each step that
-// CheckLocalAddress returns the correct NIC for each address or zero if not
-// existent.
-func TestCheckLocalAddressForSubnet(t *testing.T) {
- const nicID tcpip.NICID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID}})
- }
-
- subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
-
- testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
-
- if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */)
-
- if err := s.RemoveAddressRange(nicID, subnet); err != nil {
- t.Fatal("RemoveAddressRange failed:", err)
- }
-
- testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
-}
-
-// Set a range of addresses, then send a packet to a destination outside the
-// range and then check it doesn't get delivered.
-func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
-
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- {
- subnet, err := tcpip.NewSubnet("\x00", "\x00")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
- }
-
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- const localAddrByte byte = 0x01
- buf[0] = localAddrByte
- subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
- if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
- testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
-}
-
func TestNetworkOptions(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
@@ -1440,56 +1686,6 @@ func TestNetworkOptions(t *testing.T) {
}
}
-func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool {
- ranges, ok := s.NICAddressRanges()[id]
- if !ok {
- return false
- }
- for _, r := range ranges {
- if r == addrRange {
- return true
- }
- }
- return false
-}
-
-func TestAddresRangeAddRemove(t *testing.T) {
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep); err != nil {
- t.Fatal("CreateNIC failed:", err)
- }
-
- addr := tcpip.Address("\x01\x01\x01\x01")
- mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- addrRange, err := tcpip.NewSubnet(addr, mask)
- if err != nil {
- t.Fatal("NewSubnet failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-
- if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil {
- t.Fatal("AddAddressRange failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-
- if err := s.RemoveAddressRange(1, addrRange); err != nil {
- t.Fatal("RemoveAddressRange failed:", err)
- }
-
- if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
- t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
- }
-}
-
func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
for _, addrLen := range []int{4, 16} {
t.Run(fmt.Sprintf("addrLen=%d", addrLen), func(t *testing.T) {
@@ -1648,12 +1844,12 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
}
func TestAddAddress(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1661,7 +1857,7 @@ func TestAddAddress(t *testing.T) {
expectedAddresses := make([]tcpip.ProtocolAddress, 0, 2)
for _, addrLen := range []int{4, 16} {
address := addrGen.next(addrLen)
- if err := s.AddAddress(nicid, fakeNetNumber, address); err != nil {
+ if err := s.AddAddress(nicID, fakeNetNumber, address); err != nil {
t.Fatalf("AddAddress(address=%s) failed: %s", address, err)
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
@@ -1670,17 +1866,17 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1697,24 +1893,24 @@ func TestAddProtocolAddress(t *testing.T) {
PrefixLen: prefixLen,
},
}
- if err := s.AddProtocolAddress(nicid, protocolAddress); err != nil {
+ if err := s.AddProtocolAddress(nicID, protocolAddress); err != nil {
t.Errorf("AddProtocolAddress(%+v) failed: %s", protocolAddress, err)
}
expectedAddresses = append(expectedAddresses, protocolAddress)
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1725,7 +1921,7 @@ func TestAddAddressWithOptions(t *testing.T) {
for _, addrLen := range addrLenRange {
for _, behavior := range behaviorRange {
address := addrGen.next(addrLen)
- if err := s.AddAddressWithOptions(nicid, fakeNetNumber, address, behavior); err != nil {
+ if err := s.AddAddressWithOptions(nicID, fakeNetNumber, address, behavior); err != nil {
t.Fatalf("AddAddressWithOptions(address=%s, behavior=%d) failed: %s", address, behavior, err)
}
expectedAddresses = append(expectedAddresses, tcpip.ProtocolAddress{
@@ -1735,17 +1931,17 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
- const nicid = 1
+ const nicID = 1
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
ep := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, ep); err != nil {
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1764,7 +1960,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
PrefixLen: prefixLen,
},
}
- if err := s.AddProtocolAddressWithOptions(nicid, protocolAddress, behavior); err != nil {
+ if err := s.AddProtocolAddressWithOptions(nicID, protocolAddress, behavior); err != nil {
t.Fatalf("AddProtocolAddressWithOptions(%+v, %d) failed: %s", protocolAddress, behavior, err)
}
expectedAddresses = append(expectedAddresses, protocolAddress)
@@ -1772,10 +1968,95 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.AllAddresses()[nicid]
+ gotAddresses := s.AllAddresses()[nicID]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
+func TestCreateNICWithOptions(t *testing.T) {
+ type callArgsAndExpect struct {
+ nicID tcpip.NICID
+ opts stack.NICOptions
+ err *tcpip.Error
+ }
+
+ tests := []struct {
+ desc string
+ calls []callArgsAndExpect
+ }{
+ {
+ desc: "DuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth1"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth2"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "DuplicateName",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "lo"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{Name: "lo"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "Unnamed",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ },
+ },
+ {
+ desc: "UnnamedDuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ for _, call := range test.calls {
+ if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
+ t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
+ }
+ }
+ })
+ }
+}
+
func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
@@ -1798,7 +2079,9 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -1823,150 +2106,386 @@ func TestNICStats(t *testing.T) {
}
func TestNICForwarding(t *testing.T) {
- // Create a stack with the fake network protocol, two NICs, each with
- // an address.
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
- })
- s.SetForwarding(fakeNetNumber, true)
+ const nicID1 = 1
+ const nicID2 = 2
+ const dstAddr = tcpip.Address("\x03")
- ep1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, ep1); err != nil {
- t.Fatal("CreateNIC #1 failed:", err)
- }
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
- t.Fatal("AddAddress #1 failed:", err)
+ tests := []struct {
+ name string
+ headerLen uint16
+ }{
+ {
+ name: "Zero header length",
+ },
+ {
+ name: "Non-zero header length",
+ headerLen: 16,
+ },
}
- ep2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, ep2); err != nil {
- t.Fatal("CreateNIC #2 failed:", err)
- }
- if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
- t.Fatal("AddAddress #2 failed:", err)
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ s.SetForwarding(fakeNetNumber, true)
+
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddAddress(nicID1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x01): %s", nicID1, fakeNetNumber, err)
+ }
+
+ ep2 := channelLinkWithHeaderLength{
+ Endpoint: channel.New(10, defaultMTU, ""),
+ headerLength: test.headerLen,
+ }
+ if err := s.CreateNIC(nicID2, &ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+ if err := s.AddAddress(nicID2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress(%d, %d, 0x02): %s", nicID2, fakeNetNumber, err)
+ }
+
+ // Route all packets to dstAddr to NIC 2.
+ {
+ subnet, err := tcpip.NewSubnet(dstAddr, "\xff")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: nicID2}})
+ }
+
+ // Send a packet to dstAddr.
+ buf := buffer.NewView(30)
+ buf[dstAddrOffset] = dstAddr[0]
+ ep1.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ pkt, ok := ep2.Read()
+ if !ok {
+ t.Fatal("packet not forwarded")
+ }
+
+ // Test that the link's MaxHeaderLength is honoured.
+ if capacity, want := pkt.Pkt.AvailableHeaderBytes(), int(test.headerLen); capacity != want {
+ t.Errorf("got LinkHeader.AvailableLength() = %d, want = %d", capacity, want)
+ }
+
+ // Test that forwarding increments Tx stats correctly.
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Packets.Value(), uint64(1); got != want {
+ t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ }
+
+ if got, want := s.NICInfo()[nicID2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
+ t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ }
+ })
}
+}
- // Route all packets to address 3 to NIC 2.
- {
- subnet, err := tcpip.NewSubnet("\x03", "\xff")
- if err != nil {
- t.Fatal(err)
- }
- s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 2}})
+// TestNICContextPreservation tests that you can read out via stack.NICInfo the
+// Context data you pass via NICContext.Context in stack.CreateNICWithOptions.
+func TestNICContextPreservation(t *testing.T) {
+ var ctx *int
+ tests := []struct {
+ name string
+ opts stack.NICOptions
+ want stack.NICContext
+ }{
+ {
+ "context_set",
+ stack.NICOptions{Context: ctx},
+ ctx,
+ },
+ {
+ "context_not_set",
+ stack.NICOptions{},
+ nil,
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ id := tcpip.NICID(1)
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ if err := s.CreateNICWithOptions(id, ep, test.opts); err != nil {
+ t.Fatalf("got stack.CreateNICWithOptions(%d, %+v, %+v) = %s, want nil", id, ep, test.opts, err)
+ }
+ nicinfos := s.NICInfo()
+ nicinfo, ok := nicinfos[id]
+ if !ok {
+ t.Fatalf("got nicinfos[%d] = _, %t, want _, true; nicinfos = %+v", id, ok, nicinfos)
+ }
+ if got, want := nicinfo.Context == test.want, true; got != want {
+ t.Fatalf("got nicinfo.Context == ctx = %t, want %t; nicinfo.Context = %p, ctx = %p", got, want, nicinfo.Context, test.want)
+ }
+ })
}
+}
- // Send a packet to address 3.
- buf := buffer.NewView(30)
- buf[0] = 3
- ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
+// TestNICAutoGenLinkLocalAddr tests the auto-generation of IPv6 link-local
+// addresses.
+func TestNICAutoGenLinkLocalAddr(t *testing.T) {
+ const nicID = 1
- select {
- case <-ep2.C:
- default:
- t.Fatal("Packet not forwarded")
+ var secretKey [header.OpaqueIIDSecretKeyMinBytes]byte
+ n, err := rand.Read(secretKey[:])
+ if err != nil {
+ t.Fatalf("rand.Read(_): %s", err)
}
-
- // Test that forwarding increments Tx stats correctly.
- if got, want := s.NICInfo()[2].Stats.Tx.Packets.Value(), uint64(1); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, want = %d", got, want)
+ if n != header.OpaqueIIDSecretKeyMinBytes {
+ t.Fatalf("expected rand.Read to read %d bytes, read %d bytes", header.OpaqueIIDSecretKeyMinBytes, n)
}
- if got, want := s.NICInfo()[2].Stats.Tx.Bytes.Value(), uint64(len(buf)); got != want {
- t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
+ nicNameFunc := func(_ tcpip.NICID, name string) string {
+ return name
}
-}
-// TestNICAutoGenAddr tests the auto-generation of IPv6 link-local addresses
-// (or lack there-of if disabled (default)). Note, DAD will be disabled in
-// these tests.
-func TestNICAutoGenAddr(t *testing.T) {
tests := []struct {
- name string
- autoGen bool
- linkAddr tcpip.LinkAddress
- shouldGen bool
+ name string
+ nicName string
+ autoGen bool
+ linkAddr tcpip.LinkAddress
+ iidOpts stack.OpaqueInterfaceIdentifierOptions
+ shouldGen bool
+ expectedAddr tcpip.Address
}{
{
- "Disabled",
- false,
- linkAddr1,
- false,
+ name: "Disabled",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr1,
+ shouldGen: false,
+ },
+ {
+ name: "Disabled without OIID options",
+ nicName: "nic1",
+ autoGen: false,
+ linkAddr: linkAddr1,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:],
+ },
+ shouldGen: false,
+ },
+
+ // Tests for EUI64 based addresses.
+ {
+ name: "EUI64 Enabled",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddr(linkAddr1),
+ },
+ {
+ name: "EUI64 Empty MAC",
+ autoGen: true,
+ shouldGen: false,
},
{
- "Enabled",
- true,
- linkAddr1,
- true,
+ name: "EUI64 Invalid MAC",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03",
+ shouldGen: false,
},
{
- "Nil MAC",
- true,
- tcpip.LinkAddress([]byte(nil)),
- false,
+ name: "EUI64 Multicast MAC",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03\x04\x05\x06",
+ shouldGen: false,
},
{
- "Empty MAC",
- true,
- tcpip.LinkAddress(""),
- false,
+ name: "EUI64 Unspecified MAC",
+ autoGen: true,
+ linkAddr: "\x00\x00\x00\x00\x00\x00",
+ shouldGen: false,
},
+
+ // Tests for Opaque IID based addresses.
+ {
+ name: "OIID Enabled",
+ nicName: "nic1",
+ autoGen: true,
+ linkAddr: linkAddr1,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("nic1", 0, secretKey[:]),
+ },
+ // These are all cases where we would not have generated a
+ // link-local address if opaque IIDs were disabled.
{
- "Invalid MAC",
- true,
- tcpip.LinkAddress("\x01\x02\x03"),
- false,
+ name: "OIID Empty MAC and empty nicName",
+ autoGen: true,
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:1],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("", 0, secretKey[:1]),
},
{
- "Multicast MAC",
- true,
- tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06"),
- false,
+ name: "OIID Invalid MAC",
+ nicName: "test",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:2],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test", 0, secretKey[:2]),
},
{
- "Unspecified MAC",
- true,
- tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"),
- false,
+ name: "OIID Multicast MAC",
+ nicName: "test2",
+ autoGen: true,
+ linkAddr: "\x01\x02\x03\x04\x05\x06",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ SecretKey: secretKey[:3],
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test2", 0, secretKey[:3]),
+ },
+ {
+ name: "OIID Unspecified MAC and nil SecretKey",
+ nicName: "test3",
+ autoGen: true,
+ linkAddr: "\x00\x00\x00\x00\x00\x00",
+ iidOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: nicNameFunc,
+ },
+ shouldGen: true,
+ expectedAddr: header.LinkLocalAddrWithOpaqueIID("test3", 0, nil),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1),
+ }
opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: test.autoGen,
+ NDPDisp: &ndpDisp,
+ OpaqueIIDOpts: test.iidOpts,
}
- if test.autoGen {
- // Only set opts.AutoGenIPv6LinkLocal when
- // test.autoGen is true because
- // opts.AutoGenIPv6LinkLocal should be false by
- // default.
- opts.AutoGenIPv6LinkLocal = true
+ e := channel.New(0, 1280, test.linkAddr)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
}
- e := channel.New(10, 1280, test.linkAddr)
- s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ // A new disabled NIC should not have any address, even if auto generation
+ // was enabled.
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
}
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ // Enabling the NIC should attempt auto-generation of a link-local
+ // address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
}
+ var expectedMainAddr tcpip.AddressWithPrefix
if test.shouldGen {
- // Should have auto-generated an address and
- // resolved immediately (DAD is disabled).
- if want := (tcpip.AddressWithPrefix{Address: header.LinkLocalAddr(test.linkAddr), PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ expectedMainAddr = tcpip.AddressWithPrefix{
+ Address: test.expectedAddr,
+ PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen,
+ }
+
+ // Should have auto-generated an address and resolved immediately (DAD
+ // is disabled).
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ if diff := checkAutoGenAddrEvent(e, expectedMainAddr, newAddr); diff != "" {
+ t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff)
+ }
+ default:
+ t.Fatal("expected addr auto gen event")
}
} else {
// Should not have auto-generated an address.
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Fatal("unexpectedly auto-generated an address")
+ default:
}
}
+
+ gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if gotMainAddr != expectedMainAddr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr)
+ }
+ })
+ }
+}
+
+// TestNoLinkLocalAutoGenForLoopbackNIC tests that IPv6 link-local addresses are
+// not auto-generated for loopback NICs.
+func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
+ const nicID = 1
+ const nicName = "nicName"
+
+ tests := []struct {
+ name string
+ opaqueIIDOpts stack.OpaqueInterfaceIdentifierOptions
+ }{
+ {
+ name: "IID From MAC",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{},
+ },
+ {
+ name: "Opaque IID",
+ opaqueIIDOpts: stack.OpaqueInterfaceIdentifierOptions{
+ NICNameFromID: func(_ tcpip.NICID, nicName string) string {
+ return nicName
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ OpaqueIIDOpts: test.opaqueIIDOpts,
+ }
+
+ e := loopback.New()
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Name: nicName}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ }
})
}
}
@@ -1974,6 +2493,8 @@ func TestNICAutoGenAddr(t *testing.T) {
// TestNICAutoGenAddrDoesDAD tests that the successful auto-generation of IPv6
// link-local addresses will only be assigned after the DAD process resolves.
func TestNICAutoGenAddrDoesDAD(t *testing.T) {
+ const nicID = 1
+
ndpDisp := ndpDispatcher{
dadC: make(chan ndpDADEvent),
}
@@ -1985,20 +2506,20 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
NDPDisp: &ndpDisp,
}
- e := channel.New(10, 1280, linkAddr1)
+ e := channel.New(int(ndpConfigs.DupAddrDetectTransmits), 1280, linkAddr1)
s := stack.New(opts)
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
// Address should not be considered bound to the
// NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (_, %v), want = (_, nil)", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = (%s, nil), want = (%s, nil)", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
linkLocalAddr := header.LinkLocalAddr(linkAddr1)
@@ -2012,25 +2533,16 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if e.err != nil {
- t.Fatal("got DAD error: ", e.err)
- }
- if e.nicid != 1 {
- t.Fatalf("got DAD event w/ nicid = %d, want = 1", e.nicid)
- }
- if e.addr != linkLocalAddr {
- t.Fatalf("got DAD event w/ addr = %s, want = %s", addr, linkLocalAddr)
- }
- if !e.resolved {
- t.Fatal("got DAD event w/ resolved = false, want = true")
+ if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
}
if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr, want)
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
}
}
@@ -2078,7 +2590,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
if err != nil {
- t.Fatalf("NewSubnet failed:", err)
+ t.Fatalf("NewSubnet failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
@@ -2092,11 +2604,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// permanentExpired kind.
r, err := s.FindRoute(1, "\x01", "\x02", fakeNetNumber, false)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ t.Fatalf("FindRoute failed: %v", err)
}
defer r.Release()
if err := s.RemoveAddress(1, "\x01"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
+ t.Fatalf("RemoveAddress failed: %v", err)
}
//
@@ -2108,7 +2620,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// Add some other address with peb set to
// FirstPrimaryEndpoint.
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x03", stack.FirstPrimaryEndpoint); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
}
@@ -2116,7 +2628,7 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// make sure the new peb was respected.
// (The address should just be promoted now).
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", ps); err != nil {
- t.Fatal("AddAddressWithOptions failed:", err)
+ t.Fatalf("AddAddressWithOptions failed: %v", err)
}
var primaryAddrs []tcpip.Address
for _, pa := range s.NICInfo()[1].ProtocolAddresses {
@@ -2149,11 +2661,11 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
// GetMainNICAddress; else, our original address
// should be returned.
if err := s.RemoveAddress(1, "\x03"); err != nil {
- t.Fatalf("RemoveAddress failed:", err)
+ t.Fatalf("RemoveAddress failed: %v", err)
}
addr, err = s.GetMainNICAddress(1, fakeNetNumber)
if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
+ t.Fatalf("s.GetMainNICAddress failed: %v", err)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
@@ -2169,3 +2681,858 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
}
}
}
+
+func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
+ const (
+ linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ nicID = 1
+ lifetimeSeconds = 9999
+ )
+
+ prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1)
+ prefix2, _, stableGlobalAddr2 := prefixSubnetAddr(1, linkAddr1)
+
+ var tempIIDHistory [header.IIDSize]byte
+ header.InitialTempIID(tempIIDHistory[:], nil, nicID)
+ tempGlobalAddr1 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr1.Address).Address
+ tempGlobalAddr2 := header.GenerateTempIPv6SLAACAddr(tempIIDHistory[:], stableGlobalAddr2.Address).Address
+
+ // Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
+ tests := []struct {
+ name string
+ slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix
+ nicAddrs []tcpip.Address
+ slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix
+ connectAddr tcpip.Address
+ expectedLocalAddr tcpip.Address
+ }{
+ // Test Rule 1 of RFC 6724 section 5.
+ {
+ name: "Same Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr1,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr1,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Same Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr1,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test Rule 2 of RFC 6724 section 5.
+ {
+ name: "Global most preferred (last address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: globalAddr1,
+ },
+ {
+ name: "Link Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred for link local multicast (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred for link local multicast (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (last address)",
+ nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Unique Local most preferred (first address)",
+ nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+
+ // Test Rule 7 of RFC 6724 section 5.
+ {
+ name: "Temp Global most preferred (last address)",
+ slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: tempGlobalAddr1,
+ },
+ {
+ name: "Temp Global most preferred (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ slaacPrefixForTempAddrAfterNICAddrAdd: prefix1,
+ connectAddr: globalAddr2,
+ expectedLocalAddr: tempGlobalAddr1,
+ },
+
+ // Test returning the endpoint that is closest to the front when
+ // candidate addresses are "equal" from the perspective of RFC 6724
+ // section 5.
+ {
+ name: "Unique Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: uniqueLocalAddr1,
+ },
+ {
+ name: "Link Local for Global",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: globalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local for Unique Local",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2},
+ connectAddr: uniqueLocalAddr2,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Temp Global for Global",
+ slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1,
+ slaacPrefixForTempAddrAfterNICAddrAdd: prefix2,
+ connectAddr: globalAddr1,
+ expectedLocalAddr: tempGlobalAddr2,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ AutoGenGlobalAddresses: true,
+ AutoGenTempGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDispatcher{},
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv6EmptySubnet,
+ Gateway: llAddr3,
+ NIC: nicID,
+ }})
+ s.AddLinkAddress(nicID, llAddr3, linkAddr3)
+
+ if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) {
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
+ }
+
+ for _, a := range test.nicAddrs {
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, a); err != nil {
+ t.Errorf("s.AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, a, err)
+ }
+ }
+
+ if test.slaacPrefixForTempAddrAfterNICAddrAdd != (tcpip.AddressWithPrefix{}) {
+ e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrAfterNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds))
+ }
+
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr {
+ t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr)
+ }
+ })
+ }
+}
+
+func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+ broadcastAddr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
+ }
+
+ // Enabling the NIC should add the IPv4 broadcast address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if !containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr)
+ }
+ }
+
+ // Disabling the NIC should remove the IPv4 broadcast address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ {
+ allStackAddrs := s.AllAddresses()
+ if allNICAddrs, ok := allStackAddrs[nicID]; !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ } else if containsAddr(allNICAddrs, broadcastAddr) {
+ t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr)
+ }
+ }
+}
+
+// TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval tests that removing an IPv6
+// address after leaving its solicited node multicast address does not result in
+// an error.
+func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, addr1, err)
+ }
+
+ // The NIC should have joined addr1's solicited node multicast address.
+ snmc := header.SolicitedNodeAddr(addr1)
+ in, err := s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if !in {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, snmc)
+ }
+
+ if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, snmc); err != nil {
+ t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, snmc, err)
+ }
+ in, err = s.IsInGroup(nicID, snmc)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, snmc, err)
+ }
+ if in {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, snmc)
+ }
+
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
+ }
+}
+
+func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ }{
+ {
+ name: "IPv6 All-Nodes",
+ proto: header.IPv6ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ name: "IPv4 All-Systems",
+ proto: header.IPv4ProtocolNumber,
+ addr: header.IPv4AllSystems,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ // Should not be in the multicast group yet because the NIC has not been
+ // enabled yet.
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // The multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr)
+ }
+
+ // Leaving the group before disabling the NIC should not cause an error.
+ if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil {
+ t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err)
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+
+ if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err)
+ } else if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr)
+ }
+ })
+ }
+}
+
+// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
+// was disabled have DAD performed on them when the NIC is enabled.
+func TestDoDADWhenNICEnabled(t *testing.T) {
+ const dadTransmits = 1
+ const retransmitTimer = time.Second
+ const nicID = 1
+
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent),
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ DupAddrDetectTransmits: dadTransmits,
+ RetransmitTimer: retransmitTimer,
+ },
+ NDPDisp: &ndpDisp,
+ }
+
+ e := channel.New(dadTransmits, 1280, linkAddr1)
+ s := stack.New(opts)
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: llAddr1,
+ PrefixLen: 128,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ // Address should be in the list of all addresses.
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should be tentative so it should not be a main address.
+ got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Enabling the NIC should start DAD for the address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); got != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ }
+
+ // Wait for DAD to resolve.
+ select {
+ case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
+ t.Fatal("timed out waiting for DAD resolution")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+
+ // Enabling the NIC again should be a no-op.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
+ t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
+ }
+ got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if got != addr.AddressWithPrefix {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ }
+}
+
+func TestStackReceiveBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ rs stack.ReceiveBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.rs); err != tc.err {
+ t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err)
+ }
+ var rs stack.ReceiveBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&rs); err != nil {
+ t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err)
+ }
+ if got, want := rs, tc.rs; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestStackSendBufferSizeOption(t *testing.T) {
+ const sMin = stack.MinBufferSize
+ testCases := []struct {
+ name string
+ ss stack.SendBufferSizeOption
+ err *tcpip.Error
+ }{
+ // Invalid configurations.
+ {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+
+ // Valid Configurations
+ {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ defer s.Close()
+ if err := s.SetOption(tc.ss); err != tc.err {
+ t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
+ }
+ var ss stack.SendBufferSizeOption
+ if tc.err == nil {
+ if err := s.Option(&ss); err != nil {
+ t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
+ }
+ if got, want := ss, tc.ss; got != want {
+ t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want)
+ }
+ }
+ })
+ }
+}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID1 = 1
+ )
+
+ defaultAddr := tcpip.AddressWithPrefix{
+ Address: header.IPv4Any,
+ PrefixLen: 0,
+ }
+ defaultSubnet := defaultAddr.Subnet()
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ expectedRoute stack.Route
+ }{
+ // Broadcast to a locally attached subnet populates the broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: ipv4SubnetBcast,
+ RemoteLinkAddress: header.EthernetBroadcastAddress,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /31 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix31.Address,
+ RemoteAddress: ipv4Subnet31Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a locally attached /32 subnet does not populate the
+ // broadcast MAC.
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4AddrPrefix32.Address,
+ RemoteAddress: ipv4Subnet32Bcast,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv6Addr.Address,
+ RemoteAddress: ipv6SubnetBcast,
+ NetProto: header.IPv6ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to a remote subnet in the route table is send to the next-hop
+ // gateway.
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ // Broadcast to an unknown subnet follows the default route. Note that this
+ // is essentially just routing an unknown destination IP, because w/o any
+ // subnet prefix information a subnet broadcast address is just a normal IP.
+ {
+ name: "IPv4 Broadcast to unknown subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: defaultSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ expectedRoute: stack.Route{
+ LocalAddress: ipv4Addr.Address,
+ RemoteAddress: remNetSubnetBcast,
+ NextHop: ipv4Gateway,
+ NetProto: header.IPv4ProtocolNumber,
+ Loop: stack.PacketOut,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err)
+ } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" {
+ t.Errorf("route mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestResolveWith(t *testing.T) {
+ const (
+ unspecifiedNICID = 0
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ })
+ ep := channel.New(0, defaultMTU, "")
+ ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ if err := s.CreateNIC(nicID, ep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ addr := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ },
+ }
+ if err := s.AddProtocolAddress(nicID, addr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, addr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}})
+
+ remoteAddr := tcpip.Address(net.ParseIP("192.168.1.59").To4())
+ r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err)
+ }
+ defer r.Release()
+
+ // Should initially require resolution.
+ if !r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = false, want = true")
+ }
+
+ // Manually resolving the route should no longer require resolution.
+ r.ResolveWith("\x01")
+ if r.IsResolutionRequired() {
+ t.Fatal("got r.IsResolutionRequired() = true, want = false")
+ }
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 97a1aec4b..b902c6ca9 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -17,12 +17,12 @@ package stack
import (
"fmt"
"math/rand"
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
)
type protocolIDs struct {
@@ -35,28 +35,130 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]*endpointsByNic
+ endpoints map[TransportEndpointID]*endpointsByNIC
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
-type endpointsByNic struct {
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ return
+ }
+ if !epsByNIC.unregisterEndpoint(bindToDevice, ep, flags) {
+ return
+ }
+ delete(eps.endpoints, id)
+}
+
+func (eps *transportEndpoints) transportEndpoints() []TransportEndpoint {
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+ es := make([]TransportEndpoint, 0, len(eps.endpoints))
+ for _, e := range eps.endpoints {
+ es = append(es, e.transportEndpoints()...)
+ }
+ return es
+}
+
+// iterEndpointsLocked yields all endpointsByNIC in eps that match id, in
+// descending order of match quality. If a call to yield returns false,
+// iterEndpointsLocked stops iteration and returns immediately.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) iterEndpointsLocked(id TransportEndpointID, yield func(*endpointsByNIC) bool) {
+ // Try to find a match with the id as provided.
+ if ep, ok := eps.endpoints[id]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep, ok := eps.endpoints[nid]; ok {
+ if !yield(ep) {
+ return
+ }
+ }
+}
+
+// findAllEndpointsLocked returns all endpointsByNIC in eps that match id, in
+// descending order of match quality.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findAllEndpointsLocked(id TransportEndpointID) []*endpointsByNIC {
+ var matchedEPs []*endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEPs = append(matchedEPs, ep)
+ return true
+ })
+ return matchedEPs
+}
+
+// findEndpointLocked returns the endpoint that most closely matches the given id.
+//
+// Preconditions: eps.mu must be locked.
+func (eps *transportEndpoints) findEndpointLocked(id TransportEndpointID) *endpointsByNIC {
+ var matchedEP *endpointsByNIC
+ eps.iterEndpointsLocked(id, func(ep *endpointsByNIC) bool {
+ matchedEP = ep
+ return false
+ })
+ return matchedEP
+}
+
+type endpointsByNIC struct {
mu sync.RWMutex
endpoints map[tcpip.NICID]*multiPortEndpoint
// seed is a random secret for a jenkins hash.
seed uint32
}
+func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+ var eps []TransportEndpoint
+ for _, ep := range epsByNIC.endpoints {
+ eps = append(eps, ep.transportEndpoints()...)
+ }
+ return eps
+}
+
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- epsByNic.mu.RLock()
+func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
+ epsByNIC.mu.RLock()
- mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
if !ok {
- if mpep, ok = epsByNic.endpoints[0]; !ok {
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
}
@@ -64,24 +166,30 @@ func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, v
// If this is a broadcast or multicast datagram, deliver the datagram to all
// endpoints bound to the right device.
if isMulticastOrBroadcast(id.LocalAddress) {
- mpep.handlePacketAll(r, id, vv)
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ mpep.handlePacketAll(r, id, pkt)
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
return
}
-
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
- epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ transEP := selectEndpoint(id, mpep, epsByNIC.seed)
+ if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue {
+ queuedProtocol.QueuePacket(r, transEP, id, pkt)
+ epsByNIC.mu.RUnlock()
+ return
+ }
+
+ transEP.HandlePacket(r, id, pkt)
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- epsByNic.mu.RLock()
- defer epsByNic.mu.RUnlock()
+func (epsByNIC *endpointsByNIC) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, pkt *PacketBuffer) {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
- mpep, ok := epsByNic.endpoints[n.ID()]
+ mpep, ok := epsByNIC.endpoints[n.ID()]
if !ok {
- mpep, ok = epsByNic.endpoints[0]
+ mpep, ok = epsByNIC.endpoints[0]
}
if !ok {
return
@@ -91,55 +199,52 @@ func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpoint
// broadcast like we are doing with handlePacket above?
// multiPortEndpoints are guaranteed to have at least one element.
- selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+ selectEndpoint(id, mpep, epsByNIC.seed).HandleControlPacket(id, typ, extra, pkt)
}
// registerEndpoint returns true if it succeeds. It fails and returns
// false if ep already has an element with the same key.
-func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
- epsByNic.mu.Lock()
- defer epsByNic.mu.Unlock()
+func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, t TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
- if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
- // There was already a bind.
- return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
+ if !ok {
+ multiPortEp = &multiPortEndpoint{
+ demux: d,
+ netProto: netProto,
+ transProto: transProto,
+ }
+ epsByNIC.endpoints[bindToDevice] = multiPortEp
}
- // This is a new binding.
- multiPortEp := &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.reuse = reusePort
- epsByNic.endpoints[bindToDevice] = multiPortEp
- return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ return multiPortEp.singleRegisterEndpoint(t, flags)
}
-// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
-func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
- epsByNic.mu.Lock()
- defer epsByNic.mu.Unlock()
- multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNIC.mu.RLock()
+ defer epsByNIC.mu.RUnlock()
+
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
- return false
- }
- if multiPortEp.unregisterEndpoint(t) {
- delete(epsByNic.endpoints, bindToDevice)
+ return nil
}
- return len(epsByNic.endpoints) == 0
+
+ return multiPortEp.singleCheckEndpoint(flags)
}
-// unregisterEndpoint unregisters the endpoint with the given id such that it
-// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
- eps.mu.Lock()
- defer eps.mu.Unlock()
- epsByNic, ok := eps.endpoints[id]
+// unregisterEndpoint returns true if endpointsByNIC has to be unregistered.
+func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool {
+ epsByNIC.mu.Lock()
+ defer epsByNIC.mu.Unlock()
+ multiPortEp, ok := epsByNIC.endpoints[bindToDevice]
if !ok {
- return
+ return false
}
- if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
- return
+ if multiPortEp.unregisterEndpoint(t, flags) {
+ delete(epsByNIC.endpoints, bindToDevice)
}
- delete(eps.endpoints, id)
+ return len(epsByNIC.endpoints) == 0
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
@@ -149,17 +254,33 @@ func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep Tra
// newTransportDemuxer.
type transportDemuxer struct {
// protocol is immutable.
- protocol map[protocolIDs]*transportEndpoints
+ protocol map[protocolIDs]*transportEndpoints
+ queuedProtocols map[protocolIDs]queuedTransportProtocol
+}
+
+// queuedTransportProtocol if supported by a protocol implementation will cause
+// the dispatcher to delivery packets to the QueuePacket method instead of
+// calling HandlePacket directly on the endpoint.
+type queuedTransportProtocol interface {
+ QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer)
}
func newTransportDemuxer(stack *Stack) *transportDemuxer {
- d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+ d := &transportDemuxer{
+ protocol: make(map[protocolIDs]*transportEndpoints),
+ queuedProtocols: make(map[protocolIDs]queuedTransportProtocol),
+ }
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
- d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]*endpointsByNic),
+ protoIDs := protocolIDs{netProto, proto}
+ d.protocol[protoIDs] = &transportEndpoints{
+ endpoints: make(map[TransportEndpointID]*endpointsByNIC),
+ }
+ qTransProto, isQueued := (stack.transportProtocols[proto].proto).(queuedTransportProtocol)
+ if isQueued {
+ d.queuedProtocols[protoIDs] = qTransProto
}
}
}
@@ -169,10 +290,21 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, flags, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, flags, bindToDevice)
+ return err
+ }
+ }
+
+ return nil
+}
+
+// checkEndpoint checks if an endpoint can be registered with the dispatcher.
+func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ for _, n := range netProtos {
+ if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil {
return err
}
}
@@ -183,12 +315,29 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
// multiPortEndpoint is a container for TransportEndpoints which are bound to
// the same pair of address and port. endpointsArr always has at least one
// element.
+//
+// FIXME(gvisor.dev/issue/873): Restore this properly. Currently, we just save
+// this to ensure that the underlying endpoints get saved/restored, but not not
+// use the restored copy.
+//
+// +stateify savable
type multiPortEndpoint struct {
- mu sync.RWMutex
- endpointsArr []TransportEndpoint
- endpointsMap map[TransportEndpoint]int
- // reuse indicates if more than one endpoint is allowed.
- reuse bool
+ mu sync.RWMutex `state:"nosave"`
+ demux *transportDemuxer
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+
+ // endpoints stores the transport endpoints in the order in which they
+ // were bound. This is required for UDP SO_REUSEADDR.
+ endpoints []TransportEndpoint
+ flags ports.FlagCounter
+}
+
+func (ep *multiPortEndpoint) transportEndpoints() []TransportEndpoint {
+ ep.mu.RLock()
+ eps := append([]TransportEndpoint(nil), ep.endpoints...)
+ ep.mu.RUnlock()
+ return eps
}
// reciprocalScale scales a value into range [0, n).
@@ -203,8 +352,12 @@ func reciprocalScale(val, n uint32) uint32 {
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
- if len(mpep.endpointsArr) == 1 {
- return mpep.endpointsArr[0]
+ if len(mpep.endpoints) == 1 {
+ return mpep.endpoints[0]
+ }
+
+ if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent {
+ return mpep.endpoints[len(mpep.endpoints)-1]
}
payload := []byte{
@@ -220,72 +373,89 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
- return mpep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpoints)))
+ return mpep.endpoints[idx]
}
-func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) {
ep.mu.RLock()
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy except for
- // the final one.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
+ queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}]
+ // HandlePacket takes ownership of pkt, so each endpoint needs
+ // its own copy except for the final one.
+ for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] {
+ if mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone())
+ } else {
+ endpoint.HandlePacket(r, id, pkt.Clone())
}
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+ }
+ if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue {
+ queuedProtocol.QueuePacket(r, endpoint, id, pkt)
+ } else {
+ endpoint.HandlePacket(r, id, pkt)
}
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags ports.Flags) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if len(ep.endpointsArr) > 0 {
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
+ if len(ep.endpoints) != 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ ep.endpoints = append(ep.endpoints, t)
+ ep.flags.AddRef(bits)
+
+ return nil
+}
+
+func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error {
+ ep.mu.RLock()
+ defer ep.mu.RUnlock()
+
+ bits := flags.Bits() & ports.MultiBindFlagMask
+
+ if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if !ep.reuse || !reusePort {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
return tcpip.ErrPortInUse
}
}
- // A new endpoint is added into endpointsArr and its index there is saved in
- // endpointsMap. This will allow us to remove endpoint from the array fast.
- ep.endpointsMap[t] = len(ep.endpointsArr)
- ep.endpointsArr = append(ep.endpointsArr, t)
return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
-func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
+func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool {
ep.mu.Lock()
defer ep.mu.Unlock()
- idx, ok := ep.endpointsMap[t]
- if !ok {
- return false
- }
- delete(ep.endpointsMap, t)
- l := len(ep.endpointsArr)
- if l > 1 {
- // The last endpoint in endpointsArr is moved instead of the deleted one.
- lastEp := ep.endpointsArr[l-1]
- ep.endpointsArr[idx] = lastEp
- ep.endpointsMap[lastEp] = idx
- ep.endpointsArr = ep.endpointsArr[0 : l-1]
- return false
+ for i, endpoint := range ep.endpoints {
+ if endpoint == t {
+ copy(ep.endpoints[i:], ep.endpoints[i+1:])
+ ep.endpoints[len(ep.endpoints)-1] = nil
+ ep.endpoints = ep.endpoints[:len(ep.endpoints)-1]
+
+ ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask)
+ break
+ }
}
- return true
+ return len(ep.endpoints) == 0
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
- // TODO(eyalsoha): Why?
- reusePort = false
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
@@ -296,82 +466,109 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol
eps.mu.Lock()
defer eps.mu.Unlock()
- if epsByNic, ok := eps.endpoints[id]; ok {
- // There was already a binding.
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ epsByNIC = &endpointsByNIC{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
+ }
+ eps.endpoints[id] = epsByNIC
+ }
+
+ return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice)
+}
+
+func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
}
- // This is a new binding.
- epsByNic := &endpointsByNic{
- endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
- seed: rand.Uint32(),
+ eps.mu.RLock()
+ defer eps.mu.RUnlock()
+
+ epsByNIC, ok := eps.endpoints[id]
+ if !ok {
+ return nil
}
- eps.endpoints[id] = epsByNic
- return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
+ return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+ if id.RemotePort != 0 {
+ // SO_REUSEPORT only applies to bound/listening endpoints.
+ flags.LoadBalanced = false
+ }
+
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep, bindToDevice)
+ eps.unregisterEndpoint(id, ep, flags, bindToDevice)
}
}
}
-var loopbackSubnet = func() tcpip.Subnet {
- sn, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
- if err != nil {
- panic(err)
- }
- return sn
-}()
-
// deliverPacket attempts to find one or more matching transport endpoints, and
-// then, if matches are found, delivers the packet to them. Returns true if it
-// found one or more endpoints, false otherwise.
-func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView, id TransportEndpointID) bool {
+// then, if matches are found, delivers the packet to them. Returns true if
+// the packet no longer needs to be handled.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
- eps.mu.RLock()
-
- // Determine which transport endpoint or endpoints to deliver this packet to.
- // If the packet is a broadcast or multicast, then find all matching
+ // If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints.
- var destEps []*endpointsByNic
if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, vv, id)
- } else if ep := d.findEndpointLocked(eps, vv, id); ep != nil {
- destEps = append(destEps, ep)
+ eps.mu.RLock()
+ destEPs := eps.findAllEndpointsLocked(id)
+ eps.mu.RUnlock()
+ // Fail if we didn't find at least one matching transport endpoint.
+ if len(destEPs) == 0 {
+ r.Stats().UDP.UnknownPortErrors.Increment()
+ return false
+ }
+ // handlePacket takes ownership of pkt, so each endpoint needs its own
+ // copy except for the final one.
+ for _, ep := range destEPs[:len(destEPs)-1] {
+ ep.handlePacket(r, id, pkt.Clone())
+ }
+ destEPs[len(destEPs)-1].handlePacket(r, id, pkt)
+ return true
}
- eps.mu.RUnlock()
+ // If the packet is a TCP packet with a non-unicast source or destination
+ // address, then do nothing further and instruct the caller to do the same.
+ if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single source and a
+ // single destination; the addresses must be unicast.
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
+ }
- // Fail if we didn't find at least one matching transport endpoint.
- if len(destEps) == 0 {
- // UDP packet could not be delivered to an unknown destination port.
+ eps.mu.RLock()
+ ep := eps.findEndpointLocked(id)
+ eps.mu.RUnlock()
+ if ep == nil {
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
-
- // Deliver the packet.
- for _, ep := range destEps {
- ep.handlePacket(r, id, vv)
- }
-
+ ep.handlePacket(r, id, pkt)
return true
}
// deliverRawPacket attempts to deliver the given packet and returns whether it
// was delivered successfully.
-func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
@@ -385,7 +582,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
for _, rawEP := range eps.rawEndpoints {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
- rawEP.HandlePacket(r, buffer.NewViewFromBytes(netHeader), vv.ToView().ToVectorisedView())
+ rawEP.HandlePacket(r, pkt)
foundRaw = true
}
eps.mu.RUnlock()
@@ -395,67 +592,51 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, pkt *PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
}
- // Try to find the endpoint.
eps.mu.RLock()
- ep := d.findEndpointLocked(eps, vv, id)
+ ep := eps.findEndpointLocked(id)
eps.mu.RUnlock()
-
- // Fail if we didn't find one.
if ep == nil {
return false
}
- // Deliver the packet.
- ep.handleControlPacket(n, id, typ, extra, vv)
-
+ ep.handleControlPacket(n, id, typ, extra, pkt)
return true
}
-func (d *transportDemuxer) findAllEndpointsLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) []*endpointsByNic {
- var matchedEPs []*endpointsByNic
- // Try to find a match with the id as provided.
- if ep, ok := eps.endpoints[id]; ok {
- matchedEPs = append(matchedEPs, ep)
+// findTransportEndpoint find a single endpoint that most closely matches the provided id.
+func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint {
+ eps, ok := d.protocol[protocolIDs{netProto, transProto}]
+ if !ok {
+ return nil
}
- // Try to find a match with the id minus the local address.
- nid := id
-
- nid.LocalAddress = ""
- if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ eps.mu.RLock()
+ epsByNIC := eps.findEndpointLocked(id)
+ if epsByNIC == nil {
+ eps.mu.RUnlock()
+ return nil
}
- // Try to find a match with the id minus the remote part.
- nid.LocalAddress = id.LocalAddress
- nid.RemoteAddress = ""
- nid.RemotePort = 0
- if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
- }
+ epsByNIC.mu.RLock()
+ eps.mu.RUnlock()
- // Try to find a match with only the local port.
- nid.LocalAddress = ""
- if ep, ok := eps.endpoints[nid]; ok {
- matchedEPs = append(matchedEPs, ep)
+ mpep, ok := epsByNIC.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNIC.endpoints[0]; !ok {
+ epsByNIC.mu.RUnlock() // Don't use defer for performance reasons.
+ return nil
+ }
}
- return matchedEPs
-}
-
-// findEndpointLocked returns the endpoint that most closely matches the given
-// id.
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
- if matchedEPs := d.findAllEndpointsLocked(eps, vv, id); len(matchedEPs) > 0 {
- return matchedEPs[0]
- }
- return nil
+ ep := selectEndpoint(id, mpep, epsByNIC.seed)
+ epsByNIC.mu.RUnlock()
+ return ep
}
// registerRawEndpoint registers the given endpoint with the dispatcher such
@@ -469,8 +650,8 @@ func (d *transportDemuxer) registerRawEndpoint(netProto tcpip.NetworkProtocolNum
}
eps.mu.Lock()
- defer eps.mu.Unlock()
eps.rawEndpoints = append(eps.rawEndpoints, ep)
+ eps.mu.Unlock()
return nil
}
@@ -484,15 +665,22 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
}
eps.mu.Lock()
- defer eps.mu.Unlock()
for i, rawEP := range eps.rawEndpoints {
if rawEP == ep {
- eps.rawEndpoints = append(eps.rawEndpoints[:i], eps.rawEndpoints[i+1:]...)
- return
+ lastIdx := len(eps.rawEndpoints) - 1
+ eps.rawEndpoints[i] = eps.rawEndpoints[lastIdx]
+ eps.rawEndpoints[lastIdx] = nil
+ eps.rawEndpoints = eps.rawEndpoints[:lastIdx]
+ break
}
}
+ eps.mu.Unlock()
}
func isMulticastOrBroadcast(addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
}
+
+func isUnicast(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 210233dc0..1339edc2d 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -25,96 +25,65 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
- stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ testSrcAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testDstAddrV6 = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
- stackAddr = "\x0a\x00\x00\x01"
- stackPort = 1234
- testPort = 4096
+ testSrcAddrV4 = "\x0a\x00\x00\x01"
+ testDstAddrV4 = "\x0a\x00\x00\x02"
+
+ testDstPort = 1234
+ testSrcPort = 4096
)
type testContext struct {
- t *testing.T
- linkEPs map[string]*channel.Endpoint
+ linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
-
- ep tcpip.Endpoint
- wq waiter.Queue
+ wq waiter.Queue
}
-func (c *testContext) cleanup() {
- if c.ep != nil {
- c.ep.Close()
- }
-}
-
-func (c *testContext) createV6Endpoint(v6only bool) {
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.ep.SetSockOpt(v); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
-}
-
-// newDualTestContextMultiNic creates the testing context and also linkEpNames
-// named NICs.
-func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
+func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- linkEPs := make(map[string]*channel.Endpoint)
- for i, linkEpName := range linkEpNames {
- channelEP := channel.New(256, mtu, "")
- nicid := tcpip.NICID(i + 1)
- if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ linkEps := make(map[tcpip.NICID]*channel.Endpoint)
+ for _, linkEpID := range linkEpIDs {
+ channelEp := channel.New(256, mtu, "")
+ if err := s.CreateNIC(linkEpID, channelEp); err != nil {
+ t.Fatalf("CreateNIC failed: %s", err)
}
- linkEPs[linkEpName] = channelEP
+ linkEps[linkEpID] = channelEp
- if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress IPv4 failed: %v", err)
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %s", err)
}
- if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress IPv6 failed: %v", err)
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %s", err)
}
}
s.SetRouteTable([]tcpip.Route{
- {
- Destination: header.IPv4EmptySubnet,
- NIC: 1,
- },
- {
- Destination: header.IPv6EmptySubnet,
- NIC: 1,
- },
+ {Destination: header.IPv4EmptySubnet, NIC: 1},
+ {Destination: header.IPv6EmptySubnet, NIC: 1},
})
return &testContext{
- t: t,
s: s,
- linkEPs: linkEPs,
+ linkEps: linkEps,
}
}
type headers struct {
- srcPort uint16
- dstPort uint16
+ srcPort, dstPort uint16
}
func newPayload() []byte {
@@ -125,7 +94,47 @@ func newPayload() []byte {
return b
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TOS: 0x80,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: testSrcAddrV4,
+ DstAddr: testDstAddrV4,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ })
+ c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt)
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -136,8 +145,8 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: testV6Addr,
- DstAddr: stackV6Addr,
+ SrcAddr: testSrcAddrV6,
+ DstAddr: testDstAddrV6,
})
// Initialize the UDP header.
@@ -149,14 +158,17 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
})
// Calculate the UDP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
// Calculate the UDP checksum and set it.
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ })
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt)
}
func TestTransportDemuxerRegister(t *testing.T) {
@@ -171,95 +183,105 @@ func TestTransportDemuxerRegister(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
- TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
- t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tEP, ok := ep.(stack.TransportEndpoint)
+ if !ok {
+ t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
+ }
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
}
}
-// TestReuseBindToDevice injects varied packets on input devices and checks that
+// TestBindToDeviceDistribution injects varied packets on input devices and checks that
// the distribution of packets received matches expectations.
-func TestDistribution(t *testing.T) {
+func TestBindToDeviceDistribution(t *testing.T) {
type endpointSockopts struct {
- reuse int
- bindToDevice string
+ reuse bool
+ bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
// endpoints will received the inject packets.
endpoints []endpointSockopts
- // wantedDistribution is the wanted ratio of packets received on each
+ // wantDistributions is the want ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[string][]float64
+ wantDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
- endpointSockopts{1, ""},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
+ {reuse: true, bindToDevice: 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
- "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ 1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{0, "dev0"},
- endpointSockopts{0, "dev1"},
- endpointSockopts{0, "dev2"},
+ {reuse: false, bindToDevice: 1},
+ {reuse: false, bindToDevice: 2},
+ {reuse: false, bindToDevice: 3},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
- "dev0": []float64{1, 0, 0},
+ 1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
- "dev1": []float64{0, 1, 0},
+ 2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
- "dev2": []float64{0, 0, 1},
+ 3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev0"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, "dev1"},
- endpointSockopts{1, ""},
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 1},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 2},
+ {reuse: true, bindToDevice: 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
- "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ 1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
- "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
- "dev999": []float64{0, 0, 0, 0, 0, 1},
+ 1000: {0, 0, 0, 0, 0, 1},
},
},
} {
- t.Run(test.name, func(t *testing.T) {
- for device, wantedDistribution := range test.wantedDistributions {
- t.Run(device, func(t *testing.T) {
- var devices []string
- for d := range test.wantedDistributions {
+ for protoName, netProtoNum := range map[string]tcpip.NetworkProtocolNumber{
+ "IPv4": ipv4.ProtocolNumber,
+ "IPv6": ipv6.ProtocolNumber,
+ } {
+ for device, wantDistribution := range test.wantDistributions {
+ t.Run(test.name+protoName+string(device), func(t *testing.T) {
+ var devices []tcpip.NICID
+ for d := range test.wantDistributions {
devices = append(devices, d)
}
- c := newDualTestContextMultiNic(t, defaultMTU, devices)
- defer c.cleanup()
-
- c.createV6Endpoint(false)
+ c := newDualTestContextMultiNIC(t, defaultMTU, devices)
eps := make(map[tcpip.Endpoint]int)
@@ -273,9 +295,9 @@ func TestDistribution(t *testing.T) {
defer close(ch)
var err *tcpip.Error
- ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, netProtoNum, &wq)
if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps[ep] = i
@@ -286,22 +308,31 @@ func TestDistribution(t *testing.T) {
}(ep)
defer ep.Close()
- reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
- if err := ep.SetSockOpt(reusePortOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ if err := ep.SetSockOptBool(tcpip.ReusePortOption, endpoint.reuse); err != nil {
+ t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err)
}
bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
- c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err)
+ }
+
+ var dstAddr tcpip.Address
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ dstAddr = testDstAddrV4
+ case ipv6.ProtocolNumber:
+ dstAddr = testDstAddrV6
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
}
- if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
}
}
npackets := 100000
nports := 10000
- if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ if got, want := len(test.endpoints), len(wantDistribution); got != want {
t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
}
ports := make(map[uint16]tcpip.Endpoint)
@@ -310,17 +341,22 @@ func TestDistribution(t *testing.T) {
// Send a packet.
port := uint16(i % nports)
payload := newPayload()
- c.sendV6Packet(payload,
- &headers{
- srcPort: testPort + port,
- dstPort: stackPort},
- device)
+ hdrs := &headers{
+ srcPort: testSrcPort + port,
+ dstPort: testDstPort,
+ }
+ switch netProtoNum {
+ case ipv4.ProtocolNumber:
+ c.sendV4Packet(payload, hdrs, device)
+ case ipv6.ProtocolNumber:
+ c.sendV6Packet(payload, hdrs, device)
+ default:
+ t.Fatalf("unexpected protocol number: %d", netProtoNum)
+ }
- var addr tcpip.FullAddress
ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ if _, _, err := ep.Read(nil); err != nil {
+ t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
if i < nports {
@@ -336,17 +372,17 @@ func TestDistribution(t *testing.T) {
// Check that a packet distribution is as expected.
for ep, i := range eps {
- wantedRatio := wantedDistribution[i]
- wantedRecv := wantedRatio * float64(npackets)
+ wantRatio := wantDistribution[i]
+ wantRecv := wantRatio * float64(npackets)
actualRecv := stats[ep]
actualRatio := float64(stats[ep]) / float64(npackets)
// The deviation is less than 10%.
- if math.Abs(actualRatio-wantedRatio) > 0.05 {
- t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ if math.Abs(actualRatio-wantRatio) > 0.05 {
+ t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
}
}
})
}
- })
+ }
}
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 6d3daed24..fa4b14ba6 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -19,9 +19,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -43,6 +43,7 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
+ uniqueID uint64
// acceptQueue is non-nil iff bound.
acceptQueue []fakeTransportEndpoint
@@ -56,8 +57,14 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto}
+func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
+}
+
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
}
func (f *fakeTransportEndpoint) Close() {
@@ -77,12 +84,16 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, nil, tcpip.ErrNoRoute
}
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
- if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
+ Data: buffer.View(v).ToVectorisedView(),
+ })
+ _ = pkt.TransportHeader().Push(fakeTransHeaderLen)
+ if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, pkt); err != nil {
return 0, nil, err
}
@@ -98,13 +109,23 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptBool(tcpip.SockOptBool, bool) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// SetSockOptInt sets a socket option. Currently not supported.
-func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOptInt, int) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (*fakeTransportEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -134,7 +155,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, false /* reuse */, 0 /* bindToDevice */)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -144,6 +165,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return nil
}
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+ return f.uniqueID
+}
+
func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -175,8 +200,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false, /* reuse */
- 0, /* bindtoDevice */
+ ports.Flags{},
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -192,7 +217,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) {
// Increment the number of received packets.
f.proto.packetCount++
if f.acceptQueue != nil {
@@ -209,7 +234,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE
}
}
-func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) {
// Increment the number of received control packets.
f.proto.controlCount++
}
@@ -218,15 +243,15 @@ func (f *fakeTransportEndpoint) State() uint32 {
return 0
}
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {
-}
+func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) {
- return iptables.IPTables{}, nil
+func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
+ return stack.IPTables{}, nil
}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {
-}
+func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+
+func (f *fakeTransportEndpoint) Wait() {}
type fakeTransportGoodOption bool
@@ -251,10 +276,10 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
}
func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto), nil
+ return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -266,7 +291,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
@@ -292,6 +317,21 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
+// Parse implements TransportProtocol.Parse.
+func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.TransportHeader().Consume(fakeTransHeaderLen)
+ return ok
+}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
@@ -337,7 +377,9 @@ func TestTransportReceive(t *testing.T) {
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -346,7 +388,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
@@ -355,7 +399,9 @@ func TestTransportReceive(t *testing.T) {
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeTrans.packetCount != 1 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
}
@@ -408,7 +454,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -417,7 +465,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
@@ -426,7 +476,9 @@ func TestTransportControlReceive(t *testing.T) {
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ linkEP.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
if fakeTrans.controlCount != 1 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
}
@@ -579,7 +631,9 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- ep2.Inject(fakeNetNumber, req.ToVectorisedView())
+ ep2.InjectInbound(fakeNetNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: req.ToVectorisedView(),
+ }))
aep, _, err := ep.Accept()
if err != nil || aep == nil {
@@ -591,17 +645,16 @@ func TestTransportForwarding(t *testing.T) {
t.Fatalf("Write failed: %v", err)
}
- var p channel.PacketInfo
- select {
- case p = <-ep2.C:
- default:
+ p, ok := ep2.Read()
+ if !ok {
t.Fatal("Response packet not forwarded")
}
- if dst := p.Header[0]; dst != 3 {
+ nh := stack.PayloadSince(p.Pkt.NetworkHeader())
+ if dst := nh[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := p.Header[1]; src != 1 {
+ if src := nh[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 03be7d3d4..07c85ce59 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -35,15 +35,17 @@ import (
"reflect"
"strconv"
"strings"
- "sync"
"sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/waiter"
)
+// Using header.IPv4AddressSize would cause an import cycle.
+const ipv4AddressSize = 4
+
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
@@ -111,6 +113,71 @@ var (
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
)
+var messageToError map[string]*Error
+
+var populate sync.Once
+
+// StringToError converts an error message to the error.
+func StringToError(s string) *Error {
+ populate.Do(func() {
+ var errors = []*Error{
+ ErrUnknownProtocol,
+ ErrUnknownNICID,
+ ErrUnknownDevice,
+ ErrUnknownProtocolOption,
+ ErrDuplicateNICID,
+ ErrDuplicateAddress,
+ ErrNoRoute,
+ ErrBadLinkEndpoint,
+ ErrAlreadyBound,
+ ErrInvalidEndpointState,
+ ErrAlreadyConnecting,
+ ErrAlreadyConnected,
+ ErrNoPortAvailable,
+ ErrPortInUse,
+ ErrBadLocalAddress,
+ ErrClosedForSend,
+ ErrClosedForReceive,
+ ErrWouldBlock,
+ ErrConnectionRefused,
+ ErrTimeout,
+ ErrAborted,
+ ErrConnectStarted,
+ ErrDestinationRequired,
+ ErrNotSupported,
+ ErrQueueSizeNotSupported,
+ ErrNotConnected,
+ ErrConnectionReset,
+ ErrConnectionAborted,
+ ErrNoSuchFile,
+ ErrInvalidOptionValue,
+ ErrNoLinkAddress,
+ ErrBadAddress,
+ ErrNetworkUnreachable,
+ ErrMessageTooLong,
+ ErrNoBufferSpace,
+ ErrBroadcastDisabled,
+ ErrNotPermitted,
+ ErrAddressFamilyNotSupported,
+ }
+
+ messageToError = make(map[string]*Error)
+ for _, e := range errors {
+ if messageToError[e.String()] != nil {
+ panic("tcpip errors with duplicated message: " + e.String())
+ }
+ messageToError[e.String()] = e
+ }
+ })
+
+ e, ok := messageToError[s]
+ if !ok {
+ panic("unknown error message: " + s)
+ }
+
+ return e
+}
+
// Errors related to Subnet
var (
errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
@@ -128,7 +195,7 @@ func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
-// A Clock provides the current time.
+// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
@@ -139,6 +206,31 @@ type Clock interface {
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
+
+ // AfterFunc waits for the duration to elapse and then calls f in its own
+ // goroutine. It returns a Timer that can be used to cancel the call using
+ // its Stop method.
+ AfterFunc(d time.Duration, f func()) Timer
+}
+
+// Timer represents a single event. A Timer must be created with
+// Clock.AfterFunc.
+type Timer interface {
+ // Stop prevents the Timer from firing. It returns true if the call stops the
+ // timer, false if the timer has already expired or been stopped.
+ //
+ // If Stop returns false, then the timer has already expired and the function
+ // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
+ // does not wait for f to complete before returning. If the caller needs to
+ // know whether f is completed, it must coordinate with f explicitly.
+ Stop() bool
+
+ // Reset changes the timer to expire after duration d.
+ //
+ // Reset should be invoked only on stopped or expired timers. If the timer is
+ // known to have expired, Reset can be used directly. Otherwise, the caller
+ // must coordinate with the function f of Clock.AfterFunc(d, f).
+ Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
@@ -231,6 +323,36 @@ func (s *Subnet) Broadcast() Address {
return Address(addr)
}
+// IsBroadcast returns true if the address is considered a broadcast address.
+func (s *Subnet) IsBroadcast(address Address) bool {
+ // Only IPv4 supports the notion of a broadcast address.
+ if len(address) != ipv4AddressSize {
+ return false
+ }
+
+ // Normally, we would just compare address with the subnet's broadcast
+ // address but there is an exception where a simple comparison is not
+ // correct. This exception is for /31 and /32 IPv4 subnets where all
+ // addresses are considered valid host addresses.
+ //
+ // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that
+ // both addresses in a /31 subnet "MUST be interpreted as host addresses."
+ //
+ // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32
+ // subnets. However, the same reasoning applies - if an exception is not
+ // made, then there do not exist any host addresses in a /32 subnet. RFC
+ // 4632 Section 3.1 also vaguely implies this interpretation by referring
+ // to addresses in /32 subnets as "host routes."
+ return s.Prefix() <= 30 && s.Broadcast() == address
+}
+
+// Equal returns true if s equals o.
+//
+// Needed to use cmp.Equal on Subnet as its fields are unexported.
+func (s Subnet) Equal(o Subnet) bool {
+ return s == o
+}
+
// NICID is a number that uniquely identifies a NIC.
type NICID int32
@@ -245,6 +367,28 @@ const (
ShutdownWrite
)
+// PacketType is used to indicate the destination of the packet.
+type PacketType uint8
+
+const (
+ // PacketHost indicates a packet addressed to the local host.
+ PacketHost PacketType = iota
+
+ // PacketOtherHost indicates an outgoing packet addressed to
+ // another host caught by a NIC in promiscuous mode.
+ PacketOtherHost
+
+ // PacketOutgoing for a packet originating from the local host
+ // that is looped back to a packet socket.
+ PacketOutgoing
+
+ // PacketBroadcast indicates a link layer broadcast packet.
+ PacketBroadcast
+
+ // PacketMulticast indicates a link layer multicast packet.
+ PacketMulticast
+)
+
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
@@ -301,7 +445,7 @@ type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
- // Timestamp is the time (in ns) that the last packed used to create
+ // Timestamp is the time (in ns) that the last packet used to create
// the read data was received.
Timestamp int64
@@ -310,6 +454,33 @@ type ControlMessages struct {
// Inq is the number of bytes ready to be received.
Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo IPPacketInfo
+}
+
+// PacketOwner is used to get UID and GID of the packet.
+type PacketOwner interface {
+ // UID returns UID of the packet.
+ UID() uint32
+
+ // GID returns GID of the packet.
+ GID() uint32
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
@@ -317,9 +488,15 @@ type ControlMessages struct {
// networking stack.
type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
- // associated with it.
+ // associated with it. Close initiates the teardown process, the
+ // Endpoint may not be fully closed when Close returns.
Close()
+ // Abort initiates an expedited endpoint teardown. As compared to
+ // Close, Abort prioritizes closing the Endpoint quickly over cleanly.
+ // Abort is best effort; implementing Abort with Close is acceptable.
+ Abort()
+
// Read reads data from the endpoint and optionally returns the sender.
//
// This method does not block if there is no data pending. It will also
@@ -404,17 +581,25 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptBool sets a socket option, for simple cases where a value
+ // has the bool type.
+ SetSockOptBool(opt SockOptBool, v bool) *Error
+
// SetSockOptInt sets a socket option, for simple cases where a value
// has the int type.
- SetSockOptInt(opt SockOpt, v int) *Error
+ SetSockOptInt(opt SockOptInt, v int) *Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
+ // GetSockOptBool gets a socket option for simple cases where a return
+ // value has the bool type.
+ GetSockOptBool(SockOptBool) (bool, *Error)
+
// GetSockOptInt gets a socket option for simple cases where a return
// value has the int type.
- GetSockOptInt(SockOpt) (int, *Error)
+ GetSockOptInt(SockOptInt) (int, *Error)
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
@@ -427,14 +612,36 @@ type Endpoint interface {
// NOTE: This method is a no-op for sockets other than TCP.
ModerateRecvBuf(copied int)
- // IPTables returns the iptables for this endpoint's stack.
- IPTables() (iptables.IPTables, error)
-
// Info returns a copy to the transport endpoint info.
Info() EndpointInfo
// Stats returns a reference to the endpoint stats.
Stats() EndpointStats
+
+ // SetOwner sets the task owner to the endpoint owner.
+ SetOwner(owner PacketOwner)
+}
+
+// LinkPacketInfo holds Link layer information for a received packet.
+//
+// +stateify savable
+type LinkPacketInfo struct {
+ // Protocol is the NetworkProtocolNumber for the packet.
+ Protocol NetworkProtocolNumber
+
+ // PktType is used to indicate the destination of the packet.
+ PktType PacketType
+}
+
+// PacketEndpoint are additional methods that are only implemented by Packet
+// endpoints.
+type PacketEndpoint interface {
+ // ReadPacket reads a datagram/packet from the endpoint and optionally
+ // returns the sender and additional LinkPacketInfo.
+ //
+ // This method does not block if there is no data pending. It will also
+ // either return an error or data, never both.
+ ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
}
// EndpointInfo is the interface implemented by each endpoint info struct.
@@ -469,13 +676,117 @@ type WriteOptions struct {
Atomic bool
}
-// SockOpt represents socket options which values have the int type.
-type SockOpt int
+// SockOptBool represents socket options which values have the bool type.
+type SockOptBool int
const (
+ // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether datagram sockets are allowed to send packets to a broadcast
+ // address.
+ BroadcastOption SockOptBool = iota
+
+ // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be held until segments are full by the TCP transport
+ // protocol.
+ CorkOption
+
+ // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if
+ // data should be sent out immediately by the transport protocol. For
+ // TCP, it determines if the Nagle algorithm is on or off.
+ DelayOption
+
+ // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether TCP keepalive is enabled for this socket.
+ KeepaliveEnabledOption
+
+ // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether multicast packets sent over a non-loopback interface
+ // will be looped back.
+ MulticastLoopOption
+
+ // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether UDP checksum is disabled for this socket.
+ NoChecksumOption
+
+ // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether SCM_CREDENTIALS socket control messages are enabled.
+ //
+ // Only supported on Unix sockets.
+ PasscredOption
+
+ // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
+ QuickAckOption
+
+ // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if the IPV6_TCLASS ancillary message is passed with incoming
+ // packets.
+ ReceiveTClassOption
+
+ // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify
+ // if the TOS ancillary message is passed with incoming packets.
+ ReceiveTOSOption
+
+ // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to
+ // specify if more inforamtion is provided with incoming packets such as
+ // interface index and address.
+ ReceiveIPPacketInfoOption
+
+ // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to
+ // specify whether Bind() should allow reuse of local address.
+ ReuseAddressOption
+
+ // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit
+ // multiple sockets to be bound to an identical socket address.
+ ReusePortOption
+
+ // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify
+ // whether an IPv6 socket is to be restricted to sending and receiving
+ // IPv6 packets only.
+ V6OnlyOption
+
+ // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw
+ // endpoint that all packets being written have an IP header and the
+ // endpoint should not attach an IP header.
+ IPHdrIncludedOption
+)
+
+// SockOptInt represents socket options which values have the int type.
+type SockOptInt int
+
+const (
+ // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the number of un-ACKed TCP keepalives that will be sent
+ // before the connection is closed.
+ KeepaliveCountOption SockOptInt = iota
+
+ // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS
+ // for all subsequent outgoing IPv4 packets from the endpoint.
+ IPv4TOSOption
+
+ // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to
+ // specify TOS for all subsequent outgoing IPv6 packets from the
+ // endpoint.
+ IPv6TrafficClassOption
+
+ // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the
+ // current Maximum Segment Size(MSS) value as specified using the
+ // TCP_MAXSEG option.
+ MaxSegOption
+
+ // MTUDiscoverOption is used to set/get the path MTU discovery setting.
+ //
+ // NOTE: Setting this option to any other value than PMTUDiscoveryDont
+ // is not supported and will fail as such, and getting this option will
+ // always return PMTUDiscoveryDont.
+ MTUDiscoverOption
+
+ // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control
+ // the default TTL value for multicast messages. The default is 1.
+ MulticastTTLOption
+
// ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
// number of unread bytes in the input buffer should be returned.
- ReceiveQueueSizeOption SockOpt = iota
+ ReceiveQueueSizeOption
// SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
// specify the send buffer size option.
@@ -489,47 +800,52 @@ const (
// number of unread bytes in the output buffer should be returned.
SendQueueSizeOption
- // DelayOption is used by SetSockOpt/GetSockOpt to specify if data
- // should be sent out immediately by the transport protocol. For TCP,
- // it determines if the Nagle algorithm is on or off.
- DelayOption
+ // TTLOption is used by SetSockOptInt/GetSockOptInt to control the
+ // default TTL/hop limit value for unicast messages. The default is
+ // protocol specific.
+ //
+ // A zero value indicates the default.
+ TTLOption
- // TODO(b/137664753): convert all int socket options to be handled via
- // GetSockOptInt.
+ // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify
+ // the number of SYN retransmits that TCP should send before aborting
+ // the attempt to connect. It cannot exceed 255.
+ //
+ // NOTE: This option is currently only stubbed out and is no-op.
+ TCPSynCountOption
+
+ // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound
+ // the size of the advertised window to this value.
+ //
+ // NOTE: This option is currently only stubed out and is a no-op
+ TCPWindowClampOption
)
-// ErrorOption is used in GetSockOpt to specify that the last error reported by
-// the endpoint should be cleared and returned.
-type ErrorOption struct{}
+const (
+ // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use
+ // per-route settings.
+ PMTUDiscoveryWant int = iota
-// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
-// socket is to be restricted to sending and receiving IPv6 packets only.
-type V6OnlyOption int
+ // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable
+ // path MTU discovery.
+ PMTUDiscoveryDont
-// CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be
-// held until segments are full by the TCP transport protocol.
-type CorkOption int
+ // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do
+ // path MTU discovery.
+ PMTUDiscoveryDo
-// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
-// should allow reuse of local address.
-type ReuseAddressOption int
+ // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF
+ // but ignore path MTU.
+ PMTUDiscoveryProbe
+)
-// ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets
-// to be bound to an identical socket address.
-type ReusePortOption int
+// ErrorOption is used in GetSockOpt to specify that the last error reported by
+// the endpoint should be cleared and returned.
+type ErrorOption struct{}
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
-type BindToDeviceOption string
-
-// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
-type QuickAckOption int
-
-// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
-// SCM_CREDENTIALS socket control messages are enabled.
-//
-// Only supported on Unix sockets.
-type PasscredOption int
+type BindToDeviceOption NICID
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
@@ -539,10 +855,6 @@ type TCPInfoOption struct {
RTTVar time.Duration
}
-// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
-// TCP keepalive is enabled for this socket.
-type KeepaliveEnabledOption int
-
// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
// connection must remain idle before the first TCP keepalive packet is sent.
// Once this time is reached, KeepaliveIntervalOption is used instead.
@@ -552,10 +864,10 @@ type KeepaliveIdleOption time.Duration
// interval between sending TCP keepalive packets.
type KeepaliveIntervalOption time.Duration
-// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
-// of un-ACKed TCP keepalives that will be sent before the connection is
-// closed.
-type KeepaliveCountOption int
+// TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user
+// specified timeout for a given TCP connection.
+// See: RFC5482 for details.
+type TCPUserTimeoutOption time.Duration
// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get
// the current congestion control algorithm.
@@ -565,23 +877,45 @@ type CongestionControlOption string
// control algorithms.
type AvailableCongestionControlOption string
-// ModerateReceiveBufferOption allows the caller to enable/disable TCP receive
-// buffer moderation.
+// ModerateReceiveBufferOption is used by buffer moderation.
type ModerateReceiveBufferOption bool
-// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
-// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
-type MaxSegOption int
+// TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state
+// before being marked closed.
+type TCPLingerTimeoutOption time.Duration
-// TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop
-// limit value for unicast messages. The default is protocol specific.
-//
-// A zero value indicates the default.
-type TTLOption uint8
+// TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum duration for which a socket lingers in the TIME_WAIT state
+// before being marked closed.
+type TCPTimeWaitTimeoutOption time.Duration
+
+// TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a
+// accept to return a completed connection only when there is data to be
+// read. This usually means the listening socket will drop the final ACK
+// for a handshake till the specified timeout until a segment with data arrives.
+type TCPDeferAcceptOption time.Duration
+
+// TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
+// default MinRTO used by the Stack.
+type TCPMinRTOOption time.Duration
+
+// TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding
+// default MaxRTO used by the Stack.
+type TCPMaxRTOOption time.Duration
-// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
-// TTL value for multicast messages. The default is 1.
-type MulticastTTLOption uint8
+// TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the
+// maximum number of retransmits after which we time out the connection.
+type TCPMaxRetriesOption uint64
+
+// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
+// the number of endpoints that can be in SYN-RCVD state before the stack
+// switches to using SYN cookies.
+type TCPSynRcvdCountThresholdOption uint64
+
+// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide
+// default for number of times SYN is retransmitted before aborting a connect.
+type TCPSynRetriesOption uint8
// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
// default interface for multicast.
@@ -590,10 +924,6 @@ type MulticastInterfaceOption struct {
InterfaceAddr Address
}
-// MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether
-// multicast packets sent over a non-loopback interface will be looped back.
-type MulticastLoopOption bool
-
// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
// AddMembershipOption and RemoveMembershipOption.
type MembershipOption struct {
@@ -616,21 +946,51 @@ type RemoveMembershipOption MembershipOption
// TCP out-of-band data is delivered along with the normal in-band data.
type OutOfBandInlineOption int
-// BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether
-// datagram sockets are allowed to send packets to a broadcast address.
-type BroadcastOption int
-
// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify
// a default TTL.
type DefaultTTLOption uint8
-// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS
-// for all subsequent outgoing IPv4 packets from the endpoint.
-type IPv4TOSOption uint8
+// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached
+// classic BPF filter on a given endpoint.
+type SocketDetachFilterOption int
+
+// OriginalDestinationOption is used to get the original destination address
+// and port of a redirected packet.
+type OriginalDestinationOption FullAddress
+
+// TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to
+// specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for
+// new connections when it is safe from protocol viewpoint.
+type TCPTimeWaitReuseOption uint8
+
+const (
+ // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot
+ // be reused for new connections.
+ TCPTimeWaitReuseDisabled TCPTimeWaitReuseOption = iota
+
+ // TCPTimeWaitReuseGlobal indicates reuse of port bound by endponts in TIME-WAIT can
+ // be reused for new connections irrespective of the src/dest addresses.
+ TCPTimeWaitReuseGlobal
+
+ // TCPTimeWaitReuseLoopbackOnly indicates reuse of port bound by endpoint in TIME-WAIT can
+ // only be reused if the connection was a connection over loopback. i.e src/dest adddresses
+ // are loopback addresses.
+ TCPTimeWaitReuseLoopbackOnly
+)
+
+// IPPacketInfo is the message structure for IP_PKTINFO.
+//
+// +stateify savable
+type IPPacketInfo struct {
+ // NIC is the ID of the NIC to be used.
+ NIC NICID
+
+ // LocalAddr is the local address.
+ LocalAddr Address
-// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS
-// for all subsequent outgoing IPv6 packets from the endpoint.
-type IPv6TrafficClassOption uint8
+ // DestinationAddr is the destination address found in the IP header.
+ DestinationAddr Address
+}
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
@@ -852,9 +1212,13 @@ type IPStats struct {
// link layer in nic.DeliverNetworkPacket.
PacketsReceived *StatCounter
- // InvalidAddressesReceived is the total number of IP packets received
- // with an unknown or invalid destination address.
- InvalidAddressesReceived *StatCounter
+ // InvalidDestinationAddressesReceived is the total number of IP packets
+ // received with an unknown or invalid destination address.
+ InvalidDestinationAddressesReceived *StatCounter
+
+ // InvalidSourceAddressesReceived is the total number of IP packets received
+ // with a source address that should never have been received on the wire.
+ InvalidSourceAddressesReceived *StatCounter
// PacketsDelivered is the total number of incoming IP packets that
// are successfully delivered to the transport layer via HandlePacket.
@@ -887,14 +1251,26 @@ type TCPStats struct {
PassiveConnectionOpenings *StatCounter
// CurrentEstablished is the number of TCP connections for which the
- // current state is either ESTABLISHED or CLOSE-WAIT.
+ // current state is ESTABLISHED.
CurrentEstablished *StatCounter
+ // CurrentConnected is the number of TCP connections that
+ // are in connected state.
+ CurrentConnected *StatCounter
+
// EstablishedResets is the number of times TCP connections have made
// a direct transition to the CLOSED state from either the
// ESTABLISHED state or the CLOSE-WAIT state.
EstablishedResets *StatCounter
+ // EstablishedClosed is the number of times established TCP connections
+ // made a transition to CLOSED state.
+ EstablishedClosed *StatCounter
+
+ // EstablishedTimedout is the number of times an established connection
+ // was reset because of keep-alive time out.
+ EstablishedTimedout *StatCounter
+
// ListenOverflowSynDrop is the number of times the listen queue overflowed
// and a SYN was dropped.
ListenOverflowSynDrop *StatCounter
@@ -987,6 +1363,12 @@ type UDPStats struct {
// PacketSendErrors is the number of datagrams failed to be sent.
PacketSendErrors *StatCounter
+
+ // ChecksumErrors is the number of datagrams dropped due to bad checksums.
+ ChecksumErrors *StatCounter
+
+ // InvalidSourceAddress is the number of invalid sourced datagrams dropped.
+ InvalidSourceAddress *StatCounter
}
// Stats holds statistics about the networking stack.
@@ -1030,6 +1412,9 @@ type ReceiveErrors struct {
// ClosedReceiver is the number of received packets dropped because
// of receiving endpoint state being closed.
ClosedReceiver StatCounter
+
+ // ChecksumErrors is the number of packets dropped due to bad checksums.
+ ChecksumErrors StatCounter
}
// SendErrors collects packet send errors within the transport layer for
@@ -1055,6 +1440,10 @@ type ReadErrors struct {
// InvalidEndpointState is the number of times we found the endpoint state
// to be unexpected.
InvalidEndpointState StatCounter
+
+ // NotConnected is the number of times we tried to read but found that the
+ // endpoint was not connected.
+ NotConnected StatCounter
}
// WriteErrors collects packet write errors from an endpoint write call.
@@ -1097,7 +1486,9 @@ type TransportEndpointStats struct {
// marker interface.
func (*TransportEndpointStats) IsEndpointStats() {}
-func fillIn(v reflect.Value) {
+// InitStatCounters initializes v's fields with nil StatCounter fields to new
+// StatCounters.
+func InitStatCounters(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
if s, ok := v.Addr().Interface().(**StatCounter); ok {
@@ -1105,14 +1496,14 @@ func fillIn(v reflect.Value) {
*s = new(StatCounter)
}
} else {
- fillIn(v)
+ InitStatCounters(v)
}
}
}
// FillIn returns a copy of s with nil fields initialized to new StatCounters.
func (s Stats) FillIn() Stats {
- fillIn(reflect.ValueOf(&s).Elem())
+ InitStatCounters(reflect.ValueOf(&s).Elem())
return s
}
@@ -1322,8 +1713,8 @@ var (
// GetDanglingEndpoints returns all dangling endpoints.
func GetDanglingEndpoints() []Endpoint {
- es := make([]Endpoint, 0, len(danglingEndpoints))
danglingEndpointsMu.Lock()
+ es := make([]Endpoint, 0, len(danglingEndpoints))
for e := range danglingEndpoints {
es = append(es, e)
}
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 8c0aacffa..1c8e2bc34 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -218,7 +218,7 @@ func TestAddressWithPrefixSubnet(t *testing.T) {
gotSubnet := ap.Subnet()
wantSubnet, err := NewSubnet(tt.subnetAddr, tt.subnetMask)
if err != nil {
- t.Error("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
+ t.Errorf("NewSubnet(%q, %q) failed: %s", tt.subnetAddr, tt.subnetMask, err)
continue
}
if gotSubnet != wantSubnet {
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
new file mode 100644
index 000000000..6d52af98a
--- /dev/null
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+go_test(
+ name = "integration_test",
+ size = "small",
+ srcs = ["multicast_broadcast_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
new file mode 100644
index 000000000..9f0dd4d6d
--- /dev/null
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -0,0 +1,438 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package integration_test
+
+import (
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ defaultMTU = 1280
+ ttl = 255
+)
+
+var (
+ ipv4Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("192.168.1.58").To4()),
+ PrefixLen: 24,
+ }
+ ipv4Subnet = ipv4Addr.Subnet()
+ ipv4SubnetBcast = ipv4Subnet.Broadcast()
+
+ ipv6Addr = tcpip.AddressWithPrefix{
+ Address: tcpip.Address(net.ParseIP("200a::1").To16()),
+ PrefixLen: 64,
+ }
+ ipv6Subnet = ipv6Addr.Subnet()
+ ipv6SubnetBcast = ipv6Subnet.Broadcast()
+
+ // Remote addrs.
+ remoteIPv4Addr = tcpip.Address(net.ParseIP("10.0.0.1").To4())
+ remoteIPv6Addr = tcpip.Address(net.ParseIP("200b::1").To16())
+)
+
+// TestPingMulticastBroadcast tests that responding to an Echo Request destined
+// to a multicast or broadcast address uses a unicast source address for the
+// reply.
+func TestPingMulticastBroadcast(t *testing.T) {
+ const nicID = 1
+
+ rxIPv4ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4Echo)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(^header.Checksum(pkt, 0))
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(icmp.ProtocolNumber4),
+ TTL: ttl,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6ICMP := func(e *channel.Endpoint, dst tcpip.Address) {
+ totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ pkt.SetType(header.ICMPv6EchoRequest)
+ pkt.SetCode(0)
+ pkt.SetChecksum(0)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{}))
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: header.ICMPv6MinimumSize,
+ NextHeader: uint8(icmp.ProtocolNumber6),
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ dstAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast",
+ dstAddr: ipv4Addr.Address,
+ },
+ {
+ name: "IPv4 directed broadcast",
+ dstAddr: ipv4SubnetBcast,
+ },
+ {
+ name: "IPv4 broadcast",
+ dstAddr: header.IPv4Broadcast,
+ },
+ {
+ name: "IPv4 all-systems multicast",
+ dstAddr: header.IPv4AllSystems,
+ },
+ {
+ name: "IPv6 unicast",
+ dstAddr: ipv6Addr.Address,
+ },
+ {
+ name: "IPv6 all-nodes multicast",
+ dstAddr: header.IPv6AllNodesMulticastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ipv4Proto := ipv4.NewProtocol()
+ ipv6Proto := ipv6.NewProtocol()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4Proto, ipv6Proto},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4(), icmp.NewProtocol6()},
+ })
+ // We only expect a single packet in response to our ICMP Echo Request.
+ e := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ }
+ ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ }
+
+ // Default routes for IPv4 and IPv6 so ICMP can find a route to the remote
+ // node when attempting to send the ICMP Echo Reply.
+ s.SetRouteTable([]tcpip.Route{
+ tcpip.Route{
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID,
+ },
+ tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID,
+ },
+ })
+
+ var rxICMP func(*channel.Endpoint, tcpip.Address)
+ var expectedSrc tcpip.Address
+ var expectedDst tcpip.Address
+ var proto stack.NetworkProtocol
+ switch l := len(test.dstAddr); l {
+ case header.IPv4AddressSize:
+ rxICMP = rxIPv4ICMP
+ expectedSrc = ipv4Addr.Address
+ expectedDst = remoteIPv4Addr
+ proto = ipv4Proto
+ case header.IPv6AddressSize:
+ rxICMP = rxIPv6ICMP
+ expectedSrc = ipv6Addr.Address
+ expectedDst = remoteIPv6Addr
+ proto = ipv6Proto
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ rxICMP(e, test.dstAddr)
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected ICMP response")
+ }
+
+ if pkt.Route.LocalAddress != expectedSrc {
+ t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, expectedSrc)
+ }
+ if pkt.Route.RemoteAddress != expectedDst {
+ t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, expectedDst)
+ }
+
+ src, dst := proto.ParseAddresses(pkt.Pkt.NetworkHeader().View())
+ if src != expectedSrc {
+ t.Errorf("got pkt source = %s, want = %s", src, expectedSrc)
+ }
+ if dst != expectedDst {
+ t.Errorf("got pkt destination = %s, want = %s", dst, expectedDst)
+ }
+ })
+ }
+
+}
+
+// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some
+// multicast or broadcast address.
+func TestIncomingMulticastAndBroadcast(t *testing.T) {
+ const (
+ nicID = 1
+ remotePort = 5555
+ localPort = 80
+ )
+
+ data := []byte{1, 2, 3, 4}
+
+ rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ totalLen := header.IPv4MinimumSize + payloadLen
+ hdr := buffer.NewPrependable(totalLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(udp.ProtocolNumber),
+ TTL: ttl,
+ SrcAddr: remoteIPv4Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) {
+ payloadLen := header.UDPMinimumSize + len(data)
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen)
+ u := header.UDP(hdr.Prepend(payloadLen))
+ u.Encode(&header.UDPFields{
+ SrcPort: remotePort,
+ DstPort: localPort,
+ Length: uint16(payloadLen),
+ })
+ copy(u.Payload(), data)
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen))
+ sum = header.Checksum(data, sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLen),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: ttl,
+ SrcAddr: remoteIPv6Addr,
+ DstAddr: dst,
+ })
+
+ e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: hdr.View().ToVectorisedView(),
+ }))
+ }
+
+ tests := []struct {
+ name string
+ bindAddr tcpip.Address
+ dstAddr tcpip.Address
+ expectRx bool
+ }{
+ {
+ name: "IPv4 unicast binding to unicast",
+ bindAddr: ipv4Addr.Address,
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 unicast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4Addr.Address,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 unicast binding to wildcard",
+ dstAddr: ipv4Addr.Address,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 directed broadcast binding to subnet broadcast",
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 directed broadcast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: ipv4SubnetBcast,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 directed broadcast binding to wildcard",
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 broadcast binding to broadcast",
+ bindAddr: header.IPv4Broadcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 broadcast binding to subnet broadcast",
+ bindAddr: ipv4SubnetBcast,
+ dstAddr: header.IPv4Broadcast,
+ expectRx: false,
+ },
+ {
+ name: "IPv4 broadcast binding to wildcard",
+ dstAddr: ipv4SubnetBcast,
+ expectRx: true,
+ },
+
+ {
+ name: "IPv4 all-systems multicast binding to all-systems multicast",
+ bindAddr: header.IPv4AllSystems,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 all-systems multicast binding to wildcard",
+ dstAddr: header.IPv4AllSystems,
+ expectRx: true,
+ },
+ {
+ name: "IPv4 all-systems multicast binding to unicast",
+ bindAddr: ipv4Addr.Address,
+ dstAddr: header.IPv4AllSystems,
+ expectRx: false,
+ },
+
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 unicast binding to wildcard",
+ dstAddr: ipv6Addr.Address,
+ expectRx: true,
+ },
+ {
+ name: "IPv6 broadcast-like address binding to wildcard",
+ dstAddr: ipv6SubnetBcast,
+ expectRx: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr}
+ if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err)
+ }
+ ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr}
+ if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err)
+ }
+
+ var netproto tcpip.NetworkProtocolNumber
+ var rxUDP func(*channel.Endpoint, tcpip.Address)
+ switch l := len(test.dstAddr); l {
+ case header.IPv4AddressSize:
+ netproto = header.IPv4ProtocolNumber
+ rxUDP = rxIPv4UDP
+ case header.IPv6AddressSize:
+ netproto = header.IPv6ProtocolNumber
+ rxUDP = rxIPv6UDP
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err)
+ }
+ defer ep.Close()
+
+ bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort}
+ if err := ep.Bind(bindAddr); err != nil {
+ t.Fatalf("ep.Bind(%+v): %s", bindAddr, err)
+ }
+
+ rxUDP(e, test.dstAddr)
+ if gotPayload, _, err := ep.Read(nil); test.expectRx {
+ if err != nil {
+ t.Fatalf("Read(nil): %s", err)
+ }
+ if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" {
+ t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff)
+ }
+ } else {
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index a52262e87..f32d58091 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -13,18 +13,20 @@
// limitations under the License.
// +build go1.9
-// +build !go1.14
+// +build !go1.16
// Check go:linkname function signatures when updating Go version.
package tcpip
import (
- _ "time" // Used with go:linkname.
+ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
// StdClock implements Clock with the time package.
+//
+// +stateify savable
type StdClock struct{}
var _ Clock = (*StdClock)(nil)
@@ -43,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}
+
+// AfterFunc implements Clock.AfterFunc.
+func (*StdClock) AfterFunc(d time.Duration, f func()) Timer {
+ return &stdTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+type stdTimer struct {
+ t *time.Timer
+}
+
+var _ Timer = (*stdTimer)(nil)
+
+// Stop implements Timer.Stop.
+func (st *stdTimer) Stop() bool {
+ return st.t.Stop()
+}
+
+// Reset implements Timer.Reset.
+func (st *stdTimer) Reset(d time.Duration) {
+ st.t.Reset(d)
+}
+
+// NewStdTimer returns a Timer implemented with the time package.
+func NewStdTimer(t *time.Timer) Timer {
+ return &stdTimer{t: t}
+}
diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go
new file mode 100644
index 000000000..f1dd7c310
--- /dev/null
+++ b/pkg/tcpip/timer.go
@@ -0,0 +1,206 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// jobInstance is a specific instance of Job.
+//
+// Different instances are created each time Job is scheduled so each timer has
+// its own earlyReturn signal. This is to address a bug when a Job is stopped
+// and reset in quick succession resulting in a timer instance's earlyReturn
+// signal being affected or seen by another timer instance.
+//
+// Consider the following sceneario where timer instances share a common
+// earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a
+// lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second
+// (B), third (C), and fourth (D) instance of the timer firing, respectively):
+// T1: Obtain L
+// T1: Create a new Job w/ lock L (create instance A)
+// T2: instance A fires, blocked trying to obtain L.
+// T1: Attempt to stop instance A (set earlyReturn = true)
+// T1: Schedule timer (create instance B)
+// T3: instance B fires, blocked trying to obtain L.
+// T1: Attempt to stop instance B (set earlyReturn = true)
+// T1: Schedule timer (create instance C)
+// T4: instance C fires, blocked trying to obtain L.
+// T1: Attempt to stop instance C (set earlyReturn = true)
+// T1: Schedule timer (create instance D)
+// T5: instance D fires, blocked trying to obtain L.
+// T1: Release L
+//
+// Now that T1 has released L, any of the 4 timer instances can take L and
+// check earlyReturn. If the timers simply check earlyReturn and then do
+// nothing further, then instance D will never early return even though it was
+// not requested to stop. If the timers reset earlyReturn before early
+// returning, then all but one of the timers will do work when only one was
+// expected to. If Job resets earlyReturn when resetting, then all the timers
+// will fire (again, when only one was expected to).
+//
+// To address the above concerns the simplest solution was to give each timer
+// its own earlyReturn signal.
+type jobInstance struct {
+ timer Timer
+
+ // Used to inform the timer to early return when it gets stopped while the
+ // lock the timer tries to obtain when fired is held (T1 is a goroutine that
+ // tries to cancel the timer and T2 is the goroutine that handles the timer
+ // firing):
+ // T1: Obtain the lock, then call Cancel()
+ // T2: timer fires, and gets blocked on obtaining the lock
+ // T1: Releases lock
+ // T2: Obtains lock does unintended work
+ //
+ // To resolve this, T1 will check to see if the timer already fired, and
+ // inform the timer using earlyReturn to return early so that once T2 obtains
+ // the lock, it will see that it is set to true and do nothing further.
+ earlyReturn *bool
+}
+
+// stop stops the job instance j from firing if it hasn't fired already. If it
+// has fired and is blocked at obtaining the lock, earlyReturn will be set to
+// true so that it will early return when it obtains the lock.
+func (j *jobInstance) stop() {
+ if j.timer != nil {
+ j.timer.Stop()
+ *j.earlyReturn = true
+ }
+}
+
+// Job represents some work that can be scheduled for execution. The work can
+// be safely cancelled when it fires at the same time some "related work" is
+// being done.
+//
+// The term "related work" is defined as some work that needs to be done while
+// holding some lock that the timer must also hold while doing some work.
+//
+// Note, it is not safe to copy a Job as its timer instance creates
+// a closure over the address of the Job.
+type Job struct {
+ _ sync.NoCopy
+
+ // The clock used to schedule the backing timer
+ clock Clock
+
+ // The active instance of a cancellable timer.
+ instance jobInstance
+
+ // locker is the lock taken by the timer immediately after it fires and must
+ // be held when attempting to stop the timer.
+ //
+ // Must never change after being assigned.
+ locker sync.Locker
+
+ // fn is the function that will be called when a timer fires and has not been
+ // signaled to early return.
+ //
+ // fn MUST NOT attempt to lock locker.
+ //
+ // Must never change after being assigned.
+ fn func()
+}
+
+// Cancel prevents the Job from executing if it has not executed already.
+//
+// Cancel requires appropriate locking to be in place for any resources managed
+// by the Job. If the Job is blocked on obtaining the lock when Cancel is
+// called, it will early return.
+//
+// Note, t will be modified.
+//
+// j.locker MUST be locked.
+func (j *Job) Cancel() {
+ j.instance.stop()
+
+ // Nothing to do with the stopped instance anymore.
+ j.instance = jobInstance{}
+}
+
+// Schedule schedules the Job for execution after duration d. This can be
+// called on cancelled or completed Jobs to schedule them again.
+//
+// Schedule should be invoked only on unscheduled, cancelled, or completed
+// Jobs. To be safe, callers should always call Cancel before calling Schedule.
+//
+// Note, j will be modified.
+func (j *Job) Schedule(d time.Duration) {
+ // Create a new instance.
+ earlyReturn := false
+
+ // Capture the locker so that updating the timer does not cause a data race
+ // when a timer fires and tries to obtain the lock (read the timer's locker).
+ locker := j.locker
+ j.instance = jobInstance{
+ timer: j.clock.AfterFunc(d, func() {
+ locker.Lock()
+ defer locker.Unlock()
+
+ if earlyReturn {
+ // If we reach this point, it means that the timer fired while another
+ // goroutine called Cancel while it had the lock. Simply return here
+ // and do nothing further.
+ earlyReturn = false
+ return
+ }
+
+ j.fn()
+ }),
+ earlyReturn: &earlyReturn,
+ }
+}
+
+// NewJob returns a new Job that can be used to schedule f to run in its own
+// gorountine. l will be locked before calling f then unlocked after f returns.
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// message = "bar"
+// mu.Unlock()
+//
+// // Output: bar
+//
+// f MUST NOT attempt to lock l.
+//
+// l MUST be locked prior to calling the returned job's Cancel().
+//
+// var clock tcpip.StdClock
+// var mu sync.Mutex
+// message := "foo"
+// job := tcpip.NewJob(&clock, &mu, func() {
+// fmt.Println(message)
+// })
+// job.Schedule(time.Second)
+//
+// mu.Lock()
+// job.Cancel()
+// mu.Unlock()
+func NewJob(c Clock, l sync.Locker, f func()) *Job {
+ return &Job{
+ clock: c,
+ locker: l,
+ fn: f,
+ }
+}
diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go
new file mode 100644
index 000000000..a82384c49
--- /dev/null
+++ b/pkg/tcpip/timer_test.go
@@ -0,0 +1,268 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcpip_test
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ shortDuration = 1 * time.Nanosecond
+ middleDuration = 100 * time.Millisecond
+ longDuration = 1 * time.Second
+)
+
+func TestJobReschedule(t *testing.T) {
+ var clock tcpip.StdClock
+ var wg sync.WaitGroup
+ var lock sync.Mutex
+
+ for i := 0; i < 2; i++ {
+ wg.Add(1)
+
+ go func() {
+ lock.Lock()
+ // Assigning a new timer value updates the timer's locker and function.
+ // This test makes sure there is no data race when reassigning a timer
+ // that has an active timer (even if it has been stopped as a stopped
+ // timer may be blocked on a lock before it can check if it has been
+ // stopped while another goroutine holds the same lock).
+ job := tcpip.NewJob(&clock, &lock, func() {
+ wg.Done()
+ })
+ job.Schedule(shortDuration)
+ lock.Unlock()
+ }()
+ }
+ wg.Wait()
+}
+
+func TestJobExecution(t *testing.T) {
+ t.Parallel()
+
+ var clock tcpip.StdClock
+ var lock sync.Mutex
+ ch := make(chan struct{})
+
+ job := tcpip.NewJob(&clock, &lock, func() {
+ ch <- struct{}{}
+ })
+ job.Schedule(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestCancellableTimerResetFromLongDuration(t *testing.T) {
+ t.Parallel()
+
+ var clock tcpip.StdClock
+ var lock sync.Mutex
+ ch := make(chan struct{})
+
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(middleDuration)
+
+ lock.Lock()
+ job.Cancel()
+ lock.Unlock()
+
+ job.Schedule(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestJobRescheduleFromShortDuration(t *testing.T) {
+ t.Parallel()
+
+ var clock tcpip.StdClock
+ var lock sync.Mutex
+ ch := make(chan struct{})
+
+ lock.Lock()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
+ lock.Unlock()
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+
+ job.Schedule(shortDuration)
+
+ // Wait for timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestJobImmediatelyCancel(t *testing.T) {
+ t.Parallel()
+
+ var clock tcpip.StdClock
+ var lock sync.Mutex
+ ch := make(chan struct{})
+
+ for i := 0; i < 1000; i++ {
+ lock.Lock()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
+ lock.Unlock()
+ }
+
+ // Wait for timer to fire if it wasn't correctly stopped.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestJobCancelledRescheduleWithoutLock(t *testing.T) {
+ t.Parallel()
+
+ var clock tcpip.StdClock
+ var lock sync.Mutex
+ ch := make(chan struct{})
+
+ lock.Lock()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ job.Cancel()
+ lock.Unlock()
+
+ for i := 0; i < 10; i++ {
+ job.Schedule(middleDuration)
+
+ lock.Lock()
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration * 2)
+ job.Cancel()
+ lock.Unlock()
+ }
+
+ // Wait for double the duration so timers that weren't correctly stopped can
+ // fire.
+ select {
+ case <-ch:
+ t.Fatal("timer fired after being stopped")
+ case <-time.After(middleDuration * 2):
+ }
+}
+
+func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) {
+ t.Parallel()
+
+ var clock tcpip.StdClock
+ var lock sync.Mutex
+ ch := make(chan struct{})
+
+ lock.Lock()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ for i := 0; i < 10; i++ {
+ // Sleep until the timer fires and gets blocked trying to take the lock.
+ time.Sleep(middleDuration)
+ job.Cancel()
+ job.Schedule(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
+
+func TestManyJobReschedulesUnderLock(t *testing.T) {
+ t.Parallel()
+
+ var clock tcpip.StdClock
+ var lock sync.Mutex
+ ch := make(chan struct{})
+
+ lock.Lock()
+ job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} })
+ job.Schedule(shortDuration)
+ for i := 0; i < 10; i++ {
+ job.Cancel()
+ job.Schedule(shortDuration)
+ }
+ lock.Unlock()
+
+ // Wait for double the duration for the last timer to fire.
+ select {
+ case <-ch:
+ case <-time.After(middleDuration):
+ t.Fatal("timed out waiting for timer to fire")
+ }
+
+ // The timer should have fired only once.
+ select {
+ case <-ch:
+ t.Fatal("no other timers should have fired")
+ case <-time.After(middleDuration):
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 9254c3dea..7e5c79776 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,26 +23,18 @@ go_library(
"icmp_packet_list.go",
"protocol.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/icmp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "icmp_packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 043467519..bd6f49eb8 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,12 +15,11 @@
package icmp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -31,9 +30,6 @@ type icmpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
}
type endpointState int
@@ -58,6 +54,7 @@ type endpoint struct {
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -77,6 +74,9 @@ type endpoint struct {
route stack.Route `state:"manual"`
ttl uint8
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -90,9 +90,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: stateInitial,
+ uniqueID: s.UniqueID(),
}, nil
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -100,7 +111,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, 0 /* bindToDevice */)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -126,9 +137,8 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
- return e.stack.IPTables(), nil
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
}
// Read reads data from the endpoint. This method does not block if
@@ -274,24 +284,22 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
} else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicid := to.NIC
+ nicID := to.NIC
if e.BindNICID != 0 {
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
}
- toCopy := *to
- to = &toCopy
- netProto, err := e.checkV4Mapped(to, true)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- // Find the enpoint.
- r, err := e.stack.FindRoute(nicid, e.BindAddr, to.Addr, netProto, false /* multicastLoop */)
+ // Find the endpoint.
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */)
if err != nil {
return 0, nil, err
}
@@ -316,7 +324,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
switch e.NetProto {
case header.IPv4ProtocolNumber:
- err = send4(route, e.ID.LocalPort, v, e.ttl)
+ err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
case header.IPv6ProtocolNumber:
err = send6(route, e.ID.LocalPort, v, e.ttl)
@@ -336,23 +344,43 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(o)
- e.mu.Unlock()
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
}
+ return nil
+}
+// SetSockOptBool sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return nil
}
// SetSockOptInt sets a socket option. Currently not supported.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ e.ttl = uint8(v)
+ e.mu.Unlock()
+
+ }
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -375,39 +403,39 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
+ case tcpip.TTLOption:
+ e.rcvMu.Lock()
+ v := int(e.ttl)
+ e.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
}
- return -1, tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- case *tcpip.TTLOption:
- e.rcvMu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.rcvMu.Unlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
+func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
if len(data) < header.ICMPv4MinimumSize {
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()),
+ })
+ pkt.Owner = owner
- icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize))
copy(icmpv4, data)
// Set the ident to the user-specified port. Sequence number should
// already be set by the user.
@@ -422,10 +450,12 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+ pkt.Data = data.ToVectorisedView()
+
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error {
@@ -433,9 +463,11 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
return tcpip.ErrInvalidEndpointState
}
- hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()),
+ })
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
+ icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize))
copy(icmpv6, data)
// Set the ident. Sequence number is provided by the user.
icmpv6.SetIdent(ident)
@@ -447,26 +479,22 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err
dataVV := data.ToVectorisedView()
icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV))
+ pkt.Data = dataVV
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, hdr, dataVV, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS})
+ return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if header.IsV4MappedAddress(addr.Addr) {
- return 0, tcpip.ErrNoRoute
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */)
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
}
-
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -479,31 +507,32 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicid := addr.NIC
+ nicID := addr.NIC
localPort := uint16(0)
switch e.state {
+ case stateInitial:
case stateBound, stateConnected:
localPort = e.ID.LocalPort
if e.BindNICID == 0 {
break
}
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, false)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
@@ -520,14 +549,14 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, err = e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
e.ID = id
e.route = r.Clone()
- e.RegisterNICID = nicid
+ e.RegisterNICID = nicID
e.state = stateConnected
@@ -578,18 +607,18 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
-func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.TransProto, id, e, false /* reuse */, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
@@ -610,7 +639,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, false)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -714,19 +743,23 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h := header.ICMPv4(vv.First())
- if h.Type() != header.ICMPv4EchoReply {
+ h := header.ICMPv4(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h := header.ICMPv6(vv.First())
- if h.Type() != header.ICMPv6EchoReply {
+ h := header.ICMPv6(pkt.TransportHeader().View())
+ // TODO(b/129292233): Determine if len(h) check is still needed after early
+ // parsing.
+ if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
@@ -753,19 +786,21 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- pkt := &icmpPacket{
+ packet := &icmpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
},
}
- pkt.data = vv.Clone(pkt.views[:])
+ // ICMP socket's data includes ICMP header.
+ packet.data = pkt.TransportHeader().View().ToVectorisedView()
+ packet.data.Append(pkt.Data)
- e.rcvList.PushBack(pkt)
- e.rcvBufSize += pkt.data.Size()
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
- pkt.timestamp = e.stack.NowNanoseconds()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
@@ -776,7 +811,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
}
// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't
@@ -798,3 +833,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index bfb16f7c3..74ef6541e 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,20 +104,36 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool {
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TODO(gvisor.dev/issue/170): Implement parsing of ICMP.
+ //
+ // Right now, the Parse() method is tied to enabled protocols passed into
+ // stack.New. This works for UDP and TCP, but we handle ICMP traffic even
+ // when netstack users don't pass ICMP as a supported protocol.
+ return false
+}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD
index 8ea2e6ee5..b989b1209 100644
--- a/pkg/tcpip/transport/packet/BUILD
+++ b/pkg/tcpip/transport/packet/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -22,25 +22,16 @@ go_library(
"endpoint_state.go",
"packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/packet",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 73cdaa265..1b03ad6bb 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,12 +25,12 @@
package packet
import (
- "sync"
+ "fmt"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -41,14 +41,13 @@ type packet struct {
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // views is pre-allocated space to back data. As long as the packet is
- // made up of fewer than 8 buffer.Views, no extra allocation is
- // necessary to store packet data.
- views [8]buffer.View `state:"nosave"`
// timestampNS is the unix time at which the packet was received.
timestampNS int64
// senderAddr is the network address of the sender.
senderAddr tcpip.FullAddress
+ // packetInfo holds additional information like the protocol
+ // of the packet etc.
+ packetInfo tcpip.LinkPacketInfo
}
// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal
@@ -77,10 +76,17 @@ type endpoint struct {
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- stats tcpip.TransportEndpointStats `state:"nosave"`
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
+ boundNIC tcpip.NICID
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -97,12 +103,28 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
sndBufSize: 32 * 1024,
}
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ ep.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ ep.rcvBufSizeMax = rs.Default
+ }
+
if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil {
return nil, err
}
return ep, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (ep *endpoint) Abort() {
+ ep.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() {
ep.mu.Lock()
@@ -125,19 +147,15 @@ func (ep *endpoint) Close() {
}
ep.closed = true
+ ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (iptables.IPTables, error) {
- return ep.stack.IPTables(), nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.PacketEndpoint.ReadPacket.
+func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -162,11 +180,20 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
*addr = packet.senderAddr
}
+ if info != nil {
+ *info = packet.packetInfo
+ }
+
return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
}
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ return ep.ReadPacket(addr, nil)
+}
+
func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- // TODO(b/129292371): Implement.
+ // TODO(gvisor.dev/issue/173): Implement.
return 0, nil, tcpip.ErrInvalidOptionValue
}
@@ -216,7 +243,27 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
// sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
// - packet(7).
- return tcpip.ErrNotSupported
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
+ }
+
+ // Unregister endpoint with all the nics.
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
+
+ // Bind endpoint to receive packets from specific interface.
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ return err
+ }
+
+ ep.bound = true
+ ep.boundNIC = addr.NIC
+
+ return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
@@ -251,26 +298,119 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := ep.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ ep.mu.Lock()
+ ep.sndBufSizeMax = v
+ ep.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := ep.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ ep.rcvMu.Lock()
+ ep.rcvBufSizeMax = v
+ ep.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (ep *endpoint) takeLastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return ep.takeLastError()
+ }
return tcpip.ErrNotSupported
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrNotSupported
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSizeMax
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, ethHeader buffer.View) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -292,45 +432,73 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
- // TODO(b/129292371): Return network protocol.
- if len(ethHeader) > 0 {
+ // TODO(gvisor.dev/issue/173): Return network protocol.
+ if !pkt.LinkHeader().View().IsEmpty() {
// Get info directly from the ethernet header.
- hdr := header.Ethernet(ethHeader)
+ hdr := header.Ethernet(pkt.LinkHeader().View())
packet.senderAddr = tcpip.FullAddress{
- NIC: nicid,
+ NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
- NIC: nicid,
+ NIC: nicID,
Addr: tcpip.Address(localAddr),
}
+ packet.packetInfo.Protocol = netProto
+ packet.packetInfo.PktType = pkt.PktType
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = vv.Clone(packet.views[:])
+ switch pkt.PktType {
+ case tcpip.PacketHost:
+ packet.data = pkt.Data
+ case tcpip.PacketOutgoing:
+ // Strip Link Header.
+ var combinedVV buffer.VectorisedView
+ if v := pkt.NetworkHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ if v := pkt.TransportHeader().View(); !v.IsEmpty() {
+ combinedVV.AppendView(v)
+ }
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ default:
+ panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt))
+ }
+
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
- if len(ethHeader) == 0 {
- // We weren't provided with an actual ethernet header,
- // so fake one.
- ethFields := header.EthernetFields{
- SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
- DstAddr: localAddr,
- Type: netProto,
+ var linkHeader buffer.View
+ if pkt.PktType != tcpip.PacketOutgoing {
+ if pkt.LinkHeader().View().IsEmpty() {
+ // We weren't provided with an actual ethernet header,
+ // so fake one.
+ ethFields := header.EthernetFields{
+ SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
+ DstAddr: localAddr,
+ Type: netProto,
+ }
+ fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
+ fakeHeader.Encode(&ethFields)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader().View()...)
}
- fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
- fakeHeader.Encode(&ethFields)
- ethHeader = buffer.View(fakeHeader)
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ } else {
+ packet.data = buffer.NewVectorisedView(pkt.Size(), pkt.Views())
}
- combinedVV := buffer.View(ethHeader).ToVectorisedView()
- combinedVV.Append(vv)
- packet.data = combinedVV.Clone(packet.views[:])
}
- packet.timestampNS = ep.stack.NowNanoseconds()
+ packet.timestampNS = ep.stack.Clock().NowNanoseconds()
ep.rcvList.PushBack(&packet)
ep.rcvBufSize += packet.data.Size()
@@ -361,3 +529,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo {
func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+
+func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 4af49218c..2eab09088 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -1,5 +1,5 @@
+load("//tools:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -23,26 +23,17 @@ go_library(
"protocol.go",
"raw_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/packet",
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "raw_packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 308f10d24..edc2b5b61 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -26,12 +26,12 @@
package raw
import (
- "sync"
+ "fmt"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -42,10 +42,6 @@ type rawPacket struct {
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // views is pre-allocated space to back data. As long as the packet is
- // made up of fewer than 8 buffer.Views, no extra allocation is
- // necessary to store packet data.
- views [8]buffer.View `state:"nosave"`
// timestampNS is the unix time at which the packet was received.
timestampNS int64
// senderAddr is the network address of the sender.
@@ -67,25 +63,30 @@ type endpoint struct {
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
associated bool
+ hdrIncluded bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
rcvMu sync.Mutex `state:"nosave"`
rcvList rawPacketList
- rcvBufSizeMax int `state:".(int)"`
rcvBufSize int
+ rcvBufSizeMax int `state:".(int)"`
rcvClosed bool
// The following fields are protected by mu.
- mu sync.RWMutex `state:"nosave"`
- sndBufSize int
- closed bool
- connected bool
- bound bool
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ sndBufSizeMax int
+ closed bool
+ connected bool
+ bound bool
// route is the route to a remote network endpoint. It is set via
// Connect(), and is valid only when conneted is true.
route stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// NewEndpoint returns a raw endpoint for the given protocols.
@@ -94,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -106,8 +107,20 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
},
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
associated: associated,
+ hdrIncluded: !associated,
+ }
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
}
// Unassociated endpoints are write-only and users call Write() with IP
@@ -126,6 +139,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (e *endpoint) Close() {
e.mu.Lock()
@@ -160,17 +178,12 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
- return e.stack.IPTables(), nil
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
}
// Read implements tcpip.Endpoint.Read.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if !e.associated {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
- }
-
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -200,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -243,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// If this is an unassociated socket and callee provided a nonzero
// destination address, route using that address.
- if !e.associated {
+ if e.hdrIncluded {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
e.mu.RUnlock()
@@ -304,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrNoRoute
}
- // We don't support IPv6 yet, so this has to be an IPv4 address.
- if len(opts.To.Addr) != header.IPv4AddressSize {
- e.mu.RUnlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
@@ -339,21 +351,26 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch e.NetProto {
- case header.IPv4ProtocolNumber:
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
- return 0, nil, err
- }
- break
+ if e.hdrIncluded {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ if err := route.WriteHeaderIncludedPacket(pkt); err != nil {
+ return 0, nil, err
}
- hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
- if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil {
+ } else {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: int(route.MaxHeaderLength()),
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ })
+ pkt.Owner = e.owner
+ if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: e.TransProto,
+ TTL: route.DefaultTTL(),
+ TOS: stack.DefaultTOS,
+ }, pkt); err != nil {
return 0, nil, err
}
-
- default:
- return 0, nil, tcpip.ErrUnknownProtocol
}
return int64(len(payloadBytes)), nil, nil
@@ -378,11 +395,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // We don't support IPv6 yet.
- if len(addr.Addr) != header.IPv4AddressSize {
- return tcpip.ErrInvalidEndpointState
- }
-
nic := addr.NIC
if e.bound {
if e.BindNICID == 0 {
@@ -448,14 +460,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // Callers must provide an IPv4 address or no network address (for
- // binding to a NIC, but not an address).
- if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
- return tcpip.ErrInvalidEndpointState
- }
-
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -505,16 +511,101 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.SocketDetachFilterOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ e.hdrIncluded = v
+ e.mu.Unlock()
+ return nil
+ }
return tcpip.ErrUnknownProtocolOption
}
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
- return tcpip.ErrUnknownProtocolOption
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err))
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err))
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+ if v < rs.Min {
+ v = rs.Min
+ }
+ e.rcvMu.Lock()
+ e.rcvBufSizeMax = v
+ e.rcvMu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.IPHdrIncludedOption:
+ e.mu.Lock()
+ v := e.hdrIncluded
+ e.mu.Unlock()
+ return v, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -528,7 +619,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -538,32 +629,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
e.rcvMu.Unlock()
return v, nil
- }
-
- return -1, tcpip.ErrUnknownProtocolOption
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
- return tcpip.ErrUnknownProtocolOption
+ return -1, tcpip.ErrUnknownProtocolOption
}
}
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
-func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
e.rcvMu.Lock()
- // Drop the packet if our buffer is currently full.
- if e.rcvClosed {
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
e.rcvMu.Unlock()
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.ClosedReceiver.Increment()
@@ -600,21 +683,33 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- pkt := &rawPacket{
+ packet := &rawPacket{
senderAddr: tcpip.FullAddress{
NIC: route.NICID(),
Addr: route.RemoteAddress,
},
}
- combinedVV := netHeader.ToVectorisedView()
- combinedVV.Append(vv)
- pkt.data = combinedVV.Clone(pkt.views[:])
- pkt.timestampNS = e.stack.NowNanoseconds()
-
- e.rcvList.PushBack(pkt)
- e.rcvBufSize += pkt.data.Size()
-
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(network)+len(transport))
+ headers = append(headers, network...)
+ headers = append(headers, transport...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
+ packet.timestampNS = e.stack.Clock().NowNanoseconds()
+
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
e.rcvMu.Unlock()
e.stats.PacketsReceived.Increment()
// Notify waiters that there's data to be read.
@@ -641,3 +736,6 @@ func (e *endpoint) Info() tcpip.EndpointInfo {
func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+
+// Wait implements stack.TransportEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index f1dbc6f91..234fb95ce 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -16,18 +15,35 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "tcp_endpoint_list",
+ out = "tcp_endpoint_list.go",
+ package = "tcp",
+ prefix = "endpoint",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*endpoint",
+ "Linker": "*endpoint",
+ },
+)
+
go_library(
name = "tcp",
srcs = [
"accept.go",
"connect.go",
+ "connect_unsafe.go",
"cubic.go",
"cubic_state.go",
+ "dispatcher.go",
"endpoint.go",
"endpoint_state.go",
"forwarder.go",
"protocol.go",
+ "rack.go",
+ "rack_state.go",
"rcv.go",
+ "rcv_state.go",
"reno.go",
"sack.go",
"sack_scoreboard.go",
@@ -35,55 +51,49 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
+ "tcp_endpoint_list.go",
"tcp_segment_list.go",
"timer.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/log",
"//pkg/rand",
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
- "//pkg/tmutex",
"//pkg/waiter",
"@com_github_google_btree//:go_default_library",
],
)
-filegroup(
- name = "autogen",
- srcs = [
- "tcp_segment_list.go",
- ],
- visibility = ["//:sandbox"],
-)
-
go_test(
- name = "tcp_test",
- size = "small",
+ name = "tcp_x_test",
+ size = "medium",
srcs = [
"dual_stack_test.go",
"sack_scoreboard_test.go",
"tcp_noracedetector_test.go",
+ "tcp_rack_test.go",
"tcp_sack_test.go",
"tcp_test.go",
"tcp_timestamp_test.go",
],
- # FIXME(b/68809571)
- tags = ["flaky"],
+ shard_count = 10,
deps = [
":tcp",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
@@ -96,6 +106,25 @@ go_test(
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp/testing/context",
+ "//pkg/test/testutil",
"//pkg/waiter",
],
)
+
+go_test(
+ name = "rcv_test",
+ size = "small",
+ srcs = ["rcv_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ ],
+)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = ["timer_test.go"],
+ library = ":tcp",
+ deps = ["//pkg/sleep"],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 65c346046..b706438bd 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -17,13 +17,14 @@ package tcp
import (
"crypto/sha1"
"encoding/binary"
+ "fmt"
"hash"
"io"
- "sync"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -48,17 +49,14 @@ const (
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
-)
-var (
- // SynRcvdCountThreshold is the global maximum number of connections
- // that are allowed to be in SYN-RCVD state before TCP starts using SYN
- // cookies to accept connections.
- //
- // It is an exported variable only for testing, and should not otherwise
- // be used by importers of this package.
+ // SynRcvdCountThreshold is the default global maximum number of
+ // connections that are allowed to be in SYN-RCVD state before TCP
+ // starts using SYN cookies to accept connections.
SynRcvdCountThreshold uint64 = 1000
+)
+var (
// mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460}
@@ -73,29 +71,42 @@ func encodeMSS(mss uint16) uint32 {
return 0
}
-// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
-// protected by a mutex so that we can increment only when it's guaranteed not
-// to go above a threshold.
-var synRcvdCount struct {
- sync.Mutex
- value uint64
- pending sync.WaitGroup
-}
-
// listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
- rcvWnd seqnum.Size
- nonce [2][sha1.BlockSize]byte
+ stack *stack.Stack
+
+ // synRcvdCount is a reference to the stack level synRcvdCount.
+ synRcvdCount *synRcvdCounter
+
+ // rcvWnd is the receive window that is sent by this listening context
+ // in the initial SYN-ACK.
+ rcvWnd seqnum.Size
+
+ // nonce are random bytes that are initialized once when the context
+ // is created and used to seed the hash function when generating
+ // the SYN cookie.
+ nonce [2][sha1.BlockSize]byte
+
+ // listenEP is a reference to the listening endpoint associated with
+ // this context. Can be nil if the context is created by the forwarder.
listenEP *endpoint
+ // hasherMu protects hasher.
hasherMu sync.Mutex
- hasher hash.Hash
- v6only bool
+ // hasher is the hash function used to generate a SYN cookie.
+ hasher hash.Hash
+
+ // v6Only is true if listenEP is a dual stack socket and has the
+ // IPV6_V6ONLY option set.
+ v6Only bool
+
+ // netProto indicates the network protocol(IPv4/v6) for the listening
+ // endpoint.
netProto tcpip.NetworkProtocolNumber
+
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
//
@@ -114,55 +125,22 @@ func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask
}
-// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
-// state. It succeeds if the increment doesn't make the count go beyond the
-// threshold, and fails otherwise.
-func incSynRcvdCount() bool {
- synRcvdCount.Lock()
-
- if synRcvdCount.value >= SynRcvdCountThreshold {
- synRcvdCount.Unlock()
- return false
- }
-
- synRcvdCount.pending.Add(1)
- synRcvdCount.value++
-
- synRcvdCount.Unlock()
- return true
-}
-
-// decSynRcvdCount atomically decrements the global number of endpoints in
-// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
-// succeeded.
-func decSynRcvdCount() {
- synRcvdCount.Lock()
-
- synRcvdCount.value--
- synRcvdCount.pending.Done()
- synRcvdCount.Unlock()
-}
-
-// synCookiesInUse() returns true if the synRcvdCount is greater than
-// SynRcvdCountThreshold.
-func synCookiesInUse() bool {
- synRcvdCount.Lock()
- v := synRcvdCount.value
- synRcvdCount.Unlock()
- return v >= SynRcvdCountThreshold
-}
-
// newListenContext creates a new listen context.
-func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stk,
rcvWnd: rcvWnd,
hasher: sha1.New(),
- v6only: v6only,
+ v6Only: v6Only,
netProto: netProto,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
+ p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
+ if !ok {
+ panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
+ }
+ l.synRcvdCount = p.SynRcvdCounter()
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
@@ -221,85 +199,119 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
-func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
- n := newEndpoint(l.stack, netProto, nil)
- n.v6only = l.v6only
+ n := newEndpoint(l.stack, netProto, queue)
+ n.v6only = l.v6Only
n.ID = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
- n.amss = mssForRoute(&n.route)
+ n.amss = calculateAdvertisedMSS(n.userMSS, n.route)
+ n.setEndpointState(StateConnecting)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
n.initGSO()
- // Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.reusePort, n.bindToDevice); err != nil {
- n.Close()
- return nil, err
- }
-
- n.isRegistered = true
-
- // Create sender and receiver.
- //
- // The receiver at least temporarily has a zero receive window scale,
- // but the caller may change it (before starting the protocol loop).
- n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
- n.rcv = newReceiver(n, irs, seqnum.Size(n.initialReceiveWindow()), 0, seqnum.Size(n.receiveBufferSize()))
// Bootstrap the auto tuning algorithm. Starting at zero will result in
// a large step function on the first window adjustment causing the
// window to grow to a really large value.
n.rcvAutoParams.prevCopied = n.initialReceiveWindow()
- return n, nil
+ return n
}
-// createEndpoint creates a new endpoint in connected state and then performs
-// the TCP 3-way handshake.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+// createEndpointAndPerformHandshake creates a new endpoint in connected state
+// and then performs the TCP 3-way handshake.
+//
+// The new endpoint is returned with e.mu held.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
- cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
- ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
- if err != nil {
- return nil, err
- }
+ isn := generateSecureISN(s.id, l.stack.Seed())
+ ep := l.createConnectingEndpoint(s, isn, irs, opts, queue)
+
+ // Lock the endpoint before registering to ensure that no out of
+ // band changes are possible due to incoming packets etc till
+ // the endpoint is done initializing.
+ ep.mu.Lock()
+ ep.owner = owner
// listenEP is nil when listenContext is used by tcp.Forwarder.
+ deferAccept := time.Duration(0)
if l.listenEP != nil {
l.listenEP.mu.Lock()
- if l.listenEP.state != StateListen {
+ if l.listenEP.EndpointState() != StateListen {
+
l.listenEP.mu.Unlock()
+ // Ensure we release any registrations done by the newly
+ // created endpoint.
+ ep.mu.Unlock()
+ ep.Close()
+
return nil, tcpip.ErrConnectionAborted
}
l.addPendingEndpoint(ep)
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ l.listenEP.propagateInheritableOptionsLocked(ep)
+
+ if !ep.reserveTupleLocked() {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
+ return nil, tcpip.ErrConnectionAborted
+ }
+
+ deferAccept = l.listenEP.deferAccept
l.listenEP.mu.Unlock()
}
- // Perform the 3-way handshake.
- h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
+ // Register new endpoint so that packets are routed to it.
+ if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil {
+ ep.mu.Unlock()
+ ep.Close()
+
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
- h.resetToSynRcvd(cookie, irs, opts)
+ ep.drainClosingSegmentQueue()
+
+ return nil, err
+ }
+
+ ep.isRegistered = true
+
+ // Perform the 3-way handshake.
+ h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
+ ep.mu.Unlock()
ep.Close()
+ ep.notifyAborted()
+
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
+
+ ep.drainClosingSegmentQueue()
+
return nil, err
}
- ep.mu.Lock()
- ep.stack.Stats().TCP.CurrentEstablished.Increment()
- ep.state = StateEstablished
- ep.mu.Unlock()
+ ep.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
@@ -333,23 +345,78 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state, the new endpoint is closed
-// instead.
+// endpoint has transitioned out of the listen state (acceptedChan is nil),
+// the new endpoint is closed instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.state
e.pendingAccepted.Add(1)
- defer e.pendingAccepted.Done()
- acceptedChan := e.acceptedChan
e.mu.Unlock()
- if state == StateListen {
- acceptedChan <- n
- e.waiterQueue.Notify(waiter.EventIn)
- } else {
- n.Close()
+ defer e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ for {
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ n.notifyProtocolGoroutine(notifyReset)
+ return
+ }
+ select {
+ case e.acceptedChan <- n:
+ e.acceptMu.Unlock()
+ e.waiterQueue.Notify(waiter.EventIn)
+ return
+ default:
+ e.acceptCond.Wait()
+ }
}
}
+// propagateInheritableOptionsLocked propagates any options set on the listening
+// endpoint to the newly created endpoint.
+//
+// Precondition: e.mu and n.mu must be held.
+func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
+ n.userTimeout = e.userTimeout
+ n.portFlags = e.portFlags
+ n.boundBindToDevice = e.boundBindToDevice
+ n.boundPortFlags = e.boundPortFlags
+ n.userMSS = e.userMSS
+}
+
+// reserveTupleLocked reserves an accepted endpoint's tuple.
+//
+// Preconditions:
+// * propagateInheritableOptionsLocked has been called.
+// * e.mu is held.
+func (e *endpoint) reserveTupleLocked() bool {
+ dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
+ if !e.stack.ReserveTuple(
+ e.effectiveNetProtos,
+ ProtocolNumber,
+ e.ID.LocalAddress,
+ e.ID.LocalPort,
+ e.boundPortFlags,
+ e.boundBindToDevice,
+ dest,
+ ) {
+ return false
+ }
+
+ e.isPortReserved = true
+ e.boundDest = dest
+ return true
+}
+
+// notifyAborted wakes up any waiters on registered, but not accepted
+// endpoints.
+//
+// This is strictly not required normally as a socket that was never accepted
+// can't really have any registered waiters except when stack.Wait() is called
+// which waits for all registered endpoints to stop and expects an EventHUp.
+func (e *endpoint) notifyAborted() {
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
// handleSynSegment is called in its own goroutine once the listening endpoint
// receives a SYN segment. It is responsible for completing the handshake and
// queueing the new endpoint for acceptance.
@@ -357,53 +424,68 @@ func (e *endpoint) deliverAccepted(n *endpoint) {
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer decSynRcvdCount()
- defer e.decSynRcvdCount()
+ defer ctx.synRcvdCount.dec()
+ defer func() {
+ e.mu.Lock()
+ e.decSynRcvdCount()
+ e.mu.Unlock()
+ }()
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
ctx.removePendingEndpoint(n)
+ n.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
e.deliverAccepted(n)
}
func (e *endpoint) incSynRcvdCount() bool {
- e.mu.Lock()
- if e.synRcvdCount >= cap(e.acceptedChan) {
- e.mu.Unlock()
- return false
+ e.acceptMu.Lock()
+ canInc := e.synRcvdCount < cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ if canInc {
+ e.synRcvdCount++
}
- e.synRcvdCount++
- e.mu.Unlock()
- return true
+ return canInc
}
func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
e.synRcvdCount--
- e.mu.Unlock()
}
func (e *endpoint) acceptQueueIsFull() bool {
- e.mu.Lock()
- if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
- e.mu.Unlock()
- return true
- }
- e.mu.Unlock()
- return false
+ e.acceptMu.Lock()
+ full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ return full
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
- switch s.flags {
- case header.TCPFlagSyn:
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
+ if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) {
+ // If the endpoint is shutdown, reply with reset.
+ //
+ // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST
+ // must be sent in response to a SYN-ACK while in the listen
+ // state to prevent completing a handshake from an old SYN.
+ replyWithReset(s, e.sendTOS, e.ttl)
+ return
+ }
+
+ switch {
+ case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
- if incSynRcvdCount() {
+ if ctx.synRcvdCount.inc() {
// Only handle the syn if the following conditions hold
// - accept queue is not full.
// - number of connections in synRcvd state is less than the
@@ -413,7 +495,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
return
}
- decSynRcvdCount()
+ ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -430,23 +512,33 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
// Send SYN without window scaling because we currently
- // dont't encode this information in the cookie.
+ // don't encode this information in the cookie.
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
- mss := mssForRoute(&s.route)
+ //
+ // Use the user supplied MSS on the listening socket for
+ // new connections, if available.
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
- TSVal: tcpTimeStamp(timeStampOffset()),
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
TSEcr: opts.TSVal,
- MSS: uint16(mss),
+ MSS: calculateAdvertisedMSS(e.userMSS, s.route),
}
- e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.sendSynTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagSyn | header.TCPFlagAck,
+ seq: cookie,
+ ack: s.sequenceNumber + 1,
+ rcvWnd: ctx.rcvWnd,
+ }, synOpts)
e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
}
- case header.TCPFlagAck:
+ case (s.flags & header.TCPFlagAck) != 0:
if e.acceptQueueIsFull() {
// Silently drop the ack as the application can't accept
// the connection at this point. The ack will be
@@ -459,7 +551,15 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
- if !synCookiesInUse() {
+ if !ctx.synRcvdCount.synCookiesInUse() {
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
// Send a reset as this is an ACK for which there is no
// half open connections and we are not using cookies
// yet.
@@ -467,10 +567,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// The only time we should reach here when a connection
// was opened and closed really quickly and a delayed
// ACK was received from the sender.
- replyWithReset(s)
+ replyWithReset(s, e.sendTOS, e.ttl)
return
}
+ iss := s.ackNumber - 1
+ irs := s.sequenceNumber - 1
+
// Since SYN cookies are in use this is potentially an ACK to a
// SYN-ACK we sent but don't have a half open connection state
// as cookies are being used to protect against a potential SYN
@@ -481,7 +584,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// when under a potential syn flood attack.
//
// Validate the cookie.
- data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ data, ok := ctx.isCookieValid(s.id, iss, irs)
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -506,13 +609,35 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
- n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
- if err != nil {
+ n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{})
+
+ n.mu.Lock()
+
+ // Propagate any inheritable options from the listening endpoint
+ // to the newly created endpoint.
+ e.propagateInheritableOptionsLocked(n)
+
+ if !n.reserveTupleLocked() {
+ n.mu.Unlock()
+ n.Close()
+
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil {
+ n.mu.Unlock()
+ n.Close()
+
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
return
}
+ n.isRegistered = true
+
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
// randomly offset when the original SYN-ACK was
@@ -520,8 +645,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
n.tsOffset = 0
// Switch state to connected.
- n.stack.Stats().TCP.CurrentEstablished.Increment()
- n.state = StateEstablished
+ n.isConnectNotified = true
+ n.transitionToStateEstablishedLocked(&handshake{
+ ep: n,
+ iss: iss,
+ ackNum: irs + 1,
+ rcvWnd: seqnum.Size(n.initialReceiveWindow()),
+ sndWnd: s.window,
+ rcvWndScale: e.rcvWndScaleForHandshake(),
+ sndWndScale: rcvdSynOptions.WS,
+ mss: rcvdSynOptions.MSS,
+ })
// Do the delivery in a separate goroutine so
// that we don't block the listen loop in case
@@ -532,6 +666,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// number of goroutines as we do check before
// entering here that there was at least some
// space available in the backlog.
+
+ // Start the protocol goroutine.
+ n.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
go e.deliverAccepted(n)
}
}
@@ -540,16 +678,14 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
- v6only := e.v6only
- e.mu.Unlock()
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.NetProto)
+ v6Only := e.v6only
+ ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto)
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
- e.mu.Lock()
- e.state = StateClose
+ e.setEndpointState(StateClose)
// close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
@@ -562,15 +698,20 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
}
e.mu.Unlock()
+ e.drainClosingSegmentQueue()
+
// Notify waiters that the endpoint is shutdown.
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
s := sleep.Sleeper{}
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
for {
- switch index, _ := s.Fetch(true); index {
+ e.mu.Unlock()
+ index, _ := s.Fetch(true)
+ e.mu.Lock()
+ switch index {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
@@ -583,7 +724,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
s.decRef()
}
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
case wakerForNewSegment:
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 790e89cc3..290172ac9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -15,13 +15,15 @@
package tcp
import (
- "sync"
+ "encoding/binary"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -59,6 +61,9 @@ const (
)
// handshake holds the state used during a TCP 3-way handshake.
+//
+// NOTE: handshake.ep.mu is held during handshake processing. It is released if
+// we are going to block and reacquired when we start processing an event.
type handshake struct {
ep *endpoint
state handshakeState
@@ -84,32 +89,38 @@ type handshake struct {
// rcvWndScale is the receive window scale, as defined in RFC 1323.
rcvWndScale int
-}
-func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- rcvWndScale := ep.rcvWndScaleForHandshake()
+ // startTime is the time at which the first SYN/SYN-ACK was sent.
+ startTime time.Time
- // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
- // window offered in SYN won't be reduced due to the loss of precision if
- // window scaling is enabled after the handshake.
- rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+ // deferAccept if non-zero will drop the final ACK for a passive
+ // handshake till an ACK segment with data is received or the timeout is
+ // hit.
+ deferAccept time.Duration
- // Ensure we can always accept at least 1 byte if the scale specified
- // was too high for the provided rcvWnd.
- if rcvWnd == 0 {
- rcvWnd = 1
- }
+ // acked is true if the the final ACK for a 3-way handshake has
+ // been received. This is required to stop retransmitting the
+ // original SYN-ACK when deferAccept is enabled.
+ acked bool
+}
+func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
h := handshake{
ep: ep,
active: true,
rcvWnd: rcvWnd,
- rcvWndScale: int(rcvWndScale),
+ rcvWndScale: ep.rcvWndScaleForHandshake(),
}
h.resetState()
return h
}
+func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
+ h := newHandshake(ep, rcvWnd)
+ h.resetToSynRcvd(isn, irs, opts, deferAccept)
+ return h
+}
+
// FindWndScale determines the window scale to use for the given maximum window
// size.
func FindWndScale(wnd seqnum.Size) int {
@@ -139,7 +150,32 @@ func (h *handshake) resetState() {
h.flags = header.TCPFlagSyn
h.ackNum = 0
h.mss = 0
- h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+ h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed())
+}
+
+// generateSecureISN generates a secure Initial Sequence number based on the
+// recommendation here https://tools.ietf.org/html/rfc6528#page-3.
+func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value {
+ isnHasher := jenkins.Sum32(seed)
+ isnHasher.Write([]byte(id.LocalAddress))
+ isnHasher.Write([]byte(id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, id.LocalPort)
+ isnHasher.Write(portBuf)
+ binary.LittleEndian.PutUint16(portBuf, id.RemotePort)
+ isnHasher.Write(portBuf)
+ // The time period here is 64ns. This is similar to what linux uses
+ // generate a sequence number that overlaps less than one
+ // time per MSL (2 minutes).
+ //
+ // A 64ns clock ticks 10^9/64 = 15625000) times in a second.
+ // To wrap the whole 32 bit space would require
+ // 2^32/1562500 ~ 274 seconds.
+ //
+ // Which sort of guarantees that we won't reuse the ISN for a new
+ // connection for the same tuple for at least 274s.
+ isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6)
+ return seqnum.Value(isn)
}
// effectiveRcvWndScale returns the effective receive window scale to be used.
@@ -154,7 +190,7 @@ func (h *handshake) effectiveRcvWndScale() uint8 {
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state.
-func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) {
h.active = false
h.state = handshakeSynRcvd
h.flags = header.TCPFlagSyn | header.TCPFlagAck
@@ -162,9 +198,8 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea
h.ackNum = irs + 1
h.mss = opts.MSS
h.sndWndScale = opts.WS
- h.ep.mu.Lock()
- h.ep.state = StateSynRecv
- h.ep.mu.Unlock()
+ h.deferAccept = deferAccept
+ h.ep.setEndpointState(StateSynRecv)
}
// checkAck checks if the ACK number, if present, of a segment received during
@@ -191,6 +226,12 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// acceptable if the ack field acknowledges the SYN.
if s.flagIsSet(header.TCPFlagRst) {
if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK
+ // was acceptable then signal the user "error: connection reset", drop
+ // the segment, enter CLOSED state, delete TCB, and return."
+ h.ep.workerCleanup = true
+ // Although the RFC above calls out ECONNRESET, Linux actually returns
+ // ECONNREFUSED here so we do as well.
return tcpip.ErrConnectionRefused
}
return nil
@@ -225,6 +266,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// and the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
h.state = handshakeCompleted
+
+ h.ep.transitionToStateEstablishedLocked(h)
+
h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
return nil
}
@@ -233,26 +277,33 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error {
// but resend our own SYN and wait for it to be acknowledged in the
// SYN-RCVD state.
h.state = handshakeSynRcvd
- h.ep.mu.Lock()
- h.ep.state = StateSynRecv
ttl := h.ep.ttl
- h.ep.mu.Unlock()
+ amss := h.ep.amss
+ h.ep.setEndpointState(StateSynRecv)
synOpts := header.TCPSynOptions{
WS: int(h.effectiveRcvWndScale()),
TS: rcvSynOpts.TS,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
// We only send SACKPermitted if the other side indicated it
// permits SACK. This is not explicitly defined in the RFC but
// this is the behaviour implemented by Linux.
SACKPermitted: rcvSynOpts.SACKPermitted,
- MSS: h.ep.amss,
+ MSS: amss,
}
if ttl == 0 {
ttl = s.route.DefaultTTL()
}
- h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
return nil
}
@@ -272,6 +323,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return nil
}
+ // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a
+ // sequence number outside of the window causes an ACK with the proper seq
+ // number and "After sending the acknowledgment, drop the unacceptable
+ // segment and return."
+ if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd)
+ return nil
+ }
+
if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
// We received two SYN segments with different sequence
// numbers, so we reset this and restart the whole
@@ -292,17 +352,33 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
WS: h.rcvWndScale,
TS: h.ep.sendTSOk,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: h.ep.sackPermitted,
MSS: h.ep.amss,
}
- h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ h.ep.sendSynTCP(&s.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
return nil
}
// We have previously received (and acknowledged) the peer's SYN. If the
// peer acknowledges our SYN, the handshake is completed.
if s.flagIsSet(header.TCPFlagAck) {
+ // If deferAccept is not zero and this is a bare ACK and the
+ // timeout is not hit then drop the ACK.
+ if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept {
+ h.acked = true
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// If the timestamp option is negotiated and the segment does
// not carry a timestamp option then the segment must be dropped
// as per https://tools.ietf.org/html/rfc7323#section-3.2.
@@ -316,6 +392,15 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
}
h.state = handshakeCompleted
+
+ h.ep.transitionToStateEstablishedLocked(h)
+
+ // If the segment has data then requeue it for the receiver
+ // to process it again once main loop is started.
+ if s.data.Size() > 0 {
+ s.incRef()
+ h.ep.enqueueSegment(s)
+ }
return nil
}
@@ -401,7 +486,12 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
if n&notifyDrain != 0 {
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
+ }
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
}
}
@@ -418,12 +508,11 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.startTime = time.Now()
// Initialize the resend timer.
resendWaker := sleep.Waker{}
timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, func() {
- resendWaker.Assert()
- })
+ rt := time.AfterFunc(timeOut, resendWaker.Assert)
defer rt.Stop()
// Set up the wakers.
@@ -442,13 +531,13 @@ func (h *handshake) execute() *tcpip.Error {
// Send the initial SYN segment and loop until the handshake is
// completed.
- h.ep.amss = mssForRoute(&h.ep.route)
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
TSVal: h.ep.timestamp(),
- TSEcr: h.ep.recentTS,
+ TSEcr: h.ep.recentTimestamp(),
SACKPermitted: bool(sackEnabled),
MSS: h.ep.amss,
}
@@ -465,21 +554,52 @@ func (h *handshake) execute() *tcpip.Error {
synOpts.WS = -1
}
}
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
for h.state != handshakeCompleted {
- switch index, _ := s.Fetch(true); index {
+ h.ep.mu.Unlock()
+ index, _ := s.Fetch(true)
+ h.ep.mu.Lock()
+ switch index {
+
case wakerForResend:
timeOut *= 2
- if timeOut > 60*time.Second {
+ if timeOut > MaxRTO {
return tcpip.ErrTimeout
}
rt.Reset(timeOut)
- h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ // Resend the SYN/SYN-ACK only if the following conditions hold.
+ // - It's an active handshake (deferAccept does not apply)
+ // - It's a passive handshake and we have not yet got the final-ACK.
+ // - It's a passive handshake and we got an ACK but deferAccept is
+ // enabled and we are now past the deferAccept duration.
+ // The last is required to provide a way for the peer to complete
+ // the connection with another ACK or data (as ACKs are never
+ // retransmitted on their own).
+ if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept {
+ h.ep.sendSynTCP(&h.ep.route, tcpFields{
+ id: h.ep.ID,
+ ttl: h.ep.ttl,
+ tos: h.ep.sendTOS,
+ flags: h.flags,
+ seq: h.iss,
+ ack: h.ackNum,
+ rcvWnd: h.rcvWnd,
+ }, synOpts)
+ }
case wakerForNotification:
n := h.ep.fetchNotifications()
- if n&notifyClose != 0 {
+ if (n&notifyClose)|(n&notifyAbort) != 0 {
return tcpip.ErrAborted
}
if n&notifyDrain != 0 {
@@ -495,7 +615,12 @@ func (h *handshake) execute() *tcpip.Error {
}
}
close(h.ep.drainDone)
+ h.ep.mu.Unlock()
<-h.ep.undrain
+ h.ep.mu.Lock()
+ }
+ if n&notifyError != 0 {
+ return h.ep.takeLastError()
}
case wakerForNewSegment:
@@ -519,17 +644,17 @@ func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
var optionPool = sync.Pool{
New: func() interface{} {
- return make([]byte, maxOptionSize)
+ return &[maxOptionSize]byte{}
},
}
func getOptions() []byte {
- return optionPool.Get().([]byte)
+ return (*optionPool.Get().(*[maxOptionSize]byte))[:]
}
func putOptions(options []byte) {
// Reslice to full capacity.
- optionPool.Put(options[0:cap(options)])
+ optionPool.Put(optionsToArray(options))
}
func makeSynOptions(opts header.TCPSynOptions) []byte {
@@ -585,18 +710,33 @@ func makeSynOptions(opts header.TCPSynOptions) []byte {
return options[:offset]
}
-func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
- options := makeSynOptions(opts)
+// tcpFields is a struct to carry different parameters required by the
+// send*TCP variant functions below.
+type tcpFields struct {
+ id stack.TransportEndpointID
+ ttl uint8
+ tos uint8
+ flags byte
+ seq seqnum.Value
+ ack seqnum.Value
+ rcvWnd seqnum.Size
+ opts []byte
+ txHash uint32
+}
+
+func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error {
+ tf.opts = makeSynOptions(opts)
// We ignore SYN send errors and let the callers re-attempt send.
- if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil {
+ if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil {
e.stats.SendErrors.SynSendToNetworkFailed.Increment()
}
- putOptions(options)
+ putOptions(tf.opts)
return nil
}
-func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil {
+func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error {
+ tf.txHash = e.txHash
+ if err := sendTCP(r, tf, data, gso, e.owner); err != nil {
e.stats.SendErrors.SegmentSendToNetworkFailed.Increment()
return err
}
@@ -604,26 +744,21 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu
return nil
}
-func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDescriptor, data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) {
- optLen := len(opts)
- hdr := &d.Hdr
- packetSize := d.Size
- off := d.Off
- // Initialize the header.
- tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) {
+ optLen := len(tf.opts)
+ tcp := header.TCP(pkt.TransportHeader().Push(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{
- SrcPort: id.LocalPort,
- DstPort: id.RemotePort,
- SeqNum: uint32(seq),
- AckNum: uint32(ack),
+ SrcPort: tf.id.LocalPort,
+ DstPort: tf.id.RemotePort,
+ SeqNum: uint32(tf.seq),
+ AckNum: uint32(tf.ack),
DataOffset: uint8(header.TCPMinimumSize + optLen),
- Flags: flags,
- WindowSize: uint16(rcvWnd),
+ Flags: tf.flags,
+ WindowSize: uint16(tf.rcvWnd),
})
- copy(tcp[header.TCPMinimumSize:], opts)
+ copy(tcp[header.TCPMinimumSize:], tf.opts)
- length := uint16(hdr.UsedLength() + packetSize)
- xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, uint16(pkt.Size()))
// Only calculate the checksum if offloading isn't supported.
if gso != nil && gso.NeedsCsum {
// This is called CHECKSUM_PARTIAL in the Linux kernel. We
@@ -632,41 +767,53 @@ func buildTCPHdr(r *stack.Route, id stack.TransportEndpointID, d *stack.PacketDe
// header and data and get the right sum of the TCP packet.
tcp.SetChecksum(xsum)
} else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
- xsum = header.ChecksumVVWithOffset(data, xsum, off, packetSize)
+ xsum = header.ChecksumVV(pkt.Data, xsum)
tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
}
-
}
-func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- optLen := len(opts)
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
+func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+ // We need to shallow clone the VectorisedView here as ReadToView will
+ // split the VectorisedView and Trim underlying views as it splits. Not
+ // doing the clone here will cause the underlying views of data itself
+ // to be altered.
+ data = data.Clone(nil)
+
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
}
mss := int(gso.MSS)
n := (data.Size() + mss - 1) / mss
- hdrs := stack.NewPacketDescriptors(n, header.TCPMinimumSize+int(r.MaxHeaderLength())+optLen)
-
size := data.Size()
- off := 0
+ hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen
+ var pkts stack.PacketBufferList
for i := 0; i < n; i++ {
packetSize := mss
if packetSize > size {
packetSize = size
}
size -= packetSize
- hdrs[i].Off = off
- hdrs[i].Size = packetSize
- buildTCPHdr(r, id, &hdrs[i], data, flags, seq, ack, rcvWnd, opts, gso)
- off += packetSize
- seq = seq.Add(seqnum.Size(packetSize))
- }
- if ttl == 0 {
- ttl = r.DefaultTTL()
- }
- sent, err := r.WritePackets(gso, hdrs, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos})
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: hdrSize,
+ })
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = r.NetworkProtocolNumber()
+ data.ReadToVV(&pkt.Data, packetSize)
+ buildTCPHdr(r, tf, pkt, gso)
+ tf.seq = tf.seq.Add(seqnum.Size(packetSize))
+ pkts.PushBack(pkt)
+ }
+
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
+ }
+ sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos})
if err != nil {
r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent))
}
@@ -676,32 +823,33 @@ func sendTCPBatch(r *stack.Route, id stack.TransportEndpointID, data buffer.Vect
// sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity.
-func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
- optLen := len(opts)
- if rcvWnd > 0xffff {
- rcvWnd = 0xffff
+func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error {
+ optLen := len(tf.opts)
+ if tf.rcvWnd > 0xffff {
+ tf.rcvWnd = 0xffff
}
if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() {
- return sendTCPBatch(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso)
+ return sendTCPBatch(r, tf, data, gso, owner)
}
- d := &stack.PacketDescriptor{
- Hdr: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen),
- Off: 0,
- Size: data.Size(),
- }
- buildTCPHdr(r, id, d, data, flags, seq, ack, rcvWnd, opts, gso)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen,
+ Data: data,
+ })
+ pkt.Hash = tf.txHash
+ pkt.Owner = owner
+ buildTCPHdr(r, tf, pkt, gso)
- if ttl == 0 {
- ttl = r.DefaultTTL()
+ if tf.ttl == 0 {
+ tf.ttl = r.DefaultTTL()
}
- if err := r.WritePacket(gso, d.Hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil {
r.Stats().TCP.SegmentSendErrors.Increment()
return err
}
r.Stats().TCP.SegmentsSent.Increment()
- if (flags & header.TCPFlagRst) != 0 {
+ if (tf.flags & header.TCPFlagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment()
}
return nil
@@ -730,7 +878,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:])
- offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:])
}
if e.sackPermitted && len(sackBlocks) > 0 {
offset += header.EncodeNOP(options[offset:])
@@ -749,11 +897,20 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
// sendRaw sends a TCP segment to the endpoint's peer.
func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
var sackBlocks []header.SACKBlock
- if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
}
options := e.makeOptions(sackBlocks)
- err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso)
+ err := e.sendTCP(&e.route, tcpFields{
+ id: e.ID,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: rcvWnd,
+ opts: options,
+ }, data, e.gso)
putOptions(options)
return err
}
@@ -768,7 +925,6 @@ func (e *endpoint) handleWrite() *tcpip.Error {
first := e.sndQueue.Front()
if first != nil {
e.snd.writeList.PushBackList(&e.sndQueue)
- e.snd.sndNxtList.UpdateForward(e.sndBufInQueue)
e.sndBufInQueue = 0
}
@@ -786,6 +942,9 @@ func (e *endpoint) handleWrite() *tcpip.Error {
}
func (e *endpoint) handleClose() *tcpip.Error {
+ if !e.EndpointState().connected() {
+ return nil
+ }
// Drain the send queue.
e.handleWrite()
@@ -802,69 +961,194 @@ func (e *endpoint) handleClose() *tcpip.Error {
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
- if e.state == StateEstablished || e.state == StateCloseWait {
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
- }
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
- if err != tcpip.ErrConnectionReset {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
+ // The exact sequence number to be used for the RST is the same as the
+ // one used by Linux. We need to handle the case of window being shrunk
+ // which can cause sndNxt to be outside the acceptable window on the
+ // receiver.
+ //
+ // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more
+ // information.
+ sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd)
+ resetSeqNum := sndWndEnd
+ if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) {
+ resetSeqNum = e.snd.sndNxt
+ }
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0)
}
}
// completeWorkerLocked is called by the worker goroutine when it's about to
-// exit. It marks the worker as completed and performs cleanup work if requested
-// by Close().
+// exit.
func (e *endpoint) completeWorkerLocked() {
+ // Worker is terminating(either due to moving to
+ // CLOSED or ERROR state, ensure we release all
+ // registrations port reservations even if the socket
+ // itself is not yet closed by the application.
e.workerRunning = false
if e.workerCleanup {
e.cleanupLocked()
}
}
-// handleSegments pulls segments from the queue and processes them. It returns
-// no error if the protocol loop should continue, an error otherwise.
-func (e *endpoint) handleSegments() *tcpip.Error {
+// transitionToStateEstablisedLocked transitions a given endpoint
+// to an established state using the handshake parameters provided.
+// It also initializes sender/receiver.
+func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) {
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ rcvBufSize := seqnum.Size(e.receiveBufferSize())
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
+ // Bootstrap the auto tuning algorithm. Starting at zero will
+ // result in a really large receive window after the first auto
+ // tuning adjustment.
+ e.rcvAutoParams.prevCopied = int(h.rcvWnd)
+ e.rcvListMu.Unlock()
+
+ e.setEndpointState(StateEstablished)
+}
+
+// transitionToStateCloseLocked ensures that the endpoint is
+// cleaned up from the transport demuxer, "before" moving to
+// StateClose. This will ensure that no packet will be
+// delivered to this endpoint from the demuxer when the endpoint
+// is transitioned to StateClose.
+func (e *endpoint) transitionToStateCloseLocked() {
+ s := e.EndpointState()
+ if s == StateClose {
+ return
+ }
+
+ if s.connected() {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.stack.Stats().TCP.EstablishedClosed.Increment()
+ }
+
+ // Mark the endpoint as fully closed for reads/writes.
+ e.cleanupLocked()
+ e.setEndpointState(StateClose)
+}
+
+// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed
+// segment to any other endpoint other than the current one. This is called
+// only when the endpoint is in StateClose and we want to deliver the segment
+// to any other listening endpoint. We reply with RST if we cannot find one.
+func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) {
+ ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route)
+ if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" {
+ // Dual-stack socket, try IPv4.
+ ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route)
+ }
+ if ep == nil {
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
+ s.decRef()
+ return
+ }
+
+ if e == ep {
+ panic("current endpoint not removed from demuxer, enqueing segments to itself")
+ }
+
+ if ep := ep.(*endpoint); ep.enqueueSegment(s) {
+ ep.newSegmentWaker.Assert()
+ }
+}
+
+// Drain segment queue from the endpoint and try to re-match the segment to a
+// different endpoint. This is used when the current endpoint is transitioned to
+// StateClose and has been unregistered from the transport demuxer.
+func (e *endpoint) drainClosingSegmentQueue() {
+ for {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ break
+ }
+
+ e.tryDeliverSegmentFromClosedEndpoint(s)
+ }
+}
+
+func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ switch e.EndpointState() {
+ // In case of a RST in CLOSE-WAIT linux moves
+ // the socket to closed state with an error set
+ // to indicate EPIPE.
+ //
+ // Technically this seems to be at odds w/ RFC.
+ // As per https://tools.ietf.org/html/rfc793#section-2.7
+ // page 69 the behavior for a segment arriving
+ // w/ RST bit set in CLOSE-WAIT is inlined below.
+ //
+ // ESTABLISHED
+ // FIN-WAIT-1
+ // FIN-WAIT-2
+ // CLOSE-WAIT
+
+ // If the RST bit is set then, any outstanding RECEIVEs and
+ // SEND should receive "reset" responses. All segment queues
+ // should be flushed. Users should also receive an unsolicited
+ // general "connection reset" signal. Enter the CLOSED state,
+ // delete the TCB, and return.
+ case StateCloseWait:
+ e.transitionToStateCloseLocked()
+ e.HardError = tcpip.ErrAborted
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ return false, nil
+ default:
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+
+ // Notify protocol goroutine. This is required when
+ // handleSegment is invoked from the processor goroutine
+ // rather than the worker goroutine.
+ e.notifyProtocolGoroutine(notifyResetByPeer)
+ return false, tcpip.ErrConnectionReset
+ }
+ }
+ return true, nil
+}
+
+// handleSegments processes all inbound segments.
+func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
+ if e.EndpointState().closed() {
+ return nil
+ }
s := e.segmentQueue.dequeue()
if s == nil {
checkRequeue = false
break
}
- // Invoke the tcp probe if installed.
- if e.probe != nil {
- e.probe(e.completeState())
+ cont, err := e.handleSegment(s)
+ if err != nil {
+ s.decRef()
+ return err
}
-
- if s.flagIsSet(header.TCPFlagRst) {
- if e.rcv.acceptable(s.sequenceNumber, 0) {
- // RFC 793, page 37 states that "in all states
- // except SYN-SENT, all reset (RST) segments are
- // validated by checking their SEQ-fields." So
- // we only process it if it's acceptable.
- s.decRef()
- return tcpip.ErrConnectionReset
- }
- } else if s.flagIsSet(header.TCPFlagAck) {
- // Patch the window size in the segment according to the
- // send window scale.
- s.window <<= e.snd.sndWndScale
-
- // RFC 793, page 41 states that "once in the ESTABLISHED
- // state all segments must carry current acknowledgment
- // information."
- e.rcv.handleRcvdSegment(s)
- e.snd.handleRcvdSegment(s)
+ if !cont {
+ s.decRef()
+ return nil
}
- s.decRef()
}
- // If the queue is not empty, make sure we'll wake up in the next
- // iteration.
- if checkRequeue && !e.segmentQueue.empty() {
+ // When fastPath is true we don't want to wake up the worker
+ // goroutine. If the endpoint has more segments to process the
+ // dispatcher will call handleSegments again anyway.
+ if !fastPath && checkRequeue && !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
@@ -873,23 +1157,114 @@ func (e *endpoint) handleSegments() *tcpip.Error {
e.snd.sendAck()
}
- e.resetKeepaliveTimer(true)
+ e.resetKeepaliveTimer(true /* receivedData */)
return nil
}
+func (e *endpoint) probeSegment() {
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+}
+
+// handleSegment handles a given segment and notifies the worker goroutine if
+// if the connection should be terminated.
+func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) {
+ // Invoke the tcp probe if installed. The tcp probe function will update
+ // the TCPEndpointState after the segment is processed.
+ defer e.probeSegment()
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if ok, err := e.handleReset(s); !ok {
+ return false, err
+ }
+ } else if s.flagIsSet(header.TCPFlagSyn) {
+ // See: https://tools.ietf.org/html/rfc5961#section-4.1
+ // 1) If the SYN bit is set, irrespective of the sequence number, TCP
+ // MUST send an ACK (also referred to as challenge ACK) to the remote
+ // peer:
+ //
+ // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK>
+ //
+ // After sending the acknowledgment, TCP MUST drop the unacceptable
+ // segment and stop processing further.
+ //
+ // By sending an ACK, the remote peer is challenged to confirm the loss
+ // of the previous connection and the request to start a new connection.
+ // A legitimate peer, after restart, would not have a TCB in the
+ // synchronized state. Thus, when the ACK arrives, the peer should send
+ // a RST segment back with the sequence number derived from the ACK
+ // field that caused the RST.
+
+ // This RST will confirm that the remote peer has indeed closed the
+ // previous connection. Upon receipt of a valid RST, the local TCP
+ // endpoint MUST terminate its connection. The local TCP endpoint
+ // should then rely on SYN retransmission from the remote end to
+ // re-establish the connection.
+
+ e.snd.sendAck()
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ drop, err := e.rcv.handleRcvdSegment(s)
+ if err != nil {
+ return false, err
+ }
+ if drop {
+ return true, nil
+ }
+
+ // Now check if the received segment has caused us to transition
+ // to a CLOSED state, if yes then terminate processing and do
+ // not invoke the sender.
+ state := e.state
+ if state == StateClose {
+ // When we get into StateClose while processing from the queue,
+ // return immediately and let the protocolMainloop handle it.
+ //
+ // We can reach StateClose only while processing a previous segment
+ // or a notification from the protocolMainLoop (caller goroutine).
+ // This means that with this return, the segment dequeue below can
+ // never occur on a closed endpoint.
+ s.decRef()
+ return false, nil
+ }
+
+ e.snd.handleRcvdSegment(s)
+ }
+
+ return true, nil
+}
+
// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
// keepalive packets periodically when the connection is idle. If we don't hear
// from the other side after a number of tries, we terminate the connection.
func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ userTimeout := e.userTimeout
+
e.keepalive.Lock()
if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
e.keepalive.Unlock()
return nil
}
+ // If a userTimeout is set then abort the connection if it is
+ // exceeded.
+ if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 {
+ e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
+ return tcpip.ErrTimeout
+ }
+
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
@@ -906,7 +1281,6 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
// whether it is enabled for this endpoint.
func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
e.keepalive.Lock()
- defer e.keepalive.Unlock()
if receivedData {
e.keepalive.unacked = 0
}
@@ -914,6 +1288,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
// data to send.
if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
e.keepalive.timer.disable()
+ e.keepalive.Unlock()
return
}
if e.keepalive.unacked > 0 {
@@ -921,6 +1296,7 @@ func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
} else {
e.keepalive.timer.enable(e.keepalive.idle)
}
+ e.keepalive.Unlock()
}
// disableKeepaliveTimer stops the keepalive timer.
@@ -933,7 +1309,8 @@ func (e *endpoint) disableKeepaliveTimer() {
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
// goroutine and is responsible for sending segments and handling received
// segments.
-func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error {
+ e.mu.Lock()
var closeTimer *time.Timer
var closeWaker sleep.Waker
@@ -956,6 +1333,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.mu.Unlock()
+ e.drainClosingSegmentQueue()
+
// When the protocol loop exits we should wake up our waiters.
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -966,61 +1345,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
// completion.
initialRcvWnd := e.initialReceiveWindow()
h := newHandshake(e, seqnum.Size(initialRcvWnd))
- e.mu.Lock()
- h.ep.state = StateSynSent
- e.mu.Unlock()
+ h.ep.setEndpointState(StateSynSent)
if err := h.execute(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
- e.mu.Lock()
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
- e.state = StateError
+ e.setEndpointState(StateError)
e.HardError = err
+ e.workerCleanup = true
// Lock released below.
epilogue()
-
return err
}
-
- // Transfer handshake state to TCP connection. We disable
- // receive window scaling if the peer doesn't support it
- // (indicated by a negative send window scale).
- e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
-
- rcvBufSize := seqnum.Size(e.receiveBufferSize())
- e.rcvListMu.Lock()
- e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize)
- // boot strap the auto tuning algorithm. Starting at zero will
- // result in a large step function on the first proper causing
- // the window to just go to a really large value after the first
- // RTT itself.
- e.rcvAutoParams.prevCopied = initialRcvWnd
- e.rcvListMu.Unlock()
}
e.keepalive.timer.init(&e.keepalive.waker)
defer e.keepalive.timer.cleanup()
- // Tell waiters that the endpoint is connected and writable.
- e.mu.Lock()
- if e.state != StateEstablished {
- e.stack.Stats().TCP.CurrentEstablished.Increment()
- e.state = StateEstablished
- }
drained := e.drainDone != nil
- e.mu.Unlock()
if drained {
close(e.drainDone)
<-e.undrain
}
- e.waiterQueue.Notify(waiter.EventOut)
-
// Set up the functions that will be called when the main protocol loop
// wakes up.
funcs := []struct {
@@ -1036,25 +1386,33 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
f: e.handleClose,
},
{
- w: &e.newSegmentWaker,
- f: e.handleSegments,
- },
- {
w: &closeWaker,
f: func() *tcpip.Error {
- return tcpip.ErrConnectionAborted
+ // This means the socket is being closed due
+ // to the TCP-FIN-WAIT2 timeout was hit. Just
+ // mark the socket as closed.
+ e.transitionToStateCloseLocked()
+ e.workerCleanup = true
+ return nil
},
},
{
w: &e.snd.resendWaker,
f: func() *tcpip.Error {
if !e.snd.retransmitTimerExpired() {
+ e.stack.Stats().TCP.EstablishedTimedout.Increment()
return tcpip.ErrTimeout
}
return nil
},
},
{
+ w: &e.newSegmentWaker,
+ f: func() *tcpip.Error {
+ return e.handleSegments(false /* fastPath */)
+ },
+ },
+ {
w: &e.keepalive.waker,
f: e.keepaliveTimerExpired,
},
@@ -1080,22 +1438,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.snd.updateMaxPayloadSize(mtu, count)
}
- if n&notifyReset != 0 {
- e.mu.Lock()
- e.resetConnectionLocked(tcpip.ErrConnectionAborted)
- e.mu.Unlock()
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
+ return tcpip.ErrConnectionAborted
+ }
+
+ if n&notifyResetByPeer != 0 {
+ return tcpip.ErrConnectionReset
}
+
if n&notifyClose != 0 && closeTimer == nil {
- // Reset the connection 3 seconds after
- // the endpoint has been closed.
- //
- // The timer could fire in background
- // when the endpoint is drained. That's
- // OK as the loop here will not honor
- // the firing until the undrain arrives.
- closeTimer = time.AfterFunc(3*time.Second, func() {
- closeWaker.Assert()
- })
+ if e.EndpointState() == StateFinWait2 && e.closed {
+ // The socket has been closed and we are in FIN_WAIT2
+ // so start the FIN_WAIT2 timer.
+ closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert)
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
}
if n&notifyKeepaliveChanged != 0 {
@@ -1107,16 +1464,26 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
- if err := e.handleSegments(); err != nil {
+ if err := e.handleSegments(false /* fastPath */); err != nil {
return err
}
}
- if e.state != StateError {
+ if !e.EndpointState().closed() {
+ // Only block the worker if the endpoint
+ // is not in closed state or error state.
close(e.drainDone)
+ e.mu.Unlock()
<-e.undrain
+ e.mu.Lock()
}
}
+ if n&notifyTickleWorker != 0 {
+ // Just a tickle notification. No need to do
+ // anything.
+ return nil
+ }
+
return nil
},
},
@@ -1128,14 +1495,21 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
s.AddWaker(funcs[i].w, i)
}
+ // Notify the caller that the waker initialization is complete and the
+ // endpoint is ready.
+ if wakerInitDone != nil {
+ close(wakerInitDone)
+ }
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.waiterQueue.Notify(waiter.EventOut)
+
// The following assertions and notifications are needed for restored
// endpoints. Fresh newly created endpoints have empty states and should
// not invoke any.
- e.segmentQueue.mu.Lock()
- if !e.segmentQueue.list.Empty() {
+ if !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
- e.segmentQueue.mu.Unlock()
e.rcvListMu.Lock()
if !e.rcvList.Empty() {
@@ -1143,41 +1517,209 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
}
e.rcvListMu.Unlock()
- e.mu.RLock()
if e.workerCleanup {
e.notifyProtocolGoroutine(notifyClose)
}
- e.mu.RUnlock()
// Main loop. Handle segments until both send and receive ends of the
// connection have completed.
- for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
- e.workMu.Unlock()
- v, _ := s.Fetch(true)
- e.workMu.Lock()
- if err := funcs[v].f(); err != nil {
- e.mu.Lock()
- // Ensure we release all endpoint registration and route
- // references as the connection is now in an error
- // state.
- e.workerCleanup = true
+ cleanupOnError := func(err *tcpip.Error) {
+ e.stack.Stats().TCP.CurrentConnected.Decrement()
+ e.workerCleanup = true
+ if err != nil {
e.resetConnectionLocked(err)
- // Lock released below.
- epilogue()
+ }
+ // Lock released below.
+ epilogue()
+ }
+loop:
+ for {
+ switch e.EndpointState() {
+ case StateTimeWait, StateClose, StateError:
+ break loop
+ }
+
+ e.mu.Unlock()
+ v, _ := s.Fetch(true)
+ e.mu.Lock()
+
+ // We need to double check here because the notification may be
+ // stale by the time we got around to processing it.
+ switch e.EndpointState() {
+ case StateError:
+ // If the endpoint has already transitioned to an ERROR
+ // state just pass nil here as any reset that may need
+ // to be sent etc should already have been done and we
+ // just want to terminate the loop and cleanup the
+ // endpoint.
+ cleanupOnError(nil)
return nil
+ case StateTimeWait:
+ fallthrough
+ case StateClose:
+ break loop
+ default:
+ if err := funcs[v].f(); err != nil {
+ cleanupOnError(err)
+ return nil
+ }
}
}
- // Mark endpoint as closed.
- e.mu.Lock()
- if e.state != StateError {
- e.stack.Stats().TCP.EstablishedResets.Increment()
- e.stack.Stats().TCP.CurrentEstablished.Decrement()
- e.state = StateClose
+ var reuseTW func()
+ if e.EndpointState() == StateTimeWait {
+ // Disable close timer as we now entering real TIME_WAIT.
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+ // Mark the current sleeper done so as to free all associated
+ // wakers.
+ s.Done()
+ // Wake up any waiters before we enter TIME_WAIT.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ e.workerCleanup = true
+ reuseTW = e.doTimeWait()
+ }
+
+ // Handle any StateError transition from StateTimeWait.
+ if e.EndpointState() == StateError {
+ cleanupOnError(nil)
+ return nil
}
+
+ e.transitionToStateCloseLocked()
+
// Lock released below.
epilogue()
+ // A new SYN was received during TIME_WAIT and we need to abort
+ // the timewait and redirect the segment to the listener queue
+ if reuseTW != nil {
+ reuseTW()
+ }
+
return nil
}
+
+// handleTimeWaitSegments processes segments received during TIME_WAIT
+// state.
+func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+ extTW, newSyn := e.rcv.handleTimeWaitSegment(s)
+ if newSyn {
+ info := e.EndpointInfo.TransportEndpointInfo
+ newID := info.ID
+ newID.RemoteAddress = ""
+ newID.RemotePort = 0
+ netProtos := []tcpip.NetworkProtocolNumber{info.NetProto}
+ // If the local address is an IPv4 address then also
+ // look for IPv6 dual stack endpoints that might be
+ // listening on the local address.
+ if newID.LocalAddress.To4() != "" {
+ netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber}
+ }
+ for _, netProto := range netProtos {
+ if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil {
+ tcpEP := listenEP.(*endpoint)
+ if EndpointState(tcpEP.State()) == StateListen {
+ reuseTW = func() {
+ if !tcpEP.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+ tcpEP.newSegmentWaker.Assert()
+ }
+ // We explicitly do not decRef
+ // the segment as it's still
+ // valid and being reflected to
+ // a listening endpoint.
+ return false, reuseTW
+ }
+ }
+ }
+ }
+ if extTW {
+ extendTimeWait = true
+ }
+ s.decRef()
+ }
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ return extendTimeWait, nil
+}
+
+// doTimeWait is responsible for handling the TCP behaviour once a socket
+// enters the TIME_WAIT state. Optionally it can return a closure that
+// should be executed after releasing the endpoint registrations. This is
+// done in cases where a new SYN is received during TIME_WAIT that carries
+// a sequence number larger than one see on the connection.
+func (e *endpoint) doTimeWait() (twReuse func()) {
+ // Trigger a 2 * MSL time wait state. During this period
+ // we will drop all incoming segments.
+ // NOTE: On Linux this is not configurable and is fixed at 60 seconds.
+ timeWaitDuration := DefaultTCPTimeWaitTimeout
+
+ // Get the stack wide configuration.
+ var tcpTW tcpip.TCPTimeWaitTimeoutOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil {
+ timeWaitDuration = time.Duration(tcpTW)
+ }
+
+ const newSegment = 1
+ const notification = 2
+ const timeWaitDone = 3
+
+ s := sleep.Sleeper{}
+ defer s.Done()
+ s.AddWaker(&e.newSegmentWaker, newSegment)
+ s.AddWaker(&e.notificationWaker, notification)
+
+ var timeWaitWaker sleep.Waker
+ s.AddWaker(&timeWaitWaker, timeWaitDone)
+ timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert)
+ defer timeWaitTimer.Stop()
+
+ for {
+ e.mu.Unlock()
+ v, _ := s.Fetch(true)
+ e.mu.Lock()
+ switch v {
+ case newSegment:
+ extendTimeWait, reuseTW := e.handleTimeWaitSegments()
+ if reuseTW != nil {
+ return reuseTW
+ }
+ if extendTimeWait {
+ timeWaitTimer.Reset(timeWaitDuration)
+ }
+ case notification:
+ n := e.fetchNotifications()
+ if n&notifyAbort != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ // Ignore extending TIME_WAIT during a
+ // save. For sockets in TIME_WAIT we just
+ // terminate the TIME_WAIT early.
+ e.handleTimeWaitSegments()
+ }
+ close(e.drainDone)
+ e.mu.Unlock()
+ <-e.undrain
+ e.mu.Lock()
+ return nil
+ }
+ case timeWaitDone:
+ return nil
+ }
+ }
+}
diff --git a/pkg/log/glog_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go
index ea17ae349..cfc304616 100644
--- a/pkg/log/glog_unsafe.go
+++ b/pkg/tcpip/transport/tcp/connect_unsafe.go
@@ -12,21 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package log
+package tcp
import (
"reflect"
"unsafe"
)
-// unsafeString returns a string that points to the given byte array.
-// The byte array must be preserved until the string is disposed.
-func unsafeString(data []byte) (s string) {
- if len(data) == 0 {
- return
- }
-
- (*reflect.StringHeader)(unsafe.Pointer(&s)).Data = uintptr(unsafe.Pointer(&data[0]))
- (*reflect.StringHeader)(unsafe.Pointer(&s)).Len = len(data)
- return
+// optionsToArray converts a slice of capacity >-= maxOptionSize to an array.
+//
+// optionsToArray panics if the capacity of options is smaller than
+// maxOptionSize.
+func optionsToArray(options []byte) *[maxOptionSize]byte {
+ // Reslice to full capacity.
+ options = options[0:maxOptionSize]
+ return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data))
}
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
new file mode 100644
index 000000000..98aecab9e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -0,0 +1,234 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// epQueue is a queue of endpoints.
+type epQueue struct {
+ mu sync.Mutex
+ list endpointList
+}
+
+// enqueue adds e to the queue if the endpoint is not already on the queue.
+func (q *epQueue) enqueue(e *endpoint) {
+ q.mu.Lock()
+ if e.pendingProcessing {
+ q.mu.Unlock()
+ return
+ }
+ q.list.PushBack(e)
+ e.pendingProcessing = true
+ q.mu.Unlock()
+}
+
+// dequeue removes and returns the first element from the queue if available,
+// returns nil otherwise.
+func (q *epQueue) dequeue() *endpoint {
+ q.mu.Lock()
+ if e := q.list.Front(); e != nil {
+ q.list.Remove(e)
+ e.pendingProcessing = false
+ q.mu.Unlock()
+ return e
+ }
+ q.mu.Unlock()
+ return nil
+}
+
+// empty returns true if the queue is empty, false otherwise.
+func (q *epQueue) empty() bool {
+ q.mu.Lock()
+ v := q.list.Empty()
+ q.mu.Unlock()
+ return v
+}
+
+// processor is responsible for processing packets queued to a tcp endpoint.
+type processor struct {
+ epQ epQueue
+ sleeper sleep.Sleeper
+ newEndpointWaker sleep.Waker
+ closeWaker sleep.Waker
+}
+
+func (p *processor) close() {
+ p.closeWaker.Assert()
+}
+
+func (p *processor) queueEndpoint(ep *endpoint) {
+ // Queue an endpoint for processing by the processor goroutine.
+ p.epQ.enqueue(ep)
+ p.newEndpointWaker.Assert()
+}
+
+const (
+ newEndpointWaker = 1
+ closeWaker = 2
+)
+
+func (p *processor) start(wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer p.sleeper.Done()
+
+ for {
+ if id, _ := p.sleeper.Fetch(true); id == closeWaker {
+ break
+ }
+ for {
+ ep := p.epQ.dequeue()
+ if ep == nil {
+ break
+ }
+ if ep.segmentQueue.empty() {
+ continue
+ }
+
+ // If socket has transitioned out of connected state then just let the
+ // worker handle the packet.
+ //
+ // NOTE: We read this outside of e.mu lock which means that by the time
+ // we get to handleSegments the endpoint may not be in ESTABLISHED. But
+ // this should be fine as all normal shutdown states are handled by
+ // handleSegments and if the endpoint moves to a CLOSED/ERROR state
+ // then handleSegments is a noop.
+ if ep.EndpointState() == StateEstablished && ep.mu.TryLock() {
+ // If the endpoint is in a connected state then we do direct delivery
+ // to ensure low latency and avoid scheduler interactions.
+ switch err := ep.handleSegments(true /* fastPath */); {
+ case err != nil:
+ // Send any active resets if required.
+ ep.resetConnectionLocked(err)
+ fallthrough
+ case ep.EndpointState() == StateClose:
+ ep.notifyProtocolGoroutine(notifyTickleWorker)
+ case !ep.segmentQueue.empty():
+ p.epQ.enqueue(ep)
+ }
+ ep.mu.Unlock()
+ } else {
+ ep.newSegmentWaker.Assert()
+ }
+ }
+ }
+}
+
+// dispatcher manages a pool of TCP endpoint processors which are responsible
+// for the processing of inbound segments. This fixed pool of processor
+// goroutines do full tcp processing. The processor is selected based on the
+// hash of the endpoint id to ensure that delivery for the same endpoint happens
+// in-order.
+type dispatcher struct {
+ processors []processor
+ seed uint32
+ wg sync.WaitGroup
+}
+
+func (d *dispatcher) init(nProcessors int) {
+ d.close()
+ d.wait()
+ d.processors = make([]processor, nProcessors)
+ d.seed = generateRandUint32()
+ for i := range d.processors {
+ p := &d.processors[i]
+ p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ p.sleeper.AddWaker(&p.closeWaker, closeWaker)
+ d.wg.Add(1)
+ // NB: sleeper-waker registration must happen synchronously to avoid races
+ // with `close`. It's possible to pull all this logic into `start`, but
+ // that results in a heap-allocated function literal.
+ go p.start(&d.wg)
+ }
+}
+
+func (d *dispatcher) close() {
+ for i := range d.processors {
+ d.processors[i].close()
+ }
+}
+
+func (d *dispatcher) wait() {
+ d.wg.Wait()
+}
+
+func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ ep := stackEP.(*endpoint)
+ s := newSegment(r, id, pkt)
+ if !s.parse() {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ ep.stack.Stats().MalformedRcvdPackets.Increment()
+ ep.stack.Stats().TCP.ChecksumErrors.Increment()
+ ep.stats.ReceiveErrors.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ ep.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ ep.stats.SegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ ep.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ if !ep.enqueueSegment(s) {
+ s.decRef()
+ return
+ }
+
+ // For sockets not in established state let the worker goroutine
+ // handle the packets.
+ if ep.EndpointState() != StateEstablished {
+ ep.newSegmentWaker.Assert()
+ return
+ }
+
+ d.selectProcessor(id).queueEndpoint(ep)
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
+
+func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor {
+ var payload [4]byte
+ binary.LittleEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.LittleEndian.PutUint16(payload[2:], id.RemotePort)
+
+ h := jenkins.Sum32(d.seed)
+ h.Write(payload[:])
+ h.Write([]byte(id.LocalAddress))
+ h.Write([]byte(id.RemoteAddress))
+
+ return &d.processors[h.Sum32()%uint32(len(d.processors))]
+}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index dfaa4a559..804e95aea 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -391,9 +391,8 @@ func testV4Accept(t *testing.T, c *context.Context) {
// Make sure we get the same error when calling the original ep and the
// new one. This validates that v4-mapped endpoints are still able to
// query the V6Only flag, whereas pure v4 endpoints are not.
- var v tcpip.V6OnlyOption
- expected := c.EP.GetSockOpt(&v)
- if err := nep.GetSockOpt(&v); err != expected {
+ _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption)
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected {
t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
}
@@ -531,8 +530,7 @@ func TestV6AcceptOnV6(t *testing.T) {
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
- var v tcpip.V6OnlyOption
- if err := nep.GetSockOpt(&v); err != nil {
+ if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil {
t.Fatalf("GetSockOpt failed failed: %v", err)
}
@@ -570,11 +568,10 @@ func TestV4AcceptOnV4(t *testing.T) {
func testV4ListenClose(t *testing.T, c *context.Context) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err)
+ }
+
const n = uint16(32)
// Start listening.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6ca0d73a9..1ccedebcc 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -18,21 +18,21 @@ import (
"encoding/binary"
"fmt"
"math"
+ "runtime"
"strings"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
- "gvisor.dev/gvisor/pkg/tmutex"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -63,7 +63,8 @@ const (
StateClosing
)
-// connected is the set of states where an endpoint is connected to a peer.
+// connected returns true when s is one of the states representing an
+// endpoint connected to a peer.
func (s EndpointState) connected() bool {
switch s {
case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
@@ -73,6 +74,40 @@ func (s EndpointState) connected() bool {
}
}
+// connecting returns true when s is one of the states representing a
+// connection in progress, but not yet fully established.
+func (s EndpointState) connecting() bool {
+ switch s {
+ case StateConnecting, StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// handshake returns true when s is one of the states representing an endpoint
+// in the middle of a TCP handshake.
+func (s EndpointState) handshake() bool {
+ switch s {
+ case StateSynSent, StateSynRecv:
+ return true
+ default:
+ return false
+ }
+}
+
+// closed returns true when s is one of the states an endpoint transitions to
+// when closed or when it encounters an error. This is distinct from a newly
+// initialized endpoint that was never connected.
+func (s EndpointState) closed() bool {
+ switch s {
+ case StateClose, StateError:
+ return true
+ default:
+ return false
+ }
+}
+
// String implements fmt.Stringer.String.
func (s EndpointState) String() string {
switch s {
@@ -119,8 +154,17 @@ const (
notifyMTUChanged
notifyDrain
notifyReset
+ notifyResetByPeer
+ // notifyAbort is a request for an expedited teardown.
+ notifyAbort
notifyKeepaliveChanged
notifyMSSChanged
+ // notifyTickleWorker is used to tickle the protocol main loop during a
+ // restore after we update the endpoint state to the correct one. This
+ // ensures the loop terminates if the final state of the endpoint is
+ // say TIME_WAIT.
+ notifyTickleWorker
+ notifyError
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -273,20 +317,59 @@ func (*EndpointInfo) IsEndpointInfo() {}
// synchronized. The protocol implementation, however, runs in a single
// goroutine.
//
+// Each endpoint has a few mutexes:
+//
+// e.mu -> Primary mutex for an endpoint must be held for all operations except
+// in e.Readiness where acquiring it will result in a deadlock in epoll
+// implementation.
+//
+// The following three mutexes can be acquired independent of e.mu but if
+// acquired with e.mu then e.mu must be acquired first.
+//
+// e.acceptMu -> protects acceptedChan.
+// e.rcvListMu -> Protects the rcvList and associated fields.
+// e.sndBufMu -> Protects the sndQueue and associated fields.
+// e.lastErrorMu -> Protects the lastError field.
+//
+// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different
+// based on the context in which the lock is acquired. In the syscall context
+// e.LockUser/e.UnlockUser should be used and when doing background processing
+// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below
+// in brief.
+//
+// The reason for this locking behaviour is to avoid wakeups to handle packets.
+// In cases where the endpoint is already locked the background processor can
+// queue the packet up and go its merry way and the lock owner will eventually
+// process the backlog when releasing the lock. Similarly when acquiring the
+// lock from say a syscall goroutine we can implement a bit of spinning if we
+// know that the lock is not held by another syscall goroutine. Background
+// processors should never hold the lock for long and we can avoid an expensive
+// sleep/wakeup by spinning for a shortwhile.
+//
+// For more details please see the detailed documentation on
+// e.LockUser/e.UnlockUser methods.
+//
// +stateify savable
type endpoint struct {
EndpointInfo
- // workMu is used to arbitrate which goroutine may perform protocol
- // work. Only the main protocol goroutine is expected to call Lock() on
- // it, but other goroutines (e.g., send) may call TryLock() to eagerly
- // perform work without having to wait for the main one to wake up.
- workMu tmutex.Mutex `state:"nosave"`
+ // endpointEntry is used to queue endpoints for processing to the
+ // a given tcp processor goroutine.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ endpointEntry `state:"nosave"`
+
+ // pendingProcessing is true if this endpoint is queued for processing
+ // to a TCP processor.
+ //
+ // Precondition: epQueue.mu must be held to read/write this field..
+ pendingProcessing bool `state:"nosave"`
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
+ uniqueID uint64
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
@@ -307,21 +390,24 @@ type endpoint struct {
rcvBufSize int
rcvBufUsed int
rcvAutoParams rcvBufAutoTuneParams
- // zeroWindow indicates that the window was closed due to receive buffer
- // space being filled up. This is set by the worker goroutine before
- // moving a segment to the rcvList. This setting is cleared by the
- // endpoint when a Read() call reads enough data for the new window to
- // be non-zero.
- zeroWindow bool
- // The following fields are protected by the mutex.
- mu sync.RWMutex `state:"nosave"`
+ // mu protects all endpoint fields unless documented otherwise. mu must
+ // be acquired before interacting with the endpoint fields.
+ mu sync.Mutex `state:"nosave"`
+ ownedByUser uint32
+ // state must be read/set using the EndpointState()/setEndpointState()
+ // methods.
state EndpointState `state:".(EndpointState)"`
+ // origEndpointState is only used during a restore phase to save the
+ // endpoint state at restore time as the socket is moved to it's correct
+ // state.
+ origEndpointState EndpointState `state:"nosave"`
+
isPortReserved bool `state:"manual"`
- isRegistered bool
- boundNICID tcpip.NICID `state:"manual"`
+ isRegistered bool `state:"manual"`
+ boundNICID tcpip.NICID
route stack.Route `state:"manual"`
ttl uint8
v6only bool
@@ -330,19 +416,28 @@ type endpoint struct {
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
+ // portFlags stores the current values of port related flags.
+ portFlags ports.Flags
+
+ // Values used to reserve a port or register a transport endpoint
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
+ boundDest tcpip.FullAddress
+
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
- effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running.
workerRunning bool
// workerCleanup specifies if the worker goroutine must perform cleanup
- // before exitting. This can only be set to true when workerRunning is
+ // before exiting. This can only be set to true when workerRunning is
// also true, and they're both protected by the mutex.
workerCleanup bool
@@ -356,6 +451,9 @@ type endpoint struct {
// updated if required when a new segment is received by this endpoint.
recentTS uint32
+ // recentTSTime is the unix time when we updated recentTS last.
+ recentTSTime time.Time `state:".(unixTime)"`
+
// tsOffset is a randomized offset added to the value of the
// TSVal field in the timestamp option.
tsOffset uint32
@@ -370,9 +468,6 @@ type endpoint struct {
// sack holds TCP SACK related information for this endpoint.
sack SACKInfo
- // reusePort is set to true if SO_REUSEPORT is enabled.
- reusePort bool
-
// bindToDevice is set to the NIC on which to bind or disabled if 0.
bindToDevice tcpip.NICID
@@ -392,7 +487,6 @@ type endpoint struct {
// The options below aren't implemented, but we remember the user
// settings because applications expect to be able to set/query these
// options.
- reuseAddr bool
// slowAck holds the negated state of quick ack. It is stubbed out and
// does nothing.
@@ -411,7 +505,18 @@ type endpoint struct {
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
- userMSS int
+ userMSS uint16
+
+ // maxSynRetries is the maximum number of SYN retransmits that TCP should
+ // send before aborting the attempt to connect. It cannot exceed 255.
+ //
+ // NOTE: This is currently a no-op and does not change the SYN
+ // retransmissions.
+ maxSynRetries uint8
+
+ // windowClamp is used to bound the size of the advertised window to
+ // this value.
+ windowClamp uint32
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
@@ -458,12 +563,42 @@ type endpoint struct {
// without hearing a response, the connection is closed.
keepalive keepalive
+ // userTimeout if non-zero specifies a user specified timeout for
+ // a connection w/ pending data to send. A connection that has pending
+ // unacked data will be forcibily aborted if the timeout is reached
+ // without any data being acked.
+ userTimeout time.Duration
+
+ // deferAccept if non-zero specifies a user specified time during
+ // which the final ACK of a handshake will be dropped provided the
+ // ACK is a bare ACK and carries no data. If the timeout is crossed then
+ // the bare ACK is accepted and the connection is delivered to the
+ // listener.
+ deferAccept time.Duration
+
// pendingAccepted is a synchronization primitive used to track number
// of connections that are queued up to be delivered to the accepted
// channel. We use this to ensure that all goroutines blocked on writing
// to the acceptedChan below terminate before we close acceptedChan.
pendingAccepted sync.WaitGroup `state:"nosave"`
+ // acceptMu protects acceptedChan.
+ acceptMu sync.Mutex `state:"nosave"`
+
+ // acceptCond is a condition variable that can be used to block on when
+ // acceptedChan is full and an endpoint is ready to be delivered.
+ //
+ // This condition variable is required because just blocking on sending
+ // to acceptedChan does not work in cases where endpoint.Listen is
+ // called twice with different backlog values. In such cases the channel
+ // is closed and a new one created. Any pending goroutines blocking on
+ // the write to the channel will panic.
+ //
+ // We use this condition variable to block/unblock goroutines which
+ // tried to deliver an endpoint but couldn't because accept backlog was
+ // full ( See: endpoint.deliverAccepted ).
+ acceptCond *sync.Cond `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -502,16 +637,175 @@ type endpoint struct {
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats Stats `state:"nosave"`
+
+ // tcpLingerTimeout is the maximum amount of a time a socket
+ // a socket stays in TIME_WAIT state before being marked
+ // closed.
+ tcpLingerTimeout time.Duration
+
+ // closed indicates that the user has called closed on the
+ // endpoint and at this point the endpoint is only around
+ // to complete the TCP shutdown.
+ closed bool
+
+ // txHash is the transport layer hash to be set on outbound packets
+ // emitted by this endpoint.
+ txHash uint32
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+// calculateAdvertisedMSS calculates the MSS to advertise.
+//
+// If userMSS is non-zero and is not greater than the maximum possible MSS for
+// r, it will be used; otherwise, the maximum possible MSS will be used.
+func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 {
+ // The maximum possible MSS is dependent on the route.
+ // TODO(b/143359391): Respect TCP Min and Max size.
+ maxMSS := uint16(r.MTU() - header.TCPMinimumSize)
+
+ if userMSS != 0 && userMSS < maxMSS {
+ return userMSS
+ }
+
+ return maxMSS
+}
+
+// LockUser tries to lock e.mu and if it fails it will check if the lock is held
+// by another syscall goroutine. If yes, then it will goto sleep waiting for the
+// lock to be released, if not then it will spin till it acquires the lock or
+// another syscall goroutine acquires it in which case it will goto sleep as
+// described above.
+//
+// The assumption behind spinning here being that background packet processing
+// should not be holding the lock for long and spinning reduces latency as we
+// avoid an expensive sleep/wakeup of of the syscall goroutine).
+func (e *endpoint) LockUser() {
+ for {
+ // Try first if the sock is locked then check if it's owned
+ // by another user goroutine if not then we spin, otherwise
+ // we just goto sleep on the Lock() and wait.
+ if !e.mu.TryLock() {
+ // If socket is owned by the user then just goto sleep
+ // as the lock could be held for a reasonably long time.
+ if atomic.LoadUint32(&e.ownedByUser) == 1 {
+ e.mu.Lock()
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+ // Spin but yield the processor since the lower half
+ // should yield the lock soon.
+ runtime.Gosched()
+ continue
+ }
+ atomic.StoreUint32(&e.ownedByUser, 1)
+ return
+ }
+}
+
+// UnlockUser will check if there are any segments already queued for processing
+// and process any such segments before unlocking e.mu. This is required because
+// we when packets arrive and endpoint lock is already held then such packets
+// are queued up to be processed. If the lock is held by the endpoint goroutine
+// then it will process these packets but if the lock is instead held by the
+// syscall goroutine then we can have the syscall goroutine process the backlog
+// before unlocking.
+//
+// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the
+// endpoint. It's also required eventually when we get rid of the endpoint
+// protocol goroutine altogether.
+//
+// Precondition: e.LockUser() must have been called before calling e.UnlockUser()
+func (e *endpoint) UnlockUser() {
+ // Lock segment queue before checking so that we avoid a race where
+ // segments can be queued between the time we check if queue is empty
+ // and actually unlock the endpoint mutex.
+ for {
+ e.segmentQueue.mu.Lock()
+ if e.segmentQueue.emptyLocked() {
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ e.segmentQueue.mu.Unlock()
+ return
+ }
+ e.segmentQueue.mu.Unlock()
+
+ switch e.EndpointState() {
+ case StateEstablished:
+ if err := e.handleSegments(true /* fastPath */); err != nil {
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ }
+ default:
+ // Since we are waking the endpoint goroutine here just unlock
+ // and let it process the queued segments.
+ e.newSegmentWaker.Assert()
+ if atomic.SwapUint32(&e.ownedByUser, 0) != 1 {
+ panic("e.UnlockUser() called without calling e.LockUser()")
+ }
+ e.mu.Unlock()
+ return
+ }
+ }
}
// StopWork halts packet processing. Only to be used in tests.
func (e *endpoint) StopWork() {
- e.workMu.Lock()
+ e.mu.Lock()
}
// ResumeWork resumes packet processing. Only to be used in tests.
func (e *endpoint) ResumeWork() {
- e.workMu.Unlock()
+ e.mu.Unlock()
+}
+
+// setEndpointState updates the state of the endpoint to state atomically. This
+// method is unexported as the only place we should update the state is in this
+// package but we allow the state to be read freely without holding e.mu.
+//
+// Precondition: e.mu must be held to call this method.
+func (e *endpoint) setEndpointState(state EndpointState) {
+ oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+ switch state {
+ case StateEstablished:
+ e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.stack.Stats().TCP.CurrentConnected.Increment()
+ case StateError:
+ fallthrough
+ case StateClose:
+ if oldstate == StateCloseWait || oldstate == StateEstablished {
+ e.stack.Stats().TCP.EstablishedResets.Increment()
+ }
+ fallthrough
+ default:
+ if oldstate == StateEstablished {
+ e.stack.Stats().TCP.CurrentEstablished.Decrement()
+ }
+ }
+ atomic.StoreUint32((*uint32)(&e.state), uint32(state))
+}
+
+// EndpointState returns the current state of the endpoint.
+func (e *endpoint) EndpointState() EndpointState {
+ return EndpointState(atomic.LoadUint32((*uint32)(&e.state)))
+}
+
+// setRecentTimestamp sets the recentTS field to the provided value.
+func (e *endpoint) setRecentTimestamp(recentTS uint32) {
+ e.recentTS = recentTS
+ e.recentTSTime = time.Now()
+}
+
+// recentTimestamp returns the value of the recentTS field.
+func (e *endpoint) recentTimestamp() uint32 {
+ return e.recentTS
}
// keepalive is a synchronization wrapper used to appease stateify. See the
@@ -543,13 +837,16 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSize: DefaultReceiveBufferSize,
sndBufSize: DefaultSendBufferSize,
sndMTU: int(math.MaxInt32),
- reuseAddr: true,
keepalive: keepalive{
// Linux defaults.
idle: 2 * time.Hour,
interval: 75 * time.Second,
count: 9,
},
+ uniqueID: s.UniqueID(),
+ txHash: s.Rand().Uint32(),
+ windowClamp: DefaultReceiveBufferSize,
+ maxSynRetries: DefaultSynRetries,
}
var ss SendBufferSizeOption
@@ -572,14 +869,28 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.rcvAutoParams.disabled = !bool(mrb)
}
+ var de DelayEnabled
+ if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de {
+ e.SetSockOptBool(tcpip.DelayOption, true)
+ }
+
+ var tcpLT tcpip.TCPLingerTimeoutOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil {
+ e.tcpLingerTimeout = time.Duration(tcpLT)
+ }
+
+ var synRetries tcpip.TCPSynRetriesOption
+ if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil {
+ e.maxSynRetries = uint8(synRetries)
+ }
+
if p := s.GetTCPProbe(); p != nil {
e.probe = p
}
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
- e.workMu.Lock()
e.tsOffset = timeStampOffset()
+ e.acceptCond = sync.NewCond(&e.acceptMu)
return e
}
@@ -589,26 +900,25 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
result := waiter.EventMask(0)
- e.mu.RLock()
- defer e.mu.RUnlock()
-
- switch e.state {
+ switch e.EndpointState() {
case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv:
// Ready for nothing.
- case StateClose, StateError:
+ case StateClose, StateError, StateTimeWait:
// Ready for anything.
result = mask
case StateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
+ e.acceptMu.Lock()
if len(e.acceptedChan) > 0 {
result |= waiter.EventIn
}
+ e.acceptMu.Unlock()
}
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// Determine if the endpoint is writable if requested.
if (mask & waiter.EventOut) != 0 {
e.sndBufMu.Lock()
@@ -655,69 +965,117 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
}
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ // The abort notification is not processed synchronously, so no
+ // synchronization is needed.
+ //
+ // If the endpoint becomes connected after this check, we still close
+ // the endpoint. This worst case results in a slower abort.
+ //
+ // If the endpoint disconnected after the check, nothing needs to be
+ // done, so sending a notification which will potentially be ignored is
+ // fine.
+ //
+ // If the endpoint connecting finishes after the check, the endpoint
+ // is either in a connected state (where we would notifyAbort anyway),
+ // SYN-RECV (where we would also notifyAbort anyway), or in an error
+ // state where nothing is required and the notification can be safely
+ // ignored.
+ //
+ // Endpoints where a Close during connecting or SYN-RECV state would be
+ // problematic are set to state connecting before being registered (and
+ // thus possible to be Aborted). They are never available in initial
+ // state.
+ //
+ // Endpoints transitioning from initial to connecting state may be
+ // safely either closed or sent notifyAbort.
+ if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() {
+ e.notifyProtocolGoroutine(notifyAbort)
+ return
+ }
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources associated
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
func (e *endpoint) Close() {
+ e.LockUser()
+ defer e.UnlockUser()
+ if e.closed {
+ return
+ }
+
// Issue a shutdown so that the peer knows we won't send any more data
// if we're connected, or stop accepting if we're listening.
- e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
-
- e.mu.Lock()
+ e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ e.closeNoShutdownLocked()
+}
+// closeNoShutdown closes the endpoint without doing a full shutdown. This is
+// used when a connection needs to be aborted with a RST and we want to skip
+// a full 4 way TCP shutdown.
+func (e *endpoint) closeNoShutdownLocked() {
// For listening sockets, we always release ports inline so that they
// are immediately available for reuse after Close() is called. If also
// registered, we unregister as well otherwise the next user would fail
// in Listen() when trying to register.
- if e.state == StateListen && e.isPortReserved {
+ if e.EndpointState() == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
+ }
+
+ // Mark endpoint as closed.
+ e.closed = true
+
+ switch e.EndpointState() {
+ case StateClose, StateError:
+ return
}
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
- tcpip.AddDanglingEndpoint(e)
- if !e.workerRunning {
- e.cleanupLocked()
- } else {
+ if e.workerRunning {
e.workerCleanup = true
+ tcpip.AddDanglingEndpoint(e)
+ // Worker will remove the dangling endpoint when the endpoint
+ // goroutine terminates.
e.notifyProtocolGoroutine(notifyClose)
+ } else {
+ e.transitionToStateCloseLocked()
}
-
- e.mu.Unlock()
}
// closePendingAcceptableConnections closes all connections that have completed
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
- done := make(chan struct{})
- // Spin a goroutine up as ranging on e.acceptedChan will just block when
- // there are no more connections in the channel. Using a non-blocking
- // select does not work as it can potentially select the default case
- // even when there are pending writes but that are not yet written to
- // the channel.
- go func() {
- defer close(done)
- for n := range e.acceptedChan {
- n.mu.Lock()
- n.resetConnectionLocked(tcpip.ErrConnectionAborted)
- n.mu.Unlock()
- n.Close()
- }
- }()
- // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
- // endpoints which have completed handshake but are not yet written to
- // the e.acceptedChan. We wait here till the goroutine above can drain
- // all such connections from e.acceptedChan.
- e.pendingAccepted.Wait()
+ e.acceptMu.Lock()
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ return
+ }
close(e.acceptedChan)
- <-done
+ ch := e.acceptedChan
e.acceptedChan = nil
+ e.acceptCond.Broadcast()
+ e.acceptMu.Unlock()
+
+ // Reset all connections that are waiting to be accepted.
+ for n := range ch {
+ n.notifyProtocolGoroutine(notifyReset)
+ }
+ // Wait for reset of all endpoints that are still waiting to be delivered to
+ // the now closed acceptedChan.
+ e.pendingAccepted.Wait()
}
// cleanupLocked frees all resources associated with the endpoint. It is called
@@ -726,22 +1084,25 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
- if e.acceptedChan != nil {
- e.closePendingAcceptableConnectionsLocked()
- }
+ e.closePendingAcceptableConnectionsLocked()
+
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
e.isPortReserved = false
}
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
+ e.boundDest = tcpip.FullAddress{}
e.route.Release()
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
@@ -752,16 +1113,34 @@ func (e *endpoint) initialReceiveWindow() int {
if rcvWnd > math.MaxUint16 {
rcvWnd = math.MaxUint16
}
- routeWnd := InitialCwnd * int(mssForRoute(&e.route)) * 2
+
+ // Use the user supplied MSS, if available.
+ routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2
if rcvWnd > routeWnd {
rcvWnd = routeWnd
}
+ rcvWndScale := e.rcvWndScaleForHandshake()
+
+ // Round-down the rcvWnd to a multiple of wndScale. This ensures that the
+ // window offered in SYN won't be reduced due to the loss of precision if
+ // window scaling is enabled after the handshake.
+ rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale)
+
+ // Ensure we can always accept at least 1 byte if the scale specified
+ // was too high for the provided rcvWnd.
+ if rcvWnd == 0 {
+ rcvWnd = 1
+ }
+
return rcvWnd
}
// ModerateRecvBuf adjusts the receive buffer and the advertised window
-// based on the number of bytes copied to user space.
+// based on the number of bytes copied to userspace.
func (e *endpoint) ModerateRecvBuf(copied int) {
+ e.LockUser()
+ defer e.UnlockUser()
+
e.rcvListMu.Lock()
if e.rcvAutoParams.disabled {
e.rcvListMu.Unlock()
@@ -807,8 +1186,14 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
// reject valid data that might already be in flight as the
// acceptable window will shrink.
if rcvWnd > e.rcvBufSize {
+ availBefore := e.receiveBufferAvailableLocked()
e.rcvBufSize = rcvWnd
- e.notifyProtocolGoroutine(notifyReceiveWindowChanged)
+ availAfter := e.receiveBufferAvailableLocked()
+ mask := uint32(notifyReceiveWindowChanged)
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.notifyProtocolGoroutine(mask)
}
// We only update prevCopied when we grow the buffer because in cases
@@ -822,36 +1207,50 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
e.rcvListMu.Unlock()
}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
- return e.stack.IPTables(), nil
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
+
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+ err := e.lastError
+ e.lastError = nil
+ return err
}
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
+ e.LockUser()
+ defer e.UnlockUser()
+
+ // When in SYN-SENT state, let the caller block on the receive.
+ // An application can initiate a non-blocking connect and then block
+ // on a receive. It can expect to read any data after the handshake
+ // is complete. RFC793, section 3.9, p58.
+ if e.EndpointState() == StateSynSent {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data. Also note that a RST being received
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
bufUsed := e.rcvBufUsed
- if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 {
+ if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
- e.mu.RUnlock()
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
}
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ e.stats.ReadErrors.NotConnected.Increment()
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
}
v, err := e.readLocked()
e.rcvListMu.Unlock()
- e.mu.RUnlock()
-
if err == tcpip.ErrClosedForReceive {
e.stats.ReadErrors.ReadClosed.Increment()
}
@@ -860,7 +1259,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
return buffer.View{}, tcpip.ErrClosedForReceive
}
return buffer.View{}, tcpip.ErrWouldBlock
@@ -877,11 +1276,12 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
}
e.rcvBufUsed -= len(v)
- // If the window was zero before this read and if the read freed up
- // enough buffer space for the scaled window to be non-zero then notify
- // the protocol goroutine to send a window update.
- if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) {
- e.zeroWindow = false
+
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above {
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -895,8 +1295,8 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
- if !e.state.connected() {
- switch e.state {
+ if !e.EndpointState().connected() {
+ switch e.EndpointState() {
case StateError:
return 0, e.HardError
default:
@@ -922,13 +1322,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
- e.mu.RLock()
+ e.LockUser()
e.sndBufMu.Lock()
avail, err := e.isEndpointWritableLocked()
if err != nil {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
e.stats.WriteErrors.WriteClosed.Increment()
return 0, nil, err
}
@@ -940,73 +1340,72 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
// are copying data in.
if !opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
// Fetch data.
v, perr := p.Payload(avail)
if perr != nil || len(v) == 0 {
- if opts.Atomic { // See above.
+ // Note that perr may be nil if len(v) == 0.
+ if opts.Atomic {
e.sndBufMu.Unlock()
- e.mu.RUnlock()
+ e.UnlockUser()
}
- // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- if !opts.Atomic { // See above.
- e.mu.RLock()
- e.sndBufMu.Lock()
+ queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) {
+ // Add data to the send queue.
+ s := newSegmentFromView(&e.route, e.ID, v)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
+ e.sndQueue.PushBack(s)
+ e.sndBufMu.Unlock()
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- e.stats.WriteErrors.WriteClosed.Increment()
- return 0, nil, err
- }
+ // Do the work inline.
+ e.handleWrite()
+ e.UnlockUser()
+ return int64(len(v)), nil, nil
+ }
- // Discard any excess data copied in due to avail being reduced due
- // to a simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
- }
+ if opts.Atomic {
+ // Locks released in queueAndSend()
+ return queueAndSend()
}
- // Add data to the send queue.
- s := newSegmentFromView(&e.route, e.ID, v)
- e.sndBufUsed += len(v)
- e.sndBufInQueue += seqnum.Size(len(v))
- e.sndQueue.PushBack(s)
- e.sndBufMu.Unlock()
- // Release the endpoint lock to prevent deadlocks due to lock
- // order inversion when acquiring workMu.
- e.mu.RUnlock()
+ // Since we released locks in between it's possible that the
+ // endpoint transitioned to a CLOSED/ERROR states so make
+ // sure endpoint is still writable before trying to write.
+ e.LockUser()
+ e.sndBufMu.Lock()
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.UnlockUser()
+ e.stats.WriteErrors.WriteClosed.Increment()
+ return 0, nil, err
+ }
- if e.workMu.TryLock() {
- // Do the work inline.
- e.handleWrite()
- e.workMu.Unlock()
- } else {
- // Let the protocol goroutine do the work.
- e.sndWaker.Assert()
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
}
- return int64(len(v)), nil, nil
+ // Locks released in queueAndSend()
+ return queueAndSend()
}
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint can be read if it's connected, or if it's already closed
// but has some pending unread data.
- if s := e.state; !s.connected() && s != StateClose {
+ if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
return 0, tcpip.ControlMessages{}, e.HardError
}
@@ -1018,7 +1417,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
defer e.rcvListMu.Unlock()
if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.state.connected() {
+ if e.rcvClosed || !e.EndpointState().connected() {
e.stats.ReadErrors.ReadClosed.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
}
@@ -1055,37 +1454,174 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
return num, tcpip.ControlMessages{}, nil
}
-// zeroReceiveWindow checks if the receive window to be announced now would be
-// zero, based on the amount of available buffer and the receive window scaling.
+// windowCrossedACKThresholdLocked checks if the receive window to be announced
+// now would be under aMSS or under half receive buffer, whichever smaller. This
+// is useful as a receive side silly window syndrome prevention mechanism. If
+// window grows to reasonable value, we should send ACK to the sender to inform
+// the rx space is now large. We also want ensure a series of small read()'s
+// won't trigger a flood of spurious tiny ACK's.
//
-// It must be called with rcvListMu held.
-func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
- if e.rcvBufUsed >= e.rcvBufSize {
- return true
+// For large receive buffers, the threshold is aMSS - once reader reads more
+// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of
+// receive buffer size. This is chosen arbitrairly.
+// crossed will be true if the window size crossed the ACK threshold.
+// above will be true if the new window is >= ACK threshold and false
+// otherwise.
+//
+// Precondition: e.mu and e.rcvListMu must be held.
+func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) {
+ newAvail := e.receiveBufferAvailableLocked()
+ oldAvail := newAvail - deltaBefore
+ if oldAvail < 0 {
+ oldAvail = 0
+ }
+
+ threshold := int(e.amss)
+ if threshold > e.rcvBufSize/2 {
+ threshold = e.rcvBufSize / 2
+ }
+
+ switch {
+ case oldAvail < threshold && newAvail >= threshold:
+ return true, true
+ case oldAvail >= threshold && newAvail < threshold:
+ return true, false
+ }
+ return false, false
+}
+
+// SetSockOptBool sets a socket option.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ e.broadcast = v
+ e.UnlockUser()
+
+ case tcpip.CorkOption:
+ e.LockUser()
+ if !v {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ e.UnlockUser()
+
+ case tcpip.DelayOption:
+ if v {
+ atomic.StoreUint32(&e.delay, 1)
+ } else {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ }
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.QuickAckOption:
+ o := uint32(1)
+ if v {
+ o = 0
+ }
+ atomic.StoreUint32(&e.slowAck, o)
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ e.portFlags.TupleOnly = v
+ e.UnlockUser()
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ e.portFlags.LoadBalanced = v
+ e.UnlockUser()
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // We only allow this to be set when we're in the initial state.
+ if e.EndpointState() != StateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.LockUser()
+ e.v6only = v
+ e.UnlockUser()
}
- return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+ return nil
}
// SetSockOptInt sets a socket option.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
+ const inetECNMask = 3
+
switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = v
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ // TODO(gvisor.dev/issue/995): ECN is not currently supported,
+ // ignore the bits for now.
+ e.sendTOS = uint8(v) & ^uint8(inetECNMask)
+ e.UnlockUser()
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.LockUser()
+ e.userMSS = uint16(userMSS)
+ e.UnlockUser()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if attempting to set this option to
+ // anything other than path MTU discovery disabled.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
var rs ReceiveBufferSizeOption
- size := int(v)
if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
- if size < rs.Min {
- size = rs.Min
+ if v < rs.Min {
+ v = rs.Min
}
- if size > rs.Max {
- size = rs.Max
+ if v > rs.Max {
+ v = rs.Max
}
}
mask := uint32(notifyReceiveWindowChanged)
+ e.LockUser()
e.rcvListMu.Lock()
// Make sure the receive buffer size allows us to send a
@@ -1094,179 +1630,119 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
if e.rcv != nil {
scale = e.rcv.rcvWndScale
}
- if size>>scale == 0 {
- size = 1 << scale
+ if v>>scale == 0 {
+ v = 1 << scale
}
// Make sure 2*size doesn't overflow.
- if size > math.MaxInt32/2 {
- size = math.MaxInt32 / 2
+ if v > math.MaxInt32/2 {
+ v = math.MaxInt32 / 2
}
- e.rcvBufSize = size
+ availBefore := e.receiveBufferAvailableLocked()
+ e.rcvBufSize = v
+ availAfter := e.receiveBufferAvailableLocked()
+
e.rcvAutoParams.disabled = true
- if e.zeroWindow && !e.zeroReceiveWindow(scale) {
- e.zeroWindow = false
+
+ // Immediately send an ACK to uncork the sender silly window
+ // syndrome prevetion, when our available space grows above aMSS
+ // or half receive buffer, whichever smaller.
+ if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above {
mask |= notifyNonZeroReceiveWindow
}
- e.rcvListMu.Unlock()
+ e.rcvListMu.Unlock()
+ e.UnlockUser()
e.notifyProtocolGoroutine(mask)
- return nil
case tcpip.SendBufferSizeOption:
// Make sure the send buffer size is within the min and max
// allowed.
- size := int(v)
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
- if size < ss.Min {
- size = ss.Min
+ if v < ss.Min {
+ v = ss.Min
}
- if size > ss.Max {
- size = ss.Max
+ if v > ss.Max {
+ v = ss.Max
}
}
e.sndBufMu.Lock()
- e.sndBufSize = size
+ e.sndBufSize = v
e.sndBufMu.Unlock()
- return nil
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
+ case tcpip.TTLOption:
+ e.LockUser()
+ e.ttl = uint8(v)
+ e.UnlockUser()
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
+ case tcpip.TCPSynCountOption:
+ if v < 1 || v > 255 {
+ return tcpip.ErrInvalidOptionValue
}
- return nil
+ e.LockUser()
+ e.maxSynRetries = uint8(v)
+ e.UnlockUser()
- default:
- return nil
+ case tcpip.TCPWindowClampOption:
+ if v == 0 {
+ e.LockUser()
+ switch e.EndpointState() {
+ case StateClose, StateInitial:
+ e.windowClamp = 0
+ e.UnlockUser()
+ return nil
+ default:
+ e.UnlockUser()
+ return tcpip.ErrInvalidOptionValue
+ }
+ }
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if v < rs.Min/2 {
+ v = rs.Min / 2
+ }
+ }
+ e.LockUser()
+ e.windowClamp = uint32(v)
+ e.UnlockUser()
}
+ return nil
}
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- // Lower 2 bits represents ECN bits. RFC 3168, section 23.1
- const inetECNMask = 3
switch v := opt.(type) {
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- for nicid, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicid
- return nil
- }
- }
- return tcpip.ErrUnknownDevice
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
- case tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // We only allow this to be set when we're in the initial state.
- if e.state != StateInitial {
- return tcpip.ErrInvalidEndpointState
- }
-
- e.v6only = v != 0
- return nil
-
- case tcpip.TTLOption:
- e.mu.Lock()
- e.ttl = uint8(v)
- e.mu.Unlock()
- return nil
-
- case tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- e.keepalive.enabled = v != 0
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
+ e.LockUser()
+ e.bindToDevice = id
+ e.UnlockUser()
case tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
e.keepalive.idle = time.Duration(v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
case tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
e.keepalive.interval = time.Duration(v)
e.keepalive.Unlock()
e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
- case tcpip.KeepaliveCountOption:
- e.keepalive.Lock()
- e.keepalive.count = int(v)
- e.keepalive.Unlock()
- e.notifyProtocolGoroutine(notifyKeepaliveChanged)
- return nil
+ case tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v != 0
- e.mu.Unlock()
- return nil
+ case tcpip.TCPUserTimeoutOption:
+ e.LockUser()
+ e.userTimeout = time.Duration(v)
+ e.UnlockUser()
case tcpip.CongestionControlOption:
// Query the available cc algorithms in the stack and
@@ -1279,22 +1755,16 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
availCC := strings.Split(string(avail), " ")
for _, cc := range availCC {
if v == tcpip.CongestionControlOption(cc) {
- // Acquire the work mutex as we may need to
- // reinitialize the congestion control state.
- e.mu.Lock()
- state := e.state
+ e.LockUser()
+ state := e.EndpointState()
e.cc = v
- e.mu.Unlock()
switch state {
case StateEstablished:
- e.workMu.Lock()
- e.mu.Lock()
- if e.state == state {
+ if e.EndpointState() == state {
e.snd.cc = e.snd.initCongestionControl(e.cc)
}
- e.mu.Unlock()
- e.workMu.Unlock()
}
+ e.UnlockUser()
return nil
}
}
@@ -1303,34 +1773,44 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
// control algorithm is specified.
return tcpip.ErrNoSuchFile
- case tcpip.IPv4TOSOption:
- e.mu.Lock()
- // TODO(gvisor.dev/issue/995): ECN is not currently supported,
- // ignore the bits for now.
- e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
- return nil
+ case tcpip.TCPLingerTimeoutOption:
+ e.LockUser()
+ if v < 0 {
+ // Same as effectively disabling TCPLinger timeout.
+ v = 0
+ }
+ // Cap it to MaxTCPLingerTimeout.
+ stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout)
+ if v > stkTCPLingerTimeout {
+ v = stkTCPLingerTimeout
+ }
+ e.tcpLingerTimeout = time.Duration(v)
+ e.UnlockUser()
- case tcpip.IPv6TrafficClassOption:
- e.mu.Lock()
- // TODO(gvisor.dev/issue/995): ECN is not currently supported,
- // ignore the bits for now.
- e.sendTOS = uint8(v) & ^uint8(inetECNMask)
- e.mu.Unlock()
+ case tcpip.TCPDeferAcceptOption:
+ e.LockUser()
+ if time.Duration(v) > MaxRTO {
+ v = tcpip.TCPDeferAcceptOption(MaxRTO)
+ }
+ e.deferAccept = time.Duration(v)
+ e.UnlockUser()
+
+ case tcpip.SocketDetachFilterOption:
return nil
default:
return nil
}
+ return nil
}
// readyReceiveSize returns the number of bytes ready to be received.
func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
// The endpoint cannot be in listen state.
- if e.state == StateListen {
+ if e.EndpointState() == StateListen {
return 0, tcpip.ErrInvalidEndpointState
}
@@ -1340,9 +1820,100 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.BroadcastOption:
+ e.LockUser()
+ v := e.broadcast
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.CorkOption:
+ return atomic.LoadUint32(&e.cork) != 0, nil
+
+ case tcpip.DelayOption:
+ return atomic.LoadUint32(&e.delay) != 0, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ return v, nil
+
+ case tcpip.QuickAckOption:
+ v := atomic.LoadUint32(&e.slowAck) == 0
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ e.LockUser()
+ v := e.portFlags.TupleOnly
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.ReusePortOption:
+ e.LockUser()
+ v := e.portFlags.LoadBalanced
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.LockUser()
+ v := e.v6only
+ e.UnlockUser()
+
+ return v, nil
+
+ case tcpip.MulticastLoopOption:
+ return true, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ v := e.keepalive.count
+ e.keepalive.Unlock()
+ return v, nil
+
+ case tcpip.IPv4TOSOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.IPv6TrafficClassOption:
+ e.LockUser()
+ v := int(e.sendTOS)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.MaxSegOption:
+ // This is just stubbed out. Linux never returns the user_mss
+ // value as it either returns the defaultMSS or returns the
+ // actual current MSS. Netstack just returns the defaultMSS
+ // always for now.
+ v := header.TCPDefaultMSS
+ return v, nil
+
+ case tcpip.MTUDiscoverOption:
+ // Always return the path MTU discovery disabled setting since
+ // it's the only one supported.
+ return tcpip.PMTUDiscoveryDont, nil
+
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
@@ -1358,12 +1929,26 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
e.rcvListMu.Unlock()
return v, nil
- case tcpip.DelayOption:
- var o int
- if v := atomic.LoadUint32(&e.delay); v != 0 {
- o = 1
- }
- return o, nil
+ case tcpip.TTLOption:
+ e.LockUser()
+ v := int(e.ttl)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.TCPSynCountOption:
+ e.LockUser()
+ v := int(e.maxSynRetries)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.TCPWindowClampOption:
+ e.LockUser()
+ v := int(e.windowClamp)
+ e.UnlockUser()
+ return v, nil
+
+ case tcpip.MulticastTTLOption:
+ return 1, nil
default:
return -1, tcpip.ErrUnknownProtocolOption
@@ -1374,191 +1959,84 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- e.lastErrorMu.Lock()
- err := e.lastError
- e.lastError = nil
- e.lastErrorMu.Unlock()
- return err
-
- case *tcpip.MaxSegOption:
- // This is just stubbed out. Linux never returns the user_mss
- // value as it either returns the defaultMSS or returns the
- // actual current MSS. Netstack just returns the defaultMSS
- // always for now.
- *o = header.TCPDefaultMSS
- return nil
-
- case *tcpip.CorkOption:
- *o = 0
- if v := atomic.LoadUint32(&e.cork); v != 0 {
- *o = 1
- }
- return nil
-
- case *tcpip.ReuseAddressOption:
- e.mu.RLock()
- v := e.reuseAddr
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.reusePort
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
+ return e.takeLastError()
case *tcpip.BindToDeviceOption:
- e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = ""
- return nil
-
- case *tcpip.QuickAckOption:
- *o = 1
- if v := atomic.LoadUint32(&e.slowAck); v != 0 {
- *o = 0
- }
- return nil
-
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.TTLOption:
- e.mu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.mu.Unlock()
- return nil
+ e.LockUser()
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.UnlockUser()
case *tcpip.TCPInfoOption:
*o = tcpip.TCPInfoOption{}
- e.mu.RLock()
+ e.LockUser()
snd := e.snd
- e.mu.RUnlock()
+ e.UnlockUser()
if snd != nil {
snd.rtt.Lock()
o.RTT = snd.rtt.srtt
o.RTTVar = snd.rtt.rttvar
snd.rtt.Unlock()
}
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- e.keepalive.Lock()
- v := e.keepalive.enabled
- e.keepalive.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.KeepaliveIdleOption:
e.keepalive.Lock()
*o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
e.keepalive.Unlock()
- return nil
case *tcpip.KeepaliveIntervalOption:
e.keepalive.Lock()
*o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
e.keepalive.Unlock()
- return nil
- case *tcpip.KeepaliveCountOption:
- e.keepalive.Lock()
- *o = tcpip.KeepaliveCountOption(e.keepalive.count)
- e.keepalive.Unlock()
- return nil
+ case *tcpip.TCPUserTimeoutOption:
+ e.LockUser()
+ *o = tcpip.TCPUserTimeoutOption(e.userTimeout)
+ e.UnlockUser()
case *tcpip.OutOfBandInlineOption:
// We don't currently support disabling this option.
*o = 1
- return nil
-
- case *tcpip.BroadcastOption:
- e.mu.Lock()
- v := e.broadcast
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.CongestionControlOption:
- e.mu.Lock()
+ e.LockUser()
*o = e.cc
- e.mu.Unlock()
- return nil
+ e.UnlockUser()
- case *tcpip.IPv4TOSOption:
- e.mu.RLock()
- *o = tcpip.IPv4TOSOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
+ case *tcpip.TCPLingerTimeoutOption:
+ e.LockUser()
+ *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout)
+ e.UnlockUser()
- case *tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
+ case *tcpip.TCPDeferAcceptOption:
+ e.LockUser()
+ *o = tcpip.TCPDeferAcceptOption(e.deferAccept)
+ e.UnlockUser()
+
+ case *tcpip.OriginalDestinationOption:
+ ipt := e.stack.IPTables()
+ addr, port, err := ipt.OriginalDst(e.ID)
+ if err != nil {
+ return err
+ }
+ *o = tcpip.OriginalDestinationOption{
+ Addr: addr,
+ Port: port,
+ }
default:
return tcpip.ErrUnknownProtocolOption
}
+ return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if header.IsV4MappedAddress(addr.Addr) {
- // Fail if using a v4 mapped address on a v6only endpoint.
- if e.v6only {
- return 0, tcpip.ErrNoRoute
- }
-
- netProto = header.IPv4ProtocolNumber
- addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == header.IPv4Any {
- addr.Addr = ""
- }
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
}
-
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -1583,17 +2061,17 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// yet accepted by the app, they are restored without running the main goroutine
// here.
func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
connectingAddr := addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
- if e.state.connected() {
+ if e.EndpointState().connected() {
// The endpoint is already connected. If caller hasn't been
// notified yet, return success.
if !e.isConnectNotified {
@@ -1604,8 +2082,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnected
}
- nicid := addr.NIC
- switch e.state {
+ nicID := addr.NIC
+ switch e.EndpointState() {
case StateBound:
// If we're already bound to a NIC but the caller is requesting
// that we use a different one now, we cannot proceed.
@@ -1613,11 +2091,11 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
break
}
- if nicid != 0 && nicid != e.boundNICID {
+ if nicID != 0 && nicID != e.boundNICID {
return tcpip.ErrNoRoute
}
- nicid = e.boundNICID
+ nicID = e.boundNICID
case StateInitial:
// Nothing to do. We'll eventually fill-in the gaps in the ID (if any)
@@ -1636,14 +2114,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
if err != nil {
return err
}
defer r.Release()
- origID := e.ID
-
netProtos := []tcpip.NetworkProtocolNumber{netProto}
e.ID.LocalAddress = r.LocalAddress
e.ID.RemoteAddress = r.RemoteAddress
@@ -1651,7 +2127,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
if e.ID.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
if err != nil {
return err
}
@@ -1666,7 +2142,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
// src IP to ensure that for a given tuple (srcIP, destIP,
// destPort) the offset used as a starting point is the same to
// ensure that we can cycle through the port space effectively.
- h := jenkins.Sum32(e.stack.PortSeed())
+ h := jenkins.Sum32(e.stack.Seed())
h.Write([]byte(e.ID.LocalAddress))
h.Write([]byte(e.ID.RemoteAddress))
portBuf := make([]byte, 2)
@@ -1674,44 +2150,95 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
h.Write(portBuf)
portOffset := h.Sum32()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
+ panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %#v) = %s", ProtocolNumber, &twReuse, err))
+ }
+
+ reuse := twReuse == tcpip.TCPTimeWaitReuseGlobal
+ if twReuse == tcpip.TCPTimeWaitReuseLoopbackOnly {
+ switch netProto {
+ case header.IPv4ProtocolNumber:
+ reuse = header.IsV4LoopbackAddress(e.ID.LocalAddress) && header.IsV4LoopbackAddress(e.ID.RemoteAddress)
+ case header.IPv6ProtocolNumber:
+ reuse = e.ID.LocalAddress == header.IPv6Loopback && e.ID.RemoteAddress == header.IPv6Loopback
+ }
+ }
+
if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- // reusePort is false below because connect cannot reuse a port even if
- // reusePort was set.
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
- return false, nil
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ if err != tcpip.ErrPortInUse || !reuse {
+ return false, nil
+ }
+ transEPID := e.ID
+ transEPID.LocalPort = p
+ // Check if an endpoint is registered with demuxer in TIME-WAIT and if
+ // we can reuse it. If we can't find a transport endpoint then we just
+ // skip using this port as it's possible that either an endpoint has
+ // bound the port but not registered with demuxer yet (no listen/connect
+ // done yet) or the reservation was freed between the check above and
+ // the FindTransportEndpoint below. But rather than retry the same port
+ // we just skip it and move on.
+ transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r)
+ if transEP == nil {
+ // ReservePort failed but there is no registered endpoint with
+ // demuxer. Which indicates there is at least some endpoint that has
+ // bound the port.
+ return false, nil
+ }
+
+ tcpEP := transEP.(*endpoint)
+ tcpEP.LockUser()
+ // If the endpoint is not in TIME-WAIT or if it is in TIME-WAIT but
+ // less than 1 second has elapsed since its recentTS was updated then
+ // we cannot reuse the port.
+ if tcpEP.EndpointState() != StateTimeWait || time.Since(tcpEP.recentTSTime) < 1*time.Second {
+ tcpEP.UnlockUser()
+ return false, nil
+ }
+ // Since the endpoint is in TIME-WAIT it should be safe to acquire its
+ // Lock while holding the lock for this endpoint as endpoints in
+ // TIME-WAIT do not acquire locks on other endpoints.
+ tcpEP.workerCleanup = false
+ tcpEP.cleanupLocked()
+ tcpEP.notifyProtocolGoroutine(notifyAbort)
+ tcpEP.UnlockUser()
+ // Now try and Reserve again if it fails then we skip.
+ if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil {
+ return false, nil
+ }
}
id := e.ID
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
- case nil:
- e.ID = id
- return true, nil
- case tcpip.ErrPortInUse:
- return false, nil
- default:
+ if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr)
+ if err == tcpip.ErrPortInUse {
+ return false, nil
+ }
return false, err
}
+
+ // Port picking successful. Save the details of
+ // the selected port.
+ e.ID = id
+ e.isPortReserved = true
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
+ e.boundDest = addr
+ return true, nil
}); err != nil {
return err
}
}
- // Remove the port reservation. This can happen when Bind is called
- // before Connect: in such a case we don't want to hold on to
- // reservations anymore.
- if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
- e.isPortReserved = false
- }
-
e.isRegistered = true
- e.state = StateConnecting
+ e.setEndpointState(StateConnecting)
e.route = r.Clone()
- e.boundNICID = nicid
+ e.boundNICID = nicID
e.effectiveNetProtos = netProtos
e.connectingAddress = connectingAddr
@@ -1730,14 +2257,13 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
e.segmentQueue.mu.Unlock()
e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
- e.state = StateEstablished
- e.stack.Stats().TCP.CurrentEstablished.Increment()
+ e.setEndpointState(StateEstablished)
}
if run {
e.workerRunning = true
e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
}
return tcpip.ErrConnectStarted
@@ -1751,14 +2277,17 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
- e.shutdownFlags |= flags
+ e.LockUser()
+ defer e.UnlockUser()
+ return e.shutdownLocked(flags)
+}
+func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.shutdownFlags |= flags
switch {
- case e.state.connected():
+ case e.EndpointState().connected():
// Close for read.
- if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
// Mark read side as closed.
e.rcvListMu.Lock()
e.rcvClosed = true
@@ -1767,47 +2296,56 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// If we're fully closed and we have unread data we need to abort
// the connection with a RST.
- if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
- e.notifyProtocolGoroutine(notifyReset)
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 {
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ // Wake up worker to terminate loop.
+ e.notifyProtocolGoroutine(notifyTickleWorker)
return nil
}
}
// Close for write.
- if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
e.sndBufMu.Lock()
-
if e.sndClosed {
// Already closed.
e.sndBufMu.Unlock()
- break
+ if e.EndpointState() == StateTimeWait {
+ return tcpip.ErrNotConnected
+ }
+ return nil
}
// Queue fin segment.
s := newSegmentFromView(&e.route, e.ID, nil)
e.sndQueue.PushBack(s)
e.sndBufInQueue++
-
// Mark endpoint as closed.
e.sndClosed = true
-
e.sndBufMu.Unlock()
-
- // Tell protocol goroutine to close.
- e.sndCloseWaker.Assert()
+ e.handleClose()
}
- case e.state == StateListen:
- // Tell protocolListenLoop to stop.
- if flags&tcpip.ShutdownRead != 0 {
- e.notifyProtocolGoroutine(notifyClose)
+ return nil
+ case e.EndpointState() == StateListen:
+ if e.shutdownFlags&tcpip.ShutdownRead != 0 {
+ // Reset all connections from the accept queue and keep the
+ // worker running so that it can continue handling incoming
+ // segments by replying with RST.
+ //
+ // By not removing this endpoint from the demuxer mapping, we
+ // ensure that any other bind to the same port fails, as on Linux.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ e.rcvListMu.Unlock()
+ e.closePendingAcceptableConnectionsLocked()
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}
-
+ return nil
default:
return tcpip.ErrNotConnected
}
-
- return nil
}
// Listen puts the endpoint in "listen" mode, which allows it to accept
@@ -1822,104 +2360,136 @@ func (e *endpoint) Listen(backlog int) *tcpip.Error {
}
func (e *endpoint) listen(backlog int) *tcpip.Error {
- e.mu.Lock()
- defer e.mu.Unlock()
-
- // Allow the backlog to be adjusted if the endpoint is not shutting down.
- // When the endpoint shuts down, it sets workerCleanup to true, and from
- // that point onward, acceptedChan is the responsibility of the cleanup()
- // method (and should not be touched anywhere else, including here).
- if e.state == StateListen && !e.workerCleanup {
- // Adjust the size of the channel iff we can fix existing
- // pending connections into the new one.
- if len(e.acceptedChan) > backlog {
- return tcpip.ErrInvalidEndpointState
- }
- if cap(e.acceptedChan) == backlog {
- return nil
- }
- origChan := e.acceptedChan
- e.acceptedChan = make(chan *endpoint, backlog)
- close(origChan)
- for ep := range origChan {
- e.acceptedChan <- ep
+ e.LockUser()
+ defer e.UnlockUser()
+
+ if e.EndpointState() == StateListen && !e.closed {
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ if e.acceptedChan == nil {
+ // listen is called after shutdown.
+ e.acceptedChan = make(chan *endpoint, backlog)
+ e.shutdownFlags = 0
+ e.rcvListMu.Lock()
+ e.rcvClosed = false
+ e.rcvListMu.Unlock()
+ } else {
+ // Adjust the size of the channel iff we can fix
+ // existing pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
}
+
+ // Notify any blocked goroutines that they can attempt to
+ // deliver endpoints again.
+ e.acceptCond.Broadcast()
+
return nil
}
+ if e.EndpointState() == StateInitial {
+ // The listen is called on an unbound socket, the socket is
+ // automatically bound to a random free port with the local
+ // address set to INADDR_ANY.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return err
+ }
+ }
+
// Endpoint must be bound before it can transition to listen mode.
- if e.state != StateBound {
+ if e.EndpointState() != StateBound {
e.stats.ReadErrors.InvalidEndpointState.Increment()
return tcpip.ErrInvalidEndpointState
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.reusePort, e.bindToDevice); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil {
return err
}
e.isRegistered = true
- e.state = StateListen
+ e.setEndpointState(StateListen)
+
+ // The channel may be non-nil when we're restoring the endpoint, and it
+ // may be pre-populated with some previously accepted (but not Accepted)
+ // endpoints.
+ e.acceptMu.Lock()
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
- e.workerRunning = true
+ e.acceptMu.Unlock()
+ e.workerRunning = true
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
-
return nil
}
// startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections.
-func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
- e.waiterQueue = waiterQueue
+func (e *endpoint) startAcceptedLoop() {
e.workerRunning = true
- go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+ e.mu.Unlock()
+ wakerInitDone := make(chan struct{})
+ go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save.
+ <-wakerInitDone
}
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode.
func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ rcvClosed := e.rcvClosed
+ e.rcvListMu.Unlock()
// Endpoint must be in listen state before it can accept connections.
- if e.state != StateListen {
+ if rcvClosed || e.EndpointState() != StateListen {
return nil, nil, tcpip.ErrInvalidEndpointState
}
// Get the new accepted endpoint.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
var n *endpoint
select {
case n = <-e.acceptedChan:
+ e.acceptCond.Signal()
default:
return nil, nil, tcpip.ErrWouldBlock
}
-
- // Start the protocol goroutine.
- wq := &waiter.Queue{}
- n.startAcceptedLoop(wq)
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
-
- return n, wq, nil
+ return n, n.waiterQueue, nil
}
// Bind binds the endpoint to a specific local port and optionally address.
func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
- e.mu.Lock()
- defer e.mu.Unlock()
+ e.LockUser()
+ defer e.UnlockUser()
+
+ return e.bindLocked(addr)
+}
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) {
// Don't allow binding once endpoint is not in the initial state
// anymore. This is because once the endpoint goes into a connected or
// listen state, it is already bound.
- if e.state != StateInitial {
+ if e.EndpointState() != StateInitial {
return tcpip.ErrAlreadyBound
}
e.BindAddr = addr.Addr
- netProto, err := e.checkV4Mapped(&addr)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -1935,26 +2505,30 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
return err
}
+ e.boundBindToDevice = e.bindToDevice
+ e.boundPortFlags = e.portFlags
e.isPortReserved = true
e.effectiveNetProtos = netProtos
e.ID.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func(bindToDevice tcpip.NICID) {
+ defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{})
e.isPortReserved = false
e.effectiveNetProtos = nil
e.ID.LocalPort = 0
e.ID.LocalAddress = ""
e.boundNICID = 0
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
- }(e.bindToDevice)
+ }(e.boundPortFlags, e.boundBindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
@@ -1968,16 +2542,20 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.ID.LocalAddress = addr.Addr
}
+ if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil {
+ return err
+ }
+
// Mark endpoint as bound.
- e.state = StateBound
+ e.setEndpointState(StateBound)
return nil
}
// GetLocalAddress returns the address to which the endpoint is bound.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
return tcpip.FullAddress{
Addr: e.ID.LocalAddress,
@@ -1988,10 +2566,10 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
// GetRemoteAddress returns the address to which the endpoint is connected.
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
- e.mu.RLock()
- defer e.mu.RUnlock()
+ e.LockUser()
+ defer e.UnlockUser()
- if !e.state.connected() {
+ if !e.EndpointState().connected() {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -2002,45 +2580,26 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}, nil
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
- s := newSegment(r, id, vv)
- if !s.parse() {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
- e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
- s.decRef()
- return
- }
-
- if !s.csumValid {
- e.stack.Stats().MalformedRcvdPackets.Increment()
- e.stack.Stats().TCP.ChecksumErrors.Increment()
- e.stats.ReceiveErrors.ChecksumErrors.Increment()
- s.decRef()
- return
- }
-
- e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
- e.stats.SegmentsReceived.Increment()
- if (s.flags & header.TCPFlagRst) != 0 {
- e.stack.Stats().TCP.ResetsReceived.Increment()
- }
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // TCP HandlePacket is not required anymore as inbound packets first
+ // land at the Dispatcher which then can either delivery using the
+ // worker go routine or directly do the invoke the tcp processing inline
+ // based on the state of the endpoint.
+}
+func (e *endpoint) enqueueSegment(s *segment) bool {
// Send packet to worker goroutine.
- if e.segmentQueue.enqueue(s) {
- e.newSegmentWaker.Assert()
- } else {
+ if !e.segmentQueue.enqueue(s) {
// The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.SegmentQueueDropped.Increment()
- s.decRef()
+ return false
}
+ return true
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
switch typ {
case stack.ControlPacketTooBig:
e.sndBufMu.Lock()
@@ -2051,6 +2610,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C
e.sndBufMu.Unlock()
e.notifyProtocolGoroutine(notifyMTUChanged)
+
+ case stack.ControlNoRoute:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNoRoute
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
+
+ case stack.ControlNetworkUnreachable:
+ e.lastErrorMu.Lock()
+ e.lastError = tcpip.ErrNetworkUnreachable
+ e.lastErrorMu.Unlock()
+ e.notifyProtocolGoroutine(notifyError)
}
}
@@ -2079,20 +2650,16 @@ func (e *endpoint) readyToRead(s *segment) {
if s != nil {
s.incRef()
e.rcvBufUsed += s.data.Size()
- // Check if the receive window is now closed. If so make sure
- // we set the zero window before we deliver the segment to ensure
- // that a subsequent read of the segment will correctly trigger
- // a non-zero notification.
- if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ // Increase counter if the receive window falls down below MSS
+ // or half receive buffer size, whichever smaller.
+ if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
- e.zeroWindow = true
}
e.rcvList.PushBack(s)
} else {
e.rcvClosed = true
}
e.rcvListMu.Unlock()
-
e.waiterQueue.Notify(waiter.EventIn)
}
@@ -2156,8 +2723,8 @@ func (e *endpoint) rcvWndScaleForHandshake() int {
// updateRecentTimestamp updates the recent timestamp using the algorithm
// described in https://tools.ietf.org/html/rfc7323#section-4.3
func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
- if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
- e.recentTS = tsVal
+ if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.setRecentTimestamp(tsVal)
}
}
@@ -2167,22 +2734,21 @@ func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value,
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS {
e.sendTSOk = true
- e.recentTS = synOpts.TSVal
+ e.setRecentTimestamp(synOpts.TSVal)
}
}
// timestamp returns the timestamp value to be used in the TSVal field of the
// timestamp option for outgoing TCP segments for a given endpoint.
func (e *endpoint) timestamp() uint32 {
- return tcpTimeStamp(e.tsOffset)
+ return tcpTimeStamp(time.Now(), e.tsOffset)
}
// tcpTimeStamp returns a timestamp offset by the provided offset. This is
// not inlined above as it's used when SYN cookies are in use and endpoint
// is not created at the time when the SYN cookie is sent.
-func tcpTimeStamp(offset uint32) uint32 {
- now := time.Now()
- return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+func tcpTimeStamp(curTime time.Time, offset uint32) uint32 {
+ return uint32(curTime.Unix()*1000+int64(curTime.Nanosecond()/1e6)) + offset
}
// timeStampOffset returns a randomized timestamp offset to be used when sending
@@ -2236,9 +2802,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
s.SegTime = time.Now()
// Copy EndpointID.
- e.mu.Lock()
s.ID = stack.TCPEndpointID(e.ID)
- e.mu.Unlock()
// Copy endpoint rcv state.
e.rcvListMu.Lock()
@@ -2256,7 +2820,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
// Endpoint TCP Option state.
s.SendTSOk = e.sendTSOk
- s.RecentTS = e.recentTS
+ s.RecentTS = e.recentTimestamp()
s.TSOffset = e.tsOffset
s.SACKPermitted = e.sackPermitted
s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
@@ -2327,6 +2891,14 @@ func (e *endpoint) completeState() stack.TCPEndpointState {
WEst: cubic.wEst,
}
}
+
+ rc := e.snd.rc
+ s.Sender.RACKState = stack.TCPRACKState{
+ XmitTime: rc.xmitTime,
+ EndSequence: rc.endSequence,
+ FACK: rc.fack,
+ RTT: rc.rtt,
+ }
return s
}
@@ -2363,17 +2935,15 @@ func (e *endpoint) initGSO() {
// State implements tcpip.Endpoint.State. It exports the endpoint's protocol
// state for diagnostics.
func (e *endpoint) State() uint32 {
- e.mu.Lock()
- defer e.mu.Unlock()
- return uint32(e.state)
+ return uint32(e.EndpointState())
}
// Info returns a copy of the endpoint info.
func (e *endpoint) Info() tcpip.EndpointInfo {
- e.mu.RLock()
+ e.LockUser()
// Make a copy of the endpoint info.
ret := e.EndpointInfo
- e.mu.RUnlock()
+ e.UnlockUser()
return &ret
}
@@ -2382,6 +2952,18 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
-func mssForRoute(r *stack.Route) uint16 {
- return uint16(r.MTU() - header.TCPMinimumSize)
+// Wait implements stack.TransportEndpoint.Wait.
+func (e *endpoint) Wait() {
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp)
+ defer e.waiterQueue.EventUnregister(&waitEntry)
+ for {
+ e.LockUser()
+ running := e.workerRunning
+ e.UnlockUser()
+ if !running {
+ break
+ }
+ <-notifyCh
+ }
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index eae17237e..723e47ddc 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -16,9 +16,10 @@ package tcp
import (
"fmt"
- "sync"
+ "sync/atomic"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -48,11 +49,10 @@ func (e *endpoint) beforeSave() {
e.mu.Lock()
defer e.mu.Unlock()
- switch e.state {
- case StateInitial, StateBound:
- // TODO(b/138137272): this enumeration duplicates
- // EndpointState.connected. remove it.
- case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ epState := e.EndpointState()
+ switch {
+ case epState == StateInitial || epState == StateBound:
+ case epState.connected() || epState.handshake():
if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)})
@@ -68,35 +68,31 @@ func (e *endpoint) beforeSave() {
break
}
fallthrough
- case StateListen, StateConnecting:
+ case epState == StateListen || epState == StateConnecting:
e.drainSegmentLocked()
- if e.state != StateClose && e.state != StateError {
+ // Refresh epState, since drainSegmentLocked may have changed it.
+ epState = e.EndpointState()
+ if !epState.closed() {
if !e.workerRunning {
panic("endpoint has no worker running in listen, connecting, or connected state")
}
- break
}
- fallthrough
- case StateError, StateClose:
- for e.state == StateError && e.workerRunning {
+ case epState.closed():
+ for e.workerRunning {
e.mu.Unlock()
time.Sleep(100 * time.Millisecond)
e.mu.Lock()
}
if e.workerRunning {
- panic("endpoint still has worker running in closed or error state")
+ panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID))
}
default:
- panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState()))
}
if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
panic("endpoint still has waiters upon save")
}
-
- if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) {
- panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
- }
}
// saveAcceptedChan is invoked by stateify.
@@ -135,7 +131,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
// saveState is invoked by stateify.
func (e *endpoint) saveState() EndpointState {
- return e.state
+ return e.EndpointState()
}
// Endpoint loading must be done in the following ordering by their state, to
@@ -148,23 +144,34 @@ var connectingLoading sync.WaitGroup
// Bound endpoint loading happens last.
// loadState is invoked by stateify.
-func (e *endpoint) loadState(state EndpointState) {
+func (e *endpoint) loadState(epState EndpointState) {
// This is to ensure that the loading wait groups include all applicable
// endpoints before any asynchronous calls to the Wait() methods.
- if state.connected() {
+ // For restore purposes we treat TimeWait like a connected endpoint.
+ if epState.connected() || epState == StateTimeWait {
connectedLoading.Add(1)
}
- switch state {
- case StateListen:
+ switch {
+ case epState == StateListen:
listenLoading.Add(1)
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
connectingLoading.Add(1)
}
- e.state = state
+ // Directly update the state here rather than using e.setEndpointState
+ // as the endpoint is still being loaded and the stack reference is not
+ // yet initialized.
+ atomic.StoreUint32((*uint32)(&e.state), uint32(epState))
}
// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad() {
+ e.origEndpointState = e.state
+ // Restore the endpoint to InitialState as it will be moved to
+ // its origEndpointState during Resume.
+ e.state = StateInitial
+ // Condition variables and mutexs are not S/R'ed so reinitialize
+ // acceptCond with e.acceptMu.
+ e.acceptCond = sync.NewCond(&e.acceptMu)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
@@ -172,34 +179,40 @@ func (e *endpoint) afterLoad() {
func (e *endpoint) Resume(s *stack.Stack) {
e.stack = s
e.segmentQueue.setLimit(MaxUnprocessedSegments)
- e.workMu.Init()
-
- state := e.state
- switch state {
+ epState := e.origEndpointState
+ switch epState {
case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished:
var ss SendBufferSizeOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
}
- if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
- panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max))
}
}
}
bind := func() {
- e.state = StateInitial
- if len(e.BindAddr) == 0 {
- e.BindAddr = e.ID.LocalAddress
+ addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort})
+ if err != nil {
+ panic("unable to parse BindAddr: " + err.String())
}
- if err := e.Bind(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}); err != nil {
- panic("endpoint binding failed: " + err.String())
+ if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
+ e.isPortReserved = true
+
+ // Mark endpoint as bound.
+ e.setEndpointState(StateBound)
}
- switch state {
- case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing:
+ switch {
+ case epState.connected():
bind()
if len(e.connectingAddress) == 0 {
e.connectingAddress = e.ID.RemoteAddress
@@ -217,8 +230,18 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
panic("endpoint connecting failed: " + err.String())
}
+ e.mu.Lock()
+ e.state = e.origEndpointState
+ closed := e.closed
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyTickleWorker)
+ if epState == StateFinWait2 && closed {
+ // If the endpoint has been closed then make sure we notify so
+ // that the FIN_WAIT2 timer is started after a restore.
+ e.notifyProtocolGoroutine(notifyClose)
+ }
connectedLoading.Done()
- case StateListen:
+ case epState == StateListen:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -227,10 +250,15 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err := e.Listen(backlog); err != nil {
panic("endpoint listening failed: " + err.String())
}
+ e.LockUser()
+ if e.shutdownFlags != 0 {
+ e.shutdownLocked(e.shutdownFlags)
+ }
+ e.UnlockUser()
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateConnecting, StateSynSent, StateSynRecv:
+ case epState.connecting():
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -242,7 +270,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
connectingLoading.Done()
tcpip.AsyncLoading.Done()
}()
- case StateBound:
+ case epState == StateBound:
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
@@ -251,20 +279,14 @@ func (e *endpoint) Resume(s *stack.Stack) {
bind()
tcpip.AsyncLoading.Done()
}()
- case StateClose:
- if e.isPortReserved {
- tcpip.AsyncLoading.Add(1)
- go func() {
- connectedLoading.Wait()
- listenLoading.Wait()
- connectingLoading.Wait()
- bind()
- e.state = StateClose
- tcpip.AsyncLoading.Done()
- }()
- }
- fallthrough
- case StateError:
+ case epState == StateClose:
+ e.isPortReserved = false
+ e.state = StateClose
+ e.stack.CompleteTransportEndpointCleanup(e)
+ tcpip.DeleteDanglingEndpoint(e)
+ case epState == StateError:
+ e.state = StateError
+ e.stack.CompleteTransportEndpointCleanup(e)
tcpip.DeleteDanglingEndpoint(e)
}
}
@@ -284,7 +306,17 @@ func (e *endpoint) loadLastError(s string) {
return
}
- e.lastError = loadError(s)
+ e.lastError = tcpip.StringToError(s)
+}
+
+// saveRecentTSTime is invoked by stateify.
+func (e *endpoint) saveRecentTSTime() unixTime {
+ return unixTime{e.recentTSTime.Unix(), e.recentTSTime.UnixNano()}
+}
+
+// loadRecentTSTime is invoked by stateify.
+func (e *endpoint) loadRecentTSTime(unix unixTime) {
+ e.recentTSTime = time.Unix(unix.second, unix.nano)
}
// saveHardError is invoked by stateify.
@@ -302,71 +334,7 @@ func (e *EndpointInfo) loadHardError(s string) {
return
}
- e.HardError = loadError(s)
-}
-
-var messageToError map[string]*tcpip.Error
-
-var populate sync.Once
-
-func loadError(s string) *tcpip.Error {
- populate.Do(func() {
- var errors = []*tcpip.Error{
- tcpip.ErrUnknownProtocol,
- tcpip.ErrUnknownNICID,
- tcpip.ErrUnknownDevice,
- tcpip.ErrUnknownProtocolOption,
- tcpip.ErrDuplicateNICID,
- tcpip.ErrDuplicateAddress,
- tcpip.ErrNoRoute,
- tcpip.ErrBadLinkEndpoint,
- tcpip.ErrAlreadyBound,
- tcpip.ErrInvalidEndpointState,
- tcpip.ErrAlreadyConnecting,
- tcpip.ErrAlreadyConnected,
- tcpip.ErrNoPortAvailable,
- tcpip.ErrPortInUse,
- tcpip.ErrBadLocalAddress,
- tcpip.ErrClosedForSend,
- tcpip.ErrClosedForReceive,
- tcpip.ErrWouldBlock,
- tcpip.ErrConnectionRefused,
- tcpip.ErrTimeout,
- tcpip.ErrAborted,
- tcpip.ErrConnectStarted,
- tcpip.ErrDestinationRequired,
- tcpip.ErrNotSupported,
- tcpip.ErrQueueSizeNotSupported,
- tcpip.ErrNotConnected,
- tcpip.ErrConnectionReset,
- tcpip.ErrConnectionAborted,
- tcpip.ErrNoSuchFile,
- tcpip.ErrInvalidOptionValue,
- tcpip.ErrNoLinkAddress,
- tcpip.ErrBadAddress,
- tcpip.ErrNetworkUnreachable,
- tcpip.ErrMessageTooLong,
- tcpip.ErrNoBufferSpace,
- tcpip.ErrBroadcastDisabled,
- tcpip.ErrNotPermitted,
- tcpip.ErrAddressFamilyNotSupported,
- }
-
- messageToError = make(map[string]*tcpip.Error)
- for _, e := range errors {
- if messageToError[e.String()] != nil {
- panic("tcpip errors with duplicated message: " + e.String())
- }
- messageToError[e.String()] = e
- }
- })
-
- e, ok := messageToError[s]
- if !ok {
- panic("unknown error message: " + s)
- }
-
- return e
+ e.HardError = tcpip.StringToError(s)
}
// saveMeasureTime is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 63666f0b3..070b634b4 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -15,10 +15,8 @@
package tcp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -63,8 +61,8 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
// We only care about well-formed SYN packets.
@@ -132,7 +130,7 @@ func (r *ForwarderRequest) Complete(sendReset bool) {
// If the caller requested, send a reset.
if sendReset {
- replyWithReset(r.segment)
+ replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL())
}
// Release all resources.
@@ -159,13 +157,13 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
- })
+ }, queue, nil)
if err != nil {
return nil, err
}
// Start the protocol goroutine.
- ep.startAcceptedLoop(queue)
+ ep.startAcceptedLoop()
return ep, nil
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index db40785d3..c5afa2680 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -21,9 +21,11 @@
package tcp
import (
+ "runtime"
"strings"
- "sync"
+ "time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -54,41 +56,149 @@ const (
// MaxUnprocessedSegments is the maximum number of unprocessed segments
// that can be queued for a given endpoint.
MaxUnprocessedSegments = 300
+
+ // DefaultTCPLingerTimeout is the amount of time that sockets linger in
+ // FIN_WAIT_2 state before being marked closed.
+ DefaultTCPLingerTimeout = 60 * time.Second
+
+ // MaxTCPLingerTimeout is the maximum amount of time that sockets
+ // linger in FIN_WAIT_2 state before being marked closed.
+ MaxTCPLingerTimeout = 120 * time.Second
+
+ // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger
+ // in TIME_WAIT state before being marked closed.
+ DefaultTCPTimeWaitTimeout = 60 * time.Second
+
+ // DefaultSynRetries is the default value for the number of SYN retransmits
+ // before a connect is aborted.
+ DefaultSynRetries = 6
+)
+
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
)
-// SACKEnabled option can be used to enable SACK support in the TCP
-// protocol. See: https://tools.ietf.org/html/rfc2018.
+// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to
+// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
-// SendBufferSizeOption allows the default, min and max send buffer sizes for
-// TCP endpoints to be queried or configured.
+// Recovery is used by stack.(*Stack).TransportProtocolOption to
+// set loss detection algorithm in TCP.
+type Recovery int32
+
+const (
+ // RACKLossDetection indicates RACK is used for loss detection and
+ // recovery.
+ RACKLossDetection Recovery = 1 << iota
+
+ // RACKStaticReoWnd indicates the reordering window should not be
+ // adjusted when DSACK is received.
+ RACKStaticReoWnd
+
+ // RACKNoDupTh indicates RACK should not consider the classic three
+ // duplicate acknowledgements rule to mark the segments as lost. This
+ // is used when reordering is not detected.
+ RACKNoDupTh
+)
+
+// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to
+// enable/disable Nagle's algorithm in TCP.
+type DelayEnabled bool
+
+// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption
+// to get/set the default, min and max TCP send buffer sizes.
type SendBufferSizeOption struct {
Min int
Default int
Max int
}
-// ReceiveBufferSizeOption allows the default, min and max receive buffer size
-// for TCP endpoints to be queried or configured.
+// ReceiveBufferSizeOption is used by
+// stack.(Stack*).TransportProtocolOption to get/set the default, min and max
+// TCP receive buffer sizes.
type ReceiveBufferSizeOption struct {
Min int
Default int
Max int
}
-const (
- ccReno = "reno"
- ccCubic = "cubic"
-)
+// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
+// value is protected by a mutex so that we can increment only when it's
+// guaranteed not to go above a threshold.
+type synRcvdCounter struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+ threshold uint64
+}
+
+// inc tries to increment the global number of endpoints in SYN-RCVD state. It
+// succeeds if the increment doesn't make the count go beyond the threshold, and
+// fails otherwise.
+func (s *synRcvdCounter) inc() bool {
+ s.Lock()
+ defer s.Unlock()
+ if s.value >= s.threshold {
+ return false
+ }
+
+ s.pending.Add(1)
+ s.value++
+
+ return true
+}
+
+// dec atomically decrements the global number of endpoints in SYN-RCVD
+// state. It must only be called if a previous call to inc succeeded.
+func (s *synRcvdCounter) dec() {
+ s.Lock()
+ defer s.Unlock()
+ s.value--
+ s.pending.Done()
+}
+
+// synCookiesInUse returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func (s *synRcvdCounter) synCookiesInUse() bool {
+ s.Lock()
+ defer s.Unlock()
+ return s.value >= s.threshold
+}
+
+// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
+func (s *synRcvdCounter) SetThreshold(threshold uint64) {
+ s.Lock()
+ defer s.Unlock()
+ s.threshold = threshold
+}
+
+// Threshold returns the current value of synRcvdCounter.Threhsold.
+func (s *synRcvdCounter) Threshold() uint64 {
+ s.Lock()
+ defer s.Unlock()
+ return s.threshold
+}
type protocol struct {
- mu sync.Mutex
+ mu sync.RWMutex
sackEnabled bool
+ recovery Recovery
+ delayEnabled bool
sendBufferSize SendBufferSizeOption
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
moderateReceiveBuffer bool
+ lingerTimeout time.Duration
+ timeWaitTimeout time.Duration
+ timeWaitReuse tcpip.TCPTimeWaitReuseOption
+ minRTO time.Duration
+ maxRTO time.Duration
+ maxRetries uint32
+ synRcvdCount synRcvdCounter
+ synRetries uint8
+ dispatcher dispatcher
}
// Number returns the tcp protocol number.
@@ -97,7 +207,7 @@ func (*protocol) Number() tcpip.TransportProtocolNumber {
}
// NewEndpoint creates a new tcp endpoint.
-func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
@@ -119,6 +229,14 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
return h.SourcePort(), h.DestinationPort(), nil
}
+// QueuePacket queues packets targeted at an endpoint after hashing the packet
+// to a specific processing queue. Each queue is serviced by its own processor
+// goroutine which is responsible for dequeuing and doing full TCP dispatch of
+// the packet.
+func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ p.dispatcher.queuePacket(r, ep, id, pkt)
+}
+
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
//
@@ -126,8 +244,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- s := newSegment(r, id, vv)
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ s := newSegment(r, id, pkt)
defer s.decRef()
if !s.parse() || !s.csumValid {
@@ -139,24 +257,45 @@ func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Transpo
return true
}
- replyWithReset(s)
+ replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL())
return true
}
// replyWithReset replies to the given segment with a reset segment.
-func replyWithReset(s *segment) {
+func replyWithReset(s *segment, tos, ttl uint8) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
+ ack := seqnum.Value(0)
+ flags := byte(header.TCPFlagRst)
+ // As per RFC 793 page 35 (Reset Generation)
+ // 1. If the connection does not exist (CLOSED) then a reset is sent
+ // in response to any incoming segment except another reset. In
+ // particular, SYNs addressed to a non-existent connection are rejected
+ // by this means.
+
+ // If the incoming segment has an ACK field, the reset takes its
+ // sequence number from the ACK field of the segment, otherwise the
+ // reset has sequence number zero and the ACK field is set to the sum
+ // of the sequence number and segment length of the incoming segment.
+ // The connection remains in the CLOSED state.
if s.flagIsSet(header.TCPFlagAck) {
seq = s.ackNumber
+ } else {
+ flags |= header.TCPFlagAck
+ ack = s.sequenceNumber.Add(s.logicalLen())
}
-
- ack := s.sequenceNumber.Add(s.logicalLen())
-
- sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
+ sendTCP(&s.route, tcpFields{
+ id: s.id,
+ ttl: ttl,
+ tos: tos,
+ flags: flags,
+ seq: seq,
+ ack: ack,
+ rcvWnd: 0,
+ }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */)
}
-// SetOption implements TransportProtocol.SetOption.
+// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case SACKEnabled:
@@ -165,6 +304,18 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case Recovery:
+ p.mu.Lock()
+ p.recovery = Recovery(v)
+ p.mu.Unlock()
+ return nil
+
+ case DelayEnabled:
+ p.mu.Lock()
+ p.delayEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
case SendBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
@@ -202,48 +353,174 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPLingerTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.lingerTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitTimeoutOption:
+ if v < 0 {
+ v = 0
+ }
+ p.mu.Lock()
+ p.timeWaitTimeout = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPTimeWaitReuseOption:
+ if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.timeWaitReuse = v
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPMinRTOOption:
+ if v < 0 {
+ v = tcpip.TCPMinRTOOption(MinRTO)
+ }
+ p.mu.Lock()
+ p.minRTO = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPMaxRTOOption:
+ if v < 0 {
+ v = tcpip.TCPMaxRTOOption(MaxRTO)
+ }
+ p.mu.Lock()
+ p.maxRTO = time.Duration(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPMaxRetriesOption:
+ p.mu.Lock()
+ p.maxRetries = uint32(v)
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.Lock()
+ p.synRcvdCount.SetThreshold(uint64(v))
+ p.mu.Unlock()
+ return nil
+
+ case tcpip.TCPSynRetriesOption:
+ if v < 1 || v > 255 {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.synRetries = uint8(v)
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
}
-// Option implements TransportProtocol.Option.
+// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
- p.mu.Lock()
+ p.mu.RLock()
*v = SACKEnabled(p.sackEnabled)
- p.mu.Unlock()
+ p.mu.RUnlock()
+ return nil
+
+ case *Recovery:
+ p.mu.RLock()
+ *v = Recovery(p.recovery)
+ p.mu.RUnlock()
+ return nil
+
+ case *DelayEnabled:
+ p.mu.RLock()
+ *v = DelayEnabled(p.delayEnabled)
+ p.mu.RUnlock()
return nil
case *SendBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.sendBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *ReceiveBufferSizeOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = p.recvBufferSize
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.CongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.CongestionControlOption(p.congestionControl)
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.AvailableCongestionControlOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
- p.mu.Unlock()
+ p.mu.RUnlock()
return nil
case *tcpip.ModerateReceiveBufferOption:
- p.mu.Lock()
+ p.mu.RLock()
*v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer)
- p.mu.Unlock()
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPLingerTimeoutOption:
+ p.mu.RLock()
+ *v = tcpip.TCPLingerTimeoutOption(p.lingerTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitTimeoutOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitTimeoutOption(p.timeWaitTimeout)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPTimeWaitReuseOption:
+ p.mu.RLock()
+ *v = tcpip.TCPTimeWaitReuseOption(p.timeWaitReuse)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPMinRTOOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMinRTOOption(p.minRTO)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPMaxRTOOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMaxRTOOption(p.maxRTO)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPMaxRetriesOption:
+ p.mu.RLock()
+ *v = tcpip.TCPMaxRetriesOption(p.maxRetries)
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ p.mu.RUnlock()
+ return nil
+
+ case *tcpip.TCPSynRetriesOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRetriesOption(p.synRetries)
+ p.mu.RUnlock()
return nil
default:
@@ -251,12 +528,67 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements stack.TransportProtocol.Close.
+func (p *protocol) Close() {
+ p.dispatcher.close()
+}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (p *protocol) Wait() {
+ p.dispatcher.wait()
+}
+
+// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
+// instance.
+func (p *protocol) SynRcvdCounter() *synRcvdCounter {
+ return &p.synRcvdCount
+}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ // TCP header is variable length, peek at it first.
+ hdrLen := header.TCPMinimumSize
+ hdr, ok := pkt.Data.PullUp(hdrLen)
+ if !ok {
+ return false
+ }
+
+ // If the header has options, pull those up as well.
+ if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() {
+ // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of
+ // packets.
+ hdrLen = offset
+ }
+
+ _, ok = pkt.TransportHeader().Consume(hdrLen)
+ return ok
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ p := protocol{
+ sendBufferSize: SendBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultSendBufferSize,
+ Max: MaxBufferSize,
+ },
+ recvBufferSize: ReceiveBufferSizeOption{
+ Min: MinBufferSize,
+ Default: DefaultReceiveBufferSize,
+ Max: MaxBufferSize,
+ },
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
+ lingerTimeout: DefaultTCPLingerTimeout,
+ timeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
+ synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
+ synRetries: DefaultSynRetries,
+ minRTO: MinRTO,
+ maxRTO: MaxRTO,
+ maxRetries: MaxRetries,
+ recovery: RACKLossDetection,
}
+ p.dispatcher.init(runtime.GOMAXPROCS(0))
+ return &p
}
diff --git a/pkg/tcpip/transport/tcp/rack.go b/pkg/tcpip/transport/tcp/rack.go
new file mode 100644
index 000000000..d969ca23a
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack.go
@@ -0,0 +1,82 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+// RACK is a loss detection algorithm used in TCP to detect packet loss and
+// reordering using transmission timestamp of the packets instead of packet or
+// sequence counts. To use RACK, SACK should be enabled on the connection.
+
+// rackControl stores the rack related fields.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-6.1
+//
+// +stateify savable
+type rackControl struct {
+ // xmitTime is the latest transmission timestamp of rackControl.seg.
+ xmitTime time.Time `state:".(unixTime)"`
+
+ // endSequence is the ending TCP sequence number of rackControl.seg.
+ endSequence seqnum.Value
+
+ // fack is the highest selectively or cumulatively acknowledged
+ // sequence.
+ fack seqnum.Value
+
+ // rtt is the RTT of the most recently delivered packet on the
+ // connection (either cumulatively acknowledged or selectively
+ // acknowledged) that was not marked invalid as a possible spurious
+ // retransmission.
+ rtt time.Duration
+}
+
+// Update will update the RACK related fields when an ACK has been received.
+// See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+func (rc *rackControl) Update(seg *segment, ackSeg *segment, srtt time.Duration, offset uint32) {
+ rtt := time.Now().Sub(seg.xmitTime)
+
+ // If the ACK is for a retransmitted packet, do not update if it is a
+ // spurious inference which is determined by below checks:
+ // 1. When Timestamping option is available, if the TSVal is less than the
+ // transmit time of the most recent retransmitted packet.
+ // 2. When RTT calculated for the packet is less than the smoothed RTT
+ // for the connection.
+ // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2
+ // step 2
+ if seg.xmitCount > 1 {
+ if ackSeg.parsedOptions.TS && ackSeg.parsedOptions.TSEcr != 0 {
+ if ackSeg.parsedOptions.TSEcr < tcpTimeStamp(seg.xmitTime, offset) {
+ return
+ }
+ }
+ if rtt < srtt {
+ return
+ }
+ }
+
+ rc.rtt = rtt
+ // Update rc.xmitTime and rc.endSequence to the transmit time and
+ // ending sequence number of the packet which has been acknowledged
+ // most recently.
+ endSeq := seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ if rc.xmitTime.Before(seg.xmitTime) || (seg.xmitTime.Equal(rc.xmitTime) && rc.endSequence.LessThan(endSeq)) {
+ rc.xmitTime = seg.xmitTime
+ rc.endSequence = endSeq
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/rack_state.go b/pkg/tcpip/transport/tcp/rack_state.go
new file mode 100644
index 000000000..c9dc7e773
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rack_state.go
@@ -0,0 +1,29 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveXmitTime is invoked by stateify.
+func (rc *rackControl) saveXmitTime() unixTime {
+ return unixTime{rc.xmitTime.Unix(), rc.xmitTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (rc *rackControl) loadXmitTime(unix unixTime) {
+ rc.xmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index e90f9a7d9..5e0bfe585 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -18,6 +18,7 @@ import (
"container/heap"
"time"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
)
@@ -49,29 +50,36 @@ type receiver struct {
pendingRcvdSegments segmentHeap
pendingBufUsed seqnum.Size
pendingBufSize seqnum.Size
+
+ // Time when the last ack was received.
+ lastRcvdAckTime time.Time `state:".(unixTime)"`
}
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver {
return &receiver{
- ep: ep,
- rcvNxt: irs + 1,
- rcvAcc: irs.Add(rcvWnd + 1),
- rcvWnd: rcvWnd,
- rcvWndScale: rcvWndScale,
- pendingBufSize: pendingBufSize,
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWnd: rcvWnd,
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: pendingBufSize,
+ lastRcvdAckTime: time.Now(),
}
}
// acceptable checks if the segment sequence number range is acceptable
// according to the table on page 26 of RFC 793.
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- rcvWnd := r.rcvNxt.Size(r.rcvAcc)
- if rcvWnd == 0 {
- return segLen == 0 && segSeq == r.rcvNxt
+ // r.rcvWnd could be much larger than the window size we advertised in our
+ // outgoing packets, we should use what we have advertised for acceptability
+ // test.
+ scaledWindowSize := r.rcvWnd >> r.rcvWndScale
+ if scaledWindowSize > 0xffff {
+ // This is what we actually put in the Window field.
+ scaledWindowSize = 0xffff
}
-
- return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
- seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+ advertisedWindowSize := scaledWindowSize << r.rcvWndScale
+ return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize))
}
// getSendParams returns the parameters needed by the sender when building
@@ -93,12 +101,6 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// in such cases we may need to send an ack to indicate to our peer that it can
// resume sending data.
func (r *receiver) nonZeroWindow() {
- if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
- // We never got around to announcing a zero window size, so we
- // don't need to immediately announce a nonzero one.
- return
- }
-
// Immediately send an ack.
r.ep.snd.sendAck()
}
@@ -169,22 +171,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// We just received a FIN, our next state depends on whether we sent a
// FIN already or not.
- r.ep.mu.Lock()
- switch r.ep.state {
+ switch r.ep.EndpointState() {
case StateEstablished:
- r.ep.state = StateCloseWait
+ r.ep.setEndpointState(StateCloseWait)
case StateFinWait1:
if s.flagIsSet(header.TCPFlagAck) {
// FIN-ACK, transition to TIME-WAIT.
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
} else {
// Simultaneous close, expecting a final ACK.
- r.ep.state = StateClosing
+ r.ep.setEndpointState(StateClosing)
}
case StateFinWait2:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
}
- r.ep.mu.Unlock()
// Flush out any pending segments, except the very first one if
// it happens to be the one we're handling now because the
@@ -196,6 +196,10 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
for i := first; i < len(r.pendingRcvdSegments); i++ {
r.pendingRcvdSegments[i].decRef()
+ // Note that slice truncation does not allow garbage collection of
+ // truncated items, thus truncated items must be set to nil to avoid
+ // memory leaks.
+ r.pendingRcvdSegments[i] = nil
}
r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
@@ -204,17 +208,20 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum
// Handle ACK (not FIN-ACK, which we handled above) during one of the
// shutdown states.
- if s.flagIsSet(header.TCPFlagAck) {
- r.ep.mu.Lock()
- switch r.ep.state {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt {
+ switch r.ep.EndpointState() {
case StateFinWait1:
- r.ep.state = StateFinWait2
+ r.ep.setEndpointState(StateFinWait2)
+ // Notify protocol goroutine that we have received an
+ // ACK to our FIN so that it can start the FIN_WAIT2
+ // timer to abort connection if the other side does
+ // not close within 2MSL.
+ r.ep.notifyProtocolGoroutine(notifyClose)
case StateClosing:
- r.ep.state = StateTimeWait
+ r.ep.setEndpointState(StateTimeWait)
case StateLastAck:
- r.ep.state = StateClose
+ r.ep.transitionToStateCloseLocked()
}
- r.ep.mu.Unlock()
}
return true
@@ -253,32 +260,119 @@ func (r *receiver) updateRTT() {
r.ep.rcvListMu.Unlock()
}
-// handleRcvdSegment handles TCP segments directed at the connection managed by
-// r as they arrive. It is called by the protocol main loop.
-func (r *receiver) handleRcvdSegment(s *segment) {
+func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) {
+ r.ep.rcvListMu.Lock()
+ rcvClosed := r.ep.rcvClosed || r.closed
+ r.ep.rcvListMu.Unlock()
+
+ // If we are in one of the shutdown states then we need to do
+ // additional checks before we try and process the segment.
+ switch state {
+ case StateCloseWait:
+ // If the ACK acks something not yet sent then we send an ACK.
+ if r.ep.snd.sndNxt.LessThan(s.ackNumber) {
+ r.ep.snd.sendAck()
+ return true, nil
+ }
+ fallthrough
+ case StateClosing, StateLastAck:
+ if !s.sequenceNumber.LessThanEq(r.rcvNxt) {
+ // Just drop the segment as we have
+ // already received a FIN and this
+ // segment is after the sequence number
+ // for the FIN.
+ return true, nil
+ }
+ fallthrough
+ case StateFinWait1:
+ fallthrough
+ case StateFinWait2:
+ // If we are closed for reads (either due to an
+ // incoming FIN or the user calling shutdown(..,
+ // SHUT_RD) then any data past the rcvNxt should
+ // trigger a RST.
+ endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) {
+ return true, tcpip.ErrConnectionAborted
+ }
+ if state == StateFinWait1 {
+ break
+ }
+
+ // If it's a retransmission of an old data segment
+ // or a pure ACK then allow it.
+ if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) ||
+ s.logicalLen() == 0 {
+ break
+ }
+
+ // In FIN-WAIT2 if the socket is fully
+ // closed(not owned by application on our end
+ // then the only acceptable segment is a
+ // FIN. Since FIN can technically also carry
+ // data we verify that the segment carrying a
+ // FIN ends at exactly e.rcvNxt+1.
+ //
+ // From RFC793 page 25.
+ //
+ // For sequence number purposes, the SYN is
+ // considered to occur before the first actual
+ // data octet of the segment in which it occurs,
+ // while the FIN is considered to occur after
+ // the last actual data octet in a segment in
+ // which it occurs.
+ if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) {
+ return true, tcpip.ErrConnectionAborted
+ }
+ }
+
// We don't care about receive processing anymore if the receive side
// is closed.
- if r.closed {
- return
+ //
+ // NOTE: We still want to permit a FIN as it's possible only our
+ // end has closed and the peer is yet to send a FIN. Hence we
+ // compare only the payload.
+ segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size()))
+ if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) {
+ return true, nil
}
+ return false, nil
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
+ state := r.ep.EndpointState()
+ closed := r.ep.closed
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// If the sequence number range is outside the acceptable range, just
- // send an ACK. This is according to RFC 793, page 37.
+ // send an ACK and stop further processing of the segment.
+ // This is according to RFC 793, page 68.
if !r.acceptable(segSeq, segLen) {
r.ep.snd.sendAck()
- return
+ return true, nil
+ }
+
+ if state != StateEstablished {
+ drop, err := r.handleRcvdSegmentClosing(s, state, closed)
+ if drop || err != nil {
+ return drop, err
+ }
}
+ // Store the time of the last ack.
+ r.lastRcvdAckTime = time.Now()
+
// Defer segment processing if it can't be consumed now.
if !r.consumeSegment(s, segSeq, segLen) {
if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
// We only store the segment if it's within our buffer
// size limit.
if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ r.pendingBufUsed += seqnum.Size(s.segMemSize())
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -288,7 +382,7 @@ func (r *receiver) handleRcvdSegment(s *segment) {
// have to retransmit.
r.ep.snd.sendAck()
}
- return
+ return false, nil
}
// Since we consumed a segment update the receiver's RTT estimate
@@ -312,7 +406,70 @@ func (r *receiver) handleRcvdSegment(s *segment) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.pendingBufUsed -= seqnum.Size(s.segMemSize())
s.decRef()
}
+ return false, nil
+}
+
+// handleTimeWaitSegment handles inbound segments received when the endpoint
+// has entered the TIME_WAIT state.
+func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) {
+ segSeq := s.sequenceNumber
+ segLen := seqnum.Size(s.data.Size())
+
+ // Just silently drop any RST packets in TIME_WAIT. We do not support
+ // TIME_WAIT assasination as a result we confirm w/ fix 1 as described
+ // in https://tools.ietf.org/html/rfc1337#section-3.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return false, false
+ }
+
+ // If it's a SYN and the sequence number is higher than any seen before
+ // for this connection then try and redirect it to a listening endpoint
+ // if available.
+ //
+ // RFC 1122:
+ // "When a connection is [...] on TIME-WAIT state [...]
+ // [a TCP] MAY accept a new SYN from the remote TCP to
+ // reopen the connection directly, if it:
+
+ // (1) assigns its initial sequence number for the new
+ // connection to be larger than the largest sequence
+ // number it used on the previous connection incarnation,
+ // and
+
+ // (2) returns to TIME-WAIT state if the SYN turns out
+ // to be an old duplicate".
+ if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) {
+
+ return false, true
+ }
+
+ // Drop the segment if it does not contain an ACK.
+ if !s.flagIsSet(header.TCPFlagAck) {
+ return false, false
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if r.ep.sendTSOk && s.parsedOptions.TS {
+ r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq)
+ }
+
+ if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) {
+ // If it's a FIN-ACK then resetTimeWait and send an ACK, as it
+ // indicates our final ACK could have been lost.
+ r.ep.snd.sendAck()
+ return true, false
+ }
+
+ // If the sequence number range is outside the acceptable range or
+ // carries data then just send an ACK. This is according to RFC 793,
+ // page 37.
+ //
+ // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt.
+ if segSeq != r.rcvNxt || segLen != 0 {
+ r.ep.snd.sendAck()
+ }
+ return false, false
}
diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go
new file mode 100644
index 000000000..2bf21a2e7
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_state.go
@@ -0,0 +1,29 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// saveLastRcvdAckTime is invoked by stateify.
+func (r *receiver) saveLastRcvdAckTime() unixTime {
+ return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()}
+}
+
+// loadLastRcvdAckTime is invoked by stateify.
+func (r *receiver) loadLastRcvdAckTime(unix unixTime) {
+ r.lastRcvdAckTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go
new file mode 100644
index 000000000..8a026ec46
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rcv_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+func TestAcceptable(t *testing.T) {
+ for _, tt := range []struct {
+ segSeq seqnum.Value
+ segLen seqnum.Size
+ rcvNxt, rcvAcc seqnum.Value
+ want bool
+ }{
+ // The segment is smaller than the window.
+ {105, 2, 100, 104, false},
+ {105, 2, 101, 105, true},
+ {105, 2, 102, 106, true},
+ {105, 2, 103, 107, true},
+ {105, 2, 104, 108, true},
+ {105, 2, 105, 109, true},
+ {105, 2, 106, 110, true},
+ {105, 2, 107, 111, false},
+
+ // The segment is larger than the window.
+ {105, 4, 103, 105, true},
+ {105, 4, 104, 106, true},
+ {105, 4, 105, 107, true},
+ {105, 4, 106, 108, true},
+ {105, 4, 107, 109, true},
+ {105, 4, 108, 110, true},
+ {105, 4, 109, 111, false},
+ {105, 4, 110, 112, false},
+
+ // The segment has no width.
+ {105, 0, 100, 102, false},
+ {105, 0, 101, 103, false},
+ {105, 0, 102, 104, false},
+ {105, 0, 103, 105, true},
+ {105, 0, 104, 106, true},
+ {105, 0, 105, 107, true},
+ {105, 0, 106, 108, false},
+ {105, 0, 107, 109, false},
+
+ // The receive window has no width.
+ {105, 2, 103, 103, false},
+ {105, 2, 104, 104, false},
+ {105, 2, 105, 105, false},
+ {105, 2, 106, 106, false},
+ {105, 2, 107, 107, false},
+ {105, 2, 108, 108, false},
+ {105, 2, 109, 109, false},
+ } {
+ if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want {
+ t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index ea725d513..94307d31a 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -35,6 +35,7 @@ type segment struct {
id stack.TransportEndpointID `state:"manual"`
route stack.Route `state:"manual"`
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View `state:"nosave"`
@@ -55,18 +56,19 @@ type segment struct {
options []byte `state:".([]byte)"`
hasNewSACKInfo bool
rcvdTime time.Time `state:".(unixTime)"`
- // xmitTime is the last transmit time of this segment. A zero value
- // indicates that the segment has yet to be transmitted.
- xmitTime time.Time `state:".(unixTime)"`
+ // xmitTime is the last transmit time of this segment.
+ xmitTime time.Time `state:".(unixTime)"`
+ xmitCount uint32
}
-func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
+func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
- s.data = vv.Clone(s.views[:])
+ s.data = pkt.Data.Clone(s.views[:])
+ s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
return s
}
@@ -77,9 +79,11 @@ func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.V
id: id,
route: r.Clone(),
}
- s.views[0] = v
- s.data = buffer.NewVectorisedView(len(v), s.views[:1])
s.rcvdTime = time.Now()
+ if len(v) != 0 {
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ }
return s
}
@@ -94,13 +98,21 @@ func (s *segment) clone() *segment {
route: s.route.Clone(),
viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
+ xmitTime: s.xmitTime,
+ xmitCount: s.xmitCount,
}
t.data = s.data.Clone(t.views[:])
return t
}
-func (s *segment) flagIsSet(flag uint8) bool {
- return (s.flags & flag) != 0
+// flagIsSet checks if at least one flag in flags is set in s.flags.
+func (s *segment) flagIsSet(flags uint8) bool {
+ return s.flags&flags != 0
+}
+
+// flagsAreSet checks if all flags in flags are set in s.flags.
+func (s *segment) flagsAreSet(flags uint8) bool {
+ return s.flags&flags == flags
}
func (s *segment) decRef() {
@@ -126,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
@@ -136,8 +154,6 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h := header.TCP(s.data.First())
-
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
@@ -148,12 +164,12 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(h.DataOffset())
- if offset < header.TCPMinimumSize || offset > len(h) {
+ offset := int(s.hdr.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(s.hdr) {
return false
}
- s.options = []byte(h[header.TCPMinimumSize:offset])
+ s.options = []byte(s.hdr[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -162,21 +178,19 @@ func (s *segment) parse() bool {
if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
s.csumValid = true
verifyChecksum = false
- s.data.TrimFront(offset)
}
if verifyChecksum {
- s.csum = h.Checksum()
- xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = h.CalculateChecksum(xsum)
- s.data.TrimFront(offset)
+ s.csum = s.hdr.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr)))
+ xsum = s.hdr.CalculateChecksum(xsum)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(h.SequenceNumber())
- s.ackNumber = seqnum.Value(h.AckNumber())
- s.flags = h.Flags()
- s.window = seqnum.Size(h.WindowSize())
+ s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(s.hdr.AckNumber())
+ s.flags = s.hdr.Flags()
+ s.window = seqnum.Size(s.hdr.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
index 9fd061d7d..8d3ddce4b 100644
--- a/pkg/tcpip/transport/tcp/segment_heap.go
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -14,21 +14,25 @@
package tcp
+import "container/heap"
+
type segmentHeap []*segment
+var _ heap.Interface = (*segmentHeap)(nil)
+
// Len returns the length of h.
-func (h segmentHeap) Len() int {
- return len(h)
+func (h *segmentHeap) Len() int {
+ return len(*h)
}
// Less determines whether the i-th element of h is less than the j-th element.
-func (h segmentHeap) Less(i, j int) bool {
- return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
+func (h *segmentHeap) Less(i, j int) bool {
+ return (*h)[i].sequenceNumber.LessThan((*h)[j].sequenceNumber)
}
// Swap swaps the i-th and j-th elements of h.
-func (h segmentHeap) Swap(i, j int) {
- h[i], h[j] = h[j], h[i]
+func (h *segmentHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
}
// Push adds x as the last element of h.
@@ -41,6 +45,7 @@ func (h *segmentHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
+ old[n-1] = nil
*h = old[:n-1]
return x
}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
index e0759225e..48a257137 100644
--- a/pkg/tcpip/transport/tcp/segment_queue.go
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -15,7 +15,7 @@
package tcp
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// segmentQueue is a bounded, thread-safe queue of TCP segments.
@@ -28,10 +28,16 @@ type segmentQueue struct {
used int
}
+// emptyLocked determines if the queue is empty.
+// Preconditions: q.mu must be held.
+func (q *segmentQueue) emptyLocked() bool {
+ return q.used == 0
+}
+
// empty determines if the queue is empty.
func (q *segmentQueue) empty() bool {
q.mu.Lock()
- r := q.used == 0
+ r := q.emptyLocked()
q.mu.Unlock()
return r
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index d3f7c9125..c55589c45 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -15,12 +15,13 @@
package tcp
import (
+ "fmt"
"math"
- "sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -28,8 +29,11 @@ import (
)
const (
- // minRTO is the minimum allowed value for the retransmit timeout.
- minRTO = 200 * time.Millisecond
+ // MinRTO is the minimum allowed value for the retransmit timeout.
+ MinRTO = 200 * time.Millisecond
+
+ // MaxRTO is the maximum allowed value for the retransmit timeout.
+ MaxRTO = 120 * time.Second
// InitialCwnd is the initial congestion window.
InitialCwnd = 10
@@ -37,6 +41,11 @@ const (
// nDupAckThreshold is the number of duplicate ACK's required
// before fast-retransmit is entered.
nDupAckThreshold = 3
+
+ // MaxRetries is the maximum number of probe retries sender does
+ // before timing out the connection.
+ // Linux default TCP_RETR2, net.ipv4.tcp_retries2.
+ MaxRetries = 15
)
// ccState indicates the current congestion control state for this sender.
@@ -123,10 +132,6 @@ type sender struct {
// sndNxt is the sequence number of the next segment to be sent.
sndNxt seqnum.Value
- // sndNxtList is the sequence number of the next segment to be added to
- // the send list.
- sndNxtList seqnum.Value
-
// rttMeasureSeqNum is the sequence number being used for the latest RTT
// measurement.
rttMeasureSeqNum seqnum.Value
@@ -134,6 +139,18 @@ type sender struct {
// rttMeasureTime is the time when the rttMeasureSeqNum was sent.
rttMeasureTime time.Time `state:".(unixTime)"`
+ // firstRetransmittedSegXmitTime is the original transmit time of
+ // the first segment that was retransmitted due to RTO expiration.
+ firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"`
+
+ // zeroWindowProbing is set if the sender is currently probing
+ // for zero receive window.
+ zeroWindowProbing bool `state:"nosave"`
+
+ // unackZeroWindowProbes is the number of unacknowledged zero
+ // window probes.
+ unackZeroWindowProbes uint32 `state:"nosave"`
+
closed bool
writeNext *segment
writeList segmentList
@@ -146,6 +163,15 @@ type sender struct {
rtt rtt
rto time.Duration
+ // minRTO is the minimum permitted value for sender.rto.
+ minRTO time.Duration
+
+ // maxRTO is the maximum permitted value for sender.rto.
+ maxRTO time.Duration
+
+ // maxRetries is the maximum permitted retransmissions.
+ maxRetries uint32
+
// maxPayloadSize is the maximum size of the payload of a given segment.
// It is initialized on demand.
maxPayloadSize int
@@ -165,6 +191,10 @@ type sender struct {
// cc is the congestion control algorithm in use for this sender.
cc congestionControl
+
+ // rc has the fields needed for implementing RACK loss detection
+ // algorithm.
+ rc rackControl
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
@@ -222,7 +252,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
- sndNxtList: iss + 1,
rto: 1 * time.Second,
rttMeasureSeqNum: iss + 1,
lastSendTime: time.Now(),
@@ -258,6 +287,25 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint
// etc.
s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+ // Get Stack wide config.
+ var minRTO tcpip.TCPMinRTOOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &minRTO); err != nil {
+ panic(fmt.Sprintf("unable to get minRTO from stack: %s", err))
+ }
+ s.minRTO = time.Duration(minRTO)
+
+ var maxRTO tcpip.TCPMaxRTOOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRTO); err != nil {
+ panic(fmt.Sprintf("unable to get maxRTO from stack: %s", err))
+ }
+ s.maxRTO = time.Duration(maxRTO)
+
+ var maxRetries tcpip.TCPMaxRetriesOption
+ if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRetries); err != nil {
+ panic(fmt.Sprintf("unable to get maxRetries from stack: %s", err))
+ }
+ s.maxRetries = uint32(maxRetries)
+
return s
}
@@ -392,8 +440,8 @@ func (s *sender) updateRTO(rtt time.Duration) {
s.rto = s.rtt.srtt + 4*s.rtt.rttvar
s.rtt.Unlock()
- if s.rto < minRTO {
- s.rto = minRTO
+ if s.rto < s.minRTO {
+ s.rto = s.minRTO
}
}
@@ -435,17 +483,56 @@ func (s *sender) retransmitTimerExpired() bool {
return true
}
+ // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases
+ // when writeList is empty. Remove this once we have a proper fix for this
+ // issue.
+ if s.writeList.Front() == nil {
+ return true
+ }
+
s.ep.stack.Stats().TCP.Timeouts.Increment()
s.ep.stats.SendErrors.Timeouts.Increment()
- // Give up if we've waited more than a minute since the last resend.
- if s.rto >= 60*time.Second {
+ // Give up if we've waited more than a minute since the last resend or
+ // if a user time out is set and we have exceeded the user specified
+ // timeout since the first retransmission.
+ uto := s.ep.userTimeout
+
+ if s.firstRetransmittedSegXmitTime.IsZero() {
+ // We store the original xmitTime of the segment that we are
+ // about to retransmit as the retransmission time. This is
+ // required as by the time the retransmitTimer has expired the
+ // segment has already been sent and unacked for the RTO at the
+ // time the segment was sent.
+ s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime
+ }
+
+ elapsed := time.Since(s.firstRetransmittedSegXmitTime)
+ remaining := s.maxRTO
+ if uto != 0 {
+ // Cap to the user specified timeout if one is specified.
+ remaining = uto - elapsed
+ }
+
+ // Always honor the user-timeout irrespective of whether the zero
+ // window probes were acknowledged.
+ // net/ipv4/tcp_timer.c::tcp_probe_timer()
+ if remaining <= 0 || s.unackZeroWindowProbes >= s.maxRetries {
return false
}
// Set new timeout. The timer will be restarted by the call to sendData
// below.
s.rto *= 2
+ // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5
+ if s.rto > s.maxRTO {
+ s.rto = s.maxRTO
+ }
+
+ // Cap RTO to remaining time.
+ if s.rto > remaining {
+ s.rto = remaining
+ }
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
//
@@ -488,6 +575,26 @@ func (s *sender) retransmitTimerExpired() bool {
// information is usable after an RTO.
s.ep.scoreboard.Reset()
s.writeNext = s.writeList.Front()
+
+ // RFC 1122 4.2.2.17: Start sending zero window probes when we still see a
+ // zero receive window after retransmission interval and we have data to
+ // send.
+ if s.zeroWindowProbing {
+ s.sendZeroWindowProbe()
+ // RFC 1122 4.2.2.17: A TCP MAY keep its offered receive window closed
+ // indefinitely. As long as the receiving TCP continues to send
+ // acknowledgments in response to the probe segments, the sending TCP
+ // MUST allow the connection to stay open.
+ return true
+ }
+
+ seg := s.writeNext
+ // RFC 1122 4.2.3.5: Close the connection when the number of
+ // retransmissions for this segment is beyond a limit.
+ if seg != nil && seg.xmitCount > s.maxRetries {
+ return false
+ }
+
s.sendData()
return true
@@ -515,25 +622,51 @@ func (s *sender) splitSeg(seg *segment, size int) {
nSeg.data.TrimFront(size)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
s.writeList.InsertAfter(seg, nSeg)
+
+ // The segment being split does not carry PUSH flag because it is
+ // followed by the newly split segment.
+ // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered
+ // segment (i.e., when there is no more queued data to be sent).
+ // Linux removes PSH flag only when the segment is being split over MSS
+ // and retains it when we are splitting the segment over lack of sender
+ // window space.
+ // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test()
+ if seg.data.Size() > s.maxPayloadSize {
+ seg.flags ^= header.TCPFlagPsh
+ }
+
seg.data.CapLength(size)
}
-// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that
-// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2
-// is handled by the normal send logic.
-func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
+// NextSeg implements the RFC6675 NextSeg() operation.
+//
+// NextSeg starts scanning the writeList starting from nextSegHint and returns
+// the hint to be passed on the next call to NextSeg. This is required to avoid
+// iterating the write list repeatedly when NextSeg is invoked in a loop during
+// recovery. The returned hint will be nil if there are no more segments that
+// can match rules defined by NextSeg operation in RFC6675.
+//
+// rescueRtx will be true only if nextSeg is a rescue retransmission as
+// described by Step 4) of the NextSeg algorithm.
+func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRtx bool) {
var s3 *segment
var s4 *segment
- smss := s.ep.scoreboard.SMSS()
// Step 1.
- for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
- if !s.isAssignedSequenceNumber(seg) {
+ for seg := nextSegHint; seg != nil; seg = seg.Next() {
+ // Stop iteration if we hit a segment that has never been
+ // transmitted (i.e. either it has no assigned sequence number
+ // or if it does have one, it's >= the next sequence number
+ // to be sent [i.e. >= s.sndNxt]).
+ if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) {
+ hint = nil
break
}
segSeq := seg.sequenceNumber
- if seg.data.Size() > int(smss) {
+ if smss := s.ep.scoreboard.SMSS(); seg.data.Size() > int(smss) {
s.splitSeg(seg, int(smss))
}
+
// See RFC 6675 Section 4
//
// 1. If there exists a smallest unSACKED sequence number
@@ -550,8 +683,9 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
// NextSeg():
// (1.c) IsLost(S2) returns true.
if s.ep.scoreboard.IsLost(segSeq) {
- return seg, s3, s4
+ return seg, seg.Next(), false
}
+
// NextSeg():
//
// (3): If the conditions for rules (1) and (2)
@@ -563,6 +697,7 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
// SHOULD be returned.
if s3 == nil {
s3 = seg
+ hint = seg.Next()
}
}
// NextSeg():
@@ -571,10 +706,12 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
// but there exists outstanding unSACKED data, we
// provide the opportunity for a single "rescue"
// retransmission per entry into loss recovery. If
- // HighACK is greater than RescueRxt, the one
- // segment of upto SMSS octects that MUST include
- // the highest outstanding unSACKed sequence number
- // SHOULD be returned.
+ // HighACK is greater than RescueRxt (or RescueRxt
+ // is undefined), then one segment of upto SMSS
+ // octects that MUST include the highest outstanding
+ // unSACKed sequence number SHOULD be returned, and
+ // RescueRxt set to RecoveryPoint. HighRxt MUST NOT
+ // be updated.
if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
if s4 != nil {
if s4.sequenceNumber.LessThan(segSeq) {
@@ -583,12 +720,31 @@ func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
} else {
s4 = seg
}
- s.fr.rescueRxt = s.fr.last
}
}
}
- return nil, s3, s4
+ // If we got here then no segment matched step (1).
+ // Step (2): "If no sequence number 'S2' per rule (1)
+ // exists but there exists available unsent data and the
+ // receiver's advertised window allows, the sequence
+ // range of one segment of up to SMSS octets of
+ // previously unsent data starting with sequence number
+ // HighData+1 MUST be returned."
+ for seg := s.writeNext; seg != nil; seg = seg.Next() {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ continue
+ }
+ // We do not split the segment here to <= smss as it has
+ // potentially not been assigned a sequence number yet.
+ return seg, nil, false
+ }
+
+ if s3 != nil {
+ return s3, hint, false
+ }
+
+ return s4, nil, true
}
// maybeSendSegment tries to send the specified segment and either coalesces
@@ -601,7 +757,7 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if !s.isAssignedSequenceNumber(seg) {
// Merge segments if allowed.
if seg.data.Size() != 0 {
- available := int(seg.sequenceNumber.Size(end))
+ available := int(s.sndNxt.Size(end))
if available > limit {
available = limit
}
@@ -644,8 +800,11 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// sent all at once.
return false
}
- if atomic.LoadUint32(&s.ep.cork) != 0 {
- // Hold back the segment until full.
+ // With TCP_CORK, hold back until minimum of the available
+ // send space and MSS.
+ // TODO(gvisor.dev/issue/2833): Drain the held segments after a
+ // timeout.
+ if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 {
return false
}
}
@@ -664,18 +823,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
}
seg.flags = header.TCPFlagAck | header.TCPFlagFin
segEnd = seg.sequenceNumber.Add(1)
- // Transition to FIN-WAIT1 state since we're initiating an active close.
- s.ep.mu.Lock()
- switch s.ep.state {
+ // Update the state to reflect that we have now
+ // queued a FIN.
+ switch s.ep.EndpointState() {
case StateCloseWait:
- // We've already received a FIN and are now sending our own. The
- // sender is now awaiting a final ACK for this FIN.
- s.ep.state = StateLastAck
+ s.ep.setEndpointState(StateLastAck)
default:
- s.ep.state = StateFinWait1
+ s.ep.setEndpointState(StateFinWait1)
}
- s.ep.stack.Stats().TCP.CurrentEstablished.Decrement()
- s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
if seg.flags&header.TCPFlagFin != 0 {
@@ -690,10 +845,52 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
if available == 0 {
return false
}
+
+ // If the whole segment or at least 1MSS sized segment cannot
+ // be accomodated in the receiver advertized window, skip
+ // splitting and sending of the segment. ref:
+ // net/ipv4/tcp_output.c::tcp_snd_wnd_test()
+ //
+ // Linux checks this for all segment transmits not triggered by
+ // a probe timer. On this condition, it defers the segment split
+ // and transmit to a short probe timer.
+ //
+ // ref: include/net/tcp.h::tcp_check_probe_timer()
+ // ref: net/ipv4/tcp_output.c::tcp_write_wakeup()
+ //
+ // Instead of defining a new transmit timer, we attempt to split
+ // the segment right here if there are no pending segments. If
+ // there are pending segments, segment transmits are deferred to
+ // the retransmit timer handler.
+ if s.sndUna != s.sndNxt {
+ switch {
+ case available >= seg.data.Size():
+ // OK to send, the whole segments fits in the
+ // receiver's advertised window.
+ case available >= s.maxPayloadSize:
+ // OK to send, at least 1 MSS sized segment fits
+ // in the receiver's advertised window.
+ default:
+ return false
+ }
+ }
+
+ // The segment size limit is computed as a function of sender
+ // congestion window and MSS. When sender congestion window is >
+ // 1, this limit can be larger than MSS. Ensure that the
+ // currently available send space is not greater than minimum of
+ // this limit and MSS.
if available > limit {
available = limit
}
+ // If GSO is not in use then cap available to
+ // maxPayloadSize. When GSO is in use the gVisor GSO logic or
+ // the host GSO logic will cap the segment to the correct size.
+ if s.ep.gso == nil && available > s.maxPayloadSize {
+ available = s.maxPayloadSize
+ }
+
if seg.data.Size() > available {
s.splitSeg(seg, available)
}
@@ -716,64 +913,47 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
// section 5, step C.
func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
s.SetPipe()
+
+ if smss := int(s.ep.scoreboard.SMSS()); limit > smss {
+ // Cap segment size limit to s.smss as SACK recovery requires
+ // that all retransmissions or new segments send during recovery
+ // be of <= SMSS.
+ limit = smss
+ }
+
+ nextSegHint := s.writeList.Front()
for s.outstanding < s.sndCwnd {
- nextSeg, s3, s4 := s.NextSeg()
+ var nextSeg *segment
+ var rescueRtx bool
+ nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint)
if nextSeg == nil {
- // NextSeg():
- //
- // Step (2): "If no sequence number 'S2' per rule (1)
- // exists but there exists available unsent data and the
- // receiver's advertised window allows, the sequence
- // range of one segment of up to SMSS octets of
- // previously unsent data starting with sequence number
- // HighData+1 MUST be returned."
- for seg := s.writeNext; seg != nil; seg = seg.Next() {
- if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
- continue
- }
- // Step C.3 described below is handled by
- // maybeSendSegment which increments sndNxt when
- // a segment is transmitted.
- //
- // Step C.3 "If any of the data octets sent in
- // (C.1) are above HighData, HighData must be
- // updated to reflect the transmission of
- // previously unsent data."
- if sent := s.maybeSendSegment(seg, limit, end); !sent {
- break
- }
- dataSent = true
- s.outstanding++
- s.writeNext = seg.Next()
- nextSeg = seg
- break
- }
- if nextSeg != nil {
- continue
- }
- }
- rescueRtx := false
- if nextSeg == nil && s3 != nil {
- nextSeg = s3
- }
- if nextSeg == nil && s4 != nil {
- nextSeg = s4
- rescueRtx = true
+ return dataSent
}
- if nextSeg == nil {
- break
- }
- segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
- if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) {
- // RFC 6675, Step C.2
+ if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) {
+ // New data being sent.
+
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
//
- // "If any of the data octets sent in (C.1) are below
- // HighData, HighRxt MUST be set to the highest sequence
- // number of the retransmitted segment unless NextSeg ()
- // rule (4) was invoked for this retransmission."
- s.fr.highRxt = segEnd - 1
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ //
+ // We pass s.smss as the limit as the Step 2) requires that
+ // new data sent should be of size s.smss or less.
+ if sent := s.maybeSendSegment(nextSeg, limit, end); !sent {
+ return dataSent
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = nextSeg.Next()
+ continue
}
+ // Now handle the retransmission case where we matched either step 1,3 or 4
+ // of the NextSeg algorithm.
// RFC 6675, Step C.4.
//
// "The estimate of the amount of data outstanding in the network
@@ -782,10 +962,54 @@ func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool)
s.outstanding++
dataSent = true
s.sendSegment(nextSeg)
+
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if rescueRtx {
+ // We do the last part of rule (4) of NextSeg here to update
+ // RescueRxt as until this point we don't know if we are going
+ // to use the rescue transmission.
+ s.fr.rescueRxt = s.fr.last
+ } else {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ s.fr.highRxt = segEnd - 1
+ }
}
return dataSent
}
+func (s *sender) sendZeroWindowProbe() {
+ ack, win := s.ep.rcv.getSendParams()
+ s.unackZeroWindowProbes++
+ // Send a zero window probe with sequence number pointing to
+ // the last acknowledged byte.
+ s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win)
+ // Rearm the timer to continue probing.
+ s.resendTimer.enable(s.rto)
+}
+
+func (s *sender) enableZeroWindowProbing() {
+ s.zeroWindowProbing = true
+ // We piggyback the probing on the retransmit timer with the
+ // current retranmission interval, as we may start probing while
+ // segment retransmissions.
+ if s.firstRetransmittedSegXmitTime.IsZero() {
+ s.firstRetransmittedSegXmitTime = time.Now()
+ }
+ s.resendTimer.enable(s.rto)
+}
+
+func (s *sender) disableZeroWindowProbing() {
+ s.zeroWindowProbing = false
+ s.unackZeroWindowProbes = 0
+ s.firstRetransmittedSegXmitTime = time.Time{}
+ s.resendTimer.disable()
+}
+
// sendData sends new data segments. It is called when data becomes available or
// when the send window opens up.
func (s *sender) sendData() {
@@ -799,7 +1023,7 @@ func (s *sender) sendData() {
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
- if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
+ if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto {
if s.sndCwnd > InitialCwnd {
s.sndCwnd = InitialCwnd
}
@@ -817,6 +1041,9 @@ func (s *sender) sendData() {
limit = cwndLimit
}
if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ // Move writeNext along so that we don't try and scan data that
+ // has already been SACKED.
+ s.writeNext = seg.Next()
continue
}
if sent := s.maybeSendSegment(seg, limit, end); !sent {
@@ -834,6 +1061,13 @@ func (s *sender) sendData() {
s.ep.disableKeepaliveTimer()
}
+ // If the sender has advertized zero receive window and we have
+ // data to be sent out, start zero window probing to query the
+ // the remote for it's receive window size.
+ if s.writeNext != nil && s.sndWnd == 0 {
+ s.enableZeroWindowProbing()
+ }
+
// Enable the timer if we have pending data and it's not enabled yet.
if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
s.resendTimer.enable(s.rto)
@@ -855,6 +1089,8 @@ func (s *sender) enterFastRecovery() {
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
+ s.fr.highRxt = s.sndUna
+ s.fr.rescueRxt = s.sndUna
if s.ep.sackPermitted {
s.state = SACKRecovery
s.ep.stack.Stats().TCP.SACKRecovery.Increment()
@@ -1040,21 +1276,21 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
-func (s *sender) handleRcvdSegment(seg *segment) {
+func (s *sender) handleRcvdSegment(rcvdSeg *segment) {
// Check if we can extract an RTT measurement from this ack.
- if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ if !rcvdSeg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(rcvdSeg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
- if s.ep.sendTSOk && seg.parsedOptions.TS {
- s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(rcvdSeg.parsedOptions.TSVal, s.maxSentAck, rcvdSeg.sequenceNumber)
}
// Insert SACKBlock information into our scoreboard.
if s.ep.sackPermitted {
- for _, sb := range seg.parsedOptions.SACKBlocks {
+ for _, sb := range rcvdSeg.parsedOptions.SACKBlocks {
// Only insert the SACK block if the following holds
// true:
// * SACK block acks data after the ack number in the
@@ -1067,22 +1303,40 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// NOTE: This check specifically excludes DSACK blocks
// which have start/end before sndUna and are used to
// indicate spurious retransmissions.
- if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ if rcvdSeg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
s.ep.scoreboard.Insert(sb)
- seg.hasNewSACKInfo = true
+ rcvdSeg.hasNewSACKInfo = true
}
}
s.SetPipe()
}
// Count the duplicates and do the fast retransmit if needed.
- rtx := s.checkDuplicateAck(seg)
+ rtx := s.checkDuplicateAck(rcvdSeg)
// Stash away the current window size.
- s.sndWnd = seg.window
+ s.sndWnd = rcvdSeg.window
+
+ ack := rcvdSeg.ackNumber
+
+ // Disable zero window probing if remote advertizes a non-zero receive
+ // window. This can be with an ACK to the zero window probe (where the
+ // acknumber refers to the already acknowledged byte) OR to any previously
+ // unacknowledged segment.
+ if s.zeroWindowProbing && rcvdSeg.window > 0 &&
+ (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) {
+ s.disableZeroWindowProbing()
+ }
+
+ // On receiving the ACK for the zero window probe, account for it and
+ // skip trying to send any segment as we are still probing for
+ // receive window to become non-zero.
+ if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna {
+ s.unackZeroWindowProbes--
+ return
+ }
// Ignore ack if it doesn't acknowledge any new data.
- ack := seg.ackNumber
if (ack - 1).InRange(s.sndUna, s.sndNxt) {
s.dupAckCount = 0
@@ -1094,15 +1348,15 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
- if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ if s.ep.sendTSOk && rcvdSeg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
- elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ elapsed := time.Duration(s.ep.timestamp()-rcvdSeg.parsedOptions.TSEcr) * time.Millisecond
s.updateRTO(elapsed)
}
// When an ack is received we must rearm the timer.
- // RFC 6298 5.2
+ // RFC 6298 5.3
s.resendTimer.enable(s.rto)
// Remove all acknowledged data from the write list.
@@ -1111,6 +1365,9 @@ func (s *sender) handleRcvdSegment(seg *segment) {
ackLeft := acked
originalOutstanding := s.outstanding
+ s.rtt.Lock()
+ srtt := s.rtt.srtt
+ s.rtt.Unlock()
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
@@ -1129,6 +1386,12 @@ func (s *sender) handleRcvdSegment(seg *segment) {
if s.writeNext == seg {
s.writeNext = seg.Next()
}
+
+ // Update the RACK fields if SACK is enabled.
+ if s.ep.sackPermitted {
+ s.rc.Update(seg, rcvdSeg, srtt, s.ep.tsOffset)
+ }
+
s.writeList.Remove(seg)
// if SACK is enabled then Only reduce outstanding if
@@ -1169,6 +1432,8 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// RFC 6298 Rule 5.3
if s.sndUna == s.sndNxt {
s.outstanding = 0
+ // Reset firstRetransmittedSegXmitTime to the zero value.
+ s.firstRetransmittedSegXmitTime = time.Time{}
s.resendTimer.disable()
}
}
@@ -1182,14 +1447,14 @@ func (s *sender) handleRcvdSegment(seg *segment) {
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
- if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || rcvdSeg.hasNewSACKInfo {
s.sendData()
}
}
// sendSegment sends the specified segment.
func (s *sender) sendSegment(seg *segment) *tcpip.Error {
- if !seg.xmitTime.IsZero() {
+ if seg.xmitCount > 0 {
s.ep.stack.Stats().TCP.Retransmits.Increment()
s.ep.stats.SendErrors.Retransmits.Increment()
if s.sndCwnd < s.sndSsthresh {
@@ -1197,7 +1462,24 @@ func (s *sender) sendSegment(seg *segment) *tcpip.Error {
}
}
seg.xmitTime = time.Now()
- return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
+ seg.xmitCount++
+ err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
+
+ // Every time a packet containing data is sent (including a
+ // retransmission), if SACK is enabled and we are retransmitting data
+ // then use the conservative timer described in RFC6675 Section 6.0,
+ // otherwise follow the standard time described in RFC6298 Section 5.1.
+ if err != nil && seg.data.Size() != 0 {
+ if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted {
+ s.resendTimer.enable(s.rto)
+ } else {
+ if !s.resendTimer.enabled() {
+ s.resendTimer.enable(s.rto)
+ }
+ }
+ }
+
+ return err
}
// sendSegmentFromView sends a new segment containing the given payload, flags
@@ -1213,19 +1495,5 @@ func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq
// Remember the max sent ack.
s.maxSentAck = rcvNxt
- // Every time a packet containing data is sent (including a
- // retransmission), if SACK is enabled then use the conservative timer
- // described in RFC6675 Section 4.0, otherwise follow the standard time
- // described in RFC6298 Section 5.2.
- if data.Size() != 0 {
- if s.ep.sackPermitted {
- s.resendTimer.enable(s.rto)
- } else {
- if !s.resendTimer.enabled() {
- s.resendTimer.enable(s.rto)
- }
- }
- }
-
return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
index 12eff8afc..8b20c3455 100644
--- a/pkg/tcpip/transport/tcp/snd_state.go
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -48,3 +48,13 @@ func (s *sender) loadRttMeasureTime(unix unixTime) {
func (s *sender) afterLoad() {
s.resendTimer.init(&s.resendWaker)
}
+
+// saveFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime {
+ return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()}
+}
+
+// loadFirstRetransmittedSegXmitTime is invoked by stateify.
+func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) {
+ s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 782d7b42c..b9993ce1a 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
func TestFastRecovery(t *testing.T) {
@@ -40,7 +41,7 @@ func TestFastRecovery(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -49,7 +50,7 @@ func TestFastRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -86,16 +87,23 @@ func TestFastRecovery(t *testing.T) {
// Receive the retransmitted packet.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
+ // Wait before checking metrics.
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want)
+ }
+ return nil
}
- if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRecovery.Value = %v, want = %v", got, want)
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Now send 7 mode duplicate acks. Each of these should cause a window
@@ -117,12 +125,18 @@ func TestFastRecovery(t *testing.T) {
// Receive the retransmit due to partial ack.
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
+ // Wait before checking metrics.
+ metricPollFn = func() error {
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
+ }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want)
+ }
+ return nil
}
-
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want {
- t.Errorf("got stats.TCP.Retransmit.Value = %v, want = %v", got, want)
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Receive the 10 extra packets that should have been released due to
@@ -192,7 +206,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -201,7 +215,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
expected := tcp.InitialCwnd
@@ -234,7 +248,7 @@ func TestCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -243,7 +257,7 @@ func TestCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -338,7 +352,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
@@ -348,7 +362,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -447,7 +461,7 @@ func TestRetransmit(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -457,11 +471,11 @@ func TestRetransmit(t *testing.T) {
// MTU size though.
half := data[:len(data)/2]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
half = data[len(data)/2:]
if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -492,24 +506,33 @@ func TestRetransmit(t *testing.T) {
rtxOffset := bytesRead - maxPayload*expected
c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
- if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Timeouts.Value = %v, want = %v", got, want)
- }
+ metricPollFn := func() error {
+ if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
- t.Errorf("got EP SendErrors.Timeouts.Value = %v, want = %v", got, want)
- }
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want)
+ }
+
+ if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want {
- t.Errorf("got EP stats SendErrors.Retransmits.Value = %v, want = %v", got, want)
+ return nil
}
- if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.SlowStartRetransmits.Value = %v, want = %v", got, want)
+ // Poll when checking metrics.
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Acknowledge half of the pending data.
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
new file mode 100644
index 000000000..e03f101e8
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -0,0 +1,74 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+)
+
+// TestRACKUpdate tests the RACK related fields are updated when an ACK is
+// received on a SACK enabled connection.
+func TestRACKUpdate(t *testing.T) {
+ const maxPayload = 10
+ const tsOptionSize = 12
+ const maxTCPOptionSize = 40
+
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload))
+ defer c.Cleanup()
+
+ var xmitTime time.Time
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint Sender.RACKState is what we expect.
+ if state.Sender.RACKState.XmitTime.Before(xmitTime) {
+ t.Fatalf("RACK transmit time failed to update when an ACK is received")
+ }
+
+ gotSeq := state.Sender.RACKState.EndSequence
+ wantSeq := state.Sender.SndNxt
+ if !gotSeq.LessThanEq(wantSeq) || gotSeq.LessThan(wantSeq) {
+ t.Fatalf("RACK sequence number failed to update, got: %v, but want: %v", gotSeq, wantSeq)
+ }
+
+ if state.Sender.RACKState.RTT == 0 {
+ t.Fatalf("RACK RTT failed to update when an ACK is received")
+ }
+ })
+ setStackSACKPermitted(t, c, true)
+ createConnectedWithSACKAndTS(c)
+
+ data := buffer.NewView(maxPayload)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write the data.
+ xmitTime = time.Now()
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ bytesRead := 0
+ c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize)
+ bytesRead += maxPayload
+ c.SendAck(790, bytesRead)
+ time.Sleep(200 * time.Millisecond)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index afea124ec..99521f0c1 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
// createConnectedWithSACKPermittedOption creates and connects c.ep with the
@@ -46,7 +47,7 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint {
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err)
}
}
@@ -149,21 +150,22 @@ func TestSackPermittedAccept(t *testing.T) {
{true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
@@ -222,21 +224,23 @@ func TestSackDisabledAccept(t *testing.T) {
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
- if tc.cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- } else {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+
+ if tc.cookieEnabled {
+ // Set the SynRcvd threshold to
+ // zero to force a syn cookie
+ // based accept to happen.
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -387,7 +391,7 @@ func TestSACKRecovery(t *testing.T) {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- const iterations = 7
+ const iterations = 3
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
for i := range data {
data[i] = byte(i)
@@ -396,7 +400,7 @@ func TestSACKRecovery(t *testing.T) {
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Do slow start for a few iterations.
@@ -436,21 +440,28 @@ func TestSACKRecovery(t *testing.T) {
// Receive the retransmitted packet.
c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize)
- tcpStats := c.Stack().Stats().TCP
- stats := []struct {
- stat *tcpip.StatCounter
- name string
- want uint64
- }{
- {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
- {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
- {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
- {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
- }
- for _, s := range stats {
- if got, want := s.stat.Value(), s.want; got != want {
- t.Errorf("got %s.Value() = %v, want = %v", s.name, got, want)
+ metricPollFn := func() error {
+ tcpStats := c.Stack().Stats().TCP
+ stats := []struct {
+ stat *tcpip.StatCounter
+ name string
+ want uint64
+ }{
+ {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1},
+ {tcpStats.Retransmits, "stats.TCP.Retransmits", 1},
+ {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1},
+ {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0},
}
+ for _, s := range stats {
+ if got, want := s.stat.Value(), s.want; got != want {
+ return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want)
+ }
+ }
+ return nil
+ }
+
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
// Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause
@@ -514,22 +525,28 @@ func TestSACKRecovery(t *testing.T) {
bytesRead += maxPayload
}
- // In SACK recovery only the first segment is fast retransmitted when
- // entering recovery.
- if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got stats.TCP.FastRetransmit.Value = %v, want = %v", got, want)
- }
+ metricPollFn = func() error {
+ // In SACK recovery only the first segment is fast retransmitted when
+ // entering recovery.
+ if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
- t.Errorf("got EP stats SendErrors.FastRetransmit = %v, want = %v", got, want)
- }
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want {
+ return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want)
+ }
- if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
- t.Errorf("got stats.TCP.Retransmits.Value = %v, want = %v", got, want)
- }
+ if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want {
+ return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want)
+ }
- if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
- t.Errorf("got EP stats Stats.SendErrors.Retransmits = %v, want = %v", got, want)
+ if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want {
+ return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want)
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 6d808328c..55ae09a2f 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -34,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -55,7 +57,7 @@ func TestGiveUpConnect(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -64,7 +66,7 @@ func TestGiveUpConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Close the connection, wait for completion.
@@ -73,7 +75,21 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
+ t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+
+ // Call Connect again to retreive the handshake failure status
+ // and stats updates.
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+
+ if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 {
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
}
}
@@ -86,7 +102,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -99,10 +115,10 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats.FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want)
}
}
@@ -113,20 +129,38 @@ func TestActiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
want := stats.TCP.FailedConnectionAttempts.Value() + 1
if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute {
- t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
+ t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute)
}
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want {
- t.Errorf("got EP stats FailedConnectionAttempts = %v, want = %v", got, want)
+ t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want)
+ }
+}
+
+func TestCloseWithoutConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ c.EP.Close()
+
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -140,10 +174,10 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
- t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want {
- t.Errorf("got EP stats SegmentsSent.Value() = %v, want = %v", got, want)
+ t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want)
}
}
@@ -154,16 +188,16 @@ func TestTCPResetsSentIncrement(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
want := stats.TCP.SegmentsSent.Value() + 1
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -194,8 +228,15 @@ func TestTCPResetsSentIncrement(t *testing.T) {
c.SendPacket(nil, ackHeaders)
c.GetPacket()
- if got := stats.TCP.ResetsSent.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want)
+
+ metricPollFn := func() error {
+ if got := stats.TCP.ResetsSent.Value(); got != want {
+ return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want)
+ }
+ return nil
+ }
+ if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil {
+ t.Error(err)
}
}
@@ -206,17 +247,18 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send a SYN request.
@@ -256,7 +298,7 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -264,6 +306,13 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
}
}
+ // Lower stackwide TIME_WAIT timeout so that the reservations
+ // are released instantly on Close.
+ tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil {
+ t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err)
+ }
+
c.EP.Close()
checker.IPv4(t, c.GetPacket(), checker.TCP(
checker.SrcPort(context.StackPort),
@@ -271,7 +320,6 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SeqNum(uint32(c.IRS+1)),
checker.AckNum(uint32(iss)+1),
checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
-
finHeaders := &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
@@ -285,6 +333,11 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
// Get the ACK to the FIN we just sent.
c.GetPacket()
+ // Since an active close was done we need to wait for a little more than
+ // tcpLingerTimeout for the port reservations to be released and the
+ // socket to move to a CLOSED state.
+ time.Sleep(20 * time.Millisecond)
+
// Now resend the same ACK, this ACK should generate a RST as there
// should be no endpoint in SYN-RCVD state and we are not using
// syn-cookies yet. The reason we send the same ACK is we need a valid
@@ -296,8 +349,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS+1)),
- checker.AckNum(uint32(iss)+1),
- checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
}
func TestTCPResetsReceivedIncrement(t *testing.T) {
@@ -320,7 +373,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
}
@@ -344,7 +397,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
})
if got := stats.TCP.ResetsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want)
}
c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
}
@@ -368,7 +421,7 @@ func TestNonBlockingClose(t *testing.T) {
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -376,6 +429,13 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPLinger to 3 seconds so that sockets are marked closed
+ // after 3 second in FIN_WAIT2 state.
+ tcpLingerTimeout := 3 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -396,12 +456,24 @@ func TestConnectResetAfterClose(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1),
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Wait for the ep to give up waiting for a FIN.
+ time.Sleep(tcpLingerTimeout + 1*time.Second)
+
+ // Now send an ACK and it should trigger a RST as the endpoint should
+ // not exist anymore.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
RcvWnd: 30000,
})
- // Wait for the ep to give up waiting for a FIN, and send a RST.
- time.Sleep(3 * time.Second)
for {
b := c.GetPacket()
tcpHdr := header.TCP(header.IPv4(b).Payload())
@@ -413,15 +485,219 @@ func TestConnectResetAfterClose(t *testing.T) {
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
- checker.AckNum(790),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence number
+ // of the FIN itself.
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
break
}
}
+// TestCurrentConnectedIncrement tests increment of the current
+// established and connected counters.
+func TestCurrentConnectedIncrement(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got)
+ }
+ gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value()
+ if gotConnected != 1 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected)
+ }
+
+ ep.Close()
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected)
+ }
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Wait for a little more than the TIME-WAIT duration for the socket to
+ // transition to CLOSED state.
+ time.Sleep(1200 * time.Millisecond)
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+// TestClosingWithEnqueuedSegments tests handling of still enqueued segments
+// when the endpoint transitions to StateClose. The in-flight segments would be
+// re-enqueued to a any listening endpoint.
+func TestClosingWithEnqueuedSegments(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+ ep := c.EP
+ c.EP = nil
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want {
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
+ }
+
+ // Send a FIN for ESTABLISHED --> CLOSED-WAIT
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Get the ACK for the FIN we sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Give the stack a few ms to transition the endpoint out of ESTABLISHED
+ // state.
+ time.Sleep(10 * time.Millisecond)
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want {
+ t.Errorf("unexpected endpoint state: want %d, got %d", want, got)
+ }
+
+ // Close the application endpoint for CLOSE_WAIT --> LAST_ACK
+ ep.Close()
+
+ // Get the FIN
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ // Pause the endpoint`s protocolMainLoop.
+ ep.(interface{ StopWork() }).StopWork()
+
+ // Enqueue last ACK followed by an ACK matching the endpoint
+ //
+ // Send Last ACK for LAST_ACK --> CLOSED
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 791,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Send a packet with ACK set, this would generate RST when
+ // not using SYN cookies as in this test.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 792,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Unpause endpoint`s protocolMainLoop.
+ ep.(interface{ ResumeWork() }).ResumeWork()
+
+ // Wait for the protocolMainLoop to resume and update state.
+ time.Sleep(10 * time.Millisecond)
+
+ // Expect the endpoint to be closed.
+ if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+
+ // Check if the endpoint was moved to CLOSED and netstack a reset in
+ // response to the ACK packet that we sent after last-ACK.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+}
+
func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -433,7 +709,7 @@ func TestSimpleReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -456,7 +732,7 @@ func TestSimpleReceive(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -474,6 +750,488 @@ func TestSimpleReceive(t *testing.T) {
)
}
+// TestUserSuppliedMSSOnConnect tests that the user supplied MSS is used when
+// creating a new active TCP socket. It should be present in the sent TCP
+// SYN segment.
+func TestUserSuppliedMSSOnConnect(t *testing.T) {
+ const mtu = 5000
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ connectAddr tcpip.Address
+ checker func(*testing.T, *context.Context, uint16, int)
+ maxMSS uint16
+ }{
+ {
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ connectAddr: context.TestAddr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
+ },
+ {
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(true)
+ },
+ connectAddr: context.TestV6Addr,
+ checker: func(t *testing.T, c *context.Context, mss uint16, ws int) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: ws})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
+ },
+ }
+
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ ip.createEP(c)
+
+ // Set the MSS socket option.
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
+
+ // Get expected window size.
+ rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption): %s", err)
+ }
+ ws := tcp.FindWndScale(seqnum.Size(rcvBufSize))
+
+ connectAddr := tcpip.FullAddress{Addr: ip.connectAddr, Port: context.TestPort}
+ if err := c.EP.Connect(connectAddr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Connect(%+v): %s", connectAddr, err)
+ }
+
+ // Receive SYN packet with our user supplied MSS.
+ ip.checker(t, c, test.expMSS, ws)
+ })
+ }
+ })
+ }
+}
+
+// TestUserSuppliedMSSOnListenAccept tests that the user supplied MSS is used
+// when completing the handshake for a new TCP connection from a TCP
+// listening socket. It should be present in the sent TCP SYN-ACK segment.
+func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
+ const (
+ nonSynCookieAccepts = 2
+ totalAccepts = 4
+ mtu = 5000
+ )
+
+ ips := []struct {
+ name string
+ createEP func(*context.Context)
+ sendPkt func(*context.Context, *context.Headers)
+ checker func(*testing.T, *context.Context, uint16, uint16)
+ maxMSS uint16
+ }{
+ {
+ name: "IPv4",
+ createEP: func(c *context.Context) {
+ c.Create(-1)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendPacket(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize,
+ },
+ {
+ name: "IPv6",
+ createEP: func(c *context.Context) {
+ c.CreateV6Endpoint(false)
+ },
+ sendPkt: func(c *context.Context, h *context.Headers) {
+ c.SendV6Packet(nil, h)
+ },
+ checker: func(t *testing.T, c *context.Context, srcPort, mss uint16) {
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagSyn|header.TCPFlagAck),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: -1})))
+ },
+ maxMSS: mtu - header.IPv6MinimumSize - header.TCPMinimumSize,
+ },
+ }
+
+ for _, ip := range ips {
+ t.Run(ip.name, func(t *testing.T) {
+ tests := []struct {
+ name string
+ setMSS uint16
+ expMSS uint16
+ }{
+ {
+ name: "EqualToMaxMSS",
+ setMSS: ip.maxMSS,
+ expMSS: ip.maxMSS,
+ },
+ {
+ name: "LessThanMaxMSS",
+ setMSS: ip.maxMSS - 1,
+ expMSS: ip.maxMSS - 1,
+ },
+ {
+ name: "GreaterThanMaxMSS",
+ setMSS: ip.maxMSS + 1,
+ expMSS: ip.maxMSS,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ ip.createEP(c)
+
+ // Set the SynRcvd threshold to force a syn cookie based accept to happen.
+ opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
+ t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err)
+ }
+
+ if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
+ t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
+ }
+
+ bindAddr := tcpip.FullAddress{Port: context.StackPort}
+ if err := c.EP.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s:", bindAddr, err)
+ }
+
+ if err := c.EP.Listen(totalAccepts); err != nil {
+ t.Fatalf("Listen(%d): %s:", totalAccepts, err)
+ }
+
+ // The first nonSynCookieAccepts packets sent will trigger a gorooutine
+ // based accept. The rest will trigger a cookie based accept.
+ for i := 0; i < totalAccepts; i++ {
+ // Send a SYN requests.
+ iss := seqnum.Value(i)
+ srcPort := context.TestPort + uint16(i)
+ ip.sendPkt(c, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ ip.checker(t, c, srcPort, test.expMSS)
+ }
+ })
+ }
+ })
+ }
+}
+func TestSendRstOnListenerRxSynAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxSynAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv4.
+func TestTCPAckBeforeAcceptV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete,
+// peers can send data and expect a response within a reasonable ammount of time
+// without calling Accept on the listening endpoint first.
+//
+// This test uses IPv6.
+func TestTCPAckBeforeAcceptV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ // Send data before accepting the connection.
+ c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestSendRstOnListenerRxAckV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+func TestSendRstOnListenerRxAckV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true /* v6Only */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagFin | header.TCPFlagAck,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ checker.IPv6(t, c.GetV6Packet(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ checker.SeqNum(200)))
+}
+
+// TestListenShutdown tests for the listening endpoint replying with RST
+// on read shutdown.
+func TestListenShutdown(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
+ t.Fatal("Shutdown failed:", err)
+ }
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: 100,
+ AckNum: 200,
+ })
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
+// TestListenCloseWhileConnect tests for the listening endpoint to
+// drain the accept-queue when closed. This should reset all of the
+// pending connections that are waiting to be accepted.
+func TestListenCloseWhileConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1 /* epRcvBuf */)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(1 /* backlog */); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventIn)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+ // Wait for the new endpoint created because of handshake to be delivered
+ // to the listening endpoint's accept queue.
+ <-notifyCh
+
+ // Close the listening endpoint.
+ c.EP.Close()
+
+ // Expect the listening endpoint to reset the connection.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
func TestTOSV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
@@ -485,17 +1243,17 @@ func TestTOSV4(t *testing.T) {
c.EP = ep
const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err)
}
- var v tcpip.IPv4TOSOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Errorf("GetSockopt failed: %s", err)
+ v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err)
}
- if want := tcpip.IPv4TOSOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos)
}
testV4Connect(t, c, checker.TOS(tos, 0))
@@ -533,17 +1291,17 @@ func TestTrafficClassV6(t *testing.T) {
c.CreateV6Endpoint(false)
const tos = 0xC0
- if err := c.EP.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv6TrafficClassOption(tos), err)
+ if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil {
+ t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err)
}
- var v tcpip.IPv6TrafficClassOption
- if err := c.EP.GetSockOpt(&v); err != nil {
- t.Fatalf("GetSockopt failed: %s", err)
+ v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tos {
+ t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos)
}
// Test the connection request.
@@ -578,12 +1336,12 @@ func TestTrafficClassV6(t *testing.T) {
func TestConnectBindToDevice(t *testing.T) {
for _, test := range []struct {
name string
- device string
+ device tcpip.NICID
want tcp.EndpointState
}{
- {"RightDevice", "nic1", tcp.StateEstablished},
- {"WrongDevice", "nic2", tcp.StateSynSent},
- {"AnyDevice", "", tcp.StateEstablished},
+ {"RightDevice", 1, tcp.StateEstablished},
+ {"WrongDevice", 2, tcp.StateSynSent},
+ {"AnyDevice", 0, tcp.StateEstablished},
} {
t.Run(test.name, func(t *testing.T) {
c := context.New(t, defaultMTU)
@@ -598,7 +1356,7 @@ func TestConnectBindToDevice(t *testing.T) {
defer c.WQ.EventUnregister(&waitEntry)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -610,7 +1368,7 @@ func TestConnectBindToDevice(t *testing.T) {
),
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
}
tcpHdr := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
@@ -629,7 +1387,95 @@ func TestConnectBindToDevice(t *testing.T) {
c.GetPacket()
if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
- t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Fatalf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+ })
+ }
+}
+
+func TestSynSent(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ reset bool
+ }{
+ {"RstOnSynSent", true},
+ {"CloseOnSynSent", false},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create an endpoint, don't handshake because we want to interfere with the
+ // handshake process.
+ c.Create(-1)
+
+ // Start connection attempt.
+ waitEntry, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}
+ if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted {
+ t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ if test.reset {
+ // Send a packet with a proper ACK and a RST flag to cause the socket
+ // to error and close out.
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+ } else {
+ c.EP.Close()
+ }
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatal("timed out waiting for packet to arrive")
+ }
+
+ if test.reset {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
+ }
+ } else {
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
+ }
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+
+ // Due to the RST the endpoint should be in an error state.
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
+ t.Fatalf("got State() = %s, want %s", got, want)
}
})
}
@@ -646,7 +1492,7 @@ func TestOutOfOrderReceive(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send second half of data first, with seqnum 3 ahead of expected.
@@ -673,7 +1519,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send the first 3 bytes now.
@@ -700,7 +1546,7 @@ func TestOutOfOrderReceive(t *testing.T) {
}
continue
}
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -730,7 +1576,7 @@ func TestOutOfOrderFlood(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 100 packets before the actual one that is expected.
@@ -807,7 +1653,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -850,7 +1696,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// This final ACK should be ignored because an ACK on a reset doesn't mean
@@ -876,7 +1722,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
data := []byte{1, 2, 3}
@@ -918,7 +1764,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
))
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Cause a RST to be generated by closing the read end now since we have
@@ -930,12 +1776,14 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
- // We shouldn't consume a sequence number on RST.
- checker.SeqNum(uint32(c.IRS)+1),
+ // RST is always generated with sndNxt which if the FIN
+ // has been sent will be 1 higher than the sequence
+ // number of the FIN itself.
+ checker.SeqNum(uint32(c.IRS)+2),
))
// The RST puts the endpoint into an error state.
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// The ACK to the FIN should now be rejected since the connection has been
@@ -957,19 +1805,19 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
- t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %v want %v", got, want)
+ t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
}
}
@@ -985,7 +1833,7 @@ func TestFullWindowReceive(t *testing.T) {
_, _, err := c.EP.Read(nil)
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
// Fill up the window.
@@ -1020,7 +1868,7 @@ func TestFullWindowReceive(t *testing.T) {
// Receive data and check it.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -1029,7 +1877,7 @@ func TestFullWindowReceive(t *testing.T) {
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want {
- t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %v want %v", got, want)
+ t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want)
}
// Check that we get an ACK for the newly non-zero window.
@@ -1052,7 +1900,7 @@ func TestNoWindowShrinking(t *testing.T) {
c.CreateConnected(789, 30000, 10)
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -1060,7 +1908,7 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send 3 bytes, check that the peer acknowledges them.
@@ -1124,7 +1972,7 @@ func TestNoWindowShrinking(t *testing.T) {
for len(read) < len(data) {
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
read = append(read, v...)
@@ -1158,7 +2006,7 @@ func TestSimpleSend(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received.
@@ -1192,7 +2040,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, -1 /* epRcvBuf */)
+ c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1200,11 +2048,20 @@ func TestZeroWindowSend(t *testing.T) {
_, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
- // Since the window is currently zero, check that no packet is received.
- c.CheckNoPacket("Packet received when window is zero")
+ // Check if we got a zero-window probe.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
// Open up the window. Data should be received now.
c.SendPacket(nil, &context.Headers{
@@ -1217,7 +2074,7 @@ func TestZeroWindowSend(t *testing.T) {
})
// Check that data is received.
- b := c.GetPacket()
+ b = c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
@@ -1259,7 +2116,7 @@ func TestScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -1291,7 +2148,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -1319,21 +2176,21 @@ func TestScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -1351,7 +2208,7 @@ func TestScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -1364,7 +2221,7 @@ func TestScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xbfff,
@@ -1392,21 +2249,21 @@ func TestNonScaledWindowAccept(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN
@@ -1425,7 +2282,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -1438,7 +2295,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received, and that advertised window is 0xffff,
@@ -1522,10 +2379,14 @@ func TestZeroScaledWindowReceive(t *testing.T) {
)
}
- // Read some data. An ack should be sent in response to that.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %v", err)
+ // Read at least 1MSS of data. An ack should be sent in response to that.
+ sz := 0
+ for sz < defaultMTU {
+ v, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+ sz += len(v)
}
checker.IPv4(t, c.GetPacket(),
@@ -1534,7 +2395,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(uint32(790+sent)),
- checker.Window(uint16(len(v)>>ws)),
+ checker.Window(uint16(sz>>ws)),
checker.TCPFlags(header.TCPFlagAck),
),
)
@@ -1558,10 +2419,10 @@ func TestSegmentMerging(t *testing.T) {
{
"cork",
func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(1))
+ ep.SetSockOptBool(tcpip.CorkOption, true)
},
func(ep tcpip.Endpoint) {
- ep.SetSockOpt(tcpip.CorkOption(0))
+ ep.SetSockOptBool(tcpip.CorkOption, false)
},
},
}
@@ -1573,20 +2434,50 @@ func TestSegmentMerging(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- // Prevent the endpoint from processing packets.
- test.stop(c.EP)
+ // Send tcp.InitialCwnd number of segments to fill up
+ // InitialWindow but don't ACK. That should prevent
+ // anymore packets from going out.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ view := buffer.NewViewFromBytes([]byte{0})
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write #%d failed: %s", i+1, err)
+ }
+ }
+ // Now send the segments that should get merged as the congestion
+ // window is full and we won't be able to send any more packets.
var allData []byte
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
- // Let the endpoint process the segments that we just sent.
- test.resume(c.EP)
+ // Check that we get tcp.InitialCwnd packets.
+ for i := 0; i < tcp.InitialCwnd; i++ {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(header.TCPMinimumSize+1),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+uint32(i)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload.
+ RcvWnd: 30000,
+ })
// Check that data is received.
b := c.GetPacket()
@@ -1594,7 +2485,7 @@ func TestSegmentMerging(t *testing.T) {
checker.PayloadLen(len(allData)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.SeqNum(uint32(c.IRS)+1),
+ checker.SeqNum(uint32(c.IRS)+11),
checker.AckNum(790),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
@@ -1610,7 +2501,7 @@ func TestSegmentMerging(t *testing.T) {
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: 790,
- AckNum: c.IRS.Add(1 + seqnum.Size(len(allData))),
+ AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))),
RcvWnd: 30000,
})
})
@@ -1623,14 +2514,14 @@ func TestDelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -1671,13 +2562,13 @@ func TestUndelay(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOptInt(tcpip.DelayOption, 1)
+ c.EP.SetSockOptBool(tcpip.DelayOption, true)
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -1704,7 +2595,7 @@ func TestUndelay(t *testing.T) {
// Check that we don't get the second packet yet.
c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond)
- c.EP.SetSockOptInt(tcpip.DelayOption, 0)
+ c.EP.SetSockOptBool(tcpip.DelayOption, false)
// Check that data is received.
second := c.GetPacket()
@@ -1741,8 +2632,8 @@ func TestMSSNotDelayed(t *testing.T) {
fn func(tcpip.Endpoint)
}{
{"no-op", func(tcpip.Endpoint) {}},
- {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptInt(tcpip.DelayOption, 1) }},
- {"cork", func(ep tcpip.Endpoint) { ep.SetSockOpt(tcpip.CorkOption(1)) }},
+ {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }},
+ {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }},
}
for _, test := range tests {
@@ -1761,7 +2652,7 @@ func TestMSSNotDelayed(t *testing.T) {
for i, data := range allData {
view := buffer.NewViewFromBytes(data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write #%d failed: %v", i+1, err)
+ t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -1812,7 +2703,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that data is received in chunks.
@@ -1880,15 +2771,15 @@ func TestSetTTL(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
- if err := c.EP.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("Unexpected return value from Connect: %v", err)
+ t.Fatalf("unexpected return value from Connect: %s", err)
}
// Receive SYN packet.
@@ -1920,7 +2811,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
@@ -1928,15 +2819,15 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -1954,7 +2845,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -1974,26 +2865,24 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
// Set the SynRcvd threshold to zero to force a syn cookie based accept
// to happen.
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Do 3-way handshake.
@@ -2011,7 +2900,7 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -2045,7 +2934,7 @@ func TestForwarderSendMSSLessThanMTU(t *testing.T) {
select {
case err := <-ch:
if err != nil {
- t.Fatalf("Error creating endpoint: %v", err)
+ t.Fatalf("Error creating endpoint: %s", err)
}
case <-time.After(2 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2064,7 +2953,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Set the buffer size to a deterministic size so that we can check the
@@ -2072,7 +2961,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
const rcvBufferSize = 0x20000
const wndScale = 2
if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
- t.Fatalf("SetSockOpt failed failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err)
}
// Start connection attempt.
@@ -2081,7 +2970,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
// Receive SYN packet.
@@ -2135,7 +3024,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
select {
case <-ch:
if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -2150,22 +3039,22 @@ func TestCloseListener(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Close the listener and measure how long it takes.
t0 := time.Now()
ep.Close()
if diff := time.Now().Sub(t0); diff > 3*time.Second {
- t.Fatalf("Took too long to close: %v", diff)
+ t.Fatalf("Took too long to close: %s", diff)
}
}
@@ -2201,16 +3090,26 @@ loop:
case tcpip.ErrConnectionReset:
break loop
default:
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
// Expect the state to be StateError and subsequent Reads to fail with HardError.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
+
+ if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
+ t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
}
func TestSendOnResetConnection(t *testing.T) {
@@ -2234,7 +3133,162 @@ func TestSendOnResetConnection(t *testing.T) {
// Try to write.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ }
+}
+
+// TestMaxRetransmitsTimeout tests if the connection is timed out after
+// a segment has been retransmitted MaxRetries times.
+func TestMaxRetransmitsTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const numRetries = 2
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil {
+ t.Fatalf("could not set protocol option MaxRetries.\n")
+ }
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Expect first transmit and MaxRetries retransmits.
+ for i := 0; i < numRetries+1; i++ {
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh),
+ ),
+ )
+ }
+ // Wait for the connection to timeout after MaxRetries retransmits.
+ initRTO := 1 * time.Second
+ select {
+ case <-notifyCh:
+ case <-time.After((2 << numRetries) * initRTO):
+ t.Fatalf("connection still alive after maximum retransmits.\n")
+ }
+
+ // Send an ACK and expect a RST as the connection would have been closed.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+// TestMaxRTO tests if the retransmit interval caps to MaxRTO.
+func TestMaxRTO(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ rto := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err)
+ }
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ const numRetransmits = 2
+ for i := 0; i < numRetransmits; i++ {
+ start := time.Now()
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
+ t.Errorf("Retransmit interval not capped to MaxRTO.\n")
+ }
+ }
+}
+
+// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is
+// unique on retransmits.
+func TestRetransmitIPv4IDUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ size int
+ }{
+ {"1Byte", 1},
+ {"512Bytes", 512},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
+
+ // Disabling PMTU discovery causes all packets sent from this socket to
+ // have DF=0. This needs to be done because the IPv4 ID uniqueness
+ // applies only to non-atomic IPv4 datagrams as defined in RFC 6864
+ // Section 4, and datagrams with DF=0 are non-atomic.
+ if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil {
+ t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}}
+ // Expect two retransmitted packets, and that all packets received have
+ // unique IPv4 ID values.
+ for i := 0; i <= 2; i++ {
+ pkt := c.GetPacket()
+ checker.IPv4(t, pkt,
+ checker.FragmentFlags(0),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ id := header.IPv4(pkt).ID()
+ if _, exists := idSet[id]; exists {
+ t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id)
+ }
+ idSet[id] = struct{}{}
+ }
+ })
}
}
@@ -2246,7 +3300,7 @@ func TestFinImmediately(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2289,7 +3343,7 @@ func TestFinRetransmit(t *testing.T) {
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2344,7 +3398,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -2370,7 +3424,7 @@ func TestFinWithNoPendingData(t *testing.T) {
// Shutdown, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2417,7 +3471,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
view := buffer.NewView(10)
for i := tcp.InitialCwnd; i > 0; i-- {
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
}
@@ -2439,7 +3493,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// because the congestion window doesn't allow it. Wait until a
// retransmit is received.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2503,7 +3557,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -2529,7 +3583,7 @@ func TestFinWithPendingData(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2545,7 +3599,7 @@ func TestFinWithPendingData(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2590,7 +3644,7 @@ func TestFinWithPartialAck(t *testing.T) {
// FIN from the test side.
view := buffer.NewView(10)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -2627,7 +3681,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Write new data, but don't acknowledge it.
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2643,7 +3697,7 @@ func TestFinWithPartialAck(t *testing.T) {
// Shutdown the connection, check that we do get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2689,20 +3743,20 @@ func TestUpdateListenBacklog(t *testing.T) {
var wq waiter.Queue
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Update the backlog with another Listen() on the same endpoint.
if err := ep.Listen(20); err != nil {
- t.Fatalf("Listen failed to update backlog: %v", err)
+ t.Fatalf("Listen failed to update backlog: %s", err)
}
ep.Close()
@@ -2734,7 +3788,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
// Send some data. Check that it's capped by the window size.
view := buffer.NewView(65535)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Check that only data that fits in the scaled window is sent.
@@ -2780,18 +3834,18 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
})
if got := stats.TCP.ValidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want {
- t.Errorf("got EP stats Stats.SegmentsReceived = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want)
}
// Ensure there were no errors during handshake. If these stats have
// incremented, then the connection should not have been established.
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoRoute = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0)
}
if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 {
- t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %v, want = %v", got, 0)
+ t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0)
}
}
@@ -2809,16 +3863,16 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
c.SendSegment(vv)
if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want {
- t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
}
}
@@ -2836,7 +3890,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
// Overwrite a byte in the payload which should cause checksum
// verification to fail.
tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
@@ -2905,6 +3959,13 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed
+ // after 1 second in TIME_WAIT state.
+ tcpTimeWaitTimeout := 1 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
@@ -2912,12 +3973,12 @@ func TestReadAfterClosedState(t *testing.T) {
defer c.WQ.EventUnregister(&we)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
checker.IPv4(t, c.GetPacket(),
@@ -2931,7 +3992,7 @@ func TestReadAfterClosedState(t *testing.T) {
)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Send some data and acknowledge the FIN.
@@ -2955,13 +4016,12 @@ func TestReadAfterClosedState(t *testing.T) {
),
)
- // Give the stack the chance to transition to closed state. Note that since
- // both the sender and receiver are now closed, we effectively skip the
- // TIME-WAIT state.
- time.Sleep(1 * time.Second)
+ // Give the stack the chance to transition to closed state from
+ // TIME_WAIT.
+ time.Sleep(tcpTimeWaitTimeout * 2)
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Wait for receive to be notified.
@@ -2975,7 +4035,7 @@ func TestReadAfterClosedState(t *testing.T) {
peekBuf := make([]byte, 10)
n, _, err := c.EP.Peek([][]byte{peekBuf})
if err != nil {
- t.Fatalf("Peek failed: %v", err)
+ t.Fatalf("Peek failed: %s", err)
}
peekBuf = peekBuf[:n]
@@ -2986,7 +4046,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Receive data.
v, _, err := c.EP.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
if !bytes.Equal(data, v) {
@@ -2996,11 +4056,11 @@ func TestReadAfterClosedState(t *testing.T) {
// Now that we drained the queue, check that functions fail with the
// right error code.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %v, want = %v", err, tcpip.ErrClosedForReceive)
+ t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
}
}
@@ -3014,66 +4074,84 @@ func TestReusePort(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Second case, an endpoint that was bound and is connecting..
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
c.EP.Close()
// Third case, an endpoint that was bound and is listening.
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
c.EP.Close()
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
+ }
+ if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil {
+ t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
}
@@ -3082,11 +4160,11 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got receive buffer size = %v, want = %v", s, v)
+ t.Fatalf("got receive buffer size = %d, want = %d", s, v)
}
}
@@ -3095,11 +4173,11 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
if err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("GetSockOpt failed: %s", err)
}
if int(s) != v {
- t.Fatalf("got send buffer size = %v, want = %v", s, v)
+ t.Fatalf("got send buffer size = %d, want = %d", s, v)
}
}
@@ -3112,7 +4190,7 @@ func TestDefaultBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer func() {
if ep != nil {
@@ -3124,28 +4202,34 @@ func TestDefaultBufferSizes(t *testing.T) {
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default send buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize * 2, tcp.DefaultSendBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultSendBufferSize * 2,
+ Max: tcp.DefaultSendBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize)
// Change the default receive buffer size.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize * 3, tcp.DefaultReceiveBufferSize * 30}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{
+ Min: 1,
+ Default: tcp.DefaultReceiveBufferSize * 3,
+ Max: tcp.DefaultReceiveBufferSize * 30}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
ep.Close()
ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2)
@@ -3161,41 +4245,41 @@ func TestMinMaxBufferSizes(t *testing.T) {
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
// Change the min/max values for send/receive
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultReceiveBufferSize * 2, tcp.DefaultReceiveBufferSize * 20}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultSendBufferSize * 3, tcp.DefaultSendBufferSize * 30}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Set values below the min.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err)
}
checkRecvBufferSize(t, ep, 200)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
- t.Fatalf("GetSockOpt failed: %v", err)
+ t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
@@ -3208,50 +4292,45 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
- }
-
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
+ if err := s.CreateNIC(321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %s", err)
}
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %s, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -3314,12 +4393,12 @@ func TestSelfConnect(t *testing.T) {
var wq waiter.Queue
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Register for notification, then start connection attempt.
@@ -3328,12 +4407,12 @@ func TestSelfConnect(t *testing.T) {
defer wq.EventUnregister(&waitEntry)
if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted {
- t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted)
+ t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted)
}
<-notifyCh
if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("Connect failed: %v", err)
+ t.Fatalf("Connect failed: %s", err)
}
// Write something.
@@ -3341,7 +4420,7 @@ func TestSelfConnect(t *testing.T) {
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
// Read back what was written.
@@ -3350,12 +4429,12 @@ func TestSelfConnect(t *testing.T) {
rd, _, err := ep.Read(nil)
if err != nil {
if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
<-notifyCh
rd, _, err = ep.Read(nil)
if err != nil {
- t.Fatalf("Read failed: %v", err)
+ t.Fatalf("Read failed: %s", err)
}
}
@@ -3439,18 +4518,18 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
}
ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
eps = append(eps, ep)
switch network {
case "ipv4":
case "ipv6":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(1)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err)
}
case "dual":
- if err := ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- t.Fatalf("SetSockOpt(V6OnlyOption(0)) failed: %v", err)
+ if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil {
+ t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err)
}
default:
t.Fatalf("unknown network: '%s'", network)
@@ -3490,7 +4569,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
- t.Fatalf("Bind(%d) failed: %v", i, err)
+ t.Fatalf("Bind(%d) failed: %s", i, err)
}
}
want := tcpip.ErrConnectStarted
@@ -3498,7 +4577,7 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
want = tcpip.ErrNoPortAvailable
}
if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want {
- t.Fatalf("got ep.Connect(..) = %v, want = %v", err, want)
+ t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want)
}
})
}
@@ -3532,7 +4611,7 @@ func TestPathMTUDiscovery(t *testing.T) {
}
if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
@@ -3635,7 +4714,7 @@ func TestStackSetCongestionControl(t *testing.T) {
var oldCC tcpip.CongestionControlOption
if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil {
- t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err)
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err)
}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err {
@@ -3722,12 +4801,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
var oldCC tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&oldCC); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err)
}
if connected {
@@ -3735,12 +4814,12 @@ func TestEndpointSetCongestionControl(t *testing.T) {
}
if err := c.EP.SetSockOpt(tc.cc); err != tc.err {
- t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err)
+ t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err)
}
var cc tcpip.CongestionControlOption
if err := c.EP.GetSockOpt(&cc); err != nil {
- t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err)
+ t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err)
}
got, want := cc, oldCC
@@ -3763,7 +4842,7 @@ func enableCUBIC(t *testing.T, c *context.Context) {
t.Helper()
opt := tcpip.CongestionControlOption("cubic")
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil {
- t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err)
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err)
}
}
@@ -3773,10 +4852,11 @@ func TestKeepalive(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
- c.EP.SetSockOpt(tcpip.KeepaliveCountOption(5))
- c.EP.SetSockOpt(tcpip.KeepaliveEnabledOption(1))
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
// 5 unacked keepalives are sent. ACK each one, and check that the
// connection stays alive after 5.
@@ -3804,14 +4884,14 @@ func TestKeepalive(t *testing.T) {
// Check that the connection is still alive.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
}
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
view := buffer.NewView(3)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Write failed: %v", err)
+ t.Fatalf("Write failed: %s", err)
}
next := uint32(c.IRS) + 1
@@ -3864,18 +4944,45 @@ func TestKeepalive(t *testing.T) {
)
}
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
+
// The connection should be terminated after 5 unacked keepalives.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(next)),
- checker.AckNum(uint32(790)),
- checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
),
)
+ if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
+ }
+
if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
+ }
+
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
}
}
@@ -3890,7 +4997,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
RcvWnd: 30000,
})
- // Receive the SYN-ACK reply.w
+ // Receive the SYN-ACK reply.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss = seqnum.Value(tcp.SequenceNumber())
@@ -3923,6 +5030,50 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki
return irs, iss
}
+func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) {
+ // Send a SYN request.
+ irs = seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetV6Packet()
+ tcp := header.TCP(header.IPv6(b).Payload())
+ iss = seqnum.Value(tcp.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(srcPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+
+ if synCookieInUse {
+ // When cookies are in use window scaling is disabled.
+ tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{
+ WS: -1,
+ MSS: c.MSSWithoutOptionsV6(),
+ }))
+ }
+
+ checker.IPv6(t, b, checker.TCP(tcpCheckers...))
+
+ // Send ACK.
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: srcPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+ return irs, iss
+}
+
// TestListenBacklogFull tests that netstack does not complete handshakes if the
// listen backlog for the endpoint is full.
func TestListenBacklogFull(t *testing.T) {
@@ -3933,19 +5084,19 @@ func TestListenBacklogFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 2
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
for i := 0; i < listenBacklog; i++ {
@@ -3978,7 +5129,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4007,7 +5158,7 @@ func TestListenBacklogFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4021,7 +5172,215 @@ func TestListenBacklogFull(t *testing.T) {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
+ }
+}
+
+// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
+// non unicast IPv4 address are not accepted.
+func TestListenNoAcceptNonUnicastV4(t *testing.T) {
+ multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
+ otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv4Any,
+ context.StackAddr,
+ },
+ {
+ "SourceBroadcast",
+ header.IPv4Broadcast,
+ context.StackAddr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackAddr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackAddr,
+ },
+ {
+ "DestUnspecified",
+ context.TestAddr,
+ header.IPv4Any,
+ },
+ {
+ "DestBroadcast",
+ context.TestAddr,
+ header.IPv4Broadcast,
+ },
+ {
+ "DestOurMulticast",
+ context.TestAddr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestAddr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ test := test // capture range variable
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestAddr, context.StackAddr)
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
+// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
+// non unicast IPv6 address are not accepted.
+func TestListenNoAcceptNonUnicastV6(t *testing.T) {
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv6Any,
+ context.StackV6Addr,
+ },
+ {
+ "SourceAllNodes",
+ header.IPv6AllNodesMulticastAddress,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "DestUnspecified",
+ context.TestV6Addr,
+ header.IPv6Any,
+ },
+ {
+ "DestAllNodes",
+ context.TestV6Addr,
+ header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ "DestOurMulticast",
+ context.TestV6Addr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestV6Addr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ test := test // capture range variable
+
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestV6Addr, context.StackV6Addr)
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
}
}
@@ -4033,19 +5392,19 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
// Start listening.
listenBacklog := 1
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
@@ -4056,7 +5415,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(789),
+ SeqNum: irs,
RcvWnd: 30000,
})
@@ -4111,7 +5470,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
case <-ch:
newEP, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4125,30 +5484,28 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
- t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data)
+ t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data)
}
}
func TestListenBacklogFullSynCookieInUse(t *testing.T) {
- saved := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = saved
- }()
- tcp.SynRcvdCountThreshold = 1
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err)
+ }
+
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -4156,7 +5513,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
listenBacklog := 1
portOffset := uint16(0)
if err := c.EP.Listen(listenBacklog); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
executeHandshake(t, c, context.TestPort+portOffset, false)
@@ -4167,7 +5524,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
// Send a SYN request.
irs := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort,
+ // pick a different src port for new SYN.
+ SrcPort: context.TestPort + 1,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
@@ -4188,7 +5546,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4207,26 +5565,145 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
+func TestSynRcvdBadSeqNumber(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcpHdr.SequenceNumber())
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now send a packet with an out-of-window sequence number
+ largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: largeSeqnum,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Should receive an ACK with the expected SEQ number
+ b = c.GetPacket()
+ tcpCheckers = []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.AckNum(uint32(irs) + 1),
+ checker.SeqNum(uint32(iss + 1)),
+ }
+ checker.IPv4(t, b, checker.TCP(tcpCheckers...))
+
+ // Now that the socket replied appropriately with the ACK,
+ // complete the connection to test that the large SEQ num
+ // did not change the state from SYN-RCVD.
+
+ // Send ACK to move to ESTABLISHED state.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ newEP, _, err := c.EP.Accept()
+
+ if err != nil && err != tcpip.ErrWouldBlock {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ // Try to accept the connections in the backlog.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ newEP, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now verify that the TCP socket is usable and in a connected state.
+ data := "Don't panic"
+ _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+
+ if err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ pkt := c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(pkt).Payload())
+ if string(tcpHdr.Payload()) != data {
+ t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data)
+ }
+}
+
func TestPassiveConnectionAttemptIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
stats := c.Stack().Stats()
@@ -4247,7 +5724,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4256,7 +5733,7 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) {
}
if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want {
- t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want)
}
}
@@ -4267,14 +5744,14 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
stats := c.Stack().Stats()
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
c.EP = ep
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(1); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
srcPort := uint16(context.TestPort)
@@ -4299,10 +5776,10 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
time.Sleep(50 * time.Millisecond)
if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want)
}
if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want {
- t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %v, want = %v", got, want)
+ t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want)
}
we, ch := waiter.NewChannelEntry(nil)
@@ -4317,7 +5794,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
case <-ch:
_, _, err = c.EP.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4332,29 +5809,28 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
wq := &waiter.Queue{}
ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
- t.Fatalf("Bind failed: %v", err)
+ t.Fatalf("Bind failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- // Expect InvalidEndpointState errors on a read at this point.
- if _, _, err := ep.Read(nil); err != tcpip.ErrInvalidEndpointState {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrInvalidEndpointState)
+ if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
+ t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
}
- if got := ep.Stats().(*tcp.Stats).ReadErrors.InvalidEndpointState.Value(); got != 1 {
- t.Fatalf("got EP stats Stats.ReadErrors.InvalidEndpointState got %v want %v", got, 1)
+ if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
+ t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
if err := ep.Listen(10); err != nil {
- t.Fatalf("Listen failed: %v", err)
+ t.Fatalf("Listen failed: %s", err)
}
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -4371,7 +5847,7 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
case <-ch:
aep, _, err = ep.Accept()
if err != nil {
- t.Fatalf("Accept failed: %v", err)
+ t.Fatalf("Accept failed: %s", err)
}
case <-time.After(1 * time.Second):
@@ -4379,22 +5855,25 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
}
}
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
+ }
+ if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected {
+ t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected)
}
// Listening endpoint remains in listen state.
if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
ep.Close()
// Give worker goroutines time to receive the close notification.
time.Sleep(1 * time.Second)
if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
// Accepted endpoint remains open when the listen endpoint is closed.
if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want {
- t.Errorf("Unexpected endpoint state: want %v, got %v", want, got)
+ t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
}
@@ -4414,13 +5893,13 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// the segment queue holding unprocessed packets is limited to 500.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size defined above.
@@ -4464,6 +5943,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
rawEP.SendPacketWithTS(b[start:start+mss], tsVal)
packetsSent++
}
+
// Resume the worker so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
@@ -4512,7 +5992,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
return
}
if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) {
- t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w, wantRcvWnd)
+ t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w)
}
},
))
@@ -4531,16 +6011,16 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
stk := c.Stack()
// Set lower limits for auto-tuning tests. This is required because the
// test stops the worker which can cause packets to be dropped because
- // the segment queue holding unprocessed packets is limited to 500.
+ // the segment queue holding unprocessed packets is limited to 300.
const receiveBufferSize = 80 << 10 // 80KB.
const maxReceiveBufferSize = receiveBufferSize * 10
- if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, receiveBufferSize, maxReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Enable auto-tuning.
if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
// Change the expected window scale to match the value needed for the
// maximum buffer size used by stack.
@@ -4586,6 +6066,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
totalSent += mss
packetsSent++
}
+
// Resume it so that it only sees the packets once all of them
// are waiting to be read.
worker.ResumeWork()
@@ -4618,7 +6099,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// Invoke the moderation API. This is required for auto-tuning
// to happen. This method is normally expected to be invoked
// from a higher layer than tcpip.Endpoint. So we simulate
- // copying to user-space by invoking it explicitly here.
+ // copying to userspace by invoking it explicitly here.
c.EP.ModerateRecvBuf(totalCopied)
// Now send a keep-alive packet to trigger an ACK so that we can
@@ -4668,3 +6149,1300 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
payloadSize *= 2
}
}
+
+func TestDelayEnabled(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ checkDelayOption(t, c, false, false) // Delay is disabled by default.
+
+ for _, v := range []struct {
+ delayEnabled tcp.DelayEnabled
+ wantDelayOption bool
+ }{
+ {delayEnabled: false, wantDelayOption: false},
+ {delayEnabled: true, wantDelayOption: true},
+ } {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil {
+ t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err)
+ }
+ checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption)
+ }
+}
+
+func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) {
+ t.Helper()
+
+ var gotDelayEnabled tcp.DelayEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil {
+ t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err)
+ }
+ if gotDelayEnabled != wantDelayEnabled {
+ t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled)
+ }
+
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue))
+ if err != nil {
+ t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err)
+ }
+ gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption)
+ if err != nil {
+ t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err)
+ }
+ if gotDelayOption != wantDelayOption {
+ t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption)
+ }
+}
+
+func TestTCPLingerTimeout(t *testing.T) {
+ c := context.New(t, 1500 /* mtu */)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ testCases := []struct {
+ name string
+ tcpLingerTimeout time.Duration
+ want time.Duration
+ }{
+ {"NegativeLingerTimeout", -123123, 0},
+ {"ZeroLingerTimeout", 0, 0},
+ {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second},
+ // Values > stack's TCPLingerTimeout are capped to the stack's
+ // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds)
+ {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil {
+ t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err)
+ }
+ var v tcpip.TCPLingerTimeoutOption
+ if err := c.EP.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err)
+ }
+ if got, want := time.Duration(v), tc.want; got != want {
+ t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want)
+ }
+ })
+ }
+}
+
+func TestTCPTimeWaitRSTIgnored(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now send a RST and this should be ignored and not
+ // generate an ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagRst,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second)
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitOutOfOrder(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Out of order ACK should generate an immediate ACK in
+ // TIME_WAIT.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 3,
+ })
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+}
+
+func TestTCPTimeWaitNewSyn(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Send a SYN request w/ sequence number lower than
+ // the highest sequence number sent. We just reuse
+ // the same number.
+ iss = seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second)
+
+ // Send a SYN request w/ sequence number higher than
+ // the highest sequence number sent.
+ iss = seqnum.Value(792)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b = c.GetPacket()
+ tcpHdr = header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+}
+
+func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ time.Sleep(2 * time.Second)
+
+ // Now send a duplicate FIN. This should cause the TIME_WAIT to extend
+ // by another 5 seconds and also send us a duplicate ACK as it should
+ // indicate that the final ACK was potentially lost.
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+2)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Sleep for 4 seconds so at this point we are 1 second past the
+ // original tcpLingerTimeout of 5 seconds.
+ time.Sleep(4 * time.Second)
+
+ // Send an ACK and it should not generate any packet as the socket
+ // should still be in TIME_WAIT for another another 5 seconds due
+ // to the duplicate FIN we sent earlier.
+ *ackHeaders = *finHeaders
+ ackHeaders.SeqNum = ackHeaders.SeqNum + 1
+ ackHeaders.Flags = header.TCPFlagAck
+ c.SendPacket(nil, ackHeaders)
+
+ c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second)
+ // Now sleep for another 2 seconds so that we are past the
+ // extended TIME_WAIT of 7 seconds (2 + 5).
+ time.Sleep(2 * time.Second)
+
+ // Resend the same ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Receive the RST that should be generated as there is no valid
+ // endpoint.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+
+ if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
+ }
+}
+
+func TestTCPCloseWithData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed
+ // after 5 seconds in TIME_WAIT state.
+ tcpTimeWaitTimeout := 5 * time.Second
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil {
+ t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err)
+ }
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ RcvWnd: 30000,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Now trigger a passive close by sending a FIN.
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ RcvWnd: 30000,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+2),
+ checker.TCPFlags(header.TCPFlagAck)))
+
+ // Now write a few bytes and then close the endpoint.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ // Check that data is received.
+ b = c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) {
+ t.Errorf("got data = %x, want = %x", p, data)
+ }
+
+ c.EP.Close()
+ // Check the FIN.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))),
+ checker.AckNum(uint32(iss+2)),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ // First send a partial ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)-1),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send a full ACK.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Now ACK the FIN.
+ ackHeaders.AckNum++
+ c.SendPacket(nil, ackHeaders)
+
+ // Now send an ACK and we should get a RST back as the endpoint should
+ // be in CLOSED state.
+ ackHeaders = &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 2,
+ AckNum: c.IRS + 1 + seqnum.Value(len(data)),
+ RcvWnd: 30000,
+ }
+ c.SendPacket(nil, ackHeaders)
+
+ // Check the RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(ackHeaders.AckNum)),
+ checker.AckNum(0),
+ checker.TCPFlags(header.TCPFlagRst)))
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ // Ensure that on the next retransmit timer fire, the user timeout has
+ // expired.
+ initRTO := 1 * time.Second
+ userTimeout := initRTO / 2
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Send some data and wait before ACKing it.
+ view := buffer.NewView(3)
+ if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %s", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Wait for the retransmit timer to be fired and the user timeout to cause
+ // close of the connection.
+ select {
+ case <-notifyCh:
+ case <-time.After(2 * initRTO):
+ t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout)
+ }
+
+ // No packet should be received as the connection should be silently
+ // closed due to timeout.
+ c.CheckNoPacket("unexpected packet received after userTimeout has expired")
+
+ next += uint32(len(view))
+
+ // The connection should be terminated after userTimeout has expired.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(next)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
+ }
+
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+func TestKeepaliveWithUserTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
+
+ origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value()
+
+ const keepAliveInterval = 3 * time.Second
+ c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond))
+ c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval))
+ c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10)
+ c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true)
+
+ // Set userTimeout to be the duration to be 1 keepalive
+ // probes. Which means that after the first probe is sent
+ // the second one should cause the connection to be
+ // closed due to userTimeout being hit.
+ userTimeout := 1 * keepAliveInterval
+ c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout))
+
+ // Check that the connection is still alive.
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Now receive 1 keepalives, but don't ACK it.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)),
+ checker.AckNum(uint32(790)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Sleep for a litte over the KeepAlive interval to make sure
+ // the timer has time to fire after the last ACK and close the
+ // close the socket.
+ time.Sleep(keepAliveInterval + keepAliveInterval/2)
+
+ // The connection should be closed with a timeout.
+ // Send an ACK to trigger a RST from the stack as the endpoint should
+ // be dead.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(c.IRS + 1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(0)),
+ checker.TCPFlags(header.TCPFlagRst),
+ ),
+ )
+
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
+ }
+ if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
+ t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
+ }
+ if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
+ t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got)
+ }
+}
+
+func TestIncreaseWindowOnReceive(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after recv() when the window grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // We now have < 1 MSS in the buffer space. Read the data! An
+ // ack should be sent in response to that. The window was not
+ // zero, but it grew to larger than MSS.
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ if _, _, err := c.EP.Read(nil); err != nil {
+ t.Fatalf("Read failed: %s", err)
+ }
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestIncreaseWindowOnBufferResize(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after available recv buffer grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // Increasing the buffer from should generate an ACK,
+ // since window grew from small value to larger equal MSS
+ c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestTCPDeferAccept(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give a bit of time for the socket to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestTCPDeferAcceptTimeout(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatal("Bind failed:", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatal("Listen failed:", err)
+ }
+
+ const tcpDeferAccept = 1 * time.Second
+ if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil {
+ t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err)
+ }
+
+ irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */)
+
+ if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock)
+ }
+
+ // Sleep for a little of the tcpDeferAccept timeout.
+ time.Sleep(tcpDeferAccept + 100*time.Millisecond)
+
+ // On timeout expiry we should get a SYN-ACK retransmission.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+
+ // Send data. This should result in an acceptable endpoint.
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+
+ // Give sometime for the endpoint to be delivered to the accept queue.
+ time.Sleep(50 * time.Millisecond)
+ aep, _, err := c.EP.Accept()
+ if err != nil {
+ t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err)
+ }
+
+ aep.Close()
+ // Closing aep without reading the data should trigger a RST.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.SeqNum(uint32(iss+1)),
+ checker.AckNum(uint32(irs+5))))
+}
+
+func TestResetDuringClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ iss := seqnum.Value(789)
+ c.CreateConnected(iss, 30000, -1 /* epRecvBuf */)
+ // Send some data to make sure there is some unread
+ // data to trigger a reset on c.Close.
+ irs := c.IRS
+ c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss.Add(1),
+ AckNum: irs.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK for the data we sent.
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(irs.Add(1))),
+ checker.AckNum(uint32(iss.Add(5)))))
+
+ // Close in a separate goroutine so that we can trigger
+ // a race with the RST we send below. This should not
+ // panic due to the route being released depeding on
+ // whether Close() sends an active RST or the RST sent
+ // below is processed by the worker first.
+ var wg sync.WaitGroup
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(5),
+ AckNum: c.IRS.Add(5),
+ RcvWnd: 30000,
+ Flags: header.TCPFlagRst,
+ })
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c.EP.Close()
+ }()
+
+ wg.Wait()
+}
+
+func TestStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &twReuse, err)
+ }
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseLoopbackOnly; got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+}
+
+func TestSetStackTimeWaitReuse(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ testCases := []struct {
+ v int
+ err *tcpip.Error
+ }{
+ {int(tcpip.TCPTimeWaitReuseDisabled), nil},
+ {int(tcpip.TCPTimeWaitReuseGlobal), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly), nil},
+ {int(tcpip.TCPTimeWaitReuseLoopbackOnly) + 1, tcpip.ErrInvalidOptionValue},
+ {int(tcpip.TCPTimeWaitReuseDisabled) - 1, tcpip.ErrInvalidOptionValue},
+ }
+
+ for _, tc := range testCases {
+ err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v))
+ if got, want := err, tc.err; got != want {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err)
+ }
+ if tc.err != nil {
+ continue
+ }
+
+ var twReuse tcpip.TCPTimeWaitReuseOption
+ if err := s.TransportProtocolOption(tcp.ProtocolNumber, &twReuse); err != nil {
+ t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want nil", tcp.ProtocolNumber, &twReuse, err)
+ }
+
+ if got, want := twReuse, tcpip.TCPTimeWaitReuseOption(tc.v); got != want {
+ t.Fatalf("got tcpip.TCPTimeWaitReuseOption: %v, want: %v", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index a641e953d..8edbff964 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -127,16 +127,14 @@ func TestTimeStampDisabledConnect(t *testing.T) {
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
}
- c := context.New(t, defaultMTU)
- defer c.Cleanup()
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
@@ -148,7 +146,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ t.Fatalf("Unexpected error from Write: %s", err)
}
// Check that data is received and that the timestamp option TSEcr field
@@ -190,17 +188,15 @@ func TestTimeStampEnabledAccept(t *testing.T) {
}
func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
- savedSynCountThreshold := tcp.SynRcvdCountThreshold
- defer func() {
- tcp.SynRcvdCountThreshold = savedSynCountThreshold
- }()
- if cookieEnabled {
- tcp.SynRcvdCountThreshold = 0
- }
-
c := context.New(t, defaultMTU)
defer c.Cleanup()
+ if cookieEnabled {
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil {
+ t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err)
+ }
+ }
+
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
@@ -211,7 +207,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
- t.Fatalf("Unexpected error from Write: %v", err)
+ t.Fatalf("Unexpected error from Write: %s", err)
}
// Check that data is received and that the timestamp option is disabled
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index 19b0d31c5..ce6a2c31d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,9 +6,8 @@ go_library(
name = "context",
testonly = 1,
srcs = ["context.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
- "//:sandbox",
+ "//visibility:public",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index ef823e4ae..b6031354e 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -18,6 +18,7 @@ package context
import (
"bytes"
+ "context"
"testing"
"time"
@@ -142,13 +143,22 @@ func New(t *testing.T, mtu uint32) *Context {
TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultReceiveBufferSize, 10 * tcp.DefaultReceiveBufferSize}); err != nil {
- t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ // Increase minimum RTO in tests to avoid test flakes due to early
+ // retransmit in case the test executors are overloaded and cause timers
+ // to fire earlier than expected.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil {
+ t.Fatalf("failed to set stack-wide minRTO: %s", err)
}
// Some of the congestion control tests send up to 640 packets, we so
@@ -158,15 +168,17 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(1, wep, opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ opts2 := stack.NICOptions{Name: "nic2"}
+ if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
@@ -192,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
@@ -201,6 +213,7 @@ func (c *Context) Cleanup() {
if c.EP != nil {
c.EP.Close()
}
+ c.Stack().Close()
}
// Stack returns a reference to the stack in the Context.
@@ -213,11 +226,10 @@ func (c *Context) Stack() *stack.Stack {
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
- select {
- case <-c.linkEP.C:
+ ctx, cancel := context.WithTimeout(context.Background(), wait)
+ defer cancel()
+ if _, ok := c.linkEP.ReadContext(ctx); ok {
c.t.Fatal(errMsg)
-
- case <-time.After(wait):
}
}
@@ -231,27 +243,29 @@ func (c *Context) CheckNoPacket(errMsg string) {
// addresses. It will fail with an error if no packet is received for
// 2 seconds.
func (c *Context) GetPacket() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+ c.t.Helper()
- if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
- c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
- }
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
- case <-time.After(2 * time.Second):
- c.t.Fatalf("Packet wasn't written out")
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
+ c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
}
- return nil
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// GetPacketNonBlocking reads a packet from the link layer endpoint
@@ -259,24 +273,26 @@ func (c *Context) GetPacket() []byte {
// and destination address. If no packet is available it will return
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv4.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
- }
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+ c.t.Helper()
- checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
- return b
- default:
+ p, ok := c.linkEP.Read()
+ if !ok {
return nil
}
+
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
-func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
+func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
if len(buf) > maxTotalSize {
@@ -302,11 +318,20 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ })
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
+ return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
+}
+
+// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
+// payload and source and destination IPv4 addresses.
+func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -319,8 +344,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
+ SrcAddr: src,
+ DstAddr: dst,
})
ip.SetChecksum(^ip.CalculateChecksum())
@@ -337,7 +362,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -350,13 +375,29 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.Inject(ipv4.ProtocolNumber, s)
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: s,
+ })
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h))
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: c.BuildSegment(payload, h),
+ })
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
+}
+
+// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
+// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
+// provided source and destination IPv4 addresses.
+func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
+ })
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, pkt)
}
// SendAck sends an ACK packet.
@@ -389,6 +430,8 @@ func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlock
// verifies that the packet packet payload of packet matches the slice
// of data indicated by offset & size.
func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
+ c.t.Helper()
+
c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0)
}
@@ -397,6 +440,8 @@ func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
// data indicated by offset & size and skips optlen bytes in addition to the IP
// TCP headers when comparing the data.
func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) {
+ c.t.Helper()
+
b := c.GetPacket()
checker.IPv4(c.t, b,
checker.PayloadLen(size+header.TCPMinimumSize+optlen),
@@ -419,6 +464,8 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
// data indicated by offset & size. It returns true if a packet was received and
// processed.
func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
+ c.t.Helper()
+
b := c.GetPacketNonBlocking()
if b == nil {
return false
@@ -450,11 +497,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- var v tcpip.V6OnlyOption
- if v6only {
- v = 1
- }
- if err := c.EP.SetSockOpt(v); err != nil {
+ if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
@@ -462,28 +505,36 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// GetV6Packet reads a single packet from the link layer endpoint of the context
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
- select {
- case p := <-c.linkEP.C:
- if p.Proto != ipv6.ProtocolNumber {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
- }
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
-
- checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
- return b
+ c.t.Helper()
- case <-time.After(2 * time.Second):
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
+ }
+
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
- return nil
+ checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
+ return b
}
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
// the context.
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
+ c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
+}
+
+// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
+// endpoint of the context using the provided source and destination IPv6
+// addresses.
+func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -494,8 +545,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
NextHeader: uint8(tcp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: TestV6Addr,
- DstAddr: StackV6Addr,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
@@ -511,14 +562,17 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ })
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, pkt)
}
// CreateConnected creates a connected TCP endpoint.
@@ -535,6 +589,8 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf
//
// PreCondition: c.EP must already be created.
func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) {
+ c.t.Helper()
+
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
@@ -1051,7 +1107,11 @@ func (c *Context) SACKEnabled() bool {
// SetGSOEnabled enables or disables generic segmentation offload.
func (c *Context) SetGSOEnabled(enable bool) {
- c.linkEP.GSO = enable
+ if enable {
+ c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO
+ } else {
+ c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO
+ }
}
// MSSWithoutOptions returns the value for the MSS used by the stack when no
@@ -1059,3 +1119,9 @@ func (c *Context) SetGSOEnabled(enable bool) {
func (c *Context) MSSWithoutOptions() uint16 {
return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
}
+
+// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
+// options are in use for IPv6 packets.
+func (c *Context) MSSWithoutOptionsV6() uint16 {
+ return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
+}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index c70525f27..7981d469b 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -85,6 +85,7 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
t.timer.Stop()
+ *t = timer{}
}
// checkExpiration checks if the given timer has actually expired, it should be
diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go
new file mode 100644
index 000000000..dbd6dff54
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer_test.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+)
+
+func TestCleanup(t *testing.T) {
+ const (
+ timerDurationSeconds = 2
+ isAssertedTimeoutSeconds = timerDurationSeconds + 1
+ )
+
+ tmr := timer{}
+ w := sleep.Waker{}
+ tmr.init(&w)
+ tmr.enable(timerDurationSeconds * time.Second)
+ tmr.cleanup()
+
+ if want := (timer{}); tmr != want {
+ t.Errorf("got tmr = %+v, want = %+v", tmr, want)
+ }
+
+ // The waker should not be asserted.
+ for i := 0; i < isAssertedTimeoutSeconds; i++ {
+ time.Sleep(time.Second)
+ if w.IsAsserted() {
+ t.Fatalf("waker asserted unexpectedly")
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 43fcc27f0..3ad6994a7 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,12 +1,10 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "tcpconntrack",
srcs = ["tcp_conntrack.go"],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack",
visibility = ["//visibility:public"],
deps = [
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
index 93712cd45..558b06df0 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
return st
}
+// State returns the current state of the TCB.
+func (t *TCB) State() Result {
+ return t.state
+}
+
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
@@ -311,17 +316,7 @@ type stream struct {
// the window is zero, if it's a packet with no payload and sequence number
// equal to una.
func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
- wnd := s.una.Size(s.end)
- if wnd == 0 {
- return segLen == 0 && segSeq == s.una
- }
-
- // Make sure [segSeq, seqSeq+segLen) is non-empty.
- if segLen == 0 {
- segLen = 1
- }
-
- return seqnum.Overlap(s.una, wnd, segSeq, segLen)
+ return header.Acceptable(segSeq, segLen, s.una, s.end)
}
// closed determines if the stream has already been closed. This happens when
@@ -347,3 +342,16 @@ func logicalLen(tcp header.TCP) seqnum.Size {
}
return l
}
+
+// IsEmpty returns true if tcb is not initialized.
+func (t *TCB) IsEmpty() bool {
+ if t.inbound != (stream{}) || t.outbound != (stream{}) {
+ return false
+ }
+
+ if t.firstFin != nil || t.state != ResultDrop {
+ return false
+ }
+
+ return true
+}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c9460aa0d..b5d2d0ba6 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -25,15 +24,15 @@ go_library(
"protocol.go",
"udp_packet_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/udp",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
+ "//pkg/tcpip/ports",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/raw",
"//pkg/waiter",
@@ -59,11 +58,3 @@ go_test(
"//pkg/waiter",
],
)
-
-filegroup(
- name = "autogen",
- srcs = [
- "udp_packet_list.go",
- ],
- visibility = ["//:sandbox"],
-)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 91c8487f3..73608783c 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,12 +15,14 @@
package udp
import (
- "sync"
+ "fmt"
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -29,11 +31,11 @@ import (
type udpPacket struct {
udpPacketEntry
senderAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
+ // tos stores either the receiveTOS or receiveTClass value.
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -80,6 +82,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -93,6 +96,7 @@ type endpoint struct {
// The following fields are protected by the mu mutex.
mu sync.RWMutex `state:"nosave"`
sndBufSize int
+ sndBufSizeMax int
state EndpointState
route stack.Route `state:"manual"`
dstPort uint16
@@ -102,14 +106,34 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
multicastLoop bool
- reusePort bool
+ portFlags ports.Flags
bindToDevice tcpip.NICID
broadcast bool
+ noChecksum bool
+
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
+
+ // Values used to reserve a port or register a transport endpoint.
+ // (which ever happens first).
+ boundBindToDevice tcpip.NICID
+ boundPortFlags ports.Flags
// sendTOS represents IPv4 TOS or IPv6 TrafficClass,
// applied while sending packets. Defaults to 0 as on Linux.
sendTOS uint8
+ // receiveTOS determines if the incoming IPv4 TOS header field is passed
+ // as ancillary data to ControlMessages on Read.
+ receiveTOS bool
+
+ // receiveTClass determines if the incoming IPv6 TClass header field is
+ // passed as ancillary data to ControlMessages on Read.
+ receiveTClass bool
+
+ // receiveIPPacketInfo determines if the packet info is returned by Read.
+ receiveIPPacketInfo bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -127,6 +151,9 @@ type endpoint struct {
// TODO(b/142022063): Add ability to save and restore per endpoint stats.
stats tcpip.TransportEndpointStats `state:"nosave"`
+
+ // owner is used to get uid and gid of the packet.
+ owner tcpip.PacketOwner
}
// +stateify savable
@@ -136,7 +163,7 @@ type multicastMembership struct {
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
- return &endpoint{
+ e := &endpoint{
stack: s,
TransportEndpointInfo: stack.TransportEndpointInfo{
NetProto: netProto,
@@ -158,9 +185,42 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
multicastTTL: 1,
multicastLoop: true,
rcvBufSizeMax: 32 * 1024,
- sndBufSize: 32 * 1024,
+ sndBufSizeMax: 32 * 1024,
state: StateInitial,
+ uniqueID: s.UniqueID(),
}
+
+ // Override with stack defaults.
+ var ss stack.SendBufferSizeOption
+ if err := s.Option(&ss); err == nil {
+ e.sndBufSizeMax = ss.Default
+ }
+
+ var rs stack.ReceiveBufferSizeOption
+ if err := s.Option(&rs); err == nil {
+ e.rcvBufSizeMax = rs.Default
+ }
+
+ return e
+}
+
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
+func (e *endpoint) takeLastError() *tcpip.Error {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+
+ err := e.lastError
+ e.lastError = nil
+ return err
+}
+
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
}
// Close puts the endpoint in a closed state and frees all resources
@@ -171,8 +231,10 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ e.boundBindToDevice = 0
+ e.boundPortFlags = ports.Flags{}
}
for _, mem := range e.multicastMemberships {
@@ -203,14 +265,13 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// IPTables implements tcpip.Endpoint.IPTables.
-func (e *endpoint) IPTables() (iptables.IPTables, error) {
- return e.stack.IPTables(), nil
-}
-
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ if err := e.takeLastError(); err != nil {
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -232,7 +293,29 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
*addr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ cm := tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ }
+ e.mu.RLock()
+ receiveTOS := e.receiveTOS
+ receiveTClass := e.receiveTClass
+ receiveIPPacketInfo := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ if receiveTOS {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ if receiveTClass {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+ if receiveIPPacketInfo {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
+ return p.data.ToView(), cm, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -278,7 +361,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// connectRoute establishes a route to the specified interface or the
// configured multicast interface if no interface is specified and the
// specified address is a multicast address.
-func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
+func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) {
localAddr := e.ID.LocalAddress
if isBroadcastOrMulticast(localAddr) {
// A packet can only originate from a unicast address (i.e., an interface).
@@ -286,20 +369,20 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
}
if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
- if nicid == 0 {
- nicid = e.multicastNICID
+ if nicID == 0 {
+ nicID = e.multicastNICID
}
- if localAddr == "" && nicid == 0 {
+ if localAddr == "" && nicID == 0 {
localAddr = e.multicastAddr
}
}
// Find a route to the desired destination.
- r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
+ r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop)
if err != nil {
return stack.Route{}, 0, err
}
- return r, nicid, nil
+ return r, nicID, nil
}
// Write writes data to the endpoint's peer. This method does not block
@@ -328,6 +411,10 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ if err := e.takeLastError(); err != nil {
+ return 0, nil, err
+ }
+
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -356,58 +443,68 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
var route *stack.Route
+ var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error)
var dstPort uint16
if to == nil {
route = &e.route
dstPort = e.dstPort
-
- if route.IsResolutionRequired() {
- // Promote lock to exclusive if using a shared route, given that it may need to
- // change in Route.Resolve() call below.
+ resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) {
+ // Promote lock to exclusive if using a shared route, given that it may
+ // need to change in Route.Resolve() call below.
e.mu.RUnlock()
- defer e.mu.RLock()
-
e.mu.Lock()
- defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
if e.state != StateConnected {
- return 0, nil, tcpip.ErrInvalidEndpointState
+ err = tcpip.ErrInvalidEndpointState
+ }
+ if err == nil && route.IsResolutionRequired() {
+ ch, err = route.Resolve(waker)
+ }
+
+ e.mu.Unlock()
+ e.mu.RLock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != StateConnected {
+ err = tcpip.ErrInvalidEndpointState
}
+ return
}
} else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
- nicid := to.NIC
+ nicID := to.NIC
if e.BindNICID != 0 {
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return 0, nil, tcpip.ErrNoRoute
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
}
- if to.Addr == header.IPv4Broadcast && !e.broadcast {
- return 0, nil, tcpip.ErrBroadcastDisabled
- }
-
- netProto, err := e.checkV4Mapped(to, false)
+ dst, netProto, err := e.checkV4MappedLocked(*to)
if err != nil {
return 0, nil, err
}
- r, _, err := e.connectRoute(nicid, *to, netProto)
+ r, _, err := e.connectRoute(nicID, dst, netProto)
if err != nil {
return 0, nil, err
}
defer r.Release()
route = &r
- dstPort = to.Port
+ dstPort = dst.Port
+ resolve = route.Resolve
+ }
+
+ if !e.broadcast && route.IsOutboundBroadcast() {
+ return 0, nil, tcpip.ErrBroadcastDisabled
}
if route.IsResolutionRequired() {
- if ch, err := route.Resolve(nil); err != nil {
+ if ch, err := resolve(nil); err != nil {
if err == tcpip.ErrWouldBlock {
return 0, ch, tcpip.ErrNoLinkAddress
}
@@ -433,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
useDefaultTTL = false
}
- if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil {
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil {
return 0, nil, err
}
return int64(len(v)), nil, nil
@@ -444,14 +541,54 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
- return nil
-}
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v
+ e.mu.Unlock()
+
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = v
+ e.mu.Unlock()
+
+ case tcpip.NoChecksumOption:
+ e.mu.Lock()
+ e.noChecksum = v
+ e.mu.Unlock()
+
+ case tcpip.ReceiveTOSOption:
+ e.mu.Lock()
+ e.receiveTOS = v
+ e.mu.Unlock()
+
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrNotSupported
+ }
+
+ e.mu.Lock()
+ e.receiveTClass = v
+ e.mu.Unlock()
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.portFlags.MostRecent = v
+ e.mu.Unlock()
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.portFlags.LoadBalanced = v
+ e.mu.Unlock()
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -466,24 +603,94 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- e.v6only = v != 0
+ e.v6only = v
+ }
+
+ return nil
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.MTUDiscoverOption:
+ // Return not supported if the value is not disabling path
+ // MTU discovery.
+ if v != tcpip.PMTUDiscoveryDont {
+ return tcpip.ErrNotSupported
+ }
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
e.mu.Unlock()
- case tcpip.MulticastTTLOption:
+ case tcpip.IPv4TOSOption:
e.mu.Lock()
- e.multicastTTL = uint8(v)
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.IPv6TrafficClassOption:
+ e.mu.Lock()
+ e.sendTOS = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs stack.ReceiveBufferSizeOption
+ if err := e.stack.Option(&rs); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err))
+ }
+
+ if v < rs.Min {
+ v = rs.Min
+ }
+ if v > rs.Max {
+ v = rs.Max
+ }
+
+ e.mu.Lock()
+ e.rcvBufSizeMax = v
e.mu.Unlock()
+ return nil
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ var ss stack.SendBufferSizeOption
+ if err := e.stack.Option(&ss); err != nil {
+ panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err))
+ }
+ if v < ss.Min {
+ v = ss.Min
+ }
+ if v > ss.Max {
+ v = ss.Max
+ }
+
+ e.mu.Lock()
+ e.sndBufSizeMax = v
+ e.mu.Unlock()
+ return nil
+ }
+
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
case tcpip.MulticastInterfaceOption:
e.mu.Lock()
defer e.mu.Unlock()
fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
- netProto, err := e.checkV4Mapped(&fa, false)
+ fa, netProto, err := e.checkV4MappedLocked(fa)
if err != nil {
return err
}
@@ -601,56 +808,124 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
- case tcpip.MulticastLoopOption:
+ case tcpip.BindToDeviceOption:
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
+ }
e.mu.Lock()
- e.multicastLoop = bool(v)
+ e.bindToDevice = id
e.mu.Unlock()
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
+ case tcpip.SocketDetachFilterOption:
+ return nil
+ }
+ return nil
+}
- case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.BroadcastOption:
+ e.mu.RLock()
+ v := e.broadcast
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.NoChecksumOption:
+ e.mu.RLock()
+ v := e.noChecksum
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReceiveTOSOption:
+ e.mu.RLock()
+ v := e.receiveTOS
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrNotSupported
}
- for nicid, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicid
- return nil
- }
+
+ e.mu.RLock()
+ v := e.receiveTClass
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.RLock()
+ v := e.portFlags.MostRecent
+ e.mu.RUnlock()
+
+ return v, nil
+
+ case tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.portFlags.LoadBalanced
+ e.mu.RUnlock()
+
+ return v, nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
}
- return tcpip.ErrUnknownDevice
- case tcpip.BroadcastOption:
- e.mu.Lock()
- e.broadcast = v != 0
- e.mu.Unlock()
+ e.mu.RLock()
+ v := e.v6only
+ e.mu.RUnlock()
- return nil
+ return v, nil
+
+ default:
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
case tcpip.IPv4TOSOption:
- e.mu.Lock()
- e.sendTOS = uint8(v)
- e.mu.Unlock()
- return nil
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
case tcpip.IPv6TrafficClassOption:
+ e.mu.RLock()
+ v := int(e.sendTOS)
+ e.mu.RUnlock()
+ return v, nil
+
+ case tcpip.MTUDiscoverOption:
+ // The only supported setting is path MTU discovery disabled.
+ return tcpip.PMTUDiscoveryDont, nil
+
+ case tcpip.MulticastTTLOption:
e.mu.Lock()
- e.sendTOS = uint8(v)
+ v := int(e.multicastTTL)
e.mu.Unlock()
- return nil
- }
- return nil
-}
+ return v, nil
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
- switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
e.rcvMu.Lock()
@@ -663,7 +938,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
case tcpip.SendBufferSizeOption:
e.mu.Lock()
- v := e.sndBufSize
+ v := e.sndBufSizeMax
e.mu.Unlock()
return v, nil
@@ -672,45 +947,23 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
v := e.rcvBufSizeMax
e.rcvMu.Unlock()
return v, nil
- }
- return -1, tcpip.ErrUnknownProtocolOption
+ case tcpip.TTLOption:
+ e.mu.Lock()
+ v := int(e.ttl)
+ e.mu.Unlock()
+ return v, nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
- return nil
-
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.TTLOption:
- e.mu.Lock()
- *o = tcpip.TTLOption(e.ttl)
- e.mu.Unlock()
- return nil
-
- case *tcpip.MulticastTTLOption:
- e.mu.Lock()
- *o = tcpip.MulticastTTLOption(e.multicastTTL)
- e.mu.Unlock()
- return nil
-
+ return e.takeLastError()
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
@@ -718,87 +971,43 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.multicastAddr,
}
e.mu.Unlock()
- return nil
-
- case *tcpip.MulticastLoopOption:
- e.mu.RLock()
- v := e.multicastLoop
- e.mu.RUnlock()
-
- *o = tcpip.MulticastLoopOption(v)
- return nil
-
- case *tcpip.ReusePortOption:
- e.mu.RLock()
- v := e.reusePort
- e.mu.RUnlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = tcpip.BindToDeviceOption("")
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
- case *tcpip.BroadcastOption:
- e.mu.RLock()
- v := e.broadcast
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
e.mu.RUnlock()
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
- case *tcpip.IPv4TOSOption:
- e.mu.RLock()
- *o = tcpip.IPv4TOSOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
-
- case *tcpip.IPv6TrafficClassOption:
- e.mu.RLock()
- *o = tcpip.IPv6TrafficClassOption(e.sendTOS)
- e.mu.RUnlock()
- return nil
-
default:
return tcpip.ErrUnknownProtocolOption
}
+ return nil
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
-func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error {
- // Allocate a buffer for the UDP header.
- hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error {
+ pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
+ Data: data,
+ })
+ pkt.Owner = owner
- // Initialize the header.
- udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ // Initialize the UDP header.
+ udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
- length := uint16(hdr.UsedLength() + data.Size())
+ length := uint16(pkt.Size())
udp.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
- // Only calculate the checksum if offloading isn't supported.
- if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ // Set the checksum field unless TX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value indicates the
+ // transmitter skipped the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 &&
+ (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
@@ -809,7 +1018,11 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
if useDefaultTTL {
ttl = r.DefaultTTL()
}
- if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil {
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
+ Protocol: ProtocolNumber,
+ TTL: ttl,
+ TOS: tos,
+ }, pkt); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
@@ -819,36 +1032,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
return nil
}
-func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
- netProto := e.NetProto
- if len(addr.Addr) == 0 {
- return netProto, nil
- }
- if header.IsV4MappedAddress(addr.Addr) {
- // Fail if using a v4 mapped address on a v6only endpoint.
- if e.v6only {
- return 0, tcpip.ErrNoRoute
- }
-
- netProto = header.IPv4ProtocolNumber
- addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
- if addr.Addr == header.IPv4Any {
- addr.Addr = ""
- }
-
- // Fail if we are bound to an IPv6 address.
- if !allowMismatch && len(e.ID.LocalAddress) == 16 {
- return 0, tcpip.ErrNetworkUnreachable
- }
- }
-
- // Fail if we're bound to an address length different from the one we're
- // checking.
- if l := len(e.ID.LocalAddress); l != 0 && l != len(addr.Addr) {
- return 0, tcpip.ErrInvalidEndpointState
+// checkV4MappedLocked determines the effective network protocol and converts
+// addr to its canonical form.
+func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only)
+ if err != nil {
+ return tcpip.FullAddress{}, 0, err
}
-
- return netProto, nil
+ return unwrapped, netProto, nil
}
// Disconnect implements tcpip.Endpoint.Disconnect.
@@ -859,7 +1050,15 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if e.state != StateConnected {
return nil
}
- id := stack.TransportEndpointID{}
+ var (
+ id stack.TransportEndpointID
+ btd tcpip.NICID
+ )
+
+ // We change this value below and we need the old value to unregister
+ // the endpoint.
+ boundPortFlags := e.boundPortFlags
+
// Exclude ephemerally bound endpoints.
if e.BindNICID != 0 || e.ID.LocalAddress == "" {
var err *tcpip.Error
@@ -867,21 +1066,24 @@ func (e *endpoint) Disconnect() *tcpip.Error {
LocalPort: e.ID.LocalPort,
LocalAddress: e.ID.LocalAddress,
}
- id, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
return err
}
e.state = StateBound
+ boundPortFlags = e.boundPortFlags
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ e.boundPortFlags = ports.Flags{}
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice)
e.ID = id
+ e.boundBindToDevice = btd
e.route.Release()
e.route = stack.Route{}
e.dstPort = 0
@@ -891,10 +1093,6 @@ func (e *endpoint) Disconnect() *tcpip.Error {
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
- netProto, err := e.checkV4Mapped(&addr, false)
- if err != nil {
- return err
- }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -903,7 +1101,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- nicid := addr.NIC
+ nicID := addr.NIC
var localPort uint16
switch e.state {
case StateInitial:
@@ -913,16 +1111,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
break
}
- if nicid != 0 && nicid != e.BindNICID {
+ if nicID != 0 && nicID != e.BindNICID {
return tcpip.ErrInvalidEndpointState
}
- nicid = e.BindNICID
+ nicID = e.BindNICID
default:
return tcpip.ErrInvalidEndpointState
}
- r, nicid, err := e.connectRoute(nicid, addr, netProto)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ r, nicID, err := e.connectRoute(nicID, addr, netProto)
if err != nil {
return err
}
@@ -950,20 +1153,23 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ oldPortFlags := e.boundPortFlags
+
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
// Remove the old registration.
if e.ID.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.bindToDevice)
+ e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice)
}
e.ID = id
+ e.boundBindToDevice = btd
e.route = r.Clone()
e.dstPort = addr.Port
- e.RegisterNICID = nicid
+ e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
e.state = StateConnected
@@ -1018,20 +1224,22 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
-func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{})
if err != nil {
- return id, err
+ return id, e.bindToDevice, err
}
id.LocalPort = port
}
+ e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.boundPortFlags = ports.Flags{}
}
- return id, err
+ return id, e.bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@@ -1041,7 +1249,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- netProto, err := e.checkV4Mapped(&addr, true)
+ addr, netProto, err := e.checkV4MappedLocked(addr)
if err != nil {
return err
}
@@ -1057,11 +1265,11 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
}
}
- nicid := addr.NIC
+ nicID := addr.NIC
if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) {
// A local unicast address was specified, verify that it's valid.
- nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
- if nicid == 0 {
+ nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nicID == 0 {
return tcpip.ErrBadLocalAddress
}
}
@@ -1070,13 +1278,14 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
- id, err = e.registerWithStack(nicid, netProtos, id)
+ id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
return err
}
e.ID = id
- e.RegisterNICID = nicid
+ e.boundBindToDevice = btd
+ e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
@@ -1111,9 +1320,14 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
+ addr := e.ID.LocalAddress
+ if e.state == StateConnected {
+ addr = e.route.LocalAddress
+ }
+
return tcpip.FullAddress{
NIC: e.RegisterNICID,
- Addr: e.ID.LocalAddress,
+ Addr: addr,
Port: e.ID.LocalPort,
}, nil
}
@@ -1154,22 +1368,47 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
-func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr := header.UDP(vv.First())
- if int(hdr.Length()) > vv.Size() {
+ hdr := header.UDP(pkt.TransportHeader().View())
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
- vv.TrimFront(header.UDPMinimumSize)
+ // Never receive from a multicast address.
+ if header.IsV4MulticastAddress(id.RemoteAddress) ||
+ header.IsV6MulticastAddress(id.RemoteAddress) {
+ e.stack.Stats().UDP.InvalidSourceAddress.Increment()
+ e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment()
+ e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ // Verify checksum unless RX checksum offload is enabled.
+ // On IPv4, UDP checksum is optional, and a zero value means
+ // the transmitter omitted the checksum generation (RFC768).
+ // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
+ if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 &&
+ (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length())
+ for _, v := range pkt.Data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ if hdr.CalculateChecksum(xsum) != 0xffff {
+ // Checksum Error.
+ e.stack.Stats().UDP.ChecksumErrors.Increment()
+ e.stats.ReceiveErrors.ChecksumErrors.Increment()
+ return
+ }
+ }
- e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
e.stats.PacketsReceived.Increment()
+ e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed {
e.rcvMu.Unlock()
@@ -1188,18 +1427,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
- pkt := &udpPacket{
+ packet := &udpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
}
- pkt.data = vv.Clone(pkt.views[:])
- e.rcvList.PushBack(pkt)
- e.rcvBufSize += vv.Size()
+ packet.data = pkt.Data
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += pkt.Data.Size()
+
+ // Save any useful information from the network header to the packet.
+ switch r.NetProto {
+ case header.IPv4ProtocolNumber:
+ packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS()
+ case header.IPv6ProtocolNumber:
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader().View()).TOS()
+ }
- pkt.timestamp = e.stack.NowNanoseconds()
+ // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
+ // address. packetInfo.LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.LocalAddress
+ packet.packetInfo.NIC = r.NICID()
+ packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
@@ -1210,7 +1463,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
+ if typ == stack.ControlPortUnreachable {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state == StateConnected {
+ e.lastErrorMu.Lock()
+ defer e.lastErrorMu.Unlock()
+
+ e.lastError = tcpip.ErrConnectionRefused
+ }
+ }
}
// State implements tcpip.Endpoint.State.
@@ -1234,6 +1498,13 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
+// Wait implements tcpip.Endpoint.Wait.
+func (*endpoint) Wait() {}
+
func isBroadcastOrMulticast(a tcpip.Address) bool {
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
}
+
+func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.owner = owner
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index b227e353b..851e6b635 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -37,6 +37,24 @@ func (u *udpPacket) loadData(data buffer.VectorisedView) {
u.data = data
}
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = tcpip.StringToError(s)
+}
+
// beforeSave is invoked by stateify.
func (e *endpoint) beforeSave() {
// Stop incoming packets from being handled (and mutate endpoint state).
@@ -69,6 +87,9 @@ func (e *endpoint) afterLoad() {
// Resume implements tcpip.ResumableEndpoint.Resume.
func (e *endpoint) Resume(s *stack.Stack) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
e.stack = s
for _, m := range e.multicastMemberships {
@@ -109,7 +130,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
// pass it to the reservation machinery.
id := e.ID
e.ID.LocalPort = 0
- e.ID, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
+ e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id)
if err != nil {
panic(err)
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index d399ec722..c67e0ba95 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -16,7 +16,6 @@ package udp
import (
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -44,12 +43,12 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
-func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
f.handler(&ForwarderRequest{
stack: f.stack,
route: r,
id: id,
- vv: vv,
+ pkt: pkt,
})
return true
@@ -62,7 +61,7 @@ type ForwarderRequest struct {
stack *stack.Stack
route *stack.Route
id stack.TransportEndpointID
- vv buffer.VectorisedView
+ pkt *stack.PacketBuffer
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
@@ -74,7 +73,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
@@ -83,6 +82,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.route = r.route.Clone()
ep.dstPort = r.id.RemotePort
ep.RegisterNICID = r.route.NICID()
+ ep.boundPortFlags = ep.portFlags
ep.state = StateConnected
@@ -90,7 +90,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.rcvReady = true
ep.rcvMu.Unlock()
- ep.HandlePacket(r.route, r.id, r.vv)
+ ep.HandlePacket(r.route, r.id, r.pkt)
return ep, nil
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 5c3358a5e..63d4bed7c 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -32,9 +32,24 @@ import (
const (
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ MinBufferSize = 4 << 10 // 4KiB bytes.
+
+ // DefaultSendBufferSize is the default size of the send buffer for
+ // an endpoint.
+ DefaultSendBufferSize = 32 << 10 // 32KiB
+
+ // DefaultReceiveBufferSize is the default size of the receive buffer
+ // for an endpoint.
+ DefaultReceiveBufferSize = 32 << 10 // 32KiB
+
+ // MaxBufferSize is the largest size a receive/send buffer can grow to.
+ MaxBufferSize = 4 << 20 // 4MiB
)
-type protocol struct{}
+type protocol struct {
+}
// Number returns the udp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
@@ -66,10 +81,9 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
- // Get the header then trim it from the view.
- hdr := header.UDP(vv.First())
- if int(hdr.Length()) > vv.Size() {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
+ hdr := header.UDP(pkt.TransportHeader().View())
+ if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true
@@ -116,28 +130,30 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
available := int(mtu) - headerLen
- payloadLen := len(netHeader) + vv.Size()
+ payloadLen := pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size() + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- // The buffers used by vv and netHeader may be used elsewhere
- // in the system. For example, a raw or packet socket may use
- // what UDP considers an unreachable destination. Thus we deep
- // copy vv and netHeader to prevent multiple ownership and SR
- // errors.
- newNetHeader := make(buffer.View, len(netHeader))
- copy(newNetHeader, netHeader)
- payload := buffer.NewVectorisedView(len(newNetHeader), []buffer.View{newNetHeader})
- payload.Append(vv.ToView().ToVectorisedView())
+ // The buffers used by pkt may be used elsewhere in the system.
+ // For example, a raw or packet socket may use what UDP
+ // considers an unreachable destination. Thus we deep copy pkt
+ // to prevent multiple ownership and SR errors.
+ newHeader := append(buffer.View(nil), pkt.NetworkHeader().View()...)
+ newHeader = append(newHeader, pkt.TransportHeader().View()...)
+ payload := newHeader.ToVectorisedView()
+ payload.AppendView(pkt.Data.ToView())
payload.CapLength(payloadLen)
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
- pkt.SetType(header.ICMPv4DstUnreachable)
- pkt.SetCode(header.ICMPv4PortUnreachable)
- pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
- r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+ icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize))
+ icmpHdr.SetType(header.ICMPv4DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv4PortUnreachable)
+ icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data))
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
case header.IPv6AddressSize:
if !r.Stack().AllowICMPMessage() {
@@ -158,34 +174,50 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
}
headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
available := int(mtu) - headerLen
- payloadLen := len(netHeader) + vv.Size()
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ payloadLen := len(network) + len(transport) + pkt.Data.Size()
if payloadLen > available {
payloadLen = available
}
- payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
- payload.Append(vv)
+ payload := buffer.NewVectorisedView(len(network)+len(transport), []buffer.View{network, transport})
+ payload.Append(pkt.Data)
payload.CapLength(payloadLen)
- hdr := buffer.NewPrependable(headerLen)
- pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
- pkt.SetType(header.ICMPv6DstUnreachable)
- pkt.SetCode(header.ICMPv6PortUnreachable)
- pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
- r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS})
+ icmpPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+ ReserveHeaderBytes: headerLen,
+ Data: payload,
+ })
+ icmpHdr := header.ICMPv6(icmpPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize))
+ icmpHdr.SetType(header.ICMPv6DstUnreachable)
+ icmpHdr.SetCode(header.ICMPv6PortUnreachable)
+ icmpHdr.SetChecksum(header.ICMPv6Checksum(icmpHdr, r.LocalAddress, r.RemoteAddress, icmpPkt.Data))
+ r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, icmpPkt)
}
return true
}
-// SetOption implements TransportProtocol.SetOption.
+// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
+// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
+// Parse implements stack.TransportProtocol.Parse.
+func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
+ _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize)
+ return ok
+}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index b724d788c..f87d99d5a 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -16,6 +16,7 @@ package udp_test
import (
"bytes"
+ "context"
"fmt"
"math/rand"
"testing"
@@ -56,6 +57,7 @@ const (
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
broadcastAddr = header.IPv4Broadcast
+ testTOS = 0x80
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -81,16 +83,18 @@ type header4Tuple struct {
type testFlow int
const (
- unicastV4 testFlow = iota // V4 unicast on a V4 socket
- unicastV4in6 // V4-mapped unicast on a V6-dual socket
- unicastV6 // V6 unicast on a V6 socket
- unicastV6Only // V6 unicast on a V6-only socket
- multicastV4 // V4 multicast on a V4 socket
- multicastV4in6 // V4-mapped multicast on a V6-dual socket
- multicastV6 // V6 multicast on a V6 socket
- multicastV6Only // V6 multicast on a V6-only socket
- broadcast // V4 broadcast on a V4 socket
- broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ unicastV4 testFlow = iota // V4 unicast on a V4 socket
+ unicastV4in6 // V4-mapped unicast on a V6-dual socket
+ unicastV6 // V6 unicast on a V6 socket
+ unicastV6Only // V6 unicast on a V6-only socket
+ multicastV4 // V4 multicast on a V4 socket
+ multicastV4in6 // V4-mapped multicast on a V6-dual socket
+ multicastV6 // V6 multicast on a V6 socket
+ multicastV6Only // V6 multicast on a V6-only socket
+ broadcast // V4 broadcast on a V4 socket
+ broadcastIn6 // V4-mapped broadcast on a V6-dual socket
+ reverseMulticast4 // V4 multicast src. Must fail.
+ reverseMulticast6 // V6 multicast src. Must fail.
)
func (flow testFlow) String() string {
@@ -115,6 +119,10 @@ func (flow testFlow) String() string {
return "broadcast"
case broadcastIn6:
return "broadcastIn6"
+ case reverseMulticast4:
+ return "reverseMulticast4"
+ case reverseMulticast6:
+ return "reverseMulticast6"
default:
return "unknown"
}
@@ -166,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple {
h.dstAddr.Addr = multicastV6Addr
}
}
+ if flow.isReverseMulticast() {
+ h.srcAddr.Addr = flow.getMcastAddr()
+ }
return h
}
@@ -197,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber {
// endpoint for this flow.
func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber {
switch flow {
- case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6:
+ case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6:
return ipv6.ProtocolNumber
- case unicastV4, multicastV4, broadcast:
+ case unicastV4, multicastV4, broadcast, reverseMulticast4:
return ipv4.ProtocolNumber
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -222,7 +233,7 @@ func (flow testFlow) isV6Only() bool {
switch flow {
case unicastV6Only, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -233,7 +244,7 @@ func (flow testFlow) isMulticast() bool {
switch flow {
case multicastV4, multicastV4in6, multicastV6, multicastV6Only:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -244,7 +255,7 @@ func (flow testFlow) isBroadcast() bool {
switch flow {
case broadcast, broadcastIn6:
return true
- case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only:
+ case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
@@ -255,13 +266,22 @@ func (flow testFlow) isMapped() bool {
switch flow {
case unicastV4in6, multicastV4in6, broadcastIn6:
return true
- case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast:
+ case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6:
return false
default:
panic(fmt.Sprintf("invalid testFlow given: %d", flow))
}
}
+func (flow testFlow) isReverseMulticast() bool {
+ switch flow {
+ case reverseMulticast4, reverseMulticast6:
+ return true
+ default:
+ return false
+ }
+}
+
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
@@ -273,11 +293,16 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
-
- s := stack.New(stack.Options{
+ return newDualTestContextWithOptions(t, mtu, stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
})
+}
+
+func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext {
+ t.Helper()
+
+ s := stack.New(options)
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -285,15 +310,15 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
wep = sniffer.New(ep)
}
if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %v", err)
+ t.Fatalf("CreateNIC failed: %s", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %v", err)
+ t.Fatalf("AddAddress failed: %s", err)
}
s.SetRouteTable([]tcpip.Route{
@@ -335,12 +360,12 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
}
} else if flow.isBroadcast() {
- if err := c.ep.SetSockOpt(tcpip.BroadcastOption(1)); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
+ if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ c.t.Fatalf("SetSockOptBool failed: %s", err)
}
}
}
@@ -351,30 +376,30 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte {
c.t.Helper()
- select {
- case p := <-c.linkEP.C:
- if p.Proto != flow.netProto() {
- c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
- }
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
-
- h := flow.header4Tuple(outgoing)
- checkers := append(
- checkers,
- checker.SrcAddr(h.srcAddr.Addr),
- checker.DstAddr(h.dstAddr.Addr),
- checker.UDP(checker.DstPort(h.dstAddr.Port)),
- )
- flow.checkerFn()(c.t, b, checkers...)
- return b
-
- case <-time.After(2 * time.Second):
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
c.t.Fatalf("Packet wasn't written out")
+ return nil
}
- return nil
+ if p.Proto != flow.netProto() {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto())
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ h := flow.header4Tuple(outgoing)
+ checkers = append(
+ checkers,
+ checker.SrcAddr(h.srcAddr.Addr),
+ checker.DstAddr(h.dstAddr.Addr),
+ checker.UDP(checker.DstPort(h.dstAddr.Port)),
+ )
+ flow.checkerFn()(c.t, b, checkers...)
+ return b
}
// injectPacket creates a packet of the given flow and with the given payload,
@@ -384,24 +409,30 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) {
h := flow.header4Tuple(incoming)
if flow.isV4() {
- c.injectV4Packet(payload, &h, true /* valid */)
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
} else {
- c.injectV6Packet(payload, &h, true /* valid */)
+ buf := c.buildV6Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
}
}
-// injectV6Packet creates a V6 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV6Packet creates a V6 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
@@ -411,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
- l := uint16(header.UDPMinimumSize + len(payload))
- if !valid {
- // Change the UDP payload length to corrupt the header
- // as requested by the caller.
- l++
- }
u.Encode(&header.UDPFields{
SrcPort: h.srcAddr.Port,
DstPort: h.dstAddr.Port,
- Length: l,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
@@ -430,23 +455,22 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ return buf
}
-// injectV4Packet creates a V4 test packet with the given payload and header
-// values, and injects it into the link endpoint. valid indicates if the
-// caller intends to inject a packet with a valid or an invalid UDP header.
-// We can invalidate the header by corrupting the UDP payload length.
-func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) {
+// buildV4Packet creates a V4 test packet with the given payload and header
+// values in a buffer.
+func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
- copy(buf[len(buf)-len(payload):], payload)
+ payloadStart := len(buf) - len(payload)
+ copy(buf[payloadStart:], payload)
// Initialize the IP header.
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
+ TOS: testTOS,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
@@ -470,8 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum))
- // Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+ return buf
}
func newPayload() []byte {
@@ -493,50 +516,46 @@ func TestBindToDeviceOption(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
- t.Fatalf("NewEndpoint failed; %v", err)
+ t.Fatalf("NewEndpoint failed; %s", err)
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
- }
-
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "my_device"}
+ if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -545,8 +564,8 @@ func TestBindToDeviceOption(t *testing.T) {
// testReadInternal sends a packet of the given test flow into the stack by
// injecting it into the link endpoint. It then attempts to read it from the
// UDP endpoint and depending on if this was expected to succeed verifies its
-// correctness.
-func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool) {
+// correctness including any additional checker functions provided.
+func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
payload := newPayload()
@@ -561,12 +580,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
var addr tcpip.FullAddress
- v, _, err := c.ep.Read(&addr)
+ v, cm, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, _, err = c.ep.Read(&addr)
+ v, cm, err = c.ep.Read(&addr)
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -592,22 +611,28 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("unexpected remote address: got %s, want %s", addr.Addr, h.srcAddr)
+ c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
c.t.Fatalf("bad payload: got %x, want %x", v, payload)
}
+
+ // Run any checkers against the ControlMessages.
+ for _, f := range checkers {
+ f(c.t, cm)
+ }
+
c.checkEndpointReadStats(1, epstats, err)
}
// testRead sends a packet of the given test flow into the stack by injecting it
// into the link endpoint. It then reads it from the UDP endpoint and verifies
-// its correctness.
-func testRead(c *testContext, flow testFlow) {
+// its correctness including any additional checker functions provided.
+func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) {
c.t.Helper()
- testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */)
+ testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...)
}
// testFailingRead sends a packet of the given test flow into the stack by
@@ -625,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Bind(tcpip.FullAddress{}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}
@@ -636,19 +661,19 @@ func TestBindReservedPort(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
- t.Fatalf("GetLocalAddress failed: %v", err)
+ t.Fatalf("GetLocalAddress failed: %s", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
@@ -659,7 +684,7 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
@@ -669,7 +694,7 @@ func TestBindReservedPort(t *testing.T) {
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
@@ -679,11 +704,11 @@ func TestBindReservedPort(t *testing.T) {
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
- t.Fatalf("NewEndpoint failed: %v", err)
+ t.Fatalf("NewEndpoint failed: %s", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
+ t.Fatalf("ep.Bind(...) failed: %s", err)
}
}()
}
@@ -696,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -711,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -726,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -741,13 +766,59 @@ func TestV6ReadOnV6(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
testRead(c, unicastV6)
}
+// TestV4ReadSelfSource checks that packets coming from a local IP address are
+// correctly dropped when handleLocal is true and not otherwise.
+func TestV4ReadSelfSource(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ handleLocal bool
+ wantErr *tcpip.Error
+ wantInvalidSource uint64
+ }{
+ {"HandleLocal", false, nil, 0},
+ {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1},
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ HandleLocal: tt.handleLocal,
+ })
+ defer c.cleanup()
+
+ c.createEndpointForFlow(unicastV4)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ h.srcAddr = h.dstAddr
+
+ buf := c.buildV4Packet(payload, &h)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource {
+ t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
+ }
+
+ if _, _, err := c.ep.Read(nil); err != tt.wantErr {
+ t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -756,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Test acceptance.
@@ -819,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) {
}
}
+// TestReadFromMulticast checks that an endpoint will NOT receive a packet
+// that was sent with multicast SOURCE address.
+func TestReadFromMulticast(t *testing.T) {
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ testFailingRead(c, flow, false /* expectReadError */)
+ })
+ }
+}
+
+// TestReadFromMulticaststats checks that a discarded packet
+// that that was sent with multicast SOURCE address increments
+// the correct counters and that a regular packet does not.
+func TestReadFromMulticastStats(t *testing.T) {
+ t.Helper()
+ for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ c.injectPacket(flow, payload)
+
+ var want uint64 = 0
+ if flow.isReverseMulticast() {
+ want = 1
+ }
+ if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want {
+ t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want)
+ }
+ if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want {
+ t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+ })
+ }
+}
+
// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY
// and receive broadcast and unicast data.
func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) {
@@ -894,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
if err != nil {
- c.t.Fatalf("Write failed: %v", err)
+ c.t.Fatalf("Write failed: %s", err)
}
if n != int64(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
@@ -944,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
p := testDualWrite(c)
@@ -961,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV6)
@@ -982,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testWrite(c, unicastV4in6)
@@ -1009,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
// Write to v6 address.
@@ -1024,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) {
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV6)
@@ -1038,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) {
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
testWriteWithoutDestination(c, unicastV4)
@@ -1173,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
testRead(c, unicastV4)
@@ -1184,6 +1309,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
}
}
+func TestReadIPPacketInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedLocalAddr tcpip.Address
+ expectedDestAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedLocalAddr: stackAddr,
+ expectedDestAddr: stackAddr,
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastAddr,
+ expectedDestAddr: multicastAddr,
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: broadcastAddr,
+ expectedDestAddr: broadcastAddr,
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedLocalAddr: stackV6Addr,
+ expectedDestAddr: stackV6Addr,
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastV6Addr,
+ expectedDestAddr: multicastV6Addr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err)
+ }
+ }
+
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
+ t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
+ }
+
+ testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: 1,
+ LocalAddr: test.expectedLocalAddr,
+ DestinationAddr: test.expectedDestAddr,
+ }))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1198,6 +1422,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
+func TestNoChecksum(t *testing.T) {
+ for _, flow := range []testFlow{unicastV4, unicastV6} {
+ t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+
+ // Disable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ // This option is effective on IPv4 only.
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4())))
+
+ // Enable the checksum generation.
+ if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil {
+ t.Fatalf("SetSockOptBool failed: %s", err)
+ }
+ testWrite(c, flow, checker.UDP(checker.NoChecksum(false)))
+ })
+ }
+}
+
func TestTTL(t *testing.T) {
for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1207,8 +1455,8 @@ func TestTTL(t *testing.T) {
c.createEndpointForFlow(flow)
const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil {
+ c.t.Fatalf("SetSockOptInt failed: %s", err)
}
var wantTTL uint8
@@ -1221,10 +1469,10 @@ func TestTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
+ ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
wantTTL = ep.DefaultTTL()
ep.Close()
}
@@ -1244,8 +1492,8 @@ func TestSetTTL(t *testing.T) {
c.createEndpointForFlow(flow)
- if err := c.ep.SetSockOpt(tcpip.TTLOption(wantTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil {
+ c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err)
}
var p stack.NetworkProtocol
@@ -1254,10 +1502,10 @@ func TestSetTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
+ ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
ep.Close()
testWrite(c, flow, checker.TTL(wantTTL))
@@ -1267,7 +1515,7 @@ func TestSetTTL(t *testing.T) {
}
}
-func TestTOSV4(t *testing.T) {
+func TestSetTOS(t *testing.T) {
for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1275,26 +1523,27 @@ func TestTOSV4(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
- var v tcpip.IPv4TOSOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ const tos = testTOS
+ v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err)
}
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err)
}
- if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tos {
+ c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos)
}
testWrite(c, flow, checker.TOS(tos, 0))
@@ -1302,7 +1551,7 @@ func TestTOSV4(t *testing.T) {
}
}
-func TestTOSV6(t *testing.T) {
+func TestSetTClass(t *testing.T) {
for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1310,33 +1559,96 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = 0xC0
- var v tcpip.IPv6TrafficClassOption
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ const tClass = testTOS
+ v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt failed: %s", err)
+ if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil {
+ c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err)
}
- if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if v != tClass {
+ c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass)
}
- testWrite(c, flow, checker.TOS(tos, 0))
+ // The header getter for TClass is called TOS, so use that checker.
+ testWrite(c, flow, checker.TOS(tClass, 0))
})
}
}
+func TestReceiveTosTClass(t *testing.T) {
+ testCases := []struct {
+ name string
+ getReceiveOption tcpip.SockOptBool
+ tests []testFlow
+ }{
+ {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
+ {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ }
+ for _, testCase := range testCases {
+ for _, flow := range testCase.tests {
+ t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpointForFlow(flow)
+ option := testCase.getReceiveOption
+ name := testCase.name
+
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
+ }
+
+ want := true
+ if err := c.ep.SetSockOptBool(option, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
+ }
+
+ got, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err)
+ }
+
+ if got != want {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
+ }
+
+ // Verify that the correct received TOS or TClass is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+ switch option {
+ case tcpip.ReceiveTClassOption:
+ testRead(c, flow, checker.ReceiveTClass(testTOS))
+ case tcpip.ReceiveTOSOption:
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ default:
+ t.Fatalf("unknown test variant: %s", name)
+ }
+ })
+ }
+ }
+}
+
func TestMulticastInterfaceOption(t *testing.T) {
for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
@@ -1375,12 +1687,12 @@ func TestMulticastInterfaceOption(t *testing.T) {
Port: stackPort,
}
if err := c.ep.Connect(addr); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
}
if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
+ c.t.Fatalf("SetSockOpt failed: %s", err)
}
// Verify multicast interface addr and NIC were set correctly.
@@ -1388,7 +1700,7 @@ func TestMulticastInterfaceOption(t *testing.T) {
ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %v", err)
+ c.t.Fatalf("GetSockOpt failed: %s", err)
}
if ifoptGot != ifoptWant {
c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
@@ -1431,48 +1743,51 @@ func TestV4UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Header...)
- pkt = append(pkt, p.Payload...)
- if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
- hdr := header.IPv4(pkt)
- checker.IPv4(t, hdr, checker.ICMPv4(
- checker.ICMPv4Type(header.ICMPv4DstUnreachable),
- checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- icmpPkt := header.ICMPv4(hdr.Payload())
- payloadIPHeader := header.IPv4(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
- }
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %d, want: %d", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
}
})
}
@@ -1505,54 +1820,57 @@ func TestV6UnknownDestination(t *testing.T) {
}
c.injectPacket(tc.flow, payload)
if !tc.icmpRequired {
- select {
- case p := <-c.linkEP.C:
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ if p, ok := c.linkEP.ReadContext(ctx); ok {
t.Fatalf("unexpected packet received: %+v", p)
- case <-time.After(1 * time.Second):
- return
}
+ return
}
- select {
- case p := <-c.linkEP.C:
- var pkt []byte
- pkt = append(pkt, p.Header...)
- pkt = append(pkt, p.Payload...)
- if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
- t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
- }
+ // ICMP required.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ p, ok := c.linkEP.ReadContext(ctx)
+ if !ok {
+ t.Fatalf("packet wasn't written out")
+ return
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ pkt := vv.ToView()
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
- hdr := header.IPv6(pkt)
- checker.IPv6(t, hdr, checker.ICMPv6(
- checker.ICMPv6Type(header.ICMPv6DstUnreachable),
- checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
- icmpPkt := header.ICMPv6(hdr.Payload())
- payloadIPHeader := header.IPv6(icmpPkt.Payload())
- wantLen := len(payload)
- if tc.largePayload {
- wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
- }
- // In case of large payloads the IP packet may be truncated. Update
- // the length field before retrieving the udp datagram payload.
- payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
- origDgram := header.UDP(payloadIPHeader.Payload())
- if got, want := len(origDgram.Payload()), wantLen; got != want {
- t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
- }
- if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
- t.Fatalf("unexpected payload got: %v, want: %v", got, want)
- }
- case <-time.After(1 * time.Second):
- t.Fatalf("packet wasn't written out")
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
}
})
}
}
// TestIncrementMalformedPacketsReceived verifies if the malformed received
-// global and endpoint stats get incremented.
+// global and endpoint stats are incremented.
func TestIncrementMalformedPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1560,20 +1878,271 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
payload := newPayload()
- c.t.Helper()
h := unicastV6.header4Tuple(incoming)
- c.injectV6Packet(payload, &h, false /* !valid */)
+ buf := c.buildV6Packet(payload, &h)
- var want uint64 = 1
+ // Invalidate the UDP header length field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetLength(u.Length() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want)
+ t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want)
}
if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want {
- t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want)
+ t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want)
+ }
+}
+
+// TestShortHeader verifies that when a packet with a too-short UDP header is
+// received, the malformed received global stat gets incremented.
+func TestShortHeader(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ h := unicastV6.header4Tuple(incoming)
+
+ // Allocate a buffer for an IPv6 and too-short UDP header.
+ const udpSize = header.UDPMinimumSize - 1
+ buf := buffer.NewView(header.IPv6MinimumSize + udpSize)
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
+ PayloadLength: uint16(udpSize),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: h.srcAddr.Addr,
+ DstAddr: h.dstAddr.Addr,
+ })
+
+ // Initialize the UDP header.
+ udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize))
+ udpHdr.Encode(&header.UDPFields{
+ SrcPort: h.srcAddr.Port,
+ DstPort: h.dstAddr.Port,
+ Length: header.UDPMinimumSize,
+ })
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr)))
+ udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
+ // Copy all but the last byte of the UDP header into the packet.
+ copy(buf[header.IPv6MinimumSize:], udpHdr)
+
+ // Inject packet.
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want {
+ t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field, taking care to avoid
+ // overflow to zero, which would disable checksum validation.
+ for u := header.UDP(buf[header.IPv4MinimumSize:]); ; {
+ u.SetChecksum(u.Checksum() + 1)
+ if u.Checksum() != 0 {
+ break
+ }
+ }
+
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestIncrementChecksumErrorsV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+
+ // Invalidate the UDP header checksum field.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(u.Checksum() + 1)
+
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV4 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestPayloadModifiedV6 verifies if a checksum error is detected,
+// global and endpoint stats are incremented.
+func TestPayloadModifiedV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Modify the payload so that the checksum value in the UDP header will be incorrect.
+ buf[len(buf)-1]++
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV4 verifies if the checksum value is zero, global and
+// endpoint states are *not* incremented (UDP checksum is optional on IPv4).
+func TestChecksumZeroV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv4.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV4.header4Tuple(incoming)
+ buf := c.buildV4Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 0
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
+ }
+}
+
+// TestChecksumZeroV6 verifies if the checksum value is zero, global and
+// endpoint states are incremented (UDP checksum is *not* optional on IPv6).
+func TestChecksumZeroV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(ipv6.ProtocolNumber)
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+
+ payload := newPayload()
+ h := unicastV6.header4Tuple(incoming)
+ buf := c.buildV6Packet(payload, &h)
+ // Set the checksum field in the UDP header to zero.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.SetChecksum(0)
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buf.ToVectorisedView(),
+ }))
+
+ const want = 1
+ if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want {
+ t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want)
+ }
+ if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want {
+ t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want)
}
}
@@ -1587,15 +2156,15 @@ func TestShutdownRead(t *testing.T) {
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatalf("Bind failed: %v", err)
+ c.t.Fatalf("Bind failed: %s", err)
}
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingRead(c, unicastV6, true /* expectReadError */)
@@ -1618,11 +2187,11 @@ func TestShutdownWrite(t *testing.T) {
c.createEndpoint(ipv6.ProtocolNumber)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
- c.t.Fatalf("Connect failed: %v", err)
+ c.t.Fatalf("Connect failed: %s", err)
}
if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil {
- t.Fatalf("Shutdown failed: %v", err)
+ t.Fatalf("Shutdown failed: %s", err)
}
testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend)
@@ -1664,3 +2233,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn
c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want)
}
}
+
+func TestOutgoingSubnetBroadcast(t *testing.T) {
+ const nicID1 = 1
+
+ ipv4Addr := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 24,
+ }
+ ipv4Subnet := ipv4Addr.Subnet()
+ ipv4SubnetBcast := ipv4Subnet.Broadcast()
+ ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01")
+ ipv4AddrPrefix31 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 31,
+ }
+ ipv4Subnet31 := ipv4AddrPrefix31.Subnet()
+ ipv4Subnet31Bcast := ipv4Subnet31.Broadcast()
+ ipv4AddrPrefix32 := tcpip.AddressWithPrefix{
+ Address: "\xc0\xa8\x01\x3a",
+ PrefixLen: 32,
+ }
+ ipv4Subnet32 := ipv4AddrPrefix32.Subnet()
+ ipv4Subnet32Bcast := ipv4Subnet32.Broadcast()
+ ipv6Addr := tcpip.AddressWithPrefix{
+ Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ PrefixLen: 64,
+ }
+ ipv6Subnet := ipv6Addr.Subnet()
+ ipv6SubnetBcast := ipv6Subnet.Broadcast()
+ remNetAddr := tcpip.AddressWithPrefix{
+ Address: "\x64\x0a\x7b\x18",
+ PrefixLen: 24,
+ }
+ remNetSubnet := remNetAddr.Subnet()
+ remNetSubnetBcast := remNetSubnet.Broadcast()
+
+ tests := []struct {
+ name string
+ nicAddr tcpip.ProtocolAddress
+ routes []tcpip.Route
+ remoteAddr tcpip.Address
+ requiresBroadcastOpt bool
+ }{
+ {
+ name: "IPv4 Broadcast to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4SubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ {
+ name: "IPv4 Broadcast to local /31 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix31,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet31,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet31Bcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to local /32 subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4AddrPrefix32,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv4Subnet32,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv4Subnet32Bcast,
+ requiresBroadcastOpt: false,
+ },
+ // IPv6 has no notion of a broadcast.
+ {
+ name: "IPv6 'Broadcast' to local subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv6ProtocolNumber,
+ AddressWithPrefix: ipv6Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: ipv6Subnet,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: ipv6SubnetBcast,
+ requiresBroadcastOpt: false,
+ },
+ {
+ name: "IPv4 Broadcast to remote subnet",
+ nicAddr: tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: ipv4Addr,
+ },
+ routes: []tcpip.Route{
+ {
+ Destination: remNetSubnet,
+ Gateway: ipv4Gateway,
+ NIC: nicID1,
+ },
+ },
+ remoteAddr: remNetSubnetBcast,
+ requiresBroadcastOpt: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+ if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil {
+ t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err)
+ }
+
+ s.SetRouteTable(test.routes)
+
+ var netProto tcpip.NetworkProtocolNumber
+ switch l := len(test.remoteAddr); l {
+ case header.IPv4AddressSize:
+ netProto = header.IPv4ProtocolNumber
+ case header.IPv6AddressSize:
+ netProto = header.IPv6ProtocolNumber
+ default:
+ t.Fatalf("got unexpected address length = %d bytes", l)
+ }
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err)
+ }
+ defer ep.Close()
+
+ data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ to := tcpip.FullAddress{
+ Addr: test.remoteAddr,
+ Port: 80,
+ }
+ opts := tcpip.WriteOptions{To: &to}
+ expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled
+ if !test.requiresBroadcastOpt {
+ expectedErrWithoutBcastOpt = nil
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != nil {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err)
+ }
+
+ if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil {
+ t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err)
+ }
+
+ if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt)
+ }
+ })
+ }
+}
diff --git a/runsc/criutil/BUILD b/pkg/test/criutil/BUILD
index 558133a0e..a7b082cee 100644
--- a/runsc/criutil/BUILD
+++ b/pkg/test/criutil/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,7 +6,9 @@ go_library(
name = "criutil",
testonly = 1,
srcs = ["criutil.go"],
- importpath = "gvisor.dev/gvisor/runsc/criutil",
visibility = ["//:sandbox"],
- deps = ["//runsc/testutil"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
)
diff --git a/runsc/criutil/criutil.go b/pkg/test/criutil/criutil.go
index 773f5a1c4..70945f234 100644
--- a/runsc/criutil/criutil.go
+++ b/pkg/test/criutil/criutil.go
@@ -22,48 +22,72 @@ import (
"fmt"
"os"
"os/exec"
+ "path"
+ "regexp"
+ "strconv"
"strings"
"time"
- "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
-const endpointPrefix = "unix://"
-
// Crictl contains information required to run the crictl utility.
type Crictl struct {
- executable string
- timeout time.Duration
- imageEndpoint string
- runtimeEndpoint string
+ logger testutil.Logger
+ endpoint string
+ runpArgs []string
+ cleanup []func()
+}
+
+// ResolvePath attempts to find binary paths. It may set the path to invalid,
+// which will cause the execution to fail with a sensible error.
+func ResolvePath(executable string) string {
+ runtime, err := dockerutil.RuntimePath()
+ if err == nil {
+ // Check first the directory of the runtime itself.
+ if dir := path.Dir(runtime); dir != "" && dir != "." {
+ guess := path.Join(dir, executable)
+ if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 {
+ return guess
+ }
+ }
+ }
+
+ // Try to find via the path.
+ guess, err := exec.LookPath(executable)
+ if err == nil {
+ return guess
+ }
+
+ // Return a default path.
+ return fmt.Sprintf("/usr/local/bin/%s", executable)
}
// NewCrictl returns a Crictl configured with a timeout and an endpoint over
// which it will talk to containerd.
-func NewCrictl(timeout time.Duration, endpoint string) *Crictl {
- // Bazel doesn't pass PATH through, assume the location of crictl
- // unless specified by environment variable.
- executable := os.Getenv("CRICTL_PATH")
- if executable == "" {
- executable = "/usr/local/bin/crictl"
- }
+func NewCrictl(logger testutil.Logger, endpoint string, runpArgs []string) *Crictl {
+ // Attempt to find the executable, but don't bother propagating the
+ // error at this point. The first command executed will return with a
+ // binary not found error.
return &Crictl{
- executable: executable,
- timeout: timeout,
- imageEndpoint: endpointPrefix + endpoint,
- runtimeEndpoint: endpointPrefix + endpoint,
+ logger: logger,
+ endpoint: endpoint,
+ runpArgs: runpArgs,
}
}
-// Pull pulls an container image. It corresponds to `crictl pull`.
-func (cc *Crictl) Pull(imageName string) error {
- _, err := cc.run("pull", imageName)
- return err
+// CleanUp executes cleanup functions.
+func (cc *Crictl) CleanUp() {
+ for _, c := range cc.cleanup {
+ c()
+ }
+ cc.cleanup = nil
}
// RunPod creates a sandbox. It corresponds to `crictl runp`.
-func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
- podID, err := cc.run("runp", sbSpecFile)
+func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) {
+ podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile)
if err != nil {
return "", fmt.Errorf("runp failed: %v", err)
}
@@ -74,10 +98,42 @@ func (cc *Crictl) RunPod(sbSpecFile string) (string, error) {
// Create creates a container within a sandbox. It corresponds to `crictl
// create`.
func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) {
- podID, err := cc.run("create", podID, contSpecFile, sbSpecFile)
+ // In version 1.16.0, crictl annoying starting attempting to pull the
+ // container, even if it was already available locally. We therefore
+ // need to parse the version and add an appropriate --no-pull argument
+ // since the image has already been loaded locally.
+ out, err := cc.run("-v")
+ if err != nil {
+ return "", err
+ }
+ r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])")
+ vs := r.FindStringSubmatch(out)
+ if len(vs) != 4 {
+ return "", fmt.Errorf("crictl -v had unexpected output: %s", out)
+ }
+ major, err := strconv.ParseUint(vs[1], 10, 64)
+ if err != nil {
+ return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out)
+ }
+ minor, err := strconv.ParseUint(vs[2], 10, 64)
if err != nil {
+ return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out)
+ }
+
+ args := []string{"create"}
+ if (major == 1 && minor >= 16) || major > 1 {
+ args = append(args, "--no-pull")
+ }
+ args = append(args, podID)
+ args = append(args, contSpecFile)
+ args = append(args, sbSpecFile)
+
+ podID, err = cc.run(args...)
+ if err != nil {
+ time.Sleep(10 * time.Minute) // XXX
return "", fmt.Errorf("create failed: %v", err)
}
+
// Strip the trailing newline from crictl output.
return strings.TrimSpace(podID), nil
}
@@ -108,6 +164,17 @@ func (cc *Crictl) Exec(contID string, args ...string) (string, error) {
return output, nil
}
+// Logs retrieves the container logs. It corresponds to `crictl logs`.
+func (cc *Crictl) Logs(contID string, args ...string) (string, error) {
+ a := []string{"logs", contID}
+ a = append(a, args...)
+ output, err := cc.run(a...)
+ if err != nil {
+ return "", fmt.Errorf("logs failed: %v", err)
+ }
+ return output, nil
+}
+
// Rm removes a container. It corresponds to `crictl rm`.
func (cc *Crictl) Rm(contID string) error {
_, err := cc.run("rm", contID)
@@ -157,27 +224,66 @@ func (cc *Crictl) RmPod(podID string) error {
return err
}
+// Import imports the given container from the local Docker instance.
+func (cc *Crictl) Import(image string) error {
+ // Note that we provide a 10 minute timeout after connect because we may
+ // be pushing a lot of bytes in order to import the image. The connect
+ // timeout stays the same and is inherited from the Crictl instance.
+ cmd := testutil.Command(cc.logger,
+ ResolvePath("ctr"),
+ fmt.Sprintf("--connect-timeout=%s", 30*time.Second),
+ fmt.Sprintf("--address=%s", cc.endpoint),
+ "-n", "k8s.io", "images", "import", "-")
+ cmd.Stderr = os.Stderr // Pass through errors.
+
+ // Create a pipe and start the program.
+ w, err := cmd.StdinPipe()
+ if err != nil {
+ return err
+ }
+ if err := cmd.Start(); err != nil {
+ return err
+ }
+
+ // Save the image on the other end.
+ if err := dockerutil.Save(cc.logger, image, w); err != nil {
+ cmd.Wait()
+ return err
+ }
+
+ // Close our pipe reference & see if it was loaded.
+ if err := w.Close(); err != nil {
+ return w.Close()
+ }
+
+ return cmd.Wait()
+}
+
// StartContainer pulls the given image ands starts the container in the
// sandbox with the given podID.
+//
+// Note that the image will always be imported from the local docker daemon.
func (cc *Crictl) StartContainer(podID, image, sbSpec, contSpec string) (string, error) {
+ if err := cc.Import(image); err != nil {
+ return "", err
+ }
+
// Write the specs to files that can be read by crictl.
- sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec)
+ sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec)
if err != nil {
return "", fmt.Errorf("failed to write sandbox spec: %v", err)
}
- contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec)
+ cc.cleanup = append(cc.cleanup, cleanup)
+ contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec)
if err != nil {
return "", fmt.Errorf("failed to write container spec: %v", err)
}
+ cc.cleanup = append(cc.cleanup, cleanup)
return cc.startContainer(podID, image, sbSpecFile, contSpecFile)
}
func (cc *Crictl) startContainer(podID, image, sbSpecFile, contSpecFile string) (string, error) {
- if err := cc.Pull(image); err != nil {
- return "", fmt.Errorf("failed to pull %s: %v", image, err)
- }
-
contID, err := cc.Create(podID, contSpecFile, sbSpecFile)
if err != nil {
return "", fmt.Errorf("failed to create container in pod %q: %v", podID, err)
@@ -203,20 +309,26 @@ func (cc *Crictl) StopContainer(contID string) error {
return nil
}
-// StartPodAndContainer pulls an image, then starts a sandbox and container in
-// that sandbox. It returns the pod ID and container ID.
-func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) {
+// StartPodAndContainer starts a sandbox and container in that sandbox. It
+// returns the pod ID and container ID.
+func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) {
+ if err := cc.Import(image); err != nil {
+ return "", "", err
+ }
+
// Write the specs to files that can be read by crictl.
- sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec)
+ sbSpecFile, cleanup, err := testutil.WriteTmpFile("sbSpec", sbSpec)
if err != nil {
return "", "", fmt.Errorf("failed to write sandbox spec: %v", err)
}
- contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec)
+ cc.cleanup = append(cc.cleanup, cleanup)
+ contSpecFile, cleanup, err := testutil.WriteTmpFile("contSpec", contSpec)
if err != nil {
return "", "", fmt.Errorf("failed to write container spec: %v", err)
}
+ cc.cleanup = append(cc.cleanup, cleanup)
- podID, err := cc.RunPod(sbSpecFile)
+ podID, err := cc.RunPod(runtime, sbSpecFile)
if err != nil {
return "", "", err
}
@@ -243,35 +355,14 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error {
return nil
}
-// run runs crictl with the given args and returns an error if it takes longer
-// than cc.Timeout to run.
+// run runs crictl with the given args.
func (cc *Crictl) run(args ...string) (string, error) {
defaultArgs := []string{
- "--image-endpoint", cc.imageEndpoint,
- "--runtime-endpoint", cc.runtimeEndpoint,
- }
- cmd := exec.Command(cc.executable, append(defaultArgs, args...)...)
-
- // Run the command with a timeout.
- done := make(chan string)
- errCh := make(chan error)
- go func() {
- output, err := cmd.CombinedOutput()
- if err != nil {
- errCh <- fmt.Errorf("error: \"%v\", output: %s", err, string(output))
- return
- }
- done <- string(output)
- }()
- select {
- case output := <-done:
- return output, nil
- case err := <-errCh:
- return "", err
- case <-time.After(cc.timeout):
- if err := testutil.KillCommand(cmd); err != nil {
- return "", fmt.Errorf("timed out, then couldn't kill process %+v: %v", cmd, err)
- }
- return "", fmt.Errorf("timed out: %+v", cmd)
+ ResolvePath("crictl"),
+ "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
+ "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint),
}
+ fullArgs := append(defaultArgs, args...)
+ out, err := testutil.Command(cc.logger, fullArgs...).CombinedOutput()
+ return string(out), err
}
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
new file mode 100644
index 000000000..a5e84658a
--- /dev/null
+++ b/pkg/test/dockerutil/BUILD
@@ -0,0 +1,42 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "dockerutil",
+ testonly = 1,
+ srcs = [
+ "container.go",
+ "dockerutil.go",
+ "exec.go",
+ "network.go",
+ "profile.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/testutil",
+ "@com_github_docker_docker//api/types:go_default_library",
+ "@com_github_docker_docker//api/types/container:go_default_library",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ "@com_github_docker_docker//api/types/network:go_default_library",
+ "@com_github_docker_docker//client:go_default_library",
+ "@com_github_docker_docker//pkg/stdcopy:go_default_library",
+ "@com_github_docker_go_connections//nat:go_default_library",
+ ],
+)
+
+go_test(
+ name = "profile_test",
+ size = "large",
+ srcs = [
+ "profile_test.go",
+ ],
+ library = ":dockerutil",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ # Also requires the test to be run as root.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md
new file mode 100644
index 000000000..870292096
--- /dev/null
+++ b/pkg/test/dockerutil/README.md
@@ -0,0 +1,86 @@
+# dockerutil
+
+This package is for creating and controlling docker containers for testing
+runsc, gVisor's docker/kubernetes binary. A simple test may look like:
+
+```
+ func TestSuperCool(t *testing.T) {
+ ctx := context.Background()
+ c := dockerutil.MakeContainer(ctx, t)
+ got, err := c.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine"
+ }, "echo", "super cool")
+ if err != nil {
+ t.Fatalf("err was not nil: %v", err)
+ }
+ want := "super cool"
+ if !strings.Contains(got, want){
+ t.Fatalf("want: %s, got: %s", want, got)
+ }
+ }
+```
+
+For further examples, see many of our end to end tests elsewhere in the repo,
+such as those in //test/e2e or benchmarks at //test/benchmarks.
+
+dockerutil uses the "official" docker golang api, which is
+[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil
+is a thin wrapper around this API, allowing desired new use cases to be easily
+implemented.
+
+## Profiling
+
+dockerutil is capable of generating profiles. Currently, the only option is to
+use pprof profiles generated by `runsc debug`. The profiler will generate Block,
+CPU, Heap, Goroutine, and Mutex profiles. To generate profiles:
+
+* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc
+ ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or
+ `--vfs2`.
+* Restart docker: `sudo service docker restart`
+
+To run and generate CPU profiles run:
+
+```
+make sudo TARGETS=//path/to:target \
+ ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt"
+```
+
+Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof`
+
+Container name in most tests and benchmarks in gVisor is usually the test name
+and some random characters like so:
+`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2`
+
+Profiling requires root as runsc debug inspects running containers in /var/run
+among other things.
+
+### Writing for Profiling
+
+The below shows an example of using profiles with dockerutil.
+
+```
+func TestSuperCool(t *testing.T){
+ ctx := context.Background()
+ // profiled and using runtime from dockerutil.runtime flag
+ profiled := MakeContainer()
+
+ // not profiled and using runtime runc
+ native := MakeNativeContainer()
+
+ err := profiled.Spawn(ctx, RunOpts{
+ Image: "some/image",
+ }, "sleep", "100000")
+ // profiling has begun here
+ ...
+ expensive setup that I don't want to profile.
+ ...
+ profiled.RestartProfiles()
+ // profiled activity
+}
+```
+
+In the above example, `profiled` would be profiled and `native` would not. The
+call to `RestartProfiles()` restarts the clock on profiling. This is useful if
+the main activity being tested is done with `docker exec` or `container.Spawn()`
+followed by one or more `container.Exec()` calls.
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
new file mode 100644
index 000000000..052b6b99d
--- /dev/null
+++ b/pkg/test/dockerutil/container.go
@@ -0,0 +1,539 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "path"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/container"
+ "github.com/docker/docker/api/types/mount"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "github.com/docker/docker/pkg/stdcopy"
+ "github.com/docker/go-connections/nat"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Container represents a Docker Container allowing
+// user to configure and control as one would with the 'docker'
+// client. Container is backed by the offical golang docker API.
+// See: https://pkg.go.dev/github.com/docker/docker.
+type Container struct {
+ Name string
+ runtime string
+
+ logger testutil.Logger
+ client *client.Client
+ id string
+ mounts []mount.Mount
+ links []string
+ copyErr error
+ cleanups []func()
+
+ // Profiles are profiles added to this container. They contain methods
+ // that are run after Creation, Start, and Cleanup of this Container, along
+ // a handle to restart the profile. Generally, tests/benchmarks using
+ // profiles need to run as root.
+ profiles []Profile
+}
+
+// RunOpts are options for running a container.
+type RunOpts struct {
+ // Image is the image relative to images/. This will be mangled
+ // appropriately, to ensure that only first-party images are used.
+ Image string
+
+ // Memory is the memory limit in bytes.
+ Memory int
+
+ // Cpus in which to allow execution. ("0", "1", "0-2").
+ CpusetCpus string
+
+ // Ports are the ports to be allocated.
+ Ports []int
+
+ // WorkDir sets the working directory.
+ WorkDir string
+
+ // ReadOnly sets the read-only flag.
+ ReadOnly bool
+
+ // Env are additional environment variables.
+ Env []string
+
+ // User is the user to use.
+ User string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // CapAdd are the extra set of capabilities to add.
+ CapAdd []string
+
+ // CapDrop are the extra set of capabilities to drop.
+ CapDrop []string
+
+ // Mounts is the list of directories/files to be mounted inside the container.
+ Mounts []mount.Mount
+
+ // Links is the list of containers to be connected to the container.
+ Links []string
+}
+
+// MakeContainer sets up the struct for a Docker container.
+//
+// Names of containers will be unique.
+// Containers will check flags for profiling requests.
+func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ c := MakeNativeContainer(ctx, logger)
+ c.runtime = *runtime
+ if p := MakePprofFromFlags(c); p != nil {
+ c.AddProfile(p)
+ }
+ return c
+}
+
+// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native
+// containers aren't profiled.
+func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ // Slashes are not allowed in container names.
+ name := testutil.RandomID(logger.Name())
+ name = strings.ReplaceAll(name, "/", "-")
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ return nil
+ }
+ client.NegotiateAPIVersion(ctx)
+ return &Container{
+ logger: logger,
+ Name: name,
+ runtime: "",
+ client: client,
+ }
+}
+
+// AddProfile adds a profile to this container.
+func (c *Container) AddProfile(p Profile) {
+ c.profiles = append(c.profiles, p)
+}
+
+// RestartProfiles calls Restart on all profiles for this container.
+func (c *Container) RestartProfiles() error {
+ for _, profile := range c.profiles {
+ if err := profile.Restart(c); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Spawn is analogous to 'docker run -d'.
+func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
+ return err
+ }
+ return c.Start(ctx)
+}
+
+// SpawnProcess is analogous to 'docker run -it'. It returns a process
+// which represents the root process.
+func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) {
+ config, hostconf, netconf := c.ConfigsFrom(r, args...)
+ config.Tty = true
+ config.OpenStdin = true
+
+ if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil {
+ return Process{}, err
+ }
+
+ // Open a connection to the container for parsing logs and for TTY.
+ stream, err := c.client.ContainerAttach(ctx, c.id,
+ types.ContainerAttachOptions{
+ Stream: true,
+ Stdin: true,
+ Stdout: true,
+ Stderr: true,
+ })
+ if err != nil {
+ return Process{}, fmt.Errorf("connect failed container id %s: %v", c.id, err)
+ }
+
+ c.cleanups = append(c.cleanups, func() { stream.Close() })
+
+ if err := c.Start(ctx); err != nil {
+ return Process{}, err
+ }
+
+ return Process{container: c, conn: stream}, nil
+}
+
+// Run is analogous to 'docker run'.
+func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
+ return "", err
+ }
+
+ if err := c.Start(ctx); err != nil {
+ return "", err
+ }
+
+ if err := c.Wait(ctx); err != nil {
+ return "", err
+ }
+
+ return c.Logs(ctx)
+}
+
+// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom'
+// and Start.
+func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) {
+ return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{}
+}
+
+// MakeLink formats a link to add to a RunOpts.
+func (c *Container) MakeLink(target string) string {
+ return fmt.Sprintf("%s:%s", c.Name, target)
+}
+
+// CreateFrom creates a container from the given configs.
+func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ return c.create(ctx, conf, hostconf, netconf)
+}
+
+// Create is analogous to 'docker create'.
+func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error {
+ return c.create(ctx, c.config(r, args), c.hostConfig(r), nil)
+}
+
+func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
+ cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name)
+ if err != nil {
+ return err
+ }
+ c.id = cont.ID
+ for _, profile := range c.profiles {
+ if err := profile.OnCreate(c); err != nil {
+ return fmt.Errorf("OnCreate method failed with: %v", err)
+ }
+ }
+ return nil
+}
+
+func (c *Container) config(r RunOpts, args []string) *container.Config {
+ ports := nat.PortSet{}
+ for _, p := range r.Ports {
+ port := nat.Port(fmt.Sprintf("%d", p))
+ ports[port] = struct{}{}
+ }
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+
+ return &container.Config{
+ Image: testutil.ImageByName(r.Image),
+ Cmd: args,
+ ExposedPorts: ports,
+ Env: env,
+ WorkingDir: r.WorkDir,
+ User: r.User,
+ }
+}
+
+func (c *Container) hostConfig(r RunOpts) *container.HostConfig {
+ c.mounts = append(c.mounts, r.Mounts...)
+
+ return &container.HostConfig{
+ Runtime: c.runtime,
+ Mounts: c.mounts,
+ PublishAllPorts: true,
+ Links: r.Links,
+ CapAdd: r.CapAdd,
+ CapDrop: r.CapDrop,
+ Privileged: r.Privileged,
+ ReadonlyRootfs: r.ReadOnly,
+ Resources: container.Resources{
+ Memory: int64(r.Memory), // In bytes.
+ CpusetCpus: r.CpusetCpus,
+ },
+ }
+}
+
+// Start is analogous to 'docker start'.
+func (c *Container) Start(ctx context.Context) error {
+ if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil {
+ return fmt.Errorf("ContainerStart failed: %v", err)
+ }
+ for _, profile := range c.profiles {
+ if err := profile.OnStart(c); err != nil {
+ return fmt.Errorf("OnStart method failed: %v", err)
+ }
+ }
+ return nil
+}
+
+// Stop is analogous to 'docker stop'.
+func (c *Container) Stop(ctx context.Context) error {
+ return c.client.ContainerStop(ctx, c.id, nil)
+}
+
+// Pause is analogous to'docker pause'.
+func (c *Container) Pause(ctx context.Context) error {
+ return c.client.ContainerPause(ctx, c.id)
+}
+
+// Unpause is analogous to 'docker unpause'.
+func (c *Container) Unpause(ctx context.Context) error {
+ return c.client.ContainerUnpause(ctx, c.id)
+}
+
+// Checkpoint is analogous to 'docker checkpoint'.
+func (c *Container) Checkpoint(ctx context.Context, name string) error {
+ return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true})
+}
+
+// Restore is analogous to 'docker start --checkname [name]'.
+func (c *Container) Restore(ctx context.Context, name string) error {
+ return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name})
+}
+
+// Logs is analogous 'docker logs'.
+func (c *Container) Logs(ctx context.Context) (string, error) {
+ var out bytes.Buffer
+ err := c.logs(ctx, &out, &out)
+ return out.String(), err
+}
+
+func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error {
+ opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true}
+ writer, err := c.client.ContainerLogs(ctx, c.id, opts)
+ if err != nil {
+ return err
+ }
+ defer writer.Close()
+ _, err = stdcopy.StdCopy(stdout, stderr, writer)
+
+ return err
+}
+
+// ID returns the container id.
+func (c *Container) ID() string {
+ return c.id
+}
+
+// SandboxPid returns the container's pid.
+func (c *Container) SandboxPid(ctx context.Context) (int, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, err
+ }
+ return resp.ContainerJSONBase.State.Pid, nil
+}
+
+// FindIP returns the IP address of the container.
+func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return nil, err
+ }
+
+ var ip net.IP
+ if ipv6 {
+ ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.GlobalIPv6Address)
+ } else {
+ ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress)
+ }
+ if ip == nil {
+ return net.IP{}, fmt.Errorf("invalid IP: %q", ip)
+ }
+ return ip, nil
+}
+
+// FindPort returns the host port that is mapped to 'sandboxPort'.
+func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) {
+ desc, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+ }
+
+ format := fmt.Sprintf("%d/tcp", sandboxPort)
+ ports, ok := desc.NetworkSettings.Ports[nat.Port(format)]
+ if !ok {
+ return -1, fmt.Errorf("error retrieving port: %v", err)
+
+ }
+
+ port, err := strconv.Atoi(ports[0].HostPort)
+ if err != nil {
+ return -1, fmt.Errorf("error parsing port %q: %v", port, err)
+ }
+ return port, nil
+}
+
+// CopyFiles copies in and mounts the given files. They are always ReadOnly.
+func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) {
+ dir, err := ioutil.TempDir("", c.Name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err)
+ return
+ }
+ c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) })
+ if err := os.Chmod(dir, 0755); err != nil {
+ c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err)
+ return
+ }
+ for _, name := range sources {
+ src, err := testutil.FindFile(name)
+ if err != nil {
+ c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err)
+ return
+ }
+ dst := path.Join(dir, path.Base(name))
+ if err := testutil.Copy(src, dst); err != nil {
+ c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
+ return
+ }
+ c.logger.Logf("copy: %s -> %s", src, dst)
+ }
+ opts.Mounts = append(opts.Mounts, mount.Mount{
+ Type: mount.TypeBind,
+ Source: dir,
+ Target: target,
+ ReadOnly: false,
+ })
+}
+
+// Status inspects the container returns its status.
+func (c *Container) Status(ctx context.Context) (types.ContainerState, error) {
+ resp, err := c.client.ContainerInspect(ctx, c.id)
+ if err != nil {
+ return types.ContainerState{}, err
+ }
+ return *resp.State, err
+}
+
+// Wait waits for the container to exit.
+func (c *Container) Wait(ctx context.Context) error {
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ }
+}
+
+// WaitTimeout waits for the container to exit with a timeout.
+func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error {
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning)
+ select {
+ case <-ctx.Done():
+ if ctx.Err() == context.DeadlineExceeded {
+ return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds())
+ }
+ return nil
+ case err := <-errChan:
+ return err
+ case <-statusChan:
+ return nil
+ }
+}
+
+// WaitForOutput searches container logs for pattern and returns or timesout.
+func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) {
+ matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout)
+ if err != nil {
+ return "", err
+ }
+ if len(matches) == 0 {
+ return "", fmt.Errorf("didn't find pattern %s logs", pattern)
+ }
+ return matches[0], nil
+}
+
+// WaitForOutputSubmatch searches container logs for the given
+// pattern or times out. It returns any regexp submatches as well.
+func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) {
+ ctx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ re := regexp.MustCompile(pattern)
+ for {
+ logs, err := c.Logs(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get logs: %v logs: %s", err, logs)
+ }
+ if matches := re.FindStringSubmatch(logs); matches != nil {
+ return matches, nil
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+}
+
+// Kill kills the container.
+func (c *Container) Kill(ctx context.Context) error {
+ return c.client.ContainerKill(ctx, c.id, "")
+}
+
+// Remove is analogous to 'docker rm'.
+func (c *Container) Remove(ctx context.Context) error {
+ // Remove the image.
+ remove := types.ContainerRemoveOptions{
+ RemoveVolumes: c.mounts != nil,
+ RemoveLinks: c.links != nil,
+ Force: true,
+ }
+ return c.client.ContainerRemove(ctx, c.Name, remove)
+}
+
+// CleanUp kills and deletes the container (best effort).
+func (c *Container) CleanUp(ctx context.Context) {
+ // Execute profile cleanups before the container goes down.
+ for _, profile := range c.profiles {
+ profile.OnCleanUp(c)
+ }
+
+ // Forget profiles.
+ c.profiles = nil
+
+ // Execute all cleanups. We execute cleanups here to close any
+ // open connections to the container before closing. Open connections
+ // can cause Kill and Remove to hang.
+ for _, c := range c.cleanups {
+ c()
+ }
+ c.cleanups = nil
+
+ // Kill the container.
+ if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") {
+ // Just log; can't do anything here.
+ c.logger.Logf("error killing container %q: %v", c.Name, err)
+ }
+ // Remove the image.
+ if err := c.Remove(ctx); err != nil {
+ c.logger.Logf("error removing container %q: %v", c.Name, err)
+ }
+ // Forget all mounts.
+ c.mounts = nil
+}
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
new file mode 100644
index 000000000..952871f95
--- /dev/null
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -0,0 +1,177 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package dockerutil is a collection of utility functions.
+package dockerutil
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "os/exec"
+ "regexp"
+ "strconv"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+var (
+ // runtime is the runtime to use for tests. This will be applied to all
+ // containers. Note that the default here ("runsc") corresponds to the
+ // default used by the installations. This is important, because the
+ // default installer for vm_tests (in tools/installers:head, invoked
+ // via tools/vm:defs.bzl) will install with this name. So without
+ // changing anything, tests should have a runsc runtime available to
+ // them. Otherwise installers should update the existing runtime
+ // instead of installing a new one.
+ runtime = flag.String("runtime", "runsc", "specify which runtime to use")
+
+ // config is the default Docker daemon configuration path.
+ config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+
+ // The following flags are for the "pprof" profiler tool.
+
+ // pprofBaseDir allows the user to change the directory to which profiles are
+ // written. By default, profiles will appear under:
+ // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof.
+ pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)")
+
+ // duration is the max duration `runsc debug` will run and capture profiles.
+ // If the container's clean up method is called prior to duration, the
+ // profiling process will be killed.
+ duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds")
+
+ // The below flags enable each type of profile. Multiple profiles can be
+ // enabled for each run.
+ pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug")
+ pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug")
+ pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug")
+ pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug")
+ pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug")
+)
+
+// EnsureSupportedDockerVersion checks if correct docker is installed.
+//
+// This logs directly to stderr, as it is typically called from a Main wrapper.
+func EnsureSupportedDockerVersion() {
+ cmd := exec.Command("docker", "version")
+ out, err := cmd.CombinedOutput()
+ if err != nil {
+ log.Fatalf("error running %q: %v", "docker version", err)
+ }
+ re := regexp.MustCompile(`Version:\s+(\d+)\.(\d+)\.\d.*`)
+ matches := re.FindStringSubmatch(string(out))
+ if len(matches) != 3 {
+ log.Fatalf("Invalid docker output: %s", out)
+ }
+ major, _ := strconv.Atoi(matches[1])
+ minor, _ := strconv.Atoi(matches[2])
+ if major < 17 || (major == 17 && minor < 9) {
+ log.Fatalf("Docker version 17.09.0 or greater is required, found: %02d.%02d", major, minor)
+ }
+}
+
+// RuntimePath returns the binary path for the current runtime.
+func RuntimePath() (string, error) {
+ rs, err := runtimeMap()
+ if err != nil {
+ return "", err
+ }
+
+ p, ok := rs["path"].(string)
+ if !ok {
+ // The runtime does not declare a path.
+ return "", fmt.Errorf("runtime does not declare a path: %v", rs)
+ }
+ return p, nil
+}
+
+// UsingVFS2 returns true if the 'runtime' has the vfs2 flag set.
+// TODO(gvisor.dev/issue/1624): Remove.
+func UsingVFS2() (bool, error) {
+ rMap, err := runtimeMap()
+ if err != nil {
+ return false, err
+ }
+
+ list, ok := rMap["runtimeArgs"].([]interface{})
+ if !ok {
+ return false, fmt.Errorf("unexpected format: %v", rMap)
+ }
+
+ for _, element := range list {
+ if element == "--vfs2" {
+ return true, nil
+ }
+ }
+ return false, nil
+}
+
+func runtimeMap() (map[string]interface{}, error) {
+ // Read the configuration data; the file must exist.
+ configBytes, err := ioutil.ReadFile(*config)
+ if err != nil {
+ return nil, err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return nil, err
+ }
+
+ // Decode the expected configuration.
+ r, ok := c["runtimes"]
+ if !ok {
+ return nil, fmt.Errorf("no runtimes declared: %v", c)
+ }
+ rs, ok := r.(map[string]interface{})
+ if !ok {
+ // The runtimes are not a map.
+ return nil, fmt.Errorf("unexpected format: %v", rs)
+ }
+ r, ok = rs[*runtime]
+ if !ok {
+ // The expected runtime is not declared.
+ return nil, fmt.Errorf("runtime %q not found: %v", *runtime, rs)
+ }
+ rs, ok = r.(map[string]interface{})
+ if !ok {
+ // The runtime is not a map.
+ return nil, fmt.Errorf("unexpected format: %v", r)
+ }
+ return rs, nil
+}
+
+// Save exports a container image to the given Writer.
+//
+// Note that the writer should be actively consuming the output, otherwise it
+// is not guaranteed that the Save will make any progress and the call may
+// stall indefinitely.
+//
+// This is called by criutil in order to import imports.
+func Save(logger testutil.Logger, image string, w io.Writer) error {
+ cmd := testutil.Command(logger, "docker", "save", testutil.ImageByName(image))
+ cmd.Stdout = w // Send directly to the writer.
+ return cmd.Run()
+}
+
+// Runtime returns the value of the flag runtime.
+func Runtime() string {
+ return *runtime
+}
diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go
new file mode 100644
index 000000000..4c739c9e9
--- /dev/null
+++ b/pkg/test/dockerutil/exec.go
@@ -0,0 +1,193 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/pkg/stdcopy"
+)
+
+// ExecOpts holds arguments for Exec calls.
+type ExecOpts struct {
+ // Env are additional environment variables.
+ Env []string
+
+ // Privileged enables privileged mode.
+ Privileged bool
+
+ // User is the user to use.
+ User string
+
+ // Enables Tty and stdin for the created process.
+ UseTTY bool
+
+ // WorkDir is the working directory of the process.
+ WorkDir string
+}
+
+// Exec creates a process inside the container.
+func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) {
+ p, err := c.doExec(ctx, opts, args)
+ if err != nil {
+ return "", err
+ }
+
+ if exitStatus, err := p.WaitExitStatus(ctx); err != nil {
+ return "", err
+ } else if exitStatus != 0 {
+ out, _ := p.Logs()
+ return out, fmt.Errorf("process terminated with status: %d", exitStatus)
+ }
+
+ return p.Logs()
+}
+
+// ExecProcess creates a process inside the container and returns a process struct
+// for the caller to use.
+func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) {
+ return c.doExec(ctx, opts, args)
+}
+
+func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) {
+ config := c.execConfig(r, args)
+ resp, err := c.client.ContainerExecCreate(ctx, c.id, config)
+ if err != nil {
+ return Process{}, fmt.Errorf("exec create failed with err: %v", err)
+ }
+
+ hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{})
+ if err != nil {
+ return Process{}, fmt.Errorf("exec attach failed with err: %v", err)
+ }
+
+ if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil {
+ hijack.Close()
+ return Process{}, fmt.Errorf("exec start failed with err: %v", err)
+ }
+
+ return Process{
+ container: c,
+ execid: resp.ID,
+ conn: hijack,
+ }, nil
+}
+
+func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig {
+ env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name))
+ return types.ExecConfig{
+ AttachStdin: r.UseTTY,
+ AttachStderr: true,
+ AttachStdout: true,
+ Cmd: cmd,
+ Privileged: r.Privileged,
+ WorkingDir: r.WorkDir,
+ Env: env,
+ Tty: r.UseTTY,
+ User: r.User,
+ }
+
+}
+
+// Process represents a containerized process.
+type Process struct {
+ container *Container
+ execid string
+ conn types.HijackedResponse
+}
+
+// Write writes buf to the process's stdin.
+func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) {
+ p.conn.Conn.SetDeadline(time.Now().Add(timeout))
+ return p.conn.Conn.Write(buf)
+}
+
+// Read returns process's stdout and stderr.
+func (p *Process) Read() (string, string, error) {
+ var stdout, stderr bytes.Buffer
+ if err := p.read(&stdout, &stderr); err != nil {
+ return "", "", err
+ }
+ return stdout.String(), stderr.String(), nil
+}
+
+// Logs returns combined stdout/stderr from the process.
+func (p *Process) Logs() (string, error) {
+ var out bytes.Buffer
+ if err := p.read(&out, &out); err != nil {
+ return "", err
+ }
+ return out.String(), nil
+}
+
+func (p *Process) read(stdout, stderr *bytes.Buffer) error {
+ _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader)
+ return err
+}
+
+// ExitCode returns the process's exit code.
+func (p *Process) ExitCode(ctx context.Context) (int, error) {
+ _, exitCode, err := p.runningExitCode(ctx)
+ return exitCode, err
+}
+
+// IsRunning checks if the process is running.
+func (p *Process) IsRunning(ctx context.Context) (bool, error) {
+ running, _, err := p.runningExitCode(ctx)
+ return running, err
+}
+
+// WaitExitStatus until process completes and returns exit status.
+func (p *Process) WaitExitStatus(ctx context.Context) (int, error) {
+ waitChan := make(chan (int))
+ errChan := make(chan (error))
+
+ go func() {
+ for {
+ running, exitcode, err := p.runningExitCode(ctx)
+ if err != nil {
+ errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name)
+ }
+ if !running {
+ waitChan <- exitcode
+ }
+ time.Sleep(time.Millisecond * 500)
+ }
+ }()
+
+ select {
+ case ws := <-waitChan:
+ return ws, nil
+ case err := <-errChan:
+ return -1, err
+ }
+}
+
+// runningExitCode collects if the process is running and the exit code.
+// The exit code is only valid if the process has exited.
+func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) {
+ // If execid is not empty, this is a execed process.
+ if p.execid != "" {
+ status, err := p.container.client.ContainerExecInspect(ctx, p.execid)
+ return status.Running, status.ExitCode, err
+ }
+ // else this is the root process.
+ status, err := p.container.Status(ctx)
+ return status.Running, status.ExitCode, err
+}
diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go
new file mode 100644
index 000000000..047091e75
--- /dev/null
+++ b/pkg/test/dockerutil/network.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "net"
+
+ "github.com/docker/docker/api/types"
+ "github.com/docker/docker/api/types/network"
+ "github.com/docker/docker/client"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Network is a docker network.
+type Network struct {
+ client *client.Client
+ id string
+ logger testutil.Logger
+ Name string
+ containers []*Container
+ Subnet *net.IPNet
+}
+
+// NewNetwork sets up the struct for a Docker network. Names of networks
+// will be unique.
+func NewNetwork(ctx context.Context, logger testutil.Logger) *Network {
+ client, err := client.NewClientWithOpts(client.FromEnv)
+ if err != nil {
+ logger.Logf("create client failed with: %v", err)
+ return nil
+ }
+ client.NegotiateAPIVersion(ctx)
+
+ return &Network{
+ logger: logger,
+ Name: testutil.RandomID(logger.Name()),
+ client: client,
+ }
+}
+
+func (n *Network) networkCreate() types.NetworkCreate {
+
+ var subnet string
+ if n.Subnet != nil {
+ subnet = n.Subnet.String()
+ }
+
+ ipam := network.IPAM{
+ Config: []network.IPAMConfig{{
+ Subnet: subnet,
+ }},
+ }
+
+ return types.NetworkCreate{
+ CheckDuplicate: true,
+ IPAM: &ipam,
+ }
+}
+
+// Create is analogous to 'docker network create'.
+func (n *Network) Create(ctx context.Context) error {
+
+ opts := n.networkCreate()
+ resp, err := n.client.NetworkCreate(ctx, n.Name, opts)
+ if err != nil {
+ return err
+ }
+ n.id = resp.ID
+ return nil
+}
+
+// Connect is analogous to 'docker network connect' with the arguments provided.
+func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error {
+ settings := network.EndpointSettings{
+ IPAMConfig: &network.EndpointIPAMConfig{
+ IPv4Address: ipv4,
+ IPv6Address: ipv6,
+ },
+ }
+ err := n.client.NetworkConnect(ctx, n.id, container.id, &settings)
+ if err == nil {
+ n.containers = append(n.containers, container)
+ }
+ return err
+}
+
+// Inspect returns this network's info.
+func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) {
+ return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true})
+}
+
+// Cleanup cleans up the docker network and all the containers attached to it.
+func (n *Network) Cleanup(ctx context.Context) error {
+ for _, c := range n.containers {
+ c.CleanUp(ctx)
+ }
+ n.containers = nil
+
+ return n.client.NetworkRemove(ctx, n.id)
+}
diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go
new file mode 100644
index 000000000..f0396ef24
--- /dev/null
+++ b/pkg/test/dockerutil/profile.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "time"
+)
+
+// Profile represents profile-like operations on a container,
+// such as running perf or pprof. It is meant to be added to containers
+// such that the container type calls the Profile during its lifecycle.
+type Profile interface {
+ // OnCreate is called just after the container is created when the container
+ // has a valid ID (e.g. c.ID()).
+ OnCreate(c *Container) error
+
+ // OnStart is called just after the container is started when the container
+ // has a valid Pid (e.g. c.SandboxPid()).
+ OnStart(c *Container) error
+
+ // Restart restarts the Profile on request.
+ Restart(c *Container) error
+
+ // OnCleanUp is called during the container's cleanup method.
+ // Cleanups should just log errors if they have them.
+ OnCleanUp(c *Container) error
+}
+
+// Pprof is for running profiles with 'runsc debug'. Pprof workloads
+// should be run as root and ONLY against runsc sandboxes. The runtime
+// should have --profile set as an option in /etc/docker/daemon.json in
+// order for profiling to work with Pprof.
+type Pprof struct {
+ BasePath string // path to put profiles
+ BlockProfile bool
+ CPUProfile bool
+ HeapProfile bool
+ MutexProfile bool
+ Duration time.Duration // duration to run profiler e.g. '10s' or '1m'.
+ shouldRun bool
+ cmd *exec.Cmd
+ stdout io.ReadCloser
+ stderr io.ReadCloser
+}
+
+// MakePprofFromFlags makes a Pprof profile from flags.
+func MakePprofFromFlags(c *Container) *Pprof {
+ if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) {
+ return nil
+ }
+ return &Pprof{
+ BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name),
+ BlockProfile: *pprofBlock,
+ CPUProfile: *pprofCPU,
+ HeapProfile: *pprofHeap,
+ MutexProfile: *pprofMutex,
+ Duration: *duration,
+ }
+}
+
+// OnCreate implements Profile.OnCreate.
+func (p *Pprof) OnCreate(c *Container) error {
+ return os.MkdirAll(p.BasePath, 0755)
+}
+
+// OnStart implements Profile.OnStart.
+func (p *Pprof) OnStart(c *Container) error {
+ path, err := RuntimePath()
+ if err != nil {
+ return fmt.Errorf("failed to get runtime path: %v", err)
+ }
+
+ // The root directory of this container's runtime.
+ root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime)
+ // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`.
+ args := []string{root, "debug"}
+ args = append(args, p.makeProfileArgs(c)...)
+ args = append(args, c.ID())
+
+ // Best effort wait until container is running.
+ for now := time.Now(); time.Since(now) < 5*time.Second; {
+ if status, err := c.Status(context.Background()); err != nil {
+ return fmt.Errorf("failed to get status with: %v", err)
+
+ } else if status.Running {
+ break
+ }
+ time.Sleep(500 * time.Millisecond)
+ }
+ p.cmd = exec.Command(path, args...)
+ if err := p.cmd.Start(); err != nil {
+ return fmt.Errorf("process failed: %v", err)
+ }
+ return nil
+}
+
+// Restart implements Profile.Restart.
+func (p *Pprof) Restart(c *Container) error {
+ p.OnCleanUp(c)
+ return p.OnStart(c)
+}
+
+// OnCleanUp implements Profile.OnCleanup
+func (p *Pprof) OnCleanUp(c *Container) error {
+ defer func() { p.cmd = nil }()
+ if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() {
+ return p.cmd.Process.Kill()
+ }
+ return nil
+}
+
+// makeProfileArgs turns Pprof fields into runsc debug flags.
+func (p *Pprof) makeProfileArgs(c *Container) []string {
+ var ret []string
+ if p.BlockProfile {
+ ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof")))
+ }
+ if p.CPUProfile {
+ ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof")))
+ }
+ if p.HeapProfile {
+ ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof")))
+ }
+ if p.MutexProfile {
+ ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof")))
+ }
+ ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration))
+ return ret
+}
diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go
new file mode 100644
index 000000000..8c4ffe483
--- /dev/null
+++ b/pkg/test/dockerutil/profile_test.go
@@ -0,0 +1,116 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+)
+
+type testCase struct {
+ name string
+ pprof Pprof
+ expectedFiles []string
+}
+
+func TestPprof(t *testing.T) {
+ // Basepath and expected file names for each type of profile.
+ basePath := "/tmp/test/profile"
+ block := "block.pprof"
+ cpu := "cpu.pprof"
+ goprofle := "go.pprof"
+ heap := "heap.pprof"
+ mutex := "mutex.pprof"
+
+ testCases := []testCase{
+ {
+ name: "Cpu",
+ pprof: Pprof{
+ BasePath: basePath,
+ CPUProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{cpu},
+ },
+ {
+ name: "All",
+ pprof: Pprof{
+ BasePath: basePath,
+ BlockProfile: true,
+ CPUProfile: true,
+ HeapProfile: true,
+ MutexProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{block, cpu, goprofle, heap, mutex},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ ctx := context.Background()
+ c := MakeContainer(ctx, t)
+ // Set basepath to include the container name so there are no conflicts.
+ tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name)
+ c.AddProfile(&tc.pprof)
+
+ func() {
+ defer c.CleanUp(ctx)
+ // Start a container.
+ if err := c.Spawn(ctx, RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("run failed with: %v", err)
+ }
+
+ if status, err := c.Status(context.Background()); !status.Running {
+ t.Fatalf("container is not yet running: %+v err: %v", status, err)
+ }
+
+ // End early if the expected files exist and have data.
+ for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) {
+ if err := checkFiles(tc); err == nil {
+ break
+ }
+ }
+ }()
+
+ // Check all expected files exist and have data.
+ if err := checkFiles(tc); err != nil {
+ t.Fatalf(err.Error())
+ }
+ })
+ }
+}
+
+func checkFiles(tc testCase) error {
+ for _, file := range tc.expectedFiles {
+ stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file))
+ if err != nil {
+ return fmt.Errorf("stat failed with: %v", err)
+ } else if stat.Size() < 1 {
+ return fmt.Errorf("file not written to: %+v", stat)
+ }
+ }
+ return nil
+}
+
+func TestMain(m *testing.M) {
+ EnsureSupportedDockerVersion()
+ os.Exit(m.Run())
+}
diff --git a/runsc/testutil/BUILD b/pkg/test/testutil/BUILD
index c96ca2eb6..2d8f56bc0 100644
--- a/runsc/testutil/BUILD
+++ b/pkg/test/testutil/BUILD
@@ -1,18 +1,20 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "testutil",
testonly = 1,
- srcs = ["testutil.go"],
- importpath = "gvisor.dev/gvisor/runsc/testutil",
+ srcs = [
+ "testutil.go",
+ "testutil_runfiles.go",
+ ],
visibility = ["//:sandbox"],
deps = [
- "//pkg/log",
+ "//pkg/sync",
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
)
diff --git a/runsc/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 9632776d2..1580527b5 100644
--- a/runsc/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -25,23 +25,25 @@ import (
"fmt"
"io"
"io/ioutil"
+ "log"
"math"
"math/rand"
"net/http"
"os"
"os/exec"
"os/signal"
+ "path"
"path/filepath"
"strconv"
"strings"
- "sync"
"sync/atomic"
"syscall"
+ "testing"
"time"
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
- "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -50,23 +52,15 @@ var (
checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
)
-func init() {
- rand.Seed(time.Now().UnixNano())
-}
-
// IsCheckpointSupported returns the relevant command line flag.
func IsCheckpointSupported() bool {
return *checkpoint
}
-// TmpDir returns the absolute path to a writable directory that can be used as
-// scratch by the test.
-func TmpDir() string {
- dir := os.Getenv("TEST_TMPDIR")
- if dir == "" {
- dir = "/tmp"
- }
- return dir
+// ImageByName mangles the image name used locally. This depends on the image
+// build infrastructure in images/ and tools/vm.
+func ImageByName(name string) string {
+ return fmt.Sprintf("gvisor.dev/images/%s", name)
}
// ConfigureExePath configures the executable for runsc in the test environment.
@@ -79,75 +73,84 @@ func ConfigureExePath() error {
return nil
}
-// FindFile searchs for a file inside the test run environment. It returns the
-// full path to the file. It fails if none or more than one file is found.
-func FindFile(path string) (string, error) {
- wd, err := os.Getwd()
- if err != nil {
- return "", err
+// TmpDir returns the absolute path to a writable directory that can be used as
+// scratch by the test.
+func TmpDir() string {
+ dir := os.Getenv("TEST_TMPDIR")
+ if dir == "" {
+ dir = "/tmp"
}
+ return dir
+}
- // The test root is demarcated by a path element called "__main__". Search for
- // it backwards from the working directory.
- root := wd
- for {
- dir, name := filepath.Split(root)
- if name == "__main__" {
- break
- }
- if len(dir) == 0 {
- return "", fmt.Errorf("directory __main__ not found in %q", wd)
- }
- // Remove ending slash to loop around.
- root = dir[:len(dir)-1]
- }
+// Logger is a simple logging wrapper.
+//
+// This is designed to be implemented by *testing.T.
+type Logger interface {
+ Name() string
+ Logf(fmt string, args ...interface{})
+}
+
+// DefaultLogger logs using the log package.
+type DefaultLogger string
+
+// Name implements Logger.Name.
+func (d DefaultLogger) Name() string {
+ return string(d)
+}
+
+// Logf implements Logger.Logf.
+func (d DefaultLogger) Logf(fmt string, args ...interface{}) {
+ log.Printf(fmt, args...)
+}
+
+// Cmd is a simple wrapper.
+type Cmd struct {
+ logger Logger
+ *exec.Cmd
+}
- // Annoyingly, bazel adds the build type to the directory path for go
- // binaries, but not for c++ binaries. We use two different patterns to
- // to find our file.
- patterns := []string{
- // Try the obvious path first.
- filepath.Join(root, path),
- // If it was a go binary, use a wildcard to match the build
- // type. The pattern is: /test-path/__main__/directories/*/file.
- filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
+// CombinedOutput returns the output and logs.
+func (c *Cmd) CombinedOutput() ([]byte, error) {
+ out, err := c.Cmd.CombinedOutput()
+ if len(out) > 0 {
+ c.logger.Logf("output: %s", string(out))
}
+ if err != nil {
+ c.logger.Logf("error: %v", err)
+ }
+ return out, err
+}
- for _, p := range patterns {
- matches, err := filepath.Glob(p)
- if err != nil {
- // "The only possible returned error is ErrBadPattern,
- // when pattern is malformed." -godoc
- return "", fmt.Errorf("error globbing %q: %v", p, err)
- }
- switch len(matches) {
- case 0:
- // Try the next pattern.
- case 1:
- // We found it.
- return matches[0], nil
- default:
- return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
- }
+// Command is a simple wrapper around exec.Command, that logs.
+func Command(logger Logger, args ...string) *Cmd {
+ logger.Logf("command: %s", strings.Join(args, " "))
+ return &Cmd{
+ logger: logger,
+ Cmd: exec.Command(args[0], args[1:]...),
}
- return "", fmt.Errorf("file %q not found", path)
}
// TestConfig returns the default configuration to use in tests. Note that
// 'RootDir' must be set by caller if required.
-func TestConfig() *boot.Config {
+func TestConfig(t *testing.T) *boot.Config {
+ logDir := os.TempDir()
+ if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
+ logDir = dir + "/"
+ }
return &boot.Config{
- Debug: true,
- LogFormat: "text",
- DebugLogFormat: "text",
- AlsoLogToStderr: true,
- LogPackets: true,
- Network: boot.NetworkNone,
- Strace: true,
- Platform: "ptrace",
- FileAccess: boot.FileAccessExclusive,
+ Debug: true,
+ DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"),
+ LogFormat: "text",
+ DebugLogFormat: "text",
+ LogPackets: true,
+ Network: boot.NetworkNone,
+ Strace: true,
+ Platform: "ptrace",
+ FileAccess: boot.FileAccessExclusive,
+ NumNetworkChannels: 1,
+
TestOnlyAllowRunAsCurrentUserWithoutChroot: true,
- NumNetworkChannels: 1,
}
}
@@ -168,6 +171,13 @@ func NewSpecWithArgs(args ...string) *specs.Spec {
Capabilities: specutils.AllCapabilities(),
},
Mounts: []specs.Mount{
+ // Hide the host /etc to avoid any side-effects.
+ // For example, bash reads /etc/passwd and if it is
+ // very big, tests can fail by timeout.
+ {
+ Type: "tmpfs",
+ Destination: "/etc",
+ },
// Root is readonly, but many tests want to write to tmpdir.
// This creates a writable mount inside the root. Also, when tmpdir points
// to "/tmp", it makes the the actual /tmp to be mounted and not a tmpfs
@@ -183,37 +193,45 @@ func NewSpecWithArgs(args ...string) *specs.Spec {
}
// SetupRootDir creates a root directory for containers.
-func SetupRootDir() (string, error) {
+func SetupRootDir() (string, func(), error) {
rootDir, err := ioutil.TempDir(TmpDir(), "containers")
if err != nil {
- return "", fmt.Errorf("error creating root dir: %v", err)
+ return "", nil, fmt.Errorf("error creating root dir: %v", err)
}
- return rootDir, nil
+ return rootDir, func() { os.RemoveAll(rootDir) }, nil
}
// SetupContainer creates a bundle and root dir for the container, generates a
// test config, and writes the spec to config.json in the bundle dir.
-func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, err error) {
- rootDir, err = SetupRootDir()
+func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, cleanup func(), err error) {
+ rootDir, rootCleanup, err := SetupRootDir()
if err != nil {
- return "", "", err
+ return "", "", nil, err
}
conf.RootDir = rootDir
- bundleDir, err = SetupBundleDir(spec)
- return rootDir, bundleDir, err
+ bundleDir, bundleCleanup, err := SetupBundleDir(spec)
+ if err != nil {
+ rootCleanup()
+ return "", "", nil, err
+ }
+ return rootDir, bundleDir, func() {
+ bundleCleanup()
+ rootCleanup()
+ }, err
}
// SetupBundleDir creates a bundle dir and writes the spec to config.json.
-func SetupBundleDir(spec *specs.Spec) (bundleDir string, err error) {
- bundleDir, err = ioutil.TempDir(TmpDir(), "bundle")
+func SetupBundleDir(spec *specs.Spec) (string, func(), error) {
+ bundleDir, err := ioutil.TempDir(TmpDir(), "bundle")
if err != nil {
- return "", fmt.Errorf("error creating bundle dir: %v", err)
+ return "", nil, fmt.Errorf("error creating bundle dir: %v", err)
}
-
- if err = writeSpec(bundleDir, spec); err != nil {
- return "", fmt.Errorf("error writing spec: %v", err)
+ cleanup := func() { os.RemoveAll(bundleDir) }
+ if err := writeSpec(bundleDir, spec); err != nil {
+ cleanup()
+ return "", nil, fmt.Errorf("error writing spec: %v", err)
}
- return bundleDir, nil
+ return bundleDir, cleanup, nil
}
// writeSpec writes the spec to disk in the given directory.
@@ -225,22 +243,28 @@ func writeSpec(dir string, spec *specs.Spec) error {
return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755)
}
-// UniqueContainerID generates a unique container id for each test.
-//
-// The container id is used to create an abstract unix domain socket, which must
-// be unique. While the container forbids creating two containers with the same
-// name, sometimes between test runs the socket does not get cleaned up quickly
-// enough, causing container creation to fail.
-func UniqueContainerID() string {
+// RandomID returns 20 random bytes following the given prefix.
+func RandomID(prefix string) string {
// Read 20 random bytes.
b := make([]byte, 20)
// "[Read] always returns len(p) and a nil error." --godoc
if _, err := rand.Read(b); err != nil {
panic("rand.Read failed: " + err.Error())
}
- // base32 encode the random bytes, so that the name is a valid
- // container id and can be used as a socket name in the filesystem.
- return fmt.Sprintf("test-container-%s", base32.StdEncoding.EncodeToString(b))
+ if prefix != "" {
+ prefix = prefix + "-"
+ }
+ return fmt.Sprintf("%s%s", prefix, base32.StdEncoding.EncodeToString(b))
+}
+
+// RandomContainerID generates a random container id for each test.
+//
+// The container id is used to create an abstract unix domain socket, which
+// must be unique. While the container forbids creating two containers with the
+// same name, sometimes between test runs the socket does not get cleaned up
+// quickly enough, causing container creation to fail.
+func RandomContainerID() string {
+ return RandomID("test-container-")
}
// Copy copies file from src to dst.
@@ -251,12 +275,39 @@ func Copy(src, dst string) error {
}
defer in.Close()
- out, err := os.Create(dst)
+ st, err := in.Stat()
+ if err != nil {
+ return err
+ }
+
+ out, err := os.OpenFile(dst, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, st.Mode().Perm())
if err != nil {
return err
}
defer out.Close()
+ // Mirror the local user's permissions across all users. This is
+ // because as we inject things into the container, the UID/GID will
+ // change. Also, the build system may generate artifacts with different
+ // modes. At the top-level (volume mapping) we have a big read-only
+ // knob that can be applied to prevent modifications.
+ //
+ // Note that this must be done via a separate Chmod call, otherwise the
+ // current process's umask will get in the way.
+ var mode os.FileMode
+ if st.Mode()&0100 != 0 {
+ mode |= 0111
+ }
+ if st.Mode()&0200 != 0 {
+ mode |= 0222
+ }
+ if st.Mode()&0400 != 0 {
+ mode |= 0444
+ }
+ if err := os.Chmod(dst, mode); err != nil {
+ return err
+ }
+
_, err = io.Copy(out, in)
return err
}
@@ -265,6 +316,11 @@ func Copy(src, dst string) error {
func Poll(cb func() error, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
+ return PollContext(ctx, cb)
+}
+
+// PollContext is like Poll, but takes a context instead of a timeout.
+func PollContext(ctx context.Context, cb func() error) error {
b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
return backoff.Retry(cb, b)
}
@@ -279,7 +335,7 @@ func WaitForHTTP(port int, timeout time.Duration) error {
url := fmt.Sprintf("http://localhost:%d/", port)
resp, err := c.Get(url)
if err != nil {
- log.Infof("Waiting %s: %v", url, err)
+ log.Printf("Waiting %s: %v", url, err)
return err
}
resp.Body.Close()
@@ -389,6 +445,8 @@ func WaitUntilRead(r io.Reader, want string, split bufio.SplitFunc, timeout time
// KillCommand kills the process running cmd unless it hasn't been started. It
// returns an error if it cannot kill the process unless the reason is that the
// process has already exited.
+//
+// KillCommand will also reap the process.
func KillCommand(cmd *exec.Cmd) error {
if cmd.Process == nil {
return nil
@@ -398,26 +456,21 @@ func KillCommand(cmd *exec.Cmd) error {
return fmt.Errorf("failed to kill process %v: %v", cmd, err)
}
}
- return nil
+ return cmd.Wait()
}
// WriteTmpFile writes text to a temporary file, closes the file, and returns
-// the name of the file.
-func WriteTmpFile(pattern, text string) (string, error) {
+// the name of the file. A cleanup function is also returned.
+func WriteTmpFile(pattern, text string) (string, func(), error) {
file, err := ioutil.TempFile(TmpDir(), pattern)
if err != nil {
- return "", err
+ return "", nil, err
}
defer file.Close()
if _, err := file.Write([]byte(text)); err != nil {
- return "", err
+ return "", nil, err
}
- return file.Name(), nil
-}
-
-// RandomName create a name with a 6 digit random number appended to it.
-func RandomName(prefix string) string {
- return fmt.Sprintf("%s-%06d", prefix, rand.Int31n(1000000))
+ return file.Name(), func() { os.RemoveAll(file.Name()) }, nil
}
// IsStatic returns true iff the given file is a static binary.
@@ -434,43 +487,55 @@ func IsStatic(filename string) (bool, error) {
return true, nil
}
-// TestBoundsForShard calculates the beginning and end indices for the test
-// based on the TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. The
-// returned ints are the beginning (inclusive) and end (exclusive) of the
-// subslice corresponding to the shard. If either of the env vars are not
-// present, then the function will return bounds that include all tests. If
-// there are more shards than there are tests, then the returned list may be
-// empty.
-func TestBoundsForShard(numTests int) (int, int, error) {
+// TouchShardStatusFile indicates to Bazel that the test runner supports
+// sharding by creating or updating the last modified date of the file
+// specified by TEST_SHARD_STATUS_FILE.
+//
+// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner.
+func TouchShardStatusFile() error {
+ if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" {
+ cmd := exec.Command("touch", statusFile)
+ if b, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error())
+ }
+ }
+ return nil
+}
+
+// TestIndicesForShard returns indices for this test shard based on the
+// TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars.
+//
+// If either of the env vars are not present, then the function will return all
+// tests. If there are more shards than there are tests, then the returned list
+// may be empty.
+func TestIndicesForShard(numTests int) ([]int, error) {
var (
- begin = 0
- end = numTests
+ shardIndex = 0
+ shardTotal = 1
)
- indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
- if indexStr == "" || totalStr == "" {
- return begin, end, nil
- }
- // Parse index and total to ints.
- shardIndex, err := strconv.Atoi(indexStr)
- if err != nil {
- return 0, 0, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
- }
- shardTotal, err := strconv.Atoi(totalStr)
- if err != nil {
- return 0, 0, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
+ if indexStr != "" && totalStr != "" {
+ // Parse index and total to ints.
+ var err error
+ shardIndex, err = strconv.Atoi(indexStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
+ }
+ shardTotal, err = strconv.Atoi(totalStr)
+ if err != nil {
+ return nil, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ }
}
// Calculate!
- shardSize := int(math.Ceil(float64(numTests) / float64(shardTotal)))
- begin = shardIndex * shardSize
- end = ((shardIndex + 1) * shardSize)
- if begin > numTests {
- // Nothing to run.
- return 0, 0, nil
- }
- if end > numTests {
- end = numTests
+ var indices []int
+ numBlocks := int(math.Ceil(float64(numTests) / float64(shardTotal)))
+ for i := 0; i < numBlocks; i++ {
+ pick := i*shardTotal + shardIndex
+ if pick < numTests {
+ indices = append(indices, pick)
+ }
}
- return begin, end, nil
+ return indices, nil
}
diff --git a/pkg/test/testutil/testutil_runfiles.go b/pkg/test/testutil/testutil_runfiles.go
new file mode 100644
index 000000000..ece9ea9a1
--- /dev/null
+++ b/pkg/test/testutil/testutil_runfiles.go
@@ -0,0 +1,75 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testutil
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+)
+
+// FindFile searchs for a file inside the test run environment. It returns the
+// full path to the file. It fails if none or more than one file is found.
+func FindFile(path string) (string, error) {
+ wd, err := os.Getwd()
+ if err != nil {
+ return "", err
+ }
+
+ // The test root is demarcated by a path element called "__main__". Search for
+ // it backwards from the working directory.
+ root := wd
+ for {
+ dir, name := filepath.Split(root)
+ if name == "__main__" {
+ break
+ }
+ if len(dir) == 0 {
+ return "", fmt.Errorf("directory __main__ not found in %q", wd)
+ }
+ // Remove ending slash to loop around.
+ root = dir[:len(dir)-1]
+ }
+
+ // Annoyingly, bazel adds the build type to the directory path for go
+ // binaries, but not for c++ binaries. We use two different patterns to
+ // to find our file.
+ patterns := []string{
+ // Try the obvious path first.
+ filepath.Join(root, path),
+ // If it was a go binary, use a wildcard to match the build
+ // type. The pattern is: /test-path/__main__/directories/*/file.
+ filepath.Join(root, filepath.Dir(path), "*", filepath.Base(path)),
+ }
+
+ for _, p := range patterns {
+ matches, err := filepath.Glob(p)
+ if err != nil {
+ // "The only possible returned error is ErrBadPattern,
+ // when pattern is malformed." -godoc
+ return "", fmt.Errorf("error globbing %q: %v", p, err)
+ }
+ switch len(matches) {
+ case 0:
+ // Try the next pattern.
+ case 1:
+ // We found it.
+ return matches[0], nil
+ default:
+ return "", fmt.Errorf("more than one match found for %q: %s", path, matches)
+ }
+ }
+ return "", fmt.Errorf("file %q not found", path)
+}
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
deleted file mode 100644
index 6afdb29b7..000000000
--- a/pkg/tmutex/BUILD
+++ /dev/null
@@ -1,18 +0,0 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tmutex",
- srcs = ["tmutex.go"],
- importpath = "gvisor.dev/gvisor/pkg/tmutex",
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "tmutex_test",
- size = "medium",
- srcs = ["tmutex_test.go"],
- embed = [":tmutex"],
-)
diff --git a/pkg/tmutex/tmutex.go b/pkg/tmutex/tmutex.go
deleted file mode 100644
index c4685020d..000000000
--- a/pkg/tmutex/tmutex.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package tmutex provides the implementation of a mutex that implements an
-// efficient TryLock function in addition to Lock and Unlock.
-package tmutex
-
-import (
- "sync/atomic"
-)
-
-// Mutex is a mutual exclusion primitive that implements TryLock in addition
-// to Lock and Unlock.
-type Mutex struct {
- v int32
- ch chan struct{}
-}
-
-// Init initializes the mutex.
-func (m *Mutex) Init() {
- m.v = 1
- m.ch = make(chan struct{}, 1)
-}
-
-// Lock acquires the mutex. If it is currently held by another goroutine, Lock
-// will wait until it has a chance to acquire it.
-func (m *Mutex) Lock() {
- // Uncontended case.
- if atomic.AddInt32(&m.v, -1) == 0 {
- return
- }
-
- for {
- // Try to acquire the mutex again, at the same time making sure
- // that m.v is negative, which indicates to the owner of the
- // lock that it is contended, which will force it to try to wake
- // someone up when it releases the mutex.
- if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 {
- return
- }
-
- // Wait for the mutex to be released before trying again.
- <-m.ch
- }
-}
-
-// TryLock attempts to acquire the mutex without blocking. If the mutex is
-// currently held by another goroutine, it fails to acquire it and returns
-// false.
-func (m *Mutex) TryLock() bool {
- v := atomic.LoadInt32(&m.v)
- if v <= 0 {
- return false
- }
- return atomic.CompareAndSwapInt32(&m.v, 1, 0)
-}
-
-// Unlock releases the mutex.
-func (m *Mutex) Unlock() {
- if atomic.SwapInt32(&m.v, 1) == 0 {
- // There were no pending waiters.
- return
- }
-
- // Wake some waiter up.
- select {
- case m.ch <- struct{}{}:
- default:
- }
-}
diff --git a/pkg/tmutex/tmutex_test.go b/pkg/tmutex/tmutex_test.go
deleted file mode 100644
index ce34c7962..000000000
--- a/pkg/tmutex/tmutex_test.go
+++ /dev/null
@@ -1,257 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package tmutex
-
-import (
- "fmt"
- "runtime"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-)
-
-func TestBasicLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- m.Lock()
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- ch <- struct{}{}
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-
- // Make sure we can lock and unlock again.
- m.Lock()
- m.Unlock()
-}
-
-func TestTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Try to lock. It should succeed.
- if !m.TryLock() {
- t.Fatalf("TryLock failed on unlocked mutex")
- }
-
- // Try to lock again, it should now fail.
- if m.TryLock() {
- t.Fatalf("TryLock succeeded on locked mutex")
- }
-
- // Try blocking lock the mutex from a different goroutine. This must
- // not block because the mutex is held.
- ch := make(chan struct{}, 1)
- go func() {
- m.Lock()
- ch <- struct{}{}
- m.Unlock()
- }()
-
- select {
- case <-ch:
- t.Fatalf("Lock succeeded on locked mutex")
- case <-time.After(100 * time.Millisecond):
- }
-
- // Unlock the mutex and make sure that the goroutine waiting on Lock()
- // unblocks and succeeds.
- m.Unlock()
-
- select {
- case <-ch:
- case <-time.After(100 * time.Millisecond):
- t.Fatalf("Lock failed to acquire unlocked mutex")
- }
-}
-
-func TestMutualExclusion(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Test mutual exclusion by running "gr" goroutines concurrently, and
- // have each one increment a counter "iters" times within the critical
- // section established by the mutex.
- //
- // If at the end the counter is not gr * iters, then we know that
- // goroutines ran concurrently within the critical section.
- //
- // If one of the goroutines doesn't complete, it's likely a bug that
- // causes to it to wait forever.
- const gr = 1000
- const iters = 100000
- v := 0
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(1)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- if v != gr*iters {
- t.Fatalf("Bad count: got %v, want %v", v, gr*iters)
- }
-}
-
-func TestMutualExclusionWithTryLock(t *testing.T) {
- var m Mutex
- m.Init()
-
- // Similar to the previous, with the addition of some goroutines that
- // only increment the count if TryLock succeeds.
- const gr = 1000
- const iters = 100000
- total := int64(gr * iters)
- var tryTotal int64
- v := int64(0)
- var wg sync.WaitGroup
- for i := 0; i < gr; i++ {
- wg.Add(2)
- go func() {
- for j := 0; j < iters; j++ {
- m.Lock()
- v++
- m.Unlock()
- }
- wg.Done()
- }()
- go func() {
- local := int64(0)
- for j := 0; j < iters; j++ {
- if m.TryLock() {
- v++
- m.Unlock()
- local++
- }
- }
- atomic.AddInt64(&tryTotal, local)
- wg.Done()
- }()
- }
-
- wg.Wait()
-
- t.Logf("tryTotal = %d", tryTotal)
- total += tryTotal
-
- if v != total {
- t.Fatalf("Bad count: got %v, want %v", v, total)
- }
-}
-
-// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following
-// differences:
-//
-// - The number of goroutines is variable, with the maximum value depending on
-// GOMAXPROCS.
-//
-// - The number of iterations per benchmark is controlled by the benchmarking
-// framework.
-//
-// - Care is taken to ensure that all goroutines participating in the benchmark
-// have been created before the benchmark begins.
-func BenchmarkTmutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m Mutex
- m.Init()
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}
-
-// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as
-// a comparison point.
-func BenchmarkSyncMutex(b *testing.B) {
- for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 {
- b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
- var m sync.Mutex
-
- var ready sync.WaitGroup
- begin := make(chan struct{})
- var end sync.WaitGroup
- for i := 0; i < n; i++ {
- ready.Add(1)
- end.Add(1)
- go func() {
- ready.Done()
- <-begin
- for j := 0; j < b.N; j++ {
- m.Lock()
- m.Unlock()
- }
- end.Done()
- }()
- }
-
- ready.Wait()
- b.ResetTimer()
- close(begin)
- end.Wait()
- })
- }
-}
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index 8f6f180e5..a86501fa2 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -1,5 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -9,7 +8,6 @@ go_library(
"unet.go",
"unet_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/pkg/unet",
visibility = ["//visibility:public"],
deps = [
"//pkg/gate",
@@ -23,5 +21,6 @@ go_test(
srcs = [
"unet_test.go",
],
- embed = [":unet"],
+ library = ":unet",
+ deps = ["//pkg/sync"],
)
diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go
index a3cc6f5d3..5c4b9e8e9 100644
--- a/pkg/unet/unet_test.go
+++ b/pkg/unet/unet_test.go
@@ -19,10 +19,11 @@ import (
"os"
"path/filepath"
"reflect"
- "sync"
"syscall"
"testing"
"time"
+
+ "gvisor.dev/gvisor/pkg/sync"
)
func randomFilename() (string, error) {
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
index b6bbb0ea2..850c34ed0 100644
--- a/pkg/urpc/BUILD
+++ b/pkg/urpc/BUILD
@@ -1,16 +1,15 @@
-load("//tools/go_stateify:defs.bzl", "go_library")
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "urpc",
srcs = ["urpc.go"],
- importpath = "gvisor.dev/gvisor/pkg/urpc",
visibility = ["//:sandbox"],
deps = [
"//pkg/fd",
"//pkg/log",
+ "//pkg/sync",
"//pkg/unet",
],
)
@@ -19,6 +18,6 @@ go_test(
name = "urpc_test",
size = "small",
srcs = ["urpc_test.go"],
- embed = [":urpc"],
+ library = ":urpc",
deps = ["//pkg/unet"],
)
diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go
index df59ffab1..13b2ea314 100644
--- a/pkg/urpc/urpc.go
+++ b/pkg/urpc/urpc.go
@@ -27,10 +27,10 @@ import (
"os"
"reflect"
"runtime"
- "sync"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
)
diff --git a/pkg/sentry/usermem/BUILD b/pkg/usermem/BUILD
index 684f59a6b..6c9ada9c7 100644
--- a/pkg/sentry/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -26,19 +25,17 @@ go_library(
"bytes_io_unsafe.go",
"usermem.go",
"usermem_arm64.go",
- "usermem_unsafe.go",
"usermem_x86.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/usermem",
- visibility = ["//pkg/sentry:internal"],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/atomicbitops",
"//pkg/binary",
+ "//pkg/context",
+ "//pkg/gohacks",
"//pkg/log",
- "//pkg/sentry/context",
- "//pkg/sentry/safemem",
+ "//pkg/safemem",
"//pkg/syserror",
- "//pkg/tcpip/buffer",
],
)
@@ -49,10 +46,10 @@ go_test(
"addr_range_seq_test.go",
"usermem_test.go",
],
- embed = [":usermem"],
+ library = ":usermem",
deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/safemem",
+ "//pkg/context",
+ "//pkg/safemem",
"//pkg/syserror",
],
)
diff --git a/pkg/sentry/usermem/README.md b/pkg/usermem/README.md
index f6d2137eb..f6d2137eb 100644
--- a/pkg/sentry/usermem/README.md
+++ b/pkg/usermem/README.md
diff --git a/pkg/sentry/usermem/access_type.go b/pkg/usermem/access_type.go
index 9c1742a59..9c1742a59 100644
--- a/pkg/sentry/usermem/access_type.go
+++ b/pkg/usermem/access_type.go
diff --git a/pkg/sentry/usermem/addr.go b/pkg/usermem/addr.go
index e79210804..c4100481e 100644
--- a/pkg/sentry/usermem/addr.go
+++ b/pkg/usermem/addr.go
@@ -106,3 +106,20 @@ func (ar AddrRange) IsPageAligned() bool {
func (ar AddrRange) String() string {
return fmt.Sprintf("[%#x, %#x)", ar.Start, ar.End)
}
+
+// PageRoundDown/Up are equivalent to Addr.RoundDown/Up, but without the
+// potentially truncating conversion from uint64 to Addr. This is necessary
+// because there is no way to define generic "PageRoundDown/Up" functions in Go.
+
+// PageRoundDown returns x rounded down to the nearest page boundary.
+func PageRoundDown(x uint64) uint64 {
+ return x &^ (PageSize - 1)
+}
+
+// PageRoundUp returns x rounded up to the nearest page boundary.
+// ok is true iff rounding up did not wrap around.
+func PageRoundUp(x uint64) (addr uint64, ok bool) {
+ addr = PageRoundDown(x + PageSize - 1)
+ ok = addr >= x
+ return
+}
diff --git a/pkg/sentry/usermem/addr_range_seq_test.go b/pkg/usermem/addr_range_seq_test.go
index 82f735026..82f735026 100644
--- a/pkg/sentry/usermem/addr_range_seq_test.go
+++ b/pkg/usermem/addr_range_seq_test.go
diff --git a/pkg/sentry/usermem/addr_range_seq_unsafe.go b/pkg/usermem/addr_range_seq_unsafe.go
index c09337c15..c09337c15 100644
--- a/pkg/sentry/usermem/addr_range_seq_unsafe.go
+++ b/pkg/usermem/addr_range_seq_unsafe.go
diff --git a/pkg/sentry/usermem/bytes_io.go b/pkg/usermem/bytes_io.go
index 8d88396ba..e177d30eb 100644
--- a/pkg/sentry/usermem/bytes_io.go
+++ b/pkg/usermem/bytes_io.go
@@ -15,8 +15,8 @@
package usermem
import (
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -102,19 +102,34 @@ func (b *BytesIO) rangeCheck(addr Addr, length int) (int, error) {
}
func (b *BytesIO) blocksFromAddrRanges(ars AddrRangeSeq) (safemem.BlockSeq, error) {
- blocks := make([]safemem.Block, 0, ars.NumRanges())
- for !ars.IsEmpty() {
- ar := ars.Head()
- n, err := b.rangeCheck(ar.Start, int(ar.Length()))
- if n != 0 {
- blocks = append(blocks, safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start):int(ar.Start)+n]))
+ switch ars.NumRanges() {
+ case 0:
+ return safemem.BlockSeq{}, nil
+ case 1:
+ block, err := b.blockFromAddrRange(ars.Head())
+ return safemem.BlockSeqOf(block), err
+ default:
+ blocks := make([]safemem.Block, 0, ars.NumRanges())
+ for !ars.IsEmpty() {
+ block, err := b.blockFromAddrRange(ars.Head())
+ if block.Len() != 0 {
+ blocks = append(blocks, block)
+ }
+ if err != nil {
+ return safemem.BlockSeqFromSlice(blocks), err
+ }
+ ars = ars.Tail()
}
- if err != nil {
- return safemem.BlockSeqFromSlice(blocks), err
- }
- ars = ars.Tail()
+ return safemem.BlockSeqFromSlice(blocks), nil
+ }
+}
+
+func (b *BytesIO) blockFromAddrRange(ar AddrRange) (safemem.Block, error) {
+ n, err := b.rangeCheck(ar.Start, int(ar.Length()))
+ if n == 0 {
+ return safemem.Block{}, err
}
- return safemem.BlockSeqFromSlice(blocks), nil
+ return safemem.BlockFromSafeSlice(b.Bytes[int(ar.Start) : int(ar.Start)+n]), err
}
// BytesIOSequence returns an IOSequence representing the given byte slice.
diff --git a/pkg/sentry/usermem/bytes_io_unsafe.go b/pkg/usermem/bytes_io_unsafe.go
index fca5952f4..20de5037d 100644
--- a/pkg/sentry/usermem/bytes_io_unsafe.go
+++ b/pkg/usermem/bytes_io_unsafe.go
@@ -19,7 +19,7 @@ import (
"unsafe"
"gvisor.dev/gvisor/pkg/atomicbitops"
- "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/context"
)
// SwapUint32 implements IO.SwapUint32.
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/usermem/usermem.go
index 7b1f312b1..cd6a0ea6b 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -22,15 +22,13 @@ import (
"strconv"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
// IO provides access to the contents of a virtual memory space.
-//
-// FIXME(b/38173783): Implementations of IO cannot expect ctx to contain any
-// meaningful data.
type IO interface {
// CopyOut copies len(src) bytes from src to the memory mapped at addr. It
// returns the number of bytes copied. If the number of bytes copied is <
@@ -251,7 +249,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
}
end, ok := addr.AddLength(uint64(readlen))
if !ok {
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
@@ -272,16 +270,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
// Look for the terminating zero byte, which may have occurred before
// hitting err.
if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
+ return gohacks.StringFromImmutableBytes(buf[:done+i]), nil
}
done += n
if err != nil {
- return stringFromImmutableBytes(buf[:done]), err
+ return gohacks.StringFromImmutableBytes(buf[:done]), err
}
addr = end
}
- return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+ return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The
diff --git a/pkg/sentry/usermem/usermem_arm64.go b/pkg/usermem/usermem_arm64.go
index fdfc30a66..fdfc30a66 100644
--- a/pkg/sentry/usermem/usermem_arm64.go
+++ b/pkg/usermem/usermem_arm64.go
diff --git a/pkg/sentry/usermem/usermem_test.go b/pkg/usermem/usermem_test.go
index 299f64754..bf3c5df2b 100644
--- a/pkg/sentry/usermem/usermem_test.go
+++ b/pkg/usermem/usermem_test.go
@@ -22,8 +22,8 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
diff --git a/pkg/sentry/usermem/usermem_x86.go b/pkg/usermem/usermem_x86.go
index 8059b72d2..d96f829fb 100644
--- a/pkg/sentry/usermem/usermem_x86.go
+++ b/pkg/usermem/usermem_x86.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build amd64 i386
+// +build amd64 386
package usermem
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 1f7efb064..852480a09 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -1,6 +1,5 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -22,8 +21,8 @@ go_library(
"waiter.go",
"waiter_list.go",
],
- importpath = "gvisor.dev/gvisor/pkg/waiter",
visibility = ["//visibility:public"],
+ deps = ["//pkg/sync"],
)
go_test(
@@ -32,13 +31,5 @@ go_test(
srcs = [
"waiter_test.go",
],
- embed = [":waiter"],
-)
-
-filegroup(
- name = "autogen",
- srcs = [
- "waiter_list.go",
- ],
- visibility = ["//:sandbox"],
+ library = ":waiter",
)
diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go
index 8a65ed164..67a950444 100644
--- a/pkg/waiter/waiter.go
+++ b/pkg/waiter/waiter.go
@@ -58,11 +58,11 @@
package waiter
import (
- "sync"
+ "gvisor.dev/gvisor/pkg/sync"
)
// EventMask represents io events as used in the poll() syscall.
-type EventMask uint16
+type EventMask uint64
// Events that waiters can wait on. The meaning is the same as those in the
// poll() syscall.
@@ -128,13 +128,6 @@ type EntryCallback interface {
//
// +stateify savable
type Entry struct {
- // Context stores any state the waiter may wish to store in the entry
- // itself, which may be used at wake up time.
- //
- // Note that use of this field is optional and state may alternatively be
- // stored in the callback itself.
- Context interface{}
-
Callback EntryCallback
// The following fields are protected by the queue lock.
@@ -142,13 +135,14 @@ type Entry struct {
waiterEntry
}
-type channelCallback struct{}
+type channelCallback struct {
+ ch chan struct{}
+}
// Callback implements EntryCallback.Callback.
-func (*channelCallback) Callback(e *Entry) {
- ch := e.Context.(chan struct{})
+func (c *channelCallback) Callback(*Entry) {
select {
- case ch <- struct{}{}:
+ case c.ch <- struct{}{}:
default:
}
}
@@ -164,7 +158,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) {
c = make(chan struct{}, 1)
}
- return Entry{Context: c, Callback: &channelCallback{}}, c
+ return Entry{Callback: &channelCallback{ch: c}}, c
}
// Queue represents the wait queue where waiters can be added and
diff --git a/runsc/BUILD b/runsc/BUILD
index e4e8e64a3..96f697a5f 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -1,7 +1,6 @@
-package(licenses = ["notice"]) # Apache 2.0
+load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar")
+package(licenses = ["notice"])
go_binary(
name = "runsc",
@@ -9,7 +8,7 @@ go_binary(
"main.go",
"version.go",
],
- pure = "on",
+ pure = True,
visibility = [
"//visibility:public",
],
@@ -20,16 +19,19 @@ go_binary(
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
+ "//runsc/flag",
"//runsc/specutils",
"@com_github_google_subcommands//:go_default_library",
],
)
# The runsc-race target is a race-compatible BUILD target. This must be built
-# via "bazel build --features=race //runsc:runsc-race", since the race feature
-# must apply to all dependencies due a bug in gazelle file selection. The pure
-# attribute must be off because the race detector requires linking with non-Go
-# components, although we still require a static binary.
+# via: bazel build --features=race :runsc-race
+#
+# This is neccessary because the race feature must apply to all dependencies
+# due a bug in gazelle file selection. The pure attribute must be off because
+# the race detector requires linking with non-Go components, although we still
+# require a static binary.
#
# Note that in the future this might be convertible to a compatible target by
# using the pure and static attributes within a select function, but select is
@@ -42,7 +44,7 @@ go_binary(
"main.go",
"version.go",
],
- static = "on",
+ static = True,
visibility = [
"//visibility:public",
],
@@ -53,39 +55,57 @@ go_binary(
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
+ "//runsc/flag",
"//runsc/specutils",
"@com_github_google_subcommands//:go_default_library",
],
)
pkg_tar(
- name = "runsc-bin",
- srcs = [":runsc"],
+ name = "debian-bin",
+ srcs = [
+ ":runsc",
+ "//shim/v1:gvisor-containerd-shim",
+ "//shim/v2:containerd-shim-runsc-v1",
+ ],
mode = "0755",
package_dir = "/usr/bin",
- strip_prefix = "/runsc/linux_amd64_pure_stripped",
)
pkg_tar(
name = "debian-data",
extension = "tar.gz",
deps = [
- ":runsc-bin",
+ ":debian-bin",
+ "//shim:config",
],
)
genrule(
name = "deb-version",
+ # Note that runsc must appear in the srcs parameter and not the tools
+ # parameter, otherwise it will not be stamped. This is reasonable, as tools
+ # may be encoded differently in the build graph (cached more aggressively
+ # because they are assumes to be hermetic).
+ srcs = [":runsc"],
outs = ["version.txt"],
- cmd = "$(location :runsc) -version | grep 'runsc version' | sed 's/^[^0-9]*//' > $@",
+ # Note that the little dance here is necessary because files in the $(SRCS)
+ # attribute are not executable by default, and we can't touch in place.
+ cmd = "cp $(location :runsc) $(@D)/runsc && \
+ chmod a+x $(@D)/runsc && \
+ $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \
+ rm -f $(@D)/runsc",
stamp = 1,
- tools = [":runsc"],
)
pkg_deb(
name = "runsc-debian",
architecture = "amd64",
data = ":debian-data",
+ # Note that the description_file will be flatten (all newlines removed),
+ # and therefore it is kept to a simple one-line description. The expected
+ # format for debian packages is "short summary\nLonger explanation of
+ # tool." and this is impossible with the flattening.
description_file = "debian/description",
homepage = "https://gvisor.dev/",
maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>",
@@ -101,5 +121,7 @@ sh_test(
name = "version_test",
size = "small",
srcs = ["version_test.sh"],
+ args = ["$(location :runsc)"],
data = [":runsc"],
+ tags = ["noguitar"],
)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 6fe2b57de..9f52438c2 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -7,38 +7,42 @@ go_library(
srcs = [
"compat.go",
"compat_amd64.go",
+ "compat_arm64.go",
"config.go",
"controller.go",
"debug.go",
"events.go",
- "fds.go",
"fs.go",
"limits.go",
"loader.go",
"network.go",
- "pprof.go",
"strace.go",
- "user.go",
+ "vfs.go",
],
- importpath = "gvisor.dev/gvisor/runsc/boot",
visibility = [
+ "//pkg/test:__subpackages__",
"//runsc:__subpackages__",
"//test:__subpackages__",
],
deps = [
"//pkg/abi",
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/control/server",
"//pkg/cpuid",
"//pkg/eventchannel",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/memutil",
"//pkg/rand",
"//pkg/refs",
"//pkg/sentry/arch",
"//pkg/sentry/arch:registers_go_proto",
- "//pkg/sentry/context",
"//pkg/sentry/control",
+ "//pkg/sentry/devices/memdev",
+ "//pkg/sentry/devices/ttydev",
+ "//pkg/sentry/devices/tundev",
+ "//pkg/sentry/fdimport",
"//pkg/sentry/fs",
"//pkg/sentry/fs/dev",
"//pkg/sentry/fs/gofer",
@@ -48,6 +52,16 @@ go_library(
"//pkg/sentry/fs/sys",
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/fs/tty",
+ "//pkg/sentry/fs/user",
+ "//pkg/sentry/fsimpl/devpts",
+ "//pkg/sentry/fsimpl/devtmpfs",
+ "//pkg/sentry/fsimpl/fuse",
+ "//pkg/sentry/fsimpl/gofer",
+ "//pkg/sentry/fsimpl/host",
+ "//pkg/sentry/fsimpl/overlay",
+ "//pkg/sentry/fsimpl/proc",
+ "//pkg/sentry/fsimpl/sys",
+ "//pkg/sentry/fsimpl/tmpfs",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel:uncaught_signal_go_proto",
@@ -60,20 +74,24 @@ go_library(
"//pkg/sentry/socket/hostinet",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/socket/netlink/route",
+ "//pkg/sentry/socket/netlink/uevent",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix",
"//pkg/sentry/state",
"//pkg/sentry/strace",
- "//pkg/sentry/syscalls/linux",
+ "//pkg/sentry/syscalls/linux/vfs2",
"//pkg/sentry/time",
"//pkg/sentry/unimpl:unimplemented_syscall_go_proto",
"//pkg/sentry/usage",
- "//pkg/sentry/usermem",
+ "//pkg/sentry/vfs",
"//pkg/sentry/watchdog",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/packetsocket",
+ "//pkg/tcpip/link/qdisc/fifo",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
@@ -86,9 +104,10 @@ go_library(
"//pkg/urpc",
"//runsc/boot/filter",
"//runsc/boot/platforms",
+ "//runsc/boot/pprof",
"//runsc/specutils",
"@com_github_golang_protobuf//proto:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
)
@@ -100,19 +119,20 @@ go_test(
"compat_test.go",
"fs_test.go",
"loader_test.go",
- "user_test.go",
],
- embed = [":boot"],
+ library = ":boot",
deps = [
"//pkg/control/server",
+ "//pkg/fspath",
"//pkg/log",
"//pkg/p9",
- "//pkg/sentry/arch:registers_go_proto",
- "//pkg/sentry/context/contexttest",
+ "//pkg/sentry/contexttest",
"//pkg/sentry/fs",
- "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/unet",
"//runsc/fsgofer",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index 07e35ab10..84c67cbc2 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -17,18 +17,16 @@ package boot
import (
"fmt"
"os"
- "sync"
"syscall"
"github.com/golang/protobuf/proto"
- "gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/eventchannel"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/arch"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/sentry/strace"
spb "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto"
+ "gvisor.dev/gvisor/pkg/sync"
)
func initCompatLogs(fd int) error {
@@ -53,9 +51,9 @@ type compatEmitter struct {
}
func newCompatEmitter(logFD int) (*compatEmitter, error) {
- nameMap, ok := strace.Lookup(abi.Linux, arch.AMD64)
+ nameMap, ok := getSyscallNameMap()
if !ok {
- return nil, fmt.Errorf("amd64 Linux syscall table not found")
+ return nil, fmt.Errorf("Linux syscall table not found")
}
c := &compatEmitter{
@@ -67,7 +65,7 @@ func newCompatEmitter(logFD int) (*compatEmitter, error) {
if logFD > 0 {
f := os.NewFile(uintptr(logFD), "user log file")
- target := log.MultiEmitter{c.sink, log.K8sJSONEmitter{log.Writer{Next: f}}}
+ target := &log.MultiEmitter{c.sink, log.K8sJSONEmitter{&log.Writer{Next: f}}}
c.sink = &log.BasicLogger{Level: log.Info, Emitter: target}
}
return c, nil
@@ -86,16 +84,16 @@ func (c *compatEmitter) Emit(msg proto.Message) (bool, error) {
}
func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) {
- regs := us.Registers.GetArch().(*rpb.Registers_Amd64).Amd64
+ regs := us.Registers
c.mu.Lock()
defer c.mu.Unlock()
- sysnr := regs.OrigRax
+ sysnr := syscallNum(regs)
tr := c.trackers[sysnr]
if tr == nil {
switch sysnr {
- case syscall.SYS_PRCTL, syscall.SYS_ARCH_PRCTL:
+ case syscall.SYS_PRCTL:
// args: cmd, ...
tr = newArgsTracker(0)
@@ -112,12 +110,22 @@ func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) {
tr = newArgsTracker(2)
default:
- tr = &onceTracker{}
+ tr = newArchArgsTracker(sysnr)
+ if tr == nil {
+ tr = &onceTracker{}
+ }
}
c.trackers[sysnr] = tr
}
+
if tr.shouldReport(regs) {
- c.sink.Infof("Unsupported syscall: %s, regs: %+v", c.nameMap.Name(uintptr(sysnr)), regs)
+ name := c.nameMap.Name(uintptr(sysnr))
+ c.sink.Infof("Unsupported syscall %s(%#x,%#x,%#x,%#x,%#x,%#x). It is "+
+ "likely that you can safely ignore this message and that this is not "+
+ "the cause of any error. Please, refer to %s/%s for more information.",
+ name, argVal(0, regs), argVal(1, regs), argVal(2, regs), argVal(3, regs),
+ argVal(4, regs), argVal(5, regs), syscallLink, name)
+
tr.onReported(regs)
}
}
@@ -139,10 +147,10 @@ func (c *compatEmitter) Close() error {
// the syscall and arguments.
type syscallTracker interface {
// shouldReport returns true is the syscall should be reported.
- shouldReport(regs *rpb.AMD64Registers) bool
+ shouldReport(regs *rpb.Registers) bool
// onReported marks the syscall as reported.
- onReported(regs *rpb.AMD64Registers)
+ onReported(regs *rpb.Registers)
}
// onceTracker reports only a single time, used for most syscalls.
@@ -150,10 +158,45 @@ type onceTracker struct {
reported bool
}
-func (o *onceTracker) shouldReport(_ *rpb.AMD64Registers) bool {
+func (o *onceTracker) shouldReport(_ *rpb.Registers) bool {
return !o.reported
}
-func (o *onceTracker) onReported(_ *rpb.AMD64Registers) {
+func (o *onceTracker) onReported(_ *rpb.Registers) {
o.reported = true
}
+
+// argsTracker reports only once for each different combination of arguments.
+// It's used for generic syscalls like ioctl to report once per 'cmd'.
+type argsTracker struct {
+ // argsIdx is the syscall arguments to use as unique ID.
+ argsIdx []int
+ reported map[string]struct{}
+ count int
+}
+
+func newArgsTracker(argIdx ...int) *argsTracker {
+ return &argsTracker{argsIdx: argIdx, reported: make(map[string]struct{})}
+}
+
+// key returns the command based on the syscall argument index.
+func (a *argsTracker) key(regs *rpb.Registers) string {
+ var rv string
+ for _, idx := range a.argsIdx {
+ rv += fmt.Sprintf("%d|", argVal(idx, regs))
+ }
+ return rv
+}
+
+func (a *argsTracker) shouldReport(regs *rpb.Registers) bool {
+ if a.count >= reportLimit {
+ return false
+ }
+ _, ok := a.reported[a.key(regs)]
+ return !ok
+}
+
+func (a *argsTracker) onReported(regs *rpb.Registers) {
+ a.count++
+ a.reported[a.key(regs)] = struct{}{}
+}
diff --git a/runsc/boot/compat_amd64.go b/runsc/boot/compat_amd64.go
index 43cd0db94..8eb76b2ba 100644
--- a/runsc/boot/compat_amd64.go
+++ b/runsc/boot/compat_amd64.go
@@ -16,62 +16,85 @@ package boot
import (
"fmt"
+ "syscall"
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/strace"
)
-// reportLimit is the max number of events that should be reported per tracker.
-const reportLimit = 100
+const (
+ // reportLimit is the max number of events that should be reported per
+ // tracker.
+ reportLimit = 100
+ syscallLink = "https://gvisor.dev/c/linux/amd64"
+)
-// argsTracker reports only once for each different combination of arguments.
-// It's used for generic syscalls like ioctl to report once per 'cmd'.
-type argsTracker struct {
- // argsIdx is the syscall arguments to use as unique ID.
- argsIdx []int
- reported map[string]struct{}
- count int
+// newRegs create a empty Registers instance.
+func newRegs() *rpb.Registers {
+ return &rpb.Registers{
+ Arch: &rpb.Registers_Amd64{
+ Amd64: &rpb.AMD64Registers{},
+ },
+ }
}
-func newArgsTracker(argIdx ...int) *argsTracker {
- return &argsTracker{argsIdx: argIdx, reported: make(map[string]struct{})}
-}
+func argVal(argIdx int, regs *rpb.Registers) uint64 {
+ amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64
-// cmd returns the command based on the syscall argument index.
-func (a *argsTracker) key(regs *rpb.AMD64Registers) string {
- var rv string
- for _, idx := range a.argsIdx {
- rv += fmt.Sprintf("%d|", argVal(idx, regs))
+ switch argIdx {
+ case 0:
+ return amd64Regs.Rdi
+ case 1:
+ return amd64Regs.Rsi
+ case 2:
+ return amd64Regs.Rdx
+ case 3:
+ return amd64Regs.R10
+ case 4:
+ return amd64Regs.R8
+ case 5:
+ return amd64Regs.R9
}
- return rv
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
}
-func argVal(argIdx int, regs *rpb.AMD64Registers) uint32 {
+func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) {
+ amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64
+
switch argIdx {
case 0:
- return uint32(regs.Rdi)
+ amd64Regs.Rdi = argVal
case 1:
- return uint32(regs.Rsi)
+ amd64Regs.Rsi = argVal
case 2:
- return uint32(regs.Rdx)
+ amd64Regs.Rdx = argVal
case 3:
- return uint32(regs.R10)
+ amd64Regs.R10 = argVal
case 4:
- return uint32(regs.R8)
+ amd64Regs.R8 = argVal
case 5:
- return uint32(regs.R9)
+ amd64Regs.R9 = argVal
+ default:
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
}
- panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
}
-func (a *argsTracker) shouldReport(regs *rpb.AMD64Registers) bool {
- if a.count >= reportLimit {
- return false
- }
- _, ok := a.reported[a.key(regs)]
- return !ok
+func getSyscallNameMap() (strace.SyscallMap, bool) {
+ return strace.Lookup(abi.Linux, arch.AMD64)
+}
+
+func syscallNum(regs *rpb.Registers) uint64 {
+ amd64Regs := regs.GetArch().(*rpb.Registers_Amd64).Amd64
+ return amd64Regs.OrigRax
}
-func (a *argsTracker) onReported(regs *rpb.AMD64Registers) {
- a.count++
- a.reported[a.key(regs)] = struct{}{}
+func newArchArgsTracker(sysnr uint64) syscallTracker {
+ switch sysnr {
+ case syscall.SYS_ARCH_PRCTL:
+ // args: cmd, ...
+ return newArgsTracker(0)
+ }
+ return nil
}
diff --git a/runsc/boot/compat_arm64.go b/runsc/boot/compat_arm64.go
new file mode 100644
index 000000000..bce9d95b3
--- /dev/null
+++ b/runsc/boot/compat_arm64.go
@@ -0,0 +1,95 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
+ "gvisor.dev/gvisor/pkg/sentry/strace"
+)
+
+const (
+ // reportLimit is the max number of events that should be reported per
+ // tracker.
+ reportLimit = 100
+ syscallLink = "https://gvisor.dev/c/linux/arm64"
+)
+
+// newRegs create a empty Registers instance.
+func newRegs() *rpb.Registers {
+ return &rpb.Registers{
+ Arch: &rpb.Registers_Arm64{
+ Arm64: &rpb.ARM64Registers{},
+ },
+ }
+}
+
+func argVal(argIdx int, regs *rpb.Registers) uint64 {
+ arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64
+
+ switch argIdx {
+ case 0:
+ return arm64Regs.R0
+ case 1:
+ return arm64Regs.R1
+ case 2:
+ return arm64Regs.R2
+ case 3:
+ return arm64Regs.R3
+ case 4:
+ return arm64Regs.R4
+ case 5:
+ return arm64Regs.R5
+ }
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
+}
+
+func setArgVal(argIdx int, argVal uint64, regs *rpb.Registers) {
+ arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64
+
+ switch argIdx {
+ case 0:
+ arm64Regs.R0 = argVal
+ case 1:
+ arm64Regs.R1 = argVal
+ case 2:
+ arm64Regs.R2 = argVal
+ case 3:
+ arm64Regs.R3 = argVal
+ case 4:
+ arm64Regs.R4 = argVal
+ case 5:
+ arm64Regs.R5 = argVal
+ default:
+ panic(fmt.Sprintf("invalid syscall argument index %d", argIdx))
+ }
+}
+
+func getSyscallNameMap() (strace.SyscallMap, bool) {
+ return strace.Lookup(abi.Linux, arch.ARM64)
+}
+
+func syscallNum(regs *rpb.Registers) uint64 {
+ arm64Regs := regs.GetArch().(*rpb.Registers_Arm64).Arm64
+ return arm64Regs.R8
+}
+
+func newArchArgsTracker(sysnr uint64) syscallTracker {
+ // currently, no arch specific syscalls need to be handled here.
+ return nil
+}
diff --git a/runsc/boot/compat_test.go b/runsc/boot/compat_test.go
index 388298d8d..839c5303b 100644
--- a/runsc/boot/compat_test.go
+++ b/runsc/boot/compat_test.go
@@ -16,8 +16,6 @@ package boot
import (
"testing"
-
- rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
)
func TestOnceTracker(t *testing.T) {
@@ -35,31 +33,34 @@ func TestOnceTracker(t *testing.T) {
func TestArgsTracker(t *testing.T) {
for _, tc := range []struct {
- name string
- idx []int
- rdi1 uint64
- rdi2 uint64
- rsi1 uint64
- rsi2 uint64
- want bool
+ name string
+ idx []int
+ arg1_1 uint64
+ arg1_2 uint64
+ arg2_1 uint64
+ arg2_2 uint64
+ want bool
}{
- {name: "same rdi", idx: []int{0}, rdi1: 123, rdi2: 123, want: false},
- {name: "same rsi", idx: []int{1}, rsi1: 123, rsi2: 123, want: false},
- {name: "diff rdi", idx: []int{0}, rdi1: 123, rdi2: 321, want: true},
- {name: "diff rsi", idx: []int{1}, rsi1: 123, rsi2: 321, want: true},
- {name: "cmd is uint32", idx: []int{0}, rsi1: 0xdead00000123, rsi2: 0xbeef00000123, want: false},
- {name: "same 2 args", idx: []int{0, 1}, rsi1: 123, rdi1: 321, rsi2: 123, rdi2: 321, want: false},
- {name: "diff 2 args", idx: []int{0, 1}, rsi1: 123, rdi1: 321, rsi2: 789, rdi2: 987, want: true},
+ {name: "same arg1", idx: []int{0}, arg1_1: 123, arg1_2: 123, want: false},
+ {name: "same arg2", idx: []int{1}, arg2_1: 123, arg2_2: 123, want: false},
+ {name: "diff arg1", idx: []int{0}, arg1_1: 123, arg1_2: 321, want: true},
+ {name: "diff arg2", idx: []int{1}, arg2_1: 123, arg2_2: 321, want: true},
+ {name: "cmd is uint32", idx: []int{0}, arg2_1: 0xdead00000123, arg2_2: 0xbeef00000123, want: false},
+ {name: "same 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 123, arg1_2: 321, want: false},
+ {name: "diff 2 args", idx: []int{0, 1}, arg2_1: 123, arg1_1: 321, arg2_2: 789, arg1_2: 987, want: true},
} {
t.Run(tc.name, func(t *testing.T) {
c := newArgsTracker(tc.idx...)
- regs := &rpb.AMD64Registers{Rdi: tc.rdi1, Rsi: tc.rsi1}
+ regs := newRegs()
+ setArgVal(0, tc.arg1_1, regs)
+ setArgVal(1, tc.arg2_1, regs)
if !c.shouldReport(regs) {
t.Error("first call to shouldReport, got: false, want: true")
}
c.onReported(regs)
- regs.Rdi, regs.Rsi = tc.rdi2, tc.rsi2
+ setArgVal(0, tc.arg1_2, regs)
+ setArgVal(1, tc.arg2_2, regs)
if got := c.shouldReport(regs); tc.want != got {
t.Errorf("second call to shouldReport, got: %t, want: %t", got, tc.want)
}
@@ -70,7 +71,9 @@ func TestArgsTracker(t *testing.T) {
func TestArgsTrackerLimit(t *testing.T) {
c := newArgsTracker(0, 1)
for i := 0; i < reportLimit; i++ {
- regs := &rpb.AMD64Registers{Rdi: 123, Rsi: uint64(i)}
+ regs := newRegs()
+ setArgVal(0, 123, regs)
+ setArgVal(1, uint64(i), regs)
if !c.shouldReport(regs) {
t.Error("shouldReport before limit was reached, got: false, want: true")
}
@@ -78,7 +81,9 @@ func TestArgsTrackerLimit(t *testing.T) {
}
// Should hit the count limit now.
- regs := &rpb.AMD64Registers{Rdi: 123, Rsi: 123456}
+ regs := newRegs()
+ setArgVal(0, 123, regs)
+ setArgVal(1, 123456, regs)
if c.shouldReport(regs) {
t.Error("shouldReport after limit was reached, got: true, want: false")
}
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 72a33534f..80da8b3e6 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -158,6 +158,9 @@ type Config struct {
// DebugLog is the path to log debug information to, if not empty.
DebugLog string
+ // PanicLog is the path to log GO's runtime messages, if not empty.
+ PanicLog string
+
// DebugLogFormat is the log format for debug.
DebugLogFormat string
@@ -184,6 +187,16 @@ type Config struct {
// SoftwareGSO indicates that software segmentation offload is enabled.
SoftwareGSO bool
+ // TXChecksumOffload indicates that TX Checksum Offload is enabled.
+ TXChecksumOffload bool
+
+ // RXChecksumOffload indicates that RX Checksum Offload is enabled.
+ RXChecksumOffload bool
+
+ // QDisc indicates the type of queuening discipline to use by default
+ // for non-loopback interfaces.
+ QDisc QueueingDiscipline
+
// LogPackets indicates that all network packets should be logged.
LogPackets bool
@@ -234,8 +247,10 @@ type Config struct {
// ReferenceLeakMode sets reference leak check mode
ReferenceLeakMode refs.LeakMode
- // OverlayfsStaleRead causes cached FDs to reopen after a file is opened for
- // write to workaround overlayfs limitation on kernels before 4.19.
+ // OverlayfsStaleRead instructs the sandbox to assume that the root mount
+ // is on a Linux overlayfs mount, which does not necessarily preserve
+ // coherence between read-only and subsequent writable file descriptors
+ // representing the "same" file.
OverlayfsStaleRead bool
// TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
@@ -250,6 +265,18 @@ type Config struct {
// multiple tests are run in parallel, since there is no way to pass
// parameters to the runtime from docker.
TestOnlyTestNameEnv string
+
+ // CPUNumFromQuota sets CPU number count to available CPU quota, using
+ // least integer value greater than or equal to quota.
+ //
+ // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2.
+ CPUNumFromQuota bool
+
+ // Enables VFS2 (not plumbled through yet).
+ VFS2 bool
+
+ // Enables FUSE usage (not plumbled through yet).
+ FUSE bool
}
// ToFlags returns a slice of flags that correspond to the given Config.
@@ -260,6 +287,7 @@ func (c *Config) ToFlags() []string {
"--log=" + c.LogFilename,
"--log-format=" + c.LogFormat,
"--debug-log=" + c.DebugLog,
+ "--panic-log=" + c.PanicLog,
"--debug-log-format=" + c.DebugLogFormat,
"--file-access=" + c.FileAccess.String(),
"--overlay=" + strconv.FormatBool(c.Overlay),
@@ -280,7 +308,13 @@ func (c *Config) ToFlags() []string {
"--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode),
"--gso=" + strconv.FormatBool(c.HardwareGSO),
"--software-gso=" + strconv.FormatBool(c.SoftwareGSO),
+ "--rx-checksum-offload=" + strconv.FormatBool(c.RXChecksumOffload),
+ "--tx-checksum-offload=" + strconv.FormatBool(c.TXChecksumOffload),
"--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead),
+ "--qdisc=" + c.QDisc.String(),
+ }
+ if c.CPUNumFromQuota {
+ f = append(f, "--cpu-num-from-quota")
}
// Only include these if set since it is never to be used by users.
if c.TestOnlyAllowRunAsCurrentUserWithoutChroot {
@@ -289,5 +323,14 @@ func (c *Config) ToFlags() []string {
if len(c.TestOnlyTestNameEnv) != 0 {
f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv)
}
+
+ if c.VFS2 {
+ f = append(f, "--vfs2=true")
+ }
+
+ if c.FUSE {
+ f = append(f, "--fuse=true")
+ }
+
return f
}
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 5f644b57e..626a3816e 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -51,7 +52,7 @@ const (
ContainerEvent = "containerManager.Event"
// ContainerExecuteAsync is the URPC endpoint for executing a command in a
- // container..
+ // container.
ContainerExecuteAsync = "containerManager.ExecuteAsync"
// ContainerPause pauses the container.
@@ -103,6 +104,8 @@ const (
StartCPUProfile = "Profile.StartCPUProfile"
StopCPUProfile = "Profile.StopCPUProfile"
HeapProfile = "Profile.HeapProfile"
+ BlockProfile = "Profile.BlockProfile"
+ MutexProfile = "Profile.MutexProfile"
StartTrace = "Profile.StartTrace"
StopTrace = "Profile.StopTrace"
)
@@ -125,43 +128,55 @@ type controller struct {
// manager holds the containerManager methods.
manager *containerManager
+
+ // pprop holds the profile instance if enabled. It may be nil.
+ pprof *control.Profile
}
// newController creates a new controller. The caller must call
// controller.srv.StartServing() to start the controller.
func newController(fd int, l *Loader) (*controller, error) {
- srv, err := server.CreateFromFD(fd)
+ ctrl := &controller{}
+ var err error
+ ctrl.srv, err = server.CreateFromFD(fd)
if err != nil {
return nil, err
}
- manager := &containerManager{
+ ctrl.manager = &containerManager{
startChan: make(chan struct{}),
startResultChan: make(chan error),
l: l,
}
- srv.Register(manager)
+ ctrl.srv.Register(ctrl.manager)
- if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok {
+ if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok {
net := &Network{
Stack: eps.Stack,
}
- srv.Register(net)
+ ctrl.srv.Register(net)
}
- srv.Register(&debug{})
- srv.Register(&control.Logging{})
- if l.conf.ProfileEnable {
- srv.Register(&control.Profile{})
+ ctrl.srv.Register(&debug{})
+ ctrl.srv.Register(&control.Logging{})
+
+ if l.root.conf.ProfileEnable {
+ ctrl.pprof = &control.Profile{Kernel: l.k}
+ ctrl.srv.Register(ctrl.pprof)
}
- return &controller{
- srv: srv,
- manager: manager,
- }, nil
+ return ctrl, nil
+}
+
+func (c *controller) stop() {
+ if c.pprof != nil {
+ // These are noop if there is nothing being profiled.
+ _ = c.pprof.StopCPUProfile(nil, nil)
+ _ = c.pprof.StopTrace(nil, nil)
+ }
}
-// containerManager manages sandboes containers.
+// containerManager manages sandbox containers.
type containerManager struct {
// startChan is used to signal when the root container process should
// be started.
@@ -327,7 +342,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
// Pause the kernel while we build a new one.
cm.l.k.Pause()
- p, err := createPlatform(cm.l.conf, deviceFile)
+ p, err := createPlatform(cm.l.root.conf, deviceFile)
if err != nil {
return fmt.Errorf("creating platform: %v", err)
}
@@ -339,12 +354,12 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
return fmt.Errorf("creating memory file: %v", err)
}
k.SetMemoryFile(mf)
- networkStack := cm.l.k.NetworkStack()
+ networkStack := cm.l.k.RootNetworkNamespace().Stack()
cm.l.k = k
// Set up the restore environment.
- mntr := newContainerMounter(cm.l.spec, cm.l.goferFDs, cm.l.k, cm.l.mountHints)
- renv, err := mntr.createRestoreEnvironment(cm.l.conf)
+ mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints)
+ renv, err := mntr.createRestoreEnvironment(cm.l.root.conf)
if err != nil {
return fmt.Errorf("creating RestoreEnvironment: %v", err)
}
@@ -362,10 +377,10 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
return fmt.Errorf("file cannot be empty")
}
- if cm.l.conf.ProfileEnable {
- // initializePProf opens /proc/self/maps, so has to be
- // called before installing seccomp filters.
- initializePProf()
+ if cm.l.root.conf.ProfileEnable {
+ // pprof.Initialize opens /proc/self/maps, so has to be called before
+ // installing seccomp filters.
+ pprof.Initialize()
}
// Seccomp filters have to be applied before parsing the state file.
@@ -380,12 +395,14 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
}
// Since we have a new kernel we also must make a new watchdog.
- dog := watchdog.New(k, watchdog.DefaultTimeout, cm.l.conf.WatchdogAction)
+ dogOpts := watchdog.DefaultOpts
+ dogOpts.TaskTimeoutAction = cm.l.root.conf.WatchdogAction
+ dog := watchdog.New(k, dogOpts)
// Change the loader fields to reflect the changes made when restoring.
cm.l.k = k
cm.l.watchdog = dog
- cm.l.rootProcArgs = kernel.CreateProcessArgs{}
+ cm.l.root.procArgs = kernel.CreateProcessArgs{}
cm.l.restore = true
// Reinitialize the sandbox ID and processes map. Note that it doesn't
diff --git a/runsc/boot/fds.go b/runsc/boot/fds.go
deleted file mode 100644
index e5de1f3d7..000000000
--- a/runsc/boot/fds.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package boot
-
-import (
- "fmt"
-
- "gvisor.dev/gvisor/pkg/sentry/context"
- "gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/fs/host"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
-)
-
-// createFDTable creates an FD table that contains stdin, stdout, and stderr.
-// If console is true, then ioctl calls will be passed through to the host FD.
-// Upon success, createFDMap dups then closes stdioFDs.
-func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, error) {
- if len(stdioFDs) != 3 {
- return nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs))
- }
-
- k := kernel.KernelFromContext(ctx)
- fdTable := k.NewFDTable()
- defer fdTable.DecRef()
- mounter := fs.FileOwnerFromContext(ctx)
-
- var ttyFile *fs.File
- for appFD, hostFD := range stdioFDs {
- var appFile *fs.File
-
- if console && appFD < 3 {
- // Import the file as a host TTY file.
- if ttyFile == nil {
- var err error
- appFile, err = host.ImportFile(ctx, hostFD, mounter, true /* isTTY */)
- if err != nil {
- return nil, err
- }
- defer appFile.DecRef()
-
- // Remember this in the TTY file, as we will
- // use it for the other stdio FDs.
- ttyFile = appFile
- } else {
- // Re-use the existing TTY file, as all three
- // stdio FDs must point to the same fs.File in
- // order to share TTY state, specifically the
- // foreground process group id.
- appFile = ttyFile
- }
- } else {
- // Import the file as a regular host file.
- var err error
- appFile, err = host.ImportFile(ctx, hostFD, mounter, false /* isTTY */)
- if err != nil {
- return nil, err
- }
- defer appFile.DecRef()
- }
-
- // Add the file to the FD map.
- if err := fdTable.NewFDAt(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil {
- return nil, err
- }
- }
-
- fdTable.IncRef()
- return fdTable, nil
-}
diff --git a/runsc/boot/filter/BUILD b/runsc/boot/filter/BUILD
index f5509b6b7..ed18f0047 100644
--- a/runsc/boot/filter/BUILD
+++ b/runsc/boot/filter/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,12 +6,14 @@ go_library(
name = "filter",
srcs = [
"config.go",
+ "config_amd64.go",
+ "config_arm64.go",
+ "config_profile.go",
"extra_filters.go",
"extra_filters_msan.go",
"extra_filters_race.go",
"filter.go",
],
- importpath = "gvisor.dev/gvisor/runsc/boot/filter",
visibility = [
"//runsc/boot:__subpackages__",
],
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 5ad108261..149eb0b1b 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -26,10 +26,6 @@ import (
// allowedSyscalls is the set of syscalls executed by the Sentry to the host OS.
var allowedSyscalls = seccomp.SyscallRules{
- syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_GET_FS)},
- {seccomp.AllowValue(linux.ARCH_SET_FS)},
- },
syscall.SYS_CLOCK_GETTIME: {},
syscall.SYS_CLONE: []seccomp.Rule{
{
@@ -42,9 +38,15 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.CLONE_THREAD),
},
},
- syscall.SYS_CLOSE: {},
- syscall.SYS_DUP: {},
- syscall.SYS_DUP2: {},
+ syscall.SYS_CLOSE: {},
+ syscall.SYS_DUP: {},
+ syscall.SYS_DUP3: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.O_CLOEXEC),
+ },
+ },
syscall.SYS_EPOLL_CREATE1: {},
syscall.SYS_EPOLL_CTL: {},
syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
@@ -132,11 +134,6 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowValue(syscall.SOL_SOCKET),
seccomp.AllowValue(syscall.SO_SNDBUF),
},
- {
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.SOL_SOCKET),
- seccomp.AllowValue(syscall.SO_REUSEADDR),
- },
},
syscall.SYS_GETTID: {},
syscall.SYS_GETTIMEOFDAY: {},
@@ -177,6 +174,18 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_LSEEK: {},
syscall.SYS_MADVISE: {},
syscall.SYS_MINCORE: {},
+ // Used by the Go runtime as a temporarily workaround for a Linux
+ // 5.2-5.4 bug.
+ //
+ // See src/runtime/os_linux_x86.go.
+ //
+ // TODO(b/148688965): Remove once this is gone from Go.
+ syscall.SYS_MLOCK: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4096),
+ },
+ },
syscall.SYS_MMAP: []seccomp.Rule{
{
seccomp.AllowAny{},
@@ -220,7 +229,11 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_NANOSLEEP: {},
syscall.SYS_PPOLL: {},
syscall.SYS_PREAD64: {},
+ syscall.SYS_PREADV: {},
+ unix.SYS_PREADV2: {},
syscall.SYS_PWRITE64: {},
+ syscall.SYS_PWRITEV: {},
+ unix.SYS_PWRITEV2: {},
syscall.SYS_READ: {},
syscall.SYS_RECVMSG: []seccomp.Rule{
{
@@ -273,26 +286,36 @@ var allowedSyscalls = seccomp.SyscallRules{
{seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
+ unix.SYS_STATX: {},
syscall.SYS_SYNC_FILE_RANGE: {},
+ syscall.SYS_TEE: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(1), /* len */
+ seccomp.AllowValue(unix.SPLICE_F_NONBLOCK), /* flags */
+ },
+ },
syscall.SYS_TGKILL: []seccomp.Rule{
{
seccomp.AllowValue(uint64(os.Getpid())),
},
},
- syscall.SYS_WRITE: {},
- // The only user in rawfile.NonBlockingWrite3 always passes iovcnt with
- // values 2 or 3. Three iovec-s are passed, when the PACKET_VNET_HDR
- // option is enabled for a packet socket.
- syscall.SYS_WRITEV: []seccomp.Rule{
+ syscall.SYS_UTIMENSAT: []seccomp.Rule{
{
seccomp.AllowAny{},
+ seccomp.AllowValue(0), /* null pathname */
seccomp.AllowAny{},
- seccomp.AllowValue(2),
+ seccomp.AllowValue(0), /* flags */
},
+ },
+ syscall.SYS_WRITE: {},
+ // For rawfile.NonBlockingWriteIovec.
+ syscall.SYS_WRITEV: []seccomp.Rule{
{
seccomp.AllowAny{},
seccomp.AllowAny{},
- seccomp.AllowValue(3),
+ seccomp.GreaterThan(0),
},
},
}
@@ -315,6 +338,26 @@ func hostInetFilters() seccomp.SyscallRules {
syscall.SYS_GETSOCKOPT: []seccomp.Rule{
{
seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_TOS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_RECVTOS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_TCLASS),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_RECVTCLASS),
+ },
+ {
+ seccomp.AllowAny{},
seccomp.AllowValue(syscall.SOL_IPV6),
seccomp.AllowValue(syscall.IPV6_V6ONLY),
},
@@ -416,6 +459,34 @@ func hostInetFilters() seccomp.SyscallRules {
seccomp.AllowAny{},
seccomp.AllowValue(4),
},
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_TOS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IP),
+ seccomp.AllowValue(syscall.IP_RECVTOS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_TCLASS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.SOL_IPV6),
+ seccomp.AllowValue(syscall.IPV6_RECVTCLASS),
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4),
+ },
},
syscall.SYS_SHUTDOWN: []seccomp.Rule{
{
@@ -479,16 +550,3 @@ func controlServerFilters(fd int) seccomp.SyscallRules {
},
}
}
-
-// profileFilters returns extra syscalls made by runtime/pprof package.
-func profileFilters() seccomp.SyscallRules {
- return seccomp.SyscallRules{
- syscall.SYS_OPENAT: []seccomp.Rule{
- {
- seccomp.AllowAny{},
- seccomp.AllowAny{},
- seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
- },
- },
- }
-}
diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go
new file mode 100644
index 000000000..5335ff82c
--- /dev/null
+++ b/runsc/boot/filter/config_amd64.go
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+func init() {
+ allowedSyscalls[syscall.SYS_ARCH_PRCTL] = append(allowedSyscalls[syscall.SYS_ARCH_PRCTL],
+ seccomp.Rule{seccomp.AllowValue(linux.ARCH_GET_FS)},
+ seccomp.Rule{seccomp.AllowValue(linux.ARCH_SET_FS)},
+ )
+}
diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go
new file mode 100644
index 000000000..7fa9bbda3
--- /dev/null
+++ b/runsc/boot/filter/config_arm64.go
@@ -0,0 +1,21 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package filter
+
+// Reserve for future customization.
+func init() {
+}
diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go
new file mode 100644
index 000000000..194952a7b
--- /dev/null
+++ b/runsc/boot/filter/config_profile.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+// profileFilters returns extra syscalls made by runtime/pprof package.
+func profileFilters() seccomp.SyscallRules {
+ return seccomp.SyscallRules{
+ syscall.SYS_OPENAT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
+ },
+ },
+ }
+}
diff --git a/runsc/boot/filter/extra_filters_msan.go b/runsc/boot/filter/extra_filters_msan.go
index 5e5a3c998..209e646a7 100644
--- a/runsc/boot/filter/extra_filters_msan.go
+++ b/runsc/boot/filter/extra_filters_msan.go
@@ -26,6 +26,8 @@ import (
func instrumentationFilters() seccomp.SyscallRules {
Report("MSAN is enabled: syscall filters less restrictive!")
return seccomp.SyscallRules{
+ syscall.SYS_CLONE: {},
+ syscall.SYS_MMAP: {},
syscall.SYS_SCHED_GETAFFINITY: {},
syscall.SYS_SET_ROBUST_LIST: {},
}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 76036c147..9dd5b0184 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -16,7 +16,6 @@ package boot
import (
"fmt"
- "path"
"path/filepath"
"sort"
"strconv"
@@ -30,14 +29,22 @@ import (
_ "gvisor.dev/gvisor/pkg/sentry/fs/sys"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tty"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/gofer"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ gofervfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
+ procvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
+ sysvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ tmpfsvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
@@ -45,27 +52,19 @@ import (
)
const (
- // Filesystem name for 9p gofer mounts.
- rootFsName = "9p"
-
// Device name for root mount.
rootDevice = "9pfs-/"
// MountPrefix is the annotation prefix for mount hints.
- MountPrefix = "gvisor.dev/spec/mount"
-
- // Filesystems that runsc supports.
- bind = "bind"
- devpts = "devpts"
- devtmpfs = "devtmpfs"
- proc = "proc"
- sysfs = "sysfs"
- tmpfs = "tmpfs"
- nonefs = "none"
+ MountPrefix = "dev.gvisor.spec.mount."
+
+ // Supported filesystems that map to different internal filesystem.
+ bind = "bind"
+ nonefs = "none"
)
// tmpfs has some extra supported options that we must pass through.
-var tmpfsAllowedOptions = []string{"mode", "uid", "gid"}
+var tmpfsAllowedData = []string{"mode", "uid", "gid"}
func addOverlay(ctx context.Context, conf *Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) {
// Upper layer uses the same flags as lower, but it must be read-write.
@@ -109,12 +108,12 @@ func compileMounts(spec *specs.Spec) []specs.Mount {
// Always mount /dev.
mounts = append(mounts, specs.Mount{
- Type: devtmpfs,
+ Type: devtmpfs.Name,
Destination: "/dev",
})
mounts = append(mounts, specs.Mount{
- Type: devpts,
+ Type: devpts.Name,
Destination: "/dev/pts",
})
@@ -138,13 +137,13 @@ func compileMounts(spec *specs.Spec) []specs.Mount {
var mandatoryMounts []specs.Mount
if !procMounted {
mandatoryMounts = append(mandatoryMounts, specs.Mount{
- Type: proc,
+ Type: procvfs2.Name,
Destination: "/proc",
})
}
if !sysMounted {
mandatoryMounts = append(mandatoryMounts, specs.Mount{
- Type: sysfs,
+ Type: sysvfs2.Name,
Destination: "/sys",
})
}
@@ -156,13 +155,17 @@ func compileMounts(spec *specs.Spec) []specs.Mount {
return mounts
}
-// p9MountOptions creates a slice of options for a p9 mount.
-func p9MountOptions(fd int, fa FileAccessType) []string {
+// p9MountData creates a slice of p9 mount data.
+func p9MountData(fd int, fa FileAccessType, vfs2 bool) []string {
opts := []string{
"trans=fd",
"rfdno=" + strconv.Itoa(fd),
"wfdno=" + strconv.Itoa(fd),
- "privateunixsocket=true",
+ }
+ if !vfs2 {
+ // privateunixsocket is always enabled in VFS2. VFS1 requires explicit
+ // enablement.
+ opts = append(opts, "privateunixsocket=true")
}
if fa == FileAccessShared {
opts = append(opts, "cache=remote_revalidating")
@@ -232,8 +235,8 @@ func isSupportedMountFlag(fstype, opt string) bool {
case "rw", "ro", "noatime", "noexec":
return true
}
- if fstype == tmpfs {
- ok, err := parseMountOption(opt, tmpfsAllowedOptions...)
+ if fstype == tmpfsvfs2.Name {
+ ok, err := parseMountOption(opt, tmpfsAllowedData...)
return ok && err == nil
}
return false
@@ -279,6 +282,9 @@ func subtargets(root string, mnts []specs.Mount) []string {
}
func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ if conf.VFS2 {
+ return setupContainerVFS2(ctx, conf, mntr, procArgs)
+ }
mns, err := mntr.setupFS(conf, procArgs)
if err != nil {
return err
@@ -287,19 +293,12 @@ func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter,
// Set namespace here so that it can be found in ctx.
procArgs.MountNamespace = mns
- return setExecutablePath(ctx, procArgs)
-}
-
-// setExecutablePath sets the procArgs.Filename by searching the PATH for an
-// executable matching the procArgs.Argv[0].
-func setExecutablePath(ctx context.Context, procArgs *kernel.CreateProcessArgs) error {
- paths := fs.GetPath(procArgs.Envv)
- exe := procArgs.Argv[0]
- f, err := procArgs.MountNamespace.ResolveExecutablePath(ctx, procArgs.WorkingDirectory, exe, paths)
+ // Resolve the executable path from working dir and environment.
+ resolved, err := user.ResolveExecutablePath(ctx, procArgs)
if err != nil {
- return fmt.Errorf("searching for executable %q, cwd: %q, $PATH=%q: %v", exe, procArgs.WorkingDirectory, strings.Join(paths, ":"), err)
+ return err
}
- procArgs.Filename = f
+ procArgs.Filename = resolved
return nil
}
@@ -392,6 +391,10 @@ type mountHint struct {
// root is the inode where the volume is mounted. For mounts with 'pod' share
// the volume is mounted once and then bind mounted inside the containers.
root *fs.Inode
+
+ // vfsMount is the master mount for the volume. For mounts with 'pod' share
+ // the master volume is bind mounted inside the containers.
+ vfsMount *vfs.Mount
}
func (m *mountHint) setField(key, val string) error {
@@ -439,7 +442,7 @@ func (m *mountHint) setOptions(val string) error {
}
func (m *mountHint) isSupported() bool {
- return m.mount.Type == tmpfs && m.share == pod
+ return m.mount.Type == tmpfsvfs2.Name && m.share == pod
}
// checkCompatible verifies that shared mount is compatible with master.
@@ -465,6 +468,13 @@ func (m *mountHint) checkCompatible(mount specs.Mount) error {
return nil
}
+func (m *mountHint) fileAccessType() FileAccessType {
+ if m.share == container {
+ return FileAccessExclusive
+ }
+ return FileAccessShared
+}
+
func filterUnsupportedOptions(mount specs.Mount) []string {
rv := make([]string, 0, len(mount.Options))
for _, o := range mount.Options {
@@ -483,14 +493,15 @@ type podMountHints struct {
func newPodMountHints(spec *specs.Spec) (*podMountHints, error) {
mnts := make(map[string]*mountHint)
for k, v := range spec.Annotations {
- // Look for 'gvisor.dev/spec/mount' annotations and parse them.
+ // Look for 'dev.gvisor.spec.mount' annotations and parse them.
if strings.HasPrefix(k, MountPrefix) {
- parts := strings.Split(k, "/")
- if len(parts) != 5 {
+ // Remove the prefix and split the rest.
+ parts := strings.Split(k[len(MountPrefix):], ".")
+ if len(parts) != 2 {
return nil, fmt.Errorf("invalid mount annotation: %s=%s", k, v)
}
- name := parts[3]
- if len(name) == 0 || path.Clean(name) != name {
+ name := parts[0]
+ if len(name) == 0 {
return nil, fmt.Errorf("invalid mount name: %s", name)
}
mnt := mnts[name]
@@ -498,7 +509,7 @@ func newPodMountHints(spec *specs.Spec) (*podMountHints, error) {
mnt = &mountHint{name: name}
mnts[name] = mnt
}
- if err := mnt.setField(parts[4], v); err != nil {
+ if err := mnt.setField(parts[1], v); err != nil {
return nil, err
}
}
@@ -565,9 +576,17 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin
// processHints processes annotations that container hints about how volumes
// should be mounted (e.g. a volume shared between containers). It must be
// called for the root container only.
-func (c *containerMounter) processHints(conf *Config) error {
+func (c *containerMounter) processHints(conf *Config, creds *auth.Credentials) error {
+ if conf.VFS2 {
+ return c.processHintsVFS2(conf, creds)
+ }
ctx := c.k.SupervisorContext()
for _, hint := range c.hints.mounts {
+ // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a
+ // common gofer to mount all shared volumes.
+ if hint.mount.Type != tmpfsvfs2.Name {
+ continue
+ }
log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
inode, err := c.mountSharedMaster(ctx, conf, hint)
if err != nil {
@@ -621,7 +640,7 @@ func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Confi
func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error {
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
for _, m := range c.mounts {
log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options)
@@ -702,7 +721,7 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (*
fd := c.fds.remove()
log.Infof("Mounting root over 9P, ioFD: %d", fd)
p9FS := mustFindFilesystem("9p")
- opts := p9MountOptions(fd, conf.FileAccess)
+ opts := p9MountData(fd, conf.FileAccess, false /* vfs2 */)
if conf.OverlayfsStaleRead {
// We can't check for overlayfs here because sandbox is chroot'ed and gofer
@@ -748,36 +767,40 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
)
switch m.Type {
- case devpts, devtmpfs, proc, sysfs:
+ case devpts.Name, devtmpfs.Name, procvfs2.Name, sysvfs2.Name:
fsName = m.Type
case nonefs:
- fsName = sysfs
- case tmpfs:
+ fsName = sysvfs2.Name
+ case tmpfsvfs2.Name:
fsName = m.Type
var err error
- opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedOptions...)
+ opts, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...)
if err != nil {
return "", nil, false, err
}
case bind:
fd := c.fds.remove()
- fsName = "9p"
- // Non-root bind mounts are always shared.
- opts = p9MountOptions(fd, FileAccessShared)
+ fsName = gofervfs2.Name
+ opts = p9MountData(fd, c.getMountAccessType(m), conf.VFS2)
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
default:
- // TODO(nlacasse): Support all the mount types and make this a fatal error.
- // Most applications will "just work" without them, so this is a warning
- // for now.
log.Warningf("ignoring unknown filesystem type %q", m.Type)
}
return fsName, opts, useOverlay, nil
}
+func (c *containerMounter) getMountAccessType(mount specs.Mount) FileAccessType {
+ if hint := c.hints.findMount(mount); hint != nil {
+ return hint.fileAccessType()
+ }
+ // Non-root bind mounts are always shared if no hints were provided.
+ return FileAccessShared
+}
+
// mountSubmount mounts volumes inside the container's root. Because mounts may
// be readonly, a lower ramfs overlay is added to create the mount point dir.
// Another overlay is added with tmpfs on top if Config.Overlay is true.
@@ -805,7 +828,20 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns
inode, err := filesystem.Mount(ctx, mountDevice(m), mf, strings.Join(opts, ","), nil)
if err != nil {
- return fmt.Errorf("creating mount with source %q: %v", m.Source, err)
+ err := fmt.Errorf("creating mount with source %q: %v", m.Source, err)
+ // Check to see if this is a common error due to a Linux bug.
+ // This error is generated here in order to cause it to be
+ // printed to the user using Docker via 'runsc create' etc. rather
+ // than simply printed to the logs for the 'runsc boot' command.
+ //
+ // We check the error message string rather than type because the
+ // actual error types (syscall.EIO, syscall.EPIPE) are lost by file system
+ // implementation (e.g. p9).
+ // TODO(gvisor.dev/issue/1765): Remove message when bug is resolved.
+ if strings.Contains(err.Error(), syscall.EIO.Error()) || strings.Contains(err.Error(), syscall.EPIPE.Error()) {
+ return fmt.Errorf("%v: %s", err, specutils.FaqErrorMsg("memlock", "you may be encountering a Linux kernel bug"))
+ }
+ return err
}
// If there are submounts, we need to overlay the mount on top of a ramfs
@@ -832,12 +868,12 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns
if err != nil {
return fmt.Errorf("can't find mount destination %q: %v", m.Destination, err)
}
- defer dirent.DecRef()
+ defer dirent.DecRef(ctx)
if err := mns.Mount(ctx, dirent, inode); err != nil {
return fmt.Errorf("mount %q error: %v", m.Destination, err)
}
- log.Infof("Mounted %q to %q type %s", m.Source, m.Destination, m.Type)
+ log.Infof("Mounted %q to %q type: %s, internal-options: %q", m.Source, m.Destination, m.Type, opts)
return nil
}
@@ -853,12 +889,12 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun
if err != nil {
return fmt.Errorf("can't find mount destination %q: %v", mount.Destination, err)
}
- defer target.DecRef()
+ defer target.DecRef(ctx)
// Take a ref on the inode that is about to be (re)-mounted.
source.root.IncRef()
if err := mns.Mount(ctx, target, source.root); err != nil {
- source.root.DecRef()
+ source.root.DecRef(ctx)
return fmt.Errorf("bind mount %q error: %v", mount.Destination, err)
}
@@ -900,7 +936,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEn
// Add root mount.
fd := c.fds.remove()
- opts := p9MountOptions(fd, conf.FileAccess)
+ opts := p9MountData(fd, conf.FileAccess, false /* vfs2 */)
mf := fs.MountSourceFlags{}
if c.root.Readonly || conf.Overlay {
@@ -912,7 +948,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEn
Flags: mf,
DataString: strings.Join(opts, ","),
}
- renv.MountSources[rootFsName] = append(renv.MountSources[rootFsName], rootMount)
+ renv.MountSources[gofervfs2.Name] = append(renv.MountSources[gofervfs2.Name], rootMount)
// Add submounts.
var tmpMounted bool
@@ -928,7 +964,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEn
// TODO(b/67958150): handle '/tmp' properly (see mountTmp()).
if !tmpMounted {
tmpMount := specs.Mount{
- Type: tmpfs,
+ Type: tmpfsvfs2.Name,
Destination: "/tmp",
}
if err := c.addRestoreMount(conf, renv, tmpMount); err != nil {
@@ -961,12 +997,12 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.M
switch err {
case nil:
// Found '/tmp' in filesystem, check if it's empty.
- defer tmp.DecRef()
+ defer tmp.DecRef(ctx)
f, err := tmp.Inode.GetFile(ctx, tmp, fs.FileFlags{Read: true, Directory: true})
if err != nil {
return err
}
- defer f.DecRef()
+ defer f.DecRef(ctx)
serializer := &fs.CollectEntriesSerializer{}
if err := f.Readdir(ctx, serializer); err != nil {
return err
@@ -984,11 +1020,11 @@ func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.M
// No '/tmp' found (or fallthrough from above). Safe to mount internal
// tmpfs.
tmpMount := specs.Mount{
- Type: tmpfs,
+ Type: tmpfsvfs2.Name,
Destination: "/tmp",
// Sticky bit is added to prevent accidental deletion of files from
// another user. This is normally done for /tmp.
- Options: []string{"mode=1777"},
+ Options: []string{"mode=01777"},
}
return c.mountSubmount(ctx, conf, mns, root, tmpMount)
diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go
index 49ab34b33..912037075 100644
--- a/runsc/boot/fs_test.go
+++ b/runsc/boot/fs_test.go
@@ -15,7 +15,6 @@
package boot
import (
- "path"
"reflect"
"strings"
"testing"
@@ -26,19 +25,19 @@ import (
func TestPodMountHintsHappy(t *testing.T) {
spec := &specs.Spec{
Annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "foo",
- path.Join(MountPrefix, "mount1", "type"): "tmpfs",
- path.Join(MountPrefix, "mount1", "share"): "pod",
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
- path.Join(MountPrefix, "mount2", "source"): "bar",
- path.Join(MountPrefix, "mount2", "type"): "bind",
- path.Join(MountPrefix, "mount2", "share"): "container",
- path.Join(MountPrefix, "mount2", "options"): "rw,private",
+ MountPrefix + "mount2.source": "bar",
+ MountPrefix + "mount2.type": "bind",
+ MountPrefix + "mount2.share": "container",
+ MountPrefix + "mount2.options": "rw,private",
},
}
podHints, err := newPodMountHints(spec)
if err != nil {
- t.Errorf("newPodMountHints failed: %v", err)
+ t.Fatalf("newPodMountHints failed: %v", err)
}
// Check that fields were set correctly.
@@ -86,95 +85,95 @@ func TestPodMountHintsErrors(t *testing.T) {
{
name: "too short",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1"): "foo",
+ MountPrefix + "mount1": "foo",
},
error: "invalid mount annotation",
},
{
name: "no name",
annotations: map[string]string{
- MountPrefix + "//source": "foo",
+ MountPrefix + ".source": "foo",
},
error: "invalid mount name",
},
{
name: "missing source",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "type"): "tmpfs",
- path.Join(MountPrefix, "mount1", "share"): "pod",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
},
error: "source field",
},
{
name: "missing type",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "foo",
- path.Join(MountPrefix, "mount1", "share"): "pod",
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.share": "pod",
},
error: "type field",
},
{
name: "missing share",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "foo",
- path.Join(MountPrefix, "mount1", "type"): "tmpfs",
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
},
error: "share field",
},
{
name: "invalid field name",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "invalid"): "foo",
+ MountPrefix + "mount1.invalid": "foo",
},
error: "invalid mount annotation",
},
{
name: "invalid source",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "",
- path.Join(MountPrefix, "mount1", "type"): "tmpfs",
- path.Join(MountPrefix, "mount1", "share"): "pod",
+ MountPrefix + "mount1.source": "",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
},
error: "source cannot be empty",
},
{
name: "invalid type",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "foo",
- path.Join(MountPrefix, "mount1", "type"): "invalid-type",
- path.Join(MountPrefix, "mount1", "share"): "pod",
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "invalid-type",
+ MountPrefix + "mount1.share": "pod",
},
error: "invalid type",
},
{
name: "invalid share",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "foo",
- path.Join(MountPrefix, "mount1", "type"): "tmpfs",
- path.Join(MountPrefix, "mount1", "share"): "invalid-share",
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "invalid-share",
},
error: "invalid share",
},
{
name: "invalid options",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "foo",
- path.Join(MountPrefix, "mount1", "type"): "tmpfs",
- path.Join(MountPrefix, "mount1", "share"): "pod",
- path.Join(MountPrefix, "mount1", "options"): "invalid-option",
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
+ MountPrefix + "mount1.options": "invalid-option",
},
error: "unknown mount option",
},
{
name: "duplicate source",
annotations: map[string]string{
- path.Join(MountPrefix, "mount1", "source"): "foo",
- path.Join(MountPrefix, "mount1", "type"): "tmpfs",
- path.Join(MountPrefix, "mount1", "share"): "pod",
+ MountPrefix + "mount1.source": "foo",
+ MountPrefix + "mount1.type": "tmpfs",
+ MountPrefix + "mount1.share": "pod",
- path.Join(MountPrefix, "mount2", "source"): "foo",
- path.Join(MountPrefix, "mount2", "type"): "bind",
- path.Join(MountPrefix, "mount2", "share"): "container",
+ MountPrefix + "mount2.source": "foo",
+ MountPrefix + "mount2.type": "bind",
+ MountPrefix + "mount2.share": "container",
},
error: "have the same mount source",
},
@@ -191,3 +190,61 @@ func TestPodMountHintsErrors(t *testing.T) {
})
}
}
+
+func TestGetMountAccessType(t *testing.T) {
+ const source = "foo"
+ for _, tst := range []struct {
+ name string
+ annotations map[string]string
+ want FileAccessType
+ }{
+ {
+ name: "container=exclusive",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source,
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "container",
+ },
+ want: FileAccessExclusive,
+ },
+ {
+ name: "pod=shared",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source,
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "pod",
+ },
+ want: FileAccessShared,
+ },
+ {
+ name: "shared=shared",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source,
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "shared",
+ },
+ want: FileAccessShared,
+ },
+ {
+ name: "default=shared",
+ annotations: map[string]string{
+ MountPrefix + "mount1.source": source + "mismatch",
+ MountPrefix + "mount1.type": "bind",
+ MountPrefix + "mount1.share": "container",
+ },
+ want: FileAccessShared,
+ },
+ } {
+ t.Run(tst.name, func(t *testing.T) {
+ spec := &specs.Spec{Annotations: tst.annotations}
+ podHints, err := newPodMountHints(spec)
+ if err != nil {
+ t.Fatalf("newPodMountHints failed: %v", err)
+ }
+ mounter := containerMounter{hints: podHints}
+ if got := mounter.getMountAccessType(specs.Mount{Source: source}); got != tst.want {
+ t.Errorf("getMountAccessType(), want: %v, got: %v", tst.want, got)
+ }
+ })
+ }
+}
diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go
index d1c0bb9b5..ce62236e5 100644
--- a/runsc/boot/limits.go
+++ b/runsc/boot/limits.go
@@ -16,12 +16,12 @@ package boot
import (
"fmt"
- "sync"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sync"
)
// Mapping from linux resource names to limits.LimitType.
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 0c0eba99e..40c6f99fd 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -16,26 +16,30 @@
package boot
import (
+ "errors"
"fmt"
mrand "math/rand"
"os"
"runtime"
- "sync"
"sync/atomic"
- "syscall"
gtime "time"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/memutil"
"gvisor.dev/gvisor/pkg/rand"
+ "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/fdimport"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/host"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -43,11 +47,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/sighandling"
- slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/syscalls/linux/vfs2"
"gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -59,43 +66,46 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/runsc/boot/filter"
_ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms.
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
// Include supported socket providers.
"gvisor.dev/gvisor/pkg/sentry/socket/hostinet"
_ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
_ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route"
+ _ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
_ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
)
-// Loader keeps state needed to start the kernel and run the container..
-type Loader struct {
- // k is the kernel.
- k *kernel.Kernel
-
- // ctrl is the control server.
- ctrl *controller
-
+type containerInfo struct {
conf *Config
- // console is set to true if terminal is enabled.
- console bool
+ // spec is the base configuration for the root container.
+ spec *specs.Spec
- watchdog *watchdog.Watchdog
+ // procArgs refers to the container's init task.
+ procArgs kernel.CreateProcessArgs
// stdioFDs contains stdin, stdout, and stderr.
stdioFDs []int
// goferFDs are the FDs that attach the sandbox to the gofers.
goferFDs []int
+}
- // spec is the base configuration for the root container.
- spec *specs.Spec
+// Loader keeps state needed to start the kernel and run the container..
+type Loader struct {
+ // k is the kernel.
+ k *kernel.Kernel
+
+ // ctrl is the control server.
+ ctrl *controller
- // startSignalForwarding enables forwarding of signals to the sandboxed
- // container. It should be called after the init process is loaded.
- startSignalForwarding func() func()
+ // root contains information about the root container in the sandbox.
+ root containerInfo
+
+ watchdog *watchdog.Watchdog
// stopSignalForwarding disables forwarding of signals to the sandboxed
// container. It should be called when a sandbox is destroyed.
@@ -104,9 +114,6 @@ type Loader struct {
// restore is set to true if we are restoring a container.
restore bool
- // rootProcArgs refers to the root sandbox init task.
- rootProcArgs kernel.CreateProcessArgs
-
// sandboxID is the ID for the whole sandbox.
sandboxID string
@@ -139,6 +146,9 @@ type execProcess struct {
// tty will be nil if the process is not attached to a terminal.
tty *host.TTYFileOperations
+ // tty will be nil if the process is not attached to a terminal.
+ ttyVFS2 *hostvfs2.TTYFileDescription
+
// pidnsPath is the pid namespace path in spec
pidnsPath string
}
@@ -146,9 +156,6 @@ type execProcess struct {
func init() {
// Initialize the random number generator.
mrand.Seed(gtime.Now().UnixNano())
-
- // Register the global syscall table.
- kernel.RegisterSyscallTable(slinux.AMD64)
}
// Args are the arguments for New().
@@ -159,16 +166,18 @@ type Args struct {
Spec *specs.Spec
// Conf is the system configuration.
Conf *Config
- // ControllerFD is the FD to the URPC controller.
+ // ControllerFD is the FD to the URPC controller. The Loader takes ownership
+ // of this FD and may close it at any time.
ControllerFD int
- // Device is an optional argument that is passed to the platform.
+ // Device is an optional argument that is passed to the platform. The Loader
+ // takes ownership of this file and may close it at any time.
Device *os.File
- // GoferFDs is an array of FDs used to connect with the Gofer.
+ // GoferFDs is an array of FDs used to connect with the Gofer. The Loader
+ // takes ownership of these FDs and may close them at any time.
GoferFDs []int
- // StdioFDs is the stdio for the application.
+ // StdioFDs is the stdio for the application. The Loader takes ownership of
+ // these FDs and may close them at any time.
StdioFDs []int
- // Console is set to true if using TTY.
- Console bool
// NumCPU is the number of CPUs to create inside the sandbox.
NumCPU int
// TotalMem is the initial amount of total memory to report back to the
@@ -178,6 +187,9 @@ type Args struct {
UserLogFD int
}
+// make sure stdioFDs are always the same on initial start and on restore
+const startingStdioFD = 256
+
// New initializes a new kernel loader configured by spec.
// New also handles setting up a kernel for restoring a container.
func New(args Args) (*Loader, error) {
@@ -191,6 +203,16 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("setting up memory usage: %v", err)
}
+ // Is this a VFSv2 kernel?
+ if args.Conf.VFS2 {
+ kernel.VFS2Enabled = true
+ if args.Conf.FUSE {
+ kernel.FUSEEnabled = true
+ }
+
+ vfs2.Override()
+ }
+
// Create kernel and platform.
p, err := createPlatform(args.Conf, args.Device)
if err != nil {
@@ -210,9 +232,7 @@ func New(args Args) (*Loader, error) {
// Create VDSO.
//
// Pass k as the platform since it is savable, unlike the actual platform.
- //
- // FIXME(b/109889800): Use non-nil context.
- vdso, err := loader.PrepareVDSO(nil, k)
+ vdso, err := loader.PrepareVDSO(k)
if err != nil {
return nil, fmt.Errorf("creating vdso: %v", err)
}
@@ -228,11 +248,8 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("enabling strace: %v", err)
}
- // Create an empty network stack because the network namespace may be empty at
- // this point. Netns is configured before Run() is called. Netstack is
- // configured using a control uRPC message. Host network is configured inside
- // Run().
- networkStack, err := newEmptyNetworkStack(args.Conf, k)
+ // Create root network namespace/stack.
+ netns, err := newRootNetworkNamespace(args.Conf, k, k)
if err != nil {
return nil, fmt.Errorf("creating network: %v", err)
}
@@ -275,7 +292,7 @@ func New(args Args) (*Loader, error) {
FeatureSet: cpuid.HostFeatureSet(),
Timekeeper: tk,
RootUserNamespace: creds.UserNamespace,
- NetworkStack: networkStack,
+ RootNetworkNamespace: netns,
ApplicationCores: uint(args.NumCPU),
Vdso: vdso,
RootUTSNamespace: kernel.NewUTSNamespace(args.Spec.Hostname, args.Spec.Hostname, creds.UserNamespace),
@@ -286,6 +303,12 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("initializing kernel: %v", err)
}
+ if kernel.VFS2Enabled {
+ if err := registerFilesystems(k); err != nil {
+ return nil, fmt.Errorf("registering filesystems: %w", err)
+ }
+ }
+
if err := adjustDirentCache(k); err != nil {
return nil, err
}
@@ -300,9 +323,11 @@ func New(args Args) (*Loader, error) {
}
// Create a watchdog.
- dog := watchdog.New(k, watchdog.DefaultTimeout, args.Conf.WatchdogAction)
+ dogOpts := watchdog.DefaultOpts
+ dogOpts.TaskTimeoutAction = args.Conf.WatchdogAction
+ dog := watchdog.New(k, dogOpts)
- procArgs, err := newProcess(args.ID, args.Spec, creds, k, k.RootPIDNamespace())
+ procArgs, err := createProcessArgs(args.ID, args.Spec, creds, k, k.RootPIDNamespace())
if err != nil {
return nil, fmt.Errorf("creating init process for root container: %v", err)
}
@@ -316,19 +341,57 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("creating pod mount hints: %v", err)
}
+ if kernel.VFS2Enabled {
+ // Set up host mount that will be used for imported fds.
+ hostFilesystem, err := hostvfs2.NewFilesystem(k.VFS())
+ if err != nil {
+ return nil, fmt.Errorf("failed to create hostfs filesystem: %v", err)
+ }
+ defer hostFilesystem.DecRef(k.SupervisorContext())
+ hostMount, err := k.VFS().NewDisconnectedMount(hostFilesystem, nil, &vfs.MountOptions{})
+ if err != nil {
+ return nil, fmt.Errorf("failed to create hostfs mount: %v", err)
+ }
+ k.SetHostMount(hostMount)
+ }
+
+ // Make host FDs stable between invocations. Host FDs must map to the exact
+ // same number when the sandbox is restored. Otherwise the wrong FD will be
+ // used.
+ var stdioFDs []int
+ newfd := startingStdioFD
+ for _, fd := range args.StdioFDs {
+ // Check that newfd is unused to avoid clobbering over it.
+ if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) {
+ if err != nil {
+ return nil, fmt.Errorf("error checking for FD (%d) conflict: %w", newfd, err)
+ }
+ return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd)
+ }
+
+ err := unix.Dup3(fd, newfd, unix.O_CLOEXEC)
+ if err != nil {
+ return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err)
+ }
+ stdioFDs = append(stdioFDs, newfd)
+ _ = unix.Close(fd)
+ newfd++
+ }
+
eid := execID{cid: args.ID}
l := &Loader{
- k: k,
- conf: args.Conf,
- console: args.Console,
- watchdog: dog,
- spec: args.Spec,
- goferFDs: args.GoferFDs,
- stdioFDs: args.StdioFDs,
- rootProcArgs: procArgs,
- sandboxID: args.ID,
- processes: map[execID]*execProcess{eid: {}},
- mountHints: mountHints,
+ k: k,
+ watchdog: dog,
+ sandboxID: args.ID,
+ processes: map[execID]*execProcess{eid: {}},
+ mountHints: mountHints,
+ root: containerInfo{
+ conf: args.Conf,
+ stdioFDs: stdioFDs,
+ goferFDs: args.GoferFDs,
+ spec: args.Spec,
+ procArgs: procArgs,
+ },
}
// We don't care about child signals; some platforms can generate a
@@ -337,29 +400,6 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("ignore child stop signals failed: %v", err)
}
- // Handle signals by forwarding them to the root container process
- // (except for panic signal, which should cause a panic).
- l.startSignalForwarding = sighandling.PrepareHandler(func(sig linux.Signal) {
- // Panic signal should cause a panic.
- if args.Conf.PanicSignal != -1 && sig == linux.Signal(args.Conf.PanicSignal) {
- panic("Signal-induced panic")
- }
-
- // Otherwise forward to root container.
- deliveryMode := DeliverToProcess
- if args.Console {
- // Since we are running with a console, we should
- // forward the signal to the foreground process group
- // so that job control signals like ^C can be handled
- // properly.
- deliveryMode = DeliverToForegroundProcessGroup
- }
- log.Infof("Received external signal %d, mode: %v", sig, deliveryMode)
- if err := l.signal(args.ID, 0, int32(sig), deliveryMode); err != nil {
- log.Warningf("error sending signal %v to container %q: %v", sig, args.ID, err)
- }
- })
-
// Create the control server using the provided FD.
//
// This must be done *after* we have initialized the kernel since the
@@ -379,19 +419,24 @@ func New(args Args) (*Loader, error) {
return l, nil
}
-// newProcess creates a process that can be run with kernel.CreateProcess.
-func newProcess(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.Kernel, pidns *kernel.PIDNamespace) (kernel.CreateProcessArgs, error) {
+// createProcessArgs creates args that can be used with kernel.CreateProcess.
+func createProcessArgs(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.Kernel, pidns *kernel.PIDNamespace) (kernel.CreateProcessArgs, error) {
// Create initial limits.
ls, err := createLimitSet(spec)
if err != nil {
return kernel.CreateProcessArgs{}, fmt.Errorf("creating limits: %v", err)
}
+ wd := spec.Process.Cwd
+ if wd == "" {
+ wd = "/"
+ }
+
// Create the process arguments.
procArgs := kernel.CreateProcessArgs{
Argv: spec.Process.Args,
Envv: spec.Process.Env,
- WorkingDirectory: spec.Process.Cwd, // Defaults to '/' if empty.
+ WorkingDirectory: wd,
Credentials: creds,
Umask: 0022,
Limits: ls,
@@ -419,6 +464,11 @@ func (l *Loader) Destroy() {
l.stopSignalForwarding()
}
l.watchdog.Stop()
+
+ for i, fd := range l.root.stdioFDs {
+ _ = unix.Close(fd)
+ l.root.stdioFDs[i] = -1
+ }
}
func createPlatform(conf *Config, deviceFile *os.File) (platform.Platform, error) {
@@ -449,13 +499,13 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) {
}
func (l *Loader) installSeccompFilters() error {
- if l.conf.DisableSeccomp {
+ if l.root.conf.DisableSeccomp {
filter.Report("syscall filter is DISABLED. Running in less secure mode.")
} else {
opts := filter.Options{
Platform: l.k.Platform,
- HostNetwork: l.conf.Network == NetworkHost,
- ProfileEnable: l.conf.ProfileEnable,
+ HostNetwork: l.root.conf.Network == NetworkHost,
+ ProfileEnable: l.root.conf.ProfileEnable,
ControllerFD: l.ctrl.srv.FD(),
}
if err := filter.Install(opts); err != nil {
@@ -481,11 +531,11 @@ func (l *Loader) Run() error {
}
func (l *Loader) run() error {
- if l.conf.Network == NetworkHost {
+ if l.root.conf.Network == NetworkHost {
// Delay host network configuration to this point because network namespace
// is configured after the loader is created and before Run() is called.
log.Debugf("Configuring host network")
- stack := l.k.NetworkStack().(*hostinet.Stack)
+ stack := l.k.RootNetworkNamespace().Stack().(*hostinet.Stack)
if err := stack.Configure(); err != nil {
return err
}
@@ -503,8 +553,8 @@ func (l *Loader) run() error {
// If we are restoring, we do not want to create a process.
// l.restore is set by the container manager when a restore call is made.
if !l.restore {
- if l.conf.ProfileEnable {
- initializePProf()
+ if l.root.conf.ProfileEnable {
+ pprof.Initialize()
}
// Finally done with all configuration. Setup filters before user code
@@ -513,62 +563,50 @@ func (l *Loader) run() error {
return err
}
- // Create the FD map, which will set stdin, stdout, and stderr. If console
- // is true, then ioctl calls will be passed through to the host fd.
- ctx := l.rootProcArgs.NewContext(l.k)
- fdTable, err := createFDTable(ctx, l.console, l.stdioFDs)
- if err != nil {
- return fmt.Errorf("importing fds: %v", err)
- }
- // CreateProcess takes a reference on FDMap if successful. We won't need
- // ours either way.
- l.rootProcArgs.FDTable = fdTable
-
- // Setup the root container file system.
- l.startGoferMonitor(l.sandboxID, l.goferFDs)
-
- mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints)
- if err := mntr.processHints(l.conf); err != nil {
- return err
- }
- if err := setupContainerFS(ctx, l.conf, mntr, &l.rootProcArgs); err != nil {
- return err
- }
-
- // Add the HOME enviroment variable if it is not already set.
- envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
- if err != nil {
- return err
- }
- l.rootProcArgs.Envv = envv
-
// Create the root container init task. It will begin running
// when the kernel is started.
- if _, _, err := l.k.CreateProcess(l.rootProcArgs); err != nil {
- return fmt.Errorf("creating init process: %v", err)
+ if _, err := l.createContainerProcess(true, l.sandboxID, &l.root, ep); err != nil {
+ return err
}
-
- // CreateProcess takes a reference on FDTable if successful.
- l.rootProcArgs.FDTable.DecRef()
}
ep.tg = l.k.GlobalInit()
- if ns, ok := specutils.GetNS(specs.PIDNamespace, l.spec); ok {
+ if ns, ok := specutils.GetNS(specs.PIDNamespace, l.root.spec); ok {
ep.pidnsPath = ns.Path
}
- if l.console {
- ttyFile, _ := l.rootProcArgs.FDTable.Get(0)
- defer ttyFile.DecRef()
- ep.tty = ttyFile.FileOperations.(*host.TTYFileOperations)
- // Set the foreground process group on the TTY to the global
- // init process group, since that is what we are about to
- // start running.
- ep.tty.InitForegroundProcessGroup(ep.tg.ProcessGroup())
- }
+ // Handle signals by forwarding them to the root container process
+ // (except for panic signal, which should cause a panic).
+ l.stopSignalForwarding = sighandling.StartSignalForwarding(func(sig linux.Signal) {
+ // Panic signal should cause a panic.
+ if l.root.conf.PanicSignal != -1 && sig == linux.Signal(l.root.conf.PanicSignal) {
+ panic("Signal-induced panic")
+ }
- // Start signal forwarding only after an init process is created.
- l.stopSignalForwarding = l.startSignalForwarding()
+ // Otherwise forward to root container.
+ deliveryMode := DeliverToProcess
+ if l.root.spec.Process.Terminal {
+ // Since we are running with a console, we should forward the signal to
+ // the foreground process group so that job control signals like ^C can
+ // be handled properly.
+ deliveryMode = DeliverToForegroundProcessGroup
+ }
+ log.Infof("Received external signal %d, mode: %v", sig, deliveryMode)
+ if err := l.signal(l.sandboxID, 0, int32(sig), deliveryMode); err != nil {
+ log.Warningf("error sending signal %v to container %q: %v", sig, l.sandboxID, err)
+ }
+ })
+
+ // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again
+ // either in createFDTable() during initial start or in descriptor.initAfterLoad()
+ // during restore, we can release l.stdioFDs now. VFS2 takes ownership of the
+ // passed FDs, so only close for VFS1.
+ if !kernel.VFS2Enabled {
+ for i, fd := range l.root.stdioFDs {
+ _ = unix.Close(fd)
+ l.root.stdioFDs[i] = -1
+ }
+ }
log.Infof("Process should have started...")
l.watchdog.Start()
@@ -601,8 +639,8 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file
l.mu.Lock()
defer l.mu.Unlock()
- eid := execID{cid: cid}
- if _, ok := l.processes[eid]; !ok {
+ ep := l.processes[execID{cid: cid}]
+ if ep == nil {
return fmt.Errorf("trying to start a deleted container %q", cid)
}
@@ -636,61 +674,112 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file
if pidns == nil {
pidns = l.k.RootPIDNamespace().NewChild(l.k.RootUserNamespace())
}
- l.processes[eid].pidnsPath = ns.Path
+ ep.pidnsPath = ns.Path
} else {
pidns = l.k.RootPIDNamespace()
}
- procArgs, err := newProcess(cid, spec, creds, l.k, pidns)
+
+ info := &containerInfo{
+ conf: conf,
+ spec: spec,
+ }
+ info.procArgs, err = createProcessArgs(cid, spec, creds, l.k, pidns)
if err != nil {
return fmt.Errorf("creating new process: %v", err)
}
// setupContainerFS() dups stdioFDs, so we don't need to dup them here.
- var stdioFDs []int
for _, f := range files[:3] {
- stdioFDs = append(stdioFDs, int(f.Fd()))
- }
-
- // Create the FD map, which will set stdin, stdout, and stderr.
- ctx := procArgs.NewContext(l.k)
- fdTable, err := createFDTable(ctx, false, stdioFDs)
- if err != nil {
- return fmt.Errorf("importing fds: %v", err)
+ info.stdioFDs = append(info.stdioFDs, int(f.Fd()))
}
- // CreateProcess takes a reference on fdTable if successful. We won't
- // need ours either way.
- procArgs.FDTable = fdTable
// Can't take ownership away from os.File. dup them to get a new FDs.
- var goferFDs []int
for _, f := range files[3:] {
- fd, err := syscall.Dup(int(f.Fd()))
+ fd, err := unix.Dup(int(f.Fd()))
if err != nil {
return fmt.Errorf("failed to dup file: %v", err)
}
- goferFDs = append(goferFDs, fd)
+ info.goferFDs = append(info.goferFDs, fd)
}
+ tg, err := l.createContainerProcess(false, cid, info, ep)
+ if err != nil {
+ return err
+ }
+
+ // Success!
+ l.k.StartProcess(tg)
+ ep.tg = tg
+ return nil
+}
+
+func (l *Loader) createContainerProcess(root bool, cid string, info *containerInfo, ep *execProcess) (*kernel.ThreadGroup, error) {
+ console := false
+ if root {
+ // Only root container supports terminal for now.
+ console = info.spec.Process.Terminal
+ }
+
+ // Create the FD map, which will set stdin, stdout, and stderr.
+ ctx := info.procArgs.NewContext(l.k)
+ fdTable, ttyFile, ttyFileVFS2, err := createFDTable(ctx, console, info.stdioFDs)
+ if err != nil {
+ return nil, fmt.Errorf("importing fds: %v", err)
+ }
+ // CreateProcess takes a reference on fdTable if successful. We won't need
+ // ours either way.
+ info.procArgs.FDTable = fdTable
+
// Setup the child container file system.
- l.startGoferMonitor(cid, goferFDs)
+ l.startGoferMonitor(cid, info.goferFDs)
- mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints)
- if err := setupContainerFS(ctx, conf, mntr, &procArgs); err != nil {
- return err
+ mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints)
+ if root {
+ if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil {
+ return nil, err
+ }
+ }
+ if err := setupContainerFS(ctx, info.conf, mntr, &info.procArgs); err != nil {
+ return nil, err
}
- // Create and start the new process.
- tg, _, err := l.k.CreateProcess(procArgs)
+ // Add the HOME enviroment variable if it is not already set.
+ var envv []string
+ if kernel.VFS2Enabled {
+ envv, err = user.MaybeAddExecUserHomeVFS2(ctx, info.procArgs.MountNamespaceVFS2,
+ info.procArgs.Credentials.RealKUID, info.procArgs.Envv)
+
+ } else {
+ envv, err = user.MaybeAddExecUserHome(ctx, info.procArgs.MountNamespace,
+ info.procArgs.Credentials.RealKUID, info.procArgs.Envv)
+ }
if err != nil {
- return fmt.Errorf("creating process: %v", err)
+ return nil, err
}
- l.k.StartProcess(tg)
+ info.procArgs.Envv = envv
+ // Create and start the new process.
+ tg, _, err := l.k.CreateProcess(info.procArgs)
+ if err != nil {
+ return nil, fmt.Errorf("creating process: %v", err)
+ }
// CreateProcess takes a reference on FDTable if successful.
- procArgs.FDTable.DecRef()
+ info.procArgs.FDTable.DecRef(ctx)
+
+ // Set the foreground process group on the TTY to the global init process
+ // group, since that is what we are about to start running.
+ if root {
+ switch {
+ case ttyFileVFS2 != nil:
+ ep.ttyVFS2 = ttyFileVFS2
+ ttyFileVFS2.InitForegroundProcessGroup(tg.ProcessGroup())
+ case ttyFile != nil:
+ ep.tty = ttyFile
+ ttyFile.InitForegroundProcessGroup(tg.ProcessGroup())
+ }
+ }
- l.processes[eid].tg = tg
- return nil
+ return tg, nil
}
// startGoferMonitor runs a goroutine to monitor gofer's health. It polls on
@@ -738,14 +827,14 @@ func (l *Loader) destroyContainer(cid string) error {
l.mu.Lock()
defer l.mu.Unlock()
- _, _, started, err := l.threadGroupFromIDLocked(execID{cid: cid})
+ tg, err := l.tryThreadGroupFromIDLocked(execID{cid: cid})
if err != nil {
// Container doesn't exist.
return err
}
- // The container exists, has it been started?
- if started {
+ // The container exists, but has it been started?
+ if tg != nil {
if err := l.signalAllProcesses(cid, int32(linux.SIGKILL)); err != nil {
return fmt.Errorf("sending SIGKILL to all container processes: %v", err)
}
@@ -787,45 +876,63 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
l.mu.Lock()
defer l.mu.Unlock()
- tg, _, started, err := l.threadGroupFromIDLocked(execID{cid: args.ContainerID})
+ tg, err := l.tryThreadGroupFromIDLocked(execID{cid: args.ContainerID})
if err != nil {
return 0, err
}
- if !started {
+ if tg == nil {
return 0, fmt.Errorf("container %q not started", args.ContainerID)
}
// Get the container MountNamespace from the Task.
- tg.Leader().WithMuLocked(func(t *kernel.Task) {
- // task.MountNamespace() does not take a ref, so we must do so
- // ourselves.
- args.MountNamespace = t.MountNamespace()
- args.MountNamespace.IncRef()
- })
- defer args.MountNamespace.DecRef()
+ if kernel.VFS2Enabled {
+ // task.MountNamespace() does not take a ref, so we must do so ourselves.
+ args.MountNamespaceVFS2 = tg.Leader().MountNamespaceVFS2()
+ args.MountNamespaceVFS2.IncRef()
+ } else {
+ tg.Leader().WithMuLocked(func(t *kernel.Task) {
+ // task.MountNamespace() does not take a ref, so we must do so ourselves.
+ args.MountNamespace = t.MountNamespace()
+ args.MountNamespace.IncRef()
+ })
+ }
- // Add the HOME enviroment varible if it is not already set.
- root := args.MountNamespace.Root()
- defer root.DecRef()
- ctx := fs.WithRoot(l.k.SupervisorContext(), root)
- envv, err := maybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
- if err != nil {
- return 0, err
+ // Add the HOME environment variable if it is not already set.
+ if kernel.VFS2Enabled {
+ root := args.MountNamespaceVFS2.Root()
+ ctx := vfs.WithRoot(l.k.SupervisorContext(), root)
+ defer args.MountNamespaceVFS2.DecRef(ctx)
+ defer root.DecRef(ctx)
+ envv, err := user.MaybeAddExecUserHomeVFS2(ctx, args.MountNamespaceVFS2, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
+ } else {
+ root := args.MountNamespace.Root()
+ ctx := fs.WithRoot(l.k.SupervisorContext(), root)
+ defer args.MountNamespace.DecRef(ctx)
+ defer root.DecRef(ctx)
+ envv, err := user.MaybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
}
- args.Envv = envv
// Start the process.
proc := control.Proc{Kernel: l.k}
args.PIDNamespace = tg.PIDNamespace()
- newTG, tgid, ttyFile, err := control.ExecAsync(&proc, args)
+ newTG, tgid, ttyFile, ttyFileVFS2, err := control.ExecAsync(&proc, args)
if err != nil {
return 0, err
}
eid := execID{cid: args.ContainerID, pid: tgid}
l.processes[eid] = &execProcess{
- tg: newTG,
- tty: ttyFile,
+ tg: newTG,
+ tty: ttyFile,
+ ttyVFS2: ttyFileVFS2,
}
log.Debugf("updated processes: %v", l.processes)
@@ -836,7 +943,7 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
func (l *Loader) waitContainer(cid string, waitStatus *uint32) error {
// Don't defer unlock, as doing so would make it impossible for
// multiple clients to wait on the same container.
- tg, _, err := l.threadGroupFromID(execID{cid: cid})
+ tg, err := l.threadGroupFromID(execID{cid: cid})
if err != nil {
return fmt.Errorf("can't wait for container %q: %v", cid, err)
}
@@ -855,7 +962,7 @@ func (l *Loader) waitPID(tgid kernel.ThreadID, cid string, waitStatus *uint32) e
// Try to find a process that was exec'd
eid := execID{cid: cid, pid: tgid}
- execTG, _, err := l.threadGroupFromID(eid)
+ execTG, err := l.threadGroupFromID(eid)
if err == nil {
ws := l.wait(execTG)
*waitStatus = ws
@@ -869,7 +976,7 @@ func (l *Loader) waitPID(tgid kernel.ThreadID, cid string, waitStatus *uint32) e
// The caller may be waiting on a process not started directly via exec.
// In this case, find the process in the container's PID namespace.
- initTG, _, err := l.threadGroupFromID(execID{cid: cid})
+ initTG, err := l.threadGroupFromID(execID{cid: cid})
if err != nil {
return fmt.Errorf("waiting for PID %d: %v", tgid, err)
}
@@ -902,50 +1009,98 @@ func (l *Loader) WaitExit() kernel.ExitStatus {
// Wait for container.
l.k.WaitExited()
+ // Cleanup
+ l.ctrl.stop()
+
+ refs.OnExit()
+
return l.k.GlobalInit().ExitStatus()
}
-func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
+func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) {
+ // Create an empty network stack because the network namespace may be empty at
+ // this point. Netns is configured before Run() is called. Netstack is
+ // configured using a control uRPC message. Host network is configured inside
+ // Run().
switch conf.Network {
case NetworkHost:
- return hostinet.NewStack(), nil
+ // No network namespacing support for hostinet yet, hence creator is nil.
+ return inet.NewRootNamespace(hostinet.NewStack(), nil), nil
case NetworkNone, NetworkSandbox:
- // NetworkNone sets up loopback using netstack.
- netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
- transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
- s := netstack.Stack{stack.New(stack.Options{
- NetworkProtocols: netProtos,
- TransportProtocols: transProtos,
- Clock: clock,
- Stats: netstack.Metrics,
- HandleLocal: true,
- // Enable raw sockets for users with sufficient
- // privileges.
- RawFactory: raw.EndpointFactory{},
- })}
-
- // Enable SACK Recovery.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
- return nil, fmt.Errorf("failed to enable SACK: %v", err)
+ s, err := newEmptySandboxNetworkStack(clock, uniqueID)
+ if err != nil {
+ return nil, err
}
+ creator := &sandboxNetstackCreator{
+ clock: clock,
+ uniqueID: uniqueID,
+ }
+ return inet.NewRootNamespace(s, creator), nil
- // Set default TTLs as required by socket/netstack.
- s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
- s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ default:
+ panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ }
- // Enable Receive Buffer Auto-Tuning.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
- }
+}
- s.FillDefaultIPTables()
+func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := netstack.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: netstack.Metrics,
+ HandleLocal: true,
+ // Enable raw sockets for users with sufficient
+ // privileges.
+ RawFactory: raw.EndpointFactory{},
+ UniqueID: uniqueID,
+ })}
- return &s, nil
+ // Enable SACK Recovery.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
+ return nil, fmt.Errorf("failed to enable SACK: %s", err)
+ }
- default:
- panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ // Set default TTLs as required by socket/netstack.
+ s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+
+ // Enable Receive Buffer Auto-Tuning.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ return &s, nil
+}
+
+// sandboxNetstackCreator implements kernel.NetworkStackCreator.
+//
+// +stateify savable
+type sandboxNetstackCreator struct {
+ clock tcpip.Clock
+ uniqueID stack.UniqueID
+}
+
+// CreateStack implements kernel.NetworkStackCreator.CreateStack.
+func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) {
+ s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Setup loopback.
+ n := &Network{Stack: s.(*netstack.Stack).Stack}
+ nicID := tcpip.NICID(f.uniqueID.UniqueID())
+ link := DefaultLoopbackLink
+ linkEP := loopback.New()
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
+ return nil, err
}
+
+ return s, nil
}
// signal sends a signal to one or more processes in a container. If PID is 0,
@@ -975,8 +1130,7 @@ func (l *Loader) signal(cid string, pid, signo int32, mode SignalDeliveryMode) e
return fmt.Errorf("PID (%d) cannot be set when signaling all processes", pid)
}
// Check that the container has actually started before signaling it.
- _, _, err := l.threadGroupFromID(execID{cid: cid})
- if err != nil {
+ if _, err := l.threadGroupFromID(execID{cid: cid}); err != nil {
return err
}
if err := l.signalAllProcesses(cid, signo); err != nil {
@@ -990,16 +1144,16 @@ func (l *Loader) signal(cid string, pid, signo int32, mode SignalDeliveryMode) e
}
func (l *Loader) signalProcess(cid string, tgid kernel.ThreadID, signo int32) error {
- execTG, _, err := l.threadGroupFromID(execID{cid: cid, pid: tgid})
+ execTG, err := l.threadGroupFromID(execID{cid: cid, pid: tgid})
if err == nil {
// Send signal directly to the identified process.
- return execTG.SendSignal(&arch.SignalInfo{Signo: signo})
+ return l.k.SendExternalSignalThreadGroup(execTG, &arch.SignalInfo{Signo: signo})
}
// The caller may be signaling a process not started directly via exec.
// In this case, find the process in the container's PID namespace and
// signal it.
- initTG, _, err := l.threadGroupFromID(execID{cid: cid})
+ initTG, err := l.threadGroupFromID(execID{cid: cid})
if err != nil {
return fmt.Errorf("no thread group found: %v", err)
}
@@ -1010,25 +1164,43 @@ func (l *Loader) signalProcess(cid string, tgid kernel.ThreadID, signo int32) er
if tg.Leader().ContainerID() != cid {
return fmt.Errorf("process %d is part of a different container: %q", tgid, tg.Leader().ContainerID())
}
- return tg.SendSignal(&arch.SignalInfo{Signo: signo})
+ return l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo})
}
+// signalForegrondProcessGroup looks up foreground process group from the TTY
+// for the given "tgid" inside container "cid", and send the signal to it.
func (l *Loader) signalForegrondProcessGroup(cid string, tgid kernel.ThreadID, signo int32) error {
- // Lookup foreground process group from the TTY for the given process,
- // and send the signal to it.
- tg, tty, err := l.threadGroupFromID(execID{cid: cid, pid: tgid})
+ l.mu.Lock()
+ tg, err := l.tryThreadGroupFromIDLocked(execID{cid: cid, pid: tgid})
if err != nil {
+ l.mu.Unlock()
return fmt.Errorf("no thread group found: %v", err)
}
- if tty == nil {
+ if tg == nil {
+ l.mu.Unlock()
+ return fmt.Errorf("container %q not started", cid)
+ }
+
+ tty, ttyVFS2, err := l.ttyFromIDLocked(execID{cid: cid, pid: tgid})
+ l.mu.Unlock()
+ if err != nil {
+ return fmt.Errorf("no thread group found: %v", err)
+ }
+
+ var pg *kernel.ProcessGroup
+ switch {
+ case ttyVFS2 != nil:
+ pg = ttyVFS2.ForegroundProcessGroup()
+ case tty != nil:
+ pg = tty.ForegroundProcessGroup()
+ default:
return fmt.Errorf("no TTY attached")
}
- pg := tty.ForegroundProcessGroup()
if pg == nil {
// No foreground process group has been set. Signal the
// original thread group.
log.Warningf("No foreground process group for container %q and PID %d. Sending signal directly to PID %d.", cid, tgid, tgid)
- return tg.SendSignal(&arch.SignalInfo{Signo: signo})
+ return l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo})
}
// Send the signal to all processes in the process group.
var lastErr error
@@ -1036,7 +1208,7 @@ func (l *Loader) signalForegrondProcessGroup(cid string, tgid kernel.ThreadID, s
if tg.ProcessGroup() != pg {
continue
}
- if err := tg.SendSignal(&arch.SignalInfo{Signo: signo}); err != nil {
+ if err := l.k.SendExternalSignalThreadGroup(tg, &arch.SignalInfo{Signo: signo}); err != nil {
lastErr = err
}
}
@@ -1054,33 +1226,57 @@ func (l *Loader) signalAllProcesses(cid string, signo int32) error {
return l.k.SendContainerSignal(cid, &arch.SignalInfo{Signo: signo})
}
-// threadGroupFromID same as threadGroupFromIDLocked except that it acquires
-// mutex before calling it.
-func (l *Loader) threadGroupFromID(key execID) (*kernel.ThreadGroup, *host.TTYFileOperations, error) {
+// threadGroupFromID is similar to tryThreadGroupFromIDLocked except that it
+// acquires mutex before calling it and fails in case container hasn't started
+// yet.
+func (l *Loader) threadGroupFromID(key execID) (*kernel.ThreadGroup, error) {
l.mu.Lock()
defer l.mu.Unlock()
- tg, tty, ok, err := l.threadGroupFromIDLocked(key)
+ tg, err := l.tryThreadGroupFromIDLocked(key)
if err != nil {
- return nil, nil, err
+ return nil, err
}
- if !ok {
- return nil, nil, fmt.Errorf("container %q not started", key.cid)
+ if tg == nil {
+ return nil, fmt.Errorf("container %q not started", key.cid)
}
- return tg, tty, nil
+ return tg, nil
}
-// threadGroupFromIDLocked returns the thread group and TTY for the given
-// execution ID. TTY may be nil if the process is not attached to a terminal.
-// Also returns a boolean indicating whether the container has already started.
-// Returns error if execution ID is invalid or if the container cannot be
-// found (maybe it has been deleted). Caller must hold 'mu'.
-func (l *Loader) threadGroupFromIDLocked(key execID) (*kernel.ThreadGroup, *host.TTYFileOperations, bool, error) {
+// tryThreadGroupFromIDLocked returns the thread group for the given execution
+// ID. It may return nil in case the container has not started yet. Returns
+// error if execution ID is invalid or if the container cannot be found (maybe
+// it has been deleted). Caller must hold 'mu'.
+func (l *Loader) tryThreadGroupFromIDLocked(key execID) (*kernel.ThreadGroup, error) {
ep := l.processes[key]
if ep == nil {
- return nil, nil, false, fmt.Errorf("container %q not found", key.cid)
+ return nil, fmt.Errorf("container %q not found", key.cid)
}
- if ep.tg == nil {
- return nil, nil, false, nil
+ return ep.tg, nil
+}
+
+// ttyFromIDLocked returns the TTY files for the given execution ID. It may
+// return nil in case the container has not started yet. Returns error if
+// execution ID is invalid or if the container cannot be found (maybe it has
+// been deleted). Caller must hold 'mu'.
+func (l *Loader) ttyFromIDLocked(key execID) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ ep := l.processes[key]
+ if ep == nil {
+ return nil, nil, fmt.Errorf("container %q not found", key.cid)
+ }
+ return ep.tty, ep.ttyVFS2, nil
+}
+
+func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) {
+ if len(stdioFDs) != 3 {
+ return nil, nil, nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs))
+ }
+
+ k := kernel.KernelFromContext(ctx)
+ fdTable := k.NewFDTable()
+ ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, console, stdioFDs)
+ if err != nil {
+ fdTable.DecRef(ctx)
+ return nil, nil, nil, err
}
- return ep.tg, ep.tty, true, nil
+ return fdTable, ttyFile, ttyFileVFS2, nil
}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 147ff7703..aa3fdf96c 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -19,17 +19,20 @@ import (
"math/rand"
"os"
"reflect"
- "sync"
"syscall"
"testing"
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/control/server"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/pkg/sentry/context/contexttest"
+ "gvisor.dev/gvisor/pkg/sentry/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/fsgofer"
)
@@ -100,20 +103,29 @@ func startGofer(root string) (int, func(), error) {
return sandboxEnd, cleanup, nil
}
-func createLoader() (*Loader, func(), error) {
+func createLoader(vfsEnabled bool, spec *specs.Spec) (*Loader, func(), error) {
fd, err := server.CreateSocket(ControlSocketAddr(fmt.Sprintf("%010d", rand.Int())[:10]))
if err != nil {
return nil, nil, err
}
conf := testConfig()
- spec := testSpec()
+ conf.VFS2 = vfsEnabled
sandEnd, cleanup, err := startGofer(spec.Root.Path)
if err != nil {
return nil, nil, err
}
- stdio := []int{int(os.Stdin.Fd()), int(os.Stdout.Fd()), int(os.Stderr.Fd())}
+ // Loader takes ownership of stdio.
+ var stdio []int
+ for _, f := range []*os.File{os.Stdin, os.Stdout, os.Stderr} {
+ newFd, err := unix.Dup(int(f.Fd()))
+ if err != nil {
+ return nil, nil, err
+ }
+ stdio = append(stdio, newFd)
+ }
+
args := Args{
ID: "foo",
Spec: spec,
@@ -132,10 +144,20 @@ func createLoader() (*Loader, func(), error) {
// TestRun runs a simple application in a sandbox and checks that it succeeds.
func TestRun(t *testing.T) {
- l, cleanup, err := createLoader()
+ doRun(t, false)
+}
+
+// TestRunVFS2 runs TestRun in VFSv2.
+func TestRunVFS2(t *testing.T) {
+ doRun(t, true)
+}
+
+func doRun(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled, testSpec())
if err != nil {
t.Fatalf("error creating loader: %v", err)
}
+
defer l.Destroy()
defer cleanup()
@@ -169,7 +191,16 @@ func TestRun(t *testing.T) {
// TestStartSignal tests that the controller Start message will cause
// WaitForStartSignal to return.
func TestStartSignal(t *testing.T) {
- l, cleanup, err := createLoader()
+ doStartSignal(t, false)
+}
+
+// TestStartSignalVFS2 does TestStartSignal with VFS2.
+func TestStartSignalVFS2(t *testing.T) {
+ doStartSignal(t, true)
+}
+
+func doStartSignal(t *testing.T, vfsEnabled bool) {
+ l, cleanup, err := createLoader(vfsEnabled, testSpec())
if err != nil {
t.Fatalf("error creating loader: %v", err)
}
@@ -217,18 +248,19 @@ func TestStartSignal(t *testing.T) {
}
-// Test that MountNamespace can be created with various specs.
-func TestCreateMountNamespace(t *testing.T) {
- testCases := []struct {
- name string
- // Spec that will be used to create the mount manager. Note
- // that we can't mount procfs without a kernel, so each spec
- // MUST contain something other than procfs mounted at /proc.
- spec specs.Spec
- // Paths that are expected to exist in the resulting fs.
- expectedPaths []string
- }{
- {
+type CreateMountTestcase struct {
+ name string
+ // Spec that will be used to create the mount manager. Note
+ // that we can't mount procfs without a kernel, so each spec
+ // MUST contain something other than procfs mounted at /proc.
+ spec specs.Spec
+ // Paths that are expected to exist in the resulting fs.
+ expectedPaths []string
+}
+
+func createMountTestcases(vfs2 bool) []*CreateMountTestcase {
+ testCases := []*CreateMountTestcase{
+ &CreateMountTestcase{
// Only proc.
name: "only proc mount",
spec: specs.Spec{
@@ -270,7 +302,7 @@ func TestCreateMountNamespace(t *testing.T) {
// /dev, and /sys.
expectedPaths: []string{"/some/very/very/deep/path", "/proc", "/dev", "/sys"},
},
- {
+ &CreateMountTestcase{
// Mounts are nested inside each other.
name: "nested mounts",
spec: specs.Spec{
@@ -314,7 +346,7 @@ func TestCreateMountNamespace(t *testing.T) {
expectedPaths: []string{"/foo", "/foo/bar", "/foo/bar/baz", "/foo/qux",
"/foo/qux-quz", "/foo/some/very/very/deep/path", "/proc", "/dev", "/sys"},
},
- {
+ &CreateMountTestcase{
name: "mount inside /dev",
spec: specs.Spec{
Root: &specs.Root{
@@ -357,40 +389,46 @@ func TestCreateMountNamespace(t *testing.T) {
},
expectedPaths: []string{"/proc", "/dev", "/dev/fd-foo", "/dev/foo", "/dev/bar", "/sys"},
},
- {
- name: "mounts inside mandatory mounts",
- spec: specs.Spec{
- Root: &specs.Root{
- Path: os.TempDir(),
- Readonly: true,
+ }
+
+ vfsCase := &CreateMountTestcase{
+ name: "mounts inside mandatory mounts",
+ spec: specs.Spec{
+ Root: &specs.Root{
+ Path: os.TempDir(),
+ Readonly: true,
+ },
+ Mounts: []specs.Mount{
+ {
+ Destination: "/proc",
+ Type: "tmpfs",
},
- Mounts: []specs.Mount{
- {
- Destination: "/proc",
- Type: "tmpfs",
- },
- // We don't include /sys, and /tmp in
- // the spec, since they will be added
- // automatically.
- //
- // Instead, add submounts inside these
- // directories and make sure they are
- // visible under the mandatory mounts.
- {
- Destination: "/sys/bar",
- Type: "tmpfs",
- },
- {
- Destination: "/tmp/baz",
- Type: "tmpfs",
- },
+ // TODO (gvisor.dev/issue/1487): Re-add this case when sysfs supports
+ // MkDirAt in VFS2 (and remove the reduntant append).
+ // {
+ // Destination: "/sys/bar",
+ // Type: "tmpfs",
+ // },
+ //
+ {
+ Destination: "/tmp/baz",
+ Type: "tmpfs",
},
},
- expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"},
},
+ expectedPaths: []string{"/proc", "/sys" /* "/sys/bar" ,*/, "/tmp", "/tmp/baz"},
}
- for _, tc := range testCases {
+ if !vfs2 {
+ vfsCase.spec.Mounts = append(vfsCase.spec.Mounts, specs.Mount{Destination: "/sys/bar", Type: "tmpfs"})
+ vfsCase.expectedPaths = append(vfsCase.expectedPaths, "/sys/bar")
+ }
+ return append(testCases, vfsCase)
+}
+
+// Test that MountNamespace can be created with various specs.
+func TestCreateMountNamespace(t *testing.T) {
+ for _, tc := range createMountTestcases(false /* vfs2 */) {
t.Run(tc.name, func(t *testing.T) {
conf := testConfig()
ctx := contexttest.Context(t)
@@ -412,13 +450,59 @@ func TestCreateMountNamespace(t *testing.T) {
}
root := mns.Root()
- defer root.DecRef()
+ defer root.DecRef(ctx)
for _, p := range tc.expectedPaths {
maxTraversals := uint(0)
if d, err := mns.FindInode(ctx, root, root, p, &maxTraversals); err != nil {
t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err)
} else {
- d.DecRef()
+ d.DecRef(ctx)
+ }
+ }
+ })
+ }
+}
+
+// Test that MountNamespace can be created with various specs.
+func TestCreateMountNamespaceVFS2(t *testing.T) {
+ for _, tc := range createMountTestcases(true /* vfs2 */) {
+ t.Run(tc.name, func(t *testing.T) {
+ spec := testSpec()
+ spec.Mounts = tc.spec.Mounts
+ spec.Root = tc.spec.Root
+
+ t.Logf("Using root: %q", spec.Root.Path)
+ l, loaderCleanup, err := createLoader(true /* VFS2 Enabled */, spec)
+ if err != nil {
+ t.Fatalf("failed to create loader: %v", err)
+ }
+ defer l.Destroy()
+ defer loaderCleanup()
+
+ mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints)
+ if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil {
+ t.Fatalf("failed process hints: %v", err)
+ }
+
+ ctx := l.k.SupervisorContext()
+ mns, err := mntr.setupVFS2(ctx, l.root.conf, &l.root.procArgs)
+ if err != nil {
+ t.Fatalf("failed to setupVFS2: %v", err)
+ }
+
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ for _, p := range tc.expectedPaths {
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(p),
+ }
+
+ if d, err := l.k.VFS().GetDentryAt(ctx, l.root.procArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil {
+ t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err)
+ } else {
+ d.DecRef(ctx)
}
}
})
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index f98c5fd36..4e1fa7665 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -17,12 +17,16 @@ package boot
import (
"fmt"
"net"
+ "runtime"
+ "strings"
"syscall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/link/packetsocket"
+ "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -31,6 +35,32 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
+var (
+ // DefaultLoopbackLink contains IP addresses and routes of "127.0.0.1/8" and
+ // "::1/8" on "lo" interface.
+ DefaultLoopbackLink = LoopbackLink{
+ Name: "lo",
+ Addresses: []net.IP{
+ net.IP("\x7f\x00\x00\x01"),
+ net.IPv6loopback,
+ },
+ Routes: []Route{
+ {
+ Destination: net.IPNet{
+ IP: net.IPv4(0x7f, 0, 0, 0),
+ Mask: net.IPv4Mask(0xff, 0, 0, 0),
+ },
+ },
+ {
+ Destination: net.IPNet{
+ IP: net.IPv6loopback,
+ Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
+ },
+ },
+ },
+ }
+)
+
// Network exposes methods that can be used to configure a network stack.
type Network struct {
Stack *stack.Stack
@@ -48,6 +78,44 @@ type DefaultRoute struct {
Name string
}
+// QueueingDiscipline is used to specify the kind of Queueing Discipline to
+// apply for a give FDBasedLink.
+type QueueingDiscipline int
+
+const (
+ // QDiscNone disables any queueing for the underlying FD.
+ QDiscNone QueueingDiscipline = iota
+
+ // QDiscFIFO applies a simple fifo based queue to the underlying
+ // FD.
+ QDiscFIFO
+)
+
+// MakeQueueingDiscipline if possible the equivalent QueuingDiscipline for s
+// else returns an error.
+func MakeQueueingDiscipline(s string) (QueueingDiscipline, error) {
+ switch s {
+ case "none":
+ return QDiscNone, nil
+ case "fifo":
+ return QDiscFIFO, nil
+ default:
+ return 0, fmt.Errorf("unsupported qdisc specified: %q", s)
+ }
+}
+
+// String implements fmt.Stringer.
+func (q QueueingDiscipline) String() string {
+ switch q {
+ case QDiscNone:
+ return "none"
+ case QDiscFIFO:
+ return "fifo"
+ default:
+ panic(fmt.Sprintf("Invalid queueing discipline: %d", q))
+ }
+}
+
// FDBasedLink configures an fd-based link.
type FDBasedLink struct {
Name string
@@ -56,7 +124,10 @@ type FDBasedLink struct {
Routes []Route
GSOMaxSize uint32
SoftwareGSOEnabled bool
+ TXChecksumOffload bool
+ RXChecksumOffload bool
LinkAddress net.HardwareAddr
+ QDisc QueueingDiscipline
// NumChannels controls how many underlying FD's are to be used to
// create this endpoint.
@@ -80,7 +151,8 @@ type CreateLinksAndRoutesArgs struct {
LoopbackLinks []LoopbackLink
FDBasedLinks []FDBasedLink
- DefaultGateway DefaultRoute
+ Defaultv4Gateway DefaultRoute
+ Defaultv6Gateway DefaultRoute
}
// Empty returns true if route hasn't been set.
@@ -122,10 +194,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
nicID++
nicids[link.Name] = nicID
- ep := loopback.New()
+ linkEP := loopback.New()
log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses)
- if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
return err
}
@@ -157,7 +229,9 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
mac := tcpip.LinkAddress(link.LinkAddress)
- ep, err := fdbased.New(&fdbased.Options{
+ log.Infof("gso max size is: %d", link.GSOMaxSize)
+
+ linkEP, err := fdbased.New(&fdbased.Options{
FDs: FDs,
MTU: uint32(link.MTU),
EthernetHeader: true,
@@ -165,14 +239,25 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
PacketDispatchMode: fdbased.RecvMMsg,
GSOMaxSize: link.GSOMaxSize,
SoftwareGSOEnabled: link.SoftwareGSOEnabled,
- RXChecksumOffload: true,
+ TXChecksumOffload: link.TXChecksumOffload,
+ RXChecksumOffload: link.RXChecksumOffload,
})
if err != nil {
return err
}
+ switch link.QDisc {
+ case QDiscNone:
+ case QDiscFIFO:
+ log.Infof("Enabling FIFO QDisc on %q", link.Name)
+ linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000)
+ }
+
+ // Enable support for AF_PACKET sockets to receive outgoing packets.
+ linkEP = packetsocket.New(linkEP)
+
log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
- if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
return err
}
@@ -186,12 +271,24 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
}
- if !args.DefaultGateway.Route.Empty() {
- nicID, ok := nicids[args.DefaultGateway.Name]
+ if !args.Defaultv4Gateway.Route.Empty() {
+ nicID, ok := nicids[args.Defaultv4Gateway.Name]
if !ok {
- return fmt.Errorf("invalid interface name %q for default route", args.DefaultGateway.Name)
+ return fmt.Errorf("invalid interface name %q for default route", args.Defaultv4Gateway.Name)
}
- route, err := args.DefaultGateway.Route.toTcpipRoute(nicID)
+ route, err := args.Defaultv4Gateway.Route.toTcpipRoute(nicID)
+ if err != nil {
+ return err
+ }
+ routes = append(routes, route)
+ }
+
+ if !args.Defaultv6Gateway.Route.Empty() {
+ nicID, ok := nicids[args.Defaultv6Gateway.Name]
+ if !ok {
+ return fmt.Errorf("invalid interface name %q for default route", args.Defaultv6Gateway.Name)
+ }
+ route, err := args.Defaultv6Gateway.Route.toTcpipRoute(nicID)
if err != nil {
return err
}
@@ -205,15 +302,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
// createNICWithAddrs creates a NIC in the network stack and adds the given
// addresses.
-func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error {
- if loopback {
- if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil {
- return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err)
- }
- } else {
- if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil {
- return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err)
- }
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP) error {
+ opts := stack.NICOptions{Name: name}
+ if err := n.Stack.CreateNICWithOptions(id, sniffer.New(ep), opts); err != nil {
+ return fmt.Errorf("CreateNICWithOptions(%d, _, %+v) failed: %v", id, opts, err)
}
// Always start with an arp address for the NIC.
diff --git a/runsc/boot/platforms/BUILD b/runsc/boot/platforms/BUILD
index 03391cdca..77774f43c 100644
--- a/runsc/boot/platforms/BUILD
+++ b/runsc/boot/platforms/BUILD
@@ -1,11 +1,10 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
go_library(
name = "platforms",
srcs = ["platforms.go"],
- importpath = "gvisor.dev/gvisor/runsc/boot/platforms",
visibility = [
"//runsc:__subpackages__",
],
diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD
new file mode 100644
index 000000000..29cb42b2f
--- /dev/null
+++ b/runsc/boot/pprof/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pprof",
+ srcs = ["pprof.go"],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+)
diff --git a/runsc/boot/pprof.go b/runsc/boot/pprof/pprof.go
index 463362f02..1ded20dee 100644
--- a/runsc/boot/pprof.go
+++ b/runsc/boot/pprof/pprof.go
@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package boot
+// Package pprof provides a stub to initialize custom profilers.
+package pprof
-func initializePProf() {
+// Initialize will be called at boot for initializing custom profilers.
+func Initialize() {
}
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
new file mode 100644
index 000000000..08dce8b6c
--- /dev/null
+++ b/runsc/boot/vfs.go
@@ -0,0 +1,519 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package boot
+
+import (
+ "fmt"
+ "path"
+ "sort"
+ "strings"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/devices/memdev"
+ "gvisor.dev/gvisor/pkg/sentry/devices/ttydev"
+ "gvisor.dev/gvisor/pkg/sentry/devices/tundev"
+ "gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/fuse"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/overlay"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+func registerFilesystems(k *kernel.Kernel) error {
+ ctx := k.SupervisorContext()
+ creds := auth.NewRootCredentials(k.RootUserNamespace())
+ vfsObj := k.VFS()
+
+ vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ // TODO(b/29356795): Users may mount this once the terminals are in a
+ // usable state.
+ AllowUserMount: false,
+ })
+ vfsObj.MustRegisterFilesystemType(devtmpfs.Name, &devtmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(gofer.Name, &gofer.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(overlay.Name, &overlay.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(proc.Name, &proc.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(sys.Name, &sys.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+ vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
+
+ // Setup files in devtmpfs.
+ if err := memdev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering memdev: %w", err)
+ }
+ if err := ttydev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering ttydev: %w", err)
+ }
+ tunSupported := tundev.IsNetTunSupported(inet.StackFromContext(ctx))
+ if tunSupported {
+ if err := tundev.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering tundev: %v", err)
+ }
+ }
+
+ if kernel.FUSEEnabled {
+ if err := fuse.Register(vfsObj); err != nil {
+ return fmt.Errorf("registering fusedev: %w", err)
+ }
+ }
+
+ a, err := devtmpfs.NewAccessor(ctx, vfsObj, creds, devtmpfs.Name)
+ if err != nil {
+ return fmt.Errorf("creating devtmpfs accessor: %w", err)
+ }
+ defer a.Release(ctx)
+
+ if err := a.UserspaceInit(ctx); err != nil {
+ return fmt.Errorf("initializing userspace: %w", err)
+ }
+ if err := memdev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating memdev devtmpfs files: %w", err)
+ }
+ if err := ttydev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating ttydev devtmpfs files: %w", err)
+ }
+ if tunSupported {
+ if err := tundev.CreateDevtmpfsFiles(ctx, a); err != nil {
+ return fmt.Errorf("creating tundev devtmpfs files: %v", err)
+ }
+ }
+
+ if kernel.FUSEEnabled {
+ if err := fuse.CreateDevtmpfsFile(ctx, a); err != nil {
+ return fmt.Errorf("creating fusedev devtmpfs files: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ mns, err := mntr.setupVFS2(ctx, conf, procArgs)
+ if err != nil {
+ return fmt.Errorf("failed to setupFS: %w", err)
+ }
+ procArgs.MountNamespaceVFS2 = mns
+
+ // Resolve the executable path from working dir and environment.
+ resolved, err := user.ResolveExecutablePath(ctx, procArgs)
+ if err != nil {
+ return err
+ }
+ procArgs.Filename = resolved
+ return nil
+}
+
+func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) {
+ log.Infof("Configuring container's file system with VFS2")
+
+ // Create context with root credentials to mount the filesystem (the current
+ // user may not be privileged enough).
+ rootCreds := auth.NewRootCredentials(procArgs.Credentials.UserNamespace)
+ rootProcArgs := *procArgs
+ rootProcArgs.WorkingDirectory = "/"
+ rootProcArgs.Credentials = rootCreds
+ rootProcArgs.Umask = 0022
+ rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
+ rootCtx := procArgs.NewContext(c.k)
+
+ mns, err := c.createMountNamespaceVFS2(rootCtx, conf, rootCreds)
+ if err != nil {
+ return nil, fmt.Errorf("creating mount namespace: %w", err)
+ }
+ rootProcArgs.MountNamespaceVFS2 = mns
+
+ // Mount submounts.
+ if err := c.mountSubmountsVFS2(rootCtx, conf, mns, rootCreds); err != nil {
+ return nil, fmt.Errorf("mounting submounts vfs2: %w", err)
+ }
+ return mns, nil
+}
+
+func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *Config, creds *auth.Credentials) (*vfs.MountNamespace, error) {
+ fd := c.fds.remove()
+ opts := p9MountData(fd, conf.FileAccess, true /* vfs2 */)
+
+ if conf.OverlayfsStaleRead {
+ // We can't check for overlayfs here because sandbox is chroot'ed and gofer
+ // can only send mount options for specs.Mounts (specs.Root is missing
+ // Options field). So assume root is always on top of overlayfs.
+ opts = append(opts, "overlayfs_stale_read")
+ }
+
+ log.Infof("Mounting root over 9P, ioFD: %d", fd)
+ mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", gofer.Name, &vfs.GetFilesystemOptions{
+ Data: strings.Join(opts, ","),
+ })
+ if err != nil {
+ return nil, fmt.Errorf("setting up mount namespace: %w", err)
+ }
+ return mns, nil
+}
+
+func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials) error {
+ mounts, err := c.prepareMountsVFS2()
+ if err != nil {
+ return err
+ }
+
+ for i := range mounts {
+ submount := &mounts[i]
+ log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options)
+ if hint := c.hints.findMount(submount.Mount); hint != nil && hint.isSupported() {
+ if err := c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint); err != nil {
+ return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.Destination, err)
+ }
+ } else {
+ if err := c.mountSubmountVFS2(ctx, conf, mns, creds, submount); err != nil {
+ return fmt.Errorf("mount submount %q: %w", submount.Destination, err)
+ }
+ }
+ }
+
+ if err := c.mountTmpVFS2(ctx, conf, creds, mns); err != nil {
+ return fmt.Errorf(`mount submount "\tmp": %w`, err)
+ }
+ return nil
+}
+
+type mountAndFD struct {
+ specs.Mount
+ fd int
+}
+
+func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) {
+ // Associate bind mounts with their FDs before sorting since there is an
+ // undocumented assumption that FDs are dispensed in the order in which
+ // they are required by mounts.
+ var mounts []mountAndFD
+ for _, m := range c.mounts {
+ fd := -1
+ // Only bind mounts use host FDs; see
+ // containerMounter.getMountNameAndOptionsVFS2.
+ if m.Type == bind {
+ fd = c.fds.remove()
+ }
+ mounts = append(mounts, mountAndFD{
+ Mount: m,
+ fd: fd,
+ })
+ }
+ if err := c.checkDispenser(); err != nil {
+ return nil, err
+ }
+
+ // Sort the mounts so that we don't place children before parents.
+ sort.Slice(mounts, func(i, j int) bool {
+ return len(mounts[i].Destination) < len(mounts[j].Destination)
+ })
+
+ return mounts, nil
+}
+
+func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) error {
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(submount.Destination),
+ }
+ fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, submount)
+ if err != nil {
+ return fmt.Errorf("mountOptions failed: %w", err)
+ }
+ if len(fsName) == 0 {
+ // Filesystem is not supported (e.g. cgroup), just skip it.
+ return nil
+ }
+
+ if err := c.makeSyntheticMount(ctx, submount.Destination, root, creds); err != nil {
+ return err
+ }
+ if err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts); err != nil {
+ return fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts)
+ }
+ log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts.GetFilesystemOptions.Data)
+ return nil
+}
+
+// getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values
+// used for mounts.
+func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndFD) (string, *vfs.MountOptions, error) {
+ fsName := m.Type
+ var data []string
+
+ // Find filesystem name and FS specific data field.
+ switch m.Type {
+ case devpts.Name, devtmpfs.Name, proc.Name, sys.Name:
+ // Nothing to do.
+
+ case nonefs:
+ fsName = sys.Name
+
+ case tmpfs.Name:
+ var err error
+ data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...)
+ if err != nil {
+ return "", nil, err
+ }
+
+ case bind:
+ fsName = gofer.Name
+ if m.fd == 0 {
+ // Check that an FD was provided to fails fast. Technically FD=0 is valid,
+ // but unlikely to be correct in this context.
+ return "", nil, fmt.Errorf("9P mount requires a connection FD")
+ }
+ data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */)
+
+ default:
+ log.Warningf("ignoring unknown filesystem type %q", m.Type)
+ return "", nil, nil
+ }
+
+ opts := &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: strings.Join(data, ","),
+ },
+ InternalMount: true,
+ }
+
+ for _, o := range m.Options {
+ switch o {
+ case "rw":
+ opts.ReadOnly = false
+ case "ro":
+ opts.ReadOnly = true
+ case "noatime":
+ opts.Flags.NoATime = true
+ case "noexec":
+ opts.Flags.NoExec = true
+ default:
+ log.Warningf("ignoring unknown mount option %q", o)
+ }
+ }
+
+ if conf.Overlay {
+ // All writes go to upper, be paranoid and make lower readonly.
+ opts.ReadOnly = true
+ }
+ return fsName, opts, nil
+}
+
+func (c *containerMounter) makeSyntheticMount(ctx context.Context, currentPath string, root vfs.VirtualDentry, creds *auth.Credentials) error {
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(currentPath),
+ }
+ _, err := c.k.VFS().StatAt(ctx, creds, target, &vfs.StatOptions{})
+ if err == nil {
+ log.Debugf("Mount point %q already exists", currentPath)
+ return nil
+ }
+ if err != syserror.ENOENT {
+ return fmt.Errorf("stat failed for %q during mount point creation: %w", currentPath, err)
+ }
+
+ // Recurse to ensure parent is created and then create the mount point.
+ if err := c.makeSyntheticMount(ctx, path.Dir(currentPath), root, creds); err != nil {
+ return err
+ }
+ log.Debugf("Creating dir %q for mount point", currentPath)
+ mkdirOpts := &vfs.MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true}
+ if err := c.k.VFS().MkdirAt(ctx, creds, target, mkdirOpts); err != nil {
+ return fmt.Errorf("failed to create directory %q for mount: %w", currentPath, err)
+ }
+ return nil
+}
+
+// mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so.
+// Technically we don't have to mount tmpfs at /tmp, as we could just rely on
+// the host /tmp, but this is a nice optimization, and fixes some apps that call
+// mknod in /tmp. It's unsafe to mount tmpfs if:
+// 1. /tmp is mounted explicitly: we should not override user's wish
+// 2. /tmp is not empty: mounting tmpfs would hide existing files in /tmp
+//
+// Note that when there are submounts inside of '/tmp', directories for the
+// mount points must be present, making '/tmp' not empty anymore.
+func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds *auth.Credentials, mns *vfs.MountNamespace) error {
+ for _, m := range c.mounts {
+ // m.Destination has been cleaned, so it's to use equality here.
+ if m.Destination == "/tmp" {
+ log.Debugf(`Explict "/tmp" mount found, skipping internal tmpfs, mount: %+v`, m)
+ return nil
+ }
+ }
+
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse("/tmp"),
+ }
+ // TODO(gvisor.dev/issue/2782): Use O_PATH when available.
+ fd, err := c.k.VFS().OpenAt(ctx, creds, &pop, &vfs.OpenOptions{Flags: linux.O_RDONLY | linux.O_DIRECTORY})
+ switch err {
+ case nil:
+ defer fd.DecRef(ctx)
+
+ err := fd.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ if dirent.Name != "." && dirent.Name != ".." {
+ return syserror.ENOTEMPTY
+ }
+ return nil
+ }))
+ switch err {
+ case nil:
+ log.Infof(`Mounting internal tmpfs on top of empty "/tmp"`)
+ case syserror.ENOTEMPTY:
+ // If more than "." and ".." is found, skip internal tmpfs to prevent
+ // hiding existing files.
+ log.Infof(`Skipping internal tmpfs mount for "/tmp" because it's not empty`)
+ return nil
+ default:
+ return err
+ }
+ fallthrough
+
+ case syserror.ENOENT:
+ // No '/tmp' found (or fallthrough from above). It's safe to mount internal
+ // tmpfs.
+ tmpMount := specs.Mount{
+ Type: tmpfs.Name,
+ Destination: "/tmp",
+ // Sticky bit is added to prevent accidental deletion of files from
+ // another user. This is normally done for /tmp.
+ Options: []string{"mode=01777"},
+ }
+ return c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount})
+
+ case syserror.ENOTDIR:
+ // Not a dir?! Let it be.
+ return nil
+
+ default:
+ return fmt.Errorf(`opening "/tmp" inside container: %w`, err)
+ }
+}
+
+// processHintsVFS2 processes annotations that container hints about how volumes
+// should be mounted (e.g. a volume shared between containers). It must be
+// called for the root container only.
+func (c *containerMounter) processHintsVFS2(conf *Config, creds *auth.Credentials) error {
+ ctx := c.k.SupervisorContext()
+ for _, hint := range c.hints.mounts {
+ // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a
+ // common gofer to mount all shared volumes.
+ if hint.mount.Type != tmpfs.Name {
+ continue
+ }
+
+ log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
+ mnt, err := c.mountSharedMasterVFS2(ctx, conf, hint, creds)
+ if err != nil {
+ return fmt.Errorf("mounting shared master %q: %v", hint.name, err)
+ }
+ hint.vfsMount = mnt
+ }
+ return nil
+}
+
+// mountSharedMasterVFS2 mounts the master of a volume that is shared among
+// containers in a pod.
+func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) {
+ // Map mount type to filesystem name, and parse out the options that we are
+ // capable of dealing with.
+ mntFD := &mountAndFD{Mount: hint.mount}
+ fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, mntFD)
+ if err != nil {
+ return nil, err
+ }
+ if len(fsName) == 0 {
+ return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type)
+ }
+ return c.k.VFS().MountDisconnected(ctx, creds, "", fsName, opts)
+}
+
+// mountSharedSubmount binds mount to a previously mounted volume that is shared
+// among containers in the same pod.
+func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) error {
+ if err := source.checkCompatible(mount); err != nil {
+ return err
+ }
+
+ _, opts, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount})
+ if err != nil {
+ return err
+ }
+ newMnt, err := c.k.VFS().NewDisconnectedMount(source.vfsMount.Filesystem(), source.vfsMount.Root(), opts)
+ if err != nil {
+ return err
+ }
+ defer newMnt.DecRef(ctx)
+
+ root := mns.Root()
+ defer root.DecRef(ctx)
+ if err := c.makeSyntheticMount(ctx, mount.Destination, root, creds); err != nil {
+ return err
+ }
+
+ target := &vfs.PathOperation{
+ Root: root,
+ Start: root,
+ Path: fspath.Parse(mount.Destination),
+ }
+ if err := c.k.VFS().ConnectMountAt(ctx, creds, newMnt, target); err != nil {
+ return err
+ }
+ log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name)
+ return nil
+}
diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD
index d6165f9e5..37f4253ba 100644
--- a/runsc/cgroup/BUILD
+++ b/runsc/cgroup/BUILD
@@ -1,17 +1,16 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
go_library(
name = "cgroup",
srcs = ["cgroup.go"],
- importpath = "gvisor.dev/gvisor/runsc/cgroup",
visibility = ["//:sandbox"],
deps = [
+ "//pkg/cleanup",
"//pkg/log",
- "//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
)
@@ -19,6 +18,10 @@ go_test(
name = "cgroup_test",
size = "small",
srcs = ["cgroup_test.go"],
- embed = [":cgroup"],
+ library = ":cgroup",
tags = ["local"],
+ deps = [
+ "//pkg/test/testutil",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ ],
)
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
index ab3a25b9b..8fbc3887a 100644
--- a/runsc/cgroup/cgroup.go
+++ b/runsc/cgroup/cgroup.go
@@ -19,6 +19,7 @@ package cgroup
import (
"bufio"
"context"
+ "errors"
"fmt"
"io/ioutil"
"os"
@@ -30,29 +31,31 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/runsc/specutils"
)
const (
cgroupRoot = "/sys/fs/cgroup"
)
-var controllers = map[string]controller{
- "blkio": &blockIO{},
- "cpu": &cpu{},
- "cpuset": &cpuSet{},
- "memory": &memory{},
- "net_cls": &networkClass{},
- "net_prio": &networkPrio{},
+var controllers = map[string]config{
+ "blkio": config{ctrlr: &blockIO{}},
+ "cpu": config{ctrlr: &cpu{}},
+ "cpuset": config{ctrlr: &cpuSet{}},
+ "hugetlb": config{ctrlr: &hugeTLB{}, optional: true},
+ "memory": config{ctrlr: &memory{}},
+ "net_cls": config{ctrlr: &networkClass{}},
+ "net_prio": config{ctrlr: &networkPrio{}},
+ "pids": config{ctrlr: &pids{}},
// These controllers either don't have anything in the OCI spec or is
- // irrevalant for a sandbox, e.g. pids.
- "devices": &noop{},
- "freezer": &noop{},
- "perf_event": &noop{},
- "pids": &noop{},
- "systemd": &noop{},
+ // irrelevant for a sandbox.
+ "devices": config{ctrlr: &noop{}},
+ "freezer": config{ctrlr: &noop{}},
+ "perf_event": config{ctrlr: &noop{}},
+ "rdma": config{ctrlr: &noop{}, optional: true},
+ "systemd": config{ctrlr: &noop{}},
}
func setOptionalValueInt(path, name string, val *int64) error {
@@ -89,7 +92,17 @@ func setOptionalValueUint16(path, name string, val *uint16) error {
func setValue(path, name, data string) error {
fullpath := filepath.Join(path, name)
- return ioutil.WriteFile(fullpath, []byte(data), 0700)
+
+ // Retry writes on EINTR; see:
+ // https://github.com/golang/go/issues/38033
+ for {
+ err := ioutil.WriteFile(fullpath, []byte(data), 0700)
+ if err == nil {
+ return nil
+ } else if !errors.Is(err, syscall.EINTR) {
+ return err
+ }
+ }
}
func getValue(path, name string) (string, error) {
@@ -101,6 +114,14 @@ func getValue(path, name string) (string, error) {
return string(out), nil
}
+func getInt(path, name string) (int, error) {
+ s, err := getValue(path, name)
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(strings.TrimSpace(s))
+}
+
// fillFromAncestor sets the value of a cgroup file from the first ancestor
// that has content. It does nothing if the file in 'path' has already been set.
func fillFromAncestor(path string) (string, error) {
@@ -114,15 +135,23 @@ func fillFromAncestor(path string) (string, error) {
return val, nil
}
- // File is not set, recurse to parent and then set here.
+ // File is not set, recurse to parent and then set here.
name := filepath.Base(path)
parent := filepath.Dir(filepath.Dir(path))
val, err = fillFromAncestor(filepath.Join(parent, name))
if err != nil {
return "", err
}
- if err := ioutil.WriteFile(path, []byte(val), 0700); err != nil {
- return "", err
+
+ // Retry writes on EINTR; see:
+ // https://github.com/golang/go/issues/38033
+ for {
+ err := ioutil.WriteFile(path, []byte(val), 0700)
+ if err == nil {
+ break
+ } else if !errors.Is(err, syscall.EINTR) {
+ return "", err
+ }
}
return val, nil
}
@@ -188,8 +217,9 @@ func LoadPaths(pid string) (map[string]string, error) {
return paths, nil
}
-// Cgroup represents a group inside all controllers. For example: Name='/foo/bar'
-// maps to /sys/fs/cgroup/<controller>/foo/bar on all controllers.
+// Cgroup represents a group inside all controllers. For example:
+// Name='/foo/bar' maps to /sys/fs/cgroup/<controller>/foo/bar on
+// all controllers.
type Cgroup struct {
Name string `json:"name"`
Parents map[string]string `json:"parents"`
@@ -234,16 +264,20 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error {
// The Cleanup object cleans up partially created cgroups when an error occurs.
// Errors occuring during cleanup itself are ignored.
- clean := specutils.MakeCleanup(func() { _ = c.Uninstall() })
+ clean := cleanup.Make(func() { _ = c.Uninstall() })
defer clean.Clean()
- for key, ctrl := range controllers {
+ for key, cfg := range controllers {
path := c.makePath(key)
if err := os.MkdirAll(path, 0755); err != nil {
+ if cfg.optional && errors.Is(err, syscall.EROFS) {
+ log.Infof("Skipping cgroup %q", key)
+ continue
+ }
return err
}
if res != nil {
- if err := ctrl.set(res, path); err != nil {
+ if err := cfg.ctrlr.set(res, path); err != nil {
return err
}
}
@@ -313,16 +347,35 @@ func (c *Cgroup) Join() (func(), error) {
}
// Now join the cgroups.
- for key := range controllers {
+ for key, cfg := range controllers {
path := c.makePath(key)
log.Debugf("Joining cgroup %q", path)
if err := setValue(path, "cgroup.procs", "0"); err != nil {
+ if cfg.optional && os.IsNotExist(err) {
+ continue
+ }
return undo, err
}
}
return undo, nil
}
+func (c *Cgroup) CPUQuota() (float64, error) {
+ path := c.makePath("cpu")
+ quota, err := getInt(path, "cpu.cfs_quota_us")
+ if err != nil {
+ return -1, err
+ }
+ period, err := getInt(path, "cpu.cfs_period_us")
+ if err != nil {
+ return -1, err
+ }
+ if quota <= 0 || period <= 0 {
+ return -1, err
+ }
+ return float64(quota) / float64(period), nil
+}
+
// NumCPU returns the number of CPUs configured in 'cpuset/cpuset.cpus'.
func (c *Cgroup) NumCPU() (int, error) {
path := c.makePath("cpuset")
@@ -351,6 +404,11 @@ func (c *Cgroup) makePath(controllerName string) string {
return filepath.Join(cgroupRoot, controllerName, path)
}
+type config struct {
+ ctrlr controller
+ optional bool
+}
+
type controller interface {
set(*specs.LinuxResources, string) error
}
@@ -406,7 +464,13 @@ func (*cpu) set(spec *specs.LinuxResources, path string) error {
if err := setOptionalValueInt(path, "cpu.cfs_quota_us", spec.CPU.Quota); err != nil {
return err
}
- return setOptionalValueUint(path, "cpu.cfs_period_us", spec.CPU.Period)
+ if err := setOptionalValueUint(path, "cpu.cfs_period_us", spec.CPU.Period); err != nil {
+ return err
+ }
+ if err := setOptionalValueUint(path, "cpu.rt_period_us", spec.CPU.RealtimePeriod); err != nil {
+ return err
+ }
+ return setOptionalValueInt(path, "cpu.rt_runtime_us", spec.CPU.RealtimeRuntime)
}
type cpuSet struct{}
@@ -447,13 +511,17 @@ func (*blockIO) set(spec *specs.LinuxResources, path string) error {
}
for _, dev := range spec.BlockIO.WeightDevice {
- val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, dev.Weight)
- if err := setValue(path, "blkio.weight_device", val); err != nil {
- return err
+ if dev.Weight != nil {
+ val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, *dev.Weight)
+ if err := setValue(path, "blkio.weight_device", val); err != nil {
+ return err
+ }
}
- val = fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, dev.LeafWeight)
- if err := setValue(path, "blkio.leaf_weight_device", val); err != nil {
- return err
+ if dev.LeafWeight != nil {
+ val := fmt.Sprintf("%d:%d %d", dev.Major, dev.Minor, *dev.LeafWeight)
+ if err := setValue(path, "blkio.leaf_weight_device", val); err != nil {
+ return err
+ }
}
}
if err := setThrottle(path, "blkio.throttle.read_bps_device", spec.BlockIO.ThrottleReadBpsDevice); err != nil {
@@ -501,3 +569,26 @@ func (*networkPrio) set(spec *specs.LinuxResources, path string) error {
}
return nil
}
+
+type pids struct{}
+
+func (*pids) set(spec *specs.LinuxResources, path string) error {
+ if spec.Pids == nil || spec.Pids.Limit <= 0 {
+ return nil
+ }
+ val := strconv.FormatInt(spec.Pids.Limit, 10)
+ return setValue(path, "pids.max", val)
+}
+
+type hugeTLB struct{}
+
+func (*hugeTLB) set(spec *specs.LinuxResources, path string) error {
+ for _, limit := range spec.HugepageLimits {
+ name := fmt.Sprintf("hugetlb.%s.limit_in_bytes", limit.Pagesize)
+ val := strconv.FormatUint(limit.Limit, 10)
+ if err := setValue(path, name, val); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go
index 548c80e9a..4db5ee5c3 100644
--- a/runsc/cgroup/cgroup_test.go
+++ b/runsc/cgroup/cgroup_test.go
@@ -15,7 +15,14 @@
package cgroup
import (
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
"testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
func TestUninstallEnoent(t *testing.T) {
@@ -65,3 +72,578 @@ func TestCountCpuset(t *testing.T) {
})
}
}
+
+func uint16Ptr(v uint16) *uint16 {
+ return &v
+}
+
+func uint32Ptr(v uint32) *uint32 {
+ return &v
+}
+
+func int64Ptr(v int64) *int64 {
+ return &v
+}
+
+func uint64Ptr(v uint64) *uint64 {
+ return &v
+}
+
+func boolPtr(v bool) *bool {
+ return &v
+}
+
+func checkDir(t *testing.T, dir string, contents map[string]string) {
+ all, err := ioutil.ReadDir(dir)
+ if err != nil {
+ t.Fatalf("ReadDir(%q): %v", dir, err)
+ }
+ fileCount := 0
+ for _, file := range all {
+ if file.IsDir() {
+ // Only want to compare files.
+ continue
+ }
+ fileCount++
+
+ want, ok := contents[file.Name()]
+ if !ok {
+ t.Errorf("file not expected: %q", file.Name())
+ continue
+ }
+ gotBytes, err := ioutil.ReadFile(filepath.Join(dir, file.Name()))
+ if err != nil {
+ t.Fatal(err.Error())
+ }
+ got := strings.TrimSuffix(string(gotBytes), "\n")
+ if got != want {
+ t.Errorf("wrong file content, file: %q, want: %q, got: %q", file.Name(), want, got)
+ }
+ }
+ if fileCount != len(contents) {
+ t.Errorf("file is missing, want: %v, got: %v", contents, all)
+ }
+}
+
+func makeLinuxWeightDevice(major, minor int64, weight, leafWeight *uint16) specs.LinuxWeightDevice {
+ rv := specs.LinuxWeightDevice{
+ Weight: weight,
+ LeafWeight: leafWeight,
+ }
+ rv.Major = major
+ rv.Minor = minor
+ return rv
+}
+
+func makeLinuxThrottleDevice(major, minor int64, rate uint64) specs.LinuxThrottleDevice {
+ rv := specs.LinuxThrottleDevice{
+ Rate: rate,
+ }
+ rv.Major = major
+ rv.Minor = minor
+ return rv
+}
+
+func TestBlockIO(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxBlockIO
+ wants map[string]string
+ }{
+ {
+ name: "simple",
+ spec: &specs.LinuxBlockIO{
+ Weight: uint16Ptr(1),
+ LeafWeight: uint16Ptr(2),
+ },
+ wants: map[string]string{
+ "blkio.weight": "1",
+ "blkio.leaf_weight": "2",
+ },
+ },
+ {
+ name: "weight_device",
+ spec: &specs.LinuxBlockIO{
+ WeightDevice: []specs.LinuxWeightDevice{
+ makeLinuxWeightDevice(1, 2, uint16Ptr(3), uint16Ptr(4)),
+ },
+ },
+ wants: map[string]string{
+ "blkio.weight_device": "1:2 3",
+ "blkio.leaf_weight_device": "1:2 4",
+ },
+ },
+ {
+ name: "weight_device_nil_values",
+ spec: &specs.LinuxBlockIO{
+ WeightDevice: []specs.LinuxWeightDevice{
+ makeLinuxWeightDevice(1, 2, nil, nil),
+ },
+ },
+ },
+ {
+ name: "throttle",
+ spec: &specs.LinuxBlockIO{
+ ThrottleReadBpsDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(1, 2, 3),
+ },
+ ThrottleReadIOPSDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(4, 5, 6),
+ },
+ ThrottleWriteBpsDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(7, 8, 9),
+ },
+ ThrottleWriteIOPSDevice: []specs.LinuxThrottleDevice{
+ makeLinuxThrottleDevice(10, 11, 12),
+ },
+ },
+ wants: map[string]string{
+ "blkio.throttle.read_bps_device": "1:2 3",
+ "blkio.throttle.read_iops_device": "4:5 6",
+ "blkio.throttle.write_bps_device": "7:8 9",
+ "blkio.throttle.write_iops_device": "10:11 12",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxBlockIO{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ BlockIO: tc.spec,
+ }
+ ctrlr := blockIO{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestCPU(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxCPU{
+ Shares: uint64Ptr(1),
+ Quota: int64Ptr(2),
+ Period: uint64Ptr(3),
+ RealtimeRuntime: int64Ptr(4),
+ RealtimePeriod: uint64Ptr(5),
+ },
+ wants: map[string]string{
+ "cpu.shares": "1",
+ "cpu.cfs_quota_us": "2",
+ "cpu.cfs_period_us": "3",
+ "cpu.rt_runtime_us": "4",
+ "cpu.rt_period_us": "5",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxCPU{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpu{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestCPUSet(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxCPU{
+ Cpus: "foo",
+ Mems: "bar",
+ },
+ wants: map[string]string{
+ "cpuset.cpus": "foo",
+ "cpuset.mems": "bar",
+ },
+ },
+ // Don't test nil values because they are copied from the parent.
+ // See TestCPUSetAncestor().
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpuSet{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+// TestCPUSetAncestor checks that, when not available, value is read from
+// parent directory.
+func TestCPUSetAncestor(t *testing.T) {
+ // Prepare master directory with cgroup files that will be propagated to
+ // children.
+ grandpa, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(grandpa)
+
+ if err := ioutil.WriteFile(filepath.Join(grandpa, "cpuset.cpus"), []byte("parent-cpus"), 0666); err != nil {
+ t.Fatalf("ioutil.WriteFile(): %v", err)
+ }
+ if err := ioutil.WriteFile(filepath.Join(grandpa, "cpuset.mems"), []byte("parent-mems"), 0666); err != nil {
+ t.Fatalf("ioutil.WriteFile(): %v", err)
+ }
+
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxCPU
+ }{
+ {
+ name: "nil_values",
+ spec: &specs.LinuxCPU{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // Create empty files in intermediate directory. They should be ignored
+ // when reading, and then populated from parent.
+ parent, err := ioutil.TempDir(grandpa, "parent")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(parent)
+ if _, err := os.Create(filepath.Join(parent, "cpuset.cpus")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ if _, err := os.Create(filepath.Join(parent, "cpuset.mems")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+
+ // cgroup files mmust exist.
+ dir, err := ioutil.TempDir(parent, "child")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ if _, err := os.Create(filepath.Join(dir, "cpuset.cpus")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ if _, err := os.Create(filepath.Join(dir, "cpuset.mems")); err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+
+ spec := &specs.LinuxResources{
+ CPU: tc.spec,
+ }
+ ctrlr := cpuSet{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ want := map[string]string{
+ "cpuset.cpus": "parent-cpus",
+ "cpuset.mems": "parent-mems",
+ }
+ // Both path and dir must have been populated from grandpa.
+ checkDir(t, parent, want)
+ checkDir(t, dir, want)
+ })
+ }
+}
+
+func TestHugeTlb(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec []specs.LinuxHugepageLimit
+ wants map[string]string
+ }{
+ {
+ name: "single",
+ spec: []specs.LinuxHugepageLimit{
+ {
+ Pagesize: "1G",
+ Limit: 123,
+ },
+ },
+ wants: map[string]string{
+ "hugetlb.1G.limit_in_bytes": "123",
+ },
+ },
+ {
+ name: "multiple",
+ spec: []specs.LinuxHugepageLimit{
+ {
+ Pagesize: "1G",
+ Limit: 123,
+ },
+ {
+ Pagesize: "2G",
+ Limit: 456,
+ },
+ {
+ Pagesize: "1P",
+ Limit: 789,
+ },
+ },
+ wants: map[string]string{
+ "hugetlb.1G.limit_in_bytes": "123",
+ "hugetlb.2G.limit_in_bytes": "456",
+ "hugetlb.1P.limit_in_bytes": "789",
+ },
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ HugepageLimits: tc.spec,
+ }
+ ctrlr := hugeTLB{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestMemory(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxMemory
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxMemory{
+ Limit: int64Ptr(1),
+ Reservation: int64Ptr(2),
+ Swap: int64Ptr(3),
+ Kernel: int64Ptr(4),
+ KernelTCP: int64Ptr(5),
+ Swappiness: uint64Ptr(6),
+ DisableOOMKiller: boolPtr(true),
+ },
+ wants: map[string]string{
+ "memory.limit_in_bytes": "1",
+ "memory.soft_limit_in_bytes": "2",
+ "memory.memsw.limit_in_bytes": "3",
+ "memory.kmem.limit_in_bytes": "4",
+ "memory.kmem.tcp.limit_in_bytes": "5",
+ "memory.swappiness": "6",
+ "memory.oom_control": "1",
+ },
+ },
+ {
+ // Disable OOM killer should only write when set to true.
+ name: "oomkiller",
+ spec: &specs.LinuxMemory{
+ DisableOOMKiller: boolPtr(false),
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxMemory{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Memory: tc.spec,
+ }
+ ctrlr := memory{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestNetworkClass(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxNetwork
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxNetwork{
+ ClassID: uint32Ptr(1),
+ },
+ wants: map[string]string{
+ "net_cls.classid": "1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxNetwork{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Network: tc.spec,
+ }
+ ctrlr := networkClass{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestNetworkPriority(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxNetwork
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxNetwork{
+ Priorities: []specs.LinuxInterfacePriority{
+ {
+ Name: "foo",
+ Priority: 1,
+ },
+ },
+ },
+ wants: map[string]string{
+ "net_prio.ifpriomap": "foo 1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxNetwork{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Network: tc.spec,
+ }
+ ctrlr := networkPrio{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
+
+func TestPids(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ spec *specs.LinuxPids
+ wants map[string]string
+ }{
+ {
+ name: "all",
+ spec: &specs.LinuxPids{Limit: 1},
+ wants: map[string]string{
+ "pids.max": "1",
+ },
+ },
+ {
+ name: "nil_values",
+ spec: &specs.LinuxPids{},
+ },
+ {
+ name: "nil",
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "cgroup")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ defer os.RemoveAll(dir)
+
+ spec := &specs.LinuxResources{
+ Pids: tc.spec,
+ }
+ ctrlr := pids{}
+ if err := ctrlr.set(spec, dir); err != nil {
+ t.Fatalf("ctrlr.set(): %v", err)
+ }
+ checkDir(t, dir, tc.wants)
+ })
+ }
+}
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 250845ad7..1b5178dd5 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -31,10 +31,10 @@ go_library(
"spec.go",
"start.go",
"state.go",
+ "statefile.go",
"syscalls.go",
"wait.go",
],
- importpath = "gvisor.dev/gvisor/runsc/cmd",
visibility = [
"//runsc:__subpackages__",
],
@@ -44,17 +44,21 @@ go_library(
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/platform",
+ "//pkg/state/pretty",
+ "//pkg/state/statefile",
+ "//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
"//runsc/boot",
- "//runsc/boot/platforms",
"//runsc/console",
"//runsc/container",
+ "//runsc/flag",
"//runsc/fsgofer",
"//runsc/fsgofer/filter",
"//runsc/specutils",
"@com_github_google_subcommands//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -72,20 +76,20 @@ go_test(
data = [
"//runsc",
],
- embed = [":cmd"],
+ library = ":cmd",
deps = [
"//pkg/abi/linux",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/kernel/auth",
+ "//pkg/test/testutil",
"//pkg/urpc",
"//runsc/boot",
"//runsc/container",
"//runsc/specutils",
- "//runsc/testutil",
- "@com_github_google_go-cmp//cmp:go_default_library",
- "@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
],
)
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index b40fded5b..f4f247721 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -21,12 +21,13 @@ import (
"strings"
"syscall"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/runsc/boot"
- "gvisor.dev/gvisor/runsc/boot/platforms"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -53,10 +54,6 @@ type Boot struct {
// provided in that order.
stdioFDs intFlags
- // console is set to true if the sandbox should allow terminal ioctl(2)
- // syscalls.
- console bool
-
// applyCaps determines if capabilities defined in the spec should be applied
// to the process.
applyCaps bool
@@ -82,8 +79,13 @@ type Boot struct {
// sandbox (e.g. gofer) and sent through this FD.
mountsFD int
- // pidns is set if the sanadbox is in its own pid namespace.
+ // pidns is set if the sandbox is in its own pid namespace.
pidns bool
+
+ // attached is set to true to kill the sandbox process when the parent process
+ // terminates. This flag is set when the command execve's itself because
+ // parent death signal doesn't propagate through execve when uid/gid changes.
+ attached bool
}
// Name implements subcommands.Command.Name.
@@ -109,7 +111,6 @@ func (b *Boot) SetFlags(f *flag.FlagSet) {
f.IntVar(&b.deviceFD, "device-fd", -1, "FD for the platform device file")
f.Var(&b.ioFDs, "io-fds", "list of FDs to connect 9P clients. They must follow this order: root first, then mounts as defined in the spec")
f.Var(&b.stdioFDs, "stdio-fds", "list of FDs containing sandbox stdin, stdout, and stderr in that order")
- f.BoolVar(&b.console, "console", false, "set to true if the sandbox should allow terminal ioctl(2) syscalls")
f.BoolVar(&b.applyCaps, "apply-caps", false, "if true, apply capabilities defined in the spec to the process")
f.BoolVar(&b.setUpRoot, "setup-root", false, "if true, set up an empty root for the process")
f.BoolVar(&b.pidns, "pidns", false, "if true, the sandbox is in its own PID namespace")
@@ -118,6 +119,7 @@ func (b *Boot) SetFlags(f *flag.FlagSet) {
f.IntVar(&b.userLogFD, "user-log-fd", 0, "file descriptor to write user logs to. 0 means no logging.")
f.IntVar(&b.startSyncFD, "start-sync-fd", -1, "required FD to used to synchronize sandbox startup")
f.IntVar(&b.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to read list of mounts after they have been resolved (direct paths, no symlinks).")
+ f.BoolVar(&b.attached, "attached", false, "if attached is true, kills the sandbox process when the parent process terminates")
}
// Execute implements subcommands.Command.Execute. It starts a sandbox in a
@@ -129,33 +131,36 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
// Ensure that if there is a panic, all goroutine stacks are printed.
- debug.SetTraceback("all")
+ debug.SetTraceback("system")
conf := args[0].(*boot.Config)
+ if b.attached {
+ // Ensure this process is killed after parent process terminates when
+ // attached mode is enabled. In the unfortunate event that the parent
+ // terminates before this point, this process leaks.
+ if err := unix.Prctl(unix.PR_SET_PDEATHSIG, uintptr(unix.SIGKILL), 0, 0, 0); err != nil {
+ Fatalf("error setting parent death signal: %v", err)
+ }
+ }
+
if b.setUpRoot {
if err := setUpChroot(b.pidns); err != nil {
Fatalf("error setting up chroot: %v", err)
}
- if !b.applyCaps {
- // Remove --setup-root arg to call myself.
- var args []string
- for _, arg := range os.Args {
- if !strings.Contains(arg, "setup-root") {
- args = append(args, arg)
- }
- }
- if !conf.Rootless {
- // Note that we've already read the spec from the spec FD, and
- // we will read it again after the exec call. This works
- // because the ReadSpecFromFile function seeks to the beginning
- // of the file before reading.
- if err := callSelfAsNobody(args); err != nil {
- Fatalf("%v", err)
- }
- panic("callSelfAsNobody must never return success")
+ if !b.applyCaps && !conf.Rootless {
+ // Remove --apply-caps arg to call myself. It has already been done.
+ args := prepareArgs(b.attached, "setup-root")
+
+ // Note that we've already read the spec from the spec FD, and
+ // we will read it again after the exec call. This works
+ // because the ReadSpecFromFile function seeks to the beginning
+ // of the file before reading.
+ if err := callSelfAsNobody(args); err != nil {
+ Fatalf("%v", err)
}
+ panic("callSelfAsNobody must never return success")
}
}
@@ -173,7 +178,12 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
if caps == nil {
caps = &specs.LinuxCapabilities{}
}
- if conf.Platform == platforms.Ptrace {
+
+ gPlatform, err := platform.Lookup(conf.Platform)
+ if err != nil {
+ Fatalf("loading platform: %v", err)
+ }
+ if gPlatform.Requirements().RequiresCapSysPtrace {
// Ptrace platform requires extra capabilities.
const c = "CAP_SYS_PTRACE"
caps.Bounding = append(caps.Bounding, c)
@@ -181,13 +191,9 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
caps.Permitted = append(caps.Permitted, c)
}
- // Remove --apply-caps arg to call myself.
- var args []string
- for _, arg := range os.Args {
- if !strings.Contains(arg, "setup-root") && !strings.Contains(arg, "apply-caps") {
- args = append(args, arg)
- }
- }
+ // Remove --apply-caps and --setup-root arg to call myself. Both have
+ // already been done.
+ args := prepareArgs(b.attached, "setup-root", "apply-caps")
// Note that we've already read the spec from the spec FD, and
// we will read it again after the exec call. This works
@@ -218,7 +224,6 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Device: os.NewFile(uintptr(b.deviceFD), "platform device"),
GoferFDs: b.ioFDs.GetArray(),
StdioFDs: b.stdioFDs.GetArray(),
- Console: b.console,
NumCPU: b.cpuNum,
TotalMem: b.totalMem,
UserLogFD: b.userLogFD,
@@ -258,3 +263,22 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
l.Destroy()
return subcommands.ExitSuccess
}
+
+func prepareArgs(attached bool, exclude ...string) []string {
+ var args []string
+ for _, arg := range os.Args {
+ for _, excl := range exclude {
+ if strings.Contains(arg, excl) {
+ goto skip
+ }
+ }
+ args = append(args, arg)
+ if attached && arg == "boot" {
+ // Strategicaly place "--attached" after the command. This is needed
+ // to ensure the new process is killed when the parent process terminates.
+ args = append(args, "--attached")
+ }
+ skip:
+ }
+ return args
+}
diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go
index 0c27f7313..a84067112 100644
--- a/runsc/cmd/capability_test.go
+++ b/runsc/cmd/capability_test.go
@@ -23,10 +23,10 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/syndtr/gocapability/capability"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/testutil"
)
func init() {
@@ -85,21 +85,20 @@ func TestCapabilities(t *testing.T) {
Inheritable: caps,
}
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
// Use --network=host to make sandbox use spec's capabilities.
conf.Network = boot.NetworkHost
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create and start the container.
args := container.Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go
index d8b3a8573..8a29e521e 100644
--- a/runsc/cmd/checkpoint.go
+++ b/runsc/cmd/checkpoint.go
@@ -20,11 +20,11 @@ import (
"path/filepath"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go
index b5a0ce17d..189244765 100644
--- a/runsc/cmd/chroot.go
+++ b/runsc/cmd/chroot.go
@@ -50,7 +50,7 @@ func pivotRoot(root string) error {
// new_root, so after umounting the old_root, we will see only
// the new_root in "/".
if err := syscall.PivotRoot(".", "."); err != nil {
- return fmt.Errorf("error changing root filesystem: %v", err)
+ return fmt.Errorf("pivot_root failed, make sure that the root mount has a parent: %v", err)
}
if err := syscall.Unmount(".", syscall.MNT_DETACH); err != nil {
diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go
index a4e3071b3..910e97577 100644
--- a/runsc/cmd/create.go
+++ b/runsc/cmd/create.go
@@ -16,10 +16,11 @@ package cmd
import (
"context"
- "flag"
+
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index 7313e473f..742f8c344 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -22,12 +22,12 @@ import (
"syscall"
"time"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Debug implements subcommands.Command for the "debug" command.
@@ -37,11 +37,14 @@ type Debug struct {
signal int
profileHeap string
profileCPU string
- profileDelay int
+ profileBlock string
+ profileMutex string
trace string
strace string
logLevel string
logPackets string
+ duration time.Duration
+ ps bool
}
// Name implements subcommands.Command.
@@ -65,12 +68,15 @@ func (d *Debug) SetFlags(f *flag.FlagSet) {
f.BoolVar(&d.stacks, "stacks", false, "if true, dumps all sandbox stacks to the log")
f.StringVar(&d.profileHeap, "profile-heap", "", "writes heap profile to the given file.")
f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.")
- f.IntVar(&d.profileDelay, "profile-delay", 5, "amount of time to wait before stoping CPU profile")
+ f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.")
+ f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.")
+ f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles")
f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.")
f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox")
f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all`)
f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).")
f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.")
+ f.BoolVar(&d.ps, "ps", false, "lists processes")
}
// Execute implements subcommands.Command.Execute.
@@ -145,6 +151,30 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
log.Infof("Heap profile written to %q", d.profileHeap)
}
+ if d.profileBlock != "" {
+ f, err := os.Create(d.profileBlock)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.BlockProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Block profile written to %q", d.profileBlock)
+ }
+ if d.profileMutex != "" {
+ f, err := os.Create(d.profileMutex)
+ if err != nil {
+ return Errorf(err.Error())
+ }
+ defer f.Close()
+
+ if err := c.Sandbox.MutexProfile(f); err != nil {
+ return Errorf(err.Error())
+ }
+ log.Infof("Mutex profile written to %q", d.profileMutex)
+ }
delay := false
if d.profileCPU != "" {
@@ -163,7 +193,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
if err := c.Sandbox.StartCPUProfile(f); err != nil {
return Errorf(err.Error())
}
- log.Infof("CPU profile started for %d sec, writing to %q", d.profileDelay, d.profileCPU)
+ log.Infof("CPU profile started for %v, writing to %q", d.duration, d.profileCPU)
}
if d.trace != "" {
delay = true
@@ -181,8 +211,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
if err := c.Sandbox.StartTrace(f); err != nil {
return Errorf(err.Error())
}
- log.Infof("Tracing started for %d sec, writing to %q", d.profileDelay, d.trace)
-
+ log.Infof("Tracing started for %v, writing to %q", d.duration, d.trace)
}
if d.strace != "" || len(d.logLevel) != 0 || len(d.logPackets) != 0 {
@@ -241,9 +270,20 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
}
log.Infof("Logging options changed")
}
+ if d.ps {
+ pList, err := c.Processes()
+ if err != nil {
+ Fatalf("getting processes for container: %v", err)
+ }
+ o, err := control.ProcessListToJSON(pList)
+ if err != nil {
+ Fatalf("generating JSON: %v", err)
+ }
+ log.Infof(o)
+ }
if delay {
- time.Sleep(time.Duration(d.profileDelay) * time.Second)
+ time.Sleep(d.duration)
}
return subcommands.ExitSuccess
diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go
index 30d8164b1..0e4863f50 100644
--- a/runsc/cmd/delete.go
+++ b/runsc/cmd/delete.go
@@ -19,11 +19,11 @@ import (
"fmt"
"os"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Delete implements subcommands.Command for the "delete" command.
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 9a8a49054..7d1310c96 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -27,12 +27,12 @@ import (
"strings"
"syscall"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -166,15 +166,33 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
return Errorf("Error write spec: %v", err)
}
- runArgs := container.Args{
+ containerArgs := container.Args{
ID: cid,
Spec: spec,
BundleDir: tmpDir,
Attached: true,
}
- ws, err := container.Run(conf, runArgs)
+ ct, err := container.New(conf, containerArgs)
if err != nil {
- return Errorf("running container: %v", err)
+ return Errorf("creating container: %v", err)
+ }
+ defer ct.Destroy()
+
+ if err := ct.Start(conf); err != nil {
+ return Errorf("starting container: %v", err)
+ }
+
+ // Forward signals to init in the container. Thus if we get SIGINT from
+ // ^C, the container gracefully exit, and we can clean up.
+ //
+ // N.B. There is a still a window before this where a signal may kill
+ // this process, skipping cleanup.
+ stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */)
+ defer stopForwarding()
+
+ ws, err := ct.Wait()
+ if err != nil {
+ return Errorf("waiting for container: %v", err)
}
*waitStatus = ws
@@ -237,20 +255,27 @@ func (c *Do) setupNet(cid string, spec *specs.Spec) (func(), error) {
for _, cmd := range cmds {
log.Debugf("Run %q", cmd)
args := strings.Split(cmd, " ")
- c := exec.Command(args[0], args[1:]...)
- if err := c.Run(); err != nil {
+ cmd := exec.Command(args[0], args[1:]...)
+ if err := cmd.Run(); err != nil {
+ c.cleanupNet(cid, dev, "", "", "")
return nil, fmt.Errorf("failed to run %q: %v", cmd, err)
}
}
- if err := makeFile("/etc/resolv.conf", "nameserver 8.8.8.8\n", spec); err != nil {
+ resolvPath, err := makeFile("/etc/resolv.conf", "nameserver 8.8.8.8\n", spec)
+ if err != nil {
+ c.cleanupNet(cid, dev, "", "", "")
return nil, err
}
- if err := makeFile("/etc/hostname", cid+"\n", spec); err != nil {
+ hostnamePath, err := makeFile("/etc/hostname", cid+"\n", spec)
+ if err != nil {
+ c.cleanupNet(cid, dev, resolvPath, "", "")
return nil, err
}
hosts := fmt.Sprintf("127.0.0.1\tlocalhost\n%s\t%s\n", c.ip, cid)
- if err := makeFile("/etc/hosts", hosts, spec); err != nil {
+ hostsPath, err := makeFile("/etc/hosts", hosts, spec)
+ if err != nil {
+ c.cleanupNet(cid, dev, resolvPath, hostnamePath, "")
return nil, err
}
@@ -263,19 +288,22 @@ func (c *Do) setupNet(cid string, spec *specs.Spec) (func(), error) {
}
spec.Linux.Namespaces = append(spec.Linux.Namespaces, netns)
- return func() { c.cleanNet(cid, dev) }, nil
+ return func() { c.cleanupNet(cid, dev, resolvPath, hostnamePath, hostsPath) }, nil
}
-func (c *Do) cleanNet(cid, dev string) {
- veth, peer := deviceNames(cid)
+// cleanupNet tries to cleanup the network setup in setupNet.
+//
+// It may be called when setupNet is only partially complete, in which case it
+// will cleanup as much as possible, logging warnings for the rest.
+//
+// Unfortunately none of this can be automatically cleaned up on process exit,
+// we must do so explicitly.
+func (c *Do) cleanupNet(cid, dev, resolvPath, hostnamePath, hostsPath string) {
+ _, peer := deviceNames(cid)
cmds := []string{
fmt.Sprintf("ip link delete %s", peer),
fmt.Sprintf("ip netns delete %s", cid),
-
- fmt.Sprintf("iptables -t nat -D POSTROUTING -s %s/24 -o %s -j MASQUERADE", c.ip, dev),
- fmt.Sprintf("iptables -D FORWARD -i %s -o %s -j ACCEPT", dev, veth),
- fmt.Sprintf("iptables -D FORWARD -o %s -i %s -j ACCEPT", dev, veth),
}
for _, cmd := range cmds {
@@ -286,6 +314,10 @@ func (c *Do) cleanNet(cid, dev string) {
log.Warningf("Failed to run %q: %v", cmd, err)
}
}
+
+ tryRemove(resolvPath)
+ tryRemove(hostnamePath)
+ tryRemove(hostsPath)
}
func deviceNames(cid string) (string, string) {
@@ -306,13 +338,16 @@ func defaultDevice() (string, error) {
return parts[4], nil
}
-func makeFile(dest, content string, spec *specs.Spec) error {
+func makeFile(dest, content string, spec *specs.Spec) (string, error) {
tmpFile, err := ioutil.TempFile("", filepath.Base(dest))
if err != nil {
- return err
+ return "", err
}
if _, err := tmpFile.WriteString(content); err != nil {
- return err
+ if err := os.Remove(tmpFile.Name()); err != nil {
+ log.Warningf("Failed to remove %q: %v", tmpFile, err)
+ }
+ return "", err
}
spec.Mounts = append(spec.Mounts, specs.Mount{
Source: tmpFile.Name(),
@@ -320,7 +355,17 @@ func makeFile(dest, content string, spec *specs.Spec) error {
Type: "bind",
Options: []string{"ro"},
})
- return nil
+ return tmpFile.Name(), nil
+}
+
+func tryRemove(path string) {
+ if path == "" {
+ return
+ }
+
+ if err := os.Remove(path); err != nil {
+ log.Warningf("Failed to remove %q: %v", path, err)
+ }
}
func calculatePeerIP(ip string) (string, error) {
diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go
index 3972e9224..51f6a98ed 100644
--- a/runsc/cmd/events.go
+++ b/runsc/cmd/events.go
@@ -20,11 +20,11 @@ import (
"os"
"time"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Events implements subcommands.Command for the "events" command.
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index d1e99243b..d9a94903e 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -27,7 +27,6 @@ import (
"syscall"
"time"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
@@ -37,6 +36,7 @@ import (
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/console"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 4831210c0..3966e2d21 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -21,17 +21,17 @@ import (
"os"
"path/filepath"
"strings"
- "sync"
"syscall"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/fsgofer"
"gvisor.dev/gvisor/runsc/fsgofer/filter"
"gvisor.dev/gvisor/runsc/specutils"
@@ -168,7 +168,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Start with root mount, then add any other additional mount as needed.
ats := make([]p9.Attacher, 0, len(spec.Mounts)+1)
ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{
- ROMount: spec.Root.Readonly,
+ ROMount: spec.Root.Readonly || conf.Overlay,
PanicOnWrite: g.panicOnWrite,
})
if err != nil {
@@ -181,7 +181,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
for _, m := range spec.Mounts {
if specutils.Is9PMount(m) {
cfg := fsgofer.Config{
- ROMount: isReadonlyMount(m.Options),
+ ROMount: isReadonlyMount(m.Options) || conf.Overlay,
PanicOnWrite: g.panicOnWrite,
HostUDS: conf.FSGoferHostUDS,
}
@@ -272,9 +272,8 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
root := spec.Root.Path
if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
- // FIXME: runsc can't be re-executed without
- // /proc, so we create a tmpfs mount, mount ./proc and ./root
- // there, then move this mount to the root and after
+ // runsc can't be re-executed without /proc, so we create a tmpfs mount,
+ // mount ./proc and ./root there, then move this mount to the root and after
// setCapsAndCallSelf, runsc will chroot into /root.
//
// We need a directory to construct a new root and we know that
@@ -307,7 +306,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
}
// Replace the current spec, with the clean spec with symlinks resolved.
- if err := setupMounts(spec.Mounts, root); err != nil {
+ if err := setupMounts(conf, spec.Mounts, root); err != nil {
Fatalf("error setting up FS: %v", err)
}
@@ -323,7 +322,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
}
// Check if root needs to be remounted as readonly.
- if spec.Root.Readonly {
+ if spec.Root.Readonly || conf.Overlay {
// If root is a mount point but not read-only, we can change mount options
// to make it read-only for extra safety.
log.Infof("Remounting root as readonly: %q", root)
@@ -335,7 +334,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
if !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot {
if err := pivotRoot("/proc"); err != nil {
- Fatalf("faild to change the root file system: %v", err)
+ Fatalf("failed to change the root file system: %v", err)
}
if err := os.Chdir("/"); err != nil {
Fatalf("failed to change working directory")
@@ -347,7 +346,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error {
// setupMounts binds mount all mounts specified in the spec in their correct
// location inside root. It will resolve relative paths and symlinks. It also
// creates directories as needed.
-func setupMounts(mounts []specs.Mount, root string) error {
+func setupMounts(conf *boot.Config, mounts []specs.Mount, root string) error {
for _, m := range mounts {
if m.Type != "bind" || !specutils.IsSupportedDevMount(m) {
continue
@@ -359,6 +358,11 @@ func setupMounts(mounts []specs.Mount, root string) error {
}
flags := specutils.OptionsToFlags(m.Options) | syscall.MS_BIND
+ if conf.Overlay {
+ // Force mount read-only if writes are not going to be sent to it.
+ flags |= syscall.MS_RDONLY
+ }
+
log.Infof("Mounting src: %q, dst: %q, flags: %#x", m.Source, dst, flags)
if err := specutils.Mount(m.Source, dst, m.Type, flags); err != nil {
return fmt.Errorf("mounting %v: %v", m, err)
diff --git a/runsc/cmd/help.go b/runsc/cmd/help.go
index ff4f901cb..cd85dabbb 100644
--- a/runsc/cmd/help.go
+++ b/runsc/cmd/help.go
@@ -1,4 +1,4 @@
-// Copyright 2018 Google LLC
+// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -18,8 +18,8 @@ import (
"context"
"fmt"
- "flag"
"github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/flag"
)
// NewHelp returns a help command for the given commander.
@@ -65,16 +65,10 @@ func (h *Help) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}
switch f.NArg() {
case 0:
fmt.Fprintf(h.cdr.Output, "Usage: %s <flags> <subcommand> <subcommand args>\n\n", h.cdr.Name())
- fmt.Fprintf(h.cdr.Output, `runsc is a command line client for running applications packaged in the Open
-Container Initiative (OCI) format. Applications run by runsc are run in an
-isolated gVisor sandbox that emulates a Linux environment.
+ fmt.Fprintf(h.cdr.Output, `runsc is the gVisor container runtime.
-gVisor is a user-space kernel, written in Go, that implements a substantial
-portion of the Linux system call interface. It provides an additional layer
-of isolation between running applications and the host operating system.
-
-Functionality is provided by subcommands. For additonal help on individual
-subcommands use "%s %s <subcommand>".
+Functionality is provided by subcommands. For help with a specific subcommand,
+use "%s %s <subcommand>".
`, h.cdr.Name(), h.Name())
h.cdr.VisitGroups(func(g *subcommands.CommandGroup) {
diff --git a/runsc/cmd/install.go b/runsc/cmd/install.go
index 441c1db0d..2e223e3be 100644
--- a/runsc/cmd/install.go
+++ b/runsc/cmd/install.go
@@ -23,8 +23,8 @@ import (
"os"
"path"
- "flag"
"github.com/google/subcommands"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Install implements subcommands.Command.
diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go
index 6c1f197a6..8282ea0e0 100644
--- a/runsc/cmd/kill.go
+++ b/runsc/cmd/kill.go
@@ -21,11 +21,11 @@ import (
"strings"
"syscall"
- "flag"
"github.com/google/subcommands"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Kill implements subcommands.Command for the "kill" command.
diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go
index dd2d99a6b..d8d906fe3 100644
--- a/runsc/cmd/list.go
+++ b/runsc/cmd/list.go
@@ -22,11 +22,11 @@ import (
"text/tabwriter"
"time"
- "flag"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// List implements subcommands.Command for the "list" command for the "list" command.
diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go
index 9c0e92001..6f95a9837 100644
--- a/runsc/cmd/pause.go
+++ b/runsc/cmd/pause.go
@@ -17,10 +17,10 @@ package cmd
import (
"context"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Pause implements subcommands.Command for the "pause" command.
diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go
index 45c644f3f..7fb8041af 100644
--- a/runsc/cmd/ps.go
+++ b/runsc/cmd/ps.go
@@ -18,11 +18,11 @@ import (
"context"
"fmt"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// PS implements subcommands.Command for the "ps" command.
diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go
index 7be60cd7d..72584b326 100644
--- a/runsc/cmd/restore.go
+++ b/runsc/cmd/restore.go
@@ -19,10 +19,10 @@ import (
"path/filepath"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go
index b2df5c640..61a55a554 100644
--- a/runsc/cmd/resume.go
+++ b/runsc/cmd/resume.go
@@ -17,10 +17,10 @@ package cmd
import (
"context"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Resume implements subcommands.Command for the "resume" command.
diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go
index 33f4bc12b..cf41581ad 100644
--- a/runsc/cmd/run.go
+++ b/runsc/cmd/run.go
@@ -18,10 +18,10 @@ import (
"context"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/runsc/cmd/spec.go b/runsc/cmd/spec.go
index 344da13ba..55194e641 100644
--- a/runsc/cmd/spec.go
+++ b/runsc/cmd/spec.go
@@ -16,118 +16,122 @@ package cmd
import (
"context"
- "io/ioutil"
+ "encoding/json"
+ "io"
"os"
"path/filepath"
- "flag"
"github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/runsc/flag"
)
-var specTemplate = []byte(`{
- "ociVersion": "1.0.0",
- "process": {
- "terminal": true,
- "user": {
- "uid": 0,
- "gid": 0
- },
- "args": [
- "sh"
- ],
- "env": [
- "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
- "TERM=xterm"
- ],
- "cwd": "/",
- "capabilities": {
- "bounding": [
- "CAP_AUDIT_WRITE",
- "CAP_KILL",
- "CAP_NET_BIND_SERVICE"
- ],
- "effective": [
- "CAP_AUDIT_WRITE",
- "CAP_KILL",
- "CAP_NET_BIND_SERVICE"
- ],
- "inheritable": [
- "CAP_AUDIT_WRITE",
- "CAP_KILL",
- "CAP_NET_BIND_SERVICE"
- ],
- "permitted": [
- "CAP_AUDIT_WRITE",
- "CAP_KILL",
- "CAP_NET_BIND_SERVICE"
- ],
- "ambient": [
- "CAP_AUDIT_WRITE",
- "CAP_KILL",
- "CAP_NET_BIND_SERVICE"
- ]
- },
- "rlimits": [
- {
- "type": "RLIMIT_NOFILE",
- "hard": 1024,
- "soft": 1024
- }
- ]
- },
- "root": {
- "path": "rootfs",
- "readonly": true
- },
- "hostname": "runsc",
- "mounts": [
- {
- "destination": "/proc",
- "type": "proc",
- "source": "proc"
+func writeSpec(w io.Writer, cwd string, netns string, args []string) error {
+ spec := &specs.Spec{
+ Version: "1.0.0",
+ Process: &specs.Process{
+ Terminal: true,
+ User: specs.User{
+ UID: 0,
+ GID: 0,
+ },
+ Args: args,
+ Env: []string{
+ "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
+ "TERM=xterm",
+ },
+ Cwd: cwd,
+ Capabilities: &specs.LinuxCapabilities{
+ Bounding: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ Effective: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ Inheritable: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ Permitted: []string{
+ "CAP_AUDIT_WRITE",
+ "CAP_KILL",
+ "CAP_NET_BIND_SERVICE",
+ },
+ // TODO(gvisor.dev/issue/3166): support ambient capabilities
+ },
+ Rlimits: []specs.POSIXRlimit{
+ {
+ Type: "RLIMIT_NOFILE",
+ Hard: 1024,
+ Soft: 1024,
+ },
+ },
},
- {
- "destination": "/dev",
- "type": "tmpfs",
- "source": "tmpfs",
- "options": []
+ Root: &specs.Root{
+ Path: "rootfs",
+ Readonly: true,
},
- {
- "destination": "/sys",
- "type": "sysfs",
- "source": "sysfs",
- "options": [
- "nosuid",
- "noexec",
- "nodev",
- "ro"
- ]
- }
- ],
- "linux": {
- "namespaces": [
+ Hostname: "runsc",
+ Mounts: []specs.Mount{
{
- "type": "pid"
+ Destination: "/proc",
+ Type: "proc",
+ Source: "proc",
},
{
- "type": "network"
+ Destination: "/dev",
+ Type: "tmpfs",
+ Source: "tmpfs",
},
{
- "type": "ipc"
+ Destination: "/sys",
+ Type: "sysfs",
+ Source: "sysfs",
+ Options: []string{
+ "nosuid",
+ "noexec",
+ "nodev",
+ "ro",
+ },
},
- {
- "type": "uts"
+ },
+ Linux: &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ },
+ {
+ Type: "network",
+ Path: netns,
+ },
+ {
+ Type: "ipc",
+ },
+ {
+ Type: "uts",
+ },
+ {
+ Type: "mount",
+ },
},
- {
- "type": "mount"
- }
- ]
+ },
}
-}`)
+
+ e := json.NewEncoder(w)
+ e.SetIndent("", " ")
+ return e.Encode(spec)
+}
// Spec implements subcommands.Command for the "spec" command.
type Spec struct {
bundle string
+ cwd string
+ netns string
}
// Name implements subcommands.Command.Name.
@@ -142,21 +146,26 @@ func (*Spec) Synopsis() string {
// Usage implements subcommands.Command.Usage.
func (*Spec) Usage() string {
- return `spec [options] - create a new OCI bundle specification file.
+ return `spec [options] [-- args...] - create a new OCI bundle specification file.
+
+The spec command creates a new specification file (config.json) for a new OCI
+bundle.
-The spec command creates a new specification file (config.json) for a new OCI bundle.
+The specification file is a starter file that runs the command specified by
+'args' in the container. If 'args' is not specified the default is to run the
+'sh' program.
-The specification file is a starter file that runs the "sh" command in the container. You
-should edit the file to suit your needs. You can find out more about the format of the
-specification file by visiting the OCI runtime spec repository:
+While a number of flags are provided to change values in the specification, you
+can examine the file and edit it to suit your needs after this command runs.
+You can find out more about the format of the specification file by visiting
+the OCI runtime spec repository:
https://github.com/opencontainers/runtime-spec/
EXAMPLE:
$ mkdir -p bundle/rootfs
$ cd bundle
- $ runsc spec
+ $ runsc spec -- /hello
$ docker export $(docker create hello-world) | tar -xf - -C rootfs
- $ sed -i 's;"sh";"/hello";' config.json
$ sudo runsc run hello
`
@@ -165,16 +174,31 @@ EXAMPLE:
// SetFlags implements subcommands.Command.SetFlags.
func (s *Spec) SetFlags(f *flag.FlagSet) {
f.StringVar(&s.bundle, "bundle", ".", "path to the root of the OCI bundle")
+ f.StringVar(&s.cwd, "cwd", "/", "working directory that will be set for the executable, "+
+ "this value MUST be an absolute path")
+ f.StringVar(&s.netns, "netns", "", "network namespace path")
}
// Execute implements subcommands.Command.Execute.
func (s *Spec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Grab the arguments.
+ containerArgs := f.Args()
+ if len(containerArgs) == 0 {
+ containerArgs = []string{"sh"}
+ }
+
confPath := filepath.Join(s.bundle, "config.json")
if _, err := os.Stat(confPath); !os.IsNotExist(err) {
Fatalf("file %q already exists", confPath)
}
- if err := ioutil.WriteFile(confPath, specTemplate, 0664); err != nil {
+ configFile, err := os.OpenFile(confPath, os.O_WRONLY|os.O_CREATE, 0664)
+ if err != nil {
+ Fatalf("opening file %q: %v", confPath, err)
+ }
+
+ err = writeSpec(configFile, s.cwd, s.netns, containerArgs)
+ if err != nil {
Fatalf("writing to %q: %v", confPath, err)
}
diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go
index de2115dff..0205fd9f7 100644
--- a/runsc/cmd/start.go
+++ b/runsc/cmd/start.go
@@ -16,10 +16,11 @@ package cmd
import (
"context"
- "flag"
+
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Start implements subcommands.Command for the "start" command.
diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go
index e9f41cbd8..cf2413deb 100644
--- a/runsc/cmd/state.go
+++ b/runsc/cmd/state.go
@@ -19,11 +19,11 @@ import (
"encoding/json"
"os"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
// State implements subcommands.Command for the "state" command.
diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go
new file mode 100644
index 000000000..daed9e728
--- /dev/null
+++ b/runsc/cmd/statefile.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/state/pretty"
+ "gvisor.dev/gvisor/pkg/state/statefile"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Statefile implements subcommands.Command for the "statefile" command.
+type Statefile struct {
+ list bool
+ get string
+ key string
+ output string
+ html bool
+}
+
+// Name implements subcommands.Command.
+func (*Statefile) Name() string {
+ return "state"
+}
+
+// Synopsis implements subcommands.Command.
+func (*Statefile) Synopsis() string {
+ return "shows information about a statefile"
+}
+
+// Usage implements subcommands.Command.
+func (*Statefile) Usage() string {
+ return `statefile [flags] <statefile>`
+}
+
+// SetFlags implements subcommands.Command.
+func (s *Statefile) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&s.list, "list", false, "lists the metdata in the statefile.")
+ f.StringVar(&s.get, "get", "", "extracts the given metadata key.")
+ f.StringVar(&s.key, "key", "", "the integrity key for the file.")
+ f.StringVar(&s.output, "output", "", "target to write the result.")
+ f.BoolVar(&s.html, "html", false, "outputs in HTML format.")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (s *Statefile) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Check arguments.
+ if s.list && s.get != "" {
+ Fatalf("error: can't specify -list and -get simultaneously.")
+ }
+
+ // Setup output.
+ var output = os.Stdout // Default.
+ if s.output != "" {
+ f, err := os.OpenFile(s.output, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
+ if err != nil {
+ Fatalf("error opening output: %v", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ Fatalf("error flushing output: %v", err)
+ }
+ }()
+ output = f
+ }
+
+ // Open the file.
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ input, err := os.Open(f.Arg(0))
+ if err != nil {
+ Fatalf("error opening input: %v\n", err)
+ }
+
+ if s.html {
+ fmt.Fprintf(output, "<html><body>\n")
+ defer fmt.Fprintf(output, "</body></html>\n")
+ }
+
+ // Dump the full file?
+ if !s.list && s.get == "" {
+ var key []byte
+ if s.key != "" {
+ key = []byte(s.key)
+ }
+ rc, _, err := statefile.NewReader(input, key)
+ if err != nil {
+ Fatalf("error parsing statefile: %v", err)
+ }
+ if s.html {
+ if err := pretty.PrintHTML(output, rc); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
+ } else {
+ if err := pretty.PrintText(output, rc); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
+ }
+ return subcommands.ExitSuccess
+ }
+
+ // Load just the metadata.
+ metadata, err := statefile.MetadataUnsafe(input)
+ if err != nil {
+ Fatalf("error reading metadata: %v", err)
+ }
+
+ // Is it a single key?
+ if s.get != "" {
+ val, ok := metadata[s.get]
+ if !ok {
+ Fatalf("metadata key %s: not found", s.get)
+ }
+ fmt.Fprintf(output, "%s\n", val)
+ return subcommands.ExitSuccess
+ }
+
+ // List all keys.
+ if s.html {
+ fmt.Fprintf(output, " <ul>\n")
+ defer fmt.Fprintf(output, " </ul>\n")
+ }
+ for key := range metadata {
+ if s.html {
+ fmt.Fprintf(output, " <li>%s</li>\n", key)
+ } else {
+ fmt.Fprintf(output, "%s\n", key)
+ }
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/syscalls.go b/runsc/cmd/syscalls.go
index fb6c1ab29..a37d66139 100644
--- a/runsc/cmd/syscalls.go
+++ b/runsc/cmd/syscalls.go
@@ -25,16 +25,17 @@ import (
"strconv"
"text/tabwriter"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/runsc/flag"
)
// Syscalls implements subcommands.Command for the "syscalls" command.
type Syscalls struct {
- output string
- os string
- arch string
+ format string
+ os string
+ arch string
+ filename string
}
// CompatibilityInfo is a map of system and architecture to compatibility doc.
@@ -95,16 +96,17 @@ func (*Syscalls) Usage() string {
// SetFlags implements subcommands.Command.SetFlags.
func (s *Syscalls) SetFlags(f *flag.FlagSet) {
- f.StringVar(&s.output, "o", "table", "Output format (table, csv, json).")
+ f.StringVar(&s.format, "format", "table", "Output format (table, csv, json).")
f.StringVar(&s.os, "os", osAll, "The OS (e.g. linux)")
f.StringVar(&s.arch, "arch", archAll, "The CPU architecture (e.g. amd64).")
+ f.StringVar(&s.filename, "filename", "", "Output filename (otherwise stdout).")
}
// Execute implements subcommands.Command.Execute.
func (s *Syscalls) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- out, ok := outputMap[s.output]
+ out, ok := outputMap[s.format]
if !ok {
- Fatalf("Unsupported output format %q", s.output)
+ Fatalf("Unsupported output format %q", s.format)
}
// Build map of all supported architectures.
@@ -124,7 +126,14 @@ func (s *Syscalls) Execute(_ context.Context, f *flag.FlagSet, args ...interface
Fatalf("%v", err)
}
- if err := out(os.Stdout, info); err != nil {
+ w := os.Stdout // Default.
+ if s.filename != "" {
+ w, err = os.OpenFile(s.filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
+ if err != nil {
+ Fatalf("Error opening %q: %v", s.filename, err)
+ }
+ }
+ if err := out(w, info); err != nil {
Fatalf("Error writing output: %v", err)
}
diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go
index 046489687..29c0a15f0 100644
--- a/runsc/cmd/wait.go
+++ b/runsc/cmd/wait.go
@@ -20,10 +20,10 @@ import (
"os"
"syscall"
- "flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/flag"
)
const (
diff --git a/runsc/console/BUILD b/runsc/console/BUILD
index e623c1a0f..06924bccd 100644
--- a/runsc/console/BUILD
+++ b/runsc/console/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -7,7 +7,6 @@ go_library(
srcs = [
"console.go",
],
- importpath = "gvisor.dev/gvisor/runsc/console",
visibility = [
"//runsc:__subpackages__",
],
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 26d1cd5ab..9a9ee7e2a 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -7,60 +7,68 @@ go_library(
srcs = [
"container.go",
"hook.go",
+ "state_file.go",
"status.go",
],
- importpath = "gvisor.dev/gvisor/runsc/container",
visibility = [
"//runsc:__subpackages__",
"//test:__subpackages__",
],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/cleanup",
"//pkg/log",
"//pkg/sentry/control",
+ "//pkg/sentry/sighandling",
+ "//pkg/sync",
"//runsc/boot",
"//runsc/cgroup",
"//runsc/sandbox",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_gofrs_flock//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
],
)
go_test(
name = "container_test",
- size = "medium",
+ size = "large",
srcs = [
"console_test.go",
+ "container_norace_test.go",
+ "container_race_test.go",
"container_test.go",
"multi_container_test.go",
"shared_volume_test.go",
],
data = [
"//runsc",
- "//runsc/container/test_app",
+ "//test/cmd/test_app",
],
- embed = [":container"],
- shard_count = 5,
+ library = ":container",
+ shard_count = 10,
tags = [
"requires-kvm",
],
deps = [
"//pkg/abi/linux",
"//pkg/bits",
+ "//pkg/cleanup",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sync",
+ "//pkg/test/testutil",
"//pkg/unet",
"//pkg/urpc",
"//runsc/boot",
"//runsc/boot/platforms",
"//runsc/specutils",
- "//runsc/testutil",
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_kr_pty//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 7d67c3a75..995d4e267 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -20,7 +20,6 @@ import (
"io"
"os"
"path/filepath"
- "sync"
"syscall"
"testing"
"time"
@@ -28,9 +27,11 @@ import (
"github.com/kr/pty"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
- "gvisor.dev/gvisor/runsc/testutil"
)
// socketPath creates a path inside bundleDir and ensures that the returned
@@ -57,25 +58,26 @@ func socketPath(bundleDir string) (string, error) {
}
// createConsoleSocket creates a socket at the given path that will receive a
-// console fd from the sandbox. If no error occurs, it returns the server
-// socket and a cleanup function.
-func createConsoleSocket(path string) (*unet.ServerSocket, func() error, error) {
+// console fd from the sandbox. If an error occurs, t.Fatalf will be called.
+// The function returning should be deferred as cleanup.
+func createConsoleSocket(t *testing.T, path string) (*unet.ServerSocket, func()) {
+ t.Helper()
srv, err := unet.BindAndListen(path, false)
if err != nil {
- return nil, nil, fmt.Errorf("error binding and listening to socket %q: %v", path, err)
+ t.Fatalf("error binding and listening to socket %q: %v", path, err)
}
- cleanup := func() error {
+ cleanup := func() {
+ // Log errors; nothing can be done.
if err := srv.Close(); err != nil {
- return fmt.Errorf("error closing socket %q: %v", path, err)
+ t.Logf("error closing socket %q: %v", path, err)
}
if err := os.Remove(path); err != nil {
- return fmt.Errorf("error removing socket %q: %v", path, err)
+ t.Logf("error removing socket %q: %v", path, err)
}
- return nil
}
- return srv, cleanup, nil
+ return srv, cleanup
}
// receiveConsolePTY accepts a connection on the server socket and reads fds.
@@ -117,63 +119,60 @@ func receiveConsolePTY(srv *unet.ServerSocket) (*os.File, error) {
// Test that an pty FD is sent over the console socket if one is provided.
func TestConsoleSocket(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
- spec := testutil.NewSpecWithArgs("true")
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("true")
+ spec.Process.Terminal = true
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- sock, err := socketPath(bundleDir)
- if err != nil {
- t.Fatalf("error getting socket path: %v", err)
- }
- srv, cleanup, err := createConsoleSocket(sock)
- if err != nil {
- t.Fatalf("error creating socket at %q: %v", sock, err)
- }
- defer cleanup()
-
- // Create the container and pass the socket name.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- ConsoleSocket: sock,
- }
- c, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer c.Destroy()
+ sock, err := socketPath(bundleDir)
+ if err != nil {
+ t.Fatalf("error getting socket path: %v", err)
+ }
+ srv, cleanup := createConsoleSocket(t, sock)
+ defer cleanup()
+
+ // Create the container and pass the socket name.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ ConsoleSocket: sock,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
- // Make sure we get a console PTY.
- ptyMaster, err := receiveConsolePTY(srv)
- if err != nil {
- t.Fatalf("error receiving console FD: %v", err)
- }
- ptyMaster.Close()
+ // Make sure we get a console PTY.
+ ptyMaster, err := receiveConsolePTY(srv)
+ if err != nil {
+ t.Fatalf("error receiving console FD: %v", err)
+ }
+ ptyMaster.Close()
+ })
}
}
// Test that job control signals work on a console created with "exec -ti".
func TestJobControlSignalExec(t *testing.T) {
spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create and start the container.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
@@ -195,7 +194,10 @@ func TestJobControlSignalExec(t *testing.T) {
defer ptyMaster.Close()
defer ptySlave.Close()
- // Exec bash and attach a terminal.
+ // Exec bash and attach a terminal. Note that occasionally /bin/sh
+ // may be a different shell or have a different configuration (such
+ // as disabling interactive mode and job control). Since we want to
+ // explicitly test interactive mode, use /bin/bash. See b/116981926.
execArgs := &control.ExecArgs{
Filename: "/bin/bash",
// Don't let bash execute from profile or rc files, otherwise
@@ -219,9 +221,9 @@ func TestJobControlSignalExec(t *testing.T) {
// Make sure all the processes are running.
expectedPL := []*control.Process{
// Root container process.
- {PID: 1, Cmd: "sleep"},
+ {PID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{1}},
// Bash from exec process.
- {PID: 2, Cmd: "bash"},
+ {PID: 2, Cmd: "bash", Threads: []kernel.ThreadID{2}},
}
if err := waitForProcessList(c, expectedPL); err != nil {
t.Error(err)
@@ -231,7 +233,7 @@ func TestJobControlSignalExec(t *testing.T) {
ptyMaster.Write([]byte("sleep 100\n"))
// Wait for it to start. Sleep's PPID is bash's PID.
- expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep"})
+ expectedPL = append(expectedPL, &control.Process{PID: 3, PPID: 2, Cmd: "sleep", Threads: []kernel.ThreadID{3}})
if err := waitForProcessList(c, expectedPL); err != nil {
t.Error(err)
}
@@ -282,32 +284,28 @@ func TestJobControlSignalExec(t *testing.T) {
// Test that job control signals work on a console created with "run -ti".
func TestJobControlSignalRootContainer(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
// Don't let bash execute from profile or rc files, otherwise our PID
// counts get messed up.
spec := testutil.NewSpecWithArgs("/bin/bash", "--noprofile", "--norc")
spec.Process.Terminal = true
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
sock, err := socketPath(bundleDir)
if err != nil {
t.Fatalf("error getting socket path: %v", err)
}
- srv, cleanup, err := createConsoleSocket(sock)
- if err != nil {
- t.Fatalf("error creating socket at %q: %v", sock, err)
- }
+ srv, cleanup := createConsoleSocket(t, sock)
defer cleanup()
// Create the container and pass the socket name.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
ConsoleSocket: sock,
@@ -329,13 +327,13 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// file. Writes after a certain point will block unless we drain the
// PTY, so we must continually copy from it.
//
- // We log the output to stdout for debugabilitly, and also to a buffer,
+ // We log the output to stderr for debugabilitly, and also to a buffer,
// since we wait on particular output from bash below. We use a custom
// blockingBuffer which is thread-safe and also blocks on Read calls,
// which makes this a suitable Reader for WaitUntilRead.
ptyBuf := newBlockingBuffer()
tee := io.TeeReader(ptyMaster, ptyBuf)
- go io.Copy(os.Stdout, tee)
+ go io.Copy(os.Stderr, tee)
// Start the container.
if err := c.Start(conf); err != nil {
@@ -361,19 +359,19 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// Wait for bash to start.
expectedPL := []*control.Process{
- {PID: 1, Cmd: "bash"},
+ {PID: 1, Cmd: "bash", Threads: []kernel.ThreadID{1}},
}
if err := waitForProcessList(c, expectedPL); err != nil {
- t.Fatal(err)
+ t.Fatalf("error waiting for processes: %v", err)
}
// Execute sleep via the terminal.
ptyMaster.Write([]byte("sleep 100\n"))
// Wait for sleep to start.
- expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep"})
+ expectedPL = append(expectedPL, &control.Process{PID: 2, PPID: 1, Cmd: "sleep", Threads: []kernel.ThreadID{2}})
if err := waitForProcessList(c, expectedPL); err != nil {
- t.Fatal(err)
+ t.Fatalf("error waiting for processes: %v", err)
}
// Reset the pty buffer, so there is less output for us to scan later.
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 32510d427..7ad09bf23 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -17,13 +17,11 @@ package container
import (
"context"
- "encoding/json"
+ "errors"
"fmt"
"io/ioutil"
"os"
"os/exec"
- "os/signal"
- "path/filepath"
"regexp"
"strconv"
"strings"
@@ -31,27 +29,18 @@ import (
"time"
"github.com/cenkalti/backoff"
- "github.com/gofrs/flock"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/sighandling"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/cgroup"
"gvisor.dev/gvisor/runsc/sandbox"
"gvisor.dev/gvisor/runsc/specutils"
)
-const (
- // metadataFilename is the name of the metadata file relative to the
- // container root directory that holds sandbox metadata.
- metadataFilename = "meta.json"
-
- // metadataLockFilename is the name of a lock file in the container
- // root directory that is used to prevent concurrent modifications to
- // the container state and metadata.
- metadataLockFilename = "meta.lock"
-)
-
// validateID validates the container id.
func validateID(id string) error {
// See libcontainer/factory_linux.go.
@@ -99,11 +88,6 @@ type Container struct {
// BundleDir is the directory containing the container bundle.
BundleDir string `json:"bundleDir"`
- // Root is the directory containing the container metadata file. If this
- // container is the root container, Root and RootContainerDir will be the
- // same.
- Root string `json:"root"`
-
// CreatedAt is the time the container was created.
CreatedAt time.Time `json:"createdAt"`
@@ -121,21 +105,24 @@ type Container struct {
// be 0 if the gofer has been killed.
GoferPid int `json:"goferPid"`
+ // Sandbox is the sandbox this container is running in. It's set when the
+ // container is created and reset when the sandbox is destroyed.
+ Sandbox *sandbox.Sandbox `json:"sandbox"`
+
+ // Saver handles load from/save to the state file safely from multiple
+ // processes.
+ Saver StateFile `json:"saver"`
+
+ //
+ // Fields below this line are not saved in the state file and will not
+ // be preserved across commands.
+ //
+
// goferIsChild is set if a gofer process is a child of the current process.
//
// This field isn't saved to json, because only a creator of a gofer
// process will have it as a child process.
goferIsChild bool
-
- // Sandbox is the sandbox this container is running in. It's set when the
- // container is created and reset when the sandbox is destroyed.
- Sandbox *sandbox.Sandbox `json:"sandbox"`
-
- // RootContainerDir is the root directory containing the metadata file of the
- // sandbox root container. It's used to lock in order to serialize creating
- // and deleting this Container's metadata directory. If this container is the
- // root container, this is the same as Root.
- RootContainerDir string
}
// loadSandbox loads all containers that belong to the sandbox with the given
@@ -166,43 +153,35 @@ func loadSandbox(rootDir, id string) ([]*Container, error) {
return containers, nil
}
-// Load loads a container with the given id from a metadata file. id may be an
-// abbreviation of the full container id, in which case Load loads the
-// container to which id unambiguously refers to.
-// Returns ErrNotExist if container doesn't exist.
-func Load(rootDir, id string) (*Container, error) {
- log.Debugf("Load container %q %q", rootDir, id)
- if err := validateID(id); err != nil {
+// Load loads a container with the given id from a metadata file. partialID may
+// be an abbreviation of the full container id, in which case Load loads the
+// container to which id unambiguously refers to. Returns ErrNotExist if
+// container doesn't exist.
+func Load(rootDir, partialID string) (*Container, error) {
+ log.Debugf("Load container %q %q", rootDir, partialID)
+ if err := validateID(partialID); err != nil {
return nil, fmt.Errorf("validating id: %v", err)
}
- cRoot, err := findContainerRoot(rootDir, id)
+ id, err := findContainerID(rootDir, partialID)
if err != nil {
// Preserve error so that callers can distinguish 'not found' errors.
return nil, err
}
- // Lock the container metadata to prevent other runsc instances from
- // writing to it while we are reading it.
- unlock, err := lockContainerMetadata(cRoot)
- if err != nil {
- return nil, err
+ state := StateFile{
+ RootDir: rootDir,
+ ID: id,
}
- defer unlock()
+ defer state.close()
- // Read the container metadata file and create a new Container from it.
- metaFile := filepath.Join(cRoot, metadataFilename)
- metaBytes, err := ioutil.ReadFile(metaFile)
- if err != nil {
+ c := &Container{}
+ if err := state.load(c); err != nil {
if os.IsNotExist(err) {
// Preserve error so that callers can distinguish 'not found' errors.
return nil, err
}
- return nil, fmt.Errorf("reading container metadata file %q: %v", metaFile, err)
- }
- var c Container
- if err := json.Unmarshal(metaBytes, &c); err != nil {
- return nil, fmt.Errorf("unmarshaling container metadata from %q: %v", metaFile, err)
+ return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err)
}
// If the status is "Running" or "Created", check that the sandbox
@@ -223,57 +202,37 @@ func Load(rootDir, id string) (*Container, error) {
}
}
- return &c, nil
+ return c, nil
}
-func findContainerRoot(rootDir, partialID string) (string, error) {
+func findContainerID(rootDir, partialID string) (string, error) {
// Check whether the id fully specifies an existing container.
- cRoot := filepath.Join(rootDir, partialID)
- if _, err := os.Stat(cRoot); err == nil {
- return cRoot, nil
+ stateFile := buildStatePath(rootDir, partialID)
+ if _, err := os.Stat(stateFile); err == nil {
+ return partialID, nil
}
// Now see whether id could be an abbreviation of exactly 1 of the
// container ids. If id is ambiguous (it could match more than 1
// container), it is an error.
- cRoot = ""
ids, err := List(rootDir)
if err != nil {
return "", err
}
+ rv := ""
for _, id := range ids {
if strings.HasPrefix(id, partialID) {
- if cRoot != "" {
- return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, cRoot, id)
+ if rv != "" {
+ return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id)
}
- cRoot = id
+ rv = id
}
}
- if cRoot == "" {
+ if rv == "" {
return "", os.ErrNotExist
}
- log.Debugf("abbreviated id %q resolves to full id %q", partialID, cRoot)
- return filepath.Join(rootDir, cRoot), nil
-}
-
-// List returns all container ids in the given root directory.
-func List(rootDir string) ([]string, error) {
- log.Debugf("List containers %q", rootDir)
- fs, err := ioutil.ReadDir(rootDir)
- if err != nil {
- return nil, fmt.Errorf("reading dir %q: %v", rootDir, err)
- }
- var out []string
- for _, f := range fs {
- // Filter out directories that do no belong to a container.
- cid := f.Name()
- if validateID(cid) == nil {
- if _, err := os.Stat(filepath.Join(rootDir, cid, metadataFilename)); err == nil {
- out = append(out, f.Name())
- }
- }
- }
- return out, nil
+ log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv)
+ return rv, nil
}
// Args is used to configure a new container.
@@ -316,44 +275,34 @@ func New(conf *boot.Config, args Args) (*Container, error) {
return nil, err
}
- unlockRoot, err := maybeLockRootContainer(args.Spec, conf.RootDir)
- if err != nil {
- return nil, err
+ if err := os.MkdirAll(conf.RootDir, 0711); err != nil {
+ return nil, fmt.Errorf("creating container root directory %q: %v", conf.RootDir, err)
}
- defer unlockRoot()
+
+ c := &Container{
+ ID: args.ID,
+ Spec: args.Spec,
+ ConsoleSocket: args.ConsoleSocket,
+ BundleDir: args.BundleDir,
+ Status: Creating,
+ CreatedAt: time.Now(),
+ Owner: os.Getenv("USER"),
+ Saver: StateFile{
+ RootDir: conf.RootDir,
+ ID: args.ID,
+ },
+ }
+ // The Cleanup object cleans up partially created containers when an error
+ // occurs. Any errors occurring during cleanup itself are ignored.
+ cu := cleanup.Make(func() { _ = c.Destroy() })
+ defer cu.Clean()
// Lock the container metadata file to prevent concurrent creations of
// containers with the same id.
- containerRoot := filepath.Join(conf.RootDir, args.ID)
- unlock, err := lockContainerMetadata(containerRoot)
- if err != nil {
+ if err := c.Saver.lockForNew(); err != nil {
return nil, err
}
- defer unlock()
-
- // Check if the container already exists by looking for the metadata
- // file.
- if _, err := os.Stat(filepath.Join(containerRoot, metadataFilename)); err == nil {
- return nil, fmt.Errorf("container with id %q already exists", args.ID)
- } else if !os.IsNotExist(err) {
- return nil, fmt.Errorf("looking for existing container in %q: %v", containerRoot, err)
- }
-
- c := &Container{
- ID: args.ID,
- Spec: args.Spec,
- ConsoleSocket: args.ConsoleSocket,
- BundleDir: args.BundleDir,
- Root: containerRoot,
- Status: Creating,
- CreatedAt: time.Now(),
- Owner: os.Getenv("USER"),
- RootContainerDir: conf.RootDir,
- }
- // The Cleanup object cleans up partially created containers when an error occurs.
- // Any errors occuring during cleanup itself are ignored.
- cu := specutils.MakeCleanup(func() { _ = c.Destroy() })
- defer cu.Clean()
+ defer c.Saver.unlock()
// If the metadata annotations indicate that this container should be
// started in an existing sandbox, we must do so. The metadata will
@@ -375,7 +324,7 @@ func New(conf *boot.Config, args Args) (*Container, error) {
}
}
if err := runInCgroup(cg, func() error {
- ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir)
+ ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir, args.Attached)
if err != nil {
return err
}
@@ -431,7 +380,7 @@ func New(conf *boot.Config, args Args) (*Container, error) {
c.changeStatus(Created)
// Save the metadata file.
- if err := c.save(); err != nil {
+ if err := c.saveLocked(); err != nil {
return nil, err
}
@@ -451,17 +400,12 @@ func New(conf *boot.Config, args Args) (*Container, error) {
func (c *Container) Start(conf *boot.Config) error {
log.Debugf("Start container %q", c.ID)
- unlockRoot, err := maybeLockRootContainer(c.Spec, c.RootContainerDir)
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlockRoot()
+ unlock := cleanup.Make(func() { c.Saver.unlock() })
+ defer unlock.Clean()
- unlock, err := c.lock()
- if err != nil {
- return err
- }
- defer unlock()
if err := c.requireStatus("start", Created); err != nil {
return err
}
@@ -479,11 +423,11 @@ func (c *Container) Start(conf *boot.Config) error {
return err
}
} else {
- // Join cgroup to strt gofer process to ensure it's part of the cgroup from
+ // Join cgroup to start gofer process to ensure it's part of the cgroup from
// the start (and all their children processes).
if err := runInCgroup(c.Sandbox.Cgroup, func() error {
// Create the gofer process.
- ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir)
+ ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir, false)
if err != nil {
return err
}
@@ -509,14 +453,15 @@ func (c *Container) Start(conf *boot.Config) error {
}
c.changeStatus(Running)
- if err := c.save(); err != nil {
+ if err := c.saveLocked(); err != nil {
return err
}
- // Adjust the oom_score_adj for sandbox. This must be done after
- // save().
- err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false)
- if err != nil {
+ // Release lock before adjusting OOM score because the lock is acquired there.
+ unlock.Clean()
+
+ // Adjust the oom_score_adj for sandbox. This must be done after saveLocked().
+ if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil {
return err
}
@@ -529,11 +474,10 @@ func (c *Container) Start(conf *boot.Config) error {
// to restore a container from its state file.
func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error {
log.Debugf("Restore container %q", c.ID)
- unlock, err := c.lock()
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlock()
+ defer c.Saver.unlock()
if err := c.requireStatus("restore", Created); err != nil {
return err
@@ -551,7 +495,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile str
return err
}
c.changeStatus(Running)
- return c.save()
+ return c.saveLocked()
}
// Run is a helper that calls Create + Start + Wait.
@@ -563,7 +507,7 @@ func Run(conf *boot.Config, args Args) (syscall.WaitStatus, error) {
}
// Clean up partially created container if an error occurs.
// Any errors returned by Destroy() itself are ignored.
- cu := specutils.MakeCleanup(func() {
+ cu := cleanup.Make(func() {
c.Destroy()
})
defer cu.Clean()
@@ -679,21 +623,15 @@ func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error {
// forwarding signals.
func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() {
log.Debugf("Forwarding all signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess)
- sigCh := make(chan os.Signal, 1)
- signal.Notify(sigCh)
- go func() {
- for s := range sigCh {
- log.Debugf("Forwarding signal %d to container %q PID %d fgProcess=%t", s, c.ID, pid, fgProcess)
- if err := c.Sandbox.SignalProcess(c.ID, pid, s.(syscall.Signal), fgProcess); err != nil {
- log.Warningf("error forwarding signal %d to container %q: %v", s, c.ID, err)
- }
+ stop := sighandling.StartSignalForwarding(func(sig linux.Signal) {
+ log.Debugf("Forwarding signal %d to container %q PID %d fgProcess=%t", sig, c.ID, pid, fgProcess)
+ if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.Signal(sig), fgProcess); err != nil {
+ log.Warningf("error forwarding signal %d to container %q: %v", sig, c.ID, err)
}
- log.Debugf("Done forwarding signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess)
- }()
-
+ })
return func() {
- signal.Stop(sigCh)
- close(sigCh)
+ log.Debugf("Done forwarding signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess)
+ stop()
}
}
@@ -711,11 +649,10 @@ func (c *Container) Checkpoint(f *os.File) error {
// The call only succeeds if the container's status is created or running.
func (c *Container) Pause() error {
log.Debugf("Pausing container %q", c.ID)
- unlock, err := c.lock()
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlock()
+ defer c.Saver.unlock()
if c.Status != Created && c.Status != Running {
return fmt.Errorf("cannot pause container %q in state %v", c.ID, c.Status)
@@ -725,18 +662,17 @@ func (c *Container) Pause() error {
return fmt.Errorf("pausing container: %v", err)
}
c.changeStatus(Paused)
- return c.save()
+ return c.saveLocked()
}
// Resume unpauses the container and its kernel.
// The call only succeeds if the container's status is paused.
func (c *Container) Resume() error {
log.Debugf("Resuming container %q", c.ID)
- unlock, err := c.lock()
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlock()
+ defer c.Saver.unlock()
if c.Status != Paused {
return fmt.Errorf("cannot resume container %q in state %v", c.ID, c.Status)
@@ -745,7 +681,7 @@ func (c *Container) Resume() error {
return fmt.Errorf("resuming container: %v", err)
}
c.changeStatus(Running)
- return c.save()
+ return c.saveLocked()
}
// State returns the metadata of the container.
@@ -773,6 +709,17 @@ func (c *Container) Processes() ([]*control.Process, error) {
func (c *Container) Destroy() error {
log.Debugf("Destroy container %q", c.ID)
+ if err := c.Saver.lock(); err != nil {
+ return err
+ }
+ defer func() {
+ c.Saver.unlock()
+ c.Saver.close()
+ }()
+
+ // Stored for later use as stop() sets c.Sandbox to nil.
+ sb := c.Sandbox
+
// We must perform the following cleanup steps:
// * stop the container and gofer processes,
// * remove the container filesystem on the host, and
@@ -782,48 +729,43 @@ func (c *Container) Destroy() error {
// do our best to perform all of the cleanups. Hence, we keep a slice
// of errors return their concatenation.
var errs []string
-
- unlock, err := maybeLockRootContainer(c.Spec, c.RootContainerDir)
- if err != nil {
- return err
- }
- defer unlock()
-
- // Stored for later use as stop() sets c.Sandbox to nil.
- sb := c.Sandbox
-
if err := c.stop(); err != nil {
err = fmt.Errorf("stopping container: %v", err)
log.Warningf("%v", err)
errs = append(errs, err.Error())
}
- if err := os.RemoveAll(c.Root); err != nil && !os.IsNotExist(err) {
- err = fmt.Errorf("deleting container root directory %q: %v", c.Root, err)
+ if err := c.Saver.destroy(); err != nil {
+ err = fmt.Errorf("deleting container state files: %v", err)
log.Warningf("%v", err)
errs = append(errs, err.Error())
}
c.changeStatus(Stopped)
- // Adjust oom_score_adj for the sandbox. This must be done after the
- // container is stopped and the directory at c.Root is removed.
- // We must test if the sandbox is nil because Destroy should be
- // idempotent.
- if sb != nil {
- if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil {
+ // Adjust oom_score_adj for the sandbox. This must be done after the container
+ // is stopped and the directory at c.Root is removed. Adjustment can be
+ // skipped if the root container is exiting, because it brings down the entire
+ // sandbox.
+ //
+ // Use 'sb' to tell whether it has been executed before because Destroy must
+ // be idempotent.
+ if sb != nil && !isRoot(c.Spec) {
+ if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil {
errs = append(errs, err.Error())
}
}
// "If any poststop hook fails, the runtime MUST log a warning, but the
- // remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec.
- // Based on the OCI, "The post-stop hooks MUST be called after the container is
- // deleted but before the delete operation returns"
+ // remaining hooks and lifecycle continue as if the hook had
+ // succeeded" - OCI spec.
+ //
+ // Based on the OCI, "The post-stop hooks MUST be called after the container
+ // is deleted but before the delete operation returns"
// Run it here to:
// 1) Conform to the OCI.
- // 2) Make sure it only runs once, because the root has been deleted, the container
- // can't be loaded again.
+ // 2) Make sure it only runs once, because the root has been deleted, the
+ // container can't be loaded again.
if c.Spec.Hooks != nil {
executeHooksBestEffort(c.Spec.Hooks.Poststop, c.State())
}
@@ -834,18 +776,13 @@ func (c *Container) Destroy() error {
return fmt.Errorf(strings.Join(errs, "\n"))
}
-// save saves the container metadata to a file.
+// saveLocked saves the container metadata to a file.
//
// Precondition: container must be locked with container.lock().
-func (c *Container) save() error {
+func (c *Container) saveLocked() error {
log.Debugf("Save container %q", c.ID)
- metaFile := filepath.Join(c.Root, metadataFilename)
- meta, err := json.Marshal(c)
- if err != nil {
- return fmt.Errorf("invalid container metadata: %v", err)
- }
- if err := ioutil.WriteFile(metaFile, meta, 0640); err != nil {
- return fmt.Errorf("writing container metadata: %v", err)
+ if err := c.Saver.saveLocked(c); err != nil {
+ return fmt.Errorf("saving container metadata: %v", err)
}
return nil
}
@@ -924,7 +861,7 @@ func (c *Container) waitForStopped() error {
return backoff.Retry(op, b)
}
-func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bundleDir string) ([]*os.File, *os.File, error) {
+func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bundleDir string, attached bool) ([]*os.File, *os.File, error) {
// Start with the general config flags.
args := conf.ToFlags()
@@ -1018,6 +955,14 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund
cmd.ExtraFiles = goferEnds
cmd.Args[0] = "runsc-gofer"
+ if attached {
+ // The gofer is attached to the lifetime of this process, so it
+ // should synchronously die when this process dies.
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Pdeathsig: syscall.SIGKILL,
+ }
+ }
+
// Enter new namespaces to isolate from the rest of the system. Don't unshare
// cgroup because gofer is added to a cgroup in the caller's namespace.
nss := []specs.LinuxNamespace{
@@ -1106,48 +1051,6 @@ func (c *Container) requireStatus(action string, statuses ...Status) error {
return fmt.Errorf("cannot %s container %q in state %s", action, c.ID, c.Status)
}
-// lock takes a file lock on the container metadata lock file.
-func (c *Container) lock() (func() error, error) {
- return lockContainerMetadata(filepath.Join(c.Root, c.ID))
-}
-
-// lockContainerMetadata takes a file lock on the metadata lock file in the
-// given container root directory.
-func lockContainerMetadata(containerRootDir string) (func() error, error) {
- if err := os.MkdirAll(containerRootDir, 0711); err != nil {
- return nil, fmt.Errorf("creating container root directory %q: %v", containerRootDir, err)
- }
- f := filepath.Join(containerRootDir, metadataLockFilename)
- l := flock.NewFlock(f)
- if err := l.Lock(); err != nil {
- return nil, fmt.Errorf("acquiring lock on container lock file %q: %v", f, err)
- }
- return l.Unlock, nil
-}
-
-// maybeLockRootContainer locks the sandbox root container. It is used to
-// prevent races to create and delete child container sandboxes.
-func maybeLockRootContainer(spec *specs.Spec, rootDir string) (func() error, error) {
- if isRoot(spec) {
- return func() error { return nil }, nil
- }
-
- sbid, ok := specutils.SandboxID(spec)
- if !ok {
- return nil, fmt.Errorf("no sandbox ID found when locking root container")
- }
- sb, err := Load(rootDir, sbid)
- if err != nil {
- return nil, err
- }
-
- unlock, err := sb.lock()
- if err != nil {
- return nil, err
- }
- return unlock, nil
-}
-
func isRoot(spec *specs.Spec) bool {
return specutils.SpecContainerType(spec) != specutils.ContainerTypeContainer
}
@@ -1168,22 +1071,19 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer.
func (c *Container) adjustGoferOOMScoreAdj() error {
- if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil {
- if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
- return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
- }
+ if c.GoferPid == 0 || c.Spec.Process.OOMScoreAdj == nil {
+ return nil
}
-
- return nil
+ return setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj)
}
// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox.
// oom_score_adj is set to the lowest oom_score_adj among the containers
// running in the sandbox.
//
-// TODO(gvisor.dev/issue/512): This call could race with other containers being
+// TODO(gvisor.dev/issue/238): This call could race with other containers being
// created at the same time and end up setting the wrong oom_score_adj to the
-// sandbox.
+// sandbox. Use rpc client to synchronize.
func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
containers, err := loadSandbox(rootDir, s.ID)
if err != nil {
@@ -1251,24 +1151,29 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool)
}
// Set the lowest of all containers oom_score_adj to the sandbox.
- if err := setOOMScoreAdj(s.Pid, lowScore); err != nil {
- return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
- }
-
- return nil
+ return setOOMScoreAdj(s.Pid, lowScore)
}
// setOOMScoreAdj sets oom_score_adj to the given value for the given PID.
// /proc must be available and mounted read-write. scoreAdj should be between
-// -1000 and 1000.
+// -1000 and 1000. It's a noop if the process has already exited.
func setOOMScoreAdj(pid int, scoreAdj int) error {
f, err := os.OpenFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid), os.O_WRONLY, 0644)
if err != nil {
+ // Ignore NotExist errors because it can race with process exit.
+ if os.IsNotExist(err) {
+ log.Warningf("Process (%d) not found setting oom_score_adj", pid)
+ return nil
+ }
return err
}
defer f.Close()
if _, err := f.WriteString(strconv.Itoa(scoreAdj)); err != nil {
- return err
+ if errors.Is(err, syscall.ESRCH) {
+ log.Warningf("Process (%d) exited while setting oom_score_adj", pid)
+ return nil
+ }
+ return fmt.Errorf("setting oom_score_adj to %q: %v", scoreAdj, err)
}
return nil
}
diff --git a/runsc/container/container_norace_test.go b/runsc/container/container_norace_test.go
new file mode 100644
index 000000000..838c1e20a
--- /dev/null
+++ b/runsc/container/container_norace_test.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build !race
+
+package container
+
+// Allow both kvm and ptrace for non-race builds.
+var platformOptions = []configOption{ptrace, kvm}
diff --git a/pkg/sentry/socket/rpcinet/device.go b/runsc/container/container_race_test.go
index 8cfd5f6e5..9fb4c4fc0 100644
--- a/pkg/sentry/socket/rpcinet/device.go
+++ b/runsc/container/container_race_test.go
@@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package rpcinet
+// +build race
-import "gvisor.dev/gvisor/pkg/sentry/device"
+package container
-var socketDevice = device.NewAnonDevice()
+// Only enabled ptrace with race builds.
+var platformOptions = []configOption{ptrace}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 07eacaac0..5e8247bc8 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -20,13 +20,13 @@ import (
"fmt"
"io"
"io/ioutil"
+ "math"
"os"
"path"
"path/filepath"
"reflect"
"strconv"
"strings"
- "sync"
"syscall"
"testing"
"time"
@@ -37,11 +37,13 @@ import (
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/testutil"
)
// waitForProcessList waits for the given process list to show up in the container.
@@ -69,6 +71,7 @@ func waitForProcessCount(cont *Container, want int) error {
return &backoff.PermanentError{Err: err}
}
if got := len(pss); got != want {
+ log.Infof("Waiting for process count to reach %d. Current: %d", want, got)
return fmt.Errorf("wrong process count, got: %d, want: %d", got, want)
}
return nil
@@ -89,37 +92,72 @@ func blockUntilWaitable(pid int) error {
return err
}
-// procListsEqual is used to check whether 2 Process lists are equal for all
-// implemented fields.
-func procListsEqual(got, want []*control.Process) bool {
- if len(got) != len(want) {
+// procListsEqual is used to check whether 2 Process lists are equal. Fields
+// set to -1 in wants are ignored. Timestamp and threads fields are always
+// ignored.
+func procListsEqual(gots, wants []*control.Process) bool {
+ if len(gots) != len(wants) {
return false
}
- for i := range got {
- pd1 := got[i]
- pd2 := want[i]
- // Zero out unimplemented and timing dependant fields.
- pd1.Time = ""
- pd1.STime = ""
- pd1.C = 0
- if *pd1 != *pd2 {
+ for i := range gots {
+ got := gots[i]
+ want := wants[i]
+
+ if want.UID != math.MaxUint32 && want.UID != got.UID {
+ return false
+ }
+ if want.PID != -1 && want.PID != got.PID {
+ return false
+ }
+ if want.PPID != -1 && want.PPID != got.PPID {
+ return false
+ }
+ if len(want.TTY) != 0 && want.TTY != got.TTY {
+ return false
+ }
+ if len(want.Cmd) != 0 && want.Cmd != got.Cmd {
return false
}
}
return true
}
-// getAndCheckProcLists is similar to waitForProcessList, but does not wait and retry the
-// test for equality. This is because we already confirmed that exec occurred.
-func getAndCheckProcLists(cont *Container, want []*control.Process) error {
- got, err := cont.Processes()
- if err != nil {
- return fmt.Errorf("error getting process data from container: %v", err)
- }
- if procListsEqual(got, want) {
- return nil
+type processBuilder struct {
+ process control.Process
+}
+
+func newProcessBuilder() *processBuilder {
+ return &processBuilder{
+ process: control.Process{
+ UID: math.MaxUint32,
+ PID: -1,
+ PPID: -1,
+ },
}
- return fmt.Errorf("container got process list: %s, want: %s", procListToString(got), procListToString(want))
+}
+
+func (p *processBuilder) Cmd(cmd string) *processBuilder {
+ p.process.Cmd = cmd
+ return p
+}
+
+func (p *processBuilder) PID(pid kernel.ThreadID) *processBuilder {
+ p.process.PID = pid
+ return p
+}
+
+func (p *processBuilder) PPID(ppid kernel.ThreadID) *processBuilder {
+ p.process.PPID = ppid
+ return p
+}
+
+func (p *processBuilder) UID(uid auth.KUID) *processBuilder {
+ p.process.UID = uid
+ return p
+}
+
+func (p *processBuilder) Process() *control.Process {
+ return &p.process
}
func procListToString(pl []*control.Process) string {
@@ -145,7 +183,7 @@ func createWriteableOutputFile(path string) (*os.File, error) {
return outputFile, nil
}
-func waitForFile(f *os.File) error {
+func waitForFileNotEmpty(f *os.File) error {
op := func() error {
fi, err := f.Stat()
if err != nil {
@@ -160,6 +198,17 @@ func waitForFile(f *os.File) error {
return testutil.Poll(op, 30*time.Second)
}
+func waitForFileExist(path string) error {
+ op := func() error {
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ return err
+ }
+ return nil
+ }
+
+ return testutil.Poll(op, 30*time.Second)
+}
+
// readOutputNum reads a file at given filepath and returns the int at the
// requested position.
func readOutputNum(file string, position int) (int, error) {
@@ -169,7 +218,7 @@ func readOutputNum(file string, position int) (int, error) {
}
// Ensure that there is content in output file.
- if err := waitForFile(f); err != nil {
+ if err := waitForFileNotEmpty(f); err != nil {
return 0, fmt.Errorf("error waiting for output file: %v", err)
}
@@ -202,16 +251,15 @@ func readOutputNum(file string, position int) (int, error) {
// run starts the sandbox and waits for it to exit, checking that the
// application succeeded.
func run(spec *specs.Spec, conf *boot.Config) error {
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
return fmt.Errorf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create, start and wait for the container.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
Attached: true,
@@ -230,39 +278,64 @@ type configOption int
const (
overlay configOption = iota
+ ptrace
kvm
nonExclusiveFS
)
-var noOverlay = []configOption{kvm, nonExclusiveFS}
-var all = append(noOverlay, overlay)
+var (
+ noOverlay = append(platformOptions, nonExclusiveFS)
+ all = append(noOverlay, overlay)
+)
// configs generates different configurations to run tests.
-func configs(opts ...configOption) []*boot.Config {
+func configs(t *testing.T, opts ...configOption) map[string]*boot.Config {
// Always load the default config.
- cs := []*boot.Config{testutil.TestConfig()}
-
+ cs := make(map[string]*boot.Config)
for _, o := range opts {
- c := testutil.TestConfig()
switch o {
case overlay:
+ c := testutil.TestConfig(t)
c.Overlay = true
+ cs["overlay"] = c
+ case ptrace:
+ c := testutil.TestConfig(t)
+ c.Platform = platforms.Ptrace
+ cs["ptrace"] = c
case kvm:
- // TODO(b/112165693): KVM tests are flaky. Disable until fixed.
- continue
-
+ c := testutil.TestConfig(t)
c.Platform = platforms.KVM
+ cs["kvm"] = c
case nonExclusiveFS:
+ c := testutil.TestConfig(t)
c.FileAccess = boot.FileAccessShared
+ cs["non-exclusive"] = c
default:
panic(fmt.Sprintf("unknown config option %v", o))
-
}
- cs = append(cs, c)
}
return cs
}
+func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*boot.Config {
+ vfs1 := configs(t, opts...)
+
+ var optsVFS2 []configOption
+ for _, opt := range opts {
+ // TODO(gvisor.dev/issue/1487): Enable overlay tests.
+ if opt != overlay {
+ optsVFS2 = append(optsVFS2, opt)
+ }
+ }
+
+ for key, value := range configs(t, optsVFS2...) {
+ value.VFS2 = true
+ vfs1[key+"VFS2"] = value
+ }
+
+ return vfs1
+}
+
// TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle.
// It verifies after each step that the container can be loaded from disk, and
// has the correct status.
@@ -272,132 +345,126 @@ func TestLifecycle(t *testing.T) {
childReaper.Start()
defer childReaper.Stop()
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
- // The container will just sleep for a long time. We will kill it before
- // it finishes sleeping.
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ // The container will just sleep for a long time. We will kill it before
+ // it finishes sleeping.
+ spec := testutil.NewSpecWithArgs("sleep", "100")
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
-
- // expectedPL lists the expected process state of the container.
- expectedPL := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- },
- }
- // Create the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- c, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer c.Destroy()
-
- // Load the container from disk and check the status.
- c, err = Load(rootDir, args.ID)
- if err != nil {
- t.Fatalf("error loading container: %v", err)
- }
- if got, want := c.Status, Created; got != want {
- t.Errorf("container status got %v, want %v", got, want)
- }
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // List should return the container id.
- ids, err := List(rootDir)
- if err != nil {
- t.Fatalf("error listing containers: %v", err)
- }
- if got, want := ids, []string{args.ID}; !reflect.DeepEqual(got, want) {
- t.Errorf("container list got %v, want %v", got, want)
- }
+ // expectedPL lists the expected process state of the container.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ // Create the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
- // Start the container.
- if err := c.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Load the container from disk and check the status.
+ c, err = Load(rootDir, args.ID)
+ if err != nil {
+ t.Fatalf("error loading container: %v", err)
+ }
+ if got, want := c.Status, Created; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
- // Load the container from disk and check the status.
- c, err = Load(rootDir, args.ID)
- if err != nil {
- t.Fatalf("error loading container: %v", err)
- }
- if got, want := c.Status, Running; got != want {
- t.Errorf("container status got %v, want %v", got, want)
- }
+ // List should return the container id.
+ ids, err := List(rootDir)
+ if err != nil {
+ t.Fatalf("error listing containers: %v", err)
+ }
+ if got, want := ids, []string{args.ID}; !reflect.DeepEqual(got, want) {
+ t.Errorf("container list got %v, want %v", got, want)
+ }
- // Verify that "sleep 100" is running.
- if err := waitForProcessList(c, expectedPL); err != nil {
- t.Error(err)
- }
+ // Start the container.
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- // Wait on the container.
- var wg sync.WaitGroup
- wg.Add(1)
- ch := make(chan struct{})
- go func() {
- ch <- struct{}{}
- ws, err := c.Wait()
+ // Load the container from disk and check the status.
+ c, err = Load(rootDir, args.ID)
if err != nil {
- t.Fatalf("error waiting on container: %v", err)
+ t.Fatalf("error loading container: %v", err)
}
- if got, want := ws.Signal(), syscall.SIGTERM; got != want {
- t.Fatalf("got signal %v, want %v", got, want)
+ if got, want := c.Status, Running; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
}
- wg.Done()
- }()
- // Wait a bit to ensure that we've started waiting on the
- // container before we signal.
- <-ch
- time.Sleep(100 * time.Millisecond)
- // Send the container a SIGTERM which will cause it to stop.
- if err := c.SignalContainer(syscall.SIGTERM, false); err != nil {
- t.Fatalf("error sending signal %v to container: %v", syscall.SIGTERM, err)
- }
- // Wait for it to die.
- wg.Wait()
+ // Verify that "sleep 100" is running.
+ if err := waitForProcessList(c, expectedPL); err != nil {
+ t.Error(err)
+ }
- // Load the container from disk and check the status.
- c, err = Load(rootDir, args.ID)
- if err != nil {
- t.Fatalf("error loading container: %v", err)
- }
- if got, want := c.Status, Stopped; got != want {
- t.Errorf("container status got %v, want %v", got, want)
- }
+ // Wait on the container.
+ ch := make(chan error)
+ go func() {
+ ws, err := c.Wait()
+ if err != nil {
+ ch <- err
+ }
+ if got, want := ws.Signal(), syscall.SIGTERM; got != want {
+ ch <- fmt.Errorf("got signal %v, want %v", got, want)
+ }
+ ch <- nil
+ }()
- // Destroy the container.
- if err := c.Destroy(); err != nil {
- t.Fatalf("error destroying container: %v", err)
- }
+ // Wait a bit to ensure that we've started waiting on
+ // the container before we signal.
+ time.Sleep(time.Second)
- // List should not return the container id.
- ids, err = List(rootDir)
- if err != nil {
- t.Fatalf("error listing containers: %v", err)
- }
- if len(ids) != 0 {
- t.Errorf("expected container list to be empty, but got %v", ids)
- }
+ // Send the container a SIGTERM which will cause it to stop.
+ if err := c.SignalContainer(syscall.SIGTERM, false); err != nil {
+ t.Fatalf("error sending signal %v to container: %v", syscall.SIGTERM, err)
+ }
- // Loading the container by id should fail.
- if _, err = Load(rootDir, args.ID); err == nil {
- t.Errorf("expected loading destroyed container to fail, but it did not")
- }
+ // Wait for it to die.
+ if err := <-ch; err != nil {
+ t.Fatalf("error waiting for container: %v", err)
+ }
+
+ // Load the container from disk and check the status.
+ c, err = Load(rootDir, args.ID)
+ if err != nil {
+ t.Fatalf("error loading container: %v", err)
+ }
+ if got, want := c.Status, Stopped; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
+
+ // Destroy the container.
+ if err := c.Destroy(); err != nil {
+ t.Fatalf("error destroying container: %v", err)
+ }
+
+ // List should not return the container id.
+ ids, err = List(rootDir)
+ if err != nil {
+ t.Fatalf("error listing containers: %v", err)
+ }
+ if len(ids) != 0 {
+ t.Errorf("expected container list to be empty, but got %v", ids)
+ }
+
+ // Loading the container by id should fail.
+ if _, err = Load(rootDir, args.ID); err == nil {
+ t.Errorf("expected loading destroyed container to fail, but it did not")
+ }
+ })
}
}
@@ -406,12 +473,14 @@ func TestExePath(t *testing.T) {
// Create two directories that will be prepended to PATH.
firstPath, err := ioutil.TempDir(testutil.TmpDir(), "first")
if err != nil {
- t.Fatal(err)
+ t.Fatalf("error creating temporary directory: %v", err)
}
+ defer os.RemoveAll(firstPath)
secondPath, err := ioutil.TempDir(testutil.TmpDir(), "second")
if err != nil {
- t.Fatal(err)
+ t.Fatalf("error creating temporary directory: %v", err)
}
+ defer os.RemoveAll(secondPath)
// Create two minimal executables in the second path, two of which
// will be masked by files in first path.
@@ -419,11 +488,11 @@ func TestExePath(t *testing.T) {
path := filepath.Join(secondPath, p)
f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0777)
if err != nil {
- t.Fatal(err)
+ t.Fatalf("error opening path: %v", err)
}
defer f.Close()
if _, err := io.WriteString(f, "#!/bin/true\n"); err != nil {
- t.Fatal(err)
+ t.Fatalf("error writing contents: %v", err)
}
}
@@ -432,7 +501,7 @@ func TestExePath(t *testing.T) {
nonExecutable := filepath.Join(firstPath, "masked1")
f2, err := os.OpenFile(nonExecutable, os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
- t.Fatal(err)
+ t.Fatalf("error opening file: %v", err)
}
f2.Close()
@@ -440,85 +509,95 @@ func TestExePath(t *testing.T) {
// executable in the second.
nonRegular := filepath.Join(firstPath, "masked2")
if err := os.Mkdir(nonRegular, 0777); err != nil {
- t.Fatal(err)
- }
-
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
- for _, test := range []struct {
- path string
- success bool
- }{
- {path: "true", success: true},
- {path: "bin/true", success: true},
- {path: "/bin/true", success: true},
- {path: "thisfiledoesntexit", success: false},
- {path: "bin/thisfiledoesntexit", success: false},
- {path: "/bin/thisfiledoesntexit", success: false},
-
- {path: "unmasked", success: true},
- {path: filepath.Join(firstPath, "unmasked"), success: false},
- {path: filepath.Join(secondPath, "unmasked"), success: true},
-
- {path: "masked1", success: true},
- {path: filepath.Join(firstPath, "masked1"), success: false},
- {path: filepath.Join(secondPath, "masked1"), success: true},
-
- {path: "masked2", success: true},
- {path: filepath.Join(firstPath, "masked2"), success: false},
- {path: filepath.Join(secondPath, "masked2"), success: true},
- } {
- spec := testutil.NewSpecWithArgs(test.path)
- spec.Process.Env = []string{
- fmt.Sprintf("PATH=%s:%s:%s", firstPath, secondPath, os.Getenv("PATH")),
- }
-
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("exec: %s, error setting up container: %v", test.path, err)
- }
-
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- Attached: true,
- }
- ws, err := Run(conf, args)
-
- os.RemoveAll(rootDir)
- os.RemoveAll(bundleDir)
-
- if test.success {
- if err != nil {
- t.Errorf("exec: %s, error running container: %v", test.path, err)
- }
- if ws.ExitStatus() != 0 {
- t.Errorf("exec: %s, got exit status %v want %v", test.path, ws.ExitStatus(), 0)
- }
- } else {
- if err == nil {
- t.Errorf("exec: %s, got: no error, want: error", test.path)
- }
+ t.Fatalf("error making directory: %v", err)
+ }
+
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ for _, test := range []struct {
+ path string
+ success bool
+ }{
+ {path: "true", success: true},
+ {path: "bin/true", success: true},
+ {path: "/bin/true", success: true},
+ {path: "thisfiledoesntexit", success: false},
+ {path: "bin/thisfiledoesntexit", success: false},
+ {path: "/bin/thisfiledoesntexit", success: false},
+
+ {path: "unmasked", success: true},
+ {path: filepath.Join(firstPath, "unmasked"), success: false},
+ {path: filepath.Join(secondPath, "unmasked"), success: true},
+
+ {path: "masked1", success: true},
+ {path: filepath.Join(firstPath, "masked1"), success: false},
+ {path: filepath.Join(secondPath, "masked1"), success: true},
+
+ {path: "masked2", success: true},
+ {path: filepath.Join(firstPath, "masked2"), success: false},
+ {path: filepath.Join(secondPath, "masked2"), success: true},
+ } {
+ t.Run(fmt.Sprintf("path=%s,success=%t", test.path, test.success), func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs(test.path)
+ spec.Process.Env = []string{
+ fmt.Sprintf("PATH=%s:%s:%s", firstPath, secondPath, os.Getenv("PATH")),
+ }
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("exec: error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ Attached: true,
+ }
+ ws, err := Run(conf, args)
+
+ if test.success {
+ if err != nil {
+ t.Errorf("exec: error running container: %v", err)
+ }
+ if ws.ExitStatus() != 0 {
+ t.Errorf("exec: got exit status %v want %v", ws.ExitStatus(), 0)
+ }
+ } else {
+ if err == nil {
+ t.Errorf("exec: got: no error, want: error")
+ }
+ }
+ })
}
- }
+ })
}
}
// Test the we can retrieve the application exit status from the container.
func TestAppExitStatus(t *testing.T) {
+ doAppExitStatus(t, false)
+}
+
+// This is TestAppExitStatus for VFSv2.
+func TestAppExitStatusVFS2(t *testing.T) {
+ doAppExitStatus(t, true)
+}
+
+func doAppExitStatus(t *testing.T, vfs2 bool) {
// First container will succeed.
succSpec := testutil.NewSpecWithArgs("true")
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(succSpec, conf)
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ _, bundleDir, cleanup, err := testutil.SetupContainer(succSpec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: succSpec,
BundleDir: bundleDir,
Attached: true,
@@ -535,15 +614,14 @@ func TestAppExitStatus(t *testing.T) {
wantStatus := 123
errSpec := testutil.NewSpecWithArgs("bash", "-c", fmt.Sprintf("exit %d", wantStatus))
- rootDir2, bundleDir2, err := testutil.SetupContainer(errSpec, conf)
+ _, bundleDir2, cleanup2, err := testutil.SetupContainer(errSpec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir2)
- defer os.RemoveAll(bundleDir2)
+ defer cleanup2()
args2 := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: errSpec,
BundleDir: bundleDir2,
Attached: true,
@@ -559,164 +637,271 @@ func TestAppExitStatus(t *testing.T) {
// TestExec verifies that a container can exec a new program.
func TestExec(t *testing.T) {
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "exec-test")
+ if err != nil {
+ t.Fatalf("error creating temporary directory: %v", err)
+ }
+ // Note that some shells may exec the final command in a sequence as
+ // an optimization. We avoid this here by adding the exit 0.
+ cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100 && exit 0", dir)
+ spec := testutil.NewSpecWithArgs("sh", "-c", cmd)
- const uid = 343
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- // Create and start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Wait until sleep is running to ensure the symlink was created.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sh").Process(),
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(cont, expectedPL); err != nil {
+ t.Fatalf("waitForProcessList: %v", err)
+ }
- // expectedPL lists the expected process state of the container.
- expectedPL := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- },
- {
- UID: uid,
- PID: 2,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- },
- }
+ for _, tc := range []struct {
+ name string
+ args control.ExecArgs
+ }{
+ {
+ name: "complete",
+ args: control.ExecArgs{
+ Filename: "/bin/true",
+ Argv: []string{"/bin/true"},
+ },
+ },
+ {
+ name: "filename",
+ args: control.ExecArgs{
+ Filename: "/bin/true",
+ },
+ },
+ {
+ name: "argv",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/true"},
+ },
+ },
+ {
+ name: "filename resolution",
+ args: control.ExecArgs{
+ Filename: "true",
+ Envv: []string{"PATH=/bin"},
+ },
+ },
+ {
+ name: "argv resolution",
+ args: control.ExecArgs{
+ Argv: []string{"true"},
+ Envv: []string{"PATH=/bin"},
+ },
+ },
+ {
+ name: "argv symlink",
+ args: control.ExecArgs{
+ Argv: []string{filepath.Join(dir, "symlink")},
+ },
+ },
+ {
+ name: "working dir",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "${PWD}" != "/tmp" ]]; then exit 1; fi`},
+ WorkingDirectory: "/tmp",
+ },
+ },
+ {
+ name: "user",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "$(id -u)" != "343" ]]; then exit 1; fi`},
+ KUID: 343,
+ },
+ },
+ {
+ name: "group",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "$(id -g)" != "343" ]]; then exit 1; fi`},
+ KGID: 343,
+ },
+ },
+ {
+ name: "env",
+ args: control.ExecArgs{
+ Argv: []string{"/bin/sh", "-c", `if [[ "${FOO}" != "123" ]]; then exit 1; fi`},
+ Envv: []string{"FOO=123"},
+ },
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ // t.Parallel()
+ if ws, err := cont.executeSync(&tc.args); err != nil {
+ t.Fatalf("executeAsync(%+v): %v", tc.args, err)
+ } else if ws != 0 {
+ t.Fatalf("executeAsync(%+v) failed with exit: %v", tc.args, ws)
+ }
+ })
+ }
+ })
+ }
+}
- // Verify that "sleep 100" is running.
- if err := waitForProcessList(cont, expectedPL[:1]); err != nil {
- t.Error(err)
- }
+// TestExecProcList verifies that a container can exec a new program and it
+// shows correcly in the process list.
+func TestExecProcList(t *testing.T) {
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ const uid = 343
+ spec := testutil.NewSpecWithArgs("sleep", "100")
- execArgs := &control.ExecArgs{
- Filename: "/bin/sleep",
- Argv: []string{"/bin/sleep", "5"},
- WorkingDirectory: "/",
- KUID: uid,
- }
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Verify that "sleep 100" and "sleep 5" are running after exec.
- // First, start running exec (whick blocks).
- status := make(chan error, 1)
- go func() {
- exitStatus, err := cont.executeSync(execArgs)
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
if err != nil {
- log.Debugf("error executing: %v", err)
- status <- err
- } else if exitStatus != 0 {
- log.Debugf("bad status: %d", exitStatus)
- status <- fmt.Errorf("failed with exit status: %v", exitStatus)
- } else {
- status <- nil
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
}
- }()
- if err := waitForProcessList(cont, expectedPL); err != nil {
- t.Fatal(err)
- }
+ execArgs := &control.ExecArgs{
+ Filename: "/bin/sleep",
+ Argv: []string{"/bin/sleep", "5"},
+ WorkingDirectory: "/",
+ KUID: uid,
+ }
+
+ // Verify that "sleep 100" and "sleep 5" are running after exec. First,
+ // start running exec (which blocks).
+ ch := make(chan error)
+ go func() {
+ exitStatus, err := cont.executeSync(execArgs)
+ if err != nil {
+ ch <- err
+ } else if exitStatus != 0 {
+ ch <- fmt.Errorf("failed with exit status: %v", exitStatus)
+ } else {
+ ch <- nil
+ }
+ }()
- // Ensure that exec finished without error.
- select {
- case <-time.After(10 * time.Second):
- t.Fatalf("container timed out waiting for exec to finish.")
- case st := <-status:
- if st != nil {
- t.Errorf("container failed to exec %v: %v", args, err)
+ // expectedPL lists the expected process state of the container.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").UID(0).Process(),
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").UID(uid).Process(),
}
- }
+ if err := waitForProcessList(cont, expectedPL); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
+
+ // Ensure that exec finished without error.
+ select {
+ case <-time.After(10 * time.Second):
+ t.Fatalf("container timed out waiting for exec to finish.")
+ case err := <-ch:
+ if err != nil {
+ t.Errorf("container failed to exec %v: %v", args, err)
+ }
+ }
+ })
}
}
// TestKillPid verifies that we can signal individual exec'd processes.
func TestKillPid(t *testing.T) {
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
-
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
- if err != nil {
- t.Fatal("error finding test_app:", err)
- }
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
- const nProcs = 4
- spec := testutil.NewSpecWithArgs(app, "task-tree", "--depth", strconv.Itoa(nProcs-1), "--width=1", "--pause=true")
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ const nProcs = 4
+ spec := testutil.NewSpecWithArgs(app, "task-tree", "--depth", strconv.Itoa(nProcs-1), "--width=1", "--pause=true")
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create and start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- // Verify that all processes are running.
- if err := waitForProcessCount(cont, nProcs); err != nil {
- t.Fatalf("timed out waiting for processes to start: %v", err)
- }
+ // Verify that all processes are running.
+ if err := waitForProcessCount(cont, nProcs); err != nil {
+ t.Fatalf("timed out waiting for processes to start: %v", err)
+ }
- // Kill the child process with the largest PID.
- procs, err := cont.Processes()
- if err != nil {
- t.Fatalf("failed to get process list: %v", err)
- }
- var pid int32
- for _, p := range procs {
- if pid < int32(p.PID) {
- pid = int32(p.PID)
+ // Kill the child process with the largest PID.
+ procs, err := cont.Processes()
+ if err != nil {
+ t.Fatalf("failed to get process list: %v", err)
+ }
+ var pid int32
+ for _, p := range procs {
+ if pid < int32(p.PID) {
+ pid = int32(p.PID)
+ }
+ }
+ if err := cont.SignalProcess(syscall.SIGKILL, pid); err != nil {
+ t.Fatalf("failed to signal process %d: %v", pid, err)
}
- }
- if err := cont.SignalProcess(syscall.SIGKILL, pid); err != nil {
- t.Fatalf("failed to signal process %d: %v", pid, err)
- }
- // Verify that one process is gone.
- if err := waitForProcessCount(cont, nProcs-1); err != nil {
- t.Fatal(err)
- }
+ // Verify that one process is gone.
+ if err := waitForProcessCount(cont, nProcs-1); err != nil {
+ t.Fatalf("error waiting for processes: %v", err)
+ }
- procs, err = cont.Processes()
- if err != nil {
- t.Fatalf("failed to get process list: %v", err)
- }
- for _, p := range procs {
- if pid == int32(p.PID) {
- t.Fatalf("pid %d is still alive, which should be killed", pid)
+ procs, err = cont.Processes()
+ if err != nil {
+ t.Fatalf("failed to get process list: %v", err)
}
- }
+ for _, p := range procs {
+ if pid == int32(p.PID) {
+ t.Fatalf("pid %d is still alive, which should be killed", pid)
+ }
+ }
+ })
}
}
@@ -727,160 +912,160 @@ func TestKillPid(t *testing.T) {
// be the next consecutive number after the last number from the checkpointed container.
func TestCheckpointRestore(t *testing.T) {
// Skip overlay because test requires writing to host file.
- for _, conf := range configs(noOverlay...) {
- t.Logf("Running test with conf: %+v", conf)
-
- dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test")
- if err != nil {
- t.Fatalf("ioutil.TempDir failed: %v", err)
- }
- if err := os.Chmod(dir, 0777); err != nil {
- t.Fatalf("error chmoding file: %q, %v", dir, err)
- }
+ for name, conf := range configs(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "checkpoint-test")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir failed: %v", err)
+ }
+ defer os.RemoveAll(dir)
+ if err := os.Chmod(dir, 0777); err != nil {
+ t.Fatalf("error chmoding file: %q, %v", dir, err)
+ }
- outputPath := filepath.Join(dir, "output")
- outputFile, err := createWriteableOutputFile(outputPath)
- if err != nil {
- t.Fatalf("error creating output file: %v", err)
- }
- defer outputFile.Close()
+ outputPath := filepath.Join(dir, "output")
+ outputFile, err := createWriteableOutputFile(outputPath)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile.Close()
- script := fmt.Sprintf("for ((i=0; ;i++)); do echo $i >> %q; sleep 1; done", outputPath)
- spec := testutil.NewSpecWithArgs("bash", "-c", script)
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ script := fmt.Sprintf("for ((i=0; ;i++)); do echo $i >> %q; sleep 1; done", outputPath)
+ spec := testutil.NewSpecWithArgs("bash", "-c", script)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create and start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- // Set the image path, which is where the checkpoint image will be saved.
- imagePath := filepath.Join(dir, "test-image-file")
+ // Set the image path, which is where the checkpoint image will be saved.
+ imagePath := filepath.Join(dir, "test-image-file")
- // Create the image file and open for writing.
- file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
- if err != nil {
- t.Fatalf("error opening new file at imagePath: %v", err)
- }
- defer file.Close()
+ // Create the image file and open for writing.
+ file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
+ if err != nil {
+ t.Fatalf("error opening new file at imagePath: %v", err)
+ }
+ defer file.Close()
- // Wait until application has ran.
- if err := waitForFile(outputFile); err != nil {
- t.Fatalf("Failed to wait for output file: %v", err)
- }
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
- // Checkpoint running container; save state into new file.
- if err := cont.Checkpoint(file); err != nil {
- t.Fatalf("error checkpointing container to empty file: %v", err)
- }
- defer os.RemoveAll(imagePath)
+ // Checkpoint running container; save state into new file.
+ if err := cont.Checkpoint(file); err != nil {
+ t.Fatalf("error checkpointing container to empty file: %v", err)
+ }
+ defer os.RemoveAll(imagePath)
- lastNum, err := readOutputNum(outputPath, -1)
- if err != nil {
- t.Fatalf("error with outputFile: %v", err)
- }
+ lastNum, err := readOutputNum(outputPath, -1)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
- // Delete and recreate file before restoring.
- if err := os.Remove(outputPath); err != nil {
- t.Fatalf("error removing file")
- }
- outputFile2, err := createWriteableOutputFile(outputPath)
- if err != nil {
- t.Fatalf("error creating output file: %v", err)
- }
- defer outputFile2.Close()
+ // Delete and recreate file before restoring.
+ if err := os.Remove(outputPath); err != nil {
+ t.Fatalf("error removing file")
+ }
+ outputFile2, err := createWriteableOutputFile(outputPath)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile2.Close()
- // Restore into a new container.
- args2 := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont2, err := New(conf, args2)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont2.Destroy()
+ // Restore into a new container.
+ args2 := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont2, err := New(conf, args2)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont2.Destroy()
- if err := cont2.Restore(spec, conf, imagePath); err != nil {
- t.Fatalf("error restoring container: %v", err)
- }
+ if err := cont2.Restore(spec, conf, imagePath); err != nil {
+ t.Fatalf("error restoring container: %v", err)
+ }
- // Wait until application has ran.
- if err := waitForFile(outputFile2); err != nil {
- t.Fatalf("Failed to wait for output file: %v", err)
- }
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile2); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
- firstNum, err := readOutputNum(outputPath, 0)
- if err != nil {
- t.Fatalf("error with outputFile: %v", err)
- }
+ firstNum, err := readOutputNum(outputPath, 0)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
- // Check that lastNum is one less than firstNum and that the container picks
- // up from where it left off.
- if lastNum+1 != firstNum {
- t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum)
- }
- cont2.Destroy()
+ // Check that lastNum is one less than firstNum and that the container picks
+ // up from where it left off.
+ if lastNum+1 != firstNum {
+ t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum)
+ }
+ cont2.Destroy()
- // Restore into another container!
- // Delete and recreate file before restoring.
- if err := os.Remove(outputPath); err != nil {
- t.Fatalf("error removing file")
- }
- outputFile3, err := createWriteableOutputFile(outputPath)
- if err != nil {
- t.Fatalf("error creating output file: %v", err)
- }
- defer outputFile3.Close()
+ // Restore into another container!
+ // Delete and recreate file before restoring.
+ if err := os.Remove(outputPath); err != nil {
+ t.Fatalf("error removing file")
+ }
+ outputFile3, err := createWriteableOutputFile(outputPath)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile3.Close()
- // Restore into a new container.
- args3 := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont3, err := New(conf, args3)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont3.Destroy()
+ // Restore into a new container.
+ args3 := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont3, err := New(conf, args3)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont3.Destroy()
- if err := cont3.Restore(spec, conf, imagePath); err != nil {
- t.Fatalf("error restoring container: %v", err)
- }
+ if err := cont3.Restore(spec, conf, imagePath); err != nil {
+ t.Fatalf("error restoring container: %v", err)
+ }
- // Wait until application has ran.
- if err := waitForFile(outputFile3); err != nil {
- t.Fatalf("Failed to wait for output file: %v", err)
- }
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile3); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
- firstNum2, err := readOutputNum(outputPath, 0)
- if err != nil {
- t.Fatalf("error with outputFile: %v", err)
- }
+ firstNum2, err := readOutputNum(outputPath, 0)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
- // Check that lastNum is one less than firstNum and that the container picks
- // up from where it left off.
- if lastNum+1 != firstNum2 {
- t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum2)
- }
- cont3.Destroy()
+ // Check that lastNum is one less than firstNum and that the container picks
+ // up from where it left off.
+ if lastNum+1 != firstNum2 {
+ t.Errorf("error numbers not in order, previous: %d, next: %d", lastNum, firstNum2)
+ }
+ cont3.Destroy()
+ })
}
}
@@ -888,256 +1073,213 @@ func TestCheckpointRestore(t *testing.T) {
// with filesystem Unix Domain Socket use.
func TestUnixDomainSockets(t *testing.T) {
// Skip overlay because test requires writing to host file.
- for _, conf := range configs(noOverlay...) {
- t.Logf("Running test with conf: %+v", conf)
-
- // UDS path is limited to 108 chars for compatibility with older systems.
- // Use '/tmp' (instead of testutil.TmpDir) to ensure the size limit is
- // not exceeded. Assumes '/tmp' exists in the system.
- dir, err := ioutil.TempDir("/tmp", "uds-test")
- if err != nil {
- t.Fatalf("ioutil.TempDir failed: %v", err)
- }
- defer os.RemoveAll(dir)
+ for name, conf := range configs(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ // UDS path is limited to 108 chars for compatibility with older systems.
+ // Use '/tmp' (instead of testutil.TmpDir) to ensure the size limit is
+ // not exceeded. Assumes '/tmp' exists in the system.
+ dir, err := ioutil.TempDir("/tmp", "uds-test")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir failed: %v", err)
+ }
+ defer os.RemoveAll(dir)
- outputPath := filepath.Join(dir, "uds_output")
- outputFile, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666)
- if err != nil {
- t.Fatalf("error creating output file: %v", err)
- }
- defer outputFile.Close()
+ outputPath := filepath.Join(dir, "uds_output")
+ outputFile, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile.Close()
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
- if err != nil {
- t.Fatal("error finding test_app:", err)
- }
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
- socketPath := filepath.Join(dir, "uds_socket")
- defer os.Remove(socketPath)
+ socketPath := filepath.Join(dir, "uds_socket")
+ defer os.Remove(socketPath)
- spec := testutil.NewSpecWithArgs(app, "uds", "--file", outputPath, "--socket", socketPath)
- spec.Process.User = specs.User{
- UID: uint32(os.Getuid()),
- GID: uint32(os.Getgid()),
- }
- spec.Mounts = []specs.Mount{{
- Type: "bind",
- Destination: dir,
- Source: dir,
- }}
+ spec := testutil.NewSpecWithArgs(app, "uds", "--file", outputPath, "--socket", socketPath)
+ spec.Process.User = specs.User{
+ UID: uint32(os.Getuid()),
+ GID: uint32(os.Getgid()),
+ }
+ spec.Mounts = []specs.Mount{{
+ Type: "bind",
+ Destination: dir,
+ Source: dir,
+ }}
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create and start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- // Set the image path, the location where the checkpoint image will be saved.
- imagePath := filepath.Join(dir, "test-image-file")
+ // Set the image path, the location where the checkpoint image will be saved.
+ imagePath := filepath.Join(dir, "test-image-file")
- // Create the image file and open for writing.
- file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
- if err != nil {
- t.Fatalf("error opening new file at imagePath: %v", err)
- }
- defer file.Close()
- defer os.RemoveAll(imagePath)
+ // Create the image file and open for writing.
+ file, err := os.OpenFile(imagePath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0644)
+ if err != nil {
+ t.Fatalf("error opening new file at imagePath: %v", err)
+ }
+ defer file.Close()
+ defer os.RemoveAll(imagePath)
- // Wait until application has ran.
- if err := waitForFile(outputFile); err != nil {
- t.Fatalf("Failed to wait for output file: %v", err)
- }
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
- // Checkpoint running container; save state into new file.
- if err := cont.Checkpoint(file); err != nil {
- t.Fatalf("error checkpointing container to empty file: %v", err)
- }
+ // Checkpoint running container; save state into new file.
+ if err := cont.Checkpoint(file); err != nil {
+ t.Fatalf("error checkpointing container to empty file: %v", err)
+ }
- // Read last number outputted before checkpoint.
- lastNum, err := readOutputNum(outputPath, -1)
- if err != nil {
- t.Fatalf("error with outputFile: %v", err)
- }
+ // Read last number outputted before checkpoint.
+ lastNum, err := readOutputNum(outputPath, -1)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
- // Delete and recreate file before restoring.
- if err := os.Remove(outputPath); err != nil {
- t.Fatalf("error removing file")
- }
- outputFile2, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666)
- if err != nil {
- t.Fatalf("error creating output file: %v", err)
- }
- defer outputFile2.Close()
+ // Delete and recreate file before restoring.
+ if err := os.Remove(outputPath); err != nil {
+ t.Fatalf("error removing file")
+ }
+ outputFile2, err := os.OpenFile(outputPath, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0666)
+ if err != nil {
+ t.Fatalf("error creating output file: %v", err)
+ }
+ defer outputFile2.Close()
- // Restore into a new container.
- argsRestore := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- contRestore, err := New(conf, argsRestore)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer contRestore.Destroy()
+ // Restore into a new container.
+ argsRestore := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ contRestore, err := New(conf, argsRestore)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer contRestore.Destroy()
- if err := contRestore.Restore(spec, conf, imagePath); err != nil {
- t.Fatalf("error restoring container: %v", err)
- }
+ if err := contRestore.Restore(spec, conf, imagePath); err != nil {
+ t.Fatalf("error restoring container: %v", err)
+ }
- // Wait until application has ran.
- if err := waitForFile(outputFile2); err != nil {
- t.Fatalf("Failed to wait for output file: %v", err)
- }
+ // Wait until application has ran.
+ if err := waitForFileNotEmpty(outputFile2); err != nil {
+ t.Fatalf("Failed to wait for output file: %v", err)
+ }
- // Read first number outputted after restore.
- firstNum, err := readOutputNum(outputPath, 0)
- if err != nil {
- t.Fatalf("error with outputFile: %v", err)
- }
+ // Read first number outputted after restore.
+ firstNum, err := readOutputNum(outputPath, 0)
+ if err != nil {
+ t.Fatalf("error with outputFile: %v", err)
+ }
- // Check that lastNum is one less than firstNum.
- if lastNum+1 != firstNum {
- t.Errorf("error numbers not consecutive, previous: %d, next: %d", lastNum, firstNum)
- }
- contRestore.Destroy()
+ // Check that lastNum is one less than firstNum.
+ if lastNum+1 != firstNum {
+ t.Errorf("error numbers not consecutive, previous: %d, next: %d", lastNum, firstNum)
+ }
+ contRestore.Destroy()
+ })
}
}
// TestPauseResume tests that we can successfully pause and resume a container.
-// It checks starts running sleep and executes another sleep. It pauses and checks
-// that both processes are still running: sleep will be paused and still exist.
-// It will then unpause and confirm that both processes are running. Then it will
-// wait until one sleep completes and check to make sure the other is running.
+// The container will keep touching a file to indicate it's running. The test
+// pauses the container, removes the file, and checks that it doesn't get
+// recreated. Then it resumes the container, verify that the file gets created
+// again.
func TestPauseResume(t *testing.T) {
- for _, conf := range configs(noOverlay...) {
- t.Logf("Running test with conf: %+v", conf)
- const uid = 343
- spec := testutil.NewSpecWithArgs("sleep", "20")
-
- lock, err := ioutil.TempFile(testutil.TmpDir(), "lock")
- if err != nil {
- t.Fatalf("error creating output file: %v", err)
- }
- defer lock.Close()
-
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
-
- // Create and start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
-
- // expectedPL lists the expected process state of the container.
- expectedPL := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- },
- {
- UID: uid,
- PID: 2,
- PPID: 0,
- C: 0,
- Cmd: "bash",
- },
- }
-
- script := fmt.Sprintf("while [[ -f %q ]]; do sleep 0.1; done", lock.Name())
- execArgs := &control.ExecArgs{
- Filename: "/bin/bash",
- Argv: []string{"bash", "-c", script},
- WorkingDirectory: "/",
- KUID: uid,
- }
+ for name, conf := range configs(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "lock")
+ if err != nil {
+ t.Fatalf("error creating temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
- // First, start running exec.
- _, err = cont.Execute(execArgs)
- if err != nil {
- t.Fatalf("error executing: %v", err)
- }
+ running := path.Join(tmpDir, "running")
+ script := fmt.Sprintf("while [[ true ]]; do touch %q; sleep 0.1; done", running)
+ spec := testutil.NewSpecWithArgs("/bin/bash", "-c", script)
- // Verify that "sleep 5" is running.
- if err := waitForProcessList(cont, expectedPL); err != nil {
- t.Fatal(err)
- }
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Pause the running container.
- if err := cont.Pause(); err != nil {
- t.Errorf("error pausing container: %v", err)
- }
- if got, want := cont.Status, Paused; got != want {
- t.Errorf("container status got %v, want %v", got, want)
- }
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- if err := os.Remove(lock.Name()); err != nil {
- t.Fatalf("os.Remove(lock) failed: %v", err)
- }
- // Script loops and sleeps for 100ms. Give a bit a time for it to exit in
- // case pause didn't work.
- time.Sleep(200 * time.Millisecond)
+ // Wait until container starts running, observed by the existence of running
+ // file.
+ if err := waitForFileExist(running); err != nil {
+ t.Errorf("error waiting for container to start: %v", err)
+ }
- // Verify that the two processes still exist.
- if err := getAndCheckProcLists(cont, expectedPL); err != nil {
- t.Fatal(err)
- }
+ // Pause the running container.
+ if err := cont.Pause(); err != nil {
+ t.Errorf("error pausing container: %v", err)
+ }
+ if got, want := cont.Status, Paused; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
- // Resume the running container.
- if err := cont.Resume(); err != nil {
- t.Errorf("error pausing container: %v", err)
- }
- if got, want := cont.Status, Running; got != want {
- t.Errorf("container status got %v, want %v", got, want)
- }
+ if err := os.Remove(running); err != nil {
+ t.Fatalf("os.Remove(%q) failed: %v", running, err)
+ }
+ // Script touches the file every 100ms. Give a bit a time for it to run to
+ // catch the case that pause didn't work.
+ time.Sleep(200 * time.Millisecond)
+ if _, err := os.Stat(running); !os.IsNotExist(err) {
+ t.Fatalf("container did not pause: file exist check: %v", err)
+ }
- expectedPL2 := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- },
- }
+ // Resume the running container.
+ if err := cont.Resume(); err != nil {
+ t.Errorf("error pausing container: %v", err)
+ }
+ if got, want := cont.Status, Running; got != want {
+ t.Errorf("container status got %v, want %v", got, want)
+ }
- // Verify that deleting the file triggered the process to exit.
- if err := waitForProcessList(cont, expectedPL2); err != nil {
- t.Fatal(err)
- }
+ // Verify that the file is once again created by container.
+ if err := waitForFileExist(running); err != nil {
+ t.Fatalf("error resuming container: file exist check: %v", err)
+ }
+ })
}
}
@@ -1146,17 +1288,16 @@ func TestPauseResume(t *testing.T) {
// occurs given the correct state.
func TestPauseResumeStatus(t *testing.T) {
spec := testutil.NewSpecWithArgs("sleep", "20")
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create and start the container.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
@@ -1212,357 +1353,350 @@ func TestCapabilities(t *testing.T) {
uid := auth.KUID(os.Getuid() + 1)
gid := auth.KGID(os.Getgid() + 1)
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- spec := testutil.NewSpecWithArgs("sleep", "100")
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("sleep", "100")
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create and start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer cont.Destroy()
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- // expectedPL lists the expected process state of the container.
- expectedPL := []*control.Process{
- {
- UID: 0,
- PID: 1,
- PPID: 0,
- C: 0,
- Cmd: "sleep",
- },
- {
- UID: uid,
- PID: 2,
- PPID: 0,
- C: 0,
- Cmd: "exe",
- },
- }
- if err := waitForProcessList(cont, expectedPL[:1]); err != nil {
- t.Fatalf("Failed to wait for sleep to start, err: %v", err)
- }
+ // expectedPL lists the expected process state of the container.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(cont, expectedPL); err != nil {
+ t.Fatalf("Failed to wait for sleep to start, err: %v", err)
+ }
- // Create an executable that can't be run with the specified UID:GID.
- // This shouldn't be callable within the container until we add the
- // CAP_DAC_OVERRIDE capability to skip the access check.
- exePath := filepath.Join(rootDir, "exe")
- if err := ioutil.WriteFile(exePath, []byte("#!/bin/sh\necho hello"), 0770); err != nil {
- t.Fatalf("couldn't create executable: %v", err)
- }
- defer os.Remove(exePath)
-
- // Need to traverse the intermediate directory.
- os.Chmod(rootDir, 0755)
-
- execArgs := &control.ExecArgs{
- Filename: exePath,
- Argv: []string{exePath},
- WorkingDirectory: "/",
- KUID: uid,
- KGID: gid,
- Capabilities: &auth.TaskCapabilities{},
- }
+ // Create an executable that can't be run with the specified UID:GID.
+ // This shouldn't be callable within the container until we add the
+ // CAP_DAC_OVERRIDE capability to skip the access check.
+ exePath := filepath.Join(rootDir, "exe")
+ if err := ioutil.WriteFile(exePath, []byte("#!/bin/sh\necho hello"), 0770); err != nil {
+ t.Fatalf("couldn't create executable: %v", err)
+ }
+ defer os.Remove(exePath)
+
+ // Need to traverse the intermediate directory.
+ os.Chmod(rootDir, 0755)
+
+ execArgs := &control.ExecArgs{
+ Filename: exePath,
+ Argv: []string{exePath},
+ WorkingDirectory: "/",
+ KUID: uid,
+ KGID: gid,
+ Capabilities: &auth.TaskCapabilities{},
+ }
- // "exe" should fail because we don't have the necessary permissions.
- if _, err := cont.executeSync(execArgs); err == nil {
- t.Fatalf("container executed without error, but an error was expected")
- }
+ // "exe" should fail because we don't have the necessary permissions.
+ if _, err := cont.executeSync(execArgs); err == nil {
+ t.Fatalf("container executed without error, but an error was expected")
+ }
- // Now we run with the capability enabled and should succeed.
- execArgs.Capabilities = &auth.TaskCapabilities{
- EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
- }
- // "exe" should not fail this time.
- if _, err := cont.executeSync(execArgs); err != nil {
- t.Fatalf("container failed to exec %v: %v", args, err)
- }
+ // Now we run with the capability enabled and should succeed.
+ execArgs.Capabilities = &auth.TaskCapabilities{
+ EffectiveCaps: auth.CapabilitySetOf(linux.CAP_DAC_OVERRIDE),
+ }
+ // "exe" should not fail this time.
+ if _, err := cont.executeSync(execArgs); err != nil {
+ t.Fatalf("container failed to exec %v: %v", args, err)
+ }
+ })
}
}
// TestRunNonRoot checks that sandbox can be configured when running as
// non-privileged user.
func TestRunNonRoot(t *testing.T) {
- for _, conf := range configs(noOverlay...) {
- t.Logf("Running test with conf: %+v", conf)
-
- spec := testutil.NewSpecWithArgs("/bin/true")
-
- // Set a random user/group with no access to "blocked" dir.
- spec.Process.User.UID = 343
- spec.Process.User.GID = 2401
- spec.Process.Capabilities = nil
+ for name, conf := range configsWithVFS2(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("/bin/true")
+
+ // Set a random user/group with no access to "blocked" dir.
+ spec.Process.User.UID = 343
+ spec.Process.User.GID = 2401
+ spec.Process.Capabilities = nil
+
+ // User running inside container can't list '$TMP/blocked' and would fail to
+ // mount it.
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ if err := os.Chmod(dir, 0700); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
+ }
+ dir = path.Join(dir, "test")
+ if err := os.Mkdir(dir, 0755); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
+ }
- // User running inside container can't list '$TMP/blocked' and would fail to
- // mount it.
- dir, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
- if err != nil {
- t.Fatalf("ioutil.TempDir() failed: %v", err)
- }
- if err := os.Chmod(dir, 0700); err != nil {
- t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
- }
- dir = path.Join(dir, "test")
- if err := os.Mkdir(dir, 0755); err != nil {
- t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
- }
+ src, err := ioutil.TempDir(testutil.TmpDir(), "src")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
- src, err := ioutil.TempDir(testutil.TmpDir(), "src")
- if err != nil {
- t.Fatalf("ioutil.TempDir() failed: %v", err)
- }
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: dir,
+ Source: src,
+ Type: "bind",
+ })
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Destination: dir,
- Source: src,
- Type: "bind",
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("error running sandbox: %v", err)
+ }
})
-
- if err := run(spec, conf); err != nil {
- t.Fatalf("error running sandbox: %v", err)
- }
}
}
// TestMountNewDir checks that runsc will create destination directory if it
// doesn't exit.
func TestMountNewDir(t *testing.T) {
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ root, err := ioutil.TempDir(testutil.TmpDir(), "root")
+ if err != nil {
+ t.Fatal("ioutil.TempDir() failed:", err)
+ }
- root, err := ioutil.TempDir(testutil.TmpDir(), "root")
- if err != nil {
- t.Fatal("ioutil.TempDir() failed:", err)
- }
+ srcDir := path.Join(root, "src", "dir", "anotherdir")
+ if err := os.MkdirAll(srcDir, 0755); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", srcDir, err)
+ }
- srcDir := path.Join(root, "src", "dir", "anotherdir")
- if err := os.MkdirAll(srcDir, 0755); err != nil {
- t.Fatalf("os.MkDir(%q) failed: %v", srcDir, err)
- }
+ mountDir := path.Join(root, "dir", "anotherdir")
- mountDir := path.Join(root, "dir", "anotherdir")
+ spec := testutil.NewSpecWithArgs("/bin/ls", mountDir)
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: mountDir,
+ Source: srcDir,
+ Type: "bind",
+ })
- spec := testutil.NewSpecWithArgs("/bin/ls", mountDir)
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Destination: mountDir,
- Source: srcDir,
- Type: "bind",
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("error running sandbox: %v", err)
+ }
})
-
- if err := run(spec, conf); err != nil {
- t.Fatalf("error running sandbox: %v", err)
- }
}
}
func TestReadonlyRoot(t *testing.T) {
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
-
- spec := testutil.NewSpecWithArgs("/bin/touch", "/foo")
- spec.Root.Readonly = true
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ spec := testutil.NewSpecWithArgs("/bin/touch", "/foo")
+ spec.Root.Readonly = true
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create, start and wait for the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- c, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer c.Destroy()
- if err := c.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- ws, err := c.Wait()
- if err != nil {
- t.Fatalf("error waiting on container: %v", err)
- }
- if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
- t.Fatalf("container failed, waitStatus: %v", ws)
- }
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
+ t.Fatalf("container failed, waitStatus: %v", ws)
+ }
+ })
}
}
func TestUIDMap(t *testing.T) {
- for _, conf := range configs(noOverlay...) {
- t.Logf("Running test with conf: %+v", conf)
- testDir, err := ioutil.TempDir(testutil.TmpDir(), "test-mount")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(testDir)
- testFile := path.Join(testDir, "testfile")
-
- spec := testutil.NewSpecWithArgs("touch", "/tmp/testfile")
- uid := os.Getuid()
- gid := os.Getgid()
- spec.Linux = &specs.Linux{
- Namespaces: []specs.LinuxNamespace{
- {Type: specs.UserNamespace},
- {Type: specs.PIDNamespace},
- {Type: specs.MountNamespace},
- },
- UIDMappings: []specs.LinuxIDMapping{
- {
- ContainerID: 0,
- HostID: uint32(uid),
- Size: 1,
+ for name, conf := range configsWithVFS2(t, noOverlay...) {
+ t.Run(name, func(t *testing.T) {
+ testDir, err := ioutil.TempDir(testutil.TmpDir(), "test-mount")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ defer os.RemoveAll(testDir)
+ testFile := path.Join(testDir, "testfile")
+
+ spec := testutil.NewSpecWithArgs("touch", "/tmp/testfile")
+ uid := os.Getuid()
+ gid := os.Getgid()
+ spec.Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {Type: specs.UserNamespace},
+ {Type: specs.PIDNamespace},
+ {Type: specs.MountNamespace},
},
- },
- GIDMappings: []specs.LinuxIDMapping{
- {
- ContainerID: 0,
- HostID: uint32(gid),
- Size: 1,
+ UIDMappings: []specs.LinuxIDMapping{
+ {
+ ContainerID: 0,
+ HostID: uint32(uid),
+ Size: 1,
+ },
},
- },
- }
+ GIDMappings: []specs.LinuxIDMapping{
+ {
+ ContainerID: 0,
+ HostID: uint32(gid),
+ Size: 1,
+ },
+ },
+ }
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Destination: "/tmp",
- Source: testDir,
- Type: "bind",
- })
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: "/tmp",
+ Source: testDir,
+ Type: "bind",
+ })
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create, start and wait for the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- c, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer c.Destroy()
- if err := c.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- ws, err := c.Wait()
- if err != nil {
- t.Fatalf("error waiting on container: %v", err)
- }
- if !ws.Exited() || ws.ExitStatus() != 0 {
- t.Fatalf("container failed, waitStatus: %v", ws)
- }
- st := syscall.Stat_t{}
- if err := syscall.Stat(testFile, &st); err != nil {
- t.Fatalf("error stat /testfile: %v", err)
- }
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Fatalf("container failed, waitStatus: %v", ws)
+ }
+ st := syscall.Stat_t{}
+ if err := syscall.Stat(testFile, &st); err != nil {
+ t.Fatalf("error stat /testfile: %v", err)
+ }
- if st.Uid != uint32(uid) || st.Gid != uint32(gid) {
- t.Fatalf("UID: %d (%d) GID: %d (%d)", st.Uid, uid, st.Gid, gid)
- }
+ if st.Uid != uint32(uid) || st.Gid != uint32(gid) {
+ t.Fatalf("UID: %d (%d) GID: %d (%d)", st.Uid, uid, st.Gid, gid)
+ }
+ })
}
}
func TestReadonlyMount(t *testing.T) {
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
-
- dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount")
- spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file"))
- if err != nil {
- t.Fatalf("ioutil.TempDir() failed: %v", err)
- }
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Destination: dir,
- Source: dir,
- Type: "bind",
- Options: []string{"ro"},
- })
- spec.Root.Readonly = false
-
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount")
+ spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file"))
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: dir,
+ Source: dir,
+ Type: "bind",
+ Options: []string{"ro"},
+ })
+ spec.Root.Readonly = false
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create, start and wait for the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- c, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer c.Destroy()
- if err := c.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create, start and wait for the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- ws, err := c.Wait()
- if err != nil {
- t.Fatalf("error waiting on container: %v", err)
- }
- if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
- t.Fatalf("container failed, waitStatus: %v", ws)
- }
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
+ t.Fatalf("container failed, waitStatus: %v", ws)
+ }
+ })
}
}
// TestAbbreviatedIDs checks that runsc supports using abbreviated container
// IDs in place of full IDs.
func TestAbbreviatedIDs(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
+ doAbbreviatedIDsTest(t, false)
+}
+
+func TestAbbreviatedIDsVFS2(t *testing.T) {
+ doAbbreviatedIDsTest(t, true)
+}
+
+func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
+ conf.VFS2 = vfs2
cids := []string{
- "foo-" + testutil.UniqueContainerID(),
- "bar-" + testutil.UniqueContainerID(),
- "baz-" + testutil.UniqueContainerID(),
+ "foo-" + testutil.RandomContainerID(),
+ "bar-" + testutil.RandomContainerID(),
+ "baz-" + testutil.RandomContainerID(),
}
for _, cid := range cids {
spec := testutil.NewSpecWithArgs("sleep", "100")
- bundleDir, err := testutil.SetupBundleDir(spec)
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create and start the container.
args := Args{
@@ -1605,18 +1739,27 @@ func TestAbbreviatedIDs(t *testing.T) {
}
func TestGoferExits(t *testing.T) {
+ doGoferExitTest(t, false)
+}
+
+func TestGoferExitsVFS2(t *testing.T) {
+ doGoferExitTest(t, true)
+}
+
+func doGoferExitTest(t *testing.T, vfs2 bool) {
spec := testutil.NewSpecWithArgs("/bin/sleep", "10000")
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create and start the container.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
@@ -1645,7 +1788,7 @@ func TestGoferExits(t *testing.T) {
}
func TestRootNotMount(t *testing.T) {
- appSym, err := testutil.FindFile("runsc/container/test_app/test_app")
+ appSym, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
@@ -1675,27 +1818,26 @@ func TestRootNotMount(t *testing.T) {
spec.Root.Readonly = true
spec.Mounts = nil
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
if err := run(spec, conf); err != nil {
t.Fatalf("error running sandbox: %v", err)
}
}
func TestUserLog(t *testing.T) {
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
// sched_rr_get_interval = 148 - not implemented in gvisor.
spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall=148")
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
dir, err := ioutil.TempDir(testutil.TmpDir(), "user_log_test")
if err != nil {
@@ -1705,7 +1847,7 @@ func TestUserLog(t *testing.T) {
// Create, start and wait for the container.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
UserLog: userLog,
@@ -1723,78 +1865,85 @@ func TestUserLog(t *testing.T) {
if err != nil {
t.Fatalf("error opening user log file %q: %v", userLog, err)
}
- if want := "Unsupported syscall: sched_rr_get_interval"; !strings.Contains(string(out), want) {
+ if want := "Unsupported syscall sched_rr_get_interval("; !strings.Contains(string(out), want) {
t.Errorf("user log file doesn't contain %q, out: %s", want, string(out))
}
}
func TestWaitOnExitedSandbox(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- // Run a shell that sleeps for 1 second and then exits with a
- // non-zero code.
- const wantExit = 17
- cmd := fmt.Sprintf("sleep 1; exit %d", wantExit)
- spec := testutil.NewSpecWithArgs("/bin/sh", "-c", cmd)
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ // Run a shell that sleeps for 1 second and then exits with a
+ // non-zero code.
+ const wantExit = 17
+ cmd := fmt.Sprintf("sleep 1; exit %d", wantExit)
+ spec := testutil.NewSpecWithArgs("/bin/sh", "-c", cmd)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- // Create and Start the container.
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- c, err := New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- defer c.Destroy()
- if err := c.Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ // Create and Start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- // Wait on the sandbox. This will make an RPC to the sandbox
- // and get the actual exit status of the application.
- ws, err := c.Wait()
- if err != nil {
- t.Fatalf("error waiting on container: %v", err)
- }
- if got := ws.ExitStatus(); got != wantExit {
- t.Errorf("got exit status %d, want %d", got, wantExit)
- }
+ // Wait on the sandbox. This will make an RPC to the sandbox
+ // and get the actual exit status of the application.
+ ws, err := c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if got := ws.ExitStatus(); got != wantExit {
+ t.Errorf("got exit status %d, want %d", got, wantExit)
+ }
- // Now the sandbox has exited, but the zombie sandbox process
- // still exists. Calling Wait() now will return the sandbox
- // exit status.
- ws, err = c.Wait()
- if err != nil {
- t.Fatalf("error waiting on container: %v", err)
- }
- if got := ws.ExitStatus(); got != wantExit {
- t.Errorf("got exit status %d, want %d", got, wantExit)
- }
+ // Now the sandbox has exited, but the zombie sandbox process
+ // still exists. Calling Wait() now will return the sandbox
+ // exit status.
+ ws, err = c.Wait()
+ if err != nil {
+ t.Fatalf("error waiting on container: %v", err)
+ }
+ if got := ws.ExitStatus(); got != wantExit {
+ t.Errorf("got exit status %d, want %d", got, wantExit)
+ }
+ })
}
}
func TestDestroyNotStarted(t *testing.T) {
+ doDestroyNotStartedTest(t, false)
+}
+
+func TestDestroyNotStartedVFS2(t *testing.T) {
+ doDestroyNotStartedTest(t, true)
+}
+
+func doDestroyNotStartedTest(t *testing.T, vfs2 bool) {
spec := testutil.NewSpecWithArgs("/bin/sleep", "100")
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create the container and check that it can be destroyed.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
@@ -1809,19 +1958,27 @@ func TestDestroyNotStarted(t *testing.T) {
// TestDestroyStarting attempts to force a race between start and destroy.
func TestDestroyStarting(t *testing.T) {
+ doDestroyNotStartedTest(t, false)
+}
+
+func TestDestroyStartedVFS2(t *testing.T) {
+ doDestroyNotStartedTest(t, true)
+}
+
+func doDestroyStartingTest(t *testing.T, vfs2 bool) {
for i := 0; i < 10; i++ {
spec := testutil.NewSpecWithArgs("/bin/sleep", "100")
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create the container and check that it can be destroyed.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
@@ -1856,23 +2013,23 @@ func TestDestroyStarting(t *testing.T) {
}
func TestCreateWorkingDir(t *testing.T) {
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
-
- tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create")
- if err != nil {
- t.Fatalf("ioutil.TempDir() failed: %v", err)
- }
- dir := path.Join(tmpDir, "new/working/dir")
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ dir := path.Join(tmpDir, "new/working/dir")
- // touch will fail if the directory doesn't exist.
- spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file"))
- spec.Process.Cwd = dir
- spec.Root.Readonly = true
+ // touch will fail if the directory doesn't exist.
+ spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file"))
+ spec.Process.Cwd = dir
+ spec.Root.Readonly = true
- if err := run(spec, conf); err != nil {
- t.Fatalf("Error running container: %v", err)
- }
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+ })
}
}
@@ -1929,16 +2086,15 @@ func TestMountPropagation(t *testing.T) {
},
}
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
@@ -1980,87 +2136,87 @@ func TestMountPropagation(t *testing.T) {
}
func TestMountSymlink(t *testing.T) {
- for _, conf := range configs(overlay) {
- t.Logf("Running test with conf: %+v", conf)
-
- dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink")
- if err != nil {
- t.Fatalf("ioutil.TempDir() failed: %v", err)
- }
+ for name, conf := range configsWithVFS2(t, overlay) {
+ t.Run(name, func(t *testing.T) {
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ defer os.RemoveAll(dir)
- source := path.Join(dir, "source")
- target := path.Join(dir, "target")
- for _, path := range []string{source, target} {
- if err := os.MkdirAll(path, 0777); err != nil {
- t.Fatalf("os.MkdirAll(): %v", err)
+ source := path.Join(dir, "source")
+ target := path.Join(dir, "target")
+ for _, path := range []string{source, target} {
+ if err := os.MkdirAll(path, 0777); err != nil {
+ t.Fatalf("os.MkdirAll(): %v", err)
+ }
}
- }
- f, err := os.Create(path.Join(source, "file"))
- if err != nil {
- t.Fatalf("os.Create(): %v", err)
- }
- f.Close()
+ f, err := os.Create(path.Join(source, "file"))
+ if err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+ f.Close()
- link := path.Join(dir, "link")
- if err := os.Symlink(target, link); err != nil {
- t.Fatalf("os.Symlink(%q, %q): %v", target, link, err)
- }
+ link := path.Join(dir, "link")
+ if err := os.Symlink(target, link); err != nil {
+ t.Fatalf("os.Symlink(%q, %q): %v", target, link, err)
+ }
- spec := testutil.NewSpecWithArgs("/bin/sleep", "1000")
+ spec := testutil.NewSpecWithArgs("/bin/sleep", "1000")
- // Mount to a symlink to ensure the mount code will follow it and mount
- // at the symlink target.
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Type: "bind",
- Destination: link,
- Source: source,
- })
+ // Mount to a symlink to ensure the mount code will follow it and mount
+ // at the symlink target.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Type: "bind",
+ Destination: link,
+ Source: source,
+ })
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
- if err != nil {
- t.Fatalf("error setting up container: %v", err)
- }
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
- args := Args{
- ID: testutil.UniqueContainerID(),
- Spec: spec,
- BundleDir: bundleDir,
- }
- cont, err := New(conf, args)
- if err != nil {
- t.Fatalf("creating container: %v", err)
- }
- defer cont.Destroy()
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("creating container: %v", err)
+ }
+ defer cont.Destroy()
- if err := cont.Start(conf); err != nil {
- t.Fatalf("starting container: %v", err)
- }
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("starting container: %v", err)
+ }
- // Check that symlink was resolved and mount was created where the symlink
- // is pointing to.
- file := path.Join(target, "file")
- execArgs := &control.ExecArgs{
- Filename: "/usr/bin/test",
- Argv: []string{"test", "-f", file},
- }
- if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
- t.Fatalf("exec: test -f %q, ws: %v, err: %v", file, ws, err)
- }
+ // Check that symlink was resolved and mount was created where the symlink
+ // is pointing to.
+ file := path.Join(target, "file")
+ execArgs := &control.ExecArgs{
+ Filename: "/usr/bin/test",
+ Argv: []string{"test", "-f", file},
+ }
+ if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 {
+ t.Fatalf("exec: test -f %q, ws: %v, err: %v", file, ws, err)
+ }
+ })
}
}
// Check that --net-raw disables the CAP_NET_RAW capability.
func TestNetRaw(t *testing.T) {
capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10)
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
for _, enableRaw := range []bool{true, false} {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.EnableRaw = enableRaw
test := "--enabled"
@@ -2075,40 +2231,98 @@ func TestNetRaw(t *testing.T) {
}
}
-// TestOverlayfsStaleRead most basic test that '--overlayfs-stale-read' works.
-func TestOverlayfsStaleRead(t *testing.T) {
- conf := testutil.TestConfig()
- conf.OverlayfsStaleRead = true
+// TestTTYField checks TTY field returned by container.Processes().
+func TestTTYField(t *testing.T) {
+ stop := testutil.StartReaper()
+ defer stop()
- in, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.in")
+ testApp, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
- t.Fatalf("ioutil.TempFile() failed: %v", err)
- }
- defer in.Close()
- if _, err := in.WriteString("stale data"); err != nil {
- t.Fatalf("in.Write() failed: %v", err)
+ t.Fatal("error finding test_app:", err)
}
- out, err := ioutil.TempFile(testutil.TmpDir(), "stale-read.out")
- if err != nil {
- t.Fatalf("ioutil.TempFile() failed: %v", err)
+ testCases := []struct {
+ name string
+ useTTY bool
+ wantTTYField string
+ }{
+ {
+ name: "no tty",
+ useTTY: false,
+ wantTTYField: "?",
+ },
+ {
+ name: "tty used",
+ useTTY: true,
+ wantTTYField: "pts/0",
+ },
}
- defer out.Close()
- const want = "foobar"
- cmd := fmt.Sprintf("cat %q && echo %q> %q && cp %q %q", in.Name(), want, in.Name(), in.Name(), out.Name())
- spec := testutil.NewSpecWithArgs("/bin/bash", "-c", cmd)
- if err := run(spec, conf); err != nil {
- t.Fatalf("Error running container: %v", err)
- }
+ for _, test := range testCases {
+ for _, vfs2 := range []bool{false, true} {
+ name := test.name
+ if vfs2 {
+ name += "-vfs2"
+ }
+ t.Run(name, func(t *testing.T) {
+ conf := testutil.TestConfig(t)
+ conf.VFS2 = vfs2
+
+ // We will run /bin/sleep, possibly with an open TTY.
+ cmd := []string{"/bin/sleep", "10000"}
+ if test.useTTY {
+ // Run inside the "pty-runner".
+ cmd = append([]string{testApp, "pty-runner"}, cmd...)
+ }
- gotBytes, err := ioutil.ReadAll(out)
- if err != nil {
- t.Fatalf("out.Read() failed: %v", err)
- }
- got := strings.TrimSpace(string(gotBytes))
- if want != got {
- t.Errorf("Wrong content in out file, got: %q. want: %q", got, want)
+ spec := testutil.NewSpecWithArgs(cmd...)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ // Create and start the container.
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ c, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ defer c.Destroy()
+ if err := c.Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
+
+ // Wait for sleep to be running, and check the TTY
+ // field.
+ var gotTTYField string
+ cb := func() error {
+ ps, err := c.Processes()
+ if err != nil {
+ err = fmt.Errorf("error getting process data from container: %v", err)
+ return &backoff.PermanentError{Err: err}
+ }
+ for _, p := range ps {
+ if strings.Contains(p.Cmd, "sleep") {
+ gotTTYField = p.TTY
+ return nil
+ }
+ }
+ return fmt.Errorf("sleep not running")
+ }
+ if err := testutil.Poll(cb, 30*time.Second); err != nil {
+ t.Fatalf("error waiting for sleep process: %v", err)
+ }
+
+ if gotTTYField != test.wantTTYField {
+ t.Errorf("tty field got %q, want %q", gotTTYField, test.wantTTYField)
+ }
+ })
+ }
}
}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index a5a62378c..e189648f4 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -22,23 +22,24 @@ import (
"path"
"path/filepath"
"strings"
- "sync"
"syscall"
"testing"
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/testutil"
)
func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
var specs []*specs.Spec
var ids []string
- rootID := testutil.UniqueContainerID()
+ rootID := testutil.RandomContainerID()
for i, cmd := range cmds {
spec := testutil.NewSpecWithArgs(cmd...)
@@ -52,7 +53,7 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer,
specutils.ContainerdSandboxIDAnnotation: rootID,
}
- ids = append(ids, testutil.UniqueContainerID())
+ ids = append(ids, testutil.RandomContainerID())
}
specs = append(specs, spec)
}
@@ -64,23 +65,16 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*C
panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.")
}
+ cu := cleanup.Cleanup{}
+ defer cu.Clean()
+
var containers []*Container
- var bundles []string
- cleanup := func() {
- for _, c := range containers {
- c.Destroy()
- }
- for _, b := range bundles {
- os.RemoveAll(b)
- }
- }
for i, spec := range specs {
- bundleDir, err := testutil.SetupBundleDir(spec)
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
if err != nil {
- cleanup()
return nil, nil, fmt.Errorf("error setting up container: %v", err)
}
- bundles = append(bundles, bundleDir)
+ cu.Add(cleanup)
args := Args{
ID: ids[i],
@@ -89,45 +83,46 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*C
}
cont, err := New(conf, args)
if err != nil {
- cleanup()
return nil, nil, fmt.Errorf("error creating container: %v", err)
}
+ cu.Add(func() { cont.Destroy() })
containers = append(containers, cont)
if err := cont.Start(conf); err != nil {
- cleanup()
return nil, nil, fmt.Errorf("error starting container: %v", err)
}
}
- return containers, cleanup, nil
+
+ return containers, cu.Release(), nil
}
type execDesc struct {
c *Container
cmd []string
want int
- desc string
+ name string
}
-func execMany(execs []execDesc) error {
+func execMany(t *testing.T, execs []execDesc) {
for _, exec := range execs {
- args := &control.ExecArgs{Argv: exec.cmd}
- if ws, err := exec.c.executeSync(args); err != nil {
- return fmt.Errorf("error executing %+v: %v", args, err)
- } else if ws.ExitStatus() != exec.want {
- return fmt.Errorf("%q: exec %q got exit status: %d, want: %d", exec.desc, exec.cmd, ws.ExitStatus(), exec.want)
- }
+ t.Run(exec.name, func(t *testing.T) {
+ args := &control.ExecArgs{Argv: exec.cmd}
+ if ws, err := exec.c.executeSync(args); err != nil {
+ t.Errorf("error executing %+v: %v", args, err)
+ } else if ws.ExitStatus() != exec.want {
+ t.Errorf("%q: exec %q got exit status: %d, want: %d", exec.name, exec.cmd, ws.ExitStatus(), exec.want)
+ }
+ })
}
- return nil
}
func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) {
for _, spec := range pod {
- spec.Annotations[path.Join(boot.MountPrefix, name, "source")] = mount.Source
- spec.Annotations[path.Join(boot.MountPrefix, name, "type")] = mount.Type
- spec.Annotations[path.Join(boot.MountPrefix, name, "share")] = "pod"
+ spec.Annotations[boot.MountPrefix+name+".source"] = mount.Source
+ spec.Annotations[boot.MountPrefix+name+".type"] = mount.Type
+ spec.Annotations[boot.MountPrefix+name+".share"] = "pod"
if len(mount.Options) > 0 {
- spec.Annotations[path.Join(boot.MountPrefix, name, "options")] = strings.Join(mount.Options, ",")
+ spec.Annotations[boot.MountPrefix+name+".options"] = strings.Join(mount.Options, ",")
}
}
}
@@ -135,161 +130,161 @@ func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) {
// TestMultiContainerSanity checks that it is possible to run 2 dead-simple
// containers in the same sandbox.
func TestMultiContainerSanity(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
- // Setup the containers.
- sleep := []string{"sleep", "100"}
- specs, ids := createSpecs(sleep, sleep)
- containers, cleanup, err := startContainers(conf, specs, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ specs, ids := createSpecs(sleep, sleep)
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- // Check via ps that multiple processes are running.
- expectedPL := []*control.Process{
- {PID: 1, Cmd: "sleep"},
- }
- if err := waitForProcessList(containers[0], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
- expectedPL = []*control.Process{
- {PID: 2, Cmd: "sleep"},
- }
- if err := waitForProcessList(containers[1], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
+ // Check via ps that multiple processes are running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ })
}
}
// TestMultiPIDNS checks that it is possible to run 2 dead-simple
// containers in the same sandbox with different pidns.
func TestMultiPIDNS(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
-
- // Setup the containers.
- sleep := []string{"sleep", "100"}
- testSpecs, ids := createSpecs(sleep, sleep)
- testSpecs[1].Linux = &specs.Linux{
- Namespaces: []specs.LinuxNamespace{
- {
- Type: "pid",
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ testSpecs, ids := createSpecs(sleep, sleep)
+ testSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ },
},
- },
- }
+ }
- containers, cleanup, err := startContainers(conf, testSpecs, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ containers, cleanup, err := startContainers(conf, testSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- // Check via ps that multiple processes are running.
- expectedPL := []*control.Process{
- {PID: 1, Cmd: "sleep"},
- }
- if err := waitForProcessList(containers[0], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
- expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep"},
- }
- if err := waitForProcessList(containers[1], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
+ // Check via ps that multiple processes are running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ })
}
}
// TestMultiPIDNSPath checks the pidns path.
func TestMultiPIDNSPath(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
-
- // Setup the containers.
- sleep := []string{"sleep", "100"}
- testSpecs, ids := createSpecs(sleep, sleep, sleep)
- testSpecs[0].Linux = &specs.Linux{
- Namespaces: []specs.LinuxNamespace{
- {
- Type: "pid",
- Path: "/proc/1/ns/pid",
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ testSpecs, ids := createSpecs(sleep, sleep, sleep)
+ testSpecs[0].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ Path: "/proc/1/ns/pid",
+ },
},
- },
- }
- testSpecs[1].Linux = &specs.Linux{
- Namespaces: []specs.LinuxNamespace{
- {
- Type: "pid",
- Path: "/proc/1/ns/pid",
+ }
+ testSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ Path: "/proc/1/ns/pid",
+ },
},
- },
- }
- testSpecs[2].Linux = &specs.Linux{
- Namespaces: []specs.LinuxNamespace{
- {
- Type: "pid",
- Path: "/proc/2/ns/pid",
+ }
+ testSpecs[2].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{
+ {
+ Type: "pid",
+ Path: "/proc/2/ns/pid",
+ },
},
- },
- }
+ }
- containers, cleanup, err := startContainers(conf, testSpecs, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ containers, cleanup, err := startContainers(conf, testSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- // Check via ps that multiple processes are running.
- expectedPL := []*control.Process{
- {PID: 1, Cmd: "sleep"},
- }
- if err := waitForProcessList(containers[0], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
- if err := waitForProcessList(containers[2], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
+ // Check via ps that multiple processes are running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ if err := waitForProcessList(containers[2], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
- expectedPL = []*control.Process{
- {PID: 2, Cmd: "sleep"},
- }
- if err := waitForProcessList(containers[1], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+ })
}
}
func TestMultiContainerWait(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// The first container should run the entire duration of the test.
@@ -306,7 +301,7 @@ func TestMultiContainerWait(t *testing.T) {
// Check via ps that multiple processes are running.
expectedPL := []*control.Process{
- {PID: 2, Cmd: "sleep"},
+ newProcessBuilder().PID(2).PPID(0).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -351,7 +346,7 @@ func TestMultiContainerWait(t *testing.T) {
// After Wait returns, ensure that the root container is running and
// the child has finished.
expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep"},
+ newProcessBuilder().Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Errorf("failed to wait for %q to start: %v", strings.Join(containers[0].Spec.Process.Args, " "), err)
@@ -361,13 +356,13 @@ func TestMultiContainerWait(t *testing.T) {
// TestExecWait ensures what we can wait containers and individual processes in the
// sandbox that have already exited.
func TestExecWait(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// The first container should run the entire duration of the test.
@@ -383,7 +378,7 @@ func TestExecWait(t *testing.T) {
// Check via ps that process is running.
expectedPL := []*control.Process{
- {PID: 2, Cmd: "sleep"},
+ newProcessBuilder().Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL); err != nil {
t.Fatalf("failed to wait for sleep to start: %v", err)
@@ -418,7 +413,7 @@ func TestExecWait(t *testing.T) {
// Wait for the exec'd process to exit.
expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep"},
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL); err != nil {
t.Fatalf("failed to wait for second container to stop: %v", err)
@@ -457,13 +452,13 @@ func TestMultiContainerMount(t *testing.T) {
})
// Setup the containers.
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
containers, cleanup, err := startContainers(conf, sps, ids)
@@ -484,175 +479,177 @@ func TestMultiContainerMount(t *testing.T) {
// TestMultiContainerSignal checks that it is possible to signal individual
// containers without killing the entire sandbox.
func TestMultiContainerSignal(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
-
- // Setup the containers.
- sleep := []string{"sleep", "100"}
- specs, ids := createSpecs(sleep, sleep)
- containers, cleanup, err := startContainers(conf, specs, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
- // Check via ps that container 1 process is running.
- expectedPL := []*control.Process{
- {PID: 2, Cmd: "sleep"},
- }
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ specs, ids := createSpecs(sleep, sleep)
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- if err := waitForProcessList(containers[1], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
+ // Check via ps that container 1 process is running.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[1], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
- // Kill process 2.
- if err := containers[1].SignalContainer(syscall.SIGKILL, false); err != nil {
- t.Errorf("failed to kill process 2: %v", err)
- }
+ // Kill process 2.
+ if err := containers[1].SignalContainer(syscall.SIGKILL, false); err != nil {
+ t.Errorf("failed to kill process 2: %v", err)
+ }
- // Make sure process 1 is still running.
- expectedPL = []*control.Process{
- {PID: 1, Cmd: "sleep"},
- }
- if err := waitForProcessList(containers[0], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
+ // Make sure process 1 is still running.
+ expectedPL = []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
- // goferPid is reset when container is destroyed.
- goferPid := containers[1].GoferPid
+ // goferPid is reset when container is destroyed.
+ goferPid := containers[1].GoferPid
- // Destroy container and ensure container's gofer process has exited.
- if err := containers[1].Destroy(); err != nil {
- t.Errorf("failed to destroy container: %v", err)
- }
- _, _, err = specutils.RetryEintr(func() (uintptr, uintptr, error) {
- cpid, err := syscall.Wait4(goferPid, nil, 0, nil)
- return uintptr(cpid), 0, err
- })
- if err != syscall.ECHILD {
- t.Errorf("error waiting for gofer to exit: %v", err)
- }
- // Make sure process 1 is still running.
- if err := waitForProcessList(containers[0], expectedPL); err != nil {
- t.Errorf("failed to wait for sleep to start: %v", err)
- }
+ // Destroy container and ensure container's gofer process has exited.
+ if err := containers[1].Destroy(); err != nil {
+ t.Errorf("failed to destroy container: %v", err)
+ }
+ _, _, err = specutils.RetryEintr(func() (uintptr, uintptr, error) {
+ cpid, err := syscall.Wait4(goferPid, nil, 0, nil)
+ return uintptr(cpid), 0, err
+ })
+ if err != syscall.ECHILD {
+ t.Errorf("error waiting for gofer to exit: %v", err)
+ }
+ // Make sure process 1 is still running.
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
- // Now that process 2 is gone, ensure we get an error trying to
- // signal it again.
- if err := containers[1].SignalContainer(syscall.SIGKILL, false); err == nil {
- t.Errorf("container %q shouldn't exist, but we were able to signal it", containers[1].ID)
- }
+ // Now that process 2 is gone, ensure we get an error trying to
+ // signal it again.
+ if err := containers[1].SignalContainer(syscall.SIGKILL, false); err == nil {
+ t.Errorf("container %q shouldn't exist, but we were able to signal it", containers[1].ID)
+ }
- // Kill process 1.
- if err := containers[0].SignalContainer(syscall.SIGKILL, false); err != nil {
- t.Errorf("failed to kill process 1: %v", err)
- }
+ // Kill process 1.
+ if err := containers[0].SignalContainer(syscall.SIGKILL, false); err != nil {
+ t.Errorf("failed to kill process 1: %v", err)
+ }
- // Ensure that container's gofer and sandbox process are no more.
- err = blockUntilWaitable(containers[0].GoferPid)
- if err != nil && err != syscall.ECHILD {
- t.Errorf("error waiting for gofer to exit: %v", err)
- }
+ // Ensure that container's gofer and sandbox process are no more.
+ err = blockUntilWaitable(containers[0].GoferPid)
+ if err != nil && err != syscall.ECHILD {
+ t.Errorf("error waiting for gofer to exit: %v", err)
+ }
- err = blockUntilWaitable(containers[0].Sandbox.Pid)
- if err != nil && err != syscall.ECHILD {
- t.Errorf("error waiting for sandbox to exit: %v", err)
- }
+ err = blockUntilWaitable(containers[0].Sandbox.Pid)
+ if err != nil && err != syscall.ECHILD {
+ t.Errorf("error waiting for sandbox to exit: %v", err)
+ }
- // The sentry should be gone, so signaling should yield an error.
- if err := containers[0].SignalContainer(syscall.SIGKILL, false); err == nil {
- t.Errorf("sandbox %q shouldn't exist, but we were able to signal it", containers[0].Sandbox.ID)
- }
+ // The sentry should be gone, so signaling should yield an error.
+ if err := containers[0].SignalContainer(syscall.SIGKILL, false); err == nil {
+ t.Errorf("sandbox %q shouldn't exist, but we were able to signal it", containers[0].Sandbox.ID)
+ }
- if err := containers[0].Destroy(); err != nil {
- t.Errorf("failed to destroy container: %v", err)
- }
+ if err := containers[0].Destroy(); err != nil {
+ t.Errorf("failed to destroy container: %v", err)
+ }
+ })
}
}
// TestMultiContainerDestroy checks that container are properly cleaned-up when
// they are destroyed.
func TestMultiContainerDestroy(t *testing.T) {
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
-
- // First container will remain intact while the second container is killed.
- podSpecs, ids := createSpecs(
- []string{"sleep", "100"},
- []string{app, "fork-bomb"})
-
- // Run the fork bomb in a PID namespace to prevent processes to be
- // re-parented to PID=1 in the root container.
- podSpecs[1].Linux = &specs.Linux{
- Namespaces: []specs.LinuxNamespace{{Type: "pid"}},
- }
- containers, cleanup, err := startContainers(conf, podSpecs, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // First container will remain intact while the second container is killed.
+ podSpecs, ids := createSpecs(
+ []string{"sleep", "100"},
+ []string{app, "fork-bomb"})
+
+ // Run the fork bomb in a PID namespace to prevent processes to be
+ // re-parented to PID=1 in the root container.
+ podSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{{Type: "pid"}},
+ }
+ containers, cleanup, err := startContainers(conf, podSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- // Exec more processes to ensure signal all works for exec'd processes too.
- args := &control.ExecArgs{
- Filename: app,
- Argv: []string{app, "fork-bomb"},
- }
- if _, err := containers[1].Execute(args); err != nil {
- t.Fatalf("error exec'ing: %v", err)
- }
+ // Exec more processes to ensure signal all works for exec'd processes too.
+ args := &control.ExecArgs{
+ Filename: app,
+ Argv: []string{app, "fork-bomb"},
+ }
+ if _, err := containers[1].Execute(args); err != nil {
+ t.Fatalf("error exec'ing: %v", err)
+ }
- // Let it brew...
- time.Sleep(500 * time.Millisecond)
+ // Let it brew...
+ time.Sleep(500 * time.Millisecond)
- if err := containers[1].Destroy(); err != nil {
- t.Fatalf("error destroying container: %v", err)
- }
+ if err := containers[1].Destroy(); err != nil {
+ t.Fatalf("error destroying container: %v", err)
+ }
- // Check that destroy killed all processes belonging to the container and
- // waited for them to exit before returning.
- pss, err := containers[0].Sandbox.Processes("")
- if err != nil {
- t.Fatalf("error getting process data from sandbox: %v", err)
- }
- expectedPL := []*control.Process{{PID: 1, Cmd: "sleep"}}
- if !procListsEqual(pss, expectedPL) {
- t.Errorf("container got process list: %s, want: %s", procListToString(pss), procListToString(expectedPL))
- }
+ // Check that destroy killed all processes belonging to the container and
+ // waited for them to exit before returning.
+ pss, err := containers[0].Sandbox.Processes("")
+ if err != nil {
+ t.Fatalf("error getting process data from sandbox: %v", err)
+ }
+ expectedPL := []*control.Process{
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
+ }
+ if !procListsEqual(pss, expectedPL) {
+ t.Errorf("container got process list: %s, want: %s: error: %v",
+ procListToString(pss), procListToString(expectedPL), err)
+ }
- // Check that cont.Destroy is safe to call multiple times.
- if err := containers[1].Destroy(); err != nil {
- t.Errorf("error destroying container: %v", err)
- }
+ // Check that cont.Destroy is safe to call multiple times.
+ if err := containers[1].Destroy(); err != nil {
+ t.Errorf("error destroying container: %v", err)
+ }
+ })
}
}
func TestMultiContainerProcesses(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Note: use curly braces to keep 'sh' process around. Otherwise, shell
@@ -669,7 +666,7 @@ func TestMultiContainerProcesses(t *testing.T) {
// Check root's container process list doesn't include other containers.
expectedPL0 := []*control.Process{
- {PID: 1, Cmd: "sleep"},
+ newProcessBuilder().PID(1).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[0], expectedPL0); err != nil {
t.Errorf("failed to wait for process to start: %v", err)
@@ -677,8 +674,8 @@ func TestMultiContainerProcesses(t *testing.T) {
// Same for the other container.
expectedPL1 := []*control.Process{
- {PID: 2, Cmd: "sh"},
- {PID: 3, PPID: 2, Cmd: "sleep"},
+ newProcessBuilder().PID(2).Cmd("sh").Process(),
+ newProcessBuilder().PID(3).PPID(2).Cmd("sleep").Process(),
}
if err := waitForProcessList(containers[1], expectedPL1); err != nil {
t.Errorf("failed to wait for process to start: %v", err)
@@ -692,7 +689,7 @@ func TestMultiContainerProcesses(t *testing.T) {
if _, err := containers[1].Execute(args); err != nil {
t.Fatalf("error exec'ing: %v", err)
}
- expectedPL1 = append(expectedPL1, &control.Process{PID: 4, Cmd: "sleep"})
+ expectedPL1 = append(expectedPL1, newProcessBuilder().PID(4).Cmd("sleep").Process())
if err := waitForProcessList(containers[1], expectedPL1); err != nil {
t.Errorf("failed to wait for process to start: %v", err)
}
@@ -705,13 +702,13 @@ func TestMultiContainerProcesses(t *testing.T) {
// TestMultiContainerKillAll checks that all process that belong to a container
// are killed when SIGKILL is sent to *all* processes in that container.
func TestMultiContainerKillAll(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
for _, tc := range []struct {
@@ -720,7 +717,7 @@ func TestMultiContainerKillAll(t *testing.T) {
{killContainer: true},
{killContainer: false},
} {
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
@@ -738,11 +735,11 @@ func TestMultiContainerKillAll(t *testing.T) {
// Wait until all processes are created.
rootProcCount := int(math.Pow(2, 3) - 1)
if err := waitForProcessCount(containers[0], rootProcCount); err != nil {
- t.Fatal(err)
+ t.Fatalf("error waitting for processes: %v", err)
}
procCount := int(math.Pow(2, 5) - 1)
if err := waitForProcessCount(containers[1], procCount); err != nil {
- t.Fatal(err)
+ t.Fatalf("error waiting for processes: %v", err)
}
// Exec more processes to ensure signal works for exec'd processes too.
@@ -756,7 +753,7 @@ func TestMultiContainerKillAll(t *testing.T) {
// Wait for these new processes to start.
procCount += int(math.Pow(2, 3) - 1)
if err := waitForProcessCount(containers[1], procCount); err != nil {
- t.Fatal(err)
+ t.Fatalf("error waiting for processes: %v", err)
}
if tc.killContainer {
@@ -789,11 +786,11 @@ func TestMultiContainerKillAll(t *testing.T) {
// Check that all processes are gone.
if err := waitForProcessCount(containers[1], 0); err != nil {
- t.Fatal(err)
+ t.Fatalf("error waiting for processes: %v", err)
}
// Check that root container was not affected.
if err := waitForProcessCount(containers[0], rootProcCount); err != nil {
- t.Fatal(err)
+ t.Fatalf("error waiting for processes: %v", err)
}
}
}
@@ -803,18 +800,17 @@ func TestMultiContainerDestroyNotStarted(t *testing.T) {
[]string{"/bin/sleep", "100"},
[]string{"/bin/sleep", "100"})
- conf := testutil.TestConfig()
- rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf)
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(specs[0], conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(rootBundleDir)
+ defer cleanup()
rootArgs := Args{
ID: ids[0],
Spec: specs[0],
- BundleDir: rootBundleDir,
+ BundleDir: bundleDir,
}
root, err := New(conf, rootArgs)
if err != nil {
@@ -826,11 +822,11 @@ func TestMultiContainerDestroyNotStarted(t *testing.T) {
}
// Create and destroy sub-container.
- bundleDir, err := testutil.SetupBundleDir(specs[1])
+ bundleDir, cleanupSub, err := testutil.SetupBundleDir(specs[1])
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(bundleDir)
+ defer cleanupSub()
args := Args{
ID: ids[1],
@@ -857,18 +853,17 @@ func TestMultiContainerDestroyStarting(t *testing.T) {
}
specs, ids := createSpecs(cmds...)
- conf := testutil.TestConfig()
- rootDir, rootBundleDir, err := testutil.SetupContainer(specs[0], conf)
+ conf := testutil.TestConfig(t)
+ rootDir, bundleDir, cleanup, err := testutil.SetupContainer(specs[0], conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(rootBundleDir)
+ defer cleanup()
rootArgs := Args{
ID: ids[0],
Spec: specs[0],
- BundleDir: rootBundleDir,
+ BundleDir: bundleDir,
}
root, err := New(conf, rootArgs)
if err != nil {
@@ -885,16 +880,16 @@ func TestMultiContainerDestroyStarting(t *testing.T) {
continue // skip root container
}
- bundleDir, err := testutil.SetupBundleDir(specs[i])
+ bundleDir, cleanup, err := testutil.SetupBundleDir(specs[i])
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
rootArgs := Args{
ID: ids[i],
Spec: specs[i],
- BundleDir: rootBundleDir,
+ BundleDir: bundleDir,
}
cont, err := New(conf, rootArgs)
if err != nil {
@@ -936,13 +931,13 @@ func TestMultiContainerDifferentFilesystems(t *testing.T) {
script := fmt.Sprintf("if [ -f %q ]; then exit 1; else touch %q; fi", filename, filename)
cmd := []string{"sh", "-c", script}
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Make sure overlay is enabled, and none of the root filesystems are
@@ -976,7 +971,7 @@ func TestMultiContainerDifferentFilesystems(t *testing.T) {
// TestMultiContainerContainerDestroyStress tests that IO operations continue
// to work after containers have been stopped and gofers killed.
func TestMultiContainerContainerDestroyStress(t *testing.T) {
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
@@ -1005,13 +1000,12 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) {
childrenSpecs := allSpecs[1:]
childrenIDs := allIDs[1:]
- conf := testutil.TestConfig()
- rootDir, bundleDir, err := testutil.SetupContainer(rootSpec, conf)
+ conf := testutil.TestConfig(t)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(rootSpec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Start root container.
rootArgs := Args{
@@ -1037,11 +1031,11 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) {
var children []*Container
for j, spec := range specs {
- bundleDir, err := testutil.SetupBundleDir(spec)
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
args := Args{
ID: ids[j],
@@ -1079,355 +1073,348 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) {
// Test that pod shared mounts are properly mounted in 2 containers and that
// changes from one container is reflected in the other.
func TestMultiContainerSharedMount(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
-
- // Setup the containers.
- sleep := []string{"sleep", "100"}
- podSpec, ids := createSpecs(sleep, sleep)
- mnt0 := specs.Mount{
- Destination: "/mydir/test",
- Source: "/some/dir",
- Type: "tmpfs",
- Options: nil,
- }
- podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: nil,
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
- mnt1 := mnt0
- mnt1.Destination = "/mydir2/test2"
- podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
- createSharedMount(mnt0, "test-mount", podSpec...)
+ createSharedMount(mnt0, "test-mount", podSpec...)
- containers, cleanup, err := startContainers(conf, podSpec, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- file0 := path.Join(mnt0.Destination, "abc")
- file1 := path.Join(mnt1.Destination, "abc")
- execs := []execDesc{
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
- desc: "directory is mounted in container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
- desc: "directory is mounted in container1",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/touch", file0},
- desc: "create file in container0",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "-f", file0},
- desc: "file appears in container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "-f", file1},
- desc: "file appears in container1",
- },
- {
- c: containers[1],
- cmd: []string{"/bin/rm", file1},
- desc: "file removed from container1",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "!", "-f", file0},
- desc: "file removed from container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "!", "-f", file1},
- desc: "file removed from container1",
- },
- {
- c: containers[1],
- cmd: []string{"/bin/mkdir", file1},
- desc: "create directory in container1",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "-d", file0},
- desc: "dir appears in container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "-d", file1},
- desc: "dir appears in container1",
- },
- {
- c: containers[0],
- cmd: []string{"/bin/rmdir", file0},
- desc: "create directory in container0",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "!", "-d", file0},
- desc: "dir removed from container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "!", "-d", file1},
- desc: "dir removed from container1",
- },
- }
- if err := execMany(execs); err != nil {
- t.Fatal(err.Error())
- }
+ file0 := path.Join(mnt0.Destination, "abc")
+ file1 := path.Join(mnt1.Destination, "abc")
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
+ name: "directory is mounted in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
+ name: "directory is mounted in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/bin/touch", file0},
+ name: "create file in container0",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-f", file0},
+ name: "file appears in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-f", file1},
+ name: "file appears in container1",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/rm", file1},
+ name: "remove file from container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "!", "-f", file0},
+ name: "file removed from container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "!", "-f", file1},
+ name: "file removed from container1",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/mkdir", file1},
+ name: "create directory in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", file0},
+ name: "dir appears in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", file1},
+ name: "dir appears in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/bin/rmdir", file0},
+ name: "remove directory from container0",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "!", "-d", file0},
+ name: "dir removed from container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "!", "-d", file1},
+ name: "dir removed from container1",
+ },
+ }
+ execMany(t, execs)
+ })
}
}
// Test that pod mounts are mounted as readonly when requested.
func TestMultiContainerSharedMountReadonly(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
-
- // Setup the containers.
- sleep := []string{"sleep", "100"}
- podSpec, ids := createSpecs(sleep, sleep)
- mnt0 := specs.Mount{
- Destination: "/mydir/test",
- Source: "/some/dir",
- Type: "tmpfs",
- Options: []string{"ro"},
- }
- podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: []string{"ro"},
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
- mnt1 := mnt0
- mnt1.Destination = "/mydir2/test2"
- podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
- createSharedMount(mnt0, "test-mount", podSpec...)
+ createSharedMount(mnt0, "test-mount", podSpec...)
- containers, cleanup, err := startContainers(conf, podSpec, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- file0 := path.Join(mnt0.Destination, "abc")
- file1 := path.Join(mnt1.Destination, "abc")
- execs := []execDesc{
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
- desc: "directory is mounted in container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
- desc: "directory is mounted in container1",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/touch", file0},
- want: 1,
- desc: "fails to write to container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/touch", file1},
- want: 1,
- desc: "fails to write to container1",
- },
- }
- if err := execMany(execs); err != nil {
- t.Fatal(err.Error())
- }
+ file0 := path.Join(mnt0.Destination, "abc")
+ file1 := path.Join(mnt1.Destination, "abc")
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
+ name: "directory is mounted in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
+ name: "directory is mounted in container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/bin/touch", file0},
+ want: 1,
+ name: "fails to write to container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/touch", file1},
+ want: 1,
+ name: "fails to write to container1",
+ },
+ }
+ execMany(t, execs)
+ })
}
}
// Test that shared pod mounts continue to work after container is restarted.
func TestMultiContainerSharedMountRestart(t *testing.T) {
- for _, conf := range configs(all...) {
- t.Logf("Running test with conf: %+v", conf)
-
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
- conf.RootDir = rootDir
-
- // Setup the containers.
- sleep := []string{"sleep", "100"}
- podSpec, ids := createSpecs(sleep, sleep)
- mnt0 := specs.Mount{
- Destination: "/mydir/test",
- Source: "/some/dir",
- Type: "tmpfs",
- Options: nil,
- }
- podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+ //TODO(gvisor.dev/issue/1487): This is failing with VFS2.
+ for name, conf := range configs(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: nil,
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
- mnt1 := mnt0
- mnt1.Destination = "/mydir2/test2"
- podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
- createSharedMount(mnt0, "test-mount", podSpec...)
+ createSharedMount(mnt0, "test-mount", podSpec...)
- containers, cleanup, err := startContainers(conf, podSpec, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- file0 := path.Join(mnt0.Destination, "abc")
- file1 := path.Join(mnt1.Destination, "abc")
- execs := []execDesc{
- {
- c: containers[0],
- cmd: []string{"/usr/bin/touch", file0},
- desc: "create file in container0",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "-f", file0},
- desc: "file appears in container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "-f", file1},
- desc: "file appears in container1",
- },
- }
- if err := execMany(execs); err != nil {
- t.Fatal(err.Error())
- }
+ file0 := path.Join(mnt0.Destination, "abc")
+ file1 := path.Join(mnt1.Destination, "abc")
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/bin/touch", file0},
+ name: "create file in container0",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-f", file0},
+ name: "file appears in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-f", file1},
+ name: "file appears in container1",
+ },
+ }
+ execMany(t, execs)
- containers[1].Destroy()
+ containers[1].Destroy()
- bundleDir, err := testutil.SetupBundleDir(podSpec[1])
- if err != nil {
- t.Fatalf("error restarting container: %v", err)
- }
- defer os.RemoveAll(bundleDir)
+ bundleDir, cleanup, err := testutil.SetupBundleDir(podSpec[1])
+ if err != nil {
+ t.Fatalf("error restarting container: %v", err)
+ }
+ defer cleanup()
- args := Args{
- ID: ids[1],
- Spec: podSpec[1],
- BundleDir: bundleDir,
- }
- containers[1], err = New(conf, args)
- if err != nil {
- t.Fatalf("error creating container: %v", err)
- }
- if err := containers[1].Start(conf); err != nil {
- t.Fatalf("error starting container: %v", err)
- }
+ args := Args{
+ ID: ids[1],
+ Spec: podSpec[1],
+ BundleDir: bundleDir,
+ }
+ containers[1], err = New(conf, args)
+ if err != nil {
+ t.Fatalf("error creating container: %v", err)
+ }
+ if err := containers[1].Start(conf); err != nil {
+ t.Fatalf("error starting container: %v", err)
+ }
- execs = []execDesc{
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "-f", file0},
- desc: "file is still in container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "-f", file1},
- desc: "file is still in container1",
- },
- {
- c: containers[1],
- cmd: []string{"/bin/rm", file1},
- desc: "file removed from container1",
- },
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "!", "-f", file0},
- desc: "file removed from container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "!", "-f", file1},
- desc: "file removed from container1",
- },
- }
- if err := execMany(execs); err != nil {
- t.Fatal(err.Error())
- }
+ execs = []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-f", file0},
+ name: "file is still in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-f", file1},
+ name: "file is still in container1",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/bin/rm", file1},
+ name: "file removed from container1",
+ },
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "!", "-f", file0},
+ name: "file removed from container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "!", "-f", file1},
+ name: "file removed from container1",
+ },
+ }
+ execMany(t, execs)
+ })
}
}
// Test that unsupported pod mounts options are ignored when matching master and
// slave mounts.
func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
-
- conf := testutil.TestConfig()
- conf.RootDir = rootDir
-
- // Setup the containers.
- sleep := []string{"/bin/sleep", "100"}
- podSpec, ids := createSpecs(sleep, sleep)
- mnt0 := specs.Mount{
- Destination: "/mydir/test",
- Source: "/some/dir",
- Type: "tmpfs",
- Options: []string{"rw", "rbind", "relatime"},
- }
- podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
+ for name, conf := range configsWithVFS2(t, all...) {
+ t.Run(name, func(t *testing.T) {
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Setup the containers.
+ sleep := []string{"/bin/sleep", "100"}
+ podSpec, ids := createSpecs(sleep, sleep)
+ mnt0 := specs.Mount{
+ Destination: "/mydir/test",
+ Source: "/some/dir",
+ Type: "tmpfs",
+ Options: []string{"rw", "rbind", "relatime"},
+ }
+ podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0)
- mnt1 := mnt0
- mnt1.Destination = "/mydir2/test2"
- mnt1.Options = []string{"rw", "nosuid"}
- podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
+ mnt1 := mnt0
+ mnt1.Destination = "/mydir2/test2"
+ mnt1.Options = []string{"rw", "nosuid"}
+ podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1)
- createSharedMount(mnt0, "test-mount", podSpec...)
+ createSharedMount(mnt0, "test-mount", podSpec...)
- containers, cleanup, err := startContainers(conf, podSpec, ids)
- if err != nil {
- t.Fatalf("error starting containers: %v", err)
- }
- defer cleanup()
+ containers, cleanup, err := startContainers(conf, podSpec, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
- execs := []execDesc{
- {
- c: containers[0],
- cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
- desc: "directory is mounted in container0",
- },
- {
- c: containers[1],
- cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
- desc: "directory is mounted in container1",
- },
- }
- if err := execMany(execs); err != nil {
- t.Fatal(err.Error())
+ execs := []execDesc{
+ {
+ c: containers[0],
+ cmd: []string{"/usr/bin/test", "-d", mnt0.Destination},
+ name: "directory is mounted in container0",
+ },
+ {
+ c: containers[1],
+ cmd: []string{"/usr/bin/test", "-d", mnt1.Destination},
+ name: "directory is mounted in container1",
+ },
+ }
+ execMany(t, execs)
+ })
}
}
// Test that one container can send an FD to another container, even though
// they have distinct MountNamespaces.
func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) {
- app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ app, err := testutil.FindFile("test/cmd/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
@@ -1456,13 +1443,13 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) {
Type: "tmpfs",
}
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Create the specs.
@@ -1493,13 +1480,13 @@ func TestMultiContainerMultiRootCanHandleFDs(t *testing.T) {
// Test that container is destroyed when Gofer is killed.
func TestMultiContainerGoferKilled(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
sleep := []string{"sleep", "100"}
@@ -1513,7 +1500,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
// Ensure container is running
c := containers[2]
expectedPL := []*control.Process{
- {PID: 3, Cmd: "sleep"},
+ newProcessBuilder().PID(3).Cmd("sleep").Process(),
}
if err := waitForProcessList(c, expectedPL); err != nil {
t.Errorf("failed to wait for sleep to start: %v", err)
@@ -1541,7 +1528,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
continue // container[2] has been killed.
}
pl := []*control.Process{
- {PID: kernel.ThreadID(i + 1), Cmd: "sleep"},
+ newProcessBuilder().PID(kernel.ThreadID(i + 1)).Cmd("sleep").Process(),
}
if err := waitForProcessList(c, pl); err != nil {
t.Errorf("Container %q was affected by another container: %v", c.ID, err)
@@ -1561,7 +1548,7 @@ func TestMultiContainerGoferKilled(t *testing.T) {
// Wait until sandbox stops. waitForProcessList will loop until sandbox exits
// and RPC errors out.
impossiblePL := []*control.Process{
- {PID: 100, Cmd: "non-existent-process"},
+ newProcessBuilder().Cmd("non-existent-process").Process(),
}
if err := waitForProcessList(c, impossiblePL); err == nil {
t.Fatalf("Sandbox was not killed after gofer death")
@@ -1580,13 +1567,13 @@ func TestMultiContainerLoadSandbox(t *testing.T) {
sleep := []string{"sleep", "100"}
specs, ids := createSpecs(sleep, sleep, sleep)
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
// Create containers for the sandbox.
@@ -1613,7 +1600,7 @@ func TestMultiContainerLoadSandbox(t *testing.T) {
}
// Create a valid but empty container directory.
- randomCID := testutil.UniqueContainerID()
+ randomCID := testutil.RandomContainerID()
dir = filepath.Join(conf.RootDir, randomCID)
if err := os.MkdirAll(dir, 0755); err != nil {
t.Fatalf("os.MkdirAll(%q)=%v", dir, err)
@@ -1680,13 +1667,13 @@ func TestMultiContainerRunNonRoot(t *testing.T) {
Type: "bind",
})
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
t.Fatalf("error creating root dir: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.RootDir = rootDir
pod, cleanup, err := startContainers(conf, podSpecs, ids)
@@ -1705,3 +1692,83 @@ func TestMultiContainerRunNonRoot(t *testing.T) {
t.Fatalf("child container failed, waitStatus: %v", ws)
}
}
+
+// TestMultiContainerHomeEnvDir tests that the HOME environment variable is set
+// for root containers, sub-containers, and execed processes.
+func TestMultiContainerHomeEnvDir(t *testing.T) {
+ // TODO(gvisor.dev/issue/1487): VFSv2 configs failing.
+ // NOTE: Don't use overlay since we need changes to persist to the temp dir
+ // outside the sandbox.
+ for testName, conf := range configs(t, noOverlay...) {
+ t.Run(testName, func(t *testing.T) {
+
+ rootDir, cleanup, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
+ }
+ defer cleanup()
+ conf.RootDir = rootDir
+
+ // Create temp files we can write the value of $HOME to.
+ homeDirs := map[string]*os.File{}
+ for _, name := range []string{"root", "sub", "exec"} {
+ homeFile, err := ioutil.TempFile(testutil.TmpDir(), name)
+ if err != nil {
+ t.Fatalf("creating temp file: %v", err)
+ }
+ homeDirs[name] = homeFile
+ }
+
+ // We will sleep in the root container in order to ensure that
+ // the root container doesn't terminate before sub containers can be
+ // created.
+ rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s; sleep 1000", homeDirs["root"].Name())}
+ subCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["sub"].Name())}
+ execCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name())}
+
+ // Setup the containers, a root container and sub container.
+ specConfig, ids := createSpecs(rootCmd, subCmd)
+ containers, cleanup, err := startContainers(conf, specConfig, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Exec into the root container synchronously.
+ args := &control.ExecArgs{Argv: execCmd}
+ if _, err := containers[0].executeSync(args); err != nil {
+ t.Errorf("error executing %+v: %v", args, err)
+ }
+
+ // Wait for the subcontainer to finish.
+ _, err = containers[1].Wait()
+ if err != nil {
+ t.Errorf("wait on child container: %v", err)
+ }
+
+ // Wait for the root container to run.
+ expectedPL := []*control.Process{
+ newProcessBuilder().Cmd("sh").Process(),
+ newProcessBuilder().Cmd("sleep").Process(),
+ }
+ if err := waitForProcessList(containers[0], expectedPL); err != nil {
+ t.Errorf("failed to wait for sleep to start: %v", err)
+ }
+
+ // Check the written files.
+ for name, tmpFile := range homeDirs {
+ dirBytes, err := ioutil.ReadAll(tmpFile)
+ if err != nil {
+ t.Fatalf("reading %s temp file: %v", name, err)
+ }
+ got := string(dirBytes)
+
+ want := "/"
+ if got != want {
+ t.Errorf("%s $HOME incorrect: got: %q, want: %q", name, got, want)
+ }
+ }
+
+ })
+ }
+}
diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go
index dc4194134..bac177a88 100644
--- a/runsc/container/shared_volume_test.go
+++ b/runsc/container/shared_volume_test.go
@@ -24,16 +24,15 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/runsc/boot"
- "gvisor.dev/gvisor/runsc/testutil"
)
// TestSharedVolume checks that modifications to a volume mount are propagated
// into and out of the sandbox.
func TestSharedVolume(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.FileAccess = boot.FileAccessShared
- t.Logf("Running test with conf: %+v", conf)
// Main process just sleeps. We will use "exec" to probe the state of
// the filesystem.
@@ -44,16 +43,15 @@ func TestSharedVolume(t *testing.T) {
t.Fatalf("TempDir failed: %v", err)
}
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create and start the container.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
@@ -190,9 +188,8 @@ func checkFile(c *Container, filename string, want []byte) error {
// TestSharedVolumeFile tests that changes to file content outside the sandbox
// is reflected inside.
func TestSharedVolumeFile(t *testing.T) {
- conf := testutil.TestConfig()
+ conf := testutil.TestConfig(t)
conf.FileAccess = boot.FileAccessShared
- t.Logf("Running test with conf: %+v", conf)
// Main process just sleeps. We will use "exec" to probe the state of
// the filesystem.
@@ -203,16 +200,15 @@ func TestSharedVolumeFile(t *testing.T) {
t.Fatalf("TempDir failed: %v", err)
}
- rootDir, bundleDir, err := testutil.SetupContainer(spec, conf)
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
}
- defer os.RemoveAll(rootDir)
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
// Create and start the container.
args := Args{
- ID: testutil.UniqueContainerID(),
+ ID: testutil.RandomContainerID(),
Spec: spec,
BundleDir: bundleDir,
}
diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go
new file mode 100644
index 000000000..17a251530
--- /dev/null
+++ b/runsc/container/state_file.go
@@ -0,0 +1,185 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+
+ "github.com/gofrs/flock"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+const stateFileExtension = ".state"
+
+// StateFile handles load from/save to container state safely from multiple
+// processes. It uses a lock file to provide synchronization between operations.
+//
+// The lock file is located at: "${s.RootDir}/${s.ID}.lock".
+// The state file is located at: "${s.RootDir}/${s.ID}.state".
+type StateFile struct {
+ // RootDir is the directory containing the container metadata file.
+ RootDir string `json:"rootDir"`
+
+ // ID is the container ID.
+ ID string `json:"id"`
+
+ //
+ // Fields below this line are not saved in the state file and will not
+ // be preserved across commands.
+ //
+
+ once sync.Once
+ flock *flock.Flock
+}
+
+// List returns all container ids in the given root directory.
+func List(rootDir string) ([]string, error) {
+ log.Debugf("List containers %q", rootDir)
+ list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension))
+ if err != nil {
+ return nil, err
+ }
+ var out []string
+ for _, path := range list {
+ // Filter out files that do no belong to a container.
+ fileName := filepath.Base(path)
+ if len(fileName) < len(stateFileExtension) {
+ panic(fmt.Sprintf("invalid file match %q", path))
+ }
+ // Remove the extension.
+ cid := fileName[:len(fileName)-len(stateFileExtension)]
+ if validateID(cid) == nil {
+ out = append(out, cid)
+ }
+ }
+ return out, nil
+}
+
+// lock globally locks all locking operations for the container.
+func (s *StateFile) lock() error {
+ s.once.Do(func() {
+ s.flock = flock.NewFlock(s.lockPath())
+ })
+
+ if err := s.flock.Lock(); err != nil {
+ return fmt.Errorf("acquiring lock on %q: %v", s.flock, err)
+ }
+ return nil
+}
+
+// lockForNew acquires the lock and checks if the state file doesn't exist. This
+// is done to ensure that more than one creation didn't race to create
+// containers with the same ID.
+func (s *StateFile) lockForNew() error {
+ if err := s.lock(); err != nil {
+ return err
+ }
+
+ // Checks if the container already exists by looking for the metadata file.
+ if _, err := os.Stat(s.statePath()); err == nil {
+ s.unlock()
+ return fmt.Errorf("container already exists")
+ } else if !os.IsNotExist(err) {
+ s.unlock()
+ return fmt.Errorf("looking for existing container: %v", err)
+ }
+ return nil
+}
+
+// unlock globally unlocks all locking operations for the container.
+func (s *StateFile) unlock() error {
+ if !s.flock.Locked() {
+ panic("unlock called without lock held")
+ }
+
+ if err := s.flock.Unlock(); err != nil {
+ log.Warningf("Error to release lock on %q: %v", s.flock, err)
+ return fmt.Errorf("releasing lock on %q: %v", s.flock, err)
+ }
+ return nil
+}
+
+// saveLocked saves 'v' to the state file.
+//
+// Preconditions: lock() must been called before.
+func (s *StateFile) saveLocked(v interface{}) error {
+ if !s.flock.Locked() {
+ panic("saveLocked called without lock held")
+ }
+
+ meta, err := json.Marshal(v)
+ if err != nil {
+ return err
+ }
+ if err := ioutil.WriteFile(s.statePath(), meta, 0640); err != nil {
+ return fmt.Errorf("writing json file: %v", err)
+ }
+ return nil
+}
+
+func (s *StateFile) load(v interface{}) error {
+ if err := s.lock(); err != nil {
+ return err
+ }
+ defer s.unlock()
+
+ metaBytes, err := ioutil.ReadFile(s.statePath())
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(metaBytes, &v)
+}
+
+func (s *StateFile) close() error {
+ if s.flock == nil {
+ return nil
+ }
+ if s.flock.Locked() {
+ panic("Closing locked file")
+ }
+ return s.flock.Close()
+}
+
+func buildStatePath(rootDir, id string) string {
+ return filepath.Join(rootDir, id+stateFileExtension)
+}
+
+// statePath is the full path to the state file.
+func (s *StateFile) statePath() string {
+ return buildStatePath(s.RootDir, s.ID)
+}
+
+// lockPath is the full path to the lock file.
+func (s *StateFile) lockPath() string {
+ return filepath.Join(s.RootDir, s.ID+".lock")
+}
+
+// destroy deletes all state created by the stateFile. It may be called with the
+// lock file held. In that case, the lock file must still be unlocked and
+// properly closed after destroy returns.
+func (s *StateFile) destroy() error {
+ if err := os.Remove(s.statePath()); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ if err := os.Remove(s.lockPath()); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ return nil
+}
diff --git a/runsc/debian/description b/runsc/debian/description
index 6e3b1b2c0..9e8e08805 100644
--- a/runsc/debian/description
+++ b/runsc/debian/description
@@ -1,5 +1 @@
-gVisor is a user-space kernel, written in Go, that implements a substantial
-portion of the Linux system surface. It includes an Open Container Initiative
-(OCI) runtime called runsc that provides an isolation boundary between the
-application and the host kernel. The runsc runtime integrates with Docker and
-Kubernetes, making it simple to run sandboxed containers.
+gVisor container sandbox runtime
diff --git a/runsc/debian/postinst.sh b/runsc/debian/postinst.sh
index dc7aeee87..d1e28e17b 100755
--- a/runsc/debian/postinst.sh
+++ b/runsc/debian/postinst.sh
@@ -18,7 +18,14 @@ if [ "$1" != configure ]; then
exit 0
fi
+# Update docker configuration.
if [ -f /etc/docker/daemon.json ]; then
runsc install
- systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2
+ if systemctl status docker 2>/dev/null; then
+ systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2
+ fi
fi
+
+# For containerd-based installers, we don't automatically update the
+# configuration. If it uses a v2 shim, then it will find the package binaries
+# automatically when provided the appropriate annotation.
diff --git a/runsc/dockerutil/BUILD b/runsc/dockerutil/BUILD
deleted file mode 100644
index 0e0423504..000000000
--- a/runsc/dockerutil/BUILD
+++ /dev/null
@@ -1,15 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "dockerutil",
- testonly = 1,
- srcs = ["dockerutil.go"],
- importpath = "gvisor.dev/gvisor/runsc/dockerutil",
- visibility = ["//:sandbox"],
- deps = [
- "//runsc/testutil",
- "@com_github_kr_pty//:go_default_library",
- ],
-)
diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go
deleted file mode 100644
index 57f6ae8de..000000000
--- a/runsc/dockerutil/dockerutil.go
+++ /dev/null
@@ -1,467 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package dockerutil is a collection of utility functions, primarily for
-// testing.
-package dockerutil
-
-import (
- "encoding/json"
- "flag"
- "fmt"
- "io/ioutil"
- "log"
- "os"
- "os/exec"
- "path"
- "regexp"
- "strconv"
- "strings"
- "syscall"
- "time"
-
- "github.com/kr/pty"
- "gvisor.dev/gvisor/runsc/testutil"
-)
-
-var (
- runtime = flag.String("runtime", "runsc", "specify which runtime to use")
- config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
-)
-
-// EnsureSupportedDockerVersion checks if correct docker is installed.
-func EnsureSupportedDockerVersion() {
- cmd := exec.Command("docker", "version")
- out, err := cmd.CombinedOutput()
- if err != nil {
- log.Fatalf("Error running %q: %v", "docker version", err)
- }
- re := regexp.MustCompile(`Version:\s+(\d+)\.(\d+)\.\d.*`)
- matches := re.FindStringSubmatch(string(out))
- if len(matches) != 3 {
- log.Fatalf("Invalid docker output: %s", out)
- }
- major, _ := strconv.Atoi(matches[1])
- minor, _ := strconv.Atoi(matches[2])
- if major < 17 || (major == 17 && minor < 9) {
- log.Fatalf("Docker version 17.09.0 or greater is required, found: %02d.%02d", major, minor)
- }
-}
-
-// RuntimePath returns the binary path for the current runtime.
-func RuntimePath() (string, error) {
- // Read the configuration data; the file must exist.
- configBytes, err := ioutil.ReadFile(*config)
- if err != nil {
- return "", err
- }
-
- // Unmarshal the configuration.
- c := make(map[string]interface{})
- if err := json.Unmarshal(configBytes, &c); err != nil {
- return "", err
- }
-
- // Decode the expected configuration.
- r, ok := c["runtimes"]
- if !ok {
- return "", fmt.Errorf("no runtimes declared: %v", c)
- }
- rs, ok := r.(map[string]interface{})
- if !ok {
- // The runtimes are not a map.
- return "", fmt.Errorf("unexpected format: %v", c)
- }
- r, ok = rs[*runtime]
- if !ok {
- // The expected runtime is not declared.
- return "", fmt.Errorf("runtime %q not found: %v", *runtime, c)
- }
- rs, ok = r.(map[string]interface{})
- if !ok {
- // The runtime is not a map.
- return "", fmt.Errorf("unexpected format: %v", c)
- }
- p, ok := rs["path"].(string)
- if !ok {
- // The runtime does not declare a path.
- return "", fmt.Errorf("unexpected format: %v", c)
- }
- return p, nil
-}
-
-// MountMode describes if the mount should be ro or rw.
-type MountMode int
-
-const (
- // ReadOnly is what the name says.
- ReadOnly MountMode = iota
- // ReadWrite is what the name says.
- ReadWrite
-)
-
-// String returns the mount mode argument for this MountMode.
-func (m MountMode) String() string {
- switch m {
- case ReadOnly:
- return "ro"
- case ReadWrite:
- return "rw"
- }
- panic(fmt.Sprintf("invalid mode: %d", m))
-}
-
-// MountArg formats the volume argument to mount in the container.
-func MountArg(source, target string, mode MountMode) string {
- return fmt.Sprintf("-v=%s:%s:%v", source, target, mode)
-}
-
-// LinkArg formats the link argument.
-func LinkArg(source *Docker, target string) string {
- return fmt.Sprintf("--link=%s:%s", source.Name, target)
-}
-
-// PrepareFiles creates temp directory to copy files there. The sandbox doesn't
-// have access to files in the test dir.
-func PrepareFiles(names ...string) (string, error) {
- dir, err := ioutil.TempDir("", "image-test")
- if err != nil {
- return "", fmt.Errorf("ioutil.TempDir failed: %v", err)
- }
- if err := os.Chmod(dir, 0777); err != nil {
- return "", fmt.Errorf("os.Chmod(%q, 0777) failed: %v", dir, err)
- }
- for _, name := range names {
- src := getLocalPath(name)
- dst := path.Join(dir, name)
- if err := testutil.Copy(src, dst); err != nil {
- return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
- }
- }
- return dir, nil
-}
-
-func getLocalPath(file string) string {
- return path.Join(".", file)
-}
-
-// do executes docker command.
-func do(args ...string) (string, error) {
- log.Printf("Running: docker %s\n", args)
- cmd := exec.Command("docker", args...)
- out, err := cmd.CombinedOutput()
- if err != nil {
- return "", fmt.Errorf("error executing docker %s: %v\nout: %s", args, err, out)
- }
- return string(out), nil
-}
-
-// doWithPty executes docker command with stdio attached to a pty.
-func doWithPty(args ...string) (*exec.Cmd, *os.File, error) {
- log.Printf("Running with pty: docker %s\n", args)
- cmd := exec.Command("docker", args...)
- ptmx, err := pty.Start(cmd)
- if err != nil {
- return nil, nil, fmt.Errorf("error executing docker %s with a pty: %v", args, err)
- }
- return cmd, ptmx, nil
-}
-
-// Pull pulls a docker image. This is used in tests to isolate the
-// time to pull the image off the network from the time to actually
-// start the container, to avoid timeouts over slow networks.
-func Pull(image string) error {
- _, err := do("pull", image)
- return err
-}
-
-// Docker contains the name and the runtime of a docker container.
-type Docker struct {
- Runtime string
- Name string
-}
-
-// MakeDocker sets up the struct for a Docker container.
-// Names of containers will be unique.
-func MakeDocker(namePrefix string) Docker {
- return Docker{
- Name: testutil.RandomName(namePrefix),
- Runtime: *runtime,
- }
-}
-
-// logDockerID logs a container id, which is needed to find container runsc logs.
-func (d *Docker) logDockerID() {
- id, err := d.ID()
- if err != nil {
- log.Printf("%v\n", err)
- }
- log.Printf("Name: %s ID: %v\n", d.Name, id)
-}
-
-// Create calls 'docker create' with the arguments provided.
-func (d *Docker) Create(args ...string) error {
- a := []string{"create", "--runtime", d.Runtime, "--name", d.Name}
- a = append(a, args...)
- _, err := do(a...)
- if err == nil {
- d.logDockerID()
- }
- return err
-}
-
-// Start calls 'docker start'.
-func (d *Docker) Start() error {
- if _, err := do("start", d.Name); err != nil {
- return fmt.Errorf("error starting container %q: %v", d.Name, err)
- }
- return nil
-}
-
-// Stop calls 'docker stop'.
-func (d *Docker) Stop() error {
- if _, err := do("stop", d.Name); err != nil {
- return fmt.Errorf("error stopping container %q: %v", d.Name, err)
- }
- return nil
-}
-
-// Run calls 'docker run' with the arguments provided. The container starts
-// running in the background and the call returns immediately.
-func (d *Docker) Run(args ...string) error {
- a := d.runArgs("-d")
- a = append(a, args...)
- _, err := do(a...)
- if err == nil {
- d.logDockerID()
- }
- return err
-}
-
-// RunWithPty is like Run but with an attached pty.
-func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) {
- a := d.runArgs("-it")
- a = append(a, args...)
- return doWithPty(a...)
-}
-
-// RunFg calls 'docker run' with the arguments provided in the foreground. It
-// blocks until the container exits and returns the output.
-func (d *Docker) RunFg(args ...string) (string, error) {
- a := d.runArgs(args...)
- out, err := do(a...)
- if err == nil {
- d.logDockerID()
- }
- return string(out), err
-}
-
-func (d *Docker) runArgs(args ...string) []string {
- // Environment variable RUNSC_TEST_NAME is picked up by the runtime and added
- // to the log name, so one can easily identify the corresponding logs for
- // this test.
- rv := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-e", "RUNSC_TEST_NAME=" + d.Name}
- return append(rv, args...)
-}
-
-// Logs calls 'docker logs'.
-func (d *Docker) Logs() (string, error) {
- return do("logs", d.Name)
-}
-
-// Exec calls 'docker exec' with the arguments provided.
-func (d *Docker) Exec(args ...string) (string, error) {
- return d.ExecWithFlags(nil, args...)
-}
-
-// ExecWithFlags calls 'docker exec <flags> name <args>'.
-func (d *Docker) ExecWithFlags(flags []string, args ...string) (string, error) {
- a := []string{"exec"}
- a = append(a, flags...)
- a = append(a, d.Name)
- a = append(a, args...)
- return do(a...)
-}
-
-// ExecAsUser calls 'docker exec' as the given user with the arguments
-// provided.
-func (d *Docker) ExecAsUser(user string, args ...string) (string, error) {
- a := []string{"exec", "--user", user, d.Name}
- a = append(a, args...)
- return do(a...)
-}
-
-// ExecWithTerminal calls 'docker exec -it' with the arguments provided and
-// attaches a pty to stdio.
-func (d *Docker) ExecWithTerminal(args ...string) (*exec.Cmd, *os.File, error) {
- a := []string{"exec", "-it", d.Name}
- a = append(a, args...)
- return doWithPty(a...)
-}
-
-// Pause calls 'docker pause'.
-func (d *Docker) Pause() error {
- if _, err := do("pause", d.Name); err != nil {
- return fmt.Errorf("error pausing container %q: %v", d.Name, err)
- }
- return nil
-}
-
-// Unpause calls 'docker pause'.
-func (d *Docker) Unpause() error {
- if _, err := do("unpause", d.Name); err != nil {
- return fmt.Errorf("error unpausing container %q: %v", d.Name, err)
- }
- return nil
-}
-
-// Checkpoint calls 'docker checkpoint'.
-func (d *Docker) Checkpoint(name string) error {
- if _, err := do("checkpoint", "create", d.Name, name); err != nil {
- return fmt.Errorf("error pausing container %q: %v", d.Name, err)
- }
- return nil
-}
-
-// Restore calls 'docker start --checkname [name]'.
-func (d *Docker) Restore(name string) error {
- if _, err := do("start", "--checkpoint", name, d.Name); err != nil {
- return fmt.Errorf("error starting container %q: %v", d.Name, err)
- }
- return nil
-}
-
-// Remove calls 'docker rm'.
-func (d *Docker) Remove() error {
- if _, err := do("rm", d.Name); err != nil {
- return fmt.Errorf("error deleting container %q: %v", d.Name, err)
- }
- return nil
-}
-
-// CleanUp kills and deletes the container (best effort).
-func (d *Docker) CleanUp() {
- d.logDockerID()
- if _, err := do("kill", d.Name); err != nil {
- if strings.Contains(err.Error(), "is not running") {
- // Nothing to kill. Don't log the error in this case.
- } else {
- log.Printf("error killing container %q: %v", d.Name, err)
- }
- }
- if err := d.Remove(); err != nil {
- log.Print(err)
- }
-}
-
-// FindPort returns the host port that is mapped to 'sandboxPort'. This calls
-// docker to allocate a free port in the host and prevent conflicts.
-func (d *Docker) FindPort(sandboxPort int) (int, error) {
- format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort)
- out, err := do("inspect", "-f", format, d.Name)
- if err != nil {
- return -1, fmt.Errorf("error retrieving port: %v", err)
- }
- port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- return -1, fmt.Errorf("error parsing port %q: %v", out, err)
- }
- return port, nil
-}
-
-// SandboxPid returns the PID to the sandbox process.
-func (d *Docker) SandboxPid() (int, error) {
- out, err := do("inspect", "-f={{.State.Pid}}", d.Name)
- if err != nil {
- return -1, fmt.Errorf("error retrieving pid: %v", err)
- }
- pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- return -1, fmt.Errorf("error parsing pid %q: %v", out, err)
- }
- return pid, nil
-}
-
-// ID returns the container ID.
-func (d *Docker) ID() (string, error) {
- out, err := do("inspect", "-f={{.Id}}", d.Name)
- if err != nil {
- return "", fmt.Errorf("error retrieving ID: %v", err)
- }
- return strings.TrimSpace(string(out)), nil
-}
-
-// Wait waits for container to exit, up to the given timeout. Returns error if
-// wait fails or timeout is hit. Returns the application return code otherwise.
-// Note that the application may have failed even if err == nil, always check
-// the exit code.
-func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) {
- timeoutChan := time.After(timeout)
- waitChan := make(chan (syscall.WaitStatus))
- errChan := make(chan (error))
-
- go func() {
- out, err := do("wait", d.Name)
- if err != nil {
- errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err)
- }
- exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n"))
- if err != nil {
- errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err)
- }
- waitChan <- syscall.WaitStatus(uint32(exit))
- }()
-
- select {
- case ws := <-waitChan:
- return ws, nil
- case err := <-errChan:
- return syscall.WaitStatus(1), err
- case <-timeoutChan:
- return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name)
- }
-}
-
-// WaitForOutput calls 'docker logs' to retrieve containers output and searches
-// for the given pattern.
-func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) {
- matches, err := d.WaitForOutputSubmatch(pattern, timeout)
- if err != nil {
- return "", err
- }
- if len(matches) == 0 {
- return "", nil
- }
- return matches[0], nil
-}
-
-// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and
-// searches for the given pattern. It returns any regexp submatches as well.
-func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) {
- re := regexp.MustCompile(pattern)
- var out string
- for exp := time.Now().Add(timeout); time.Now().Before(exp); {
- var err error
- out, err = d.Logs()
- if err != nil {
- return nil, err
- }
- if matches := re.FindStringSubmatch(out); matches != nil {
- // Success!
- return matches, nil
- }
- time.Sleep(100 * time.Millisecond)
- }
- return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), out)
-}
diff --git a/runsc/flag/BUILD b/runsc/flag/BUILD
new file mode 100644
index 000000000..5cb7604a8
--- /dev/null
+++ b/runsc/flag/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "flag",
+ srcs = ["flag.go"],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/sentry/kernel/pipe/buffer_test.go b/runsc/flag/flag.go
index ee1b90115..0ca4829d7 100644
--- a/pkg/sentry/kernel/pipe/buffer_test.go
+++ b/runsc/flag/flag.go
@@ -12,21 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package pipe
+package flag
import (
- "testing"
- "unsafe"
+ "flag"
+)
+
+type FlagSet = flag.FlagSet
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+var (
+ NewFlagSet = flag.NewFlagSet
+ String = flag.String
+ Bool = flag.Bool
+ Int = flag.Int
+ Uint = flag.Uint
+ CommandLine = flag.CommandLine
+ Parse = flag.Parse
)
-func TestBufferSize(t *testing.T) {
- bufferSize := unsafe.Sizeof(buffer{})
- if bufferSize < usermem.PageSize {
- t.Errorf("buffer is less than a page")
- }
- if bufferSize > (2 * usermem.PageSize) {
- t.Errorf("buffer is greater than two pages")
- }
-}
+const ContinueOnError = flag.ContinueOnError
diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD
index 80a4aa2fe..05e3637f7 100644
--- a/runsc/fsgofer/BUILD
+++ b/runsc/fsgofer/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -6,19 +6,19 @@ go_library(
name = "fsgofer",
srcs = [
"fsgofer.go",
+ "fsgofer_amd64_unsafe.go",
+ "fsgofer_arm64_unsafe.go",
"fsgofer_unsafe.go",
],
- importpath = "gvisor.dev/gvisor/runsc/fsgofer",
- visibility = [
- "//runsc:__subpackages__",
- ],
+ visibility = ["//runsc:__subpackages__"],
deps = [
"//pkg/abi/linux",
+ "//pkg/cleanup",
"//pkg/fd",
"//pkg/log",
"//pkg/p9",
+ "//pkg/sync",
"//pkg/syserr",
- "//runsc/specutils",
"@org_golang_x_sys//unix:go_default_library",
],
)
@@ -27,9 +27,10 @@ go_test(
name = "fsgofer_test",
size = "small",
srcs = ["fsgofer_test.go"],
- embed = [":fsgofer"],
+ library = ":fsgofer",
deps = [
"//pkg/log",
"//pkg/p9",
+ "//pkg/test/testutil",
],
)
diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD
index 02168ad1b..82b48ef32 100644
--- a/runsc/fsgofer/filter/BUILD
+++ b/runsc/fsgofer/filter/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -6,12 +6,13 @@ go_library(
name = "filter",
srcs = [
"config.go",
+ "config_amd64.go",
+ "config_arm64.go",
"extra_filters.go",
"extra_filters_msan.go",
"extra_filters_race.go",
"filter.go",
],
- importpath = "gvisor.dev/gvisor/runsc/fsgofer/filter",
visibility = [
"//runsc:__subpackages__",
],
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 2ea95f8fb..88814b83c 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -25,11 +25,7 @@ import (
// allowedSyscalls is the set of syscalls executed by the gofer.
var allowedSyscalls = seccomp.SyscallRules{
- syscall.SYS_ACCEPT: {},
- syscall.SYS_ARCH_PRCTL: []seccomp.Rule{
- {seccomp.AllowValue(linux.ARCH_GET_FS)},
- {seccomp.AllowValue(linux.ARCH_SET_FS)},
- },
+ syscall.SYS_ACCEPT: {},
syscall.SYS_CLOCK_GETTIME: {},
syscall.SYS_CLONE: []seccomp.Rule{
{
@@ -132,6 +128,19 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_MADVISE: {},
unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
syscall.SYS_MKDIRAT: {},
+ syscall.SYS_MKNODAT: {},
+ // Used by the Go runtime as a temporarily workaround for a Linux
+ // 5.2-5.4 bug.
+ //
+ // See src/runtime/os_linux_x86.go.
+ //
+ // TODO(b/148688965): Remove once this is gone from Go.
+ syscall.SYS_MLOCK: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(4096),
+ },
+ },
syscall.SYS_MMAP: []seccomp.Rule{
{
seccomp.AllowAny{},
@@ -155,7 +164,6 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_MPROTECT: {},
syscall.SYS_MUNMAP: {},
syscall.SYS_NANOSLEEP: {},
- syscall.SYS_NEWFSTATAT: {},
syscall.SYS_OPENAT: {},
syscall.SYS_PPOLL: {},
syscall.SYS_PREAD64: {},
diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go
new file mode 100644
index 000000000..a4b28cb8b
--- /dev/null
+++ b/runsc/fsgofer/filter/config_amd64.go
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package filter
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/seccomp"
+)
+
+func init() {
+ allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{
+ {seccomp.AllowValue(linux.ARCH_GET_FS)},
+ {seccomp.AllowValue(linux.ARCH_SET_FS)},
+ }
+
+ allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{}
+}
diff --git a/pkg/fspath/builder_unsafe.go b/runsc/fsgofer/filter/config_arm64.go
index 75606808d..d2697deb7 100644
--- a/pkg/fspath/builder_unsafe.go
+++ b/runsc/fsgofer/filter/config_arm64.go
@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package fspath
+// +build arm64
+
+package filter
import (
- "unsafe"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/seccomp"
)
-// String returns the accumulated string. No other methods should be called
-// after String.
-func (b *Builder) String() string {
- bs := b.buf[b.start:]
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
+func init() {
+ allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{}
}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 3fceecb3d..c6694c278 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -29,15 +29,15 @@ import (
"path/filepath"
"runtime"
"strconv"
- "sync"
"syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
- "gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/pkg/sync"
)
const (
@@ -48,36 +48,6 @@ const (
openFlags = syscall.O_NOFOLLOW | syscall.O_CLOEXEC
)
-type fileType int
-
-const (
- regular fileType = iota
- directory
- symlink
- socket
- unknown
-)
-
-// String implements fmt.Stringer.
-func (f fileType) String() string {
- switch f {
- case regular:
- return "regular"
- case directory:
- return "directory"
- case symlink:
- return "symlink"
- case socket:
- return "socket"
- }
- return "unknown"
-}
-
-// ControlSocketAddr generates an abstract unix socket name for the given id.
-func ControlSocketAddr(id string) string {
- return fmt.Sprintf("\x00runsc-gofer.%s", id)
-}
-
// Config sets configuration options for each attach point.
type Config struct {
// ROMount is set to true if this is a readonly mount.
@@ -132,19 +102,19 @@ func (a *attachPoint) Attach() (p9.File, error) {
return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
}
- f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) {
+ f, readable, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) {
return fd.Open(a.prefix, openFlags|mode, 0)
})
if err != nil {
return nil, fmt.Errorf("unable to open %q: %v", a.prefix, err)
}
- stat, err := stat(f.FD())
+ stat, err := fstat(f.FD())
if err != nil {
return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err)
}
- lf, err := newLocalFile(a, f, a.prefix, stat)
+ lf, err := newLocalFile(a, f, a.prefix, readable, stat)
if err != nil {
return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err)
}
@@ -175,8 +145,6 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID {
log.Warningf("first 8 bytes of host inode id %x will be truncated to construct virtual inode id", stat.Ino)
}
ino := uint64(dev)<<56 | maskedIno
- log.Debugf("host inode %x on device %x mapped to virtual inode %x", stat.Ino, stat.Dev, ino)
-
return p9.QID{
Type: p9.FileMode(stat.Mode).QIDType(),
Path: ino,
@@ -199,9 +167,8 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID {
// The reason that the file is not opened initially as read-write is for better
// performance with 'overlay2' storage driver. overlay2 eagerly copies the
// entire file up when it's opened in write mode, and would perform badly when
+// multiple files are only being opened for read (esp. startup).
type localFile struct {
- p9.DefaultWalkGetAttr
-
// attachPoint is the attachPoint that serves this localFile.
attachPoint *attachPoint
@@ -213,12 +180,19 @@ type localFile struct {
// opened with.
file *fd.FD
+ // controlReadable tells whether 'file' was opened with read permissions
+ // during a walk.
+ controlReadable bool
+
// mode is the mode in which the file was opened. Set to invalidMode
// if localFile isn't opened.
mode p9.OpenFlags
- // ft is the fileType for this file.
- ft fileType
+ // fileType for this file. It is equivalent to:
+ // syscall.Stat_t.Mode & syscall.S_IFMT
+ fileType uint32
+
+ qid p9.QID
// readDirMu protects against concurrent Readdir calls.
readDirMu sync.Mutex
@@ -252,83 +226,88 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) {
return fd.New(d), nil
}
-func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, error) {
+func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) {
path := path.Join(parent.hostPath, name)
- f, err := openAnyFile(path, func(mode int) (*fd.FD, error) {
+ f, readable, err := openAnyFile(path, func(mode int) (*fd.FD, error) {
return fd.OpenAt(parent.file, name, openFlags|mode, 0)
})
- return f, path, err
+ return f, path, readable, err
}
// openAnyFile attempts to open the file in O_RDONLY and if it fails fallsback
// to O_PATH. 'path' is used for logging messages only. 'fn' is what does the
// actual file open and is customizable by the caller.
-func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) {
+func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, error) {
// Attempt to open file in the following mode in order:
// 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs.
// Use non-blocking to prevent getting stuck inside open(2) for
// FIFOs. This option has no effect on regular files.
// 2. PATH: for symlinks, sockets.
- modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH}
+ options := []struct {
+ mode int
+ readable bool
+ }{
+ {
+ mode: syscall.O_RDONLY | syscall.O_NONBLOCK,
+ readable: true,
+ },
+ {
+ mode: unix.O_PATH,
+ readable: false,
+ },
+ }
var err error
- var file *fd.FD
- for i, mode := range modes {
- file, err = fn(mode)
+ for i, option := range options {
+ var file *fd.FD
+ file, err = fn(option.mode)
if err == nil {
- // openat succeeded, we're done.
- break
+ // Succeeded opening the file, we're done.
+ return file, option.readable, nil
}
switch e := extractErrno(err); e {
case syscall.ENOENT:
// File doesn't exist, no point in retrying.
- return nil, e
+ return nil, false, e
}
- // openat failed. Try again with next mode, preserving 'err' in case this
- // was the last attempt.
- log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|mode, path, err)
+ // File failed to open. Try again with next mode, preserving 'err' in case
+ // this was the last attempt.
+ log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|option.mode, path, err)
}
- if err != nil {
- // All attempts to open file have failed, return the last error.
- log.Debugf("Failed to open file, path: %q, err: %v", path, err)
- return nil, extractErrno(err)
- }
-
- return file, nil
+ // All attempts to open file have failed, return the last error.
+ log.Debugf("Failed to open file, path: %q, err: %v", path, err)
+ return nil, false, extractErrno(err)
}
-func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) {
- var ft fileType
+func checkSupportedFileType(stat syscall.Stat_t, permitSocket bool) error {
switch stat.Mode & syscall.S_IFMT {
- case syscall.S_IFREG:
- ft = regular
- case syscall.S_IFDIR:
- ft = directory
- case syscall.S_IFLNK:
- ft = symlink
+ case syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK:
+ return nil
+
case syscall.S_IFSOCK:
if !permitSocket {
- return unknown, syscall.EPERM
+ return syscall.EPERM
}
- ft = socket
+ return nil
+
default:
- return unknown, syscall.EPERM
+ return syscall.EPERM
}
- return ft, nil
}
-func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) {
- ft, err := getSupportedFileType(stat, a.conf.HostUDS)
- if err != nil {
+func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat syscall.Stat_t) (*localFile, error) {
+ if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil {
return nil, err
}
return &localFile{
- attachPoint: a,
- hostPath: path,
- file: file,
- mode: invalidMode,
- ft: ft,
+ attachPoint: a,
+ hostPath: path,
+ file: file,
+ mode: invalidMode,
+ fileType: stat.Mode & syscall.S_IFMT,
+ qid: a.makeQID(stat),
+ controlReadable: readable,
}, nil
}
@@ -347,13 +326,13 @@ func newFDMaybe(file *fd.FD) *fd.FD {
// fd is blocking; non-blocking is required.
if err := syscall.SetNonblock(dup.FD(), true); err != nil {
- dup.Close()
+ _ = dup.Close()
return nil
}
return dup
}
-func stat(fd int) (syscall.Stat_t, error) {
+func fstat(fd int) (syscall.Stat_t, error) {
var stat syscall.Stat_t
if err := syscall.Fstat(fd, &stat); err != nil {
return syscall.Stat_t{}, err
@@ -361,43 +340,44 @@ func stat(fd int) (syscall.Stat_t, error) {
return stat, nil
}
+func stat(path string) (syscall.Stat_t, error) {
+ var stat syscall.Stat_t
+ if err := syscall.Stat(path, &stat); err != nil {
+ return syscall.Stat_t{}, err
+ }
+ return stat, nil
+}
+
func fchown(fd int, uid p9.UID, gid p9.GID) error {
return syscall.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW)
}
// Open implements p9.File.
-func (l *localFile) Open(mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
+func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
if l.isOpen() {
panic(fmt.Sprintf("attempting to open already opened file: %q", l.hostPath))
}
// Check if control file can be used or if a new open must be created.
var newFile *fd.FD
- if mode == p9.ReadOnly {
- log.Debugf("Open reusing control file, mode: %v, %q", mode, l.hostPath)
+ if flags == p9.ReadOnly && l.controlReadable {
+ log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath)
newFile = l.file
} else {
// Ideally reopen would call name_to_handle_at (with empty name) and
// open_by_handle_at to reopen the file without using 'hostPath'. However,
// name_to_handle_at and open_by_handle_at aren't supported by overlay2.
- log.Debugf("Open reopening file, mode: %v, %q", mode, l.hostPath)
+ log.Debugf("Open reopening file, flags: %v, %q", flags, l.hostPath)
var err error
- newFile, err = reopenProcFd(l.file, openFlags|mode.OSFlags())
+ // Constrain open flags to the open mode and O_TRUNC.
+ newFile, err = reopenProcFd(l.file, openFlags|(flags.OSFlags()&(syscall.O_ACCMODE|syscall.O_TRUNC)))
if err != nil {
return nil, p9.QID{}, 0, extractErrno(err)
}
}
- stat, err := stat(newFile.FD())
- if err != nil {
- if newFile != l.file {
- newFile.Close()
- }
- return nil, p9.QID{}, 0, extractErrno(err)
- }
-
var fd *fd.FD
- if stat.Mode&syscall.S_IFMT == syscall.S_IFREG {
+ if l.fileType == syscall.S_IFREG {
// Donate FD for regular files only.
fd = newFDMaybe(newFile)
}
@@ -409,8 +389,8 @@ func (l *localFile) Open(mode p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) {
}
l.file = newFile
}
- l.mode = mode
- return fd, l.attachPoint.makeQID(stat), 0, nil
+ l.mode = flags & p9.OpenFlagsModeMask
+ return fd, l.qid, 0, nil
}
// Create implements p9.File.
@@ -437,8 +417,8 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid
if err != nil {
return nil, nil, p9.QID{}, 0, extractErrno(err)
}
- cu := specutils.MakeCleanup(func() {
- child.Close()
+ cu := cleanup.Make(func() {
+ _ = child.Close()
// Best effort attempt to remove the file in case of failure.
if err := syscall.Unlinkat(l.file.FD(), name); err != nil {
log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err)
@@ -449,7 +429,7 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid
if err := fchown(child.FD(), uid, gid); err != nil {
return nil, nil, p9.QID{}, 0, extractErrno(err)
}
- stat, err := stat(child.FD())
+ stat, err := fstat(child.FD())
if err != nil {
return nil, nil, p9.QID{}, 0, extractErrno(err)
}
@@ -459,10 +439,12 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid
hostPath: path.Join(l.hostPath, name),
file: child,
mode: mode,
+ fileType: syscall.S_IFREG,
+ qid: l.attachPoint.makeQID(stat),
}
cu.Release()
- return newFDMaybe(c.file), c, l.attachPoint.makeQID(stat), 0, nil
+ return newFDMaybe(c.file), c, c.qid, 0, nil
}
// Mkdir implements p9.File.
@@ -478,7 +460,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
if err := syscall.Mkdirat(l.file.FD(), name, uint32(perm.Permissions())); err != nil {
return p9.QID{}, extractErrno(err)
}
- cu := specutils.MakeCleanup(func() {
+ cu := cleanup.Make(func() {
// Best effort attempt to remove the dir in case of failure.
if err := unix.Unlinkat(l.file.FD(), name, unix.AT_REMOVEDIR); err != nil {
log.Warningf("error unlinking dir %q after failure: %v", path.Join(l.hostPath, name), err)
@@ -497,7 +479,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
if err := fchown(f.FD(), uid, gid); err != nil {
return p9.QID{}, extractErrno(err)
}
- stat, err := stat(f.FD())
+ stat, err := fstat(f.FD())
if err != nil {
return p9.QID{}, extractErrno(err)
}
@@ -508,55 +490,74 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID)
// Walk implements p9.File.
func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) {
+ qids, file, _, err := l.walk(names)
+ return qids, file, err
+}
+
+// WalkGetAttr implements p9.File.
+func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) {
+ qids, file, stat, err := l.walk(names)
+ if err != nil {
+ return nil, nil, p9.AttrMask{}, p9.Attr{}, err
+ }
+ mask, attr := l.fillAttr(stat)
+ return qids, file, mask, attr, nil
+}
+
+func (l *localFile) walk(names []string) ([]p9.QID, p9.File, syscall.Stat_t, error) {
// Duplicate current file if 'names' is empty.
if len(names) == 0 {
- newFile, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) {
+ newFile, readable, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) {
return reopenProcFd(l.file, openFlags|mode)
})
if err != nil {
- return nil, nil, extractErrno(err)
+ return nil, nil, syscall.Stat_t{}, extractErrno(err)
}
- stat, err := stat(newFile.FD())
+ stat, err := fstat(newFile.FD())
if err != nil {
- newFile.Close()
- return nil, nil, extractErrno(err)
+ _ = newFile.Close()
+ return nil, nil, syscall.Stat_t{}, extractErrno(err)
}
c := &localFile{
- attachPoint: l.attachPoint,
- hostPath: l.hostPath,
- file: newFile,
- mode: invalidMode,
+ attachPoint: l.attachPoint,
+ hostPath: l.hostPath,
+ file: newFile,
+ mode: invalidMode,
+ fileType: l.fileType,
+ qid: l.attachPoint.makeQID(stat),
+ controlReadable: readable,
}
- return []p9.QID{l.attachPoint.makeQID(stat)}, c, nil
+ return []p9.QID{c.qid}, c, stat, nil
}
var qids []p9.QID
+ var lastStat syscall.Stat_t
last := l
for _, name := range names {
- f, path, err := openAnyFileFromParent(last, name)
+ f, path, readable, err := openAnyFileFromParent(last, name)
if last != l {
- last.Close()
+ _ = last.Close()
}
if err != nil {
- return nil, nil, extractErrno(err)
+ return nil, nil, syscall.Stat_t{}, extractErrno(err)
}
- stat, err := stat(f.FD())
+ lastStat, err = fstat(f.FD())
if err != nil {
- f.Close()
- return nil, nil, extractErrno(err)
+ _ = f.Close()
+ return nil, nil, syscall.Stat_t{}, extractErrno(err)
}
- c, err := newLocalFile(last.attachPoint, f, path, stat)
+ c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat)
if err != nil {
- f.Close()
- return nil, nil, extractErrno(err)
+ _ = f.Close()
+ return nil, nil, syscall.Stat_t{}, extractErrno(err)
}
- qids = append(qids, l.attachPoint.makeQID(stat))
+ qids = append(qids, c.qid)
last = c
}
- return qids, last, nil
+ return qids, last, lastStat, nil
}
// StatFS implements p9.File.
@@ -592,16 +593,20 @@ func (l *localFile) FSync() error {
// GetAttr implements p9.File.
func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) {
- stat, err := stat(l.file.FD())
+ stat, err := fstat(l.file.FD())
if err != nil {
return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err)
}
+ mask, attr := l.fillAttr(stat)
+ return l.qid, mask, attr, nil
+}
+func (l *localFile) fillAttr(stat syscall.Stat_t) (p9.AttrMask, p9.Attr) {
attr := p9.Attr{
Mode: p9.FileMode(stat.Mode),
UID: p9.UID(stat.Uid),
GID: p9.GID(stat.Gid),
- NLink: stat.Nlink,
+ NLink: uint64(stat.Nlink),
RDev: stat.Rdev,
Size: uint64(stat.Size),
BlockSize: uint64(stat.Blksize),
@@ -625,8 +630,7 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error)
MTime: true,
CTime: true,
}
-
- return l.attachPoint.makeQID(stat), valid, attr, nil
+ return valid, attr
}
// SetAttr implements p9.File. Due to mismatch in file API, options
@@ -667,7 +671,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
// Check if it's possible to use cached file, or if another one needs to be
// opened for write.
f := l.file
- if l.ft == regular && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite {
+ if l.fileType == syscall.S_IFREG && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite {
var err error
f, err = reopenProcFd(l.file, openFlags|os.O_WRONLY)
if err != nil {
@@ -723,7 +727,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
}
}
- if l.ft == symlink {
+ if l.fileType == syscall.S_IFLNK {
// utimensat operates different that other syscalls. To operate on a
// symlink it *requires* AT_SYMLINK_NOFOLLOW with dirFD and a non-empty
// name.
@@ -765,6 +769,22 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
return err
}
+func (*localFile) GetXattr(string, uint64) (string, error) {
+ return "", syscall.EOPNOTSUPP
+}
+
+func (*localFile) SetXattr(string, string, uint32) error {
+ return syscall.EOPNOTSUPP
+}
+
+func (*localFile) ListXattr(uint64) (map[string]struct{}, error) {
+ return nil, syscall.EOPNOTSUPP
+}
+
+func (*localFile) RemoveXattr(string) error {
+ return syscall.EOPNOTSUPP
+}
+
// Allocate implements p9.File.
func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error {
if !l.isOpen() {
@@ -778,7 +798,7 @@ func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error
}
// Rename implements p9.File; this should never be called.
-func (l *localFile) Rename(p9.File, string) error {
+func (*localFile) Rename(p9.File, string) error {
panic("rename called directly")
}
@@ -846,7 +866,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.
if err := unix.Symlinkat(target, l.file.FD(), newName); err != nil {
return p9.QID{}, extractErrno(err)
}
- cu := specutils.MakeCleanup(func() {
+ cu := cleanup.Make(func() {
// Best effort attempt to remove the symlink in case of failure.
if err := syscall.Unlinkat(l.file.FD(), newName); err != nil {
log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, newName), err)
@@ -864,7 +884,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.
if err := fchown(f.FD(), uid, gid); err != nil {
return p9.QID{}, extractErrno(err)
}
- stat, err := stat(f.FD())
+ stat, err := fstat(f.FD())
if err != nil {
return p9.QID{}, extractErrno(err)
}
@@ -891,13 +911,39 @@ func (l *localFile) Link(target p9.File, newName string) error {
}
// Mknod implements p9.File.
-//
-// Not implemented.
-func (*localFile) Mknod(_ string, _ p9.FileMode, _ uint32, _ uint32, _ p9.UID, _ p9.GID) (p9.QID, error) {
+func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, _ p9.UID, _ p9.GID) (p9.QID, error) {
+ conf := l.attachPoint.conf
+ if conf.ROMount {
+ if conf.PanicOnWrite {
+ panic("attempt to write to RO mount")
+ }
+ return p9.QID{}, syscall.EROFS
+ }
+
+ hostPath := path.Join(l.hostPath, name)
+
+ // Return EEXIST if the file already exists.
+ if _, err := stat(hostPath); err == nil {
+ return p9.QID{}, syscall.EEXIST
+ }
+
// From mknod(2) man page:
// "EPERM: [...] if the filesystem containing pathname does not support
// the type of node requested."
- return p9.QID{}, syscall.EPERM
+ if mode.FileType() != p9.ModeRegular {
+ return p9.QID{}, syscall.EPERM
+ }
+
+ // Allow Mknod to create regular files.
+ if err := syscall.Mknod(hostPath, uint32(mode), 0); err != nil {
+ return p9.QID{}, err
+ }
+
+ stat, err := stat(hostPath)
+ if err != nil {
+ return p9.QID{}, extractErrno(err)
+ }
+ return l.attachPoint.makeQID(stat), nil
}
// UnlinkAt implements p9.File.
@@ -933,9 +979,12 @@ func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) {
skip := uint64(0)
- // Check if the file is at the correct position already. If not, seek to the
- // beginning and read the entire directory again.
- if l.lastDirentOffset != offset {
+ // Check if the file is at the correct position already. If not, seek to
+ // the beginning and read the entire directory again. We always seek if
+ // offset is 0, since this is side-effectual (equivalent to rewinddir(3),
+ // which causes the directory stream to resynchronize with the directory's
+ // current contents).
+ if l.lastDirentOffset != offset || offset == 0 {
if _, err := syscall.Seek(l.file.FD(), 0, 0); err != nil {
return nil, extractErrno(err)
}
@@ -955,14 +1004,14 @@ func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) {
}
func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) ([]p9.Dirent, error) {
+ var dirents []p9.Dirent
+
// Limit 'count' to cap the slice size that is returned.
const maxCount = 100000
if count > maxCount {
count = maxCount
}
- dirents := make([]p9.Dirent, 0, count)
-
// Pre-allocate buffers that will be reused to get partial results.
direntsBuf := make([]byte, 8192)
names := make([]string, 0, 100)
@@ -1063,13 +1112,13 @@ func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) {
}
if err := syscall.SetNonblock(f, true); err != nil {
- syscall.Close(f)
+ _ = syscall.Close(f)
return nil, err
}
sa := syscall.SockaddrUnix{Name: l.hostPath}
if err := syscall.Connect(f, &sa); err != nil {
- syscall.Close(f)
+ _ = syscall.Close(f)
return nil, err
}
diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go
new file mode 100644
index 000000000..5d4aab597
--- /dev/null
+++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build amd64
+
+package fsgofer
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+func statAt(dirFd int, name string) (syscall.Stat_t, error) {
+ nameBytes, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return syscall.Stat_t{}, err
+ }
+ namePtr := unsafe.Pointer(nameBytes)
+
+ var stat syscall.Stat_t
+ statPtr := unsafe.Pointer(&stat)
+
+ if _, _, errno := syscall.Syscall6(
+ syscall.SYS_NEWFSTATAT,
+ uintptr(dirFd),
+ uintptr(namePtr),
+ uintptr(statPtr),
+ linux.AT_SYMLINK_NOFOLLOW,
+ 0,
+ 0); errno != 0 {
+
+ return syscall.Stat_t{}, syserr.FromHost(errno).ToError()
+ }
+ return stat, nil
+}
diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go
new file mode 100644
index 000000000..8041fd352
--- /dev/null
+++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go
@@ -0,0 +1,49 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build arm64
+
+package fsgofer
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+func statAt(dirFd int, name string) (syscall.Stat_t, error) {
+ nameBytes, err := syscall.BytePtrFromString(name)
+ if err != nil {
+ return syscall.Stat_t{}, err
+ }
+ namePtr := unsafe.Pointer(nameBytes)
+
+ var stat syscall.Stat_t
+ statPtr := unsafe.Pointer(&stat)
+
+ if _, _, errno := syscall.Syscall6(
+ syscall.SYS_FSTATAT,
+ uintptr(dirFd),
+ uintptr(namePtr),
+ uintptr(statPtr),
+ linux.AT_SYMLINK_NOFOLLOW,
+ 0,
+ 0); errno != 0 {
+
+ return syscall.Stat_t{}, syserr.FromHost(errno).ToError()
+ }
+ return stat, nil
+}
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index 05af7e397..94f167417 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -26,6 +26,19 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/p9"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite}
+
+var (
+ allTypes = []uint32{syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK}
+
+ // allConfs is set in init().
+ allConfs []Config
+
+ rwConfs = []Config{{ROMount: false}}
+ roConfs = []Config{{ROMount: true}}
)
func init() {
@@ -39,6 +52,13 @@ func init() {
}
}
+func configTestName(config *Config) string {
+ if config.ROMount {
+ return "ROMount"
+ }
+ return "RWMount"
+}
+
func assertPanic(t *testing.T, f func()) {
defer func() {
if r := recover(); r == nil {
@@ -88,71 +108,76 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error {
return nil
}
-var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite}
-
-var (
- allTypes = []fileType{regular, directory, symlink}
-
- // allConfs is set in init() above.
- allConfs []Config
-
- rwConfs = []Config{{ROMount: false}}
- roConfs = []Config{{ROMount: true}}
-)
-
type state struct {
- root *localFile
- file *localFile
- conf Config
- ft fileType
+ root *localFile
+ file *localFile
+ conf Config
+ fileType uint32
}
func (s state) String() string {
- return fmt.Sprintf("type(%v)", s.ft)
+ return fmt.Sprintf("type(%v)", s.fileType)
+}
+
+func typeName(fileType uint32) string {
+ switch fileType {
+ case syscall.S_IFREG:
+ return "file"
+ case syscall.S_IFDIR:
+ return "directory"
+ case syscall.S_IFLNK:
+ return "symlink"
+ default:
+ panic(fmt.Sprintf("invalid file type for test: %d", fileType))
+ }
}
func runAll(t *testing.T, test func(*testing.T, state)) {
runCustom(t, allTypes, allConfs, test)
}
-func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testing.T, state)) {
+func runCustom(t *testing.T, types []uint32, confs []Config, test func(*testing.T, state)) {
for _, c := range confs {
- t.Logf("Config: %+v", c)
-
for _, ft := range types {
- t.Logf("File type: %v", ft)
+ name := fmt.Sprintf("%s/%s", configTestName(&c), typeName(ft))
+ t.Run(name, func(t *testing.T) {
+ path, name, err := setup(ft)
+ if err != nil {
+ t.Fatalf("%v", err)
+ }
+ defer os.RemoveAll(path)
- path, name, err := setup(ft)
- if err != nil {
- t.Fatalf("%v", err)
- }
- defer os.RemoveAll(path)
+ a, err := NewAttachPoint(path, c)
+ if err != nil {
+ t.Fatalf("NewAttachPoint failed: %v", err)
+ }
+ root, err := a.Attach()
+ if err != nil {
+ t.Fatalf("Attach failed, err: %v", err)
+ }
- a, err := NewAttachPoint(path, c)
- if err != nil {
- t.Fatalf("NewAttachPoint failed: %v", err)
- }
- root, err := a.Attach()
- if err != nil {
- t.Fatalf("Attach failed, err: %v", err)
- }
+ _, file, err := root.Walk([]string{name})
+ if err != nil {
+ root.Close()
+ t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err)
+ }
- _, file, err := root.Walk([]string{name})
- if err != nil {
+ st := state{
+ root: root.(*localFile),
+ file: file.(*localFile),
+ conf: c,
+ fileType: ft,
+ }
+ test(t, st)
+ file.Close()
root.Close()
- t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err)
- }
-
- st := state{root: root.(*localFile), file: file.(*localFile), conf: c, ft: ft}
- test(t, st)
- file.Close()
- root.Close()
+ })
}
}
}
-func setup(ft fileType) (string, string, error) {
- path, err := ioutil.TempDir("", "root-")
+func setup(fileType uint32) (string, string, error) {
+ path, err := ioutil.TempDir(testutil.TmpDir(), "root-")
if err != nil {
return "", "", fmt.Errorf("ioutil.TempDir() failed, err: %v", err)
}
@@ -169,26 +194,26 @@ func setup(ft fileType) (string, string, error) {
defer root.Close()
var name string
- switch ft {
- case regular:
+ switch fileType {
+ case syscall.S_IFREG:
name = "file"
_, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
if err != nil {
return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err)
}
defer f.Close()
- case directory:
+ case syscall.S_IFDIR:
name = "dir"
if _, err := root.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
return "", "", fmt.Errorf("root.MkDir(%q) failed, err: %v", name, err)
}
- case symlink:
+ case syscall.S_IFLNK:
name = "symlink"
if _, err := root.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
return "", "", fmt.Errorf("root.Symlink(%q) failed, err: %v", name, err)
}
default:
- panic(fmt.Sprintf("unknown file type %v", ft))
+ panic(fmt.Sprintf("unknown file type %v", fileType))
}
return path, name, nil
}
@@ -202,7 +227,7 @@ func createFile(dir *localFile, name string) (*localFile, error) {
}
func TestReadWrite(t *testing.T) {
- runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
child, err := createFile(s.file, "test")
if err != nil {
t.Fatalf("%v: createFile() failed, err: %v", s, err)
@@ -232,7 +257,7 @@ func TestReadWrite(t *testing.T) {
}
func TestCreate(t *testing.T) {
- runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
for i, flags := range allOpenFlags {
_, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid()))
if err != nil {
@@ -249,7 +274,7 @@ func TestCreate(t *testing.T) {
// TestReadWriteDup tests that a file opened in any mode can be dup'ed and
// reopened in any other mode.
func TestReadWriteDup(t *testing.T) {
- runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
child, err := createFile(s.file, "test")
if err != nil {
t.Fatalf("%v: createFile() failed, err: %v", s, err)
@@ -291,7 +316,7 @@ func TestReadWriteDup(t *testing.T) {
}
func TestUnopened(t *testing.T) {
- runCustom(t, []fileType{regular}, allConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{syscall.S_IFREG}, allConfs, func(t *testing.T, s state) {
b := []byte("foobar")
if _, err := s.file.WriteAt(b, 0); err != syscall.EBADF {
t.Errorf("%v: WriteAt() should have failed, got: %v, expected: syscall.EBADF", s, err)
@@ -308,6 +333,32 @@ func TestUnopened(t *testing.T) {
})
}
+// TestOpenOPath is a regression test to ensure that a file that cannot be open
+// for read is allowed to be open. This was happening because the control file
+// was open with O_PATH, but Open() was not checking for it and allowing the
+// control file to be reused.
+func TestOpenOPath(t *testing.T) {
+ runCustom(t, []uint32{syscall.S_IFREG}, rwConfs, func(t *testing.T, s state) {
+ // Fist remove all permissions on the file.
+ if err := s.file.SetAttr(p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(0)}); err != nil {
+ t.Fatalf("SetAttr(): %v", err)
+ }
+ // Then walk to the file again to open a new control file.
+ filename := filepath.Base(s.file.hostPath)
+ _, newFile, err := s.root.Walk([]string{filename})
+ if err != nil {
+ t.Fatalf("root.Walk(%q): %v", filename, err)
+ }
+
+ if newFile.(*localFile).controlReadable {
+ t.Fatalf("control file didn't open with O_PATH: %+v", newFile)
+ }
+ if _, _, _, err := newFile.Open(p9.ReadOnly); err != syscall.EACCES {
+ t.Fatalf("Open() should have failed, got: %v, wanted: EACCES", err)
+ }
+ })
+}
+
func SetGetAttr(l *localFile, valid p9.SetAttrMask, attr p9.SetAttr) (p9.Attr, error) {
if err := l.SetAttr(valid, attr); err != nil {
return p9.Attr{}, err
@@ -324,7 +375,7 @@ func TestSetAttrPerm(t *testing.T) {
valid := p9.SetAttrMask{Permissions: true}
attr := p9.SetAttr{Permissions: 0777}
got, err := SetGetAttr(s.file, valid, attr)
- if s.ft == symlink {
+ if s.fileType == syscall.S_IFLNK {
if err == nil {
t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions)
}
@@ -345,7 +396,7 @@ func TestSetAttrSize(t *testing.T) {
valid := p9.SetAttrMask{Size: true}
attr := p9.SetAttr{Size: size}
got, err := SetGetAttr(s.file, valid, attr)
- if s.ft == symlink || s.ft == directory {
+ if s.fileType == syscall.S_IFLNK || s.fileType == syscall.S_IFDIR {
if err == nil {
t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions)
}
@@ -427,7 +478,7 @@ func TestLink(t *testing.T) {
}
err = dir.Link(s.file, linkFile)
- if s.ft == directory {
+ if s.fileType == syscall.S_IFDIR {
if err != syscall.EPERM {
t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: syscall.EPERM", s, linkFile, err)
}
@@ -485,7 +536,7 @@ func TestROMountPanics(t *testing.T) {
}
func TestWalkNotFound(t *testing.T) {
- runCustom(t, []fileType{directory}, allConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{syscall.S_IFDIR}, allConfs, func(t *testing.T, s state) {
if _, _, err := s.file.Walk([]string{"nobody-here"}); err != syscall.ENOENT {
t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: syscall.ENOENT", s, "nobody-here", err)
}
@@ -506,7 +557,7 @@ func TestWalkDup(t *testing.T) {
}
func TestReaddir(t *testing.T) {
- runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) {
+ runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) {
name := "dir"
if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil {
t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err)
diff --git a/runsc/fsgofer/fsgofer_unsafe.go b/runsc/fsgofer/fsgofer_unsafe.go
index ff2556aee..542b54365 100644
--- a/runsc/fsgofer/fsgofer_unsafe.go
+++ b/runsc/fsgofer/fsgofer_unsafe.go
@@ -18,34 +18,9 @@ import (
"syscall"
"unsafe"
- "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/syserr"
)
-func statAt(dirFd int, name string) (syscall.Stat_t, error) {
- nameBytes, err := syscall.BytePtrFromString(name)
- if err != nil {
- return syscall.Stat_t{}, err
- }
- namePtr := unsafe.Pointer(nameBytes)
-
- var stat syscall.Stat_t
- statPtr := unsafe.Pointer(&stat)
-
- if _, _, errno := syscall.Syscall6(
- syscall.SYS_NEWFSTATAT,
- uintptr(dirFd),
- uintptr(namePtr),
- uintptr(statPtr),
- linux.AT_SYMLINK_NOFOLLOW,
- 0,
- 0); errno != 0 {
-
- return syscall.Stat_t{}, syserr.FromHost(errno).ToError()
- }
- return stat, nil
-}
-
func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) error {
// utimensat(2) doesn't accept empty name, instead name must be nil to make it
// operate directly on 'dirFd' unlike other *at syscalls.
diff --git a/runsc/main.go b/runsc/main.go
index ae906c661..69cb505fa 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -26,8 +26,7 @@ import (
"path/filepath"
"strings"
"syscall"
-
- "flag"
+ "time"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
@@ -35,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/cmd"
+ "gvisor.dev/gvisor/runsc/flag"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -46,15 +46,19 @@ var (
logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.")
debug = flag.Bool("debug", false, "enable debug logging.")
showVersion = flag.Bool("version", false, "show version and exit.")
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.")
// These flags are unique to runsc, and are used to configure parts of the
// system that are not covered by the runtime spec.
// Debugging flags.
debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.")
+ panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.")
logPackets = flag.Bool("log-packets", false, "enable network packet logging.")
logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.")
debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.")
+ panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.")
debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.")
alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.")
@@ -67,11 +71,14 @@ var (
platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.")
network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.")
- softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware ofload can't be enabled.")
+ softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.")
+ txChecksumOffload = flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.")
+ rxChecksumOffload = flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.")
+ qDisc = flag.String("qdisc", "fifo", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.")
fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.")
overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
- overlayfsStaleRead = flag.Bool("overlayfs-stale-read", false, "reopen cached FDs after a file is opened for write to workaround overlayfs limitation on kernels before 4.19.")
+ overlayfsStaleRead = flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem")
watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.")
panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.")
profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).")
@@ -79,6 +86,9 @@ var (
numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.")
rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.")
+ cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)")
+ vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.")
+ fuseEnabled = flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.")
// Test flags, not to be used outside tests, ever.
testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
@@ -113,8 +123,8 @@ func main() {
subcommands.Register(new(cmd.Resume), "")
subcommands.Register(new(cmd.Run), "")
subcommands.Register(new(cmd.Spec), "")
- subcommands.Register(new(cmd.Start), "")
subcommands.Register(new(cmd.State), "")
+ subcommands.Register(new(cmd.Start), "")
subcommands.Register(new(cmd.Wait), "")
// Register internal commands with the internal group name. This causes
@@ -124,6 +134,7 @@ func main() {
subcommands.Register(new(cmd.Boot), internalGroup)
subcommands.Register(new(cmd.Debug), internalGroup)
subcommands.Register(new(cmd.Gofer), internalGroup)
+ subcommands.Register(new(cmd.Statefile), internalGroup)
// All subcommands must be registered before flag parsing.
flag.Parse()
@@ -136,6 +147,12 @@ func main() {
os.Exit(0)
}
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ if *systemdCgroup {
+ fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193")
+ os.Exit(1)
+ }
+
var errorLogger io.Writer
if *logFD > -1 {
errorLogger = os.NewFile(uintptr(*logFD), "error log file")
@@ -185,6 +202,11 @@ func main() {
cmd.Fatalf("%v", err)
}
+ queueingDiscipline, err := boot.MakeQueueingDiscipline(*qDisc)
+ if err != nil {
+ cmd.Fatalf("%s", err)
+ }
+
// Sets the reference leak check mode. Also set it in config below to
// propagate it to child processes.
refs.SetLeakMode(refsLeakMode)
@@ -196,6 +218,7 @@ func main() {
LogFilename: *logFilename,
LogFormat: *logFormat,
DebugLog: *debugLog,
+ PanicLog: *panicLog,
DebugLogFormat: *debugLogFormat,
FileAccess: fsAccess,
FSGoferHostUDS: *fsGoferHostUDS,
@@ -203,6 +226,8 @@ func main() {
Network: netType,
HardwareGSO: *hardwareGSO,
SoftwareGSO: *softwareGSO,
+ TXChecksumOffload: *txChecksumOffload,
+ RXChecksumOffload: *rxChecksumOffload,
LogPackets: *logPackets,
Platform: platformType,
Strace: *strace,
@@ -216,7 +241,10 @@ func main() {
AlsoLogToStderr: *alsoLogToStderr,
ReferenceLeakMode: refsLeakMode,
OverlayfsStaleRead: *overlayfsStaleRead,
-
+ CPUNumFromQuota: *cpuNumFromQuota,
+ VFS2: *vfs2Enabled,
+ FUSE: *fuseEnabled,
+ QDisc: queueingDiscipline,
TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
TestOnlyTestNameEnv: *testOnlyTestNameEnv,
}
@@ -229,26 +257,24 @@ func main() {
log.SetLevel(log.Debug)
}
+ // Logging will include the local date and time via the time package.
+ //
+ // On first use, time.Local initializes the local time zone, which
+ // involves opening tzdata files on the host. Since this requires
+ // opening host files, it must be done before syscall filter
+ // installation.
+ //
+ // Generally there will be a log message before filter installation
+ // that will force initialization, but force initialization here in
+ // case that does not occur.
+ _ = time.Local.String()
+
subcommand := flag.CommandLine.Arg(0)
var e log.Emitter
if *debugLogFD > -1 {
f := os.NewFile(uintptr(*debugLogFD), "debug log file")
- // Quick sanity check to make sure no other commands get passed
- // a log fd (they should use log dir instead).
- if subcommand != "boot" && subcommand != "gofer" {
- cmd.Fatalf("flag --debug-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
- }
-
- // If we are the boot process, then we own our stdio FDs and can do what we
- // want with them. Since Docker and Containerd both eat boot's stderr, we
- // dup our stderr to the provided log FD so that panics will appear in the
- // logs, rather than just disappear.
- if err := syscall.Dup3(int(f.Fd()), int(os.Stderr.Fd()), 0); err != nil {
- cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err)
- }
-
e = newEmitter(*debugLogFormat, f)
} else if *debugLog != "" {
@@ -264,8 +290,26 @@ func main() {
e = newEmitter("text", ioutil.Discard)
}
- if *alsoLogToStderr {
- e = log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)}
+ if *panicLogFD > -1 || *debugLogFD > -1 {
+ fd := *panicLogFD
+ if fd < 0 {
+ fd = *debugLogFD
+ }
+ // Quick sanity check to make sure no other commands get passed
+ // a log fd (they should use log dir instead).
+ if subcommand != "boot" && subcommand != "gofer" {
+ cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand)
+ }
+
+ // If we are the boot process, then we own our stdio FDs and can do what we
+ // want with them. Since Docker and Containerd both eat boot's stderr, we
+ // dup our stderr to the provided log FD so that panics will appear in the
+ // logs, rather than just disappear.
+ if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
+ cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
+ }
+ } else if *alsoLogToStderr {
+ e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)}
}
log.SetTarget(e)
@@ -281,6 +325,7 @@ func main() {
log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay)
log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets)
log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls)
+ log.Infof("\t\tVFS2 enabled: %v", conf.VFS2)
log.Infof("***************************")
if *testOnlyAllowRunAsCurrentUserWithoutChroot {
@@ -297,7 +342,7 @@ func main() {
log.Infof("Exiting with status: %v", ws)
if ws.Signaled() {
// No good way to return it, emulate what the shell does. Maybe raise
- // signall to self?
+ // signal to self?
os.Exit(128 + int(ws.Signal()))
}
os.Exit(ws.ExitStatus())
@@ -310,11 +355,11 @@ func main() {
func newEmitter(format string, logFile io.Writer) log.Emitter {
switch format {
case "text":
- return &log.GoogleEmitter{&log.Writer{Next: logFile}}
+ return log.GoogleEmitter{&log.Writer{Next: logFile}}
case "json":
- return &log.JSONEmitter{log.Writer{Next: logFile}}
+ return log.JSONEmitter{&log.Writer{Next: logFile}}
case "json-k8s":
- return &log.K8sJSONEmitter{log.Writer{Next: logFile}}
+ return log.K8sJSONEmitter{&log.Writer{Next: logFile}}
}
cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format)
panic("unreachable")
diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD
index 27459e6d1..2b9d4549d 100644
--- a/runsc/sandbox/BUILD
+++ b/runsc/sandbox/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -9,16 +9,18 @@ go_library(
"network_unsafe.go",
"sandbox.go",
],
- importpath = "gvisor.dev/gvisor/runsc/sandbox",
visibility = [
"//runsc:__subpackages__",
],
deps = [
+ "//pkg/cleanup",
"//pkg/control/client",
"//pkg/control/server",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/platform",
+ "//pkg/sync",
+ "//pkg/tcpip/header",
"//pkg/tcpip/stack",
"//pkg/urpc",
"//runsc/boot",
@@ -27,7 +29,7 @@ go_library(
"//runsc/console",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
"@com_github_vishvananda_netlink//:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index d42de0176..817a923ad 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -21,13 +21,13 @@ import (
"path/filepath"
"runtime"
"strconv"
- "strings"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/boot"
@@ -62,7 +62,7 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi
// Build the path to the net namespace of the sandbox process.
// This is what we will copy.
nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net")
- if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.NumNetworkChannels); err != nil {
+ if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.TXChecksumOffload, conf.RXChecksumOffload, conf.NumNetworkChannels, conf.QDisc); err != nil {
return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err)
}
case boot.NetworkHost:
@@ -74,30 +74,8 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi
}
func createDefaultLoopbackInterface(conn *urpc.Client) error {
- link := boot.LoopbackLink{
- Name: "lo",
- Addresses: []net.IP{
- net.IP("\x7f\x00\x00\x01"),
- net.IPv6loopback,
- },
- Routes: []boot.Route{
- {
- Destination: net.IPNet{
-
- IP: net.IPv4(0x7f, 0, 0, 0),
- Mask: net.IPv4Mask(0xff, 0, 0, 0),
- },
- },
- {
- Destination: net.IPNet{
- IP: net.IPv6loopback,
- Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
- },
- },
- },
- }
if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &boot.CreateLinksAndRoutesArgs{
- LoopbackLinks: []boot.LoopbackLink{link},
+ LoopbackLinks: []boot.LoopbackLink{boot.DefaultLoopbackLink},
}, nil); err != nil {
return fmt.Errorf("creating loopback link and routes: %v", err)
}
@@ -137,7 +115,7 @@ func isRootNS() (bool, error) {
// createInterfacesAndRoutesFromNS scrapes the interface and routes from the
// net namespace with the given path, creates them in the sandbox, and removes
// them from the host.
-func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, numNetworkChannels int) error {
+func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, txChecksumOffload bool, rxChecksumOffload bool, numNetworkChannels int, qDisc boot.QueueingDiscipline) error {
// Join the network namespace that we will be copying.
restore, err := joinNetNS(nsPath)
if err != nil {
@@ -156,7 +134,6 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
return err
}
if isRoot {
-
return fmt.Errorf("cannot run with network enabled in root network namespace")
}
@@ -173,53 +150,59 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
return fmt.Errorf("fetching interface addresses for %q: %v", iface.Name, err)
}
- // We build our own loopback devices.
+ // We build our own loopback device.
if iface.Flags&net.FlagLoopback != 0 {
- links, err := loopbackLinks(iface, allAddrs)
+ link, err := loopbackLink(iface, allAddrs)
if err != nil {
- return fmt.Errorf("getting loopback routes and links for iface %q: %v", iface.Name, err)
+ return fmt.Errorf("getting loopback link for iface %q: %v", iface.Name, err)
}
- args.LoopbackLinks = append(args.LoopbackLinks, links...)
+ args.LoopbackLinks = append(args.LoopbackLinks, link)
continue
}
- // Keep only IPv4 addresses.
- var ip4addrs []*net.IPNet
+ var ipAddrs []*net.IPNet
for _, ifaddr := range allAddrs {
ipNet, ok := ifaddr.(*net.IPNet)
if !ok {
return fmt.Errorf("address is not IPNet: %+v", ifaddr)
}
- if ipNet.IP.To4() == nil {
- log.Warningf("IPv6 is not supported, skipping: %v", ipNet)
- continue
- }
- ip4addrs = append(ip4addrs, ipNet)
+ ipAddrs = append(ipAddrs, ipNet)
}
- if len(ip4addrs) == 0 {
- log.Warningf("No IPv4 address found for interface %q, skipping", iface.Name)
+ if len(ipAddrs) == 0 {
+ log.Warningf("No usable IP addresses found for interface %q, skipping", iface.Name)
continue
}
// Scrape the routes before removing the address, since that
// will remove the routes as well.
- routes, def, err := routesForIface(iface)
+ routes, defv4, defv6, err := routesForIface(iface)
if err != nil {
return fmt.Errorf("getting routes for interface %q: %v", iface.Name, err)
}
- if def != nil {
- if !args.DefaultGateway.Route.Empty() {
- return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, def, args.DefaultGateway)
+ if defv4 != nil {
+ if !args.Defaultv4Gateway.Route.Empty() {
+ return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, defv4, args.Defaultv4Gateway)
+ }
+ args.Defaultv4Gateway.Route = *defv4
+ args.Defaultv4Gateway.Name = iface.Name
+ }
+
+ if defv6 != nil {
+ if !args.Defaultv6Gateway.Route.Empty() {
+ return fmt.Errorf("more than one default route found, interface: %v, route: %v, default route: %+v", iface.Name, defv6, args.Defaultv6Gateway)
}
- args.DefaultGateway.Route = *def
- args.DefaultGateway.Name = iface.Name
+ args.Defaultv6Gateway.Route = *defv6
+ args.Defaultv6Gateway.Name = iface.Name
}
link := boot.FDBasedLink{
- Name: iface.Name,
- MTU: iface.MTU,
- Routes: routes,
- NumChannels: numNetworkChannels,
+ Name: iface.Name,
+ MTU: iface.MTU,
+ Routes: routes,
+ TXChecksumOffload: txChecksumOffload,
+ RXChecksumOffload: rxChecksumOffload,
+ NumChannels: numNetworkChannels,
+ QDisc: qDisc,
}
// Get the link for the interface.
@@ -247,6 +230,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
}
args.FilePayload.Files = append(args.FilePayload.Files, socketEntry.deviceFile)
}
+
if link.GSOMaxSize == 0 && softwareGSO {
// Hardware GSO is disabled. Let's enable software GSO.
link.GSOMaxSize = stack.SoftwareGSOMaxSize
@@ -255,7 +239,7 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
// Collect the addresses for the interface, enable forwarding,
// and remove them from the host.
- for _, addr := range ip4addrs {
+ for _, addr := range ipAddrs {
link.Addresses = append(link.Addresses, addr.IP)
// Steal IP address from NIC.
@@ -316,81 +300,96 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
}
}
- // Use SO_RCVBUFFORCE because on linux the receive buffer for an
- // AF_PACKET socket is capped by "net.core.rmem_max". rmem_max
- // defaults to a unusually low value of 208KB. This is too low
- // for gVisor to be able to receive packets at high throughputs
- // without incurring packet drops.
- const rcvBufSize = 4 << 20 // 4MB.
+ // Use SO_RCVBUFFORCE/SO_SNDBUFFORCE because on linux the receive/send buffer
+ // for an AF_PACKET socket is capped by "net.core.rmem_max/wmem_max".
+ // wmem_max/rmem_max default to a unusually low value of 208KB. This is too low
+ // for gVisor to be able to receive packets at high throughputs without
+ // incurring packet drops.
+ const bufSize = 4 << 20 // 4MB.
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", bufSize, err)
+ }
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, rcvBufSize); err != nil {
- return nil, fmt.Errorf("failed to increase socket rcv buffer to %d: %v", rcvBufSize, err)
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil {
+ return nil, fmt.Errorf("failed to increase socket snd buffer to %d: %v", bufSize, err)
}
+
return &socketEntry{deviceFile, gsoMaxSize}, nil
}
-// loopbackLinks collects the links for a loopback interface.
-func loopbackLinks(iface net.Interface, addrs []net.Addr) ([]boot.LoopbackLink, error) {
- var links []boot.LoopbackLink
+// loopbackLink returns the link with addresses and routes for a loopback
+// interface.
+func loopbackLink(iface net.Interface, addrs []net.Addr) (boot.LoopbackLink, error) {
+ link := boot.LoopbackLink{
+ Name: iface.Name,
+ }
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
- return nil, fmt.Errorf("address is not IPNet: %+v", addr)
+ return boot.LoopbackLink{}, fmt.Errorf("address is not IPNet: %+v", addr)
}
dst := *ipNet
dst.IP = dst.IP.Mask(dst.Mask)
- links = append(links, boot.LoopbackLink{
- Name: iface.Name,
- Addresses: []net.IP{ipNet.IP},
- Routes: []boot.Route{{
- Destination: dst,
- }},
+ link.Addresses = append(link.Addresses, ipNet.IP)
+ link.Routes = append(link.Routes, boot.Route{
+ Destination: dst,
})
}
- return links, nil
+ return link, nil
}
// routesForIface iterates over all routes for the given interface and converts
-// them to boot.Routes.
-func routesForIface(iface net.Interface) ([]boot.Route, *boot.Route, error) {
+// them to boot.Routes. It also returns the a default v4/v6 route if found.
+func routesForIface(iface net.Interface) ([]boot.Route, *boot.Route, *boot.Route, error) {
link, err := netlink.LinkByIndex(iface.Index)
if err != nil {
- return nil, nil, err
+ return nil, nil, nil, err
}
rs, err := netlink.RouteList(link, netlink.FAMILY_ALL)
if err != nil {
- return nil, nil, fmt.Errorf("getting routes from %q: %v", iface.Name, err)
+ return nil, nil, nil, fmt.Errorf("getting routes from %q: %v", iface.Name, err)
}
- var def *boot.Route
+ var defv4, defv6 *boot.Route
var routes []boot.Route
for _, r := range rs {
// Is it a default route?
if r.Dst == nil {
if r.Gw == nil {
- return nil, nil, fmt.Errorf("default route with no gateway %q: %+v", iface.Name, r)
- }
- if r.Gw.To4() == nil {
- log.Warningf("IPv6 is not supported, skipping default route: %v", r)
- continue
- }
- if def != nil {
- return nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, def, r)
+ return nil, nil, nil, fmt.Errorf("default route with no gateway %q: %+v", iface.Name, r)
}
// Create a catch all route to the gateway.
- def = &boot.Route{
- Destination: net.IPNet{
- IP: net.IPv4zero,
- Mask: net.IPMask(net.IPv4zero),
- },
- Gateway: r.Gw,
+ switch len(r.Gw) {
+ case header.IPv4AddressSize:
+ if defv4 != nil {
+ return nil, nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, defv4, r)
+ }
+ defv4 = &boot.Route{
+ Destination: net.IPNet{
+ IP: net.IPv4zero,
+ Mask: net.IPMask(net.IPv4zero),
+ },
+ Gateway: r.Gw,
+ }
+ case header.IPv6AddressSize:
+ if defv6 != nil {
+ return nil, nil, nil, fmt.Errorf("more than one default route found %q, def: %+v, route: %+v", iface.Name, defv6, r)
+ }
+
+ defv6 = &boot.Route{
+ Destination: net.IPNet{
+ IP: net.IPv6zero,
+ Mask: net.IPMask(net.IPv6zero),
+ },
+ Gateway: r.Gw,
+ }
+ default:
+ return nil, nil, nil, fmt.Errorf("unexpected address size for gateway: %+v for route: %+v", r.Gw, r)
}
continue
}
- if r.Dst.IP.To4() == nil {
- log.Warningf("IPv6 is not supported, skipping route: %v", r)
- continue
- }
+
dst := *r.Dst
dst.IP = dst.IP.Mask(dst.Mask)
routes = append(routes, boot.Route{
@@ -398,7 +397,7 @@ func routesForIface(iface net.Interface) ([]boot.Route, *boot.Route, error) {
Gateway: r.Gw,
})
}
- return routes, def, nil
+ return routes, defv4, defv6, nil
}
// removeAddress removes IP address from network device. It's equivalent to:
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index ee9327fc8..36bb0c9c9 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -18,21 +18,25 @@ package sandbox
import (
"context"
"fmt"
+ "io"
+ "math"
"os"
"os/exec"
"strconv"
- "sync"
+ "strings"
"syscall"
"time"
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/control/client"
"gvisor.dev/gvisor/pkg/control/server"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/urpc"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
@@ -116,7 +120,7 @@ func New(conf *boot.Config, args *Args) (*Sandbox, error) {
s := &Sandbox{ID: args.ID, Cgroup: args.Cgroup}
// The Cleanup object cleans up partially created sandboxes when an error
// occurs. Any errors occurring during cleanup itself are ignored.
- c := specutils.MakeCleanup(func() {
+ c := cleanup.Make(func() {
err := s.destroy()
log.Warningf("error destroying sandbox: %v", err)
})
@@ -141,7 +145,19 @@ func New(conf *boot.Config, args *Args) (*Sandbox, error) {
// Wait until the sandbox has booted.
b := make([]byte, 1)
if l, err := clientSyncFile.Read(b); err != nil || l != 1 {
- return nil, fmt.Errorf("waiting for sandbox to start: %v", err)
+ err := fmt.Errorf("waiting for sandbox to start: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), io.EOF.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return nil, fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return nil, err
}
c.Release()
@@ -368,8 +384,24 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
cmd.Args = append(cmd.Args, "--debug-log-fd="+strconv.Itoa(nextFD))
nextFD++
}
+ if conf.PanicLog != "" {
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
- cmd.Args = append(cmd.Args, "--panic-signal="+strconv.Itoa(int(syscall.SIGTERM)))
+ panicLogFile, err := specutils.DebugLogFile(conf.PanicLog, "panic", test)
+ if err != nil {
+ return fmt.Errorf("opening debug log file in %q: %v", conf.PanicLog, err)
+ }
+ defer panicLogFile.Close()
+ cmd.ExtraFiles = append(cmd.ExtraFiles, panicLogFile)
+ cmd.Args = append(cmd.Args, "--panic-log-fd="+strconv.Itoa(nextFD))
+ nextFD++
+ }
// Add the "boot" command to the args.
//
@@ -415,9 +447,13 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
- // If the platform needs a device FD we must pass it in.
- if deviceFile, err := deviceFileForPlatform(conf.Platform); err != nil {
+ gPlatform, err := platform.Lookup(conf.Platform)
+ if err != nil {
return err
+ }
+
+ if deviceFile, err := gPlatform.OpenDevice(); err != nil {
+ return fmt.Errorf("opening device file for platform %q: %v", gPlatform, err)
} else if deviceFile != nil {
defer deviceFile.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, deviceFile)
@@ -425,6 +461,12 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
+ // TODO(b/151157106): syscall tests fail by timeout if asyncpreemptoff
+ // isn't set.
+ if conf.Platform == "kvm" {
+ cmd.Env = append(cmd.Env, "GODEBUG=asyncpreemptoff=1")
+ }
+
// The current process' stdio must be passed to the application via the
// --stdio-fds flag. The stdio of the sandbox process itself must not
// be connected to the same FDs, otherwise we risk leaking sandbox
@@ -436,9 +478,7 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
// If the console control socket file is provided, then create a new
// pty master/slave pair and set the TTY on the sandbox process.
- if args.ConsoleSocket != "" {
- cmd.Args = append(cmd.Args, "--console=true")
-
+ if args.Spec.Process.Terminal && args.ConsoleSocket != "" {
// console.NewWithSocket will send the master on the given
// socket, and return the slave.
tty, err := console.NewWithSocket(args.ConsoleSocket)
@@ -502,7 +542,7 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
{Type: specs.UTSNamespace},
}
- if conf.Platform == platforms.Ptrace {
+ if gPlatform.Requirements().RequiresCurrentPIDNS {
// TODO(b/75837838): Also set a new PID namespace so that we limit
// access to other host processes.
log.Infof("Sandbox will be started in the current PID namespace")
@@ -563,45 +603,32 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nss = append(nss, specs.LinuxNamespace{Type: specs.UserNamespace})
cmd.Args = append(cmd.Args, "--setup-root")
+ const nobody = 65534
if conf.Rootless {
- log.Infof("Rootless mode: sandbox will run as root inside user namespace, mapped to the current user, uid: %d, gid: %d", os.Getuid(), os.Getgid())
+ log.Infof("Rootless mode: sandbox will run as nobody inside user namespace, mapped to the current user, uid: %d, gid: %d", os.Getuid(), os.Getgid())
cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
+ ContainerID: nobody,
HostID: os.Getuid(),
Size: 1,
},
}
cmd.SysProcAttr.GidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
+ ContainerID: nobody,
HostID: os.Getgid(),
Size: 1,
},
}
- cmd.SysProcAttr.Credential = &syscall.Credential{Uid: 0, Gid: 0}
} else {
// Map nobody in the new namespace to nobody in the parent namespace.
//
// A sandbox process will construct an empty
- // root for itself, so it has to have the CAP_SYS_ADMIN
- // capability.
- //
- // FIXME(b/122554829): The current implementations of
- // os/exec doesn't allow to set ambient capabilities if
- // a process is started in a new user namespace. As a
- // workaround, we start the sandbox process with the 0
- // UID and then it constructs a chroot and sets UID to
- // nobody. https://github.com/golang/go/issues/2315
- const nobody = 65534
+ // root for itself, so it has to have
+ // CAP_SYS_ADMIN and CAP_SYS_CHROOT capabilities.
cmd.SysProcAttr.UidMappings = []syscall.SysProcIDMap{
{
- ContainerID: 0,
- HostID: nobody - 1,
- Size: 1,
- },
- {
ContainerID: nobody,
HostID: nobody,
Size: 1,
@@ -614,11 +641,11 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
Size: 1,
},
}
-
- // Set credentials to run as user and group nobody.
- cmd.SysProcAttr.Credential = &syscall.Credential{Uid: 0, Gid: nobody}
}
+ // Set credentials to run as user and group nobody.
+ cmd.SysProcAttr.Credential = &syscall.Credential{Uid: nobody, Gid: nobody}
+ cmd.SysProcAttr.AmbientCaps = append(cmd.SysProcAttr.AmbientCaps, uintptr(capability.CAP_SYS_ADMIN), uintptr(capability.CAP_SYS_CHROOT))
} else {
return fmt.Errorf("can't run sandbox process as user nobody since we don't have CAP_SETUID or CAP_SETGID")
}
@@ -631,6 +658,26 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
if err != nil {
return fmt.Errorf("getting cpu count from cgroups: %v", err)
}
+ if conf.CPUNumFromQuota {
+ // Dropping below 2 CPUs can trigger application to disable
+ // locks that can lead do hard to debug errors, so just
+ // leaving two cores as reasonable default.
+ const minCPUs = 2
+
+ quota, err := s.Cgroup.CPUQuota()
+ if err != nil {
+ return fmt.Errorf("getting cpu qouta from cgroups: %v", err)
+ }
+ if n := int(math.Ceil(quota)); n > 0 {
+ if n < minCPUs {
+ n = minCPUs
+ }
+ if n < cpuNum {
+ // Only lower the cpu number.
+ cpuNum = n
+ }
+ }
+ }
cmd.Args = append(cmd.Args, "--cpu-num", strconv.Itoa(cpuNum))
mem, err := s.Cgroup.MemoryLimit()
@@ -656,6 +703,13 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
+ if args.Attached {
+ // Kill sandbox if parent process exits in attached mode.
+ cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
+ // Tells boot that any process it creates must have pdeathsig set.
+ cmd.Args = append(cmd.Args, "--attached")
+ }
+
// Add container as the last argument.
cmd.Args = append(cmd.Args, s.ID)
@@ -664,15 +718,22 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
log.Debugf("Donating FD %d: %q", i+3, f.Name())
}
- if args.Attached {
- // Kill sandbox if parent process exits in attached mode.
- cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
- }
-
log.Debugf("Starting sandbox: %s %v", binPath, cmd.Args)
log.Debugf("SysProcAttr: %+v", cmd.SysProcAttr)
if err := specutils.StartInNS(cmd, nss); err != nil {
- return fmt.Errorf("Sandbox: %v", err)
+ err := fmt.Errorf("starting sandbox: %v", err)
+ // If the sandbox failed to start, it may be because the binary
+ // permissions were incorrect. Check the bits and return a more helpful
+ // error message.
+ //
+ // NOTE: The error message is checked because error types are lost over
+ // rpc calls.
+ if strings.Contains(err.Error(), syscall.EACCES.Error()) {
+ if permsErr := checkBinaryPermissions(conf); permsErr != nil {
+ return fmt.Errorf("%v: %v", err, permsErr)
+ }
+ }
+ return err
}
s.child = true
s.Pid = cmd.Process.Pid
@@ -951,6 +1012,46 @@ func (s *Sandbox) StopCPUProfile() error {
return nil
}
+// BlockProfile writes a block profile to the given file.
+func (s *Sandbox) BlockProfile(f *os.File) error {
+ log.Debugf("Block profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.BlockProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q block profile: %v", s.ID, err)
+ }
+ return nil
+}
+
+// MutexProfile writes a mutex profile to the given file.
+func (s *Sandbox) MutexProfile(f *os.File) error {
+ log.Debugf("Mutex profile %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ opts := control.ProfileOpts{
+ FilePayload: urpc.FilePayload{
+ Files: []*os.File{f},
+ },
+ }
+ if err := conn.Call(boot.MutexProfile, &opts, nil); err != nil {
+ return fmt.Errorf("getting sandbox %q mutex profile: %v", s.ID, err)
+ }
+ return nil
+}
+
// StartTrace start trace writing to the given file.
func (s *Sandbox) StartTrace(f *os.File) error {
log.Debugf("Trace start %q", s.ID)
@@ -1004,16 +1105,22 @@ func (s *Sandbox) ChangeLogging(args control.LoggingArgs) error {
// DestroyContainer destroys the given container. If it is the root container,
// then the entire sandbox is destroyed.
func (s *Sandbox) DestroyContainer(cid string) error {
+ if err := s.destroyContainer(cid); err != nil {
+ // If the sandbox isn't running, the container has already been destroyed,
+ // ignore the error in this case.
+ if s.IsRunning() {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *Sandbox) destroyContainer(cid string) error {
if s.IsRootContainer(cid) {
log.Debugf("Destroying root container %q by destroying sandbox", cid)
return s.destroy()
}
- if !s.IsRunning() {
- // Sandbox isn't running anymore, container is already destroyed.
- return nil
- }
-
log.Debugf("Destroying container %q in sandbox %q", cid, s.ID)
conn, err := s.sandboxConnect()
if err != nil {
@@ -1069,3 +1176,31 @@ func deviceFileForPlatform(name string) (*os.File, error) {
}
return f, nil
}
+
+// checkBinaryPermissions verifies that the required binary bits are set on
+// the runsc executable.
+func checkBinaryPermissions(conf *boot.Config) error {
+ // All platforms need the other exe bit
+ neededBits := os.FileMode(0001)
+ if conf.Platform == platforms.Ptrace {
+ // Ptrace needs the other read bit
+ neededBits |= os.FileMode(0004)
+ }
+
+ exePath, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("getting exe path: %v", err)
+ }
+
+ // Check the permissions of the runsc binary and print an error if it
+ // doesn't match expectations.
+ info, err := os.Stat(exePath)
+ if err != nil {
+ return fmt.Errorf("stat file: %v", err)
+ }
+
+ if info.Mode().Perm()&neededBits != neededBits {
+ return fmt.Errorf(specutils.FaqErrorMsg("runsc-perms", fmt.Sprintf("%s does not have the correct permissions", exePath)))
+ }
+ return nil
+}
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
index 205638803..43851a22f 100644
--- a/runsc/specutils/BUILD
+++ b/runsc/specutils/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,7 +10,6 @@ go_library(
"namespace.go",
"specutils.go",
],
- importpath = "gvisor.dev/gvisor/runsc/specutils",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
@@ -18,7 +17,8 @@ go_library(
"//pkg/log",
"//pkg/sentry/kernel/auth",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
],
@@ -28,6 +28,6 @@ go_test(
name = "specutils_test",
size = "small",
srcs = ["specutils_test.go"],
- embed = [":specutils"],
- deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"],
+ library = ":specutils",
+ deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"],
)
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
index c7dd3051c..23001d67c 100644
--- a/runsc/specutils/namespace.go
+++ b/runsc/specutils/namespace.go
@@ -18,6 +18,7 @@ import (
"fmt"
"os"
"os/exec"
+ "os/signal"
"path/filepath"
"runtime"
"syscall"
@@ -252,13 +253,27 @@ func MaybeRunAsRoot() error {
},
Credential: &syscall.Credential{Uid: 0, Gid: 0},
GidMappingsEnableSetgroups: false,
+
+ // Make sure child is killed when the parent terminates.
+ Pdeathsig: syscall.SIGKILL,
}
cmd.Env = os.Environ()
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
- if err := cmd.Run(); err != nil {
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("re-executing self: %w", err)
+ }
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch)
+ go func() {
+ for {
+ // Forward all signals to child process.
+ cmd.Process.Signal(<-ch)
+ }
+ }()
+ if err := cmd.Wait(); err != nil {
if exit, ok := err.(*exec.ExitError); ok {
if ws, ok := exit.Sys().(syscall.WaitStatus); ok {
os.Exit(ws.ExitStatus())
@@ -266,7 +281,7 @@ func MaybeRunAsRoot() error {
log.Warningf("No wait status provided, exiting with -1: %v", err)
os.Exit(-1)
}
- return fmt.Errorf("re-executing self: %v", err)
+ return err
}
// Child completed with success.
os.Exit(0)
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index d3c2e4e78..5015c3a84 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -29,6 +29,7 @@ import (
"time"
"github.com/cenkalti/backoff"
+ "github.com/mohae/deepcopy"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
@@ -44,20 +45,31 @@ var ExePath = "/proc/self/exe"
var Version = specs.Version
// LogSpec logs the spec in a human-friendly way.
-func LogSpec(spec *specs.Spec) {
- log.Debugf("Spec: %+v", spec)
- log.Debugf("Spec.Hooks: %+v", spec.Hooks)
- log.Debugf("Spec.Linux: %+v", spec.Linux)
- if spec.Linux != nil && spec.Linux.Resources != nil {
- res := spec.Linux.Resources
- log.Debugf("Spec.Linux.Resources.Memory: %+v", res.Memory)
- log.Debugf("Spec.Linux.Resources.CPU: %+v", res.CPU)
- log.Debugf("Spec.Linux.Resources.BlockIO: %+v", res.BlockIO)
- log.Debugf("Spec.Linux.Resources.Network: %+v", res.Network)
- }
- log.Debugf("Spec.Process: %+v", spec.Process)
- log.Debugf("Spec.Root: %+v", spec.Root)
- log.Debugf("Spec.Mounts: %+v", spec.Mounts)
+func LogSpec(orig *specs.Spec) {
+ if !log.IsLogging(log.Debug) {
+ return
+ }
+
+ // Strip down parts of the spec that are not interesting.
+ spec := deepcopy.Copy(orig).(*specs.Spec)
+ if spec.Process != nil {
+ spec.Process.Capabilities = nil
+ }
+ if spec.Linux != nil {
+ spec.Linux.Seccomp = nil
+ spec.Linux.MaskedPaths = nil
+ spec.Linux.ReadonlyPaths = nil
+ if spec.Linux.Resources != nil {
+ spec.Linux.Resources.Devices = nil
+ }
+ }
+
+ out, err := json.MarshalIndent(spec, "", " ")
+ if err != nil {
+ log.Debugf("Failed to marshal spec: %v", err)
+ return
+ }
+ log.Debugf("Spec:\n%s", out)
}
// ValidateSpec validates that the spec is compatible with runsc.
@@ -92,6 +104,12 @@ func ValidateSpec(spec *specs.Spec) error {
log.Warningf("AppArmor profile %q is being ignored", spec.Process.ApparmorProfile)
}
+ // PR_SET_NO_NEW_PRIVS is assumed to always be set.
+ // See kernel.Task.updateCredsForExecLocked.
+ if !spec.Process.NoNewPrivileges {
+ log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.")
+ }
+
// TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox.
if spec.Linux != nil && spec.Linux.Seccomp != nil {
log.Warningf("Seccomp spec is being ignored")
@@ -438,36 +456,6 @@ func ContainsStr(strs []string, str string) bool {
return false
}
-// Cleanup allows defers to be aborted when cleanup needs to happen
-// conditionally. Usage:
-// c := MakeCleanup(func() { f.Close() })
-// defer c.Clean() // any failure before release is called will close the file.
-// ...
-// c.Release() // on success, aborts closing the file and return it.
-// return f
-type Cleanup struct {
- clean func()
-}
-
-// MakeCleanup creates a new Cleanup object.
-func MakeCleanup(f func()) Cleanup {
- return Cleanup{clean: f}
-}
-
-// Clean calls the cleanup function.
-func (c *Cleanup) Clean() {
- if c.clean != nil {
- c.clean()
- c.clean = nil
- }
-}
-
-// Release releases the cleanup from its duties, i.e. cleanup function is not
-// called after this point.
-func (c *Cleanup) Release() {
- c.clean = nil
-}
-
// RetryEintr retries the function until an error different than EINTR is
// returned.
func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) {
@@ -528,3 +516,8 @@ func EnvVar(env []string, name string) (string, bool) {
}
return "", false
}
+
+// FaqErrorMsg returns an error message pointing to the FAQ.
+func FaqErrorMsg(anchor, msg string) string {
+ return fmt.Sprintf("%s; see https://gvisor.dev/faq#%s for more details", msg, anchor)
+}
diff --git a/runsc/version_test.sh b/runsc/version_test.sh
index cc0ca3f05..747350654 100755
--- a/runsc/version_test.sh
+++ b/runsc/version_test.sh
@@ -16,7 +16,7 @@
set -euf -x -o pipefail
-readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc"
+readonly runsc="$1"
readonly version=$($runsc --version)
# Version should should not match VERSION, which is the default and which will
diff --git a/scripts/build.sh b/scripts/build.sh
deleted file mode 100755
index 0b3d1b316..000000000
--- a/scripts/build.sh
+++ /dev/null
@@ -1,79 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Install required packages for make_repository.sh et al.
-sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils
-
-# Build runsc.
-runsc=$(build -c opt //runsc)
-
-# Build packages.
-pkg=$(build -c opt //runsc:runsc-debian)
-
-# Build a repository, if the key is available.
-if [[ -v KOKORO_REPO_KEY ]]; then
- repo=$(tools/make_repository.sh "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" gvisor-bot@google.com main ${pkg})
-fi
-
-# Install installs artifacts.
-install() {
- local -r binaries_dir="$1"
- local -r repo_dir="$2"
- mkdir -p "${binaries_dir}"
- cp -f "${runsc}" "${binaries_dir}"/runsc
- sha512sum "${binaries_dir}"/runsc | awk '{print $1 " runsc"}' > "${binaries_dir}"/runsc.sha512
- if [[ -v repo ]]; then
- rm -rf "${repo_dir}" && mkdir -p "$(dirname "${repo_dir}")"
- cp -a "${repo}" "${repo_dir}"
- fi
-}
-
-# Move the runsc binary into "latest" directory, and also a directory with the
-# current date. If the current commit happens to correpond to a tag, then we
-# will also move everything into a directory named after the given tag.
-if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
- if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then
- # The "latest" directory and current date.
- stamp="$(date -Idate)"
- install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" \
- "${KOKORO_ARTIFACTS_DIR}/dists/nightly/latest"
- install "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" \
- "${KOKORO_ARTIFACTS_DIR}/dists/nightly/${stamp}"
- else
- # Is it a tagged release? Build that instead. In that case, we also try to
- # update the base release directory, in case this is an update. Finally, we
- # update the "release" directory, which has the last released version.
- tags="$(git tag --points-at HEAD)"
- if ! [[ -z "${tags}" ]]; then
- # Note that a given commit can match any number of tags. We have to
- # iterate through all possible tags and produce associated artifacts.
- for tag in ${tags}; do
- name=$(echo "${tag}" | cut -d'-' -f2)
- base=$(echo "${name}" | cut -d'.' -f1)
- install "${KOKORO_ARTIFACTS_DIR}/release/${name}" \
- "${KOKORO_ARTIFACTS_DIR}/dists/${name}"
- if [[ "${base}" != "${tag}" ]]; then
- install "${KOKORO_ARTIFACTS_DIR}/release/${base}" \
- "${KOKORO_ARTIFACTS_DIR}/dists/${base}"
- fi
- install "${KOKORO_ARTIFACTS_DIR}/release/latest" \
- "${KOKORO_ARTIFACTS_DIR}/dists/latest"
- done
- fi
- fi
-fi
diff --git a/scripts/common.sh b/scripts/common.sh
index 6dabad141..3ca699e4a 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -16,12 +16,18 @@
set -xeou pipefail
-if [[ -f $(dirname $0)/common_google.sh ]]; then
- source $(dirname $0)/common_google.sh
+# Get the path to the directory this script lives in.
+# If this script is being called with `source`, $0 will be the path of the
+# *sourcing* script, so we can't use `dirname $0` to find scripts in this
+# directory.
+if [[ -v BASH_SOURCE && "$0" != "$BASH_SOURCE" ]]; then
+ declare -r script_dir="$(dirname "$BASH_SOURCE")"
else
- source $(dirname $0)/common_bazel.sh
+ declare -r script_dir="$(dirname "$0")"
fi
+source "${script_dir}/common_build.sh"
+
# Ensure it attempts to collect logs in all cases.
trap collect_logs EXIT
@@ -73,7 +79,7 @@ function install_runsc() {
sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
# Clear old logs files that may exist.
- sudo rm -f "${RUNSC_LOGS_DIR}"/*
+ sudo rm -f "${RUNSC_LOGS_DIR}"/'*'
# Restart docker to pick up the new runtime configuration.
sudo systemctl restart docker
diff --git a/scripts/common_bazel.sh b/scripts/common_build.sh
index f8ec967b1..d4a6c4908 100755
--- a/scripts/common_bazel.sh
+++ b/scripts/common_build.sh
@@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Install the latest version of Bazel and log the version.
-(which use_bazel.sh && use_bazel.sh latest) || which bazel
+which bazel
bazel version
# Switch into the workspace; only necessary if run with kokoro.
@@ -26,34 +25,30 @@ elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
fi
# Set the standard bazel flags.
-declare -r BAZEL_FLAGS=(
+declare -a BAZEL_FLAGS=(
"--show_timestamps"
"--test_output=errors"
"--keep_going"
"--verbose_failures=true"
)
-if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]] || [[ -v RBE_PROJECT_ID ]]; then
- declare -r RBE_PROJECT_ID="${RBE_PROJECT_ID:-gvisor-rbe}"
- declare -r BAZEL_RBE_FLAGS=(
+# If running via kokoro, use the remote config.
+if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
+ BAZEL_FLAGS+=(
"--config=remote"
- "--project_id=${RBE_PROJECT_ID}"
- "--remote_instance_name=projects/${RBE_PROJECT_ID}/instances/default_instance"
- )
-fi
-if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
- declare -r BAZEL_RBE_AUTH_FLAGS=(
- "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
)
fi
+declare -r BAZEL_FLAGS
# Wrap bazel.
function build() {
- bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 |
- tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }'
+ bazel build "${BAZEL_FLAGS[@]}" "$@" 2>&1 \
+ | tee /dev/fd/2 \
+ | grep -E '^ bazel-bin/' \
+ | awk '{ print $1; }'
}
function test() {
- bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@"
+ bazel test "${BAZEL_FLAGS[@]}" "$@"
}
function run() {
@@ -68,9 +63,22 @@ function run_as_root() {
bazel run --run_under="sudo" "${binary}" -- "$@"
}
+function query() {
+ bazel query "$@"
+}
+
function collect_logs() {
# Zip out everything into a convenient form.
if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then
+ # Merge results files of all shards for each test suite.
+ for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do
+ junitparser merge `find $d -name test.xml` $d/test.xml
+ cat $d/shard_*_of_*/test.log > $d/test.log
+ if ls -ld $d/shard_*_of_*/test.outputs 2>/dev/null; then
+ zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs
+ fi
+ done
+ find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf
# Move test logs to Kokoro directory. tar is used to conveniently perform
# renames while moving files.
find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
@@ -88,12 +96,21 @@ function collect_logs() {
echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
fi
- tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" .
+ time tar \
+ --verbose \
+ --create \
+ --gzip \
+ --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \
+ --directory "${RUNSC_LOGS_DIR}" \
+ .
fi
fi
fi
}
function find_branch_name() {
- git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename
+ git branch --show-current \
+ || git rev-parse HEAD \
+ || bazel info workspace \
+ | xargs basename
}
diff --git a/scripts/dev.sh b/scripts/dev.sh
index c67003018..a9107f33e 100755
--- a/scripts/dev.sh
+++ b/scripts/dev.sh
@@ -54,9 +54,10 @@ declare OUTPUT="$(build //runsc)"
if [[ ${REFRESH} -eq 0 ]]; then
install_runsc "${RUNTIME}" --net-raw
install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets
+ install_runsc "${RUNTIME}-p" --net-raw --profile
echo
- echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup."
+ echo "Runtimes ${RUNTIME}, ${RUNTIME}-d (debug enabled), and ${RUNTIME}-p installed."
echo "Use --runtime="${RUNTIME}" with your Docker command."
echo " docker run --rm --runtime="${RUNTIME}" hello-world"
echo
@@ -65,6 +66,7 @@ if [[ ${REFRESH} -eq 0 ]]; then
else
mkdir -p "$(dirname ${RUNSC_BIN})"
cp -f ${OUTPUT} "${RUNSC_BIN}"
+ chmod a+rx "${RUNSC_BIN}"
echo
echo "Runtime ${RUNTIME} refreshed."
diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh
index 72ba05260..4f3867d05 100755
--- a/scripts/docker_tests.sh
+++ b/scripts/docker_tests.sh
@@ -16,5 +16,10 @@
source $(dirname $0)/common.sh
+make load-all-images
+
install_runsc_for_test docker
test_runsc //test/image:image_test //test/e2e:integration_test
+
+install_runsc_for_test docker --vfs2
+test_runsc //test/e2e:integration_test //test/image:image_test
diff --git a/scripts/fuse_tests.sh b/scripts/fuse_tests.sh
new file mode 100755
index 000000000..bbaaa99fc
--- /dev/null
+++ b/scripts/fuse_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all vfs2_fuse system call tests.
+test --test_tag_filters=fuse //test/fuse/...
diff --git a/scripts/go.sh b/scripts/go.sh
index 0dbfb7747..626ed8fa4 100755
--- a/scripts/go.sh
+++ b/scripts/go.sh
@@ -25,6 +25,8 @@ tools/go_branch.sh
# Checkout the new branch.
git checkout go && git clean -f
+go version
+
# Build everything.
go build ./...
diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh
index 41298293d..992db50dd 100755
--- a/scripts/hostnet_tests.sh
+++ b/scripts/hostnet_tests.sh
@@ -16,6 +16,8 @@
source $(dirname $0)/common.sh
+make load-all-images
+
# Install the runtime and perform basic tests.
install_runsc_for_test hostnet --network=host
test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/iptables_tests.sh b/scripts/iptables_tests.sh
new file mode 100755
index 000000000..8299a7c8b
--- /dev/null
+++ b/scripts/iptables_tests.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+make load-iptables
+
+# Needed by ip6tables.
+sudo modprobe ip6table_filter
+
+install_runsc_for_test iptables --net-raw
+test //test/iptables:iptables_test "--test_arg=--runtime=runc"
+test //test/iptables:iptables_test "--test_arg=--runtime=${RUNTIME}"
diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh
index 5662401df..619571c74 100755
--- a/scripts/kvm_tests.sh
+++ b/scripts/kvm_tests.sh
@@ -16,6 +16,8 @@
source $(dirname $0)/common.sh
+make load-all-images
+
# Ensure that KVM is loaded, and we can use it.
(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
sudo chmod a+rw /dev/kvm
diff --git a/scripts/make_tests.sh b/scripts/make_tests.sh
index 79426756d..dbf1bba77 100755
--- a/scripts/make_tests.sh
+++ b/scripts/make_tests.sh
@@ -16,10 +16,5 @@
source $(dirname $0)/common.sh
-top_level=$(git rev-parse --show-toplevel 2>/dev/null)
-[[ $? -eq 0 ]] && cd "${top_level}" || exit 1
-
-make
make runsc
-make BAZEL_OPTIONS="build //..." bazel
make bazel-shutdown
diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh
index 2a1f12c0b..448864953 100755
--- a/scripts/overlay_tests.sh
+++ b/scripts/overlay_tests.sh
@@ -16,6 +16,8 @@
source $(dirname $0)/common.sh
+make load-all-images
+
# Install the runtime and perform basic tests.
install_runsc_for_test overlay --overlay
test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/kokoro/ubuntu1604/build.sh b/scripts/packetdrill_tests.sh
index d664a3a76..1a8181ac8 100755
--- a/kokoro/ubuntu1604/build.sh
+++ b/scripts/packetdrill_tests.sh
@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -xeo pipefail
+source $(dirname $0)/common.sh
-# Run the image_build.sh script with appropriate parameters.
-IMAGE_PROJECT=ubuntu-os-cloud IMAGE_FAMILY=ubuntu-1604-lts $(dirname $0)/../../tools/image_build.sh $(dirname $0)/??_*.sh
+make load-packetdrill
+
+install_runsc_for_test runsc-d
+QUERY_RESULT=$(query "attr(tags, manual, tests(//test/packetdrill/...))")
+test_runsc $QUERY_RESULT
diff --git a/kokoro/ubuntu1804/build.sh b/scripts/packetimpact_tests.sh
index 2b5c9a6f2..77fb84bc3 100755
--- a/kokoro/ubuntu1804/build.sh
+++ b/scripts/packetimpact_tests.sh
@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -xeo pipefail
+source $(dirname $0)/common.sh
-# Run the image_build.sh script with appropriate parameters.
-IMAGE_PROJECT=ubuntu-os-cloud IMAGE_FAMILY=ubuntu-1804-lts $(dirname $0)/../../tools/image_build.sh $(dirname $0)/??_*.sh
+make load-packetimpact
+
+install_runsc_for_test runsc-d
+QUERY_RESULT=$(query "attr(tags, packetimpact, tests(//test/packetimpact/...))")
+test_runsc $QUERY_RESULT
diff --git a/scripts/release.sh b/scripts/release.sh
deleted file mode 100755
index b936bcc77..000000000
--- a/scripts/release.sh
+++ /dev/null
@@ -1,38 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-# Tag a release only if provided.
-if ! [[ -v KOKORO_RELEASE_COMMIT ]]; then
- echo "No KOKORO_RELEASE_COMMIT provided." >&2
- exit 1
-fi
-if ! [[ -v KOKORO_RELEASE_TAG ]]; then
- echo "No KOKORO_RELEASE_TAG provided." >&2
- exit 1
-fi
-
-# Unless an explicit releaser is provided, use the bot e-mail.
-declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot}
-declare -r EMAIL=${EMAIL:-${KOKORO_RELEASE_AUTHOR}@google.com}
-
-# Ensure we have an appropriate configuration for the tag.
-git config --get user.name || git config user.name "gVisor-bot"
-git config --get user.email || git config user.email "${EMAIL}"
-
-# Run the release tool, which pushes to the origin repository.
-tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}"
diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh
index 4e4fcc76b..3eb735e62 100755
--- a/scripts/root_tests.sh
+++ b/scripts/root_tests.sh
@@ -16,16 +16,10 @@
source $(dirname $0)/common.sh
-# Reinstall the latest containerd shim.
-declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
-declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
-declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
-wget --no-verbose "${base}"/latest -O ${latest}
-wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
-chmod +x ${shim_path}
-sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim
+make load-all-images
+CONTAINERD_VERSION=1.3.4 make sudo TARGETS="tools/installers:containerd"
+make sudo TARGETS="tools/installers:shim"
# Run the tests that require root.
install_runsc_for_test root
run_as_root //test/root:root_test --runtime=${RUNTIME}
-
diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh
new file mode 100755
index 000000000..85e95d45d
--- /dev/null
+++ b/scripts/runtime_tests.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Check that a runtime is provided.
+if [ ! -v RUNTIME_TEST_NAME ]; then
+ echo "Must set $RUNTIME_TEST_NAME" >&2
+ exit 1
+fi
+
+# Download language runtime image.
+make -C images/ "load-runtimes_${RUNTIME_TEST_NAME}"
+
+install_runsc_for_test runtimes
+test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}"
diff --git a/scripts/swgso_tests.sh b/scripts/swgso_tests.sh
index 0de2df1d2..c67f2fe5c 100755
--- a/scripts/swgso_tests.sh
+++ b/scripts/swgso_tests.sh
@@ -16,6 +16,8 @@
source $(dirname $0)/common.sh
+make load-all-images
+
# Install the runtime and perform basic tests.
install_runsc_for_test swgso --software-gso=true --gso=false
test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/syscall_kvm_tests.sh b/scripts/syscall_kvm_tests.sh
new file mode 100755
index 000000000..0e5d86727
--- /dev/null
+++ b/scripts/syscall_kvm_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all ptrace-variants of the system call tests.
+test --test_tag_filters=runsc_kvm //test/syscalls/...
diff --git a/shim/BUILD b/shim/BUILD
new file mode 100644
index 000000000..e581618b2
--- /dev/null
+++ b/shim/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "pkg_tar")
+
+package(licenses = ["notice"])
+
+pkg_tar(
+ name = "config",
+ srcs = [
+ "runsc.toml",
+ ],
+ mode = "0644",
+ package_dir = "/etc/containerd",
+ visibility = [
+ "//runsc:__pkg__",
+ ],
+)
diff --git a/shim/README.md b/shim/README.md
new file mode 100644
index 000000000..75daf00ac
--- /dev/null
+++ b/shim/README.md
@@ -0,0 +1,10 @@
+# Shim Overview
+
+Integration with containerd is done via a [shim][shims]. There are various shims
+supported for different versions of [containerd][containerd].
+
+- [Containerd 1.2+ (shim v2)](https://gvisor.dev/docs/user_guide/containerd/quick_start/)
+- [Containerd 1.1 (shim v1)](https://gvisor.dev/docs/user_guide/containerd/containerd_11/)
+
+[containerd]: https://github.com/containerd/containerd
+[shims]: https://iximiuz.com/en/posts/implementing-container-runtime-shim/
diff --git a/shim/runsc.toml b/shim/runsc.toml
new file mode 100644
index 000000000..e1c7de1bb
--- /dev/null
+++ b/shim/runsc.toml
@@ -0,0 +1,6 @@
+# This is an example configuration file for runsc.
+#
+# By default, it will be parsed from /etc/containerd/runsc.toml, but see the
+# static path configured in v1/main.go. Note that the configuration mechanism
+# for newer container shim versions is different: see the documentation in v2.
+[runsc_config]
diff --git a/shim/v1/BUILD b/shim/v1/BUILD
new file mode 100644
index 000000000..4c9e2c2c6
--- /dev/null
+++ b/shim/v1/BUILD
@@ -0,0 +1,30 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "gvisor-containerd-shim",
+ srcs = [
+ "api.go",
+ "config.go",
+ "main.go",
+ ],
+ static = True,
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/shim/runsc",
+ "//pkg/shim/v1/shim",
+ "@com_github_burntsushi_toml//:go_default_library",
+ "@com_github_containerd_containerd//events:go_default_library",
+ "@com_github_containerd_containerd//namespaces:go_default_library",
+ "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library",
+ "@com_github_containerd_containerd//sys:go_default_library",
+ "@com_github_containerd_containerd//sys/reaper:go_default_library",
+ "@com_github_containerd_ttrpc//:go_default_library",
+ "@com_github_containerd_typeurl//:go_default_library",
+ "@com_github_gogo_protobuf//types:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/shim/v1/api.go b/shim/v1/api.go
new file mode 100644
index 000000000..2444d23f1
--- /dev/null
+++ b/shim/v1/api.go
@@ -0,0 +1,24 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ shim "github.com/containerd/containerd/runtime/v1/shim/v1"
+)
+
+type KillRequest = shim.KillRequest
+
+var registerShimService = shim.RegisterShimService
diff --git a/shim/v1/config.go b/shim/v1/config.go
new file mode 100644
index 000000000..a72cc7754
--- /dev/null
+++ b/shim/v1/config.go
@@ -0,0 +1,40 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import "github.com/BurntSushi/toml"
+
+// config is the configuration for gvisor containerd shim.
+type config struct {
+ // RuncShim is the shim binary path for standard containerd-shim for runc.
+ // When the runtime is `runc`, gvisor containerd shim will exec current
+ // process to standard containerd-shim. This is a work around for containerd
+ // 1.1. In containerd 1.2, containerd will choose different containerd-shims
+ // based on runtime.
+ RuncShim string `toml:"runc_shim"`
+ // RunscConfig is configuration for runsc. The key value will be converted
+ // to runsc flags --key=value directly.
+ RunscConfig map[string]string `toml:"runsc_config"`
+}
+
+// loadConfig load gvisor containerd shim config from config file.
+func loadConfig(path string) (*config, error) {
+ var c config
+ _, err := toml.DecodeFile(path, &c)
+ if err != nil {
+ return &c, err
+ }
+ return &c, nil
+}
diff --git a/shim/v1/main.go b/shim/v1/main.go
new file mode 100644
index 000000000..3159923af
--- /dev/null
+++ b/shim/v1/main.go
@@ -0,0 +1,265 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "bytes"
+ "context"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "strings"
+ "sync"
+ "syscall"
+
+ "github.com/containerd/containerd/events"
+ "github.com/containerd/containerd/namespaces"
+ "github.com/containerd/containerd/sys"
+ "github.com/containerd/containerd/sys/reaper"
+ "github.com/containerd/ttrpc"
+ "github.com/containerd/typeurl"
+ "github.com/gogo/protobuf/types"
+ "golang.org/x/sys/unix"
+
+ "gvisor.dev/gvisor/pkg/shim/runsc"
+ "gvisor.dev/gvisor/pkg/shim/v1/shim"
+)
+
+var (
+ debugFlag bool
+ namespaceFlag string
+ socketFlag string
+ addressFlag string
+ workdirFlag string
+ runtimeRootFlag string
+ containerdBinaryFlag string
+ shimConfigFlag string
+)
+
+// Containerd defaults to runc, unless another runtime is explicitly specified.
+// We keep the same default to make the default behavior consistent.
+const defaultRoot = "/run/containerd/runc"
+
+func init() {
+ flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs")
+ flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim")
+ flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve")
+ flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd")
+ flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data")
+ flag.StringVar(&runtimeRootFlag, "runtime-root", defaultRoot, "root directory for the runtime")
+
+ // Currently, the `containerd publish` utility is embedded in the
+ // daemon binary. The daemon invokes `containerd-shim
+ // -containerd-binary ...` with its own os.Executable() path.
+ flag.StringVar(&containerdBinaryFlag, "containerd-binary", "containerd", "path to containerd binary (used for `containerd publish`)")
+ flag.StringVar(&shimConfigFlag, "config", "/etc/containerd/runsc.toml", "path to the shim configuration file")
+}
+
+func main() {
+ flag.Parse()
+
+ // This is a hack. Exec current process to run standard containerd-shim
+ // if runtime root is not `runsc`. We don't need this for shim v2 api.
+ if filepath.Base(runtimeRootFlag) != "runsc" {
+ if err := executeRuncShim(); err != nil {
+ fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err)
+ os.Exit(1)
+ }
+ }
+
+ // Run regular shim if needed.
+ if err := executeShim(); err != nil {
+ fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err)
+ os.Exit(1)
+ }
+}
+
+// executeRuncShim execs current process to a containerd-shim process and
+// retains all flags and envs.
+func executeRuncShim() error {
+ c, err := loadConfig(shimConfigFlag)
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to load shim config: %w", err)
+ }
+ shimPath := c.RuncShim
+ if shimPath == "" {
+ shimPath, err = exec.LookPath("containerd-shim")
+ if err != nil {
+ return fmt.Errorf("lookup containerd-shim failed: %w", err)
+ }
+ }
+
+ args := append([]string{shimPath}, os.Args[1:]...)
+ if err := syscall.Exec(shimPath, args, os.Environ()); err != nil {
+ return fmt.Errorf("exec containerd-shim @ %q failed: %w", shimPath, err)
+ }
+ return nil
+}
+
+func executeShim() error {
+ // start handling signals as soon as possible so that things are
+ // properly reaped or if runtime exits before we hit the handler.
+ signals, err := setupSignals()
+ if err != nil {
+ return err
+ }
+ path, err := os.Getwd()
+ if err != nil {
+ return err
+ }
+ server, err := ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser()))
+ if err != nil {
+ return fmt.Errorf("failed creating server: %w", err)
+ }
+ c, err := loadConfig(shimConfigFlag)
+ if err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to load shim config: %w", err)
+ }
+ sv, err := shim.NewService(
+ shim.Config{
+ Path: path,
+ Namespace: namespaceFlag,
+ WorkDir: workdirFlag,
+ RuntimeRoot: runtimeRootFlag,
+ RunscConfig: c.RunscConfig,
+ },
+ &remoteEventsPublisher{address: addressFlag},
+ )
+ if err != nil {
+ return err
+ }
+ registerShimService(server, sv)
+ if err := serve(server, socketFlag); err != nil {
+ return err
+ }
+ return handleSignals(signals, server, sv)
+}
+
+// serve serves the ttrpc API over a unix socket at the provided path this
+// function does not block.
+func serve(server *ttrpc.Server, path string) error {
+ var (
+ l net.Listener
+ err error
+ )
+ if path == "" {
+ l, err = net.FileListener(os.NewFile(3, "socket"))
+ path = "[inherited from parent]"
+ } else {
+ if len(path) > 106 {
+ return fmt.Errorf("%q: unix socket path too long (> 106)", path)
+ }
+ l, err = net.Listen("unix", "\x00"+path)
+ }
+ if err != nil {
+ return err
+ }
+ go func() {
+ defer l.Close()
+ err := server.Serve(context.Background(), l)
+ if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
+ log.Fatalf("ttrpc server failure: %v", err)
+ }
+ }()
+ return nil
+}
+
+// setupSignals creates a new signal handler for all signals and sets the shim
+// as a sub-reaper so that the container processes are reparented.
+func setupSignals() (chan os.Signal, error) {
+ signals := make(chan os.Signal, 32)
+ signal.Notify(signals, unix.SIGTERM, unix.SIGINT, unix.SIGCHLD, unix.SIGPIPE)
+ // make sure runc is setup to use the monitor for waiting on processes.
+ // TODO(random-liu): Move shim/reaper.go to a separate package.
+ runsc.Monitor = reaper.Default
+ // Set the shim as the subreaper for all orphaned processes created by
+ // the container.
+ if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil {
+ return nil, err
+ }
+ return signals, nil
+}
+
+func handleSignals(signals chan os.Signal, server *ttrpc.Server, sv *shim.Service) error {
+ var (
+ termOnce sync.Once
+ done = make(chan struct{})
+ )
+
+ for {
+ select {
+ case <-done:
+ return nil
+ case s := <-signals:
+ switch s {
+ case unix.SIGCHLD:
+ if _, err := sys.Reap(false); err != nil {
+ log.Printf("reap error: %v", err)
+ }
+ case unix.SIGTERM, unix.SIGINT:
+ go termOnce.Do(func() {
+ ctx := context.TODO()
+ if err := server.Shutdown(ctx); err != nil {
+ log.Printf("failed to shutdown server: %v", err)
+ }
+ // Ensure our child is dead if any.
+ sv.Kill(ctx, &KillRequest{
+ Signal: uint32(syscall.SIGKILL),
+ All: true,
+ })
+ sv.Delete(context.Background(), &types.Empty{})
+ close(done)
+ })
+ case unix.SIGPIPE:
+ }
+ }
+ }
+}
+
+type remoteEventsPublisher struct {
+ address string
+}
+
+func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error {
+ ns, _ := namespaces.Namespace(ctx)
+ encoded, err := typeurl.MarshalAny(event)
+ if err != nil {
+ return err
+ }
+ data, err := encoded.Marshal()
+ if err != nil {
+ return err
+ }
+ cmd := exec.CommandContext(ctx, containerdBinaryFlag, "--address", l.address, "publish", "--topic", topic, "--namespace", ns)
+ cmd.Stdin = bytes.NewReader(data)
+ c, err := reaper.Default.Start(cmd)
+ if err != nil {
+ return err
+ }
+ status, err := reaper.Default.Wait(cmd, c)
+ if err != nil {
+ return fmt.Errorf("failed to publish event: %w", err)
+ }
+ if status != 0 {
+ return fmt.Errorf("failed to publish event: status %d", status)
+ }
+ return nil
+}
diff --git a/shim/v2/BUILD b/shim/v2/BUILD
new file mode 100644
index 000000000..8de9ac0ba
--- /dev/null
+++ b/shim/v2/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "containerd-shim-runsc-v1",
+ srcs = [
+ "main.go",
+ ],
+ static = True,
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/shim/v2",
+ "@com_github_containerd_containerd//runtime/v2/shim:go_default_library",
+ ],
+)
diff --git a/shim/v2/main.go b/shim/v2/main.go
new file mode 100644
index 000000000..753871eea
--- /dev/null
+++ b/shim/v2/main.go
@@ -0,0 +1,26 @@
+// Copyright 2018 The containerd Authors.
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "github.com/containerd/containerd/runtime/v2/shim"
+
+ "gvisor.dev/gvisor/pkg/shim/v2"
+)
+
+func main() {
+ shim.Run("io.containerd.runsc.v1", v2.New)
+}
diff --git a/test/BUILD b/test/BUILD
index 01fa01f2e..34b950644 100644
--- a/test/BUILD
+++ b/test/BUILD
@@ -1,44 +1 @@
-package(licenses = ["notice"]) # Apache 2.0
-
-# We need to define a bazel platform and toolchain to specify dockerPrivileged
-# and dockerRunAsRoot options, they are required to run tests on the RBE
-# cluster in Kokoro.
-alias(
- name = "rbe_ubuntu1604",
- actual = ":rbe_ubuntu1604_r346485",
-)
-
-platform(
- name = "rbe_ubuntu1604_r346485",
- constraint_values = [
- "@bazel_tools//platforms:x86_64",
- "@bazel_tools//platforms:linux",
- "@bazel_tools//tools/cpp:clang",
- "@bazel_toolchains//constraints:xenial",
- "@bazel_toolchains//constraints/sanitizers:support_msan",
- ],
- remote_execution_properties = """
- properties: {
- name: "container-image"
- value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:69c9f1652941d64a46f6f7358a44c1718f25caa5cb1ced4a58ccc5281cd183b5"
- }
- properties: {
- name: "dockerAddCapabilities"
- value: "SYS_ADMIN"
- }
- properties: {
- name: "dockerPrivileged"
- value: "true"
- }
- """,
-)
-
-toolchain(
- name = "cc-toolchain-clang-x86_64-default",
- exec_compatible_with = [
- ],
- target_compatible_with = [
- ],
- toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/9.0.0/bazel_0.28.0/cc:cc-compiler-k8",
- toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
-)
+package(licenses = ["notice"])
diff --git a/test/README.md b/test/README.md
index 97fe7ea04..02bbf42ff 100644
--- a/test/README.md
+++ b/test/README.md
@@ -24,11 +24,11 @@ also used to run these tests in `kokoro`.
To run image and integration tests, run:
-`./scripts/docker_test.sh`
+`./scripts/docker_tests.sh`
To run root tests, run:
-`./scripts/root_test.sh`
+`./scripts/root_tests.sh`
There are a few other interesting variations for image and integration tests:
diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md
new file mode 100644
index 000000000..d1bbabf6f
--- /dev/null
+++ b/test/benchmarks/README.md
@@ -0,0 +1,157 @@
+# Benchmark tools
+
+This package and subpackages are for running macro benchmarks on `runsc`. They
+are meant to replace the previous //benchmarks benchmark-tools written in
+python.
+
+Benchmarks are meant to look like regular golang benchmarks using the testing.B
+library.
+
+## Setup
+
+To run benchmarks you will need:
+
+* Docker installed (17.09.0 or greater).
+
+The easiest way to setup runsc for running benchmarks is to use the make file.
+From the root directory:
+
+* Download images: `make load-all-images`
+* Install runsc suitable for benchmarking, which should probably not have
+ strace or debug logs enabled. For example:`make configure RUNTIME=myrunsc
+ ARGS=--platform=kvm`.
+* Restart docker: `sudo service docker restart`
+
+You should now have a runtime with the following options configured in
+`/etc/docker/daemon.json`
+
+```
+"myrunsc": {
+ "path": "/tmp/myrunsc/runsc",
+ "runtimeArgs": [
+ "--debug-log",
+ "/tmp/bench/logs/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%",
+ "--platform=kvm"
+ ]
+ },
+
+```
+
+This runtime has been configured with a debugging off and strace logs off and is
+using kvm for demonstration.
+
+## Running benchmarks
+
+Given the runtime above runtime `myrunsc`, run benchmarks with the following:
+
+```
+make sudo TARGETS=//path/to:target ARGS="--runtime=myrunsc -test.v \
+ -test.bench=." OPTIONS="-c opt
+```
+
+For example, to run only the Iperf tests:
+
+```
+make sudo TARGETS=//test/benchmarks/network:network_test \
+ ARGS="--runtime=myrunsc -test.v -test.bench=Iperf" OPTIONS="-c opt"
+```
+
+Benchmarks are run with root as some benchmarks require root privileges to do
+things like drop caches.
+
+## Writing benchmarks
+
+Benchmarks consist of docker images as Dockerfiles and golang testing.B
+benchmarks.
+
+### Dockerfiles:
+
+* Are stored at //images.
+* New Dockerfiles go in an appropriately named directory at
+ `//images/benchmarks/my-cool-dockerfile`.
+* Dockerfiles for benchmarks should:
+ * Use explicitly versioned packages.
+ * Not use ENV and CMD statements...it is easy to add these in the API.
+* Note: A common pattern for getting access to a tmpfs mount is to copy files
+ there after container start. See: //test/benchmarks/build/bazel_test.go. You
+ can also make your own with `RunOpts.Mounts`.
+
+### testing.B packages
+
+In general, benchmarks should look like this:
+
+```golang
+
+var h harness.Harness
+
+func BenchmarkMyCoolOne(b *testing.B) {
+ machine, err := h.GetMachine()
+ // check err
+ defer machine.CleanUp()
+
+ ctx := context.Background()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
+
+ b.ResetTimer()
+
+ //Respect b.N.
+ for i := 0; i < b.N; i++ {
+ out, err := container.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/my-cool-image",
+ Env: []string{"MY_VAR=awesome"},
+ other options...see dockerutil
+ }, "sh", "-c", "echo MY_VAR")
+ //check err
+ b.StopTimer()
+
+ // Do parsing and reporting outside of the timer.
+ number := parseMyMetric(out)
+ b.ReportMetric(number, "my-cool-custom-metric")
+
+ b.StartTimer()
+ }
+}
+
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
+```
+
+Some notes on the above:
+
+* The harness is initiated in the TestMain method and made global to test
+ module. The harness will handle any presetup that needs to happen with
+ flags, remote virtual machines (eventually), and other services.
+* Respect `b.N` in that users of the benchmark may want to "run for an hour"
+ or something of the sort.
+* Use the `b.ReportMetric()` method to report custom metrics.
+* Set the timer if time is useful for reporting. There isn't a way to turn off
+ default metrics in testing.B (B/op, allocs/op, ns/op).
+* Take a look at dockerutil at //pkg/test/dockerutil to see all methods
+ available from containers. The API is based on the "official"
+ [docker API for golang](https://pkg.go.dev/mod/github.com/docker/docker).
+* `harness.GetMachine()` marks how many machines this tests needs. If you have
+ a client and server and to mark them as multiple machines, call
+ `harness.GetMachine()` twice.
+
+## Profiling
+
+For profiling, the runtime is required to have the `--profile` flag enabled.
+This flag loosens seccomp filters so that the runtime can write profile data to
+disk. This configuration is not recommended for production.
+
+* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc
+ ARGS="--profile --platform=kvm --vfs2"`. The kvm and vfs2 flags are not
+ required, but are included for demonstration.
+* Restart docker: `sudo service docker restart`
+
+To run and generate CPU profiles fs_test test run:
+
+```
+make sudo TARGETS=//test/benchmarks/fs:fs_test \
+ ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt"
+```
+
+Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof`
diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD
new file mode 100644
index 000000000..32c139204
--- /dev/null
+++ b/test/benchmarks/base/BUILD
@@ -0,0 +1,34 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "base",
+ testonly = 1,
+ srcs = [
+ "base.go",
+ ],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "base_test",
+ size = "large",
+ srcs = [
+ "size_test.go",
+ "startup_test.go",
+ "sysbench_test.go",
+ ],
+ library = ":base",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
diff --git a/test/benchmarks/base/base.go b/test/benchmarks/base/base.go
new file mode 100644
index 000000000..7bac52ff1
--- /dev/null
+++ b/test/benchmarks/base/base.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package base holds base performance benchmarks.
+package base
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var testHarness harness.Harness
+
+// TestMain is the main method for package network.
+func TestMain(m *testing.M) {
+ testHarness.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/base/size_test.go b/test/benchmarks/base/size_test.go
new file mode 100644
index 000000000..3c1364faf
--- /dev/null
+++ b/test/benchmarks/base/size_test.go
@@ -0,0 +1,220 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package base
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// BenchmarkSizeEmpty creates N empty containers and reads memory usage from
+// /proc/meminfo.
+func BenchmarkSizeEmpty(b *testing.B) {
+ machine, err := testHarness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer machine.CleanUp()
+ meminfo := tools.Meminfo{}
+ ctx := context.Background()
+ containers := make([]*dockerutil.Container, 0, b.N)
+
+ // DropCaches before the test.
+ harness.DropCaches(machine)
+
+ // Check available memory on 'machine'.
+ cmd, args := meminfo.MakeCmd()
+ before, err := machine.RunCommand(cmd, args...)
+ if err != nil {
+ b.Fatalf("failed to get meminfo: %v", err)
+ }
+
+ // Make N containers.
+ for i := 0; i < b.N; i++ {
+ container := machine.GetContainer(ctx, b)
+ containers = append(containers, container)
+ if err := container.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/alpine",
+ }, "sh", "-c", "echo Hello && sleep 1000"); err != nil {
+ cleanUpContainers(ctx, containers)
+ b.Fatalf("failed to run container: %v", err)
+ }
+ if _, err := container.WaitForOutputSubmatch(ctx, "Hello", 5*time.Second); err != nil {
+ cleanUpContainers(ctx, containers)
+ b.Fatalf("failed to read container output: %v", err)
+ }
+ }
+
+ // Drop caches again before second measurement.
+ harness.DropCaches(machine)
+
+ // Check available memory after containers are up.
+ after, err := machine.RunCommand(cmd, args...)
+ cleanUpContainers(ctx, containers)
+ if err != nil {
+ b.Fatalf("failed to get meminfo: %v", err)
+ }
+ meminfo.Report(b, before, after)
+}
+
+// BenchmarkSizeNginx starts N containers running Nginx, checks that they're
+// serving, and checks memory used based on /proc/meminfo.
+func BenchmarkSizeNginx(b *testing.B) {
+ machine, err := testHarness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine with: %v", err)
+ }
+ defer machine.CleanUp()
+
+ // DropCaches for the first measurement.
+ harness.DropCaches(machine)
+
+ // Measure MemAvailable before creating containers.
+ meminfo := tools.Meminfo{}
+ cmd, args := meminfo.MakeCmd()
+ before, err := machine.RunCommand(cmd, args...)
+ if err != nil {
+ b.Fatalf("failed to run meminfo command: %v", err)
+ }
+
+ // Make N Nginx containers.
+ ctx := context.Background()
+ runOpts := dockerutil.RunOpts{
+ Image: "benchmarks/nginx",
+ }
+ const port = 80
+ servers := startServers(ctx, b,
+ serverArgs{
+ machine: machine,
+ port: port,
+ runOpts: runOpts,
+ })
+ defer cleanUpContainers(ctx, servers)
+
+ // DropCaches after servers are created.
+ harness.DropCaches(machine)
+ // Take after measurement.
+ after, err := machine.RunCommand(cmd, args...)
+ if err != nil {
+ b.Fatalf("failed to run meminfo command: %v", err)
+ }
+ meminfo.Report(b, before, after)
+}
+
+// BenchmarkSizeNode starts N containers running a Node app, checks that
+// they're serving, and checks memory used based on /proc/meminfo.
+func BenchmarkSizeNode(b *testing.B) {
+ machine, err := testHarness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine with: %v", err)
+ }
+ defer machine.CleanUp()
+
+ // Make a redis instance for Node to connect.
+ ctx := context.Background()
+ redis, redisIP := redisInstance(ctx, b, machine)
+ defer redis.CleanUp(ctx)
+
+ // DropCaches after redis is created.
+ harness.DropCaches(machine)
+
+ // Take before measurement.
+ meminfo := tools.Meminfo{}
+ cmd, args := meminfo.MakeCmd()
+ before, err := machine.RunCommand(cmd, args...)
+ if err != nil {
+ b.Fatalf("failed to run meminfo commend: %v", err)
+ }
+
+ // Create N Node servers.
+ runOpts := dockerutil.RunOpts{
+ Image: "benchmarks/node",
+ WorkDir: "/usr/src/app",
+ Links: []string{redis.MakeLink("redis")},
+ }
+ nodeCmd := []string{"node", "index.js", redisIP.String()}
+ const port = 8080
+ servers := startServers(ctx, b,
+ serverArgs{
+ machine: machine,
+ port: port,
+ runOpts: runOpts,
+ cmd: nodeCmd,
+ })
+ defer cleanUpContainers(ctx, servers)
+
+ // DropCaches after servers are created.
+ harness.DropCaches(machine)
+ // Take after measurement.
+ cmd, args = meminfo.MakeCmd()
+ after, err := machine.RunCommand(cmd, args...)
+ if err != nil {
+ b.Fatalf("failed to run meminfo command: %v", err)
+ }
+ meminfo.Report(b, before, after)
+}
+
+// serverArgs wraps args for startServers and runServerWorkload.
+type serverArgs struct {
+ machine harness.Machine
+ port int
+ runOpts dockerutil.RunOpts
+ cmd []string
+}
+
+// startServers starts b.N containers defined by 'runOpts' and 'cmd' and uses
+// 'machine' to check that each is up.
+func startServers(ctx context.Context, b *testing.B, args serverArgs) []*dockerutil.Container {
+ b.Helper()
+ servers := make([]*dockerutil.Container, 0, b.N)
+
+ // Create N servers and wait until each of them is serving.
+ for i := 0; i < b.N; i++ {
+ server := args.machine.GetContainer(ctx, b)
+ servers = append(servers, server)
+ if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil {
+ cleanUpContainers(ctx, servers)
+ b.Fatalf("failed to spawn node instance: %v", err)
+ }
+
+ // Get the container IP.
+ servingIP, err := server.FindIP(ctx, false)
+ if err != nil {
+ cleanUpContainers(ctx, servers)
+ b.Fatalf("failed to get ip from server: %v", err)
+ }
+
+ // Wait until the server is up.
+ if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil {
+ cleanUpContainers(ctx, servers)
+ b.Fatalf("failed to wait for serving")
+ }
+ }
+ return servers
+}
+
+// cleanUpContainers cleans up a slice of containers.
+func cleanUpContainers(ctx context.Context, containers []*dockerutil.Container) {
+ for _, c := range containers {
+ if c != nil {
+ c.CleanUp(ctx)
+ }
+ }
+}
diff --git a/test/benchmarks/base/startup_test.go b/test/benchmarks/base/startup_test.go
new file mode 100644
index 000000000..4628a0a41
--- /dev/null
+++ b/test/benchmarks/base/startup_test.go
@@ -0,0 +1,156 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package base
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+// BenchmarkStartEmpty times startup time for an empty container.
+func BenchmarkStartupEmpty(b *testing.B) {
+ machine, err := testHarness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer machine.CleanUp()
+
+ ctx := context.Background()
+ for i := 0; i < b.N; i++ {
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
+ if _, err := container.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/alpine",
+ }, "true"); err != nil {
+ b.Fatalf("failed to run container: %v", err)
+ }
+ }
+}
+
+// BenchmarkStartupNginx times startup for a Nginx instance.
+// Time is measured from start until the first request is served.
+func BenchmarkStartupNginx(b *testing.B) {
+ // The machine to hold Nginx and the Node Server.
+ machine, err := testHarness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine with: %v", err)
+ }
+ defer machine.CleanUp()
+
+ ctx := context.Background()
+ runOpts := dockerutil.RunOpts{
+ Image: "benchmarks/nginx",
+ }
+ runServerWorkload(ctx, b,
+ serverArgs{
+ machine: machine,
+ runOpts: runOpts,
+ port: 80,
+ })
+}
+
+// BenchmarkStartupNode times startup for a Node application instance.
+// Time is measured from start until the first request is served.
+// Note that the Node app connects to a Redis instance before serving.
+func BenchmarkStartupNode(b *testing.B) {
+ machine, err := testHarness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine with: %v", err)
+ }
+ defer machine.CleanUp()
+
+ ctx := context.Background()
+ redis, redisIP := redisInstance(ctx, b, machine)
+ defer redis.CleanUp(ctx)
+ runOpts := dockerutil.RunOpts{
+ Image: "benchmarks/node",
+ WorkDir: "/usr/src/app",
+ Links: []string{redis.MakeLink("redis")},
+ }
+
+ cmd := []string{"node", "index.js", redisIP.String()}
+ runServerWorkload(ctx, b,
+ serverArgs{
+ machine: machine,
+ port: 8080,
+ runOpts: runOpts,
+ cmd: cmd,
+ })
+}
+
+// redisInstance returns a Redis container and its reachable IP.
+func redisInstance(ctx context.Context, b *testing.B, machine harness.Machine) (*dockerutil.Container, net.IP) {
+ b.Helper()
+ // Spawn a redis instance for the app to use.
+ redis := machine.GetNativeContainer(ctx, b)
+ if err := redis.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ }); err != nil {
+ redis.CleanUp(ctx)
+ b.Fatalf("failed to spwan redis instance: %v", err)
+ }
+
+ if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil {
+ redis.CleanUp(ctx)
+ b.Fatalf("failed to start redis server: %v %s", err, out)
+ }
+ redisIP, err := redis.FindIP(ctx, false)
+ if err != nil {
+ redis.CleanUp(ctx)
+ b.Fatalf("failed to get IP from redis instance: %v", err)
+ }
+ return redis, redisIP
+}
+
+// runServerWorkload runs a server workload defined by 'runOpts' and 'cmd'.
+// 'clientMachine' is used to connect to the server on 'serverMachine'.
+func runServerWorkload(ctx context.Context, b *testing.B, args serverArgs) {
+ b.Helper()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if err := func() error {
+ server := args.machine.GetContainer(ctx, b)
+ defer func() {
+ b.StopTimer()
+ // Cleanup servers as we run so that we can go indefinitely.
+ server.CleanUp(ctx)
+ b.StartTimer()
+ }()
+ if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil {
+ return fmt.Errorf("failed to spawn node instance: %v", err)
+ }
+
+ servingIP, err := server.FindIP(ctx, false)
+ if err != nil {
+ return fmt.Errorf("failed to get ip from server: %v", err)
+ }
+
+ // Wait until the Client sees the server as up.
+ if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil {
+ return fmt.Errorf("failed to wait for serving: %v", err)
+ }
+ return nil
+ }(); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
diff --git a/test/benchmarks/base/sysbench_test.go b/test/benchmarks/base/sysbench_test.go
new file mode 100644
index 000000000..6fb813640
--- /dev/null
+++ b/test/benchmarks/base/sysbench_test.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package base
+
+import (
+ "context"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+type testCase struct {
+ name string
+ test tools.Sysbench
+}
+
+// BenchmarSysbench runs sysbench on the runtime.
+func BenchmarkSysbench(b *testing.B) {
+
+ testCases := []testCase{
+ testCase{
+ name: "CPU",
+ test: &tools.SysbenchCPU{
+ Base: tools.SysbenchBase{
+ Threads: 1,
+ Time: 5,
+ },
+ MaxPrime: 50000,
+ },
+ },
+ testCase{
+ name: "Memory",
+ test: &tools.SysbenchMemory{
+ Base: tools.SysbenchBase{
+ Threads: 1,
+ },
+ BlockSize: "1M",
+ TotalSize: "500G",
+ },
+ },
+ testCase{
+ name: "Mutex",
+ test: &tools.SysbenchMutex{
+ Base: tools.SysbenchBase{
+ Threads: 8,
+ },
+ Loops: 1,
+ Locks: 10000000,
+ Num: 4,
+ },
+ },
+ }
+
+ machine, err := testHarness.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer machine.CleanUp()
+
+ for _, tc := range testCases {
+ b.Run(tc.name, func(b *testing.B) {
+
+ ctx := context.Background()
+ sysbench := machine.GetContainer(ctx, b)
+ defer sysbench.CleanUp(ctx)
+
+ out, err := sysbench.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/sysbench",
+ }, tc.test.MakeCmd()...)
+ if err != nil {
+ b.Fatalf("failed to run sysbench: %v: logs:%s", err, out)
+ }
+ tc.test.Report(b, out)
+ })
+ }
+}
diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD
new file mode 100644
index 000000000..93b380e8a
--- /dev/null
+++ b/test/benchmarks/database/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "database",
+ testonly = 1,
+ srcs = ["database.go"],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "database_test",
+ size = "enormous",
+ srcs = ["redis_test.go"],
+ library = ":database",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go
new file mode 100644
index 000000000..9eeb59f9a
--- /dev/null
+++ b/test/benchmarks/database/database.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package database holds benchmarks around database applications.
+package database
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var h harness.Harness
+
+// TestMain is the main method for package database.
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go
new file mode 100644
index 000000000..394fce820
--- /dev/null
+++ b/test/benchmarks/database/redis_test.go
@@ -0,0 +1,123 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// All possible operations from redis. Note: "ping" will
+// run both PING_INLINE and PING_BUILD.
+var operations []string = []string{
+ "PING_INLINE",
+ "PING_BULK",
+ "SET",
+ "GET",
+ "INCR",
+ "LPUSH",
+ "RPUSH",
+ "LPOP",
+ "RPOP",
+ "SADD",
+ "HSET",
+ "SPOP",
+ "LRANGE_100",
+ "LRANGE_300",
+ "LRANGE_500",
+ "LRANGE_600",
+ "MSET",
+}
+
+// BenchmarkRedis runs redis-benchmark against a redis instance and reports
+// data in queries per second. Each is reported by named operation (e.g. LPUSH).
+func BenchmarkRedis(b *testing.B) {
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // Redis runs on port 6379 by default.
+ port := 6379
+ ctx := context.Background()
+
+ for _, operation := range operations {
+ b.Run(operation, func(b *testing.B) {
+ server := serverMachine.GetContainer(ctx, b)
+ defer server.CleanUp(ctx)
+
+ // The redis docker container takes no arguments to run a redis server.
+ if err := server.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ Ports: []int{port},
+ }); err != nil {
+ b.Fatalf("failed to start redis server with: %v", err)
+ }
+
+ if out, err := server.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil {
+ b.Fatalf("failed to start redis server: %v %s", err, out)
+ }
+
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatal("failed to get IP from server: %v", err)
+ }
+
+ serverPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatal("failed to get IP from server: %v", err)
+ }
+
+ if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil {
+ b.Fatalf("failed to start redis with: %v", err)
+ }
+
+ redis := tools.Redis{
+ Operation: operation,
+ }
+
+ // Reset profiles and timer to begin the measurement.
+ server.RestartProfiles()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ client := clientMachine.GetNativeContainer(ctx, b)
+ defer client.CleanUp(ctx)
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ }, redis.MakeCmd(ip, serverPort)...)
+ if err != nil {
+ b.Fatalf("redis-benchmark failed with: %v", err)
+ }
+
+ // Stop time while we parse results.
+ b.StopTimer()
+ redis.Report(b, out)
+ b.StartTimer()
+ }
+ })
+ }
+}
diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD
new file mode 100644
index 000000000..45f11372b
--- /dev/null
+++ b/test/benchmarks/fs/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "fs",
+ testonly = 1,
+ srcs = ["fs.go"],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "fs_test",
+ size = "large",
+ srcs = [
+ "bazel_test.go",
+ "fio_test.go",
+ ],
+ library = ":fs",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ "local",
+ "manual",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ ],
+)
diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go
new file mode 100644
index 000000000..f4236ba37
--- /dev/null
+++ b/test/benchmarks/fs/bazel_test.go
@@ -0,0 +1,119 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package fs
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+// Note: CleanCache versions of this test require running with root permissions.
+func BenchmarkBuildABSL(b *testing.B) {
+ runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...")
+}
+
+// Note: CleanCache versions of this test require running with root permissions.
+// Note: This test takes on the order of 10m per permutation for runsc on kvm.
+func BenchmarkBuildRunsc(b *testing.B) {
+ runBuildBenchmark(b, "benchmarks/runsc", "/gvisor", "runsc:runsc")
+}
+
+func runBuildBenchmark(b *testing.B, image, workdir, target string) {
+ b.Helper()
+ // Get a machine from the Harness on which to run.
+ machine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer machine.CleanUp()
+
+ // Dimensions here are clean/dirty cache (do or don't drop caches)
+ // and if the mount on which we are compiling is a tmpfs/bind mount.
+ benchmarks := []struct {
+ name string
+ clearCache bool // clearCache drops caches before running.
+ tmpfs bool // tmpfs will run compilation on a tmpfs.
+ }{
+ {name: "CleanCache", clearCache: true, tmpfs: false},
+ {name: "DirtyCache", clearCache: false, tmpfs: false},
+ {name: "CleanCacheTmpfs", clearCache: true, tmpfs: true},
+ {name: "DirtyCacheTmpfs", clearCache: false, tmpfs: true},
+ }
+ for _, bm := range benchmarks {
+ b.Run(bm.name, func(b *testing.B) {
+ // Grab a container.
+ ctx := context.Background()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
+
+ // Start a container and sleep by an order of b.N.
+ if err := container.Spawn(ctx, dockerutil.RunOpts{
+ Image: image,
+ }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil {
+ b.Fatalf("run failed with: %v", err)
+ }
+
+ // If we are running on a tmpfs, copy to /tmp which is a tmpfs.
+ if bm.tmpfs {
+ if out, err := container.Exec(ctx, dockerutil.ExecOpts{},
+ "cp", "-r", workdir, "/tmp/."); err != nil {
+ b.Fatal("failed to copy directory: %v %s", err, out)
+ }
+ workdir = "/tmp" + workdir
+ }
+
+ // Restart profiles after the copy.
+ container.RestartProfiles()
+ b.ResetTimer()
+ // Drop Caches and bazel clean should happen inside the loop as we may use
+ // time options with b.N. (e.g. Run for an hour.)
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ // Drop Caches for clear cache runs.
+ if bm.clearCache {
+ if err := harness.DropCaches(machine); err != nil {
+ b.Skipf("failed to drop caches: %v. You probably need root.", err)
+ }
+ }
+ b.StartTimer()
+
+ got, err := container.Exec(ctx, dockerutil.ExecOpts{
+ WorkDir: workdir,
+ }, "bazel", "build", "-c", "opt", target)
+ if err != nil {
+ b.Fatalf("build failed with: %v", err)
+ }
+ b.StopTimer()
+
+ want := "Build completed successfully"
+ if !strings.Contains(got, want) {
+ b.Fatalf("string %s not in: %s", want, got)
+ }
+ // Clean bazel in case we use b.N.
+ _, err = container.Exec(ctx, dockerutil.ExecOpts{
+ WorkDir: workdir,
+ }, "bazel", "clean")
+ if err != nil {
+ b.Fatalf("build failed with: %v", err)
+ }
+ b.StartTimer()
+ }
+ })
+ }
+}
diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go
new file mode 100644
index 000000000..65874ed8b
--- /dev/null
+++ b/test/benchmarks/fs/fio_test.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package fs
+
+import (
+ "context"
+ "fmt"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/docker/docker/api/types/mount"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// BenchmarkFio runs fio on the runtime under test. There are 4 basic test
+// cases each run on a tmpfs mount and a bind mount. Fio requires root so that
+// caches can be dropped.
+func BenchmarkFio(b *testing.B) {
+ testCases := []tools.Fio{
+ tools.Fio{
+ Test: "write",
+ Size: "5G",
+ Blocksize: "1M",
+ Iodepth: 4,
+ },
+ tools.Fio{
+ Test: "read",
+ Size: "5G",
+ Blocksize: "1M",
+ Iodepth: 4,
+ },
+ tools.Fio{
+ Test: "randwrite",
+ Size: "5G",
+ Blocksize: "4K",
+ Iodepth: 4,
+ Time: 30,
+ },
+ tools.Fio{
+ Test: "randread",
+ Size: "5G",
+ Blocksize: "4K",
+ Iodepth: 4,
+ Time: 30,
+ },
+ }
+
+ machine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine with: %v", err)
+ }
+ defer machine.CleanUp()
+
+ for _, fsType := range []mount.Type{mount.TypeBind, mount.TypeTmpfs} {
+ for _, tc := range testCases {
+ testName := strings.Title(tc.Test) + strings.Title(string(fsType))
+ b.Run(testName, func(b *testing.B) {
+ ctx := context.Background()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
+
+ // Directory and filename inside container where fio will read/write.
+ outdir := "/data"
+ outfile := filepath.Join(outdir, "test.txt")
+
+ // Make the required mount and grab a cleanup for bind mounts
+ // as they are backed by a temp directory (mktemp).
+ mnt, mountCleanup, err := makeMount(machine, fsType, outdir)
+ if err != nil {
+ b.Fatalf("failed to make mount: %v", err)
+ }
+ defer mountCleanup()
+
+ // Start the container with the mount.
+ if err := container.Spawn(
+ ctx,
+ dockerutil.RunOpts{
+ Image: "benchmarks/fio",
+ Mounts: []mount.Mount{
+ mnt,
+ },
+ },
+ // Sleep on the order of b.N.
+ "sleep", fmt.Sprintf("%d", 1000*b.N),
+ ); err != nil {
+ b.Fatalf("failed to start fio container with: %v", err)
+ }
+
+ // For reads, we need a file to read so make one inside the container.
+ if strings.Contains(tc.Test, "read") {
+ fallocateCmd := fmt.Sprintf("fallocate -l %s %s", tc.Size, outfile)
+ if out, err := container.Exec(ctx, dockerutil.ExecOpts{},
+ strings.Split(fallocateCmd, " ")...); err != nil {
+ b.Fatalf("failed to create readable file on mount: %v, %s", err, out)
+ }
+ }
+
+ // Drop caches just before running.
+ if err := harness.DropCaches(machine); err != nil {
+ b.Skipf("failed to drop caches with %v. You probably need root.", err)
+ }
+ cmd := tc.MakeCmd(outfile)
+ container.RestartProfiles()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // Run fio.
+ data, err := container.Exec(ctx, dockerutil.ExecOpts{}, cmd...)
+ if err != nil {
+ b.Fatalf("failed to run cmd %v: %v", cmd, err)
+ }
+ b.StopTimer()
+ tc.Report(b, data)
+ // If b.N is used (i.e. we run for an hour), we should drop caches
+ // after each run.
+ if err := harness.DropCaches(machine); err != nil {
+ b.Fatalf("failed to drop caches: %v", err)
+ }
+ b.StartTimer()
+ }
+ })
+ }
+ }
+}
+
+// makeMount makes a mount and cleanup based on the requested type. Bind
+// and volume mounts are backed by a temp directory made with mktemp.
+// tmpfs mounts require no such backing and are just made.
+// It is up to the caller to call the returned cleanup.
+func makeMount(machine harness.Machine, mountType mount.Type, target string) (mount.Mount, func(), error) {
+ switch mountType {
+ case mount.TypeVolume, mount.TypeBind:
+ dir, err := machine.RunCommand("mktemp", "-d")
+ if err != nil {
+ return mount.Mount{}, func() {}, fmt.Errorf("failed to create tempdir: %v", err)
+ }
+ dir = strings.TrimSuffix(dir, "\n")
+
+ out, err := machine.RunCommand("chmod", "777", dir)
+ if err != nil {
+ machine.RunCommand("rm", "-rf", dir)
+ return mount.Mount{}, func() {}, fmt.Errorf("failed modify directory: %v %s", err, out)
+ }
+ return mount.Mount{
+ Target: target,
+ Source: dir,
+ Type: mount.TypeBind,
+ }, func() { machine.RunCommand("rm", "-rf", dir) }, nil
+ case mount.TypeTmpfs:
+ return mount.Mount{
+ Target: target,
+ Type: mount.TypeTmpfs,
+ }, func() {}, nil
+ default:
+ return mount.Mount{}, func() {}, fmt.Errorf("illegal mount time not supported: %v", mountType)
+ }
+}
diff --git a/test/benchmarks/fs/fs.go b/test/benchmarks/fs/fs.go
new file mode 100644
index 000000000..e5ca28c3b
--- /dev/null
+++ b/test/benchmarks/fs/fs.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package fs holds benchmarks around filesystem performance.
+package fs
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var h harness.Harness
+
+// TestMain is the main method for package fs.
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD
new file mode 100644
index 000000000..c2e316709
--- /dev/null
+++ b/test/benchmarks/harness/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "harness",
+ testonly = 1,
+ srcs = [
+ "harness.go",
+ "machine.go",
+ "util.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
diff --git a/test/benchmarks/harness/harness.go b/test/benchmarks/harness/harness.go
new file mode 100644
index 000000000..68bd7b4cf
--- /dev/null
+++ b/test/benchmarks/harness/harness.go
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package harness holds utility code for running benchmarks on Docker.
+package harness
+
+import (
+ "flag"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+)
+
+// Harness is a handle for managing state in benchmark runs.
+type Harness struct {
+}
+
+// Init performs any harness initilialization before runs.
+func (h *Harness) Init() error {
+ flag.Parse()
+ dockerutil.EnsureSupportedDockerVersion()
+ return nil
+}
+
+// GetMachine returns this run's implementation of machine.
+func (h *Harness) GetMachine() (Machine, error) {
+ return &localMachine{}, nil
+}
diff --git a/test/benchmarks/harness/machine.go b/test/benchmarks/harness/machine.go
new file mode 100644
index 000000000..88e5e841b
--- /dev/null
+++ b/test/benchmarks/harness/machine.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package harness
+
+import (
+ "context"
+ "net"
+ "os/exec"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// Machine describes a real machine for use in benchmarks.
+type Machine interface {
+ // GetContainer gets a container from the machine. The container uses the
+ // runtime under test and is profiled if requested by flags.
+ GetContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container
+
+ // GetNativeContainer gets a native container from the machine. Native containers
+ // use runc by default and are not profiled.
+ GetNativeContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container
+
+ // RunCommand runs cmd on this machine.
+ RunCommand(cmd string, args ...string) (string, error)
+
+ // Returns IP Address for the machine.
+ IPAddress() (net.IP, error)
+
+ // CleanUp cleans up this machine.
+ CleanUp()
+}
+
+// localMachine describes this machine.
+type localMachine struct {
+}
+
+// GetContainer implements Machine.GetContainer for localMachine.
+func (l *localMachine) GetContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container {
+ return dockerutil.MakeContainer(ctx, logger)
+}
+
+// GetContainer implements Machine.GetContainer for localMachine.
+func (l *localMachine) GetNativeContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container {
+ return dockerutil.MakeNativeContainer(ctx, logger)
+}
+
+// RunCommand implements Machine.RunCommand for localMachine.
+func (l *localMachine) RunCommand(cmd string, args ...string) (string, error) {
+ c := exec.Command(cmd, args...)
+ out, err := c.CombinedOutput()
+ return string(out), err
+}
+
+// IPAddress implements Machine.IPAddress.
+func (l *localMachine) IPAddress() (net.IP, error) {
+ conn, err := net.Dial("udp", "8.8.8.8:80")
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ addr := conn.LocalAddr().(*net.UDPAddr)
+ return addr.IP, nil
+}
+
+// CleanUp implements Machine.CleanUp and does nothing for localMachine.
+func (*localMachine) CleanUp() {
+}
diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go
new file mode 100644
index 000000000..86b863f78
--- /dev/null
+++ b/test/benchmarks/harness/util.go
@@ -0,0 +1,48 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package harness
+
+import (
+ "context"
+ "fmt"
+ "net"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+//TODO(gvisor.dev/issue/3535): move to own package or move methods to harness struct.
+
+// WaitUntilServing grabs a container from `machine` and waits for a server at
+// IP:port.
+func WaitUntilServing(ctx context.Context, machine Machine, server net.IP, port int) error {
+ var logger testutil.DefaultLogger = "util"
+ netcat := machine.GetNativeContainer(ctx, logger)
+ defer netcat.CleanUp(ctx)
+
+ cmd := fmt.Sprintf("while ! wget -q --spider http://%s:%d; do true; done", server, port)
+ _, err := netcat.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/util",
+ }, "sh", "-c", cmd)
+ return err
+}
+
+// DropCaches drops caches on the provided machine. Requires root.
+func DropCaches(machine Machine) error {
+ if out, err := machine.RunCommand("/bin/sh", "-c", "sync && sysctl vm.drop_caches=3"); err != nil {
+ return fmt.Errorf("failed to drop caches: %v logs: %s", err, out)
+ }
+ return nil
+}
diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD
new file mode 100644
index 000000000..bb242d385
--- /dev/null
+++ b/test/benchmarks/media/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "media",
+ testonly = 1,
+ srcs = ["media.go"],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "media_test",
+ size = "large",
+ srcs = ["ffmpeg_test.go"],
+ library = ":media",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ ],
+)
diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go
new file mode 100644
index 000000000..7822dfad7
--- /dev/null
+++ b/test/benchmarks/media/ffmpeg_test.go
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package media
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+// BenchmarkFfmpeg runs ffmpeg in a container and records runtime.
+// BenchmarkFfmpeg should run as root to drop caches.
+func BenchmarkFfmpeg(b *testing.B) {
+ machine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer machine.CleanUp()
+
+ ctx := context.Background()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
+ cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ")
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ if err := harness.DropCaches(machine); err != nil {
+ b.Skipf("failed to drop caches: %v. You probably need root.", err)
+ }
+ b.StartTimer()
+
+ if _, err := container.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/ffmpeg",
+ }, cmd...); err != nil {
+ b.Fatalf("failed to run container: %v", err)
+ }
+ }
+}
diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go
new file mode 100644
index 000000000..c7b35b758
--- /dev/null
+++ b/test/benchmarks/media/media.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package media holds benchmarks around media processing applications.
+package media
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var h harness.Harness
+
+// TestMain is the main method for package media.
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD
new file mode 100644
index 000000000..970f52706
--- /dev/null
+++ b/test/benchmarks/ml/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ml",
+ testonly = 1,
+ srcs = ["ml.go"],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "ml_test",
+ size = "large",
+ srcs = ["tensorflow_test.go"],
+ library = ":ml",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ ],
+)
diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go
new file mode 100644
index 000000000..13282d7bb
--- /dev/null
+++ b/test/benchmarks/ml/ml.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ml holds benchmarks around machine learning performance.
+package ml
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var h harness.Harness
+
+// TestMain is the main method for package ml.
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go
new file mode 100644
index 000000000..f7746897d
--- /dev/null
+++ b/test/benchmarks/ml/tensorflow_test.go
@@ -0,0 +1,69 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package ml
+
+import (
+ "context"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+// BenchmarkTensorflow runs workloads from a TensorFlow tutorial.
+// See: https://github.com/aymericdamien/TensorFlow-Examples
+func BenchmarkTensorflow(b *testing.B) {
+ workloads := map[string]string{
+ "GradientDecisionTree": "2_BasicModels/gradient_boosted_decision_tree.py",
+ "Kmeans": "2_BasicModels/kmeans.py",
+ "LogisticRegression": "2_BasicModels/logistic_regression.py",
+ "NearestNeighbor": "2_BasicModels/nearest_neighbor.py",
+ "RandomForest": "2_BasicModels/random_forest.py",
+ "ConvolutionalNetwork": "3_NeuralNetworks/convolutional_network.py",
+ "MultilayerPerceptron": "3_NeuralNetworks/multilayer_perceptron.py",
+ "NeuralNetwork": "3_NeuralNetworks/neural_network.py",
+ }
+
+ machine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer machine.CleanUp()
+
+ for name, workload := range workloads {
+ b.Run(name, func(b *testing.B) {
+ ctx := context.Background()
+ container := machine.GetContainer(ctx, b)
+ defer container.CleanUp(ctx)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ if err := harness.DropCaches(machine); err != nil {
+ b.Skipf("failed to drop caches: %v. You probably need root.", err)
+ }
+ b.StartTimer()
+
+ if out, err := container.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/tensorflow",
+ Env: []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"},
+ WorkDir: "/TensorFlow-Examples/examples",
+ }, "python", workload); err != nil {
+ b.Fatalf("failed to run container: %v logs: %s", err, out)
+ }
+ }
+ })
+ }
+
+}
diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD
new file mode 100644
index 000000000..bd3f6245c
--- /dev/null
+++ b/test/benchmarks/network/BUILD
@@ -0,0 +1,35 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "network",
+ testonly = 1,
+ srcs = ["network.go"],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "network_test",
+ size = "large",
+ srcs = [
+ "httpd_test.go",
+ "iperf_test.go",
+ "nginx_test.go",
+ "node_test.go",
+ "ruby_test.go",
+ ],
+ library = ":network",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ "//test/benchmarks/harness",
+ "//test/benchmarks/tools",
+ ],
+)
diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go
new file mode 100644
index 000000000..336e04c91
--- /dev/null
+++ b/test/benchmarks/network/httpd_test.go
@@ -0,0 +1,181 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package network
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// see Dockerfile '//images/benchmarks/httpd'.
+var docs = map[string]string{
+ "notfound": "notfound",
+ "1Kb": "latin1k.txt",
+ "10Kb": "latin10k.txt",
+ "100Kb": "latin100k.txt",
+ "1000Kb": "latin1000k.txt",
+ "1Mb": "latin1024k.txt",
+ "10Mb": "latin10240k.txt",
+}
+
+// BenchmarkHttpdConcurrency iterates the concurrency argument and tests
+// how well the runtime under test handles requests in parallel.
+func BenchmarkHttpdConcurrency(b *testing.B) {
+ // Grab a machine for the client and server.
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get client: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get server: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // The test iterates over client concurrency, so set other parameters.
+ concurrency := []int{1, 25, 50, 100, 1000}
+
+ for _, c := range concurrency {
+ b.Run(fmt.Sprintf("%d", c), func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: 10000,
+ Concurrency: c,
+ Doc: docs["10Kb"],
+ }
+ runHttpd(b, clientMachine, serverMachine, hey, false /* reverse */)
+ })
+ }
+}
+
+// BenchmarkHttpdDocSize iterates over different sized payloads, testing how
+// well the runtime handles sending different payload sizes.
+func BenchmarkHttpdDocSize(b *testing.B) {
+ benchmarkHttpdDocSize(b, false /* reverse */)
+}
+
+// BenchmarkReverseHttpdDocSize iterates over different sized payloads, testing
+// how well the runtime handles receiving different payload sizes.
+func BenchmarkReverseHttpdDocSize(b *testing.B) {
+ benchmarkHttpdDocSize(b, true /* reverse */)
+}
+
+func benchmarkHttpdDocSize(b *testing.B, reverse bool) {
+ b.Helper()
+
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ for name, filename := range docs {
+ concurrency := []int{1, 25, 50, 100, 1000}
+ for _, c := range concurrency {
+ b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: 10000,
+ Concurrency: c,
+ Doc: filename,
+ }
+ runHttpd(b, clientMachine, serverMachine, hey, reverse)
+ })
+ }
+ }
+}
+
+// runHttpd runs a single test run.
+func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey, reverse bool) {
+ b.Helper()
+
+ // Grab a container from the server.
+ ctx := context.Background()
+ var server *dockerutil.Container
+ if reverse {
+ server = serverMachine.GetNativeContainer(ctx, b)
+ } else {
+ server = serverMachine.GetContainer(ctx, b)
+ }
+
+ defer server.CleanUp(ctx)
+
+ // Copy the docs to /tmp and serve from there.
+ cmd := "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"
+ port := 80
+
+ // Start the server.
+ if err := server.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/httpd",
+ Ports: []int{port},
+ Env: []string{
+ // Standard environmental variables for httpd.
+ "APACHE_RUN_DIR=/tmp",
+ "APACHE_RUN_USER=nobody",
+ "APACHE_RUN_GROUP=nogroup",
+ "APACHE_LOG_DIR=/tmp",
+ "APACHE_PID_FILE=/tmp/apache.pid",
+ },
+ }, "sh", "-c", cmd); err != nil {
+ b.Fatalf("failed to start server: %v", err)
+ }
+
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to find server ip: %v", err)
+ }
+
+ servingPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to find server port %d: %v", port, err)
+ }
+
+ // Check the server is serving.
+ harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
+
+ var client *dockerutil.Container
+ // Grab a client.
+ if reverse {
+ client = clientMachine.GetContainer(ctx, b)
+ } else {
+ client = clientMachine.GetNativeContainer(ctx, b)
+ }
+ defer client.CleanUp(ctx)
+
+ b.ResetTimer()
+ server.RestartProfiles()
+ for i := 0; i < b.N; i++ {
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/hey",
+ }, hey.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("run failed with: %v", err)
+ }
+
+ b.StopTimer()
+ hey.Report(b, out)
+ b.StartTimer()
+ }
+}
diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go
new file mode 100644
index 000000000..b8ab7dfb8
--- /dev/null
+++ b/test/benchmarks/network/iperf_test.go
@@ -0,0 +1,113 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package network
+
+import (
+ "context"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+func BenchmarkIperf(b *testing.B) {
+ iperf := tools.Iperf{
+ Time: 10, // time in seconds to run client.
+ }
+
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer serverMachine.CleanUp()
+ ctx := context.Background()
+ for _, bm := range []struct {
+ name string
+ clientFunc func(context.Context, testutil.Logger) *dockerutil.Container
+ serverFunc func(context.Context, testutil.Logger) *dockerutil.Container
+ }{
+ // We are either measuring the server or the client. The other should be
+ // runc. e.g. Upload sees how fast the runtime under test uploads to a native
+ // server.
+ {
+ name: "Upload",
+ clientFunc: clientMachine.GetContainer,
+ serverFunc: serverMachine.GetNativeContainer,
+ },
+ {
+ name: "Download",
+ clientFunc: clientMachine.GetNativeContainer,
+ serverFunc: serverMachine.GetContainer,
+ },
+ } {
+ b.Run(bm.name, func(b *testing.B) {
+ // Set up the containers.
+ server := bm.serverFunc(ctx, b)
+ defer server.CleanUp(ctx)
+ client := bm.clientFunc(ctx, b)
+ defer client.CleanUp(ctx)
+
+ // iperf serves on port 5001 by default.
+ port := 5001
+
+ // Start the server.
+ if err := server.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/iperf",
+ Ports: []int{port},
+ }, "iperf", "-s"); err != nil {
+ b.Fatalf("failed to start server with: %v", err)
+ }
+
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to find server ip: %v", err)
+ }
+
+ servingPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to find port %d: %v", port, err)
+ }
+
+ // Make sure the server is up and serving before we run.
+ if err := harness.WaitUntilServing(ctx, clientMachine, ip, servingPort); err != nil {
+ b.Fatalf("failed to wait for server: %v", err)
+ }
+ // Run the client.
+ b.ResetTimer()
+
+ // Restart the server profiles. If the server isn't being profiled
+ // this does nothing.
+ server.RestartProfiles()
+ for i := 0; i < b.N; i++ {
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/iperf",
+ }, iperf.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("failed to run client: %v", err)
+ }
+ b.StopTimer()
+ iperf.Report(b, out)
+ b.StartTimer()
+ }
+ })
+ }
+}
diff --git a/test/benchmarks/network/network.go b/test/benchmarks/network/network.go
new file mode 100644
index 000000000..ce17ddb94
--- /dev/null
+++ b/test/benchmarks/network/network.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package network holds benchmarks around raw network performance.
+package network
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var h harness.Harness
+
+// TestMain is the main method for package network.
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go
new file mode 100644
index 000000000..2bf1a3624
--- /dev/null
+++ b/test/benchmarks/network/nginx_test.go
@@ -0,0 +1,104 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package network
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// BenchmarkNginxConcurrency iterates the concurrency argument and tests
+// how well the runtime under test handles requests in parallel.
+// TODO(gvisor.dev/issue/3536): Update with different doc sizes like Httpd.
+func BenchmarkNginxConcurrency(b *testing.B) {
+ // Grab a machine for the client and server.
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get client: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get server: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ concurrency := []int{1, 5, 10, 25}
+ for _, c := range concurrency {
+ b.Run(fmt.Sprintf("%d", c), func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: 10000,
+ Concurrency: c,
+ }
+ runNginx(b, clientMachine, serverMachine, hey)
+ })
+ }
+}
+
+// runHttpd runs a single test run.
+func runNginx(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey) {
+ b.Helper()
+
+ // Grab a container from the server.
+ ctx := context.Background()
+ server := serverMachine.GetContainer(ctx, b)
+ defer server.CleanUp(ctx)
+
+ port := 80
+ // Start the server.
+ if err := server.Spawn(ctx,
+ dockerutil.RunOpts{
+ Image: "benchmarks/nginx",
+ Ports: []int{port},
+ }); err != nil {
+ b.Fatalf("server failed to start: %v", err)
+ }
+
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to find server ip: %v", err)
+ }
+
+ servingPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to find server port %d: %v", port, err)
+ }
+
+ // Check the server is serving.
+ harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
+
+ // Grab a client.
+ client := clientMachine.GetNativeContainer(ctx, b)
+ defer client.CleanUp(ctx)
+
+ b.ResetTimer()
+ server.RestartProfiles()
+ for i := 0; i < b.N; i++ {
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/hey",
+ }, hey.MakeCmd(ip, servingPort)...)
+ if err != nil {
+ b.Fatalf("run failed with: %v", err)
+ }
+ b.StopTimer()
+ hey.Report(b, out)
+ b.StartTimer()
+ }
+}
diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go
new file mode 100644
index 000000000..52eb794c4
--- /dev/null
+++ b/test/benchmarks/network/node_test.go
@@ -0,0 +1,127 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package network
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// BenchmarkNode runs requests using 'hey' against a Node server run on
+// 'runtime'. The server responds to requests by grabbing some data in a
+// redis instance and returns the data in its reponse. The test loops through
+// increasing amounts of concurency for requests.
+func BenchmarkNode(b *testing.B) {
+ concurrency := []int{1, 5, 10, 25}
+ for _, c := range concurrency {
+ b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: b.N * c, // Requests b.N requests per thread.
+ Concurrency: c,
+ }
+ runNode(b, hey)
+ })
+ }
+}
+
+// runNode runs the test for a given # of requests and concurrency.
+func runNode(b *testing.B, hey *tools.Hey) {
+ b.Helper()
+
+ // The machine to hold Redis and the Node Server.
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatal("failed to get machine with: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // The machine to run 'hey'.
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatal("failed to get machine with: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ ctx := context.Background()
+
+ // Spawn a redis instance for the app to use.
+ redis := serverMachine.GetNativeContainer(ctx, b)
+ if err := redis.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ }); err != nil {
+ b.Fatalf("failed to spwan redis instance: %v", err)
+ }
+ defer redis.CleanUp(ctx)
+
+ if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil {
+ b.Fatalf("failed to start redis server: %v %s", err, out)
+ }
+ redisIP, err := redis.FindIP(ctx, false)
+ if err != nil {
+ b.Fatalf("failed to get IP from redis instance: %v", err)
+ }
+
+ // Node runs on port 8080.
+ port := 8080
+
+ // Start-up the Node server.
+ nodeApp := serverMachine.GetContainer(ctx, b)
+ if err := nodeApp.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/node",
+ WorkDir: "/usr/src/app",
+ Links: []string{redis.MakeLink("redis")},
+ Ports: []int{port},
+ }, "node", "index.js", redisIP.String()); err != nil {
+ b.Fatalf("failed to spawn node instance: %v", err)
+ }
+ defer nodeApp.CleanUp(ctx)
+
+ servingIP, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to get ip from server: %v", err)
+ }
+
+ servingPort, err := nodeApp.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to port from node instance: %v", err)
+ }
+
+ // Wait until the Client sees the server as up.
+ harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort)
+
+ heyCmd := hey.MakeCmd(servingIP, servingPort)
+
+ nodeApp.RestartProfiles()
+ b.ResetTimer()
+
+ // the client should run on Native.
+ client := clientMachine.GetNativeContainer(ctx, b)
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/hey",
+ }, heyCmd...)
+ if err != nil {
+ b.Fatalf("hey container failed: %v logs: %s", err, out)
+ }
+
+ // Stop the timer to parse the data and report stats.
+ b.StopTimer()
+ hey.Report(b, out)
+ b.StartTimer()
+}
diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go
new file mode 100644
index 000000000..5e0b2b724
--- /dev/null
+++ b/test/benchmarks/network/ruby_test.go
@@ -0,0 +1,134 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package network
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+ "gvisor.dev/gvisor/test/benchmarks/tools"
+)
+
+// BenchmarkRuby runs requests using 'hey' against a ruby application server.
+// On start, ruby app generates some random data and pushes it to a redis
+// instance. On a request, the app grabs for random entries from the redis
+// server, publishes it to a document, and returns the doc to the request.
+func BenchmarkRuby(b *testing.B) {
+ concurrency := []int{1, 5, 10, 25}
+ for _, c := range concurrency {
+ b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) {
+ hey := &tools.Hey{
+ Requests: b.N * c, // b.N requests per thread.
+ Concurrency: c,
+ }
+ runRuby(b, hey)
+ })
+ }
+}
+
+// runRuby runs the test for a given # of requests and concurrency.
+func runRuby(b *testing.B, hey *tools.Hey) {
+ b.Helper()
+ // The machine to hold Redis and the Ruby Server.
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatal("failed to get machine with: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // The machine to run 'hey'.
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatal("failed to get machine with: %v", err)
+ }
+ defer clientMachine.CleanUp()
+ ctx := context.Background()
+
+ // Spawn a redis instance for the app to use.
+ redis := serverMachine.GetNativeContainer(ctx, b)
+ if err := redis.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ }); err != nil {
+ b.Fatalf("failed to spwan redis instance: %v", err)
+ }
+ defer redis.CleanUp(ctx)
+
+ if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil {
+ b.Fatalf("failed to start redis server: %v %s", err, out)
+ }
+ redisIP, err := redis.FindIP(ctx, false)
+ if err != nil {
+ b.Fatalf("failed to get IP from redis instance: %v", err)
+ }
+
+ // Ruby runs on port 9292.
+ const port = 9292
+
+ // Start-up the Ruby server.
+ rubyApp := serverMachine.GetContainer(ctx, b)
+ if err := rubyApp.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/ruby",
+ WorkDir: "/app",
+ Links: []string{redis.MakeLink("redis")},
+ Ports: []int{port},
+ Env: []string{
+ fmt.Sprintf("PORT=%d", port),
+ "WEB_CONCURRENCY=20",
+ "WEB_MAX_THREADS=20",
+ "RACK_ENV=production",
+ fmt.Sprintf("HOST=%s", redisIP),
+ },
+ User: "nobody",
+ }, "sh", "-c", "/usr/bin/puma"); err != nil {
+ b.Fatalf("failed to spawn node instance: %v", err)
+ }
+ defer rubyApp.CleanUp(ctx)
+
+ servingIP, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatalf("failed to get ip from server: %v", err)
+ }
+
+ servingPort, err := rubyApp.FindPort(ctx, port)
+ if err != nil {
+ b.Fatalf("failed to port from node instance: %v", err)
+ }
+
+ // Wait until the Client sees the server as up.
+ if err := harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort); err != nil {
+ b.Fatalf("failed to wait until serving: %v", err)
+ }
+ heyCmd := hey.MakeCmd(servingIP, servingPort)
+ rubyApp.RestartProfiles()
+ b.ResetTimer()
+
+ // the client should run on Native.
+ client := clientMachine.GetNativeContainer(ctx, b)
+ defer client.CleanUp(ctx)
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/hey",
+ }, heyCmd...)
+ if err != nil {
+ b.Fatalf("hey container failed: %v logs: %s", err, out)
+ }
+
+ // Stop the timer to parse the data and report stats.
+ b.StopTimer()
+ hey.Report(b, out)
+ b.StartTimer()
+}
diff --git a/test/benchmarks/tcp/BUILD b/test/benchmarks/tcp/BUILD
new file mode 100644
index 000000000..6dde7d9e6
--- /dev/null
+++ b/test/benchmarks/tcp/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "cc_binary", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "tcp_proxy",
+ srcs = ["tcp_proxy.go"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/adapters/gonet",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/qdisc/fifo",
+ "//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+# nsjoin is a trivial replacement for nsenter. This is used because nsenter is
+# not available on all systems where this benchmark is run (and we aim to
+# minimize external dependencies.)
+
+cc_binary(
+ name = "nsjoin",
+ srcs = ["nsjoin.c"],
+ visibility = ["//:sandbox"],
+)
+
+sh_binary(
+ name = "tcp_benchmark",
+ srcs = ["tcp_benchmark.sh"],
+ data = [
+ ":nsjoin",
+ ":tcp_proxy",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/test/benchmarks/tcp/README.md b/test/benchmarks/tcp/README.md
new file mode 100644
index 000000000..38e6e69f0
--- /dev/null
+++ b/test/benchmarks/tcp/README.md
@@ -0,0 +1,87 @@
+# TCP Benchmarks
+
+This directory contains a standardized TCP benchmark. This helps to evaluate the
+performance of netstack and native networking stacks under various conditions.
+
+## `tcp_benchmark`
+
+This benchmark allows TCP throughput testing under various conditions. The setup
+consists of an iperf client, a client proxy, a server proxy and an iperf server.
+The client proxy and server proxy abstract the network mechanism used to
+communicate between the iperf client and server.
+
+The setup looks like the following:
+
+```
+ +--------------+ (native) +--------------+
+ | iperf client |[lo @ 10.0.0.1]------>| client proxy |
+ +--------------+ +--------------+
+ [client.0 @ 10.0.0.2]
+ (netstack) | | (native)
+ +------+-----+
+ |
+ [br0]
+ |
+ Network emulation applied ---> [wan.0:wan.1]
+ |
+ [br1]
+ |
+ +------+-----+
+ (netstack) | | (native)
+ [server.0 @ 10.0.0.3]
+ +--------------+ +--------------+
+ | iperf server |<------[lo @ 10.0.0.4]| server proxy |
+ +--------------+ (native) +--------------+
+```
+
+Different configurations can be run using different arguments. For example:
+
+* Native test under normal internet conditions: `tcp_benchmark`
+* Native test under ideal conditions: `tcp_benchmark --ideal`
+* Netstack client under ideal conditions: `tcp_benchmark --client --ideal`
+* Netstack client with 5% packet loss: `tcp_benchmark --client --ideal --loss
+ 5`
+
+Use `tcp_benchmark --help` for full arguments.
+
+This tool may be used to easily generate data for graphing. For example, to
+generate a CSV for various latencies, you might do:
+
+```
+rm -f /tmp/netstack_latency.csv /tmp/native_latency.csv
+latencies=$(seq 0 5 50;
+ seq 60 10 100;
+ seq 125 25 250;
+ seq 300 50 500)
+for latency in $latencies; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --client --ideal --latency $latency)
+ echo $latency,$throughput,$client_cpu >> /tmp/netstack_latency.csv
+done
+for latency in $latencies; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --ideal --latency $latency)
+ echo $latency,$throughput,$client_cpu >> /tmp/native_latency.csv
+done
+```
+
+Similarly, to generate a CSV for various levels of packet loss, the following
+would be appropriate:
+
+```
+rm -f /tmp/netstack_loss.csv /tmp/native_loss.csv
+losses=$(seq 0 0.1 1.0;
+ seq 1.2 0.2 2.0;
+ seq 2.5 0.5 5.0;
+ seq 6.0 1.0 10.0)
+for loss in $losses; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --client --ideal --latency 10 --loss $loss)
+ echo $loss,$throughput,$client_cpu >> /tmp/netstack_loss.csv
+done
+for loss in $losses; do
+ read throughput client_cpu server_cpu <<< \
+ $(./tcp_benchmark --duration 30 --ideal --latency 10 --loss $loss)
+ echo $loss,$throughput,$client_cpu >> /tmp/native_loss.csv
+done
+```
diff --git a/test/benchmarks/tcp/nsjoin.c b/test/benchmarks/tcp/nsjoin.c
new file mode 100644
index 000000000..524b4d549
--- /dev/null
+++ b/test/benchmarks/tcp/nsjoin.c
@@ -0,0 +1,47 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+
+#include <errno.h>
+#include <fcntl.h>
+#include <sched.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+int main(int argc, char** argv) {
+ if (argc <= 2) {
+ fprintf(stderr, "error: must provide a namespace file.\n");
+ fprintf(stderr, "usage: %s <file> [arguments...]\n", argv[0]);
+ return 1;
+ }
+
+ int fd = open(argv[1], O_RDONLY);
+ if (fd < 0) {
+ fprintf(stderr, "error opening %s: %s\n", argv[1], strerror(errno));
+ return 1;
+ }
+ if (setns(fd, 0) < 0) {
+ fprintf(stderr, "error joining %s: %s\n", argv[1], strerror(errno));
+ return 1;
+ }
+
+ execvp(argv[2], &argv[2]);
+ return 1;
+}
diff --git a/test/benchmarks/tcp/tcp_benchmark.sh b/test/benchmarks/tcp/tcp_benchmark.sh
new file mode 100755
index 000000000..ef04b4ace
--- /dev/null
+++ b/test/benchmarks/tcp/tcp_benchmark.sh
@@ -0,0 +1,392 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# TCP benchmark; see README.md for documentation.
+
+# Fixed parameters.
+iperf_port=45201 # Not likely to be privileged.
+proxy_port=44000 # Ditto.
+client_addr=10.0.0.1
+client_proxy_addr=10.0.0.2
+server_proxy_addr=10.0.0.3
+server_addr=10.0.0.4
+mask=8
+
+# Defaults; this provides a reasonable approximation of a decent internet link.
+# Parameters can be varied independently from this set to see response to
+# various changes in the kind of link available.
+client=false
+server=false
+verbose=false
+gso=0
+swgso=false
+mtu=1280 # 1280 is a reasonable lowest-common-denominator.
+latency=10 # 10ms approximates a fast, dedicated connection.
+latency_variation=1 # +/- 1ms is a relatively low amount of jitter.
+loss=0.1 # 0.1% loss is non-zero, but not extremely high.
+duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses.
+duration=30 # 30s is enough time to consistent results (experimentally).
+helper_dir=$(dirname $0)
+netstack_opts=
+disable_linux_gso=
+num_client_threads=1
+
+# Check for netem support.
+lsmod_output=$(lsmod | grep sch_netem)
+if [ "$?" != "0" ]; then
+ echo "warning: sch_netem may not be installed." >&2
+fi
+
+while [ $# -gt 0 ]; do
+ case "$1" in
+ --client)
+ client=true
+ ;;
+ --client_tcp_probe_file)
+ shift
+ netstack_opts="${netstack_opts} -client_tcp_probe_file=$1"
+ ;;
+ --server)
+ server=true
+ ;;
+ --verbose)
+ verbose=true
+ ;;
+ --gso)
+ shift
+ gso=$1
+ ;;
+ --swgso)
+ swgso=true
+ ;;
+ --server_tcp_probe_file)
+ shift
+ netstack_opts="${netstack_opts} -server_tcp_probe_file=$1"
+ ;;
+ --ideal)
+ mtu=1500 # Standard ethernet.
+ latency=0 # No latency.
+ latency_variation=0 # No jitter.
+ loss=0 # No loss.
+ duplicate=0 # No duplicates.
+ ;;
+ --mtu)
+ shift
+ [ "$#" -le 0 ] && echo "no mtu provided" && exit 1
+ mtu=$1
+ ;;
+ --sack)
+ netstack_opts="${netstack_opts} -sack"
+ ;;
+ --cubic)
+ netstack_opts="${netstack_opts} -cubic"
+ ;;
+ --moderate-recv-buf)
+ netstack_opts="${netstack_opts} -moderate_recv_buf"
+ ;;
+ --duration)
+ shift
+ [ "$#" -le 0 ] && echo "no duration provided" && exit 1
+ duration=$1
+ ;;
+ --latency)
+ shift
+ [ "$#" -le 0 ] && echo "no latency provided" && exit 1
+ latency=$1
+ ;;
+ --latency-variation)
+ shift
+ [ "$#" -le 0 ] && echo "no latency variation provided" && exit 1
+ latency_variation=$1
+ ;;
+ --loss)
+ shift
+ [ "$#" -le 0 ] && echo "no loss probability provided" && exit 1
+ loss=$1
+ ;;
+ --duplicate)
+ shift
+ [ "$#" -le 0 ] && echo "no duplicate provided" && exit 1
+ duplicate=$1
+ ;;
+ --cpuprofile)
+ shift
+ netstack_opts="${netstack_opts} -cpuprofile=$1"
+ ;;
+ --memprofile)
+ shift
+ netstack_opts="${netstack_opts} -memprofile=$1"
+ ;;
+ --disable-linux-gso)
+ disable_linux_gso=1
+ ;;
+ --num-client-threads)
+ shift
+ num_client_threads=$1
+ ;;
+ --helpers)
+ shift
+ [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1
+ helper_dir=$1
+ ;;
+ *)
+ echo "usage: $0 [options]"
+ echo "options:"
+ echo " --help show this message"
+ echo " --verbose verbose output"
+ echo " --client use netstack as the client"
+ echo " --ideal reset all network emulation"
+ echo " --server use netstack as the server"
+ echo " --mtu set the mtu (bytes)"
+ echo " --sack enable SACK support"
+ echo " --moderate-recv-buf enable TCP receive buffer auto-tuning"
+ echo " --cubic enable CUBIC congestion control for Netstack"
+ echo " --duration set the test duration (s)"
+ echo " --latency set the latency (ms)"
+ echo " --latency-variation set the latency variation"
+ echo " --loss set the loss probability (%)"
+ echo " --duplicate set the duplicate probability (%)"
+ echo " --helpers set the helper directory"
+ echo " --num-client-threads number of parallel client threads to run"
+ echo " --disable-linux-gso disable segmentation offload in the Linux network stack"
+ echo ""
+ echo "The output will of the script will be:"
+ echo " <throughput> <client-cpu-usage> <server-cpu-usage>"
+ exit 1
+ esac
+ shift
+done
+
+if [ ${verbose} == "true" ]; then
+ set -x
+fi
+
+# Latency needs to be halved, since it's applied on both ways.
+half_latency=$(echo ${latency}/2 | bc -l | awk '{printf "%1.2f", $0}')
+half_loss=$(echo ${loss}/2 | bc -l | awk '{printf "%1.6f", $0}')
+half_duplicate=$(echo ${duplicate}/2 | bc -l | awk '{printf "%1.6f", $0}')
+helper_dir=${helper_dir#$(pwd)/} # Use relative paths.
+proxy_binary=${helper_dir}/tcp_proxy
+nsjoin_binary=${helper_dir}/nsjoin
+
+if [ ! -e ${proxy_binary} ]; then
+ echo "Could not locate ${proxy_binary}, please make sure you've built the binary"
+ exit 1
+fi
+
+if [ ! -e ${nsjoin_binary} ]; then
+ echo "Could not locate ${nsjoin_binary}, please make sure you've built the binary"
+ exit 1
+fi
+
+if [ $(echo ${latency_variation} | awk '{printf "%1.2f", $0}') != "0.00" ]; then
+ # As long as there's some jitter, then we use the paretonormal distribution.
+ # This will preserve the minimum RTT, but add a realistic amount of jitter to
+ # the connection and cause re-ordering, etc. The regular pareto distribution
+ # appears to an unreasonable level of delay (we want only small spikes.)
+ distribution="distribution paretonormal"
+else
+ distribution=""
+fi
+
+# Client proxy that will listen on the client's iperf target forward traffic
+# using the host networking stack.
+client_args="${proxy_binary} -port ${proxy_port} -forward ${server_proxy_addr}:${proxy_port}"
+if ${client}; then
+ # Client proxy that will listen on the client's iperf target
+ # and forward traffic using netstack.
+ client_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -client \\
+ -mtu ${mtu} -iface client.0 -addr ${client_proxy_addr} -mask ${mask} \\
+ -forward ${server_proxy_addr}:${proxy_port} -gso=${gso} -swgso=${swgso}"
+fi
+
+# Server proxy that will listen on the proxy port and forward to the server's
+# iperf server using the host networking stack.
+server_args="${proxy_binary} -port ${proxy_port} -forward ${server_addr}:${iperf_port}"
+if ${server}; then
+ # Server proxy that will listen on the proxy port and forward to the servers'
+ # iperf server using netstack.
+ server_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -server \\
+ -mtu ${mtu} -iface server.0 -addr ${server_proxy_addr} -mask ${mask} \\
+ -forward ${server_addr}:${iperf_port} -gso=${gso} -swgso=${swgso}"
+fi
+
+# Specify loss and duplicate parameters only if they are non-zero
+loss_opt=""
+if [ "$(echo $half_loss | bc -q)" != "0" ]; then
+ loss_opt="loss random ${half_loss}%"
+fi
+duplicate_opt=""
+if [ "$(echo $half_duplicate | bc -q)" != "0" ]; then
+ duplicate_opt="duplicate ${half_duplicate}%"
+fi
+
+exec unshare -U -m -n -r -f -p --mount-proc /bin/bash << EOF
+set -e -m
+
+if [ ${verbose} == "true" ]; then
+ set -x
+fi
+
+mount -t tmpfs netstack-bench /tmp
+
+# We may have reset the path in the unshare if the shell loaded some public
+# profiles. Ensure that tools are discoverable via the parent's PATH.
+export PATH=${PATH}
+
+# Add client, server interfaces.
+ip link add client.0 type veth peer name client.1
+ip link add server.0 type veth peer name server.1
+
+# Add network emulation devices.
+ip link add wan.0 type veth peer name wan.1
+ip link set wan.0 up
+ip link set wan.1 up
+
+# Enroll on the bridge.
+ip link add name br0 type bridge
+ip link add name br1 type bridge
+ip link set client.1 master br0
+ip link set server.1 master br1
+ip link set wan.0 master br0
+ip link set wan.1 master br1
+ip link set br0 up
+ip link set br1 up
+
+# Set the MTU appropriately.
+ip link set client.0 mtu ${mtu}
+ip link set server.0 mtu ${mtu}
+ip link set wan.0 mtu ${mtu}
+ip link set wan.1 mtu ${mtu}
+
+# Add appropriate latency, loss and duplication.
+#
+# This is added in at the point of bridge connection.
+for device in wan.0 wan.1; do
+ # NOTE: We don't support a loss correlation as testing has shown that it
+ # actually doesn't work. The man page actually has a small comment about this
+ # "It is also possible to add a correlation, but this option is now deprecated
+ # due to the noticed bad behavior." For more information see netem(8).
+ tc qdisc add dev \$device root netem \\
+ delay ${half_latency}ms ${latency_variation}ms ${distribution} \\
+ ${loss_opt} ${duplicate_opt}
+done
+
+# Start a client proxy.
+touch /tmp/client.netns
+unshare -n mount --bind /proc/self/ns/net /tmp/client.netns
+
+# Move the endpoint into the namespace.
+while ip link | grep client.0 > /dev/null; do
+ ip link set dev client.0 netns /tmp/client.netns
+done
+
+if ! ${client}; then
+ # Only add the address to NIC if netstack is not in use. Otherwise the host
+ # will also process the inbound SYN and send a RST back.
+ ${nsjoin_binary} /tmp/client.netns ip addr add ${client_proxy_addr}/${mask} dev client.0
+fi
+
+# Start a server proxy.
+touch /tmp/server.netns
+unshare -n mount --bind /proc/self/ns/net /tmp/server.netns
+# Move the endpoint into the namespace.
+while ip link | grep server.0 > /dev/null; do
+ ip link set dev server.0 netns /tmp/server.netns
+done
+if ! ${server}; then
+ # Only add the address to NIC if netstack is not in use. Otherwise the host
+ # will also process the inbound SYN and send a RST back.
+ ${nsjoin_binary} /tmp/server.netns ip addr add ${server_proxy_addr}/${mask} dev server.0
+fi
+
+# Add client and server addresses, and bring everything up.
+${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0
+${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0
+if [ "${disable_linux_gso}" == "1" ]; then
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off
+ ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off
+ ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off
+fi
+${nsjoin_binary} /tmp/client.netns ip link set client.0 up
+${nsjoin_binary} /tmp/client.netns ip link set lo up
+${nsjoin_binary} /tmp/server.netns ip link set server.0 up
+${nsjoin_binary} /tmp/server.netns ip link set lo up
+ip link set dev client.1 up
+ip link set dev server.1 up
+
+${nsjoin_binary} /tmp/client.netns ${client_args} &
+client_pid=\$!
+${nsjoin_binary} /tmp/server.netns ${server_args} &
+server_pid=\$!
+
+# Start the iperf server.
+${nsjoin_binary} /tmp/server.netns iperf -p ${iperf_port} -s >&2 &
+iperf_pid=\$!
+
+# Show traffic information.
+if ! ${client} && ! ${server}; then
+ ${nsjoin_binary} /tmp/client.netns ping -c 100 -i 0.001 -W 1 ${server_addr} >&2 || true
+fi
+
+results_file=\$(mktemp)
+function cleanup {
+ rm -f \$results_file
+ kill -TERM \$client_pid
+ kill -TERM \$server_pid
+ wait \$client_pid
+ wait \$server_pid
+ kill -9 \$iperf_pid 2>/dev/null
+}
+
+# Allow failure from this point.
+set +e
+trap cleanup EXIT
+
+# Run the benchmark, recording the results file.
+while ${nsjoin_binary} /tmp/client.netns iperf \\
+ -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\
+ | tee \$results_file \\
+ | grep "connect failed" >/dev/null; do
+ sleep 0.1 # Wait for all services.
+done
+
+# Unlink all relevant devices from the bridge. This is because when the bridge
+# is deleted, the kernel may hang. It appears that this problem is fixed in
+# upstream commit 1ce5cce895309862d2c35d922816adebe094fe4a.
+ip link set client.1 nomaster
+ip link set server.1 nomaster
+ip link set wan.0 nomaster
+ip link set wan.1 nomaster
+
+# Emit raw results.
+cat \$results_file >&2
+
+# Emit a useful result (final throughput).
+mbits=\$(grep Mbits/sec \$results_file \\
+ | sed -n -e 's/^.*[[:space:]]\\([[:digit:]]\\+\\(\\.[[:digit:]]\\+\\)\\?\\)[[:space:]]*Mbits\\/sec.*/\\1/p')
+client_cpu_ticks=\$(cat /proc/\$client_pid/stat \\
+ | awk '{print (\$14+\$15);}')
+server_cpu_ticks=\$(cat /proc/\$server_pid/stat \\
+ | awk '{print (\$14+\$15);}')
+ticks_per_sec=\$(getconf CLK_TCK)
+client_cpu_load=\$(bc -l <<< \$client_cpu_ticks/\$ticks_per_sec/${duration})
+server_cpu_load=\$(bc -l <<< \$server_cpu_ticks/\$ticks_per_sec/${duration})
+echo \$mbits \$client_cpu_load \$server_cpu_load
+EOF
diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go
new file mode 100644
index 000000000..4b7ca7a14
--- /dev/null
+++ b/test/benchmarks/tcp/tcp_proxy.go
@@ -0,0 +1,451 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary tcp_proxy is a simple TCP proxy.
+package main
+
+import (
+ "encoding/gob"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "os/signal"
+ "regexp"
+ "runtime"
+ "runtime/pprof"
+ "strconv"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo"
+ "gvisor.dev/gvisor/pkg/tcpip/network/arp"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+)
+
+var (
+ port = flag.Int("port", 0, "bind port (all addresses)")
+ forward = flag.String("forward", "", "forwarding target")
+ client = flag.Bool("client", false, "use netstack for listen")
+ server = flag.Bool("server", false, "use netstack for dial")
+
+ // Netstack-specific options.
+ mtu = flag.Int("mtu", 1280, "mtu for network stack")
+ addr = flag.String("addr", "", "address for tap-based netstack")
+ mask = flag.Int("mask", 8, "mask size for address")
+ iface = flag.String("iface", "", "network interface name to bind for netstack")
+ sack = flag.Bool("sack", false, "enable SACK support for netstack")
+ moderateRecvBuf = flag.Bool("moderate_recv_buf", false, "enable TCP Receive Buffer Auto-tuning")
+ cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack")
+ gso = flag.Int("gso", 0, "GSO maximum size")
+ swgso = flag.Bool("swgso", false, "software-level GSO")
+ clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.")
+ serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.")
+ cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.")
+ memprofile = flag.String("memprofile", "", "write memory profile to the specified file.")
+)
+
+type impl interface {
+ dial(address string) (net.Conn, error)
+ listen(port int) (net.Listener, error)
+ printStats()
+}
+
+type netImpl struct{}
+
+func (netImpl) dial(address string) (net.Conn, error) {
+ return net.Dial("tcp", address)
+}
+
+func (netImpl) listen(port int) (net.Listener, error) {
+ return net.Listen("tcp", fmt.Sprintf(":%d", port))
+}
+
+func (netImpl) printStats() {
+}
+
+const (
+ nicID = 1 // Fixed.
+ bufSize = 4 << 20 // 4MB.
+)
+
+type netstackImpl struct {
+ s *stack.Stack
+ addr tcpip.Address
+ mode string
+}
+
+func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) {
+ // Get all interfaces in the namespace.
+ ifaces, err := net.Interfaces()
+ if err != nil {
+ return nil, fmt.Errorf("querying interfaces: %v", err)
+ }
+
+ for _, iface := range ifaces {
+ if iface.Name != ifaceName {
+ continue
+ }
+ // Create the socket.
+ const protocol = 0x0300 // htons(ETH_P_ALL)
+ fds := make([]int, numChannels)
+ for i := range fds {
+ fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
+ if err != nil {
+ return nil, fmt.Errorf("unable to create raw socket: %v", err)
+ }
+
+ // Bind to the appropriate device.
+ ll := syscall.SockaddrLinklayer{
+ Protocol: protocol,
+ Ifindex: iface.Index,
+ Pkttype: syscall.PACKET_HOST,
+ }
+ if err := syscall.Bind(fd, &ll); err != nil {
+ return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
+ }
+
+ // RAW Sockets by default have a very small SO_RCVBUF of 256KB,
+ // up it to at least 4MB to reduce packet drops.
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err)
+ }
+
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil {
+ return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err)
+ }
+
+ if !*swgso && *gso != 0 {
+ if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
+ return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
+ }
+ }
+ fds[i] = fd
+ }
+ return fds, nil
+ }
+ return nil, fmt.Errorf("failed to find interface: %v", ifaceName)
+}
+
+func newNetstackImpl(mode string) (impl, error) {
+ fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1))
+ if err != nil {
+ return nil, err
+ }
+
+ // Parse details.
+ parsedAddr := tcpip.Address(net.ParseIP(*addr).To4())
+ parsedDest := tcpip.Address("") // Filled in below.
+ parsedMask := tcpip.AddressMask("") // Filled in below.
+ switch *mask {
+ case 8:
+ parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0})
+ parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0})
+ case 16:
+ parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0})
+ parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0})
+ case 24:
+ parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0})
+ parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0})
+ default:
+ // This is just laziness; we don't expect a different mask.
+ return nil, fmt.Errorf("mask %d not supported", mask)
+ }
+
+ // Create a new network stack.
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}
+ s := stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ })
+
+ // Generate a new mac for the eth device.
+ mac := make(net.HardwareAddr, 6)
+ rand.Read(mac) // Fill with random data.
+ mac[0] &^= 0x1 // Clear multicast bit.
+ mac[0] |= 0x2 // Set local assignment bit (IEEE802).
+ ep, err := fdbased.New(&fdbased.Options{
+ FDs: fds,
+ MTU: uint32(*mtu),
+ EthernetHeader: true,
+ Address: tcpip.LinkAddress(mac),
+ // Enable checksum generation as we need to generate valid
+ // checksums for the veth device to deliver our packets to the
+ // peer. But we do want to disable checksum verification as veth
+ // devices do perform GRO and the linux host kernel may not
+ // regenerate valid checksums after GRO.
+ TXChecksumOffload: false,
+ RXChecksumOffload: true,
+ PacketDispatchMode: fdbased.RecvMMsg,
+ GSOMaxSize: uint32(*gso),
+ SoftwareGSOEnabled: *swgso,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to create FD endpoint: %v", err)
+ }
+ if err := s.CreateNIC(nicID, fifo.New(ep, runtime.GOMAXPROCS(0), 1000)); err != nil {
+ return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err)
+ }
+ if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err)
+ }
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil {
+ return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err)
+ }
+
+ subnet, err := tcpip.NewSubnet(parsedDest, parsedMask)
+ if err != nil {
+ return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err)
+ }
+ // Add default route; we only support
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: nicID,
+ },
+ })
+
+ // Set protocol options.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err)
+ }
+
+ // Enable Receive Buffer Auto-Tuning.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(*moderateRecvBuf)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err)
+ }
+
+ // Set Congestion Control to cubic if requested.
+ if *cubic {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %s", err)
+ }
+ }
+
+ return netstackImpl{
+ s: s,
+ addr: parsedAddr,
+ mode: mode,
+ }, nil
+}
+
+func (n netstackImpl) dial(address string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, err
+ }
+ if host == "" {
+ // A host must be provided for the dial.
+ return nil, fmt.Errorf("no host provided")
+ }
+ portNumber, err := strconv.Atoi(port)
+ if err != nil {
+ return nil, err
+ }
+ addr := tcpip.FullAddress{
+ NIC: nicID,
+ Addr: tcpip.Address(net.ParseIP(host).To4()),
+ Port: uint16(portNumber),
+ }
+ conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+func (n netstackImpl) listen(port int) (net.Listener, error) {
+ addr := tcpip.FullAddress{
+ NIC: nicID,
+ Port: uint16(port),
+ }
+ listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ return nil, err
+ }
+ return listener, nil
+}
+
+var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`)
+
+func (n netstackImpl) printStats() {
+ // Don't show zero fields.
+ stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "")
+ log.Printf("netstack %s Stats: %+v\n", n.mode, stats)
+}
+
+// installProbe installs a TCP Probe function that will dump endpoint
+// state to the specified file. It also returns a close func() that
+// can be used to close the probeFile.
+func (n netstackImpl) installProbe(probeFileName string) (close func()) {
+ // Install Probe to dump out end point state.
+ probeFile, err := os.Create(probeFileName)
+ if err != nil {
+ log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err)
+ }
+ probeEncoder := gob.NewEncoder(probeFile)
+ // Install a TCP Probe.
+ n.s.AddTCPProbe(func(state stack.TCPEndpointState) {
+ probeEncoder.Encode(state)
+ })
+ return func() { probeFile.Close() }
+}
+
+func main() {
+ flag.Parse()
+ if *port == 0 {
+ log.Fatalf("no port provided")
+ }
+ if *forward == "" {
+ log.Fatalf("no forward provided")
+ }
+ // Seed the random number generator to ensure that we are given MAC addresses that don't
+ // for the case of the client and server stack.
+ rand.Seed(time.Now().UTC().UnixNano())
+
+ if *cpuprofile != "" {
+ f, err := os.Create(*cpuprofile)
+ if err != nil {
+ log.Fatal("could not create CPU profile: ", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ log.Print("error closing CPU profile: ", err)
+ }
+ }()
+ if err := pprof.StartCPUProfile(f); err != nil {
+ log.Fatal("could not start CPU profile: ", err)
+ }
+ defer pprof.StopCPUProfile()
+ }
+
+ var (
+ in impl
+ out impl
+ err error
+ )
+ if *server {
+ in, err = newNetstackImpl("server")
+ if *serverTCPProbeFile != "" {
+ defer in.(netstackImpl).installProbe(*serverTCPProbeFile)()
+ }
+
+ } else {
+ in = netImpl{}
+ }
+ if err != nil {
+ log.Fatalf("netstack error: %v", err)
+ }
+ if *client {
+ out, err = newNetstackImpl("client")
+ if *clientTCPProbeFile != "" {
+ defer out.(netstackImpl).installProbe(*clientTCPProbeFile)()
+ }
+ } else {
+ out = netImpl{}
+ }
+ if err != nil {
+ log.Fatalf("netstack error: %v", err)
+ }
+
+ // Dial forward before binding.
+ var next net.Conn
+ for {
+ next, err = out.dial(*forward)
+ if err == nil {
+ break
+ }
+ time.Sleep(50 * time.Millisecond)
+ log.Printf("connect failed retrying: %v", err)
+ }
+
+ // Bind once to the server socket.
+ listener, err := in.listen(*port)
+ if err != nil {
+ // Should not happen, everything must be bound by this time
+ // this proxy is started.
+ log.Fatalf("unable to listen: %v", err)
+ }
+ log.Printf("client=%v, server=%v, ready.", *client, *server)
+
+ sigs := make(chan os.Signal, 1)
+ signal.Notify(sigs, syscall.SIGTERM)
+ go func() {
+ <-sigs
+ if *cpuprofile != "" {
+ pprof.StopCPUProfile()
+ }
+ if *memprofile != "" {
+ f, err := os.Create(*memprofile)
+ if err != nil {
+ log.Fatal("could not create memory profile: ", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ log.Print("error closing memory profile: ", err)
+ }
+ }()
+ runtime.GC() // get up-to-date statistics
+ if err := pprof.WriteHeapProfile(f); err != nil {
+ log.Fatalf("Unable to write heap profile: %v", err)
+ }
+ }
+ os.Exit(0)
+ }()
+
+ for {
+ // Forward all connections.
+ inConn, err := listener.Accept()
+ if err != nil {
+ // This should not happen; we are listening
+ // successfully. Exhausted all available FDs?
+ log.Fatalf("accept error: %v", err)
+ }
+ log.Printf("incoming connection established.")
+
+ // Copy both ways.
+ go io.Copy(inConn, next)
+ go io.Copy(next, inConn)
+
+ // Print stats every second.
+ go func() {
+ t := time.NewTicker(time.Second)
+ defer t.Stop()
+ for {
+ <-t.C
+ in.printStats()
+ out.printStats()
+ }
+ }()
+
+ for {
+ // Dial again.
+ next, err = out.dial(*forward)
+ if err == nil {
+ break
+ }
+ }
+ }
+}
diff --git a/test/benchmarks/tools/BUILD b/test/benchmarks/tools/BUILD
new file mode 100644
index 000000000..e5734d85c
--- /dev/null
+++ b/test/benchmarks/tools/BUILD
@@ -0,0 +1,33 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tools",
+ srcs = [
+ "ab.go",
+ "fio.go",
+ "hey.go",
+ "iperf.go",
+ "meminfo.go",
+ "redis.go",
+ "sysbench.go",
+ "tools.go",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "tools_test",
+ size = "small",
+ srcs = [
+ "ab_test.go",
+ "fio_test.go",
+ "hey_test.go",
+ "iperf_test.go",
+ "meminfo_test.go",
+ "redis_test.go",
+ "sysbench_test.go",
+ ],
+ library = ":tools",
+)
diff --git a/test/benchmarks/tools/ab.go b/test/benchmarks/tools/ab.go
new file mode 100644
index 000000000..4cc9c3bce
--- /dev/null
+++ b/test/benchmarks/tools/ab.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strconv"
+ "testing"
+)
+
+// ApacheBench is for the client application ApacheBench.
+type ApacheBench struct {
+ Requests int
+ Concurrency int
+ Doc string
+ // TODO(zkoopmans): support KeepAlive and pass option to enable.
+}
+
+// MakeCmd makes an ApacheBench command.
+func (a *ApacheBench) MakeCmd(ip net.IP, port int) []string {
+ path := fmt.Sprintf("http://%s:%d/%s", ip, port, a.Doc)
+ // See apachebench (ab) for flags.
+ cmd := fmt.Sprintf("ab -n %d -c %d %s", a.Requests, a.Concurrency, path)
+ return []string{"sh", "-c", cmd}
+}
+
+// Report parses and reports metrics from ApacheBench output.
+func (a *ApacheBench) Report(b *testing.B, output string) {
+ // Parse and report custom metrics.
+ transferRate, err := a.parseTransferRate(output)
+ if err != nil {
+ b.Logf("failed to parse transferrate: %v", err)
+ }
+ b.ReportMetric(transferRate*1024, "transfer_rate_b/s") // Convert from Kb/s to b/s.
+
+ latency, err := a.parseLatency(output)
+ if err != nil {
+ b.Logf("failed to parse latency: %v", err)
+ }
+ b.ReportMetric(latency/1000, "mean_latency_secs") // Convert from ms to s.
+
+ reqPerSecond, err := a.parseRequestsPerSecond(output)
+ if err != nil {
+ b.Logf("failed to parse requests per second: %v", err)
+ }
+ b.ReportMetric(reqPerSecond, "requests_per_second")
+}
+
+var transferRateRE = regexp.MustCompile(`Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received`)
+
+// parseTransferRate parses transfer rate from ApacheBench output.
+func (a *ApacheBench) parseTransferRate(data string) (float64, error) {
+ match := transferRateRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0, fmt.Errorf("failed get bandwidth: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
+
+var latencyRE = regexp.MustCompile(`Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s`)
+
+// parseLatency parses latency from ApacheBench output.
+func (a *ApacheBench) parseLatency(data string) (float64, error) {
+ match := latencyRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0, fmt.Errorf("failed get bandwidth: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
+
+var requestsPerSecondRE = regexp.MustCompile(`Requests per second:\s+(\d+\.?\d+?)\s+`)
+
+// parseRequestsPerSecond parses requests per second from ApacheBench output.
+func (a *ApacheBench) parseRequestsPerSecond(data string) (float64, error) {
+ match := requestsPerSecondRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0, fmt.Errorf("failed get bandwidth: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
diff --git a/test/benchmarks/tools/ab_test.go b/test/benchmarks/tools/ab_test.go
new file mode 100644
index 000000000..28ee66ec1
--- /dev/null
+++ b/test/benchmarks/tools/ab_test.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import "testing"
+
+// TestApacheBench checks the ApacheBench parsers on sample output.
+func TestApacheBench(t *testing.T) {
+ // Sample output from apachebench.
+ sampleData := `This is ApacheBench, Version 2.3 <$Revision: 1826891 $>
+Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/
+Licensed to The Apache Software Foundation, http://www.apache.org/
+
+Benchmarking 10.10.10.10 (be patient).....done
+
+
+Server Software: Apache/2.4.38
+Server Hostname: 10.10.10.10
+Server Port: 80
+
+Document Path: /latin10k.txt
+Document Length: 210 bytes
+
+Concurrency Level: 1
+Time taken for tests: 0.180 seconds
+Complete requests: 100
+Failed requests: 0
+Non-2xx responses: 100
+Total transferred: 38800 bytes
+HTML transferred: 21000 bytes
+Requests per second: 556.44 [#/sec] (mean)
+Time per request: 1.797 [ms] (mean)
+Time per request: 1.797 [ms] (mean, across all concurrent requests)
+Transfer rate: 210.84 [Kbytes/sec] received
+
+Connection Times (ms)
+ min mean[+/-sd] median max
+Connect: 0 0 0.2 0 2
+Processing: 1 2 1.0 1 8
+Waiting: 1 1 1.0 1 7
+Total: 1 2 1.2 1 10
+
+Percentage of the requests served within a certain time (ms)
+ 50% 1
+ 66% 2
+ 75% 2
+ 80% 2
+ 90% 2
+ 95% 3
+ 98% 7
+ 99% 10
+ 100% 10 (longest request)`
+
+ ab := ApacheBench{}
+ want := 210.84
+ got, err := ab.parseTransferRate(sampleData)
+ if err != nil {
+ t.Fatalf("failed to parse transfer rate with error: %v", err)
+ } else if got != want {
+ t.Fatalf("parseTransferRate got: %f, want: %f", got, want)
+ }
+
+ want = 2.0
+ got, err = ab.parseLatency(sampleData)
+ if err != nil {
+ t.Fatalf("failed to parse transfer rate with error: %v", err)
+ } else if got != want {
+ t.Fatalf("parseLatency got: %f, want: %f", got, want)
+ }
+
+ want = 556.44
+ got, err = ab.parseRequestsPerSecond(sampleData)
+ if err != nil {
+ t.Fatalf("failed to parse transfer rate with error: %v", err)
+ } else if got != want {
+ t.Fatalf("parseRequestsPerSecond got: %f, want: %f", got, want)
+ }
+}
diff --git a/test/benchmarks/tools/fio.go b/test/benchmarks/tools/fio.go
new file mode 100644
index 000000000..20000db16
--- /dev/null
+++ b/test/benchmarks/tools/fio.go
@@ -0,0 +1,124 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+// Fio makes 'fio' commands and parses their output.
+type Fio struct {
+ Test string // test to run: read, write, randread, randwrite.
+ Size string // total size to be read/written of format N[GMK] (e.g. 5G).
+ Blocksize string // blocksize to be read/write of format N[GMK] (e.g. 4K).
+ Iodepth int // iodepth for reads/writes.
+ Time int // time to run the test in seconds, usually for rand(read/write).
+}
+
+// MakeCmd makes a 'fio' command.
+func (f *Fio) MakeCmd(filename string) []string {
+ cmd := []string{"fio", "--output-format=json", "--ioengine=sync"}
+ cmd = append(cmd, fmt.Sprintf("--name=%s", f.Test))
+ cmd = append(cmd, fmt.Sprintf("--size=%s", f.Size))
+ cmd = append(cmd, fmt.Sprintf("--blocksize=%s", f.Blocksize))
+ cmd = append(cmd, fmt.Sprintf("--filename=%s", filename))
+ cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.Iodepth))
+ cmd = append(cmd, fmt.Sprintf("--rw=%s", f.Test))
+ if f.Time != 0 {
+ cmd = append(cmd, "--time_based")
+ cmd = append(cmd, fmt.Sprintf("--runtime=%d", f.Time))
+ }
+ return cmd
+}
+
+// Report reports metrics based on output from an 'fio' command.
+func (f *Fio) Report(b *testing.B, output string) {
+ b.Helper()
+ // Parse the output and report the metrics.
+ isRead := strings.Contains(f.Test, "read")
+ bw, err := f.parseBandwidth(output, isRead)
+ if err != nil {
+ b.Fatalf("failed to parse bandwidth from %s with: %v", output, err)
+ }
+ b.ReportMetric(bw, "bandwidth_b/s") // in b/s.
+
+ iops, err := f.parseIOps(output, isRead)
+ if err != nil {
+ b.Fatalf("failed to parse iops from %s with: %v", output, err)
+ }
+ b.ReportMetric(iops, "iops")
+}
+
+// parseBandwidth reports the bandwidth in b/s.
+func (f *Fio) parseBandwidth(data string, isRead bool) (float64, error) {
+ if isRead {
+ result, err := f.parseFioJSON(data, "read", "bw")
+ if err != nil {
+ return 0, err
+ }
+ return 1024 * result, nil
+ }
+ result, err := f.parseFioJSON(data, "write", "bw")
+ if err != nil {
+ return 0, err
+ }
+ return 1024 * result, nil
+}
+
+// parseIOps reports the write IO per second metric.
+func (f *Fio) parseIOps(data string, isRead bool) (float64, error) {
+ if isRead {
+ return f.parseFioJSON(data, "read", "iops")
+ }
+ return f.parseFioJSON(data, "write", "iops")
+}
+
+// fioResult is for parsing FioJSON.
+type fioResult struct {
+ Jobs []fioJob
+}
+
+// fioJob is for parsing FioJSON.
+type fioJob map[string]json.RawMessage
+
+// fioMetrics is for parsing FioJSON.
+type fioMetrics map[string]json.RawMessage
+
+// parseFioJSON parses data and grabs "op" (read or write) and "metric"
+// (bw or iops) from the JSON.
+func (f *Fio) parseFioJSON(data, op, metric string) (float64, error) {
+ var result fioResult
+ if err := json.Unmarshal([]byte(data), &result); err != nil {
+ return 0, fmt.Errorf("could not unmarshal data: %v", err)
+ }
+
+ if len(result.Jobs) < 1 {
+ return 0, fmt.Errorf("no jobs present to parse")
+ }
+
+ var metrics fioMetrics
+ if err := json.Unmarshal(result.Jobs[0][op], &metrics); err != nil {
+ return 0, fmt.Errorf("could not unmarshal jobs: %v", err)
+ }
+
+ if _, ok := metrics[metric]; !ok {
+ return 0, fmt.Errorf("no metric found for op: %s", op)
+ }
+ return strconv.ParseFloat(string(metrics[metric]), 64)
+}
diff --git a/test/benchmarks/tools/fio_test.go b/test/benchmarks/tools/fio_test.go
new file mode 100644
index 000000000..a98277150
--- /dev/null
+++ b/test/benchmarks/tools/fio_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import "testing"
+
+// TestFio checks the Fio parsers on sample output.
+func TestFio(t *testing.T) {
+ sampleData := `
+{
+ "fio version" : "fio-3.1",
+ "timestamp" : 1554837456,
+ "timestamp_ms" : 1554837456621,
+ "time" : "Tue Apr 9 19:17:36 2019",
+ "jobs" : [
+ {
+ "jobname" : "test",
+ "groupid" : 0,
+ "error" : 0,
+ "eta" : 2147483647,
+ "elapsed" : 1,
+ "job options" : {
+ "name" : "test",
+ "ioengine" : "sync",
+ "size" : "1073741824",
+ "filename" : "/disk/file.dat",
+ "iodepth" : "4",
+ "bs" : "4096",
+ "rw" : "write"
+ },
+ "read" : {
+ "io_bytes" : 0,
+ "io_kbytes" : 0,
+ "bw" : 123456,
+ "iops" : 1234.5678,
+ "runtime" : 0,
+ "total_ios" : 0,
+ "short_ios" : 0,
+ "bw_min" : 0,
+ "bw_max" : 0,
+ "bw_agg" : 0.000000,
+ "bw_mean" : 0.000000,
+ "bw_dev" : 0.000000,
+ "bw_samples" : 0,
+ "iops_min" : 0,
+ "iops_max" : 0,
+ "iops_mean" : 0.000000,
+ "iops_stddev" : 0.000000,
+ "iops_samples" : 0
+ },
+ "write" : {
+ "io_bytes" : 1073741824,
+ "io_kbytes" : 1048576,
+ "bw" : 1753471,
+ "iops" : 438367.892977,
+ "runtime" : 598,
+ "total_ios" : 262144,
+ "bw_min" : 1731120,
+ "bw_max" : 1731120,
+ "bw_agg" : 98.725328,
+ "bw_mean" : 1731120.000000,
+ "bw_dev" : 0.000000,
+ "bw_samples" : 1,
+ "iops_min" : 432780,
+ "iops_max" : 432780,
+ "iops_mean" : 432780.000000,
+ "iops_stddev" : 0.000000,
+ "iops_samples" : 1
+ }
+ }
+ ]
+}
+`
+ fio := Fio{}
+ // WriteBandwidth.
+ got, err := fio.parseBandwidth(sampleData, false)
+ var want float64 = 1753471.0 * 1024
+ if err != nil {
+ t.Fatalf("parse failed with err: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %f, want: %f", got, want)
+ }
+
+ // ReadBandwidth.
+ got, err = fio.parseBandwidth(sampleData, true)
+ want = 123456 * 1024
+ if err != nil {
+ t.Fatalf("parse failed with err: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %f, want: %f", got, want)
+ }
+
+ // WriteIOps.
+ got, err = fio.parseIOps(sampleData, false)
+ want = 438367.892977
+ if err != nil {
+ t.Fatalf("parse failed with err: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %f, want: %f", got, want)
+ }
+
+ // ReadIOps.
+ got, err = fio.parseIOps(sampleData, true)
+ want = 1234.5678
+ if err != nil {
+ t.Fatalf("parse failed with err: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %f, want: %f", got, want)
+ }
+}
diff --git a/test/benchmarks/tools/hey.go b/test/benchmarks/tools/hey.go
new file mode 100644
index 000000000..b1e20e356
--- /dev/null
+++ b/test/benchmarks/tools/hey.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+// Hey is for the client application 'hey'.
+type Hey struct {
+ Requests int // Note: requests cannot be less than concurrency.
+ Concurrency int
+ Doc string
+}
+
+// MakeCmd returns a 'hey' command.
+func (h *Hey) MakeCmd(ip net.IP, port int) []string {
+ return strings.Split(fmt.Sprintf("hey -n %d -c %d http://%s:%d/%s",
+ h.Requests, h.Concurrency, ip, port, h.Doc), " ")
+}
+
+// Report parses output from 'hey' and reports metrics.
+func (h *Hey) Report(b *testing.B, output string) {
+ b.Helper()
+ requests, err := h.parseRequestsPerSecond(output)
+ if err != nil {
+ b.Fatalf("failed to parse requests per second: %v", err)
+ }
+ b.ReportMetric(requests, "requests_per_second")
+
+ ave, err := h.parseAverageLatency(output)
+ if err != nil {
+ b.Fatalf("failed to parse average latency: %v", err)
+ }
+ b.ReportMetric(ave, "average_latency_secs")
+}
+
+var heyReqPerSecondRE = regexp.MustCompile(`Requests/sec:\s*(\d+\.?\d+?)\s+`)
+
+// parseRequestsPerSecond finds requests per second from 'hey' output.
+func (h *Hey) parseRequestsPerSecond(data string) (float64, error) {
+ match := heyReqPerSecondRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0, fmt.Errorf("failed get bandwidth: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
+
+var heyAverageLatencyRE = regexp.MustCompile(`Average:\s*(\d+\.?\d+?)\s+secs`)
+
+// parseHeyAverageLatency finds Average Latency in seconds form 'hey' output.
+func (h *Hey) parseAverageLatency(data string) (float64, error) {
+ match := heyAverageLatencyRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0, fmt.Errorf("failed get average latency match%d : %s", len(match), data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
diff --git a/test/benchmarks/tools/hey_test.go b/test/benchmarks/tools/hey_test.go
new file mode 100644
index 000000000..e0cab1f52
--- /dev/null
+++ b/test/benchmarks/tools/hey_test.go
@@ -0,0 +1,81 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import "testing"
+
+// TestHey checks the Hey parsers on sample output.
+func TestHey(t *testing.T) {
+ sampleData := `
+ Summary:
+ Total: 2.2391 secs
+ Slowest: 1.6292 secs
+ Fastest: 0.0066 secs
+ Average: 0.5351 secs
+ Requests/sec: 89.3202
+
+ Total data: 841200 bytes
+ Size/request: 4206 bytes
+
+ Response time histogram:
+ 0.007 [1] |
+ 0.169 [0] |
+ 0.331 [149] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
+ 0.493 [0] |
+ 0.656 [0] |
+ 0.818 [0] |
+ 0.980 [0] |
+ 1.142 [0] |
+ 1.305 [0] |
+ 1.467 [49] |■■■■■■■■■■■■■
+ 1.629 [1] |
+
+
+ Latency distribution:
+ 10% in 0.2149 secs
+ 25% in 0.2449 secs
+ 50% in 0.2703 secs
+ 75% in 1.3315 secs
+ 90% in 1.4045 secs
+ 95% in 1.4232 secs
+ 99% in 1.4362 secs
+
+ Details (average, fastest, slowest):
+ DNS+dialup: 0.0002 secs, 0.0066 secs, 1.6292 secs
+ DNS-lookup: 0.0000 secs, 0.0000 secs, 0.0000 secs
+ req write: 0.0000 secs, 0.0000 secs, 0.0012 secs
+ resp wait: 0.5225 secs, 0.0064 secs, 1.4346 secs
+ resp read: 0.0122 secs, 0.0001 secs, 0.2006 secs
+
+ Status code distribution:
+ [200] 200 responses
+ `
+ hey := Hey{}
+ want := 89.3202
+ got, err := hey.parseRequestsPerSecond(sampleData)
+ if err != nil {
+ t.Fatalf("failed to parse request per second with: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %f, want: %f", got, want)
+ }
+
+ want = 0.5351
+ got, err = hey.parseAverageLatency(sampleData)
+ if err != nil {
+ t.Fatalf("failed to parse average latency with: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %f, want: %f", got, want)
+ }
+}
diff --git a/test/benchmarks/tools/iperf.go b/test/benchmarks/tools/iperf.go
new file mode 100644
index 000000000..df3d9349b
--- /dev/null
+++ b/test/benchmarks/tools/iperf.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+// Iperf is for the client side of `iperf`.
+type Iperf struct {
+ Time int
+}
+
+// MakeCmd returns a iperf client command.
+func (i *Iperf) MakeCmd(ip net.IP, port int) []string {
+ // iperf report in Kb realtime
+ return strings.Split(fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", i.Time, ip, port), " ")
+}
+
+// Report parses output from iperf client and reports metrics.
+func (i *Iperf) Report(b *testing.B, output string) {
+ b.Helper()
+ // Parse bandwidth and report it.
+ bW, err := i.bandwidth(output)
+ if err != nil {
+ b.Fatalf("failed to parse bandwitdth from %s: %v", output, err)
+ }
+ b.ReportMetric(bW*1024, "bandwidth_b/s") // Convert from Kb/s to b/s.
+}
+
+// bandwidth parses the Bandwidth number from an iperf report. A sample is below.
+func (i *Iperf) bandwidth(data string) (float64, error) {
+ re := regexp.MustCompile(`\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec`)
+ match := re.FindStringSubmatch(data)
+ if len(match) < 1 {
+ return 0, fmt.Errorf("failed get bandwidth: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
diff --git a/test/benchmarks/tools/iperf_test.go b/test/benchmarks/tools/iperf_test.go
new file mode 100644
index 000000000..03bb30d05
--- /dev/null
+++ b/test/benchmarks/tools/iperf_test.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package tools
+
+import "testing"
+
+// TestIperf checks the Iperf parsers on sample output.
+func TestIperf(t *testing.T) {
+ sampleData := `
+------------------------------------------------------------
+Client connecting to 10.138.15.215, TCP port 32779
+TCP window size: 45.0 KByte (default)
+------------------------------------------------------------
+[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779
+[ ID] Interval Transfer Bandwidth
+[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec
+`
+ i := Iperf{}
+ bandwidth, err := i.bandwidth(sampleData)
+ if err != nil || bandwidth != 45900 {
+ t.Fatalf("failed with: %v and %f", err, bandwidth)
+ }
+}
diff --git a/test/benchmarks/tools/meminfo.go b/test/benchmarks/tools/meminfo.go
new file mode 100644
index 000000000..2414a96a7
--- /dev/null
+++ b/test/benchmarks/tools/meminfo.go
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "fmt"
+ "regexp"
+ "strconv"
+ "testing"
+)
+
+// Meminfo wraps measurements of MemAvailable using /proc/meminfo.
+type Meminfo struct {
+}
+
+// MakeCmd returns a command for checking meminfo.
+func (*Meminfo) MakeCmd() (string, []string) {
+ return "cat", []string{"/proc/meminfo"}
+}
+
+// Report takes two reads of meminfo, parses them, and reports the difference
+// divided by b.N.
+func (*Meminfo) Report(b *testing.B, before, after string) {
+ b.Helper()
+
+ beforeVal, err := parseMemAvailable(before)
+ if err != nil {
+ b.Fatalf("could not parse before value %s: %v", before, err)
+ }
+
+ afterVal, err := parseMemAvailable(after)
+ if err != nil {
+ b.Fatalf("could not parse before value %s: %v", before, err)
+ }
+ val := 1024 * ((beforeVal - afterVal) / float64(b.N))
+ b.ReportMetric(val, "average_container_size_bytes")
+}
+
+var memInfoRE = regexp.MustCompile(`MemAvailable:\s*(\d+)\skB\n`)
+
+// parseMemAvailable grabs the MemAvailable number from /proc/meminfo.
+func parseMemAvailable(data string) (float64, error) {
+ match := memInfoRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0, fmt.Errorf("couldn't find MemAvailable in %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
diff --git a/test/benchmarks/tools/meminfo_test.go b/test/benchmarks/tools/meminfo_test.go
new file mode 100644
index 000000000..ba803540f
--- /dev/null
+++ b/test/benchmarks/tools/meminfo_test.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "testing"
+)
+
+// TestMeminfo checks the Meminfo parser on sample output.
+func TestMeminfo(t *testing.T) {
+ sampleData := `
+MemTotal: 16337408 kB
+MemFree: 3742696 kB
+MemAvailable: 9319948 kB
+Buffers: 1433884 kB
+Cached: 4607036 kB
+SwapCached: 45284 kB
+Active: 8288376 kB
+Inactive: 2685928 kB
+Active(anon): 4724912 kB
+Inactive(anon): 1047940 kB
+Active(file): 3563464 kB
+Inactive(file): 1637988 kB
+Unevictable: 326940 kB
+Mlocked: 48 kB
+SwapTotal: 33292284 kB
+SwapFree: 32865736 kB
+Dirty: 708 kB
+Writeback: 0 kB
+AnonPages: 4304204 kB
+Mapped: 975424 kB
+Shmem: 910292 kB
+KReclaimable: 744532 kB
+Slab: 1058448 kB
+SReclaimable: 744532 kB
+SUnreclaim: 313916 kB
+KernelStack: 25188 kB
+PageTables: 65300 kB
+NFS_Unstable: 0 kB
+Bounce: 0 kB
+WritebackTmp: 0 kB
+CommitLimit: 41460988 kB
+Committed_AS: 22859492 kB
+VmallocTotal: 34359738367 kB
+VmallocUsed: 63088 kB
+VmallocChunk: 0 kB
+Percpu: 9248 kB
+HardwareCorrupted: 0 kB
+AnonHugePages: 786432 kB
+ShmemHugePages: 0 kB
+ShmemPmdMapped: 0 kB
+FileHugePages: 0 kB
+FilePmdMapped: 0 kB
+HugePages_Total: 0
+HugePages_Free: 0
+HugePages_Rsvd: 0
+HugePages_Surp: 0
+Hugepagesize: 2048 kB
+Hugetlb: 0 kB
+DirectMap4k: 5408532 kB
+DirectMap2M: 11241472 kB
+DirectMap1G: 1048576 kB
+`
+ want := 9319948.0
+ got, err := parseMemAvailable(sampleData)
+ if err != nil {
+ t.Fatalf("parseMemAvailable failed: %v", err)
+ }
+ if got != want {
+ t.Fatalf("parseMemAvailable got %f, want %f", got, want)
+ }
+}
diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go
new file mode 100644
index 000000000..c899ae0d4
--- /dev/null
+++ b/test/benchmarks/tools/redis.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+// Redis is for the client 'redis-benchmark'.
+type Redis struct {
+ Operation string
+}
+
+// MakeCmd returns a redis-benchmark client command.
+func (r *Redis) MakeCmd(ip net.IP, port int) []string {
+ // There is no -t PING_BULK for redis-benchmark, so adjust the command in that case.
+ // Note that "ping" will run both PING_INLINE and PING_BULK.
+ if r.Operation == "PING_BULK" {
+ return strings.Split(
+ fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, port), " ")
+ }
+
+ // runs redis-benchmark -t operation for 100K requests against server.
+ return strings.Split(
+ fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", r.Operation, ip, port), " ")
+}
+
+// Report parses output from redis-benchmark client and reports metrics.
+func (r *Redis) Report(b *testing.B, output string) {
+ b.Helper()
+ result, err := r.parseOperation(output)
+ if err != nil {
+ b.Fatalf("parsing result %s failed with err: %v", output, err)
+ }
+ b.ReportMetric(result, r.Operation) // operations per second
+}
+
+// parseOperation grabs the metric operations per second from redis-benchmark output.
+func (r *Redis) parseOperation(data string) (float64, error) {
+ re := regexp.MustCompile(fmt.Sprintf(`"%s( .*)?","(\d*\.\d*)"`, r.Operation))
+ match := re.FindStringSubmatch(data)
+ if len(match) < 3 {
+ return 0.0, fmt.Errorf("could not find %s in %s", r.Operation, data)
+ }
+ return strconv.ParseFloat(match[2], 64)
+}
diff --git a/test/benchmarks/tools/redis_test.go b/test/benchmarks/tools/redis_test.go
new file mode 100644
index 000000000..4bafda66f
--- /dev/null
+++ b/test/benchmarks/tools/redis_test.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "testing"
+)
+
+// TestRedis checks the Redis parsers on sample output.
+func TestRedis(t *testing.T) {
+ sampleData := `
+ "PING_INLINE","48661.80"
+ "PING_BULK","50301.81"
+ "SET","48923.68"
+ "GET","49382.71"
+ "INCR","49975.02"
+ "LPUSH","49875.31"
+ "RPUSH","50276.52"
+ "LPOP","50327.12"
+ "RPOP","50556.12"
+ "SADD","49504.95"
+ "HSET","49504.95"
+ "SPOP","50025.02"
+ "LPUSH (needed to benchmark LRANGE)","48875.86"
+ "LRANGE_100 (first 100 elements)","33955.86"
+ "LRANGE_300 (first 300 elements)","16550.81"// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+ "LRANGE_500 (first 450 elements)","13653.74"
+ "LRANGE_600 (first 600 elements)","11219.57"
+ "MSET (10 keys)","44682.75"
+ `
+ wants := map[string]float64{
+ "PING_INLINE": 48661.80,
+ "PING_BULK": 50301.81,
+ "SET": 48923.68,
+ "GET": 49382.71,
+ "INCR": 49975.02,
+ "LPUSH": 49875.31,
+ "RPUSH": 50276.52,
+ "LPOP": 50327.12,
+ "RPOP": 50556.12,
+ "SADD": 49504.95,
+ "HSET": 49504.95,
+ "SPOP": 50025.02,
+ "LRANGE_100": 33955.86,
+ "LRANGE_300": 16550.81,
+ "LRANGE_500": 13653.74,
+ "LRANGE_600": 11219.57,
+ "MSET": 44682.75,
+ }
+ for op, want := range wants {
+ redis := Redis{
+ Operation: op,
+ }
+ if got, err := redis.parseOperation(sampleData); err != nil {
+ t.Fatalf("failed to parse %s: %v", op, err)
+ } else if want != got {
+ t.Fatalf("wanted %f for op %s, got %f", want, op, got)
+ }
+ }
+}
diff --git a/test/benchmarks/tools/sysbench.go b/test/benchmarks/tools/sysbench.go
new file mode 100644
index 000000000..6b2f75ca2
--- /dev/null
+++ b/test/benchmarks/tools/sysbench.go
@@ -0,0 +1,245 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+)
+
+var warmup = "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null &&"
+
+// Sysbench represents a 'sysbench' command.
+type Sysbench interface {
+ MakeCmd() []string // Makes a sysbench command.
+ flags() []string
+ Report(*testing.B, string) // Reports results contained in string.
+}
+
+// SysbenchBase is the top level struct for sysbench and holds top-level arguments
+// for sysbench. See: 'sysbench --help'
+type SysbenchBase struct {
+ Threads int // number of Threads for the test.
+ Time int // time limit for test in seconds.
+}
+
+// baseFlags returns top level flags.
+func (s *SysbenchBase) baseFlags() []string {
+ var ret []string
+ if s.Threads > 0 {
+ ret = append(ret, fmt.Sprintf("--threads=%d", s.Threads))
+ }
+ if s.Time > 0 {
+ ret = append(ret, fmt.Sprintf("--time=%d", s.Time))
+ }
+ return ret
+}
+
+// SysbenchCPU is for 'sysbench [flags] cpu run' and holds CPU specific arguments.
+type SysbenchCPU struct {
+ Base SysbenchBase
+ MaxPrime int // upper limit for primes generator [10000].
+}
+
+// MakeCmd makes commands for SysbenchCPU.
+func (s *SysbenchCPU) MakeCmd() []string {
+ cmd := []string{warmup, "sysbench"}
+ cmd = append(cmd, s.flags()...)
+ cmd = append(cmd, "cpu run")
+ return []string{"sh", "-c", strings.Join(cmd, " ")}
+}
+
+// flags makes flags for SysbenchCPU cmds.
+func (s *SysbenchCPU) flags() []string {
+ cmd := s.Base.baseFlags()
+ if s.MaxPrime > 0 {
+ return append(cmd, fmt.Sprintf("--cpu-max-prime=%d", s.MaxPrime))
+ }
+ return cmd
+}
+
+// Report reports the relevant metrics for SysbenchCPU.
+func (s *SysbenchCPU) Report(b *testing.B, output string) {
+ b.Helper()
+ result, err := s.parseEvents(output)
+ if err != nil {
+ b.Fatalf("parsing CPU events from %s failed: %v", output, err)
+ }
+ b.ReportMetric(result, "cpu_events_per_second")
+}
+
+var cpuEventsPerSecondRE = regexp.MustCompile(`events per second:\s*(\d*.?\d*)\n`)
+
+// parseEvents parses cpu events per second.
+func (s *SysbenchCPU) parseEvents(data string) (float64, error) {
+ match := cpuEventsPerSecondRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0.0, fmt.Errorf("could not find events per second: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
+
+// SysbenchMemory is for 'sysbench [FLAGS] memory run' and holds Memory specific arguments.
+type SysbenchMemory struct {
+ Base SysbenchBase
+ BlockSize string // size of test memory block [1K].
+ TotalSize string // size of data to transfer [100G].
+ Scope string // memory access scope {global, local} [global].
+ HugeTLB bool // allocate memory from HugeTLB [off].
+ OperationType string // type of memory ops {read, write, none} [write].
+ AccessMode string // access mode {seq, rnd} [seq].
+}
+
+// MakeCmd makes commands for SysbenchMemory.
+func (s *SysbenchMemory) MakeCmd() []string {
+ cmd := []string{warmup, "sysbench"}
+ cmd = append(cmd, s.flags()...)
+ cmd = append(cmd, "memory run")
+ return []string{"sh", "-c", strings.Join(cmd, " ")}
+}
+
+// flags makes flags for SysbenchMemory cmds.
+func (s *SysbenchMemory) flags() []string {
+ cmd := s.Base.baseFlags()
+ if s.BlockSize != "" {
+ cmd = append(cmd, fmt.Sprintf("--memory-block-size=%s", s.BlockSize))
+ }
+ if s.TotalSize != "" {
+ cmd = append(cmd, fmt.Sprintf("--memory-total-size=%s", s.TotalSize))
+ }
+ if s.Scope != "" {
+ cmd = append(cmd, fmt.Sprintf("--memory-scope=%s", s.Scope))
+ }
+ if s.HugeTLB {
+ cmd = append(cmd, "--memory-hugetlb=on")
+ }
+ if s.OperationType != "" {
+ cmd = append(cmd, fmt.Sprintf("--memory-oper=%s", s.OperationType))
+ }
+ if s.AccessMode != "" {
+ cmd = append(cmd, fmt.Sprintf("--memory-access-mode=%s", s.AccessMode))
+ }
+ return cmd
+}
+
+// Report reports the relevant metrics for SysbenchMemory.
+func (s *SysbenchMemory) Report(b *testing.B, output string) {
+ b.Helper()
+ result, err := s.parseOperations(output)
+ if err != nil {
+ b.Fatalf("parsing result %s failed with err: %v", output, err)
+ }
+ b.ReportMetric(result, "operations_per_second")
+}
+
+var memoryOperationsRE = regexp.MustCompile(`Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)`)
+
+// parseOperations parses memory operations per second form sysbench memory ouput.
+func (s *SysbenchMemory) parseOperations(data string) (float64, error) {
+ match := memoryOperationsRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0.0, fmt.Errorf("couldn't find memory operations per second: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
+
+// SysbenchMutex is for 'sysbench [FLAGS] mutex run' and holds Mutex specific arguments.
+type SysbenchMutex struct {
+ Base SysbenchBase
+ Num int // total size of mutex array [4096].
+ Locks int // number of mutex locks per thread [50K].
+ Loops int // number of loops to do outside mutex lock [10K].
+}
+
+// MakeCmd makes commands for SysbenchMutex.
+func (s *SysbenchMutex) MakeCmd() []string {
+ cmd := []string{warmup, "sysbench"}
+ cmd = append(cmd, s.flags()...)
+ cmd = append(cmd, "mutex run")
+ return []string{"sh", "-c", strings.Join(cmd, " ")}
+}
+
+// flags makes flags for SysbenchMutex commands.
+func (s *SysbenchMutex) flags() []string {
+ var cmd []string
+ cmd = append(cmd, s.Base.baseFlags()...)
+ if s.Num > 0 {
+ cmd = append(cmd, fmt.Sprintf("--mutex-num=%d", s.Num))
+ }
+ if s.Locks > 0 {
+ cmd = append(cmd, fmt.Sprintf("--mutex-locks=%d", s.Locks))
+ }
+ if s.Loops > 0 {
+ cmd = append(cmd, fmt.Sprintf("--mutex-loops=%d", s.Loops))
+ }
+ return cmd
+}
+
+// Report parses and reports relevant sysbench mutex metrics.
+func (s *SysbenchMutex) Report(b *testing.B, output string) {
+ b.Helper()
+
+ result, err := s.parseExecutionTime(output)
+ if err != nil {
+ b.Fatalf("parsing result %s failed with err: %v", output, err)
+ }
+ b.ReportMetric(result, "average_execution_time_secs")
+
+ result, err = s.parseDeviation(output)
+ if err != nil {
+ b.Fatalf("parsing result %s failed with err: %v", output, err)
+ }
+ b.ReportMetric(result, "stdev_execution_time_secs")
+
+ result, err = s.parseLatency(output)
+ if err != nil {
+ b.Fatalf("parsing result %s failed with err: %v", output, err)
+ }
+ b.ReportMetric(result/1000, "average_latency_secs")
+}
+
+var executionTimeRE = regexp.MustCompile(`execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)`)
+
+// parseExecutionTime parses threads fairness average execution time from sysbench output.
+func (s *SysbenchMutex) parseExecutionTime(data string) (float64, error) {
+ match := executionTimeRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0.0, fmt.Errorf("could not find execution time average: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
+
+// parseDeviation parses threads fairness stddev time from sysbench output.
+func (s *SysbenchMutex) parseDeviation(data string) (float64, error) {
+ match := executionTimeRE.FindStringSubmatch(data)
+ if len(match) < 3 {
+ return 0.0, fmt.Errorf("could not find execution time deviation: %s", data)
+ }
+ return strconv.ParseFloat(match[2], 64)
+}
+
+var averageLatencyRE = regexp.MustCompile(`avg:[^\n^\d]*(\d*\.?\d*)`)
+
+// parseLatency parses latency from sysbench output.
+func (s *SysbenchMutex) parseLatency(data string) (float64, error) {
+ match := averageLatencyRE.FindStringSubmatch(data)
+ if len(match) < 2 {
+ return 0.0, fmt.Errorf("could not find average latency: %s", data)
+ }
+ return strconv.ParseFloat(match[1], 64)
+}
diff --git a/test/benchmarks/tools/sysbench_test.go b/test/benchmarks/tools/sysbench_test.go
new file mode 100644
index 000000000..850d1939e
--- /dev/null
+++ b/test/benchmarks/tools/sysbench_test.go
@@ -0,0 +1,169 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "testing"
+)
+
+// TestSysbenchCpu tests parses on sample 'sysbench cpu' output.
+func TestSysbenchCpu(t *testing.T) {
+ sampleData := `
+sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3)
+
+Running the test with following options:
+Number of threads: 8
+Initializing random number generator from current time
+
+
+Prime numbers limit: 10000
+
+Initializing worker threads...
+
+Threads started!
+
+CPU speed:
+ events per second: 9093.38
+
+General statistics:
+ total time: 10.0007s
+ total number of events: 90949
+
+Latency (ms):
+ min: 0.64
+ avg: 0.88
+ max: 24.65
+ 95th percentile: 1.55
+ sum: 79936.91
+
+Threads fairness:
+ events (avg/stddev): 11368.6250/831.38
+ execution time (avg/stddev): 9.9921/0.01
+`
+ sysbench := SysbenchCPU{}
+ want := 9093.38
+ if got, err := sysbench.parseEvents(sampleData); err != nil {
+ t.Fatalf("parse cpu events failed: %v", err)
+ } else if want != got {
+ t.Fatalf("got: %f want: %f", got, want)
+ }
+}
+
+// TestSysbenchMemory tests parsers on sample 'sysbench memory' output.
+func TestSysbenchMemory(t *testing.T) {
+ sampleData := `
+sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3)
+
+Running the test with following options:
+Number of threads: 8
+Initializing random number generator from current time
+
+
+Running memory speed test with the following options:
+ block size: 1KiB
+ total size: 102400MiB
+ operation: write
+ scope: global
+
+Initializing worker threads...
+
+Threads started!
+
+Total operations: 47999046 (9597428.64 per second)
+
+46874.07 MiB transferred (9372.49 MiB/sec)
+
+
+General statistics:
+ total time: 5.0001s
+ total number of events: 47999046
+
+Latency (ms):
+ min: 0.00
+ avg: 0.00
+ max: 0.21
+ 95th percentile: 0.00
+ sum: 33165.91
+
+Threads fairness:
+ events (avg/stddev): 5999880.7500/111242.52
+ execution time (avg/stddev): 4.1457/0.09
+`
+ sysbench := SysbenchMemory{}
+ want := 9597428.64
+ if got, err := sysbench.parseOperations(sampleData); err != nil {
+ t.Fatalf("parse memory ops failed: %v", err)
+ } else if want != got {
+ t.Fatalf("got: %f want: %f", got, want)
+ }
+}
+
+// TestSysbenchMutex tests parsers on sample 'sysbench mutex' output.
+func TestSysbenchMutex(t *testing.T) {
+ sampleData := `
+sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3)
+
+The 'mutex' test requires a command argument. See 'sysbench mutex help'
+root@ec078132e294:/# sysbench mutex --threads=8 run
+sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3)
+
+Running the test with following options:
+Number of threads: 8
+Initializing random number generator from current time
+
+
+Initializing worker threads...
+
+Threads started!
+
+
+General statistics:
+ total time: 0.2320s
+ total number of events: 8
+
+Latency (ms):
+ min: 152.35
+ avg: 192.48
+ max: 231.41
+ 95th percentile: 231.53
+ sum: 1539.83
+
+Threads fairness:
+ events (avg/stddev): 1.0000/0.00
+ execution time (avg/stddev): 0.1925/0.04
+`
+
+ sysbench := SysbenchMutex{}
+ want := .1925
+ if got, err := sysbench.parseExecutionTime(sampleData); err != nil {
+ t.Fatalf("parse mutex time failed: %v", err)
+ } else if want != got {
+ t.Fatalf("got: %f want: %f", got, want)
+ }
+
+ want = 0.04
+ if got, err := sysbench.parseDeviation(sampleData); err != nil {
+ t.Fatalf("parse mutex deviation failed: %v", err)
+ } else if want != got {
+ t.Fatalf("got: %f want: %f", got, want)
+ }
+
+ want = 192.48
+ if got, err := sysbench.parseLatency(sampleData); err != nil {
+ t.Fatalf("parse mutex time failed: %v", err)
+ } else if want != got {
+ t.Fatalf("got: %f want: %f", got, want)
+ }
+}
diff --git a/test/benchmarks/tools/tools.go b/test/benchmarks/tools/tools.go
new file mode 100644
index 000000000..eb61c0136
--- /dev/null
+++ b/test/benchmarks/tools/tools.go
@@ -0,0 +1,17 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tools holds tooling to couple command formatting and output parsers
+// together.
+package tools
diff --git a/runsc/container/test_app/BUILD b/test/cmd/test_app/BUILD
index 9bf9e6e9d..98ba5a3d9 100644
--- a/runsc/container/test_app/BUILD
+++ b/test/cmd/test_app/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
@@ -9,11 +9,13 @@ go_binary(
"fds.go",
"test_app.go",
],
- pure = "on",
+ pure = True,
visibility = ["//runsc/container:__pkg__"],
deps = [
+ "//pkg/test/testutil",
"//pkg/unet",
- "//runsc/testutil",
+ "//runsc/flag",
"@com_github_google_subcommands//:go_default_library",
+ "@com_github_kr_pty//:go_default_library",
],
)
diff --git a/runsc/container/test_app/fds.go b/test/cmd/test_app/fds.go
index a90cc1662..a7658eefd 100644
--- a/runsc/container/test_app/fds.go
+++ b/test/cmd/test_app/fds.go
@@ -21,10 +21,10 @@ import (
"os"
"time"
- "flag"
"github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/pkg/unet"
- "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/runsc/flag"
)
const fileContents = "foobarbaz"
diff --git a/runsc/container/test_app/test_app.go b/test/cmd/test_app/test_app.go
index 913d781c6..3ba4f38f8 100644
--- a/runsc/container/test_app/test_app.go
+++ b/test/cmd/test_app/test_app.go
@@ -19,6 +19,7 @@ package main
import (
"context"
"fmt"
+ "io"
"io/ioutil"
"log"
"net"
@@ -29,9 +30,10 @@ import (
sys "syscall"
"time"
- "flag"
"github.com/google/subcommands"
- "gvisor.dev/gvisor/runsc/testutil"
+ "github.com/kr/pty"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/flag"
)
func main() {
@@ -41,6 +43,7 @@ func main() {
subcommands.Register(new(fdReceiver), "")
subcommands.Register(new(fdSender), "")
subcommands.Register(new(forkBomb), "")
+ subcommands.Register(new(ptyRunner), "")
subcommands.Register(new(reaper), "")
subcommands.Register(new(syscall), "")
subcommands.Register(new(taskTree), "")
@@ -93,7 +96,7 @@ func (c *uds) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{})
listener, err := net.Listen("unix", c.socketPath)
if err != nil {
- log.Fatal("error listening on socket %q:", c.socketPath, err)
+ log.Fatalf("error listening on socket %q: %v", c.socketPath, err)
}
go server(listener, outputFile)
@@ -352,3 +355,40 @@ func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...inter
return subcommands.ExitSuccess
}
+
+type ptyRunner struct{}
+
+// Name implements subcommands.Command.
+func (*ptyRunner) Name() string {
+ return "pty-runner"
+}
+
+// Synopsis implements subcommands.Command.
+func (*ptyRunner) Synopsis() string {
+ return "runs the given command with an open pty terminal"
+}
+
+// Usage implements subcommands.Command.
+func (*ptyRunner) Usage() string {
+ return "pty-runner [command]"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (*ptyRunner) SetFlags(f *flag.FlagSet) {}
+
+// Execute implements subcommands.Command.
+func (*ptyRunner) Execute(_ context.Context, fs *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
+ c := exec.Command(fs.Args()[0], fs.Args()[1:]...)
+ f, err := pty.Start(c)
+ if err != nil {
+ fmt.Printf("pty.Start failed: %v", err)
+ return subcommands.ExitFailure
+ }
+ defer f.Close()
+
+ // Copy stdout from the command to keep this process alive until the
+ // subprocess exits.
+ io.Copy(os.Stdout, f)
+
+ return subcommands.ExitSuccess
+}
diff --git a/test/e2e/BUILD b/test/e2e/BUILD
index 4fe03a220..29a84f184 100644
--- a/test/e2e/BUILD
+++ b/test/e2e/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,7 +10,7 @@ go_test(
"integration_test.go",
"regression_test.go",
],
- embed = [":integration"],
+ library = ":integration",
tags = [
# Requires docker and runsc to be configured before the test runs.
"manual",
@@ -20,14 +20,14 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/bits",
- "//runsc/dockerutil",
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
"//runsc/specutils",
- "//runsc/testutil",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
],
)
go_library(
name = "integration",
srcs = ["integration.go"],
- importpath = "gvisor.dev/gvisor/test/integration",
)
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
index 4074d2285..b47df447c 100644
--- a/test/e2e/exec_test.go
+++ b/test/e2e/exec_test.go
@@ -22,33 +22,34 @@
package integration
import (
+ "context"
"fmt"
"strconv"
"strings"
- "syscall"
"testing"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
- "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/runsc/specutils"
)
// Test that exec uses the exact same capability set as the container.
func TestExecCapabilities(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("exec-capabilities-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container.
- if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
- matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ // Check that capability.
+ matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second)
if err != nil {
t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
}
@@ -59,7 +60,7 @@ func TestExecCapabilities(t *testing.T) {
t.Log("Root capabilities:", want)
// Now check that exec'd process capabilities match the root.
- got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "grep", "CapEff:", "/proc/self/status")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
@@ -72,19 +73,20 @@ func TestExecCapabilities(t *testing.T) {
// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW
// which is removed from the container when --net-raw=false.
func TestExecPrivileged(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("exec-privileged-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container with all capabilities dropped.
- if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ CapDrop: []string{"all"},
+ }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Check that all capabilities where dropped from container.
- matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second)
if err != nil {
t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
}
@@ -100,9 +102,11 @@ func TestExecPrivileged(t *testing.T) {
t.Fatalf("Container should have no capabilities: %x", containerCaps)
}
- // Check that 'exec --privileged' adds all capabilities, except
- // for CAP_NET_RAW.
- got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status")
+ // Check that 'exec --privileged' adds all capabilities, except for
+ // CAP_NET_RAW.
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{
+ Privileged: true,
+ }, "grep", "CapEff:", "/proc/self/status")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
@@ -114,97 +118,83 @@ func TestExecPrivileged(t *testing.T) {
}
func TestExecJobControl(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("exec-job-control-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container.
- if err := d.Run("alpine", "sleep", "1000"); err != nil {
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
- // Exec 'sh' with an attached pty.
- cmd, ptmx, err := d.ExecWithTerminal("sh")
+ p, err := d.ExecProcess(ctx, dockerutil.ExecOpts{UseTTY: true}, "/bin/sh")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
- defer ptmx.Close()
- // Call "sleep 100 | cat" in the shell. We pipe to cat so that there
- // will be two processes in the foreground process group.
- if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
+ if _, err = p.Write(time.Second, []byte("sleep 100 | cat\n")); err != nil {
+ t.Fatalf("error exit: %v", err)
}
+ time.Sleep(time.Second)
- // Give shell a few seconds to start executing the sleep.
- time.Sleep(2 * time.Second)
-
- // Send a ^C to the pty, which should kill sleep and cat, but not the
- // shell. \x03 is ASCII "end of text", which is the same as ^C.
- if _, err := ptmx.Write([]byte{'\x03'}); err != nil {
- t.Fatalf("error writing to pty: %v", err)
+ if _, err = p.Write(time.Second, []byte{0x03}); err != nil {
+ t.Fatalf("error exit: %v", err)
}
- // The shell should still be alive at this point. Sleep should have
- // exited with code 2+128=130. We'll exit with 10 plus that number, so
- // that we can be sure that the shell did not get signalled.
- if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
+ if _, err = p.Write(time.Second, []byte("exit $(expr $? + 10)\n")); err != nil {
+ t.Fatalf("error exit: %v", err)
}
- // Exec process should exit with code 10+130=140.
- ps, err := cmd.Process.Wait()
+ want := 140
+ got, err := p.WaitExitStatus(ctx)
if err != nil {
- t.Fatalf("error waiting for exec process: %v", err)
- }
- ws := ps.Sys().(syscall.WaitStatus)
- if !ws.Exited() {
- t.Errorf("ws.Exited got false, want true")
- }
- if got, want := ws.ExitStatus(), 140; got != want {
- t.Errorf("ws.ExitedStatus got %d, want %d", got, want)
+ t.Fatalf("wait for exit failed with: %v", err)
+ } else if got != want {
+ t.Fatalf("wait for exit returned: %d want: %d", got, want)
}
}
// Test that failure to exec returns proper error message.
func TestExecError(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("exec-error-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container.
- if err := d.Run("alpine", "sleep", "1000"); err != nil {
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
- _, err := d.Exec("no_can_find")
+ // Attempt to exec a binary that doesn't exist.
+ out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "no_can_find")
if err == nil {
t.Fatalf("docker exec didn't fail")
}
- if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) {
- t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want)
+ if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(out, want) {
+ t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", out, want)
}
}
// Test that exec inherits environment from run.
func TestExecEnv(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("exec-env-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container with env FOO=BAR.
- if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil {
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ Env: []string{"FOO=BAR"},
+ }, "sleep", "1000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Exec "echo $FOO".
- got, err := d.Exec("/bin/sh", "-c", "echo $FOO")
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $FOO")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
@@ -216,17 +206,20 @@ func TestExecEnv(t *testing.T) {
// TestRunEnvHasHome tests that run always has HOME environment set.
func TestRunEnvHasHome(t *testing.T) {
// Base alpine image does not have any environment variables set.
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("run-env-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Exec "echo $HOME". The 'bin' user's home dir is '/bin'.
- got, err := d.RunFg("--user", "bin", "alpine", "/bin/sh", "-c", "echo $HOME")
+ got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ User: "bin",
+ }, "/bin/sh", "-c", "echo $HOME")
if err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
+
+ // Check that the directory matches.
if got, want := strings.TrimSpace(got), "/bin"; got != want {
t.Errorf("bad output from 'docker run'. Got %q; Want %q.", got, want)
}
@@ -235,28 +228,18 @@ func TestRunEnvHasHome(t *testing.T) {
// Test that exec always has HOME environment set, even when not set in run.
func TestExecEnvHasHome(t *testing.T) {
// Base alpine image does not have any environment variables set.
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("exec-env-home-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
- // We will check that HOME is set for root user, and also for a new
- // non-root user we will create.
- newUID := 1234
- newHome := "/foo/bar"
-
- // Create a new user with a home directory, and then sleep.
- script := fmt.Sprintf(`
- mkdir -p -m 777 %s && \
- adduser foo -D -u %d -h %s && \
- sleep 1000`, newHome, newUID, newHome)
- if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil {
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Exec "echo $HOME", and expect to see "/root".
- got, err := d.Exec("/bin/sh", "-c", "echo $HOME")
+ got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $HOME")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
@@ -264,8 +247,18 @@ func TestExecEnvHasHome(t *testing.T) {
t.Errorf("wanted exec output to contain %q, got %q", want, got)
}
- // Execute the same as uid 123 and expect newHome.
- got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME")
+ // Create a new user with a home directory.
+ newUID := 1234
+ newHome := "/foo/bar"
+ cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome)
+ if _, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", cmd); err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+
+ // Execute the same as the new user and expect newHome.
+ got, err = d.Exec(ctx, dockerutil.ExecOpts{
+ User: strconv.Itoa(newUID),
+ }, "/bin/sh", "-c", "echo $HOME")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 7cc0de129..809244bab 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -22,21 +22,27 @@
package integration
import (
+ "context"
"flag"
"fmt"
+ "io/ioutil"
"net"
"net/http"
"os"
+ "path/filepath"
"strconv"
"strings"
- "syscall"
"testing"
"time"
- "gvisor.dev/gvisor/runsc/dockerutil"
- "gvisor.dev/gvisor/runsc/testutil"
+ "github.com/docker/docker/api/types/mount"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
+// defaultWait is the default wait time used for tests.
+const defaultWait = time.Minute
+
// httpRequestSucceeds sends a request to a given url and checks that the status is OK.
func httpRequestSucceeds(client http.Client, server string, port int) error {
url := fmt.Sprintf("http://%s:%d", server, port)
@@ -53,78 +59,82 @@ func httpRequestSucceeds(client http.Client, server string, port int) error {
// TestLifeCycle tests a basic Create/Start/Stop docker container life cycle.
func TestLifeCycle(t *testing.T) {
- if err := dockerutil.Pull("nginx"); err != nil {
- t.Fatal("docker pull failed:", err)
- }
- d := dockerutil.MakeDocker("lifecycle-test")
- if err := d.Create("-p", "80", "nginx"); err != nil {
- t.Fatal("docker create failed:", err)
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Create(ctx, dockerutil.RunOpts{
+ Image: "basic/nginx",
+ Ports: []int{80},
+ }); err != nil {
+ t.Fatalf("docker create failed: %v", err)
}
- if err := d.Start(); err != nil {
- d.CleanUp()
- t.Fatal("docker start failed:", err)
+ if err := d.Start(ctx); err != nil {
+ t.Fatalf("docker start failed: %v", err)
}
- // Test that container is working
- port, err := d.FindPort(80)
+ // Test that container is working.
+ port, err := d.FindPort(ctx, 80)
if err != nil {
- t.Fatal("docker.FindPort(80) failed: ", err)
+ t.Fatalf("docker.FindPort(80) failed: %v", err)
}
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
- t.Fatal("WaitForHTTP() timeout:", err)
+ if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
}
- client := http.Client{Timeout: time.Duration(2 * time.Second)}
+ client := http.Client{Timeout: defaultWait}
if err := httpRequestSucceeds(client, "localhost", port); err != nil {
- t.Error("http request failed:", err)
+ t.Errorf("http request failed: %v", err)
}
- if err := d.Stop(); err != nil {
- d.CleanUp()
- t.Fatal("docker stop failed:", err)
+ if err := d.Stop(ctx); err != nil {
+ t.Fatalf("docker stop failed: %v", err)
}
- if err := d.Remove(); err != nil {
- t.Fatal("docker rm failed:", err)
+ if err := d.Remove(ctx); err != nil {
+ t.Fatalf("docker rm failed: %v", err)
}
}
func TestPauseResume(t *testing.T) {
- const img = "gcr.io/gvisor-presubmit/python-hello"
if !testutil.IsCheckpointSupported() {
- t.Log("Checkpoint is not supported, skipping test.")
- return
+ t.Skip("Checkpoint is not supported.")
}
- if err := dockerutil.Pull(img); err != nil {
- t.Fatal("docker pull failed:", err)
- }
- d := dockerutil.MakeDocker("pause-resume-test")
- if err := d.Run("-p", "8080", img); err != nil {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/python",
+ Ports: []int{8080}, // See Dockerfile.
+ }); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Find where port 8080 is mapped to.
- port, err := d.FindPort(8080)
+ port, err := d.FindPort(ctx, 8080)
if err != nil {
- t.Fatal("docker.FindPort(8080) failed:", err)
+ t.Fatalf("docker.FindPort(8080) failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
- t.Fatal("WaitForHTTP() timeout:", err)
+ if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Check that container is working.
- client := http.Client{Timeout: time.Duration(2 * time.Second)}
+ client := http.Client{Timeout: defaultWait}
if err := httpRequestSucceeds(client, "localhost", port); err != nil {
t.Error("http request failed:", err)
}
- if err := d.Pause(); err != nil {
- t.Fatal("docker pause failed:", err)
+ if err := d.Pause(ctx); err != nil {
+ t.Fatalf("docker pause failed: %v", err)
}
// Check if container is paused.
+ client = http.Client{Timeout: 10 * time.Millisecond} // Don't wait a minute.
switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) {
case nil:
t.Errorf("http req expected to fail but it succeeded")
@@ -136,62 +146,72 @@ func TestPauseResume(t *testing.T) {
t.Errorf("http req got unexpected error %v", v)
}
- if err := d.Unpause(); err != nil {
- t.Fatal("docker unpause failed:", err)
+ if err := d.Unpause(ctx); err != nil {
+ t.Fatalf("docker unpause failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
- t.Fatal("WaitForHTTP() timeout:", err)
+ if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Check if container is working again.
+ client = http.Client{Timeout: defaultWait}
if err := httpRequestSucceeds(client, "localhost", port); err != nil {
t.Error("http request failed:", err)
}
}
func TestCheckpointRestore(t *testing.T) {
- const img = "gcr.io/gvisor-presubmit/python-hello"
if !testutil.IsCheckpointSupported() {
- t.Log("Pause/resume is not supported, skipping test.")
- return
+ t.Skip("Pause/resume is not supported.")
}
- if err := dockerutil.Pull(img); err != nil {
- t.Fatal("docker pull failed:", err)
+ // TODO(gvisor.dev/issue/3373): Remove after implementing.
+ if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 {
+ t.Skip("CheckpointRestore not implemented in VFS2.")
+ } else if err != nil {
+ t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err)
}
- d := dockerutil.MakeDocker("save-restore-test")
- if err := d.Run("-p", "8080", img); err != nil {
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the container.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/python",
+ Ports: []int{8080}, // See Dockerfile.
+ }); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
- if err := d.Checkpoint("test"); err != nil {
- t.Fatal("docker checkpoint failed:", err)
+ // Create a snapshot.
+ if err := d.Checkpoint(ctx, "test"); err != nil {
+ t.Fatalf("docker checkpoint failed: %v", err)
}
-
- if _, err := d.Wait(30 * time.Second); err != nil {
- t.Fatal(err)
+ if err := d.WaitTimeout(ctx, defaultWait); err != nil {
+ t.Fatalf("wait failed: %v", err)
}
- if err := d.Restore("test"); err != nil {
- t.Fatal("docker restore failed:", err)
+ // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed.
+ if err := testutil.Poll(func() error { return d.Restore(ctx, "test") }, defaultWait); err != nil {
+ t.Fatalf("docker restore failed: %v", err)
}
// Find where port 8080 is mapped to.
- port, err := d.FindPort(8080)
+ port, err := d.FindPort(ctx, 8080)
if err != nil {
- t.Fatal("docker.FindPort(8080) failed:", err)
+ t.Fatalf("docker.FindPort(8080) failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
- t.Fatal("WaitForHTTP() timeout:", err)
+ if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
+ t.Fatalf("WaitForHTTP() timeout: %v", err)
}
// Check if container is working again.
- client := http.Client{Timeout: time.Duration(2 * time.Second)}
+ client := http.Client{Timeout: defaultWait}
if err := httpRequestSucceeds(client, "localhost", port); err != nil {
t.Error("http request failed:", err)
}
@@ -199,48 +219,55 @@ func TestCheckpointRestore(t *testing.T) {
// Create client and server that talk to each other using the local IP.
func TestConnectToSelf(t *testing.T) {
- d := dockerutil.MakeDocker("connect-to-self-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Creates server that replies "server" and exists. Sleeps at the end because
// 'docker exec' gets killed if the init process exists before it can finish.
- if err := d.Run("ubuntu:trusty", "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil {
- t.Fatal("docker run failed:", err)
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ }, "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Finds IP address for host.
- ip, err := d.Exec("/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'")
+ ip, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'")
if err != nil {
- t.Fatal("docker exec failed:", err)
+ t.Fatalf("docker exec failed: %v", err)
}
ip = strings.TrimRight(ip, "\n")
// Runs client that sends "client" to the server and exits.
- reply, err := d.Exec("/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip))
+ reply, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip))
if err != nil {
- t.Fatal("docker exec failed:", err)
+ t.Fatalf("docker exec failed: %v", err)
}
// Ensure both client and server got the message from each other.
if want := "server\n"; reply != want {
t.Errorf("Error on server, want: %q, got: %q", want, reply)
}
- if _, err := d.WaitForOutput("^client\n$", 1*time.Second); err != nil {
- t.Fatal("docker.WaitForOutput(client) timeout:", err)
+ if _, err := d.WaitForOutput(ctx, "^client\n$", defaultWait); err != nil {
+ t.Fatalf("docker.WaitForOutput(client) timeout: %v", err)
}
}
func TestMemLimit(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatal("docker pull failed:", err)
- }
- d := dockerutil.MakeDocker("cgroup-test")
- cmd := "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'"
- out, err := d.RunFg("--memory=500MB", "alpine", "sh", "-c", cmd)
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // N.B. Because the size of the memory file may grow in large chunks,
+ // there is a minimum threshold of 1GB for the MemTotal figure.
+ allocMemory := 1024 * 1024 // In kb.
+ out, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ Memory: allocMemory * 1024, // In bytes.
+ }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'")
if err != nil {
- t.Fatal("docker run failed:", err)
+ t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Remove warning message that swap isn't present.
if strings.HasPrefix(out, "WARNING") {
@@ -251,27 +278,31 @@ func TestMemLimit(t *testing.T) {
out = lines[1]
}
+ // Ensure the memory matches what we want.
got, err := strconv.ParseUint(strings.TrimSpace(out), 10, 64)
if err != nil {
t.Fatalf("failed to parse %q: %v", out, err)
}
- if want := uint64(500 * 1024); got != want {
+ if want := uint64(allocMemory); got != want {
t.Errorf("MemTotal got: %d, want: %d", got, want)
}
}
func TestNumCPU(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatal("docker pull failed:", err)
- }
- d := dockerutil.MakeDocker("cgroup-test")
- cmd := "cat /proc/cpuinfo | grep 'processor.*:' | wc -l"
- out, err := d.RunFg("--cpuset-cpus=0", "alpine", "sh", "-c", cmd)
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Read how many cores are in the container.
+ out, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ CpusetCpus: "0",
+ }, "sh", "-c", "cat /proc/cpuinfo | grep 'processor.*:' | wc -l")
if err != nil {
- t.Fatal("docker run failed:", err)
+ t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
+ // Ensure it matches what we want.
got, err := strconv.Atoi(strings.TrimSpace(out))
if err != nil {
t.Fatalf("failed to parse %q: %v", out, err)
@@ -283,62 +314,182 @@ func TestNumCPU(t *testing.T) {
// TestJobControl tests that job control characters are handled properly.
func TestJobControl(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("job-control-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container with an attached PTY.
- _, ptmx, err := d.RunWithPty("alpine", "sh")
+ p, err := d.SpawnProcess(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sh", "-c", "sleep 100 | cat")
if err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer ptmx.Close()
- defer d.CleanUp()
-
- // Call "sleep 100" in the shell.
- if _, err := ptmx.Write([]byte("sleep 100\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
- }
-
// Give shell a few seconds to start executing the sleep.
time.Sleep(2 * time.Second)
- // Send a ^C to the pty, which should kill sleep, but not the shell.
- // \x03 is ASCII "end of text", which is the same as ^C.
- if _, err := ptmx.Write([]byte{'\x03'}); err != nil {
- t.Fatalf("error writing to pty: %v", err)
+ if _, err := p.Write(time.Second, []byte{0x03}); err != nil {
+ t.Fatalf("error exit: %v", err)
}
- // The shell should still be alive at this point. Sleep should have
- // exited with code 2+128=130. We'll exit with 10 plus that number, so
- // that we can be sure that the shell did not get signalled.
- if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
+ if err := d.WaitTimeout(ctx, 3*time.Second); err != nil {
+ t.Fatalf("WaitTimeout failed: %v", err)
}
- // Wait for the container to exit.
- got, err := d.Wait(5 * time.Second)
+ want := 130
+ got, err := p.WaitExitStatus(ctx)
if err != nil {
- t.Fatalf("error getting exit code: %v", err)
+ t.Fatalf("wait for exit failed with: %v", err)
+ } else if got != want {
+ t.Fatalf("got: %d want: %d", got, want)
}
- // Container should exit with code 10+130=140.
- if want := syscall.WaitStatus(140); got != want {
- t.Errorf("container exited with code %d want %d", got, want)
+}
+
+// TestWorkingDirCreation checks that working dir is created if it doesn't exit.
+func TestWorkingDirCreation(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ workingDir string
+ }{
+ {name: "root", workingDir: "/foo"},
+ {name: "tmp", workingDir: "/tmp/foo"},
+ } {
+ for _, readonly := range []bool{true, false} {
+ name := tc.name
+ if readonly {
+ name += "-readonly"
+ }
+ t.Run(name, func(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ WorkDir: tc.workingDir,
+ ReadOnly: readonly,
+ }
+ got, err := d.Run(ctx, opts, "sh", "-c", "echo ${PWD}")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want := tc.workingDir + "\n"; want != got {
+ t.Errorf("invalid working dir, want: %q, got: %q", want, got)
+ }
+ })
+ }
}
}
-// TestTmpFile checks that files inside '/tmp' are not overridden. In addition,
-// it checks that working dir is created if it doesn't exit.
+// TestTmpFile checks that files inside '/tmp' are not overridden.
func TestTmpFile(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatal("docker pull failed:", err)
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{Image: "basic/tmpfile"}
+ got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want := "123\n"; want != got {
+ t.Errorf("invalid file content, want: %q, got: %q", want, got)
+ }
+}
+
+// TestTmpMount checks that mounts inside '/tmp' are not overridden.
+func TestTmpMount(t *testing.T) {
+ ctx := context.Background()
+ dir, err := ioutil.TempDir(testutil.TmpDir(), "tmp-mount")
+ if err != nil {
+ t.Fatalf("TempDir(): %v", err)
+ }
+ want := "123"
+ if err := ioutil.WriteFile(filepath.Join(dir, "file.txt"), []byte("123"), 0666); err != nil {
+ t.Fatalf("WriteFile(): %v", err)
+ }
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ Mounts: []mount.Mount{
+ {
+ Type: mount.TypeBind,
+ Source: dir,
+ Target: "/tmp/foo",
+ },
+ },
+ }
+ got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ if want != got {
+ t.Errorf("invalid file content, want: %q, got: %q", want, got)
+ }
+}
+
+// TestHostOverlayfsCopyUp tests that the --overlayfs-stale-read option causes
+// runsc to hide the incoherence of FDs opened before and after overlayfs
+// copy-up on the host.
+func TestHostOverlayfsCopyUp(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/hostoverlaytest",
+ WorkDir: "/root",
+ }, "./test_copy_up"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ } else if got != "" {
+ t.Errorf("test failed:\n%s", got)
}
- d := dockerutil.MakeDocker("tmp-file-test")
- if err := d.Run("-w=/tmp/foo/bar", "--read-only", "alpine", "touch", "/tmp/foo/bar/file"); err != nil {
- t.Fatal("docker run failed:", err)
+}
+
+// TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory
+// stream to refer to the current state of the corresponding directory, as a
+// call to opendir() would have done" as required by POSIX, when the directory
+// in question is host overlayfs.
+//
+// This test specifically targets host overlayfs because, per POSIX, "if a file
+// is removed from or added to the directory after the most recent call to
+// opendir() or rewinddir(), whether a subsequent call to readdir() returns an
+// entry for that file is unspecified"; the host filesystems used by other
+// automated tests yield newly-added files from readdir() even if the fsgofer
+// does not explicitly rewinddir(), but overlayfs does not.
+func TestHostOverlayfsRewindDir(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/hostoverlaytest",
+ WorkDir: "/root",
+ }, "./test_rewinddir"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ } else if got != "" {
+ t.Errorf("test failed:\n%s", got)
+ }
+}
+
+// Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it
+// cannot use tricks like userns as root. For this reason, run a basic link test
+// to ensure some coverage.
+func TestLink(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/linktest",
+ WorkDir: "/root",
+ }, "./link_test"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ } else if got != "" {
+ t.Errorf("test failed:\n%s", got)
}
- defer d.CleanUp()
}
func TestMain(m *testing.M) {
diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go
index 2488be383..70bbe5121 100644
--- a/test/e2e/regression_test.go
+++ b/test/e2e/regression_test.go
@@ -15,10 +15,11 @@
package integration
import (
+ "context"
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
)
// Test that UDS can be created using overlay when parent directory is in lower
@@ -27,19 +28,20 @@ import (
// Prerequisite: the directory where the socket file is created must not have
// been open for write before bind(2) is called.
func TestBindOverlay(t *testing.T) {
- if err := dockerutil.Pull("ubuntu:trusty"); err != nil {
- t.Fatal("docker pull failed:", err)
- }
- d := dockerutil.MakeDocker("bind-overlay-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
- cmd := "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p"
- got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd)
+ // Run the container.
+ got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p")
if err != nil {
- t.Fatal("docker run failed:", err)
+ t.Fatalf("docker run failed: %v", err)
}
+ // Check the output contains what we want.
if want := "foobar-asdf"; !strings.Contains(got, want) {
t.Fatalf("docker run output is missing %q: %s", want, got)
}
- defer d.CleanUp()
}
diff --git a/test/fuse/BUILD b/test/fuse/BUILD
new file mode 100644
index 000000000..56157c96b
--- /dev/null
+++ b/test/fuse/BUILD
@@ -0,0 +1,9 @@
+load("//test/runner:defs.bzl", "syscall_test")
+
+package(licenses = ["notice"])
+
+syscall_test(
+ fuse = "True",
+ test = "//test/fuse/linux:stat_test",
+ vfs2 = "True",
+)
diff --git a/test/fuse/README.md b/test/fuse/README.md
new file mode 100644
index 000000000..734c3a4e3
--- /dev/null
+++ b/test/fuse/README.md
@@ -0,0 +1,103 @@
+# gVisor FUSE Test Suite
+
+This is an integration test suite for fuse(4) filesystem. It runs under both
+gVisor and Linux, and ensures compatibility between the two. This test suite is
+based on system calls test.
+
+This document describes the framework of fuse integration test and the
+guidelines that should be followed when adding new fuse tests.
+
+## Integration Test Framework
+
+Please refer to the figure below. `>` is entering the function, `<` is leaving
+the function, and `=` indicates sequentially entering and leaving.
+
+```
+ | Client (Test Main Process) | Server (FUSE Daemon)
+ | |
+ | >TEST_F() |
+ | >SetUp() |
+ | =MountFuse() |
+ | >SetUpFuseServer() |
+ | [create communication pipes] |
+ | =fork() | =fork()
+ | >WaitCompleted() |
+ | [wait for MarkDone()] |
+ | | =ConsumeFuseInit()
+ | | =MarkDone()
+ | <WaitCompleted() |
+ | <SetUpFuseServer() |
+ | <SetUp() |
+ | >SetExpected() |
+ | [construct expected reaction] |
+ | | >FuseLoop()
+ | | >ReceiveExpected()
+ | | [wait data from pipe]
+ | [write data to pipe] |
+ | [wait for MarkDone()] |
+ | | [save data to memory]
+ | | =MarkDone()
+ | <SetExpected() |
+ | | <ReceiveExpected()
+ | | >read()
+ | | [wait for fs operation]
+ | >[Do fs operation] |
+ | [wait for fs response] |
+ | | <read()
+ | | =CompareRequest()
+ | | =write() [write fs response]
+ | <[Do fs operation] |
+ | =[Test fs operation result] |
+ | =[wait for MarkDone()] |
+ | | =MarkDone()
+ | >TearDown() |
+ | =UnmountFuse() |
+ | <TearDown() |
+ | <TEST_F() |
+```
+
+## Running the tests
+
+Based on syscall tests, fuse tests can run in different environments. To enable
+fuse testing environment, the test targets should be appended with `_fuse`.
+
+For example, to run fuse test in `stat_test.cc`:
+
+```bash
+$ bazel test //test/fuse:stat_test_runsc_ptrace_vfs2_fuse
+```
+
+Test all targets tagged with fuse:
+
+```bash
+$ bazel test --test_tag_filters=fuse //test/fuse/...
+```
+
+## Writing a new FUSE test
+
+1. Add test targets in `BUILD` and `linux/BUILD`.
+2. Inherit your test from `FuseTest` base class. It allows you to:
+ - Run a fake FUSE server in background during each test setup.
+ - Create pipes for communication and provide utility functions.
+ - Stop FUSE server after test completes.
+3. Customize your comparison function for request assessment in FUSE server.
+4. Add the mapping of the size of structs if you are working on new FUSE
+ opcode.
+ - Please update `FuseTest::GetPayloadSize()` for each new FUSE opcode.
+5. Build the expected request-response pair of your FUSE operation.
+6. Call `SetExpected()` function to inject the expected reaction.
+7. Check the response and/or errors.
+8. Finally call `WaitCompleted()` to ensure the FUSE server acts correctly.
+
+A few customized matchers used in syscalls test are encouraged to test the
+outcome of filesystem operations. Such as:
+
+```cc
+SyscallSucceeds()
+SyscallSucceedsWithValue(...)
+SyscallFails()
+SyscallFailsWithErrno(...)
+```
+
+Please refer to [test/syscalls/README.md](../syscalls/README.md) for further
+details.
diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD
new file mode 100644
index 000000000..4871bb531
--- /dev/null
+++ b/test/fuse/linux/BUILD
@@ -0,0 +1,32 @@
+load("//tools:defs.bzl", "cc_binary", "cc_library", "gtest")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "stat_test",
+ testonly = 1,
+ srcs = ["stat_test.cc"],
+ deps = [
+ gtest,
+ ":fuse_base",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_library(
+ name = "fuse_base",
+ testonly = 1,
+ srcs = ["fuse_base.cc"],
+ hdrs = ["fuse_base.h"],
+ deps = [
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc
new file mode 100644
index 000000000..9c3124472
--- /dev/null
+++ b/test/fuse/linux/fuse_base.cc
@@ -0,0 +1,208 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/fuse/linux/fuse_base.h"
+
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <string.h>
+#include <sys/mount.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include <iostream>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_format.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+void FuseTest::SetUp() {
+ MountFuse();
+ SetUpFuseServer();
+}
+
+void FuseTest::TearDown() { UnmountFuse(); }
+
+// Since CompareRequest is running in background thread, gTest assertions and
+// expectations won't directly reflect the test result. However, the FUSE
+// background server still connects to the same standard I/O as testing main
+// thread. So EXPECT_XX can still be used to show different results. To
+// ensure failed testing result is observable, return false and the result
+// will be sent to test main thread via pipe.
+bool FuseTest::CompareRequest(void* expected_mem, size_t expected_len,
+ void* real_mem, size_t real_len) {
+ if (expected_len != real_len) return false;
+ return memcmp(expected_mem, real_mem, expected_len) == 0;
+}
+
+// SetExpected is called by the testing main thread to set expected request-
+// response pair of a single FUSE operation.
+void FuseTest::SetExpected(struct iovec* iov_in, int iov_in_cnt,
+ struct iovec* iov_out, int iov_out_cnt) {
+ EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_in, iov_in_cnt),
+ SyscallSucceedsWithValue(::testing::Gt(0)));
+ WaitCompleted();
+
+ EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_out, iov_out_cnt),
+ SyscallSucceedsWithValue(::testing::Gt(0)));
+ WaitCompleted();
+}
+
+// WaitCompleted waits for the FUSE server to finish its job and check if it
+// completes without errors.
+void FuseTest::WaitCompleted() {
+ char success;
+ EXPECT_THAT(RetryEINTR(read)(done_[0], &success, sizeof(success)),
+ SyscallSucceedsWithValue(1));
+}
+
+void FuseTest::MountFuse() {
+ EXPECT_THAT(dev_fd_ = open("/dev/fuse", O_RDWR), SyscallSucceeds());
+
+ std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, kMountOpts);
+ mount_point_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(mount("fuse", mount_point_.path().c_str(), "fuse",
+ MS_NODEV | MS_NOSUID, mount_opts.c_str()),
+ SyscallSucceedsWithValue(0));
+}
+
+void FuseTest::UnmountFuse() {
+ EXPECT_THAT(umount(mount_point_.path().c_str()), SyscallSucceeds());
+ // TODO(gvisor.dev/issue/3330): ensure the process is terminated successfully.
+}
+
+// ConsumeFuseInit consumes the first FUSE request and returns the
+// corresponding PosixError.
+PosixError FuseTest::ConsumeFuseInit() {
+ RETURN_ERROR_IF_SYSCALL_FAIL(
+ RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size()));
+
+ struct iovec iov_out[2];
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_init_out),
+ .error = 0,
+ .unique = 2,
+ };
+ // Returns a fake fuse_init_out with 7.0 version to avoid ECONNREFUSED
+ // error in the initialization of FUSE connection.
+ struct fuse_init_out out_payload = {
+ .major = 7,
+ };
+ iov_out[0].iov_len = sizeof(out_header);
+ iov_out[0].iov_base = &out_header;
+ iov_out[1].iov_len = sizeof(out_payload);
+ iov_out[1].iov_base = &out_payload;
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(writev)(dev_fd_, iov_out, 2));
+ return NoError();
+}
+
+// ReceiveExpected reads 1 pair of expected fuse request-response `iovec`s
+// from pipe and save them into member variables of this testing instance.
+void FuseTest::ReceiveExpected() {
+ // Set expected fuse_in request.
+ EXPECT_THAT(len_in_ = RetryEINTR(read)(set_expected_[0], mem_in_.data(),
+ mem_in_.size()),
+ SyscallSucceedsWithValue(::testing::Gt(0)));
+ MarkDone(len_in_ > 0);
+
+ // Set expected fuse_out response.
+ EXPECT_THAT(len_out_ = RetryEINTR(read)(set_expected_[0], mem_out_.data(),
+ mem_out_.size()),
+ SyscallSucceedsWithValue(::testing::Gt(0)));
+ MarkDone(len_out_ > 0);
+}
+
+// MarkDone writes 1 byte of success indicator through pipe.
+void FuseTest::MarkDone(bool success) {
+ char data = success ? 1 : 0;
+ EXPECT_THAT(RetryEINTR(write)(done_[1], &data, sizeof(data)),
+ SyscallSucceedsWithValue(1));
+}
+
+// FuseLoop is the implementation of the fake FUSE server. Read from /dev/fuse,
+// compare the request by CompareRequest (use derived function if specified),
+// and write the expected response to /dev/fuse.
+void FuseTest::FuseLoop() {
+ bool success = true;
+ ssize_t len = 0;
+ while (true) {
+ ReceiveExpected();
+
+ EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size()),
+ SyscallSucceedsWithValue(len_in_));
+ if (len != len_in_) success = false;
+
+ if (!CompareRequest(buf_.data(), len_in_, mem_in_.data(), len_in_)) {
+ std::cerr << "the FUSE request is not expected" << std::endl;
+ success = false;
+ }
+
+ EXPECT_THAT(len = RetryEINTR(write)(dev_fd_, mem_out_.data(), len_out_),
+ SyscallSucceedsWithValue(len_out_));
+ if (len != len_out_) success = false;
+ MarkDone(success);
+ }
+}
+
+// SetUpFuseServer creates 2 pipes. First is for testing client to send the
+// expected request-response pair, and the other acts as a checkpoint for the
+// FUSE server to notify the client that it can proceed.
+void FuseTest::SetUpFuseServer() {
+ ASSERT_THAT(pipe(set_expected_), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(pipe(done_), SyscallSucceedsWithValue(0));
+
+ switch (fork()) {
+ case -1:
+ GTEST_FAIL();
+ return;
+ case 0:
+ break;
+ default:
+ ASSERT_THAT(close(set_expected_[0]), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(close(done_[1]), SyscallSucceedsWithValue(0));
+ WaitCompleted();
+ return;
+ }
+
+ ASSERT_THAT(close(set_expected_[1]), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(close(done_[0]), SyscallSucceedsWithValue(0));
+
+ MarkDone(ConsumeFuseInit().ok());
+
+ FuseLoop();
+ _exit(0);
+}
+
+// GetPayloadSize is a helper function to get the number of bytes of a
+// specific FUSE operation struct.
+size_t FuseTest::GetPayloadSize(uint32_t opcode, bool in) {
+ switch (opcode) {
+ case FUSE_INIT:
+ return in ? sizeof(struct fuse_init_in) : sizeof(struct fuse_init_out);
+ default:
+ break;
+ }
+ return 0;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/fuse/linux/fuse_base.h b/test/fuse/linux/fuse_base.h
new file mode 100644
index 000000000..3a2f255a9
--- /dev/null
+++ b/test/fuse/linux/fuse_base.h
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_FUSE_FUSE_BASE_H_
+#define GVISOR_TEST_FUSE_FUSE_BASE_H_
+
+#include <linux/fuse.h>
+#include <sys/uio.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+constexpr char kMountOpts[] = "rootmode=755,user_id=0,group_id=0";
+
+class FuseTest : public ::testing::Test {
+ public:
+ FuseTest() {
+ buf_.resize(FUSE_MIN_READ_BUFFER);
+ mem_in_.resize(FUSE_MIN_READ_BUFFER);
+ mem_out_.resize(FUSE_MIN_READ_BUFFER);
+ }
+ void SetUp() override;
+ void TearDown() override;
+
+ // CompareRequest is used by the FUSE server and should be implemented to
+ // compare different FUSE operations. It compares the actual FUSE input
+ // request with the expected one set by `SetExpected()`.
+ virtual bool CompareRequest(void* expected_mem, size_t expected_len,
+ void* real_mem, size_t real_len);
+
+ // SetExpected is called by the testing main thread. Writes a request-
+ // response pair into FUSE server's member variables via pipe.
+ void SetExpected(struct iovec* iov_in, int iov_in_cnt, struct iovec* iov_out,
+ int iov_out_cnt);
+
+ // WaitCompleted waits for FUSE server to complete its processing. It
+ // complains if the FUSE server responds failure during tests.
+ void WaitCompleted();
+
+ protected:
+ TempPath mount_point_;
+
+ private:
+ void MountFuse();
+ void UnmountFuse();
+
+ // ConsumeFuseInit is only used during FUSE server setup.
+ PosixError ConsumeFuseInit();
+
+ // ReceiveExpected is the FUSE server side's corresponding code of
+ // `SetExpected()`. Save the request-response pair into its memory.
+ void ReceiveExpected();
+
+ // MarkDone is used by the FUSE server to tell testing main if it's OK to
+ // proceed next command.
+ void MarkDone(bool success);
+
+ // FuseLoop is where the FUSE server stay until it is terminated.
+ void FuseLoop();
+
+ // SetUpFuseServer creates 2 pipes for communication and forks FUSE server.
+ void SetUpFuseServer();
+
+ // GetPayloadSize is a helper function to get the number of bytes of a
+ // specific FUSE operation struct.
+ size_t GetPayloadSize(uint32_t opcode, bool in);
+
+ int dev_fd_;
+ int set_expected_[2];
+ int done_[2];
+
+ std::vector<char> buf_;
+ std::vector<char> mem_in_;
+ std::vector<char> mem_out_;
+ ssize_t len_in_;
+ ssize_t len_out_;
+};
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_FUSE_FUSE_BASE_H_
diff --git a/test/fuse/linux/stat_test.cc b/test/fuse/linux/stat_test.cc
new file mode 100644
index 000000000..172e09867
--- /dev/null
+++ b/test/fuse/linux/stat_test.cc
@@ -0,0 +1,169 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/fuse.h>
+#include <sys/stat.h>
+#include <sys/statfs.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/fuse/linux/fuse_base.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class StatTest : public FuseTest {
+ public:
+ bool CompareRequest(void* expected_mem, size_t expected_len, void* real_mem,
+ size_t real_len) override {
+ if (expected_len != real_len) return false;
+ struct fuse_in_header* real_header =
+ reinterpret_cast<fuse_in_header*>(real_mem);
+
+ if (real_header->opcode != FUSE_GETATTR) {
+ std::cerr << "expect header opcode " << FUSE_GETATTR << " but got "
+ << real_header->opcode << std::endl;
+ return false;
+ }
+ return true;
+ }
+
+ bool StatsAreEqual(struct stat expected, struct stat actual) {
+ // device number will be dynamically allocated by kernel, we cannot know
+ // in advance
+ actual.st_dev = expected.st_dev;
+ return memcmp(&expected, &actual, sizeof(struct stat)) == 0;
+ }
+};
+
+TEST_F(StatTest, StatNormal) {
+ struct iovec iov_in[2];
+ struct iovec iov_out[2];
+
+ struct fuse_in_header in_header = {
+ .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in),
+ .opcode = FUSE_GETATTR,
+ .unique = 4,
+ .nodeid = 1,
+ .uid = 0,
+ .gid = 0,
+ .pid = 4,
+ .padding = 0,
+ };
+ struct fuse_getattr_in in_payload = {0};
+ iov_in[0].iov_len = sizeof(in_header);
+ iov_in[0].iov_base = &in_header;
+ iov_in[1].iov_len = sizeof(in_payload);
+ iov_in[1].iov_base = &in_payload;
+
+ mode_t expected_mode = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH;
+ struct timespec atime = {.tv_sec = 1595436289, .tv_nsec = 134150844};
+ struct timespec mtime = {.tv_sec = 1595436290, .tv_nsec = 134150845};
+ struct timespec ctime = {.tv_sec = 1595436291, .tv_nsec = 134150846};
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out),
+ .error = 0,
+ .unique = 4,
+ };
+ struct fuse_attr attr = {
+ .ino = 1,
+ .size = 512,
+ .blocks = 4,
+ .atime = static_cast<uint64_t>(atime.tv_sec),
+ .mtime = static_cast<uint64_t>(mtime.tv_sec),
+ .ctime = static_cast<uint64_t>(ctime.tv_sec),
+ .atimensec = static_cast<uint32_t>(atime.tv_nsec),
+ .mtimensec = static_cast<uint32_t>(mtime.tv_nsec),
+ .ctimensec = static_cast<uint32_t>(ctime.tv_nsec),
+ .mode = expected_mode,
+ .nlink = 2,
+ .uid = 1234,
+ .gid = 4321,
+ .rdev = 12,
+ .blksize = 4096,
+ };
+ struct fuse_attr_out out_payload = {
+ .attr = attr,
+ };
+ iov_out[0].iov_len = sizeof(out_header);
+ iov_out[0].iov_base = &out_header;
+ iov_out[1].iov_len = sizeof(out_payload);
+ iov_out[1].iov_base = &out_payload;
+
+ SetExpected(iov_in, 2, iov_out, 2);
+
+ struct stat stat_buf;
+ EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), SyscallSucceeds());
+
+ struct stat expected_stat = {
+ .st_ino = attr.ino,
+ .st_nlink = attr.nlink,
+ .st_mode = expected_mode,
+ .st_uid = attr.uid,
+ .st_gid = attr.gid,
+ .st_rdev = attr.rdev,
+ .st_size = static_cast<off_t>(attr.size),
+ .st_blksize = attr.blksize,
+ .st_blocks = static_cast<blkcnt_t>(attr.blocks),
+ .st_atim = atime,
+ .st_mtim = mtime,
+ .st_ctim = ctime,
+ };
+ EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat));
+ WaitCompleted();
+}
+
+TEST_F(StatTest, StatNotFound) {
+ struct iovec iov_in[2];
+ struct iovec iov_out[2];
+
+ struct fuse_in_header in_header = {
+ .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in),
+ .opcode = FUSE_GETATTR,
+ .unique = 4,
+ };
+ struct fuse_getattr_in in_payload = {0};
+ iov_in[0].iov_len = sizeof(in_header);
+ iov_in[0].iov_base = &in_header;
+ iov_in[1].iov_len = sizeof(in_payload);
+ iov_in[1].iov_base = &in_payload;
+
+ struct fuse_out_header out_header = {
+ .len = sizeof(struct fuse_out_header),
+ .error = -ENOENT,
+ .unique = 4,
+ };
+ iov_out[0].iov_len = sizeof(out_header);
+ iov_out[0].iov_base = &out_header;
+
+ SetExpected(iov_in, 2, iov_out, 1);
+
+ struct stat stat_buf;
+ EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf),
+ SyscallFailsWithErrno(ENOENT));
+ WaitCompleted();
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/image/BUILD b/test/image/BUILD
index 09b0a0ad5..e749e47d4 100644
--- a/test/image/BUILD
+++ b/test/image/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -14,7 +14,7 @@ go_test(
"ruby.rb",
"ruby.sh",
],
- embed = [":image"],
+ library = ":image",
tags = [
# Requires docker and runsc to be configured before the test runs.
"manual",
@@ -22,13 +22,12 @@ go_test(
],
visibility = ["//:sandbox"],
deps = [
- "//runsc/dockerutil",
- "//runsc/testutil",
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
],
)
go_library(
name = "image",
srcs = ["image.go"],
- importpath = "gvisor.dev/gvisor/test/image",
)
diff --git a/test/image/image_test.go b/test/image/image_test.go
index d0dcb1861..ac6186688 100644
--- a/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -22,30 +22,44 @@
package image
import (
+ "context"
"flag"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
- "path/filepath"
"strings"
"testing"
"time"
- "gvisor.dev/gvisor/runsc/dockerutil"
- "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
+// defaultWait defines how long to wait for progress.
+//
+// See BUILD: This is at least a "large" test, so allow up to 1 minute for any
+// given "wait" step. Note that all tests are run in parallel, which may cause
+// individual slow-downs (but a huge speed-up in aggregate).
+const defaultWait = time.Minute
+
func TestHelloWorld(t *testing.T) {
- d := dockerutil.MakeDocker("hello-test")
- if err := d.Run("hello-world"); err != nil {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Run the basic container.
+ out, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "echo", "Hello world!")
+ if err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
- if _, err := d.WaitForOutput("Hello from Docker!", 5*time.Second); err != nil {
- t.Fatalf("docker didn't say hello: %v", err)
+ // Check the output.
+ if !strings.Contains(out, "Hello world!") {
+ t.Fatalf("docker didn't say hello: got %s", out)
}
}
@@ -102,31 +116,28 @@ func testHTTPServer(t *testing.T, port int) {
}
func TestHttpd(t *testing.T) {
- if err := dockerutil.Pull("httpd"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("http-test")
-
- dir, err := dockerutil.PrepareFiles("latin10k.txt")
- if err != nil {
- t.Fatalf("PrepareFiles() failed: %v", err)
- }
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container.
- mountArg := dockerutil.MountArg(dir, "/usr/local/apache2/htdocs", dockerutil.ReadOnly)
- if err := d.Run("-p", "80", mountArg, "httpd"); err != nil {
+ opts := dockerutil.RunOpts{
+ Image: "basic/httpd",
+ Ports: []int{80},
+ }
+ d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt")
+ if err := d.Spawn(ctx, opts); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Find where port 80 is mapped to.
- port, err := d.FindPort(80)
+ port, err := d.FindPort(ctx, 80)
if err != nil {
- t.Fatalf("docker.FindPort(80) failed: %v", err)
+ t.Fatalf("FindPort(80) failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
t.Errorf("WaitForHTTP() timeout: %v", err)
}
@@ -134,31 +145,28 @@ func TestHttpd(t *testing.T) {
}
func TestNginx(t *testing.T) {
- if err := dockerutil.Pull("nginx"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("net-test")
-
- dir, err := dockerutil.PrepareFiles("latin10k.txt")
- if err != nil {
- t.Fatalf("PrepareFiles() failed: %v", err)
- }
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// Start the container.
- mountArg := dockerutil.MountArg(dir, "/usr/share/nginx/html", dockerutil.ReadOnly)
- if err := d.Run("-p", "80", mountArg, "nginx"); err != nil {
+ opts := dockerutil.RunOpts{
+ Image: "basic/nginx",
+ Ports: []int{80},
+ }
+ d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt")
+ if err := d.Spawn(ctx, opts); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Find where port 80 is mapped to.
- port, err := d.FindPort(80)
+ port, err := d.FindPort(ctx, 80)
if err != nil {
- t.Fatalf("docker.FindPort(80) failed: %v", err)
+ t.Fatalf("FindPort(80) failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
t.Errorf("WaitForHTTP() timeout: %v", err)
}
@@ -166,103 +174,65 @@ func TestNginx(t *testing.T) {
}
func TestMysql(t *testing.T) {
- if err := dockerutil.Pull("mysql"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("mysql-test")
+ ctx := context.Background()
+ server := dockerutil.MakeContainer(ctx, t)
+ defer server.CleanUp(ctx)
// Start the container.
- if err := d.Run("-e", "MYSQL_ROOT_PASSWORD=foobar123", "mysql"); err != nil {
+ if err := server.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/mysql",
+ Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"},
+ }); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Wait until it's up and running.
- if _, err := d.WaitForOutput("port: 3306 MySQL Community Server", 3*time.Minute); err != nil {
- t.Fatalf("docker.WaitForOutput() timeout: %v", err)
+ if _, err := server.WaitForOutput(ctx, "port: 3306 MySQL Community Server", defaultWait); err != nil {
+ t.Fatalf("WaitForOutput() timeout: %v", err)
}
- client := dockerutil.MakeDocker("mysql-client-test")
- dir, err := dockerutil.PrepareFiles("mysql.sql")
- if err != nil {
- t.Fatalf("PrepareFiles() failed: %v", err)
- }
+ // Generate the client and copy in the SQL payload.
+ client := dockerutil.MakeContainer(ctx, t)
+ defer client.CleanUp(ctx)
- // Tell mysql client to connect to the server and execute the file in verbose
- // mode to verify the output.
- args := []string{
- dockerutil.LinkArg(&d, "mysql"),
- dockerutil.MountArg(dir, "/sql", dockerutil.ReadWrite),
- "mysql",
- "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql",
+ // Tell mysql client to connect to the server and execute the file in
+ // verbose mode to verify the output.
+ opts := dockerutil.RunOpts{
+ Image: "basic/mysql",
+ Links: []string{server.MakeLink("mysql")},
}
- if err := client.Run(args...); err != nil {
+ client.CopyFiles(&opts, "/sql", "test/image/mysql.sql")
+ if _, err := client.Run(ctx, opts, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer client.CleanUp()
// Ensure file executed to the end and shutdown mysql.
- if _, err := client.WaitForOutput("--------------\nshutdown\n--------------", 15*time.Second); err != nil {
- t.Fatalf("docker.WaitForOutput() timeout: %v", err)
- }
- if _, err := d.WaitForOutput("mysqld: Shutdown complete", 30*time.Second); err != nil {
- t.Fatalf("docker.WaitForOutput() timeout: %v", err)
- }
-}
-
-func TestPythonHello(t *testing.T) {
- // TODO(b/136503277): Once we have more complete python runtime tests,
- // we can drop this one.
- const img = "gcr.io/gvisor-presubmit/python-hello"
- if err := dockerutil.Pull(img); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("python-hello-test")
- if err := d.Run("-p", "8080", img); err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- defer d.CleanUp()
-
- // Find where port 8080 is mapped to.
- port, err := d.FindPort(8080)
- if err != nil {
- t.Fatalf("docker.FindPort(8080) failed: %v", err)
- }
-
- // Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
- t.Fatalf("WaitForHTTP() timeout: %v", err)
- }
-
- // Ensure that content is being served.
- url := fmt.Sprintf("http://localhost:%d", port)
- resp, err := http.Get(url)
- if err != nil {
- t.Errorf("Error reaching http server: %v", err)
- }
- if want := http.StatusOK; resp.StatusCode != want {
- t.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want)
+ if _, err := server.WaitForOutput(ctx, "mysqld: Shutdown complete", defaultWait); err != nil {
+ t.Fatalf("WaitForOutput() timeout: %v", err)
}
}
func TestTomcat(t *testing.T) {
- if err := dockerutil.Pull("tomcat:8.0"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("tomcat-test")
- if err := d.Run("-p", "8080", "tomcat:8.0"); err != nil {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start the server.
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/tomcat",
+ Ports: []int{8080},
+ }); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Find where port 8080 is mapped to.
- port, err := d.FindPort(8080)
+ port, err := d.FindPort(ctx, 8080)
if err != nil {
- t.Fatalf("docker.FindPort(8080) failed: %v", err)
+ t.Fatalf("FindPort(8080) failed: %v", err)
}
// Wait until it's up and running.
- if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil {
+ if err := testutil.WaitForHTTP(port, defaultWait); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
@@ -278,32 +248,28 @@ func TestTomcat(t *testing.T) {
}
func TestRuby(t *testing.T) {
- if err := dockerutil.Pull("ruby"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("ruby-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
- dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh")
- if err != nil {
- t.Fatalf("PrepareFiles() failed: %v", err)
- }
- if err := os.Chmod(filepath.Join(dir, "ruby.sh"), 0333); err != nil {
- t.Fatalf("os.Chmod(%q, 0333) failed: %v", dir, err)
+ // Execute the ruby workload.
+ opts := dockerutil.RunOpts{
+ Image: "basic/ruby",
+ Ports: []int{8080},
}
-
- if err := d.Run("-p", "8080", dockerutil.MountArg(dir, "/src", dockerutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil {
+ d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh")
+ if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// Find where port 8080 is mapped to.
- port, err := d.FindPort(8080)
+ port, err := d.FindPort(ctx, 8080)
if err != nil {
- t.Fatalf("docker.FindPort(8080) failed: %v", err)
+ t.Fatalf("FindPort(8080) failed: %v", err)
}
// Wait until it's up and running, 'gem install' can take some time.
- if err := testutil.WaitForHTTP(port, 1*time.Minute); err != nil {
+ if err := testutil.WaitForHTTP(port, time.Minute); err != nil {
t.Fatalf("WaitForHTTP() timeout: %v", err)
}
@@ -326,21 +292,21 @@ func TestRuby(t *testing.T) {
}
func TestStdio(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := dockerutil.MakeDocker("stdio-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
wantStdout := "hello stdout"
wantStderr := "bonjour stderr"
cmd := fmt.Sprintf("echo %q; echo %q 1>&2;", wantStdout, wantStderr)
- if err := d.Run("alpine", "/bin/sh", "-c", cmd); err != nil {
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "/bin/sh", "-c", cmd); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
for _, want := range []string{wantStdout, wantStderr} {
- if _, err := d.WaitForOutput(want, 5*time.Second); err != nil {
+ if _, err := d.WaitForOutput(ctx, want, defaultWait); err != nil {
t.Fatalf("docker didn't get output %q : %v", want, err)
}
}
diff --git a/test/image/ruby.sh b/test/image/ruby.sh
index ebe8d5b0e..ebe8d5b0e 100644..100755
--- a/test/image/ruby.sh
+++ b/test/image/ruby.sh
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
new file mode 100644
index 000000000..66453772a
--- /dev/null
+++ b/test/iptables/BUILD
@@ -0,0 +1,38 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "iptables",
+ testonly = 1,
+ srcs = [
+ "filter_input.go",
+ "filter_output.go",
+ "iptables.go",
+ "iptables_unsafe.go",
+ "iptables_util.go",
+ "nat.go",
+ ],
+ visibility = ["//test/iptables:__subpackages__"],
+ deps = [
+ "//pkg/test/testutil",
+ ],
+)
+
+go_test(
+ name = "iptables_test",
+ size = "large",
+ srcs = [
+ "iptables_test.go",
+ ],
+ data = ["//test/iptables/runner"],
+ library = ":iptables",
+ tags = [
+ "local",
+ "manual",
+ ],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
diff --git a/test/iptables/README.md b/test/iptables/README.md
new file mode 100644
index 000000000..b9f44bd40
--- /dev/null
+++ b/test/iptables/README.md
@@ -0,0 +1,54 @@
+# iptables Tests
+
+iptables tests are run via `scripts/iptables_test.sh`.
+
+iptables requires raw socket support, so you must add the `--net-raw=true` flag
+to `/etc/docker/daemon.json` in order to use it.
+
+## Test Structure
+
+Each test implements `TestCase`, providing (1) a function to run inside the
+container and (2) a function to run locally. Those processes are given each
+others' IP addresses. The test succeeds when both functions succeed.
+
+The function inside the container (`ContainerAction`) typically sets some
+iptables rules and then tries to send or receive packets. The local function
+(`LocalAction`) will typically just send or receive packets.
+
+### Adding Tests
+
+1) Add your test to the `iptables` package.
+
+2) Register the test in an `init` function via `RegisterTestCase` (see
+`filter_input.go` as an example).
+
+3) Add it to `iptables_test.go` (see the other tests in that file).
+
+Your test is now runnable with bazel!
+
+## Run individual tests
+
+Build and install `runsc`. Re-run this when you modify gVisor:
+
+```bash
+$ bazel build //runsc && sudo cp bazel-bin/runsc/linux_amd64_pure_stripped/runsc $(which runsc)
+```
+
+Build the testing Docker container. Re-run this when you modify the test code in
+this directory:
+
+```bash
+$ make load-iptables
+```
+
+Run an individual test via:
+
+```bash
+$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME>
+```
+
+To run an individual test with `runc`:
+
+```bash
+$ bazel test //test/iptables:iptables_test --test_filter=<TESTNAME> --test_arg=--runtime=runc
+```
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
new file mode 100644
index 000000000..b45d448b8
--- /dev/null
+++ b/test/iptables/filter_input.go
@@ -0,0 +1,745 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "time"
+)
+
+const (
+ dropPort = 2401
+ acceptPort = 2402
+ sendloopDuration = 2 * time.Second
+ chainName = "foochain"
+)
+
+func init() {
+ RegisterTestCase(FilterInputDropAll{})
+ RegisterTestCase(FilterInputDropDifferentUDPPort{})
+ RegisterTestCase(FilterInputDropOnlyUDP{})
+ RegisterTestCase(FilterInputDropTCPDestPort{})
+ RegisterTestCase(FilterInputDropTCPSrcPort{})
+ RegisterTestCase(FilterInputDropUDPPort{})
+ RegisterTestCase(FilterInputDropUDP{})
+ RegisterTestCase(FilterInputCreateUserChain{})
+ RegisterTestCase(FilterInputDefaultPolicyAccept{})
+ RegisterTestCase(FilterInputDefaultPolicyDrop{})
+ RegisterTestCase(FilterInputReturnUnderflow{})
+ RegisterTestCase(FilterInputSerializeJump{})
+ RegisterTestCase(FilterInputJumpBasic{})
+ RegisterTestCase(FilterInputJumpReturn{})
+ RegisterTestCase(FilterInputJumpReturnDrop{})
+ RegisterTestCase(FilterInputJumpBuiltin{})
+ RegisterTestCase(FilterInputJumpTwice{})
+ RegisterTestCase(FilterInputDestination{})
+ RegisterTestCase(FilterInputInvertDestination{})
+ RegisterTestCase(FilterInputSource{})
+ RegisterTestCase(FilterInputInvertSource{})
+}
+
+// FilterInputDropUDP tests that we can drop UDP traffic.
+type FilterInputDropUDP struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDropUDP) Name() string {
+ return "FilterInputDropUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, dropPort)
+}
+
+// FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic.
+type FilterInputDropOnlyUDP struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDropOnlyUDP) Name() string {
+ return "FilterInputDropOnlyUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropOnlyUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for a TCP connection, which should be allowed.
+ if err := listenTCP(ctx, acceptPort); err != nil {
+ return fmt.Errorf("failed to establish a connection %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropOnlyUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Try to establish a TCP connection with the container, which should
+ // succeed.
+ return connectTCP(ctx, ip, acceptPort)
+}
+
+// FilterInputDropUDPPort tests that we can drop UDP traffic by port.
+type FilterInputDropUDPPort struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDropUDPPort) Name() string {
+ return "FilterInputDropUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, dropPort)
+}
+
+// FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port
+// doesn't drop packets on other ports.
+type FilterInputDropDifferentUDPPort struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDropDifferentUDPPort) Name() string {
+ return "FilterInputDropDifferentUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropDifferentUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on another port.
+ if err := listenUDP(ctx, acceptPort); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropDifferentUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports.
+type FilterInputDropTCPDestPort struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDropTCPDestPort) Name() string {
+ return "FilterInputDropTCPDestPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on drop port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Ensure we cannot connect to the container.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, dropPort); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort)
+ }
+ return nil
+}
+
+// FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports.
+type FilterInputDropTCPSrcPort struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDropTCPSrcPort) Name() string {
+ return "FilterInputDropTCPSrcPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Drop anything from an ephemeral port.
+ if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Ensure we cannot connect to the container.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, dropPort); err == nil {
+ return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort)
+ }
+ return nil
+}
+
+// FilterInputDropAll tests that we can drop all traffic to the INPUT chain.
+type FilterInputDropAll struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDropAll) Name() string {
+ return "FilterInputDropAll"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDropAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for all packets on dropPort.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("packets should have been dropped, but got a packet")
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDropAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, dropPort)
+}
+
+// FilterInputMultiUDPRules verifies that multiple UDP rules are applied
+// correctly. This has the added benefit of testing whether we're serializing
+// rules correctly -- if we do it incorrectly, the iptables tool will
+// misunderstand and save the wrong tables.
+type FilterInputMultiUDPRules struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputMultiUDPRules) Name() string {
+ return "FilterInputMultiUDPRules"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputMultiUDPRules) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"},
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"},
+ {"-L"},
+ }
+ return filterTableRules(ipv6, rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputMultiUDPRules) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be
+// specified.
+type FilterInputRequireProtocolUDP struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputRequireProtocolUDP) Name() string {
+ return "FilterInputRequireProtocolUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputRequireProtocolUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil {
+ return errors.New("expected iptables to fail with out \"-p udp\", but succeeded")
+ }
+ return nil
+}
+
+func (FilterInputRequireProtocolUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputCreateUserChain tests chain creation.
+type FilterInputCreateUserChain struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputCreateUserChain) Name() string {
+ return "FilterInputCreateUserChain"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputCreateUserChain) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ // Create a chain.
+ {"-N", chainName},
+ // Add a simple rule to the chain.
+ {"-A", chainName, "-j", "DROP"},
+ }
+ return filterTableRules(ipv6, rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputCreateUserChain) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputDefaultPolicyAccept tests the default ACCEPT policy.
+type FilterInputDefaultPolicyAccept struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDefaultPolicyAccept) Name() string {
+ return "FilterInputDefaultPolicyAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDefaultPolicyAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Set the default policy to accept, then receive a packet.
+ if err := filterTable(ipv6, "-P", "INPUT", "ACCEPT"); err != nil {
+ return err
+ }
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDefaultPolicyAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputDefaultPolicyDrop tests the default DROP policy.
+type FilterInputDefaultPolicyDrop struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDefaultPolicyDrop) Name() string {
+ return "FilterInputDefaultPolicyDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDefaultPolicyDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-P", "INPUT", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDefaultPolicyDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes
+// the underflow rule (i.e. default policy) to be executed.
+type FilterInputReturnUnderflow struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputReturnUnderflow) Name() string {
+ return "FilterInputReturnUnderflow"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputReturnUnderflow) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Add a RETURN rule followed by an unconditional accept, and set the
+ // default policy to DROP.
+ rules := [][]string{
+ {"-A", "INPUT", "-j", "RETURN"},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-P", "INPUT", "ACCEPT"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // We should receive packets, as the RETURN rule will trigger the default
+ // ACCEPT policy.
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputReturnUnderflow) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputSerializeJump verifies that we can serialize jumps.
+type FilterInputSerializeJump struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputSerializeJump) Name() string {
+ return "FilterInputSerializeJump"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputSerializeJump) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Write a JUMP rule, the serialize it with `-L`.
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-L"},
+ }
+ return filterTableRules(ipv6, rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputSerializeJump) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpBasic jumps to a chain and executes a rule there.
+type FilterInputJumpBasic struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBasic) Name() string {
+ return "FilterInputJumpBasic"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBasic) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBasic) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputJumpReturn jumps, returns, and executes a rule.
+type FilterInputJumpReturn struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturn) Name() string {
+ return "FilterInputJumpReturn"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturn) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-P", "INPUT", "ACCEPT"},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "RETURN"},
+ {"-A", chainName, "-j", "DROP"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturn) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets.
+type FilterInputJumpReturnDrop struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturnDrop) Name() string {
+ return "FilterInputJumpReturnDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturnDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-A", chainName, "-j", "RETURN"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturnDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, dropPort)
+}
+
+// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal.
+type FilterInputJumpBuiltin struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBuiltin) Name() string {
+ return "FilterInputJumpBuiltin"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBuiltin) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "INPUT", "-j", "OUTPUT"); err == nil {
+ return fmt.Errorf("iptables should be unable to jump to a built-in chain")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBuiltin) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpTwice jumps twice, then returns twice and executes a rule.
+type FilterInputJumpTwice struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputJumpTwice) Name() string {
+ return "FilterInputJumpTwice"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpTwice) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ const chainName2 = chainName + "2"
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-N", chainName2},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", chainName2},
+ {"-A", "INPUT", "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // UDP packets should jump and return twice, eventually hitting the
+ // ACCEPT rule.
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpTwice) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputDestination verifies that we can filter packets via `-d
+// <ipaddr>`.
+type FilterInputDestination struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputDestination) Name() string {
+ return "FilterInputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ addrs, err := localAddrs(ipv6)
+ if err != nil {
+ return err
+ }
+
+ // Make INPUT's default action DROP, then ACCEPT all packets bound for
+ // this machine.
+ rules := [][]string{{"-P", "INPUT", "DROP"}}
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"})
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputInvertDestination verifies that we can filter packets via `! -d
+// <ipaddr>`.
+type FilterInputInvertDestination struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputInvertDestination) Name() string {
+ return "FilterInputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets not bound
+ // for 127.0.0.1.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputSource verifies that we can filter packets via `-s
+// <ipaddr>`.
+type FilterInputSource struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputSource) Name() string {
+ return "FilterInputSource"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets from this
+ // machine.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "-s", fmt.Sprintf("%v", ip), "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// FilterInputInvertSource verifies that we can filter packets via `! -s
+// <ipaddr>`.
+type FilterInputInvertSource struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (FilterInputInvertSource) Name() string {
+ return "FilterInputInvertSource"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Make INPUT's default action DROP, then ACCEPT all packets not bound
+ // for 127.0.0.1.
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-A", "INPUT", "!", "-s", localIP(ipv6), "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go
new file mode 100644
index 000000000..32bf2a992
--- /dev/null
+++ b/test/iptables/filter_output.go
@@ -0,0 +1,663 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+)
+
+func init() {
+ RegisterTestCase(FilterOutputDropTCPDestPort{})
+ RegisterTestCase(FilterOutputDropTCPSrcPort{})
+ RegisterTestCase(FilterOutputDestination{})
+ RegisterTestCase(FilterOutputInvertDestination{})
+ RegisterTestCase(FilterOutputAcceptTCPOwner{})
+ RegisterTestCase(FilterOutputDropTCPOwner{})
+ RegisterTestCase(FilterOutputAcceptUDPOwner{})
+ RegisterTestCase(FilterOutputDropUDPOwner{})
+ RegisterTestCase(FilterOutputOwnerFail{})
+ RegisterTestCase(FilterOutputAcceptGIDOwner{})
+ RegisterTestCase(FilterOutputDropGIDOwner{})
+ RegisterTestCase(FilterOutputInvertGIDOwner{})
+ RegisterTestCase(FilterOutputInvertUIDOwner{})
+ RegisterTestCase(FilterOutputInvertUIDAndGIDOwner{})
+ RegisterTestCase(FilterOutputInterfaceAccept{})
+ RegisterTestCase(FilterOutputInterfaceDrop{})
+ RegisterTestCase(FilterOutputInterface{})
+ RegisterTestCase(FilterOutputInterfaceBeginsWith{})
+ RegisterTestCase(FilterOutputInterfaceInvertDrop{})
+ RegisterTestCase(FilterOutputInterfaceInvertAccept{})
+}
+
+// FilterOutputDropTCPDestPort tests that connections are not accepted on
+// specified source ports.
+type FilterOutputDropTCPDestPort struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPDestPort) Name() string {
+ return "FilterOutputDropTCPDestPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, acceptPort); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// FilterOutputDropTCPSrcPort tests that connections are not accepted on
+// specified source ports.
+type FilterOutputDropTCPSrcPort struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPSrcPort) Name() string {
+ return "FilterOutputDropTCPSrcPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on drop port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, dropPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort)
+ }
+
+ return nil
+}
+
+// FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted.
+type FilterOutputAcceptTCPOwner struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptTCPOwner) Name() string {
+ return "FilterOutputAcceptTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, acceptPort)
+}
+
+// FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped.
+type FilterOutputDropTCPOwner struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputDropTCPOwner) Name() string {
+ return "FilterOutputDropTCPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, acceptPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted.
+type FilterOutputAcceptUDPOwner struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptUDPOwner) Name() string {
+ return "FilterOutputAcceptUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on acceptPort.
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(ctx, acceptPort)
+}
+
+// FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped.
+type FilterOutputDropUDPOwner struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputDropUDPOwner) Name() string {
+ return "FilterOutputDropUDPOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Send UDP packets on dropPort.
+ return sendUDPLoop(ctx, ip, dropPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Listen for UDP packets on dropPort.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, dropPort); err == nil {
+ return fmt.Errorf("packets should not be received")
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// FilterOutputOwnerFail tests that without uid/gid option, owner rule
+// will fail.
+type FilterOutputOwnerFail struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputOwnerFail) Name() string {
+ return "FilterOutputOwnerFail"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil {
+ return fmt.Errorf("Invalid argument")
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputOwnerFail) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // no-op.
+ return nil
+}
+
+// FilterOutputAcceptGIDOwner tests that TCP connections from gid owner are accepted.
+type FilterOutputAcceptGIDOwner struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputAcceptGIDOwner) Name() string {
+ return "FilterOutputAcceptGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputAcceptGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputAcceptGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, acceptPort)
+}
+
+// FilterOutputDropGIDOwner tests that TCP connections from gid owner are dropped.
+type FilterOutputDropGIDOwner struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputDropGIDOwner) Name() string {
+ return "FilterOutputDropGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDropGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDropGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, acceptPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInvertGIDOwner tests that TCP connections from gid owner are dropped.
+type FilterOutputInvertGIDOwner struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertGIDOwner) Name() string {
+ return "FilterOutputInvertGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--gid-owner", "root", "-j", "ACCEPT"},
+ {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, acceptPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInvertUIDOwner tests that TCP connections from gid owner are dropped.
+type FilterOutputInvertUIDOwner struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertUIDOwner) Name() string {
+ return "FilterOutputInvertUIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertUIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "-j", "DROP"},
+ {"-A", "OUTPUT", "-p", "tcp", "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertUIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, acceptPort)
+}
+
+// FilterOutputInvertUIDAndGIDOwner tests that TCP connections from uid and gid
+// owner are dropped.
+type FilterOutputInvertUIDAndGIDOwner struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertUIDAndGIDOwner) Name() string {
+ return "FilterOutputInvertUIDAndGIDOwner"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "!", "--gid-owner", "root", "-j", "ACCEPT"},
+ {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, acceptPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputDestination tests that we can selectively allow packets to
+// certain destinations.
+type FilterOutputDestination struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputDestination) Name() string {
+ return "FilterOutputDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return listenUDP(ctx, acceptPort)
+}
+
+// FilterOutputInvertDestination tests that we can selectively allow packets
+// not headed for a particular destination.
+type FilterOutputInvertDestination struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInvertDestination) Name() string {
+ return "FilterOutputInvertDestination"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ rules := [][]string{
+ {"-A", "OUTPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"},
+ {"-P", "OUTPUT", "DROP"},
+ }
+ if err := filterTableRules(ipv6, rules); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return listenUDP(ctx, acceptPort)
+}
+
+// FilterOutputInterfaceAccept tests that packets are sent via interface
+// matching the iptables rule.
+type FilterOutputInterfaceAccept struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceAccept) Name() string {
+ return "FilterOutputInterfaceAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return listenUDP(ctx, acceptPort)
+}
+
+// FilterOutputInterfaceDrop tests that packets are not sent via interface
+// matching the iptables rule.
+type FilterOutputInterfaceDrop struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceDrop) Name() string {
+ return "FilterOutputInterfaceDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ ifname, ok := getInterfaceName()
+ if !ok {
+ return fmt.Errorf("no interface is present, except loopback")
+ }
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// FilterOutputInterface tests that packets are sent via interface which is
+// not matching the interface name in the iptables rule.
+type FilterOutputInterface struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInterface) Name() string {
+ return "FilterOutputInterface"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return listenUDP(ctx, acceptPort)
+}
+
+// FilterOutputInterfaceBeginsWith tests that packets are not sent via an
+// interface which begins with the given interface name.
+type FilterOutputInterfaceBeginsWith struct{ localCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceBeginsWith) Name() string {
+ return "FilterOutputInterfaceBeginsWith"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// FilterOutputInterfaceInvertDrop tests that we selectively do not send
+// packets via interface not matching the interface name.
+type FilterOutputInterfaceInvertDrop struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceInvertDrop) Name() string {
+ return "FilterOutputInterfaceInvertDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenTCP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := connectTCP(timedCtx, ip, acceptPort); err == nil {
+ return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort)
+ }
+
+ return nil
+}
+
+// FilterOutputInterfaceInvertAccept tests that we can selectively send packets
+// not matching the specific outgoing interface.
+type FilterOutputInterfaceInvertAccept struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (FilterOutputInterfaceInvertAccept) Name() string {
+ return "FilterOutputInterfaceInvertAccept"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterOutputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on accept port.
+ return listenTCP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterOutputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, acceptPort)
+}
diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go
new file mode 100644
index 000000000..c2a03f54c
--- /dev/null
+++ b/test/iptables/iptables.go
@@ -0,0 +1,115 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package iptables contains a set of iptables tests implemented as TestCases
+package iptables
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "time"
+)
+
+// IPExchangePort is the port the container listens on to receive the IP
+// address of the local process.
+const IPExchangePort = 2349
+
+// TerminalStatement is the last statement in the test runner.
+const TerminalStatement = "Finished!"
+
+// TestTimeout is the timeout used for all tests.
+const TestTimeout = 10 * time.Second
+
+// NegativeTimeout is the time tests should wait to establish the negative
+// case, i.e. that connections are not made.
+const NegativeTimeout = 2 * time.Second
+
+// A TestCase contains one action to run in the container and one to run
+// locally. The actions run concurrently and each must succeed for the test
+// pass.
+type TestCase interface {
+ // Name returns the name of the test.
+ Name() string
+
+ // ContainerAction runs inside the container. It receives the IP of the
+ // local process.
+ ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error
+
+ // LocalAction runs locally. It receives the IP of the container.
+ LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error
+
+ // ContainerSufficient indicates whether ContainerAction's return value
+ // alone indicates whether the test succeeded.
+ ContainerSufficient() bool
+
+ // LocalSufficient indicates whether LocalAction's return value alone
+ // indicates whether the test succeeded.
+ LocalSufficient() bool
+}
+
+// baseCase provides defaults for ContainerSufficient and LocalSufficient when
+// both actions are required to finish.
+type baseCase struct{}
+
+// ContainerSufficient implements TestCase.ContainerSufficient.
+func (baseCase) ContainerSufficient() bool {
+ return false
+}
+
+// LocalSufficient implements TestCase.LocalSufficient.
+func (baseCase) LocalSufficient() bool {
+ return false
+}
+
+// localCase provides defaults for ContainerSufficient and LocalSufficient when
+// only the local action is required to finish.
+type localCase struct{}
+
+// ContainerSufficient implements TestCase.ContainerSufficient.
+func (localCase) ContainerSufficient() bool {
+ return false
+}
+
+// LocalSufficient implements TestCase.LocalSufficient.
+func (localCase) LocalSufficient() bool {
+ return true
+}
+
+// containerCase provides defaults for ContainerSufficient and LocalSufficient
+// when only the container action is required to finish.
+type containerCase struct{}
+
+// ContainerSufficient implements TestCase.ContainerSufficient.
+func (containerCase) ContainerSufficient() bool {
+ return true
+}
+
+// LocalSufficient implements TestCase.LocalSufficient.
+func (containerCase) LocalSufficient() bool {
+ return false
+}
+
+// Tests maps test names to TestCase.
+//
+// New TestCases are added by calling RegisterTestCase in an init function.
+var Tests = map[string]TestCase{}
+
+// RegisterTestCase registers tc so it can be run.
+func RegisterTestCase(tc TestCase) {
+ if _, ok := Tests[tc.Name()]; ok {
+ panic(fmt.Sprintf("TestCase %s already registered.", tc.Name()))
+ }
+ Tests[tc.Name()] = tc
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
new file mode 100644
index 000000000..e2beb30d5
--- /dev/null
+++ b/test/iptables/iptables_test.go
@@ -0,0 +1,427 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "reflect"
+ "sync"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// singleTest runs a TestCase. Each test follows a pattern:
+// - Create a container.
+// - Get the container's IP.
+// - Send the container our IP.
+// - Start a new goroutine running the local action of the test.
+// - Wait for both the container and local actions to finish.
+//
+// Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or
+// to stderr.
+func singleTest(t *testing.T, test TestCase) {
+ for _, tc := range []bool{false, true} {
+ subtest := "IPv4"
+ if tc {
+ subtest = "IPv6"
+ }
+ t.Run(subtest, func(t *testing.T) {
+ iptablesTest(t, test, tc)
+ })
+ }
+}
+
+func iptablesTest(t *testing.T, test TestCase, ipv6 bool) {
+ if _, ok := Tests[test.Name()]; !ok {
+ t.Fatalf("no test found with name %q. Has it been registered?", test.Name())
+ }
+
+ // Wait for the local and container goroutines to finish.
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ ctx, cancel := context.WithTimeout(context.Background(), TestTimeout)
+ defer cancel()
+
+ d := dockerutil.MakeContainer(ctx, t)
+ defer func() {
+ if logs, err := d.Logs(context.Background()); err != nil {
+ t.Logf("Failed to retrieve container logs.")
+ } else {
+ t.Logf("=== Container logs: ===\n%s", logs)
+ }
+ // Use a new context, as cleanup should run even when we
+ // timeout.
+ d.CleanUp(context.Background())
+ }()
+
+ // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests.
+ if ipv6 && dockerutil.Runtime() != "runc" {
+ t.Skip("gVisor ip6tables not yet implemented")
+ }
+
+ // Create and start the container.
+ opts := dockerutil.RunOpts{
+ Image: "iptables",
+ CapAdd: []string{"NET_ADMIN"},
+ }
+ d.CopyFiles(&opts, "/runner", "test/iptables/runner/runner")
+ args := []string{"/runner/runner", "-name", test.Name()}
+ if ipv6 {
+ args = append(args, "-ipv6")
+ }
+ if err := d.Spawn(ctx, opts, args...); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Get the container IP.
+ ip, err := d.FindIP(ctx, ipv6)
+ if err != nil {
+ t.Fatalf("failed to get container IP: %v", err)
+ }
+
+ // Give the container our IP.
+ if err := sendIP(ip); err != nil {
+ t.Fatalf("failed to send IP to container: %v", err)
+ }
+
+ // Run our side of the test.
+ errCh := make(chan error, 2)
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ if err := test.LocalAction(ctx, ip, ipv6); err != nil && !errors.Is(err, context.Canceled) {
+ errCh <- fmt.Errorf("LocalAction failed: %v", err)
+ } else {
+ errCh <- nil
+ }
+ if test.LocalSufficient() {
+ errCh <- nil
+ }
+ }()
+
+ // Run the container side.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ // Wait for the final statement. This structure has the side
+ // effect that all container logs will appear within the
+ // individual test context.
+ if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil && !errors.Is(err, context.Canceled) {
+ errCh <- fmt.Errorf("ContainerAction failed: %v", err)
+ } else {
+ errCh <- nil
+ }
+ if test.ContainerSufficient() {
+ errCh <- nil
+ }
+ }()
+
+ for i := 0; i < 2; i++ {
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+ }
+}
+
+func sendIP(ip net.IP) error {
+ contAddr := net.TCPAddr{
+ IP: ip,
+ Port: IPExchangePort,
+ }
+ var conn *net.TCPConn
+ // The container may not be listening when we first connect, so retry
+ // upon error.
+ cb := func() error {
+ c, err := net.DialTCP("tcp", nil, &contAddr)
+ conn = c
+ return err
+ }
+ if err := testutil.Poll(cb, TestTimeout); err != nil {
+ return fmt.Errorf("timed out waiting to send IP, most recent error: %v", err)
+ }
+ if _, err := conn.Write([]byte{0}); err != nil {
+ return fmt.Errorf("error writing to container: %v", err)
+ }
+ return nil
+}
+
+func TestFilterInputDropUDP(t *testing.T) {
+ singleTest(t, FilterInputDropUDP{})
+}
+
+func TestFilterInputDropUDPPort(t *testing.T) {
+ singleTest(t, FilterInputDropUDPPort{})
+}
+
+func TestFilterInputDropDifferentUDPPort(t *testing.T) {
+ singleTest(t, FilterInputDropDifferentUDPPort{})
+}
+
+func TestFilterInputDropAll(t *testing.T) {
+ singleTest(t, FilterInputDropAll{})
+}
+
+func TestFilterInputDropOnlyUDP(t *testing.T) {
+ singleTest(t, FilterInputDropOnlyUDP{})
+}
+
+func TestFilterInputDropTCPDestPort(t *testing.T) {
+ singleTest(t, FilterInputDropTCPDestPort{})
+}
+
+func TestFilterInputDropTCPSrcPort(t *testing.T) {
+ singleTest(t, FilterInputDropTCPSrcPort{})
+}
+
+func TestFilterInputCreateUserChain(t *testing.T) {
+ singleTest(t, FilterInputCreateUserChain{})
+}
+
+func TestFilterInputDefaultPolicyAccept(t *testing.T) {
+ singleTest(t, FilterInputDefaultPolicyAccept{})
+}
+
+func TestFilterInputDefaultPolicyDrop(t *testing.T) {
+ singleTest(t, FilterInputDefaultPolicyDrop{})
+}
+
+func TestFilterInputReturnUnderflow(t *testing.T) {
+ singleTest(t, FilterInputReturnUnderflow{})
+}
+
+func TestFilterOutputDropTCPDestPort(t *testing.T) {
+ singleTest(t, FilterOutputDropTCPDestPort{})
+}
+
+func TestFilterOutputDropTCPSrcPort(t *testing.T) {
+ singleTest(t, FilterOutputDropTCPSrcPort{})
+}
+
+func TestFilterOutputAcceptTCPOwner(t *testing.T) {
+ singleTest(t, FilterOutputAcceptTCPOwner{})
+}
+
+func TestFilterOutputDropTCPOwner(t *testing.T) {
+ singleTest(t, FilterOutputDropTCPOwner{})
+}
+
+func TestFilterOutputAcceptUDPOwner(t *testing.T) {
+ singleTest(t, FilterOutputAcceptUDPOwner{})
+}
+
+func TestFilterOutputDropUDPOwner(t *testing.T) {
+ singleTest(t, FilterOutputDropUDPOwner{})
+}
+
+func TestFilterOutputOwnerFail(t *testing.T) {
+ singleTest(t, FilterOutputOwnerFail{})
+}
+
+func TestFilterOutputAcceptGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputAcceptGIDOwner{})
+}
+
+func TestFilterOutputDropGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputDropGIDOwner{})
+}
+
+func TestFilterOutputInvertGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputInvertGIDOwner{})
+}
+
+func TestFilterOutputInvertUIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputInvertUIDOwner{})
+}
+
+func TestFilterOutputInvertUIDAndGIDOwner(t *testing.T) {
+ singleTest(t, FilterOutputInvertUIDAndGIDOwner{})
+}
+
+func TestFilterOutputInterfaceAccept(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceAccept{})
+}
+
+func TestFilterOutputInterfaceDrop(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceDrop{})
+}
+
+func TestFilterOutputInterface(t *testing.T) {
+ singleTest(t, FilterOutputInterface{})
+}
+
+func TestFilterOutputInterfaceBeginsWith(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceBeginsWith{})
+}
+
+func TestFilterOutputInterfaceInvertDrop(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceInvertDrop{})
+}
+
+func TestFilterOutputInterfaceInvertAccept(t *testing.T) {
+ singleTest(t, FilterOutputInterfaceInvertAccept{})
+}
+
+func TestJumpSerialize(t *testing.T) {
+ singleTest(t, FilterInputSerializeJump{})
+}
+
+func TestJumpBasic(t *testing.T) {
+ singleTest(t, FilterInputJumpBasic{})
+}
+
+func TestJumpReturn(t *testing.T) {
+ singleTest(t, FilterInputJumpReturn{})
+}
+
+func TestJumpReturnDrop(t *testing.T) {
+ singleTest(t, FilterInputJumpReturnDrop{})
+}
+
+func TestJumpBuiltin(t *testing.T) {
+ singleTest(t, FilterInputJumpBuiltin{})
+}
+
+func TestJumpTwice(t *testing.T) {
+ singleTest(t, FilterInputJumpTwice{})
+}
+
+func TestInputDestination(t *testing.T) {
+ singleTest(t, FilterInputDestination{})
+}
+
+func TestInputInvertDestination(t *testing.T) {
+ singleTest(t, FilterInputInvertDestination{})
+}
+
+func TestOutputDestination(t *testing.T) {
+ singleTest(t, FilterOutputDestination{})
+}
+
+func TestOutputInvertDestination(t *testing.T) {
+ singleTest(t, FilterOutputInvertDestination{})
+}
+
+func TestNATPreRedirectUDPPort(t *testing.T) {
+ singleTest(t, NATPreRedirectUDPPort{})
+}
+
+func TestNATPreRedirectTCPPort(t *testing.T) {
+ singleTest(t, NATPreRedirectTCPPort{})
+}
+
+func TestNATPreRedirectTCPOutgoing(t *testing.T) {
+ singleTest(t, NATPreRedirectTCPOutgoing{})
+}
+
+func TestNATOutRedirectTCPIncoming(t *testing.T) {
+ singleTest(t, NATOutRedirectTCPIncoming{})
+}
+func TestNATOutRedirectUDPPort(t *testing.T) {
+ singleTest(t, NATOutRedirectUDPPort{})
+}
+
+func TestNATOutRedirectTCPPort(t *testing.T) {
+ singleTest(t, NATOutRedirectTCPPort{})
+}
+
+func TestNATDropUDP(t *testing.T) {
+ singleTest(t, NATDropUDP{})
+}
+
+func TestNATAcceptAll(t *testing.T) {
+ singleTest(t, NATAcceptAll{})
+}
+
+func TestNATOutRedirectIP(t *testing.T) {
+ singleTest(t, NATOutRedirectIP{})
+}
+
+func TestNATOutDontRedirectIP(t *testing.T) {
+ singleTest(t, NATOutDontRedirectIP{})
+}
+
+func TestNATOutRedirectInvert(t *testing.T) {
+ singleTest(t, NATOutRedirectInvert{})
+}
+
+func TestNATPreRedirectIP(t *testing.T) {
+ singleTest(t, NATPreRedirectIP{})
+}
+
+func TestNATPreDontRedirectIP(t *testing.T) {
+ singleTest(t, NATPreDontRedirectIP{})
+}
+
+func TestNATPreRedirectInvert(t *testing.T) {
+ singleTest(t, NATPreRedirectInvert{})
+}
+
+func TestNATRedirectRequiresProtocol(t *testing.T) {
+ singleTest(t, NATRedirectRequiresProtocol{})
+}
+
+func TestNATLoopbackSkipsPrerouting(t *testing.T) {
+ singleTest(t, NATLoopbackSkipsPrerouting{})
+}
+
+func TestInputSource(t *testing.T) {
+ singleTest(t, FilterInputSource{})
+}
+
+func TestInputInvertSource(t *testing.T) {
+ singleTest(t, FilterInputInvertSource{})
+}
+
+func TestFilterAddrs(t *testing.T) {
+ tcs := []struct {
+ ipv6 bool
+ addrs []string
+ want []string
+ }{
+ {
+ ipv6: false,
+ addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"},
+ want: []string{"192.168.0.1", "192.168.0.2"},
+ },
+ {
+ ipv6: true,
+ addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"},
+ want: []string{"::1", "::2"},
+ },
+ }
+
+ for _, tc := range tcs {
+ if got := filterAddrs(tc.addrs, tc.ipv6); !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("%v with IPv6 %t: got %v, but wanted %v", tc.addrs, tc.ipv6, got, tc.want)
+ }
+ }
+}
+
+func TestNATPreOriginalDst(t *testing.T) {
+ singleTest(t, NATPreOriginalDst{})
+}
+
+func TestNATOutOriginalDst(t *testing.T) {
+ singleTest(t, NATOutOriginalDst{})
+}
diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go
new file mode 100644
index 000000000..bd85a8fea
--- /dev/null
+++ b/test/iptables/iptables_unsafe.go
@@ -0,0 +1,63 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "fmt"
+ "syscall"
+ "unsafe"
+)
+
+type originalDstError struct {
+ errno syscall.Errno
+}
+
+func (e originalDstError) Error() string {
+ return fmt.Sprintf("errno (%d) when calling getsockopt(SO_ORIGINAL_DST): %v", int(e.errno), e.errno.Error())
+}
+
+// SO_ORIGINAL_DST gets the original destination of a redirected packet via
+// getsockopt.
+const SO_ORIGINAL_DST = 80
+
+func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) {
+ var addr syscall.RawSockaddrInet4
+ var addrLen uint32 = syscall.SizeofSockaddrInet4
+ if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 {
+ return syscall.RawSockaddrInet4{}, originalDstError{errno}
+ }
+ return addr, nil
+}
+
+func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) {
+ var addr syscall.RawSockaddrInet6
+ var addrLen uint32 = syscall.SizeofSockaddrInet6
+ if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 {
+ return syscall.RawSockaddrInet6{}, originalDstError{errno}
+ }
+ return addr, nil
+}
+
+func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno {
+ _, _, errno := syscall.Syscall6(
+ syscall.SYS_GETSOCKOPT,
+ uintptr(connfd),
+ level,
+ SO_ORIGINAL_DST,
+ uintptr(optval),
+ uintptr(unsafe.Pointer(optlen)),
+ 0)
+ return errno
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
new file mode 100644
index 000000000..a6ec5cca3
--- /dev/null
+++ b/test/iptables/iptables_util.go
@@ -0,0 +1,282 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net"
+ "os/exec"
+ "strings"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/testutil"
+)
+
+// filterTable calls `ip{6}tables -t filter` with the given args.
+func filterTable(ipv6 bool, args ...string) error {
+ return tableCmd(ipv6, "filter", args)
+}
+
+// natTable calls `ip{6}tables -t nat` with the given args.
+func natTable(ipv6 bool, args ...string) error {
+ return tableCmd(ipv6, "nat", args)
+}
+
+func tableCmd(ipv6 bool, table string, args []string) error {
+ args = append([]string{"-t", table}, args...)
+ binary := "iptables"
+ if ipv6 {
+ binary = "ip6tables"
+ }
+ cmd := exec.Command(binary, args...)
+ if out, err := cmd.CombinedOutput(); err != nil {
+ return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out))
+ }
+ return nil
+}
+
+// filterTableRules is like filterTable, but runs multiple iptables commands.
+func filterTableRules(ipv6 bool, argsList [][]string) error {
+ return tableRules(ipv6, "filter", argsList)
+}
+
+// natTableRules is like natTable, but runs multiple iptables commands.
+func natTableRules(ipv6 bool, argsList [][]string) error {
+ return tableRules(ipv6, "nat", argsList)
+}
+
+func tableRules(ipv6 bool, table string, argsList [][]string) error {
+ for _, args := range argsList {
+ if err := tableCmd(ipv6, table, args); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
+// the first read on that port.
+func listenUDP(ctx context.Context, port int) error {
+ localAddr := net.UDPAddr{
+ Port: port,
+ }
+ conn, err := net.ListenUDP("udp", &localAddr)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ ch := make(chan error)
+ go func() {
+ _, err = conn.Read([]byte{0})
+ ch <- err
+ }()
+
+ select {
+ case err := <-ch:
+ return err
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified
+// over a duration.
+func sendUDPLoop(ctx context.Context, ip net.IP, port int) error {
+ remote := net.UDPAddr{
+ IP: ip,
+ Port: port,
+ }
+ conn, err := net.DialUDP("udp", nil, &remote)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ for {
+ // This may return an error (connection refused) if the remote
+ // hasn't started listening yet or they're dropping our
+ // packets. So we ignore Write errors and depend on the remote
+ // to report a failure if it doesn't get a packet it needs.
+ conn.Write([]byte{0})
+ select {
+ case <-ctx.Done():
+ // Being cancelled or timing out isn't an error, as we
+ // cannot tell with UDP whether we succeeded.
+ return nil
+ // Continue looping.
+ case <-time.After(200 * time.Millisecond):
+ }
+ }
+}
+
+// listenTCP listens for connections on a TCP port.
+func listenTCP(ctx context.Context, port int) error {
+ localAddr := net.TCPAddr{
+ Port: port,
+ }
+
+ // Starts listening on port.
+ lConn, err := net.ListenTCP("tcp", &localAddr)
+ if err != nil {
+ return err
+ }
+ defer lConn.Close()
+
+ // Accept connections on port.
+ ch := make(chan error)
+ go func() {
+ conn, err := lConn.AcceptTCP()
+ ch <- err
+ conn.Close()
+ }()
+
+ select {
+ case err := <-ch:
+ return err
+ case <-ctx.Done():
+ return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err())
+ }
+}
+
+// connectTCP connects to the given IP and port from an ephemeral local address.
+func connectTCP(ctx context.Context, ip net.IP, port int) error {
+ contAddr := net.TCPAddr{
+ IP: ip,
+ Port: port,
+ }
+ // The container may not be listening when we first connect, so retry
+ // upon error.
+ callback := func() error {
+ var d net.Dialer
+ conn, err := d.DialContext(ctx, "tcp", contAddr.String())
+ if conn != nil {
+ conn.Close()
+ }
+ return err
+ }
+ if err := testutil.PollContext(ctx, callback); err != nil {
+ return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err)
+ }
+
+ return nil
+}
+
+// localAddrs returns a list of local network interface addresses. When ipv6 is
+// true, only IPv6 addresses are returned. Otherwise only IPv4 addresses are
+// returned.
+func localAddrs(ipv6 bool) ([]string, error) {
+ addrs, err := net.InterfaceAddrs()
+ if err != nil {
+ return nil, err
+ }
+ addrStrs := make([]string, 0, len(addrs))
+ for _, addr := range addrs {
+ // Add only IPv4 or only IPv6 addresses.
+ parts := strings.Split(addr.String(), "/")
+ if len(parts) != 2 {
+ return nil, fmt.Errorf("bad interface address: %q", addr.String())
+ }
+ if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 {
+ addrStrs = append(addrStrs, addr.String())
+ }
+ }
+ return filterAddrs(addrStrs, ipv6), nil
+}
+
+func filterAddrs(addrs []string, ipv6 bool) []string {
+ addrStrs := make([]string, 0, len(addrs))
+ for _, addr := range addrs {
+ // Add only IPv4 or only IPv6 addresses.
+ parts := strings.Split(addr, "/")
+ if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 {
+ addrStrs = append(addrStrs, parts[0])
+ }
+ }
+ return addrStrs
+}
+
+// getInterfaceName returns the name of the interface other than loopback.
+func getInterfaceName() (string, bool) {
+ iface, ok := getNonLoopbackInterface()
+ if !ok {
+ return "", false
+ }
+ return iface.Name, true
+}
+
+func getInterfaceAddrs(ipv6 bool) ([]net.IP, error) {
+ iface, ok := getNonLoopbackInterface()
+ if !ok {
+ return nil, errors.New("no non-loopback interface found")
+ }
+ addrs, err := iface.Addrs()
+ if err != nil {
+ return nil, err
+ }
+
+ // Get only IPv4 or IPv6 addresses.
+ ips := make([]net.IP, 0, len(addrs))
+ for _, addr := range addrs {
+ parts := strings.Split(addr.String(), "/")
+ var ip net.IP
+ // To16() returns IPv4 addresses as IPv4-mapped IPv6 addresses.
+ // So we check whether To4() returns nil to test whether the
+ // address is v4 or v6.
+ if v4 := net.ParseIP(parts[0]).To4(); ipv6 && v4 == nil {
+ ip = net.ParseIP(parts[0]).To16()
+ } else {
+ ip = v4
+ }
+ if ip != nil {
+ ips = append(ips, ip)
+ }
+ }
+ return ips, nil
+}
+
+func getNonLoopbackInterface() (net.Interface, bool) {
+ if interfaces, err := net.Interfaces(); err == nil {
+ for _, intf := range interfaces {
+ if intf.Name != "lo" {
+ return intf, true
+ }
+ }
+ }
+ return net.Interface{}, false
+}
+
+func htons(x uint16) uint16 {
+ buf := make([]byte, 2)
+ binary.BigEndian.PutUint16(buf, x)
+ return binary.LittleEndian.Uint16(buf)
+}
+
+func localIP(ipv6 bool) string {
+ if ipv6 {
+ return "::1"
+ }
+ return "127.0.0.1"
+}
+
+func nowhereIP(ipv6 bool) string {
+ if ipv6 {
+ return "2001:db8::1"
+ }
+ return "192.0.2.1"
+}
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
new file mode 100644
index 000000000..dd9a18339
--- /dev/null
+++ b/test/iptables/nat.go
@@ -0,0 +1,657 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "syscall"
+)
+
+const redirectPort = 42
+
+func init() {
+ RegisterTestCase(NATPreRedirectUDPPort{})
+ RegisterTestCase(NATPreRedirectTCPPort{})
+ RegisterTestCase(NATPreRedirectTCPOutgoing{})
+ RegisterTestCase(NATOutRedirectTCPIncoming{})
+ RegisterTestCase(NATOutRedirectUDPPort{})
+ RegisterTestCase(NATOutRedirectTCPPort{})
+ RegisterTestCase(NATDropUDP{})
+ RegisterTestCase(NATAcceptAll{})
+ RegisterTestCase(NATPreRedirectIP{})
+ RegisterTestCase(NATPreDontRedirectIP{})
+ RegisterTestCase(NATPreRedirectInvert{})
+ RegisterTestCase(NATOutRedirectIP{})
+ RegisterTestCase(NATOutDontRedirectIP{})
+ RegisterTestCase(NATOutRedirectInvert{})
+ RegisterTestCase(NATRedirectRequiresProtocol{})
+ RegisterTestCase(NATLoopbackSkipsPrerouting{})
+ RegisterTestCase(NATPreOriginalDst{})
+ RegisterTestCase(NATOutOriginalDst{})
+}
+
+// NATPreRedirectUDPPort tests that packets are redirected to different port.
+type NATPreRedirectUDPPort struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATPreRedirectUDPPort) Name() string {
+ return "NATPreRedirectUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ if err := listenUDP(ctx, redirectPort); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// NATPreRedirectTCPPort tests that connections are redirected on specified ports.
+type NATPreRedirectTCPPort struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATPreRedirectTCPPort) Name() string {
+ return "NATPreRedirectTCPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ // Listen for TCP packets on redirect port.
+ return listenTCP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, dropPort)
+}
+
+// NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't
+// affected by PREROUTING connection tracking.
+type NATPreRedirectTCPOutgoing struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATPreRedirectTCPOutgoing) Name() string {
+ return "NATPreRedirectTCPOutgoing"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectTCPOutgoing) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Redirect all incoming TCP traffic to a closed port.
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+
+ // Establish a connection to the host process.
+ return connectTCP(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectTCPOutgoing) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return listenTCP(ctx, acceptPort)
+}
+
+// NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't
+// affected by OUTPUT connection tracking.
+type NATOutRedirectTCPIncoming struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATOutRedirectTCPIncoming) Name() string {
+ return "NATOutRedirectTCPIncoming"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectTCPIncoming) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Redirect all outgoing TCP traffic to a closed port.
+ if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+
+ // Establish a connection to the host process.
+ return listenTCP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectTCPIncoming) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, acceptPort)
+}
+
+// NATOutRedirectUDPPort tests that packets are redirected to different port.
+type NATOutRedirectUDPPort struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATOutRedirectUDPPort) Name() string {
+ return "NATOutRedirectUDPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// NATDropUDP tests that packets are not received in ports other than redirect
+// port.
+type NATDropUDP struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATDropUDP) Name() string {
+ return "NATDropUDP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil {
+ return err
+ }
+
+ timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout)
+ defer cancel()
+ if err := listenUDP(timedCtx, acceptPort); err == nil {
+ return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort)
+ } else if !errors.Is(err, context.DeadlineExceeded) {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// NATAcceptAll tests that all UDP packets are accepted.
+type NATAcceptAll struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATAcceptAll) Name() string {
+ return "NATAcceptAll"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATAcceptAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil {
+ return err
+ }
+
+ if err := listenUDP(ctx, acceptPort); err != nil {
+ return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err)
+ }
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATAcceptAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// NATOutRedirectIP uses iptables to select packets based on destination IP and
+// redirects them.
+type NATOutRedirectIP struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATOutRedirectIP) Name() string {
+ return "NATOutRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)),
+ "-A", "OUTPUT",
+ "-d", nowhereIP(ipv6),
+ "-p", "udp",
+ "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// NATOutDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATOutDontRedirectIP struct{ localCase }
+
+// Name implements TestCase.Name.
+func (NATOutDontRedirectIP) Name() string {
+ return "NATOutDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "OUTPUT", "-d", localIP(ipv6), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return listenUDP(ctx, acceptPort)
+}
+
+// NATOutRedirectInvert tests that iptables can match with "! -d".
+type NATOutRedirectInvert struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATOutRedirectInvert) Name() string {
+ return "NATOutRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Redirect OUTPUT packets to a listening localhost port.
+ dest := "192.0.2.2"
+ if ipv6 {
+ dest = "2001:db8::2"
+ }
+ return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)),
+ "-A", "OUTPUT",
+ "!", "-d", dest,
+ "-p", "udp",
+ "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort))
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// NATPreRedirectIP tests that we can use iptables to select packets based on
+// destination IP and redirect them.
+type NATPreRedirectIP struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATPreRedirectIP) Name() string {
+ return "NATPreRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ addrs, err := localAddrs(ipv6)
+ if err != nil {
+ return err
+ }
+
+ var rules [][]string
+ for _, addr := range addrs {
+ rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)})
+ }
+ if err := natTableRules(ipv6, rules); err != nil {
+ return err
+ }
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, dropPort)
+}
+
+// NATPreDontRedirectIP tests that iptables matching with "-d" does not match
+// packets it shouldn't.
+type NATPreDontRedirectIP struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATPreDontRedirectIP) Name() string {
+ return "NATPreDontRedirectIP"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, acceptPort)
+}
+
+// NATPreRedirectInvert tests that iptables can match with "! -d".
+type NATPreRedirectInvert struct{ containerCase }
+
+// Name implements TestCase.Name.
+func (NATPreRedirectInvert) Name() string {
+ return "NATPreRedirectInvert"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "!", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+ return listenUDP(ctx, acceptPort)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return sendUDPLoop(ctx, ip, dropPort)
+}
+
+// NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a
+// protocol to be specified with -p.
+type NATRedirectRequiresProtocol struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATRedirectRequiresProtocol) Name() string {
+ return "NATRedirectRequiresProtocol"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATRedirectRequiresProtocol) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "PREROUTING", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil {
+ return errors.New("expected an error using REDIRECT --to-ports without a protocol")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATRedirectRequiresProtocol) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// NATOutRedirectTCPPort tests that connections are redirected on specified ports.
+type NATOutRedirectTCPPort struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATOutRedirectTCPPort) Name() string {
+ return "NATOutRedirectTCPPort"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ localAddr := net.TCPAddr{
+ IP: net.ParseIP(localIP(ipv6)),
+ Port: acceptPort,
+ }
+
+ // Starts listening on port.
+ lConn, err := net.ListenTCP("tcp", &localAddr)
+ if err != nil {
+ return err
+ }
+ defer lConn.Close()
+
+ // Accept connections on port.
+ if err := connectTCP(ctx, ip, dropPort); err != nil {
+ return err
+ }
+
+ conn, err := lConn.AcceptTCP()
+ if err != nil {
+ return err
+ }
+ conn.Close()
+
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return nil
+}
+
+// NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't
+// affected by PREROUTING rules.
+type NATLoopbackSkipsPrerouting struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATLoopbackSkipsPrerouting) Name() string {
+ return "NATLoopbackSkipsPrerouting"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Redirect anything sent to localhost to an unused port.
+ dest := []byte{127, 0, 0, 1}
+ if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
+ return err
+ }
+
+ // Establish a connection via localhost. If the PREROUTING rule did apply to
+ // loopback traffic, the connection would fail.
+ sendCh := make(chan error)
+ go func() {
+ sendCh <- connectTCP(ctx, dest, acceptPort)
+ }()
+
+ if err := listenTCP(ctx, acceptPort); err != nil {
+ return err
+ }
+ return <-sendCh
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATLoopbackSkipsPrerouting) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+// NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination
+// of PREROUTING NATted packets.
+type NATPreOriginalDst struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATPreOriginalDst) Name() string {
+ return "NATPreOriginalDst"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATPreOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Redirect incoming TCP connections to acceptPort.
+ if err := natTable(ipv6, "-A", "PREROUTING",
+ "-p", "tcp",
+ "--destination-port", fmt.Sprintf("%d", dropPort),
+ "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ addrs, err := getInterfaceAddrs(ipv6)
+ if err != nil {
+ return err
+ }
+ return listenForRedirectedConn(ctx, ipv6, addrs)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATPreOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ return connectTCP(ctx, ip, dropPort)
+}
+
+// NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination
+// of OUTBOUND NATted packets.
+type NATOutOriginalDst struct{ baseCase }
+
+// Name implements TestCase.Name.
+func (NATOutOriginalDst) Name() string {
+ return "NATOutOriginalDst"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (NATOutOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // Redirect incoming TCP connections to acceptPort.
+ if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil {
+ return err
+ }
+
+ connCh := make(chan error)
+ go func() {
+ connCh <- connectTCP(ctx, ip, dropPort)
+ }()
+
+ if err := listenForRedirectedConn(ctx, ipv6, []net.IP{ip}); err != nil {
+ return err
+ }
+ return <-connCh
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
+ // No-op.
+ return nil
+}
+
+func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error {
+ // The net package doesn't give guarantee access to the connection's
+ // underlying FD, and thus we cannot call getsockopt. We have to use
+ // traditional syscalls for SO_ORIGINAL_DST.
+
+ // Create the listening socket, bind, listen, and accept.
+ family := syscall.AF_INET
+ if ipv6 {
+ family = syscall.AF_INET6
+ }
+ sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0)
+ if err != nil {
+ return err
+ }
+ defer syscall.Close(sockfd)
+
+ var bindAddr syscall.Sockaddr
+ if ipv6 {
+ bindAddr = &syscall.SockaddrInet6{
+ Port: acceptPort,
+ Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any
+ }
+ } else {
+ bindAddr = &syscall.SockaddrInet4{
+ Port: acceptPort,
+ Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY
+ }
+ }
+ if err := syscall.Bind(sockfd, bindAddr); err != nil {
+ return err
+ }
+
+ if err := syscall.Listen(sockfd, 1); err != nil {
+ return err
+ }
+
+ // Block on accept() in another goroutine.
+ connCh := make(chan int)
+ errCh := make(chan error)
+ go func() {
+ connFD, _, err := syscall.Accept(sockfd)
+ if err != nil {
+ errCh <- err
+ }
+ connCh <- connFD
+ }()
+
+ // Wait for accept() to return or for the context to finish.
+ var connFD int
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case err := <-errCh:
+ return err
+ case connFD = <-connCh:
+ }
+ defer syscall.Close(connFD)
+
+ // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST
+ // indicates the packet was sent to originalDst:dropPort.
+ if ipv6 {
+ got, err := originalDestination6(connFD)
+ if err != nil {
+ return err
+ }
+ // The original destination could be any of our IPs.
+ for _, dst := range originalDsts {
+ want := syscall.RawSockaddrInet6{
+ Family: syscall.AF_INET6,
+ Port: htons(dropPort),
+ }
+ copy(want.Addr[:], dst.To16())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ } else {
+ got, err := originalDestination4(connFD)
+ if err != nil {
+ return err
+ }
+ // The original destination could be any of our IPs.
+ for _, dst := range originalDsts {
+ want := syscall.RawSockaddrInet4{
+ Family: syscall.AF_INET,
+ Port: htons(dropPort),
+ }
+ copy(want.Addr[:], dst.To4())
+ if got == want {
+ return nil
+ }
+ }
+ return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts)
+ }
+}
+
+// loopbackTests runs an iptables rule and ensures that packets sent to
+// dest:dropPort are received by localhost:acceptPort.
+func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) error {
+ if err := natTable(ipv6, args...); err != nil {
+ return err
+ }
+ sendCh := make(chan error, 1)
+ listenCh := make(chan error, 1)
+ go func() {
+ sendCh <- sendUDPLoop(ctx, dest, dropPort)
+ }()
+ go func() {
+ listenCh <- listenUDP(ctx, acceptPort)
+ }()
+ select {
+ case err := <-listenCh:
+ return err
+ case err := <-sendCh:
+ return err
+ }
+}
diff --git a/test/iptables/runner/BUILD b/test/iptables/runner/BUILD
new file mode 100644
index 000000000..24504a1b9
--- /dev/null
+++ b/test/iptables/runner/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["main.go"],
+ pure = True,
+ visibility = ["//test/iptables:__subpackages__"],
+ deps = ["//test/iptables"],
+)
diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go
new file mode 100644
index 000000000..9ae2d1b4d
--- /dev/null
+++ b/test/iptables/runner/main.go
@@ -0,0 +1,79 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package main runs iptables tests from within a docker container.
+package main
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "log"
+ "net"
+
+ "gvisor.dev/gvisor/test/iptables"
+)
+
+var (
+ name = flag.String("name", "", "name of the test to run")
+ ipv6 = flag.Bool("ipv6", false, "whether the test utilizes ip6tables")
+)
+
+func main() {
+ flag.Parse()
+
+ // Find out which test we're running.
+ test, ok := iptables.Tests[*name]
+ if !ok {
+ log.Fatalf("No test found named %q", *name)
+ }
+ log.Printf("Running test %q", *name)
+
+ // Get the IP of the local process.
+ ip, err := getIP()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Run the test.
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ if err := test.ContainerAction(ctx, ip, *ipv6); err != nil {
+ log.Fatalf("Failed running test %q: %v", *name, err)
+ }
+
+ // Emit the final line.
+ log.Printf("%s", iptables.TerminalStatement)
+}
+
+// getIP listens for a connection from the local process and returns the source
+// IP of that connection.
+func getIP() (net.IP, error) {
+ localAddr := net.TCPAddr{
+ Port: iptables.IPExchangePort,
+ }
+ listener, err := net.ListenTCP("tcp", &localAddr)
+ if err != nil {
+ return net.IP{}, fmt.Errorf("failed listening for IP: %v", err)
+ }
+ defer listener.Close()
+ conn, err := listener.AcceptTCP()
+ if err != nil {
+ return net.IP{}, fmt.Errorf("failed accepting IP: %v", err)
+ }
+ defer conn.Close()
+ log.Printf("Connected to %v", conn.RemoteAddr())
+
+ return conn.RemoteAddr().(*net.TCPAddr).IP, nil
+}
diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD
new file mode 100644
index 000000000..49642f282
--- /dev/null
+++ b/test/packetdrill/BUILD
@@ -0,0 +1,45 @@
+load("//tools:defs.bzl", "bzl_library")
+load("//test/packetdrill:defs.bzl", "packetdrill_test")
+
+package(licenses = ["notice"])
+
+packetdrill_test(
+ name = "packetdrill_sanity_test",
+ scripts = ["sanity_test.pkt"],
+)
+
+packetdrill_test(
+ name = "accept_ack_drop_test",
+ scripts = ["accept_ack_drop.pkt"],
+)
+
+packetdrill_test(
+ name = "fin_wait2_timeout_test",
+ scripts = ["fin_wait2_timeout.pkt"],
+)
+
+packetdrill_test(
+ name = "listen_close_before_handshake_complete_test",
+ scripts = ["listen_close_before_handshake_complete.pkt"],
+)
+
+packetdrill_test(
+ name = "no_rst_to_rst_test",
+ scripts = ["no_rst_to_rst.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_test",
+ scripts = ["tcp_defer_accept.pkt"],
+)
+
+packetdrill_test(
+ name = "tcp_defer_accept_timeout_test",
+ scripts = ["tcp_defer_accept_timeout.pkt"],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/test/packetdrill/accept_ack_drop.pkt b/test/packetdrill/accept_ack_drop.pkt
new file mode 100644
index 000000000..76e638fd4
--- /dev/null
+++ b/test/packetdrill/accept_ack_drop.pkt
@@ -0,0 +1,27 @@
+// Test that the accept works if the final ACK is dropped and an ack with data
+// follows the dropped ack.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
++0.0 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl
new file mode 100644
index 000000000..fc28ce9ba
--- /dev/null
+++ b/test/packetdrill/defs.bzl
@@ -0,0 +1,91 @@
+"""Defines a rule for packetdrill test targets."""
+
+def _packetdrill_test_impl(ctx):
+ test_runner = ctx.executable._test_runner
+ runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+
+ script_paths = []
+ for script in ctx.files.scripts:
+ script_paths.append(script.short_path)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ # This test will run part in a distinct user namespace. This can cause
+ # permission problems, because all runfiles may not be owned by the
+ # current user, and no other users will be mapped in that namespace.
+ # Make sure that everything is readable here.
+ "find . -type f -exec chmod a+rx {} \\;",
+ "find . -type d -exec chmod a+rx {} \\;",
+ "%s %s --init_script %s $@ -- %s\n" % (
+ test_runner.short_path,
+ " ".join(ctx.attr.flags),
+ ctx.files._init_script[0].short_path,
+ " ".join(script_paths),
+ ),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ transitive_files = depset()
+ if hasattr(ctx.attr._test_runner, "data_runfiles"):
+ transitive_files = ctx.attr._test_runner.data_runfiles.files
+ runfiles = ctx.runfiles(
+ files = [test_runner] + ctx.files._init_script + ctx.files.scripts,
+ transitive_files = transitive_files,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_packetdrill_test = rule(
+ attrs = {
+ "_test_runner": attr.label(
+ executable = True,
+ cfg = "host",
+ allow_files = True,
+ default = "packetdrill_test.sh",
+ ),
+ "_init_script": attr.label(
+ allow_single_file = True,
+ default = "packetdrill_setup.sh",
+ ),
+ "flags": attr.string_list(
+ mandatory = False,
+ default = [],
+ ),
+ "scripts": attr.label_list(
+ mandatory = True,
+ allow_files = True,
+ ),
+ },
+ test = True,
+ implementation = _packetdrill_test_impl,
+)
+
+PACKETDRILL_TAGS = [
+ "local",
+ "manual",
+ "packetdrill",
+]
+
+def packetdrill_linux_test(name, **kwargs):
+ if "tags" not in kwargs:
+ kwargs["tags"] = PACKETDRILL_TAGS
+ _packetdrill_test(
+ name = name,
+ flags = ["--dut_platform", "linux"],
+ **kwargs
+ )
+
+def packetdrill_netstack_test(name, **kwargs):
+ if "tags" not in kwargs:
+ kwargs["tags"] = PACKETDRILL_TAGS
+ _packetdrill_test(
+ name = name,
+ # This is the default runtime unless
+ # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value.
+ flags = ["--dut_platform", "netstack", "--runtime", "runsc-d"],
+ **kwargs
+ )
+
+def packetdrill_test(name, **kwargs):
+ packetdrill_linux_test(name + "_linux_test", **kwargs)
+ packetdrill_netstack_test(name + "_netstack_test", **kwargs)
diff --git a/test/packetdrill/fin_wait2_timeout.pkt b/test/packetdrill/fin_wait2_timeout.pkt
new file mode 100644
index 000000000..93ab08575
--- /dev/null
+++ b/test/packetdrill/fin_wait2_timeout.pkt
@@ -0,0 +1,23 @@
+// Test that a socket in FIN_WAIT_2 eventually times out and a subsequent
+// packet generates a RST.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0 < P. 1:1(0) ack 1 win 257
+
++0.100 accept(3, ..., ...) = 4
+// set FIN_WAIT2 timeout to 1 seconds.
++0.100 setsockopt(4, SOL_TCP, TCP_LINGER2, [1], 4) = 0
++0 close(4) = 0
+
++0 > F. 1:1(0) ack 1 <...>
++0 < . 1:1(0) ack 2 win 257
+
++2 < . 1:1(0) ack 2 win 257
++0 > R 2:2(0) win 0
diff --git a/test/packetdrill/listen_close_before_handshake_complete.pkt b/test/packetdrill/listen_close_before_handshake_complete.pkt
new file mode 100644
index 000000000..51c3f1a32
--- /dev/null
+++ b/test/packetdrill/listen_close_before_handshake_complete.pkt
@@ -0,0 +1,31 @@
+// Test that closing a listening socket closes any connections in SYN-RCVD
+// state and any packets bound for these connections generate a RESET.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
+
++0.100 close(3) = 0
++0.1 < P. 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.0 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/no_rst_to_rst.pkt b/test/packetdrill/no_rst_to_rst.pkt
new file mode 100644
index 000000000..612747827
--- /dev/null
+++ b/test/packetdrill/no_rst_to_rst.pkt
@@ -0,0 +1,36 @@
+// Test a RST is not generated in response to a RST and a RST is correctly
+// generated when an accepted endpoint is RST due to an incoming RST.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0 > S. 0:0(0) ack 1 <...>
++0 < P. 1:1(0) ack 1 win 257
+
++0.100 accept(3, ..., ...) = 4
+
++0.200 < R 1:1(0) win 0
+
++0.300 read(4,..., 4) = -1 ECONNRESET (Connection Reset by Peer)
+
++0.00 < . 1:1(0) ack 1 win 257
+
+// Linux generates a reset with no ack number/bit set. This is contradictory to
+// what is specified in Rule 1 under Reset Generation in
+// https://tools.ietf.org/html/rfc793#section-3.4.
+// "1. If the connection does not exist (CLOSED) then a reset is sent
+// in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected
+// by this means.
+//
+// If the incoming segment has an ACK field, the reset takes its
+// sequence number from the ACK field of the segment, otherwise the
+// reset has sequence number zero and the ACK field is set to the sum
+// of the sequence number and segment length of the incoming segment.
+// The connection remains in the CLOSED state."
+
++0.00 > R 1:1(0) win 0 \ No newline at end of file
diff --git a/test/packetdrill/packetdrill_setup.sh b/test/packetdrill/packetdrill_setup.sh
new file mode 100755
index 000000000..b858072f0
--- /dev/null
+++ b/test/packetdrill/packetdrill_setup.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script runs both within the sentry context and natively. It should tweak
+# TCP parameters to match expectations found in the script files.
+sysctl -q net.ipv4.tcp_sack=1
+sysctl -q net.ipv4.tcp_rmem="4096 2097152 $((8*1024*1024))"
+sysctl -q net.ipv4.tcp_wmem="4096 2097152 $((8*1024*1024))"
+
+# There may be errors from the above, but they will show up in the test logs and
+# we always want to proceed from this point. It's possible that values were
+# already set correctly and the nodes were not available in the namespace.
+exit 0
diff --git a/test/packetdrill/packetdrill_test.sh b/test/packetdrill/packetdrill_test.sh
new file mode 100755
index 000000000..922547d65
--- /dev/null
+++ b/test/packetdrill/packetdrill_test.sh
@@ -0,0 +1,226 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Run a packetdrill test. Two docker containers are made, one for the
+# Device-Under-Test (DUT) and one for the test runner. Each is attached with
+# two networks, one for control packets that aid the test and one for test
+# packets which are sent as part of the test and observed for correctness.
+
+set -euxo pipefail
+
+function failure() {
+ local lineno=$1
+ local msg=$2
+ local filename="$0"
+ echo "FAIL: $filename:$lineno: $msg"
+}
+trap 'failure ${LINENO} "$BASH_COMMAND"' ERR
+
+declare -r LONGOPTS="dut_platform:,init_script:,runtime:"
+
+# Don't use declare below so that the error from getopt will end the script.
+PARSED=$(getopt --options "" --longoptions=$LONGOPTS --name "$0" -- "$@")
+
+eval set -- "$PARSED"
+
+while true; do
+ case "$1" in
+ --dut_platform)
+ # Either "linux" or "netstack".
+ declare -r DUT_PLATFORM="$2"
+ shift 2
+ ;;
+ --init_script)
+ declare -r INIT_SCRIPT="$2"
+ shift 2
+ ;;
+ --runtime)
+ # Not readonly because there might be multiple --runtime arguments and we
+ # want to use just the last one. Only used if --dut_platform is
+ # "netstack".
+ declare RUNTIME="$2"
+ shift 2
+ ;;
+ --)
+ shift
+ break
+ ;;
+ *)
+ echo "Programming error"
+ exit 3
+ esac
+done
+
+# All the other arguments are scripts.
+declare -r scripts="$@"
+
+# Check that the required flags are defined in a way that is safe for "set -u".
+if [[ "${DUT_PLATFORM-}" == "netstack" ]]; then
+ if [[ -z "${RUNTIME-}" ]]; then
+ echo "FAIL: Missing --runtime argument: ${RUNTIME-}"
+ exit 2
+ fi
+ declare -r RUNTIME_ARG="--runtime ${RUNTIME}"
+elif [[ "${DUT_PLATFORM-}" == "linux" ]]; then
+ declare -r RUNTIME_ARG=""
+else
+ echo "FAIL: Bad or missing --dut_platform argument: ${DUT_PLATFORM-}"
+ exit 2
+fi
+if [[ ! -x "${INIT_SCRIPT-}" ]]; then
+ echo "FAIL: Bad or missing --init_script: ${INIT_SCRIPT-}"
+ exit 2
+fi
+
+function new_net_prefix() {
+ # Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ echo "$(shuf -i 192-223 -n 1).$(shuf -i 0-255 -n 1).$(shuf -i 0-255 -n 1)"
+}
+
+# Variables specific to the control network and interface start with CTRL_.
+# Variables specific to the test network and interface start with TEST_.
+# Variables specific to the DUT start with DUT_.
+# Variables specific to the test runner start with TEST_RUNNER_.
+declare -r PACKETDRILL="/packetdrill/gtests/net/packetdrill/packetdrill"
+# Use random numbers so that test networks don't collide.
+declare CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)"
+declare CTRL_NET_PREFIX=$(new_net_prefix)
+declare TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)"
+declare TEST_NET_PREFIX=$(new_net_prefix)
+declare -r tolerance_usecs=100000
+# On both DUT and test runner, testing packets are on the eth2 interface.
+declare -r TEST_DEVICE="eth2"
+# Number of bits in the *_NET_PREFIX variables.
+declare -r NET_MASK="24"
+# Last bits of the DUT's IP address.
+declare -r DUT_NET_SUFFIX=".10"
+# Control port.
+declare -r CTRL_PORT="40000"
+# Last bits of the test runner's IP address.
+declare -r TEST_RUNNER_NET_SUFFIX=".20"
+declare -r TIMEOUT="60"
+declare -r IMAGE_TAG="gcr.io/gvisor-presubmit/packetdrill"
+
+# Make sure that docker is installed.
+docker --version
+
+function finish {
+ local cleanup_success=1
+ for net in "${CTRL_NET}" "${TEST_NET}"; do
+ # Kill all processes attached to ${net}.
+ for docker_command in "kill" "rm"; do
+ (docker network inspect "${net}" \
+ --format '{{range $key, $value := .Containers}}{{$key}} {{end}}' \
+ | xargs -r docker "${docker_command}") || \
+ cleanup_success=0
+ done
+ # Remove the network.
+ docker network rm "${net}" || \
+ cleanup_success=0
+ done
+
+ if ((!$cleanup_success)); then
+ echo "FAIL: Cleanup command failed"
+ exit 4
+ fi
+}
+trap finish EXIT
+
+# Subnet for control packets between test runner and DUT.
+while ! docker network create \
+ "--subnet=${CTRL_NET_PREFIX}.0/${NET_MASK}" "${CTRL_NET}"; do
+ sleep 0.1
+ CTRL_NET_PREFIX=$(new_net_prefix)
+ CTRL_NET="ctrl_net-$(shuf -i 0-99999999 -n 1)"
+done
+
+# Subnet for the packets that are part of the test.
+while ! docker network create \
+ "--subnet=${TEST_NET_PREFIX}.0/${NET_MASK}" "${TEST_NET}"; do
+ sleep 0.1
+ TEST_NET_PREFIX=$(new_net_prefix)
+ TEST_NET="test_net-$(shuf -i 0-99999999 -n 1)"
+done
+
+# Create the DUT container and connect to network.
+DUT=$(docker create ${RUNTIME_ARG} --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" "${DUT}" \
+ || (docker kill ${DUT}; docker rm ${DUT}; false)
+docker start "${DUT}"
+
+# Create the test runner container and connect to network.
+TEST_RUNNER=$(docker create --privileged --rm \
+ --stop-timeout ${TIMEOUT} -it ${IMAGE_TAG})
+docker network connect "${CTRL_NET}" \
+ --ip "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \
+ || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false)
+docker network connect "${TEST_NET}" \
+ --ip "${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${TEST_RUNNER}" \
+ || (docker kill ${TEST_RUNNER}; docker rm ${REST_RUNNER}; false)
+docker start "${TEST_RUNNER}"
+
+# Run tcpdump in the test runner unbuffered, without dns resolution, just on the
+# interface with the test packets.
+docker exec -t ${TEST_RUNNER} tcpdump -U -n -i "${TEST_DEVICE}" &
+
+# Start a packetdrill server on the test_runner. The packetdrill server sends
+# packets and asserts that they are received.
+docker exec -d "${TEST_RUNNER}" \
+ ${PACKETDRILL} --wire_server --wire_server_dev="${TEST_DEVICE}" \
+ --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --wire_server_port="${CTRL_PORT}" \
+ --local_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --remote_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}"
+
+# Because the Linux kernel receives the SYN-ACK but didn't send the SYN it will
+# issue a RST. To prevent this IPtables can be used to filter those out.
+docker exec "${TEST_RUNNER}" \
+ iptables -A OUTPUT -p tcp --tcp-flags RST RST -j DROP
+
+# Wait for the packetdrill server on the test runner to come. Attempt to
+# connect to it from the DUT every 100 milliseconds until success.
+while ! docker exec "${DUT}" \
+ nc -zv "${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" "${CTRL_PORT}"; do
+ sleep 0.1
+done
+
+# Copy the packetdrill setup script to the DUT.
+docker cp -L "${INIT_SCRIPT}" "${DUT}:packetdrill_setup.sh"
+
+# Copy the packetdrill scripts to the DUT.
+declare -a dut_scripts
+for script in $scripts; do
+ docker cp -L "${script}" "${DUT}:$(basename ${script})"
+ dut_scripts+=("/$(basename ${script})")
+done
+
+# Start a packetdrill client on the DUT. The packetdrill client runs POSIX
+# socket commands and also sends instructions to the server.
+docker exec -t "${DUT}" \
+ ${PACKETDRILL} --wire_client --wire_client_dev="${TEST_DEVICE}" \
+ --wire_server_ip="${CTRL_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --wire_server_port="${CTRL_PORT}" \
+ --local_ip="${TEST_NET_PREFIX}${DUT_NET_SUFFIX}" \
+ --remote_ip="${TEST_NET_PREFIX}${TEST_RUNNER_NET_SUFFIX}" \
+ --init_scripts=/packetdrill_setup.sh \
+ --tolerance_usecs="${tolerance_usecs}" "${dut_scripts[@]}"
+
+echo PASS: No errors.
diff --git a/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
new file mode 100644
index 000000000..a86b90ce6
--- /dev/null
+++ b/test/packetdrill/reset_for_ack_when_no_syn_cookies_in_use.pkt
@@ -0,0 +1,9 @@
+// Test that a listening socket generates a RST when it receives an
+// ACK and syn cookies are not in use.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 bind(3, ..., ...) = 0
+
++0 listen(3, 1) = 0
++0.1 < . 1:1(0) ack 1 win 32792
++0 > R 1:1(0) ack 0 win 0 \ No newline at end of file
diff --git a/test/packetdrill/sanity_test.pkt b/test/packetdrill/sanity_test.pkt
new file mode 100644
index 000000000..b3b58c366
--- /dev/null
+++ b/test/packetdrill/sanity_test.pkt
@@ -0,0 +1,7 @@
+// Basic sanity test. One system call.
+//
+// All of the plumbing has to be working however, and the packetdrill wire
+// client needs to be able to connect to the wire server and send the script,
+// probe local interfaces, run through the test w/ timings, etc.
+
+0.000 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
diff --git a/test/packetdrill/tcp_defer_accept.pkt b/test/packetdrill/tcp_defer_accept.pkt
new file mode 100644
index 000000000..a17f946db
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK does not complete a connection when TCP_DEFER_ACCEPT
+// timeout is not hit but an ACK w/ data does complete and deliver the
+// connection to the accept queue.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [5], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// Now send an ACK w/ data. This should complete the connection
+// and deliver the socket to the accept queue.
++0.1 < . 1:5(4) ack 1 win 257
++0.0 > . 1:1(0) ack 5 <...>
+
+// This should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
+// Now read the data and we should get 4 bytes.
++0.000 read(4,..., 4) = 4
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 5 <...>
++0.0 < F. 5:5(0) ack 2 win 257
++0.01 > . 2:2(0) ack 6 <...> \ No newline at end of file
diff --git a/test/packetdrill/tcp_defer_accept_timeout.pkt b/test/packetdrill/tcp_defer_accept_timeout.pkt
new file mode 100644
index 000000000..201fdeb14
--- /dev/null
+++ b/test/packetdrill/tcp_defer_accept_timeout.pkt
@@ -0,0 +1,48 @@
+// Test that a bare ACK is accepted after TCP_DEFER_ACCEPT timeout
+// is hit and a connection is delivered.
+
+0 socket(..., SOCK_STREAM, IPPROTO_TCP) = 3
++0 setsockopt(3, SOL_TCP, TCP_DEFER_ACCEPT, [3], 4) = 0
++0.000 fcntl(3, F_SETFL, O_RDWR|O_NONBLOCK) = 0
++0 bind(3, ..., ...) = 0
+
+// Set backlog to 1 so that we can easily test.
++0 listen(3, 1) = 0
+
+// Establish a connection without timestamps.
++0.0 < S 0:0(0) win 32792 <mss 1460,sackOK,nop,nop,nop,wscale 7>
++0.0 > S. 0:0(0) ack 1 <...>
+
+// Send a bare ACK this should not complete the connection as we
+// set the TCP_DEFER_ACCEPT above.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The bare ACK should be dropped and no connection should be delivered
+// to the accept queue.
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// Send another bare ACK and it should still fail we set TCP_DEFER_ACCEPT
+// to 5 seconds above.
++2.5 < . 1:1(0) ack 1 win 257
++0.100 accept(3, ..., ...) = -1 EWOULDBLOCK (operation would block)
+
+// set accept socket back to blocking.
++0.000 fcntl(3, F_SETFL, O_RDWR) = 0
+
+// We should see one more retransmit of the SYN-ACK as a last ditch
+// attempt when TCP_DEFER_ACCEPT timeout is hit to trigger another
+// ACK or a packet with data.
++.35~+2.35 > S. 0:0(0) ack 1 <...>
+
+// Now send another bare ACK after TCP_DEFER_ACCEPT time has been passed.
++0.0 < . 1:1(0) ack 1 win 257
+
+// The ACK above should cause connection to transition to connected state.
++0.000 accept(3, ..., ...) = 4
++0.000 fcntl(4, F_SETFL, O_RDWR|O_NONBLOCK) = 0
+
++0.000 close(4) = 0
+
++0.0 > F. 1:1(0) ack 1 <...>
++0.0 < F. 1:1(0) ack 2 win 257
++0.01 > . 2:2(0) ack 2 <...>
diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md
new file mode 100644
index 000000000..ffa96ba98
--- /dev/null
+++ b/test/packetimpact/README.md
@@ -0,0 +1,702 @@
+# Packetimpact
+
+## What is packetimpact?
+
+Packetimpact is a tool for platform-independent network testing. It is heavily
+inspired by [packetdrill](https://github.com/google/packetdrill). It creates two
+docker containers connected by a network. One is for the test bench, which
+operates the test. The other is for the device-under-test (DUT), which is the
+software being tested. The test bench communicates over the network with the DUT
+to check correctness of the network.
+
+### Goals
+
+Packetimpact aims to provide:
+
+* A **multi-platform** solution that can test both Linux and gVisor.
+* **Conciseness** on par with packetdrill scripts.
+* **Control-flow** like for loops, conditionals, and variables.
+* **Flexibilty** to specify every byte in a packet or use multiple sockets.
+
+## How to run packetimpact tests?
+
+Build the test container image by running the following at the root of the
+repository:
+
+```bash
+$ make load-packetimpact
+```
+
+Run a test, e.g. `fin_wait2_timeout`, against Linux:
+
+```bash
+$ bazel test //test/packetimpact/tests:fin_wait2_timeout_native_test
+```
+
+Run the same test, but against gVisor:
+
+```bash
+$ bazel test //test/packetimpact/tests:fin_wait2_timeout_netstack_test
+```
+
+## When to use packetimpact?
+
+There are a few ways to write networking tests for gVisor currently:
+
+* [Go unit tests](https://github.com/google/gvisor/tree/master/pkg/tcpip)
+* [syscall tests](https://github.com/google/gvisor/tree/master/test/syscalls/linux)
+* [packetdrill tests](https://github.com/google/gvisor/tree/master/test/packetdrill)
+* packetimpact tests
+
+The right choice depends on the needs of the test.
+
+Feature | Go unit test | syscall test | packetdrill | packetimpact
+-------------- | ------------ | ------------ | ----------- | ------------
+Multi-platform | no | **YES** | **YES** | **YES**
+Concise | no | somewhat | somewhat | **VERY**
+Control-flow | **YES** | **YES** | no | **YES**
+Flexible | **VERY** | no | somewhat | **VERY**
+
+### Go unit tests
+
+If the test depends on the internals of gVisor and doesn't need to run on Linux
+or other platforms for comparison purposes, a Go unit test can be appropriate.
+They can observe internals of gVisor networking. The downside is that they are
+**not concise** and **not multi-platform**. If you require insight on gVisor
+internals, this is the right choice.
+
+### Syscall tests
+
+Syscall tests are **multi-platform** but cannot examine the internals of gVisor
+networking. They are **concise**. They can use **control-flow** structures like
+conditionals, for loops, and variables. However, they are limited to only what
+the POSIX interface provides so they are **not flexible**. For example, you
+would have difficulty writing a syscall test that intentionally sends a bad IP
+checksum. Or if you did write that test with raw sockets, it would be very
+**verbose** to write a test that intentionally send wrong checksums, wrong
+protocols, wrong sequence numbers, etc.
+
+### Packetdrill tests
+
+Packetdrill tests are **multi-platform** and can run against both Linux and
+gVisor. They are **concise** and use a special packetdrill scripting language.
+They are **more flexible** than a syscall test in that they can send packets
+that a syscall test would have difficulty sending, like a packet with a
+calcuated ACK number. But they are also somewhat limimted in flexibiilty in that
+they can't do tests with multiple sockets. They have **no control-flow** ability
+like variables or conditionals. For example, it isn't possible to send a packet
+that depends on the window size of a previous packet because the packetdrill
+language can't express that. Nor could you branch based on whether or not the
+other side supports window scaling, for example.
+
+### Packetimpact tests
+
+Packetimpact tests are similar to Packetdrill tests except that they are written
+in Go instead of the packetdrill scripting language. That gives them all the
+**control-flow** abilities of Go (loops, functions, variables, etc). They are
+**multi-platform** in the same way as packetdrill tests but even more
+**flexible** because Go is more expressive than the scripting language of
+packetdrill. However, Go is **not as concise** as the packetdrill language. Many
+design decisions below are made to mitigate that.
+
+## How it works
+
+```
+ Testbench Device-Under-Test (DUT)
+ +-------------------+ +------------------------+
+ | | TEST NET | |
+ | rawsockets.go <-->| <===========> | <---+ |
+ | ^ | | | |
+ | | | | | |
+ | v | | | |
+ | unittest | | | |
+ | ^ | | | |
+ | | | | | |
+ | v | | v |
+ | dut.go <========gRPC========> posix server |
+ | | CONTROL NET | |
+ +-------------------+ +------------------------+
+```
+
+Two docker containers are created by a "runner" script, one for the testbench
+and the other for the device under test (DUT). The script connects the two
+containers with a control network and test network. It also does some other
+tasks like waiting until the DUT is ready before starting the test and disabling
+Linux networking that would interfere with the test bench.
+
+### DUT
+
+The DUT container runs a program called the "posix_server". The posix_server is
+written in c++ for maximum portability. It is compiled on the host. The script
+that starts the containers copies it into the DUT's container and runs it. It's
+job is to receive directions from the test bench on what actions to take. For
+this, the posix_server does three steps in a loop:
+
+1. Listen for a request from the test bench.
+2. Execute a command.
+3. Send the response back to the test bench.
+
+The requests and responses are
+[protobufs](https://developers.google.com/protocol-buffers) and the
+communication is done with [gRPC](https://grpc.io/). The commands run are
+[POSIX socket commands](https://en.wikipedia.org/wiki/Berkeley_sockets#Socket_API_functions),
+with the inputs and outputs converted into protobuf requests and responses. All
+communication is on the control network, so that the test network is unaffected
+by extra packets.
+
+For example, this is the request and response pair to call
+[`socket()`](http://man7.org/linux/man-pages/man2/socket.2.html):
+
+```protocol-buffer
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2;
+}
+```
+
+##### Alternatives considered
+
+* We could have use JSON for communication instead. It would have been a
+ lighter-touch than protobuf but protobuf handles all the data type and has
+ strict typing to prevent a class of errors. The test bench could be written
+ in other languages, too.
+* Instead of mimicking the POSIX interfaces, arguments could have had a more
+ natural form, like the `bind()` getting a string IP address instead of bytes
+ in a `sockaddr_t`. However, conforming to the existing structures keeps more
+ of the complexity in Go and keeps the posix_server simpler and thus more
+ likely to compile everywhere.
+
+### Test Bench
+
+The test bench does most of the work in a test. It is a Go program that compiles
+on the host and is copied by the script into test bench's container. It is a
+regular [go unit test](https://golang.org/pkg/testing/) that imports the test
+bench framework. The test bench framwork is based on three basic utilities:
+
+* Commanding the DUT to run POSIX commands and return responses.
+* Sending raw packets to the DUT on the test network.
+* Listening for raw packets from the DUT on the test network.
+
+#### DUT commands
+
+To keep the interface to the DUT consistent and easy-to-use, each POSIX command
+supported by the posix_server is wrapped in functions with signatures similar to
+the ones in the [Go unix package](https://godoc.org/golang.org/x/sys/unix). This
+way all the details of endianess and (un)marshalling of go structs such as
+[unix.Timeval](https://godoc.org/golang.org/x/sys/unix#Timeval) is handled in
+one place. This also makes it straight-forward to convert tests that use `unix.`
+or `syscall.` calls to `dut.` calls.
+
+For example, creating a connection to the DUT and commanding it to make a socket
+looks like this:
+
+```go
+dut := testbench.NewDut(t)
+fd, err := dut.SocketWithErrno(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+if fd < 0 {
+ t.Fatalf(...)
+}
+```
+
+Because the usual case is to fail the test when the DUT fails to create a
+socket, there is a concise version of each of the `...WithErrno` functions that
+does that:
+
+```go
+dut := testbench.NewDut(t)
+fd := dut.Socket(unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_IP)
+```
+
+The DUT and other structs in the code store a `*testing.T` so that they can
+provide versions of functions that call `t.Fatalf(...)`. This helps keep tests
+concise.
+
+##### Alternatives considered
+
+* Instead of mimicking the `unix.` go interface, we could have invented a more
+ natural one, like using `float64` instead of `Timeval`. However, using the
+ same function signatures that `unix.` has makes it easier to convert code to
+ `dut.`. Also, using an existing interface ensures that we don't invent an
+ interface that isn't extensible. For example, if we invented a function for
+ `bind()` that didn't support IPv6 and later we had to add a second `bind6()`
+ function.
+
+#### Sending/Receiving Raw Packets
+
+The framework wraps POSIX sockets for sending and receiving raw frames. Both
+send and receive are synchronous commands.
+[SO_RCVTIMEO](http://man7.org/linux/man-pages/man7/socket.7.html) is used to set
+a timeout on the receive commands. For ease of use, these are wrapped in an
+`Injector` and a `Sniffer`. They have functions:
+
+```go
+func (s *Sniffer) Recv(timeout time.Duration) []byte {...}
+func (i *Injector) Send(b []byte) {...}
+```
+
+##### Alternatives considered
+
+* [gopacket](https://github.com/google/gopacket) pcap has raw socket support
+ but requires cgo. cgo is not guaranteed to be portable from the host to the
+ container and in practice, the container doesn't recognize binaries built on
+ the host if they use cgo.
+* Both gVisor and gopacket have the ability to read and write pcap files
+ without cgo but that is insufficient here because we can't just replay pcap
+ files, we need a more dynamic solution.
+* The sniffer and injector can't share a socket because they need to be bound
+ differently.
+* Sniffing could have been done asynchronously with channels, obviating the
+ need for `SO_RCVTIMEO`. But that would introduce asynchronous complication.
+ `SO_RCVTIMEO` is well supported on the test bench.
+
+#### `Layer` struct
+
+A large part of packetimpact tests is creating packets to send and comparing
+received packets against expectations. To keep tests concise, it is useful to be
+able to specify just the important parts of packets that need to be set. For
+example, sending a packet with default values except for TCP Flags. And for
+packets received, it's useful to be able to compare just the necessary parts of
+received packets and ignore the rest.
+
+To aid in both of those, Go structs with optional fields are created for each
+encapsulation type, such as IPv4, TCP, and Ethernet. This is inspired by
+[scapy](https://scapy.readthedocs.io/en/latest/). For example, here is the
+struct for Ethernet:
+
+```go
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+```
+
+Each struct has the same fields as those in the
+[gVisor headers](https://github.com/google/gvisor/tree/master/pkg/tcpip/header)
+but with a pointer for each field that may be `nil`.
+
+##### Alternatives considered
+
+* Just use []byte like gVisor headers do. The drawback is that it makes the
+ tests more verbose.
+ * For example, there would be no way to call `Send(myBytes)` concisely and
+ indicate if the checksum should be calculated automatically versus
+ overridden. The only way would be to add lines to the test to calculate
+ it before each Send, which is wordy. Or make multiple versions of Send:
+ one that checksums IP, one that doesn't, one that checksums TCP, one
+ that does both, etc. That would be many combinations.
+ * Filtering inputs would become verbose. Either:
+ * large conditionals that need to be repeated many places:
+ `h[FlagOffset] == SYN && h[LengthOffset:LengthOffset+2] == ...` or
+ * Many functions, one per field, like: `filterByFlag(myBytes, SYN)`,
+ `filterByLength(myBytes, 20)`, `filterByNextProto(myBytes, 0x8000)`,
+ etc.
+ * Using pointers allows us to combine `Layer`s with reflection. So the
+ default `Layers` can be overridden by a `Layers` with just the TCP
+ conection's src/dst which can be overridden by one with just a test
+ specific TCP window size.
+ * It's a proven way to separate the details of a packet from the byte
+ format as shown by scapy's success.
+* Use packetgo. It's more general than parsing packets with gVisor. However:
+ * packetgo doesn't have optional fields so many of the above problems
+ still apply.
+ * It would be yet another dependency.
+ * It's not as well known to engineers that are already writing gVisor
+ code.
+ * It might be a good candidate for replacing the parsing of packets into
+ `Layer`s if all that parsing turns out to be more work than parsing by
+ packetgo and converting *that* to `Layer`. packetgo has easier to use
+ getters for the layers. This could be done later in a way that doesn't
+ break tests.
+
+#### `Layer` methods
+
+The `Layer` structs provide a way to partially specify an encapsulation. They
+also need methods for using those partially specified encapsulation, for example
+to marshal them to bytes or compare them. For those, each encapsulation
+implements the `Layer` interface:
+
+```go
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ // toBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ toBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+}
+```
+
+The `next` and `prev` make up a link listed so that each layer can get at the
+information in the layer around it. This is necessary for some protocols, like
+TCP that needs the layer before and payload after to compute the checksum. Any
+sequence of `Layer` structs is valid so long as the parser and `toBytes`
+functions can map from type to protool number and vice-versa. When the mapping
+fails, an error is emitted explaining what functionality is missing. The
+solution is either to fix the ordering or implement the missing protocol.
+
+For each `Layer` there is also a parsing function. For example, this one is for
+Ethernet:
+
+```
+func ParseEther(b []byte) (Layers, error)
+```
+
+The parsing function converts bytes received on the wire into a `Layer`
+(actually `Layers`, see below) which has no `nil`s in it. By using
+`match(Layer)` to compare against another `Layer` that *does* have `nil`s in it,
+the received bytes can be partially compared. The `nil`s behave as
+"don't-cares".
+
+##### Alternatives considered
+
+* Matching against `[]byte` instead of converting to `Layer` first.
+ * The downside is that it precludes the use of a `cmp.Equal` one-liner to
+ do comparisons.
+ * It creates confusion in the code to deal with both representations at
+ different times. For example, is the checksum calculated on `[]byte` or
+ `Layer` when sending? What about when checking received packets?
+
+#### `Layers`
+
+```
+type Layers []Layer
+
+func (ls *Layers) match(other Layers) bool {...}
+func (ls *Layers) toBytes() ([]byte, error) {...}
+```
+
+`Layers` is an array of `Layer`. It represents a stack of encapsulations, such
+as `Layers{Ether{},IPv4{},TCP{},Payload{}}`. It also has `toBytes()` and
+`match(Layers)`, like `Layer`. The parse functions above actually return
+`Layers` and not `Layer` because they know about the headers below and
+sequentially call each parser on the remaining, encapsulated bytes.
+
+All this leads to the ability to write concise packet processing. For example:
+
+```go
+etherType := 0x8000
+flags = uint8(header.TCPFlagSyn|header.TCPFlagAck)
+toMatch := Layers{Ether{Type: &etherType}, IPv4{}, TCP{Flags: &flags}}
+for {
+ recvBytes := sniffer.Recv(time.Second)
+ if recvBytes == nil {
+ println("Got no packet for 1 second")
+ }
+ gotPacket, err := ParseEther(recvBytes)
+ if err == nil && toMatch.match(gotPacket) {
+ println("Got a TCP/IPv4/Eth packet with SYNACK")
+ }
+}
+```
+
+##### Alternatives considered
+
+* Don't use previous and next pointers.
+ * Each layer may need to be able to interrogate the layers around it, like
+ for computing the next protocol number or total length. So *some*
+ mechanism is needed for a `Layer` to see neighboring layers.
+ * We could pass the entire array `Layers` to the `toBytes()` function.
+ Passing an array to a method that includes in the array the function
+ receiver itself seems wrong.
+
+#### `layerState`
+
+`Layers` represents the different headers of a packet but a connection includes
+more state. For example, a TCP connection needs to keep track of the next
+expected sequence number and also the next sequence number to send. This is
+stored in a `layerState` struct. This is the `layerState` for TCP:
+
+```go
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+```
+
+The next sequence numbers for each side of the connection are stored. `out` and
+`in` have defaults for the TCP header, such as the expected source and
+destination ports for outgoing packets and incoming packets.
+
+##### `layerState` interface
+
+```go
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame.
+ outgoing() Layer
+
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
+}
+```
+
+`outgoing` generates the default Layer for an outgoing packet. For TCP, this
+would be a `TCP` with the source and destination ports populated. Because they
+are static, they are stored inside the `out` member of `tcpState`. However, the
+sequence numbers change frequently so the outgoing sequence number is stored in
+the `localSeqNum` and put into the output of outgoing for each call.
+
+`incoming` does the same functions for packets that arrive but instead of
+generating a packet to send, it generates an expect packet for filtering packets
+that arrive. For example, if a `TCP` header arrives with the wrong ports, it can
+be ignored as belonging to a different connection. `incoming` needs the received
+header itself as an input because the filter may depend on the input. For
+example, the expected sequence number depends on the flags in the TCP header.
+
+`sent` and `received` are run for each header that is actually sent or received
+and used to update the internal state. `incoming` and `outgoing` should *not* be
+used for these purpose. For example, `incoming` is called on every packet that
+arrives but only packets that match ought to actually update the state.
+`outgoing` is called to created outgoing packets and those packets are always
+sent, so unlike `incoming`/`received`, there is one `outgoing` call for each
+`sent` call.
+
+`close` cleans up after the layerState. For example, TCP and UDP need to keep a
+port reserved and then release it.
+
+#### Connections
+
+Using `layerState` above, we can create connections.
+
+```go
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+ t *testing.T
+}
+```
+
+The connection stores an array of `layerState` in the order that the headers
+should be present in the frame to send. For example, Ether then IPv4 then TCP.
+The injector and sniffer are for writing and reading frames. A `*testing.T` is
+stored so that internal errors can be reported directly without code in the unit
+test.
+
+The `Connection` has some useful functions:
+
+```go
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close() {...}
+// CreateFrame builds a frame for the connection with layer overriding defaults
+// of the innermost layer and additionalLayers added after it.
+func (conn *Connection) CreateFrame(layer Layer, additionalLayers ...Layer) Layers {...}
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(frame Layers) {...}
+// Send a packet with reasonable defaults. Potentially override the final layer
+// in the connection with the provided layer and add additionLayers.
+func (conn *Connection) Send(layer Layer, additionalLayers ...Layer) {...}
+// Expect a frame with the final layerStates layer matching the provided Layer
+// within the timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {...}
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, it returns nil.
+func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {...}
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain() {...}
+```
+
+`CreateFrame` uses the `[]layerState` to create a frame to send. The first
+argument is for overriding defaults in the last header of the frame, because
+this is the most common need. For a TCPIPv4 connection, this would be the TCP
+header. Optional additionalLayers can be specified to add to the frame being
+created, such as a `Payload` for `TCP`.
+
+`SendFrame` sends the frame to the DUT. It is combined with `CreateFrame` to
+make `Send`. For unittests with basic sending needs, `Send` can be used. If more
+control is needed over the frame, it can be made with `CreateFrame`, modified in
+the unit test, and then sent with `SendFrame`.
+
+On the receiving side, there is `Expect` and `ExpectFrame`. Like with the
+sending side, there are two forms of each function, one for just the last header
+and one for the whole frame. The expect functions use the `[]layerState` to
+create a template for the expected incoming frame. That frame is then overridden
+by the values in the first argument. Finally, a loop starts sniffing packets on
+the wire for frames. If a matching frame is found before the timeout, it is
+returned without error. If not, nil is returned and the error contains text of
+all the received frames that didn't match. Exactly one of the outputs will be
+non-nil, even if no frames are received at all.
+
+`Drain` sniffs and discards all the frames that have yet to be received. A
+common way to write a test is:
+
+```go
+conn.Drain() // Discard all outstanding frames.
+conn.Send(...) // Send a frame with overrides.
+// Now expect a frame with a certain header and fail if it doesn't arrive.
+if _, err := conn.Expect(...); err != nil { t.Fatal(...) }
+```
+
+Or for a test where we want to check that no frame arrives:
+
+```go
+if gotOne, _ := conn.Expect(...); gotOne != nil { t.Fatal(...) }
+```
+
+#### Specializing `Connection`
+
+Because there are some common combinations of `layerState` into `Connection`,
+they are defined:
+
+```go
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
+```
+
+Each has a `NewXxx` function to create a new connection with reasonable
+defaults. They also have functions that call the underlying `Connection`
+functions but with specialization and tighter type-checking. For example:
+
+```go
+func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
+ (*Connection)(conn).Send(&tcp, additionalLayers...)
+}
+func (conn *TCPIPv4) Drain() {
+ conn.sniffer.Drain()
+}
+```
+
+They may also have some accessors to get or set the internal state of the
+connection:
+
+```go
+func (conn *TCPIPv4) state() *tcpState {
+ state, ok := conn.layerStates[len(conn.layerStates)-1].(*tcpState)
+ if !ok {
+ conn.t.Fatalf("expected final state of %v to be tcpState", conn.layerStates)
+ }
+ return state
+}
+func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
+ return conn.state().remoteSeqNum
+}
+func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
+ return conn.state().localSeqNum
+}
+```
+
+Unittests will in practice use these functions and not the functions on
+`Connection`. For example, `NewTCPIPv4()` and then call `Send` on that rather
+than cast is to a `Connection` and call `Send` on that cast result.
+
+##### Alternatives considered
+
+* Instead of storing `outgoing` and `incoming`, store values.
+ * There would be many more things to store instead, like `localMac`,
+ `remoteMac`, `localIP`, `remoteIP`, `localPort`, and `remotePort`.
+ * Construction of a packet would be many lines to copy each of these
+ values into a `[]byte`. And there would be slight variations needed for
+ each encapsulation stack, like TCPIPv6 and ARP.
+ * Filtering incoming packets would be a long sequence:
+ * Compare the MACs, then
+ * Parse the next header, then
+ * Compare the IPs, then
+ * Parse the next header, then
+ * Compare the TCP ports. Instead it's all just one call to
+ `cmp.Equal(...)`, for all sequences.
+ * A TCPIPv6 connection could share most of the code. Only the type of the
+ IP addresses are different. The types of `outgoing` and `incoming` would
+ be remain `Layers`.
+ * An ARP connection could share all the Ethernet parts. The IP `Layer`
+ could be factored out of `outgoing`. After that, the IPv4 and IPv6
+ connections could implement one interface and a single TCP struct could
+ have either network protocol through composition.
+
+## Putting it all together
+
+Here's what te start of a packetimpact unit test looks like. This test creates a
+TCP connection with the DUT. There are added comments for explanation in this
+document but a real test might not include them in order to stay even more
+concise.
+
+```go
+func TestMyTcpTest(t *testing.T) {
+ // Prepare a DUT for communication.
+ dut := testbench.NewDUT(t)
+
+ // This does:
+ // dut.Socket()
+ // dut.Bind()
+ // dut.Getsockname() to learn the new port number
+ // dut.Listen()
+ listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(listenFD) // Tell the DUT to close the socket at the end of the test.
+
+ // Monitor a new TCP connection with sniffer, injector, sequence number tracking,
+ // and reasonable outgoing and incoming packet field default IPs, MACs, and port numbers.
+ conn := testbench.NewTCPIPv4(t, dut, remotePort)
+
+ // Perform a 3-way handshake: send SYN, expect SYNACK, send ACK.
+ conn.Handshake()
+
+ // Tell the DUT to accept the new connection.
+ acceptFD := dut.Accept(acceptFd)
+}
+```
+
+## Other notes
+
+* The time between receiving a SYN-ACK and replying with an ACK in `Handshake`
+ is about 3ms. This is much slower than the native unix response, which is
+ about 0.3ms. Packetdrill gets closer to 0.3ms. For tests where timing is
+ crucial, packetdrill is faster and more precise.
diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD
new file mode 100644
index 000000000..3ce63c2c6
--- /dev/null
+++ b/test/packetimpact/dut/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "cc_binary", "grpcpp")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "posix_server",
+ srcs = ["posix_server.cc"],
+ linkstatic = 1,
+ static = True, # This is needed for running in a docker container.
+ deps = [
+ grpcpp,
+ "//test/packetimpact/proto:posix_server_cc_grpc_proto",
+ "//test/packetimpact/proto:posix_server_cc_proto",
+ ],
+)
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
new file mode 100644
index 000000000..29d4cc6fe
--- /dev/null
+++ b/test/packetimpact/dut/posix_server.cc
@@ -0,0 +1,371 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at //
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <getopt.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <unordered_map>
+
+#include "include/grpcpp/security/server_credentials.h"
+#include "include/grpcpp/server_builder.h"
+#include "test/packetimpact/proto/posix_server.grpc.pb.h"
+#include "test/packetimpact/proto/posix_server.pb.h"
+
+// Converts a sockaddr_storage to a Sockaddr message.
+::grpc::Status sockaddr_to_proto(const sockaddr_storage &addr,
+ socklen_t addrlen,
+ posix_server::Sockaddr *sockaddr_proto) {
+ switch (addr.ss_family) {
+ case AF_INET: {
+ auto addr_in = reinterpret_cast<const sockaddr_in *>(&addr);
+ auto response_in = sockaddr_proto->mutable_in();
+ response_in->set_family(addr_in->sin_family);
+ response_in->set_port(ntohs(addr_in->sin_port));
+ response_in->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in->sin_addr.s_addr), 4);
+ return ::grpc::Status::OK;
+ }
+ case AF_INET6: {
+ auto addr_in6 = reinterpret_cast<const sockaddr_in6 *>(&addr);
+ auto response_in6 = sockaddr_proto->mutable_in6();
+ response_in6->set_family(addr_in6->sin6_family);
+ response_in6->set_port(ntohs(addr_in6->sin6_port));
+ response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo));
+ response_in6->mutable_addr()->assign(
+ reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ // sin6_scope_id is stored in host byte order.
+ //
+ // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html
+ response_in6->set_scope_id(addr_in6->sin6_scope_id);
+ return ::grpc::Status::OK;
+ }
+ }
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown Sockaddr");
+}
+
+::grpc::Status proto_to_sockaddr(const posix_server::Sockaddr &sockaddr_proto,
+ sockaddr_storage *addr, socklen_t *addr_len) {
+ switch (sockaddr_proto.sockaddr_case()) {
+ case posix_server::Sockaddr::SockaddrCase::kIn: {
+ auto proto_in = sockaddr_proto.in();
+ if (proto_in.addr().size() != 4) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv4 address must be 4 bytes");
+ }
+ auto addr_in = reinterpret_cast<sockaddr_in *>(addr);
+ addr_in->sin_family = proto_in.family();
+ addr_in->sin_port = htons(proto_in.port());
+ proto_in.addr().copy(reinterpret_cast<char *>(&addr_in->sin_addr.s_addr),
+ 4);
+ *addr_len = sizeof(*addr_in);
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::kIn6: {
+ auto proto_in6 = sockaddr_proto.in6();
+ if (proto_in6.addr().size() != 16) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "IPv6 address must be 16 bytes");
+ }
+ auto addr_in6 = reinterpret_cast<sockaddr_in6 *>(addr);
+ addr_in6->sin6_family = proto_in6.family();
+ addr_in6->sin6_port = htons(proto_in6.port());
+ addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo());
+ proto_in6.addr().copy(
+ reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16);
+ // sin6_scope_id is stored in host byte order.
+ //
+ // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html
+ addr_in6->sin6_scope_id = proto_in6.scope_id();
+ *addr_len = sizeof(*addr_in6);
+ break;
+ }
+ case posix_server::Sockaddr::SockaddrCase::SOCKADDR_NOT_SET:
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown Sockaddr");
+ }
+ return ::grpc::Status::OK;
+}
+
+class PosixImpl final : public posix_server::Posix::Service {
+ ::grpc::Status Accept(grpc_impl::ServerContext *context,
+ const ::posix_server::AcceptRequest *request,
+ ::posix_server::AcceptResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_fd(accept(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status Bind(grpc_impl::ServerContext *context,
+ const ::posix_server::BindRequest *request,
+ ::posix_server::BindResponse *response) override {
+ if (!request->has_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+
+ sockaddr_storage addr;
+ socklen_t addr_len;
+ auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(
+ bind(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Close(grpc_impl::ServerContext *context,
+ const ::posix_server::CloseRequest *request,
+ ::posix_server::CloseResponse *response) override {
+ response->set_ret(close(request->fd()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Connect(grpc_impl::ServerContext *context,
+ const ::posix_server::ConnectRequest *request,
+ ::posix_server::ConnectResponse *response) override {
+ if (!request->has_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+ socklen_t addr_len;
+ auto err = proto_to_sockaddr(request->addr(), &addr, &addr_len);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(connect(request->sockfd(),
+ reinterpret_cast<sockaddr *>(&addr), addr_len));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Fcntl(grpc_impl::ServerContext *context,
+ const ::posix_server::FcntlRequest *request,
+ ::posix_server::FcntlResponse *response) override {
+ response->set_ret(::fcntl(request->fd(), request->cmd(), request->arg()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status GetSockName(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::GetSockNameRequest *request,
+ ::posix_server::GetSockNameResponse *response) override {
+ sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ response->set_ret(getsockname(
+ request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen));
+ response->set_errno_(errno);
+ return sockaddr_to_proto(addr, addrlen, response->mutable_addr());
+ }
+
+ ::grpc::Status GetSockOpt(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::GetSockOptRequest *request,
+ ::posix_server::GetSockOptResponse *response) override {
+ switch (request->type()) {
+ case ::posix_server::GetSockOptRequest::BYTES: {
+ socklen_t optlen = request->optlen();
+ std::vector<char> buf(optlen);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), buf.data(),
+ &optlen));
+ if (optlen >= 0) {
+ response->mutable_optval()->set_bytesval(buf.data(), optlen);
+ }
+ break;
+ }
+ case ::posix_server::GetSockOptRequest::INT: {
+ int intval = 0;
+ socklen_t optlen = sizeof(intval);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), &intval, &optlen));
+ response->mutable_optval()->set_intval(intval);
+ break;
+ }
+ case ::posix_server::GetSockOptRequest::TIME: {
+ timeval tv;
+ socklen_t optlen = sizeof(tv);
+ response->set_ret(::getsockopt(request->sockfd(), request->level(),
+ request->optname(), &tv, &optlen));
+ response->mutable_optval()->mutable_timeval()->set_seconds(tv.tv_sec);
+ response->mutable_optval()->mutable_timeval()->set_microseconds(
+ tv.tv_usec);
+ break;
+ }
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown SockOpt Type");
+ }
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Listen(grpc_impl::ServerContext *context,
+ const ::posix_server::ListenRequest *request,
+ ::posix_server::ListenResponse *response) override {
+ response->set_ret(listen(request->sockfd(), request->backlog()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Send(::grpc::ServerContext *context,
+ const ::posix_server::SendRequest *request,
+ ::posix_server::SendResponse *response) override {
+ response->set_ret(::send(request->sockfd(), request->buf().data(),
+ request->buf().size(), request->flags()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SendTo(::grpc::ServerContext *context,
+ const ::posix_server::SendToRequest *request,
+ ::posix_server::SendToResponse *response) override {
+ if (!request->has_dest_addr()) {
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Missing address");
+ }
+ sockaddr_storage addr;
+ socklen_t addr_len;
+ auto err = proto_to_sockaddr(request->dest_addr(), &addr, &addr_len);
+ if (!err.ok()) {
+ return err;
+ }
+
+ response->set_ret(::sendto(request->sockfd(), request->buf().data(),
+ request->buf().size(), request->flags(),
+ reinterpret_cast<sockaddr *>(&addr), addr_len));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status SetSockOpt(
+ grpc_impl::ServerContext *context,
+ const ::posix_server::SetSockOptRequest *request,
+ ::posix_server::SetSockOptResponse *response) override {
+ switch (request->optval().val_case()) {
+ case ::posix_server::SockOptVal::kBytesval:
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(),
+ request->optval().bytesval().c_str(),
+ request->optval().bytesval().size()));
+ break;
+ case ::posix_server::SockOptVal::kIntval: {
+ int opt = request->optval().intval();
+ response->set_ret(::setsockopt(request->sockfd(), request->level(),
+ request->optname(), &opt, sizeof(opt)));
+ break;
+ }
+ case ::posix_server::SockOptVal::kTimeval: {
+ timeval tv = {.tv_sec = static_cast<__time_t>(
+ request->optval().timeval().seconds()),
+ .tv_usec = static_cast<__suseconds_t>(
+ request->optval().timeval().microseconds())};
+ response->set_ret(setsockopt(request->sockfd(), request->level(),
+ request->optname(), &tv, sizeof(tv)));
+ break;
+ }
+ default:
+ return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
+ "Unknown SockOpt Type");
+ }
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Socket(grpc_impl::ServerContext *context,
+ const ::posix_server::SocketRequest *request,
+ ::posix_server::SocketResponse *response) override {
+ response->set_fd(
+ socket(request->domain(), request->type(), request->protocol()));
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+
+ ::grpc::Status Recv(::grpc::ServerContext *context,
+ const ::posix_server::RecvRequest *request,
+ ::posix_server::RecvResponse *response) override {
+ std::vector<char> buf(request->len());
+ response->set_ret(
+ recv(request->sockfd(), buf.data(), buf.size(), request->flags()));
+ if (response->ret() >= 0) {
+ response->set_buf(buf.data(), response->ret());
+ }
+ response->set_errno_(errno);
+ return ::grpc::Status::OK;
+ }
+};
+
+// Parse command line options. Returns a pointer to the first argument beyond
+// the options.
+void parse_command_line_options(int argc, char *argv[], std::string *ip,
+ int *port) {
+ static struct option options[] = {{"ip", required_argument, NULL, 1},
+ {"port", required_argument, NULL, 2},
+ {0, 0, 0, 0}};
+
+ // Parse the arguments.
+ int c;
+ while ((c = getopt_long(argc, argv, "", options, NULL)) > 0) {
+ if (c == 1) {
+ *ip = optarg;
+ } else if (c == 2) {
+ *port = std::stoi(std::string(optarg));
+ }
+ }
+}
+
+void run_server(const std::string &ip, int port) {
+ PosixImpl posix_service;
+ grpc::ServerBuilder builder;
+ std::string server_address = ip + ":" + std::to_string(port);
+ // Set the authentication mechanism.
+ std::shared_ptr<grpc::ServerCredentials> creds =
+ grpc::InsecureServerCredentials();
+ builder.AddListeningPort(server_address, creds);
+ builder.RegisterService(&posix_service);
+
+ std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
+ std::cerr << "Server listening on " << server_address << std::endl;
+ server->Wait();
+ std::cerr << "posix_server is finished." << std::endl;
+}
+
+int main(int argc, char *argv[]) {
+ std::cerr << "posix_server is starting." << std::endl;
+ std::string ip;
+ int port;
+ parse_command_line_options(argc, argv, &ip, &port);
+
+ std::cerr << "Got IP " << ip << " and port " << port << "." << std::endl;
+ run_server(ip, port);
+}
diff --git a/test/packetimpact/netdevs/BUILD b/test/packetimpact/netdevs/BUILD
new file mode 100644
index 000000000..8d1193fed
--- /dev/null
+++ b/test/packetimpact/netdevs/BUILD
@@ -0,0 +1,23 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "netdevs",
+ srcs = ["netdevs.go"],
+ visibility = ["//test/packetimpact:__subpackages__"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ ],
+)
+
+go_test(
+ name = "netdevs_test",
+ size = "small",
+ srcs = ["netdevs_test.go"],
+ library = ":netdevs",
+ deps = ["@com_github_google_go_cmp//cmp:go_default_library"],
+)
diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go
new file mode 100644
index 000000000..eecfe0730
--- /dev/null
+++ b/test/packetimpact/netdevs/netdevs.go
@@ -0,0 +1,115 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netdevs contains utilities for working with network devices.
+package netdevs
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// A DeviceInfo represents a network device.
+type DeviceInfo struct {
+ ID uint32
+ MAC net.HardwareAddr
+ IPv4Addr net.IP
+ IPv4Net *net.IPNet
+ IPv6Addr net.IP
+ IPv6Net *net.IPNet
+}
+
+var (
+ deviceLine = regexp.MustCompile(`^\s*(\d+): (\w+)`)
+ linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`)
+ inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`)
+ inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`)
+)
+
+// ParseDevices parses the output from `ip addr show` into a map from device
+// name to information about the device.
+//
+// Note: if multiple IPv6 addresses are assigned to a device, the last address
+// displayed by `ip addr show` will be used. This is fine for packetimpact
+// because we will always only have at most one IPv6 address assigned to each
+// device.
+func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) {
+ var currentDevice string
+ var currentInfo DeviceInfo
+ deviceInfos := make(map[string]DeviceInfo)
+ for _, line := range strings.Split(cmdOutput, "\n") {
+ if m := deviceLine.FindStringSubmatch(line); m != nil {
+ if currentDevice != "" {
+ deviceInfos[currentDevice] = currentInfo
+ }
+ id, err := strconv.ParseUint(m[1], 10, 32)
+ if err != nil {
+ return nil, fmt.Errorf("parsing device ID %s: %w", m[1], err)
+ }
+ currentInfo = DeviceInfo{ID: uint32(id)}
+ currentDevice = m[2]
+ } else if m := linkLine.FindStringSubmatch(line); m != nil {
+ mac, err := net.ParseMAC(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.MAC = mac
+ } else if m := inetLine.FindStringSubmatch(line); m != nil {
+ ipv4Addr, ipv4Net, err := net.ParseCIDR(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.IPv4Addr = ipv4Addr
+ currentInfo.IPv4Net = ipv4Net
+ } else if m := inet6Line.FindStringSubmatch(line); m != nil {
+ ipv6Addr, ipv6Net, err := net.ParseCIDR(m[1])
+ if err != nil {
+ return nil, err
+ }
+ currentInfo.IPv6Addr = ipv6Addr
+ currentInfo.IPv6Net = ipv6Net
+ }
+ }
+ if currentDevice != "" {
+ deviceInfos[currentDevice] = currentInfo
+ }
+ return deviceInfos, nil
+}
+
+// MACToIP converts the MAC address to an IPv6 link local address as described
+// in RFC 4291 page 20: https://tools.ietf.org/html/rfc4291#page-20
+func MACToIP(mac net.HardwareAddr) net.IP {
+ addr := make([]byte, header.IPv6AddressSize)
+ addr[0] = 0xfe
+ addr[1] = 0x80
+ header.EthernetAdddressToModifiedEUI64IntoBuf(tcpip.LinkAddress(mac), addr[8:])
+ return net.IP(addr)
+}
+
+// FindDeviceByIP finds a DeviceInfo and device name from an IP address in the
+// output of ParseDevices.
+func FindDeviceByIP(ip net.IP, devices map[string]DeviceInfo) (string, DeviceInfo, error) {
+ for dev, info := range devices {
+ if info.IPv4Addr.Equal(ip) {
+ return dev, info, nil
+ }
+ }
+ return "", DeviceInfo{}, fmt.Errorf("can't find %s on any interface", ip)
+}
diff --git a/test/packetimpact/netdevs/netdevs_test.go b/test/packetimpact/netdevs/netdevs_test.go
new file mode 100644
index 000000000..24ad12198
--- /dev/null
+++ b/test/packetimpact/netdevs/netdevs_test.go
@@ -0,0 +1,227 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package netdevs
+
+import (
+ "fmt"
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func mustParseMAC(s string) net.HardwareAddr {
+ mac, err := net.ParseMAC(s)
+ if err != nil {
+ panic(fmt.Sprintf("failed to parse test MAC %q: %s", s, err))
+ }
+ return mac
+}
+
+func TestParseDevices(t *testing.T) {
+ for _, v := range []struct {
+ desc string
+ cmdOutput string
+ want map[string]DeviceInfo
+ }{
+ {
+ desc: "v4 and v6",
+ cmdOutput: `
+1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000
+ link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
+ inet 127.0.0.1/8 scope host lo
+ valid_lft forever preferred_lft forever
+ inet6 ::1/128 scope host
+ valid_lft forever preferred_lft forever
+2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0
+ valid_lft forever preferred_lft forever
+ inet6 fe80::42:c0ff:fea8:902/64 scope link tentative
+ valid_lft forever preferred_lft forever
+2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 223.245.225.10/24 brd 223.245.225.255 scope global eth2
+ valid_lft forever preferred_lft forever
+ inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative
+ valid_lft forever preferred_lft forever
+2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1
+ valid_lft forever preferred_lft forever
+ inet6 fe80::42:daff:fe33:130a/64 scope link tentative
+ valid_lft forever preferred_lft forever`,
+ want: map[string]DeviceInfo{
+ "lo": DeviceInfo{
+ ID: 1,
+ MAC: mustParseMAC("00:00:00:00:00:00"),
+ IPv4Addr: net.IPv4(127, 0, 0, 1),
+ IPv4Net: &net.IPNet{
+ IP: net.IPv4(127, 0, 0, 0),
+ Mask: net.CIDRMask(8, 32),
+ },
+ IPv6Addr: net.ParseIP("::1"),
+ IPv6Net: &net.IPNet{
+ IP: net.ParseIP("::1"),
+ Mask: net.CIDRMask(128, 128),
+ },
+ },
+ "eth0": DeviceInfo{
+ ID: 2613,
+ MAC: mustParseMAC("02:42:c0:a8:09:02"),
+ IPv4Addr: net.IPv4(192, 168, 9, 2),
+ IPv4Net: &net.IPNet{
+ IP: net.IPv4(192, 168, 9, 0),
+ Mask: net.CIDRMask(24, 32),
+ },
+ IPv6Addr: net.ParseIP("fe80::42:c0ff:fea8:902"),
+ IPv6Net: &net.IPNet{
+ IP: net.ParseIP("fe80::"),
+ Mask: net.CIDRMask(64, 128),
+ },
+ },
+ "eth1": DeviceInfo{
+ ID: 2617,
+ MAC: mustParseMAC("02:42:da:33:13:0a"),
+ IPv4Addr: net.IPv4(218, 51, 19, 10),
+ IPv4Net: &net.IPNet{
+ IP: net.IPv4(218, 51, 19, 0),
+ Mask: net.CIDRMask(24, 32),
+ },
+ IPv6Addr: net.ParseIP("fe80::42:daff:fe33:130a"),
+ IPv6Net: &net.IPNet{
+ IP: net.ParseIP("fe80::"),
+ Mask: net.CIDRMask(64, 128),
+ },
+ },
+ "eth2": DeviceInfo{
+ ID: 2615,
+ MAC: mustParseMAC("02:42:df:f5:e1:0a"),
+ IPv4Addr: net.IPv4(223, 245, 225, 10),
+ IPv4Net: &net.IPNet{
+ IP: net.IPv4(223, 245, 225, 0),
+ Mask: net.CIDRMask(24, 32),
+ },
+ IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"),
+ IPv6Net: &net.IPNet{
+ IP: net.ParseIP("fe80::"),
+ Mask: net.CIDRMask(64, 128),
+ },
+ },
+ },
+ },
+ {
+ desc: "v4 only",
+ cmdOutput: `
+2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0
+ valid_lft forever preferred_lft forever`,
+ want: map[string]DeviceInfo{
+ "eth0": DeviceInfo{
+ ID: 2613,
+ MAC: mustParseMAC("02:42:c0:a8:09:02"),
+ IPv4Addr: net.IPv4(192, 168, 9, 2),
+ IPv4Net: &net.IPNet{
+ IP: net.IPv4(192, 168, 9, 0),
+ Mask: net.CIDRMask(24, 32),
+ },
+ },
+ },
+ },
+ {
+ desc: "v6 only",
+ cmdOutput: `
+2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative
+ valid_lft forever preferred_lft forever`,
+ want: map[string]DeviceInfo{
+ "eth2": DeviceInfo{
+ ID: 2615,
+ MAC: mustParseMAC("02:42:df:f5:e1:0a"),
+ IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"),
+ IPv6Net: &net.IPNet{
+ IP: net.ParseIP("fe80::"),
+ Mask: net.CIDRMask(64, 128),
+ },
+ },
+ },
+ },
+ } {
+ t.Run(v.desc, func(t *testing.T) {
+ got, err := ParseDevices(v.cmdOutput)
+ if err != nil {
+ t.Errorf("ParseDevices(\n%s\n) got unexpected error: %s", v.cmdOutput, err)
+ }
+ if diff := cmp.Diff(v.want, got); diff != "" {
+ t.Errorf("ParseDevices(\n%s\n) got output diff (-want, +got):\n%s", v.cmdOutput, diff)
+ }
+ })
+ }
+}
+
+func TestParseDevicesErrors(t *testing.T) {
+ for _, v := range []struct {
+ desc string
+ cmdOutput string
+ }{
+ {
+ desc: "invalid MAC addr",
+ cmdOutput: `
+2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:da:33:13:0a:ffffffff brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1
+ valid_lft forever preferred_lft forever
+ inet6 fe80::42:daff:fe33:130a/64 scope link tentative
+ valid_lft forever preferred_lft forever`,
+ },
+ {
+ desc: "invalid v4 addr",
+ cmdOutput: `
+2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 1234.4321.424242.0/24 brd 218.51.19.255 scope global eth1
+ valid_lft forever preferred_lft forever
+ inet6 fe80::42:daff:fe33:130a/64 scope link tentative
+ valid_lft forever preferred_lft forever`,
+ },
+ {
+ desc: "invalid v6 addr",
+ cmdOutput: `
+2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1
+ valid_lft forever preferred_lft forever
+ inet6 fe80:ffffffff::42:daff:fe33:130a/64 scope link tentative
+ valid_lft forever preferred_lft forever`,
+ },
+ {
+ desc: "invalid CIDR missing prefixlen",
+ cmdOutput: `
+2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default
+ link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0
+ inet 218.51.19.10 brd 218.51.19.255 scope global eth1
+ valid_lft forever preferred_lft forever
+ inet6 fe80::42:daff:fe33:130a scope link tentative
+ valid_lft forever preferred_lft forever`,
+ },
+ } {
+ t.Run(v.desc, func(t *testing.T) {
+ if _, err := ParseDevices(v.cmdOutput); err == nil {
+ t.Errorf("ParseDevices(\n%s\n) succeeded unexpectedly, want error", v.cmdOutput)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/proto/BUILD b/test/packetimpact/proto/BUILD
new file mode 100644
index 000000000..4a4370f42
--- /dev/null
+++ b/test/packetimpact/proto/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "proto_library")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+proto_library(
+ name = "posix_server",
+ srcs = ["posix_server.proto"],
+ has_services = 1,
+)
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
new file mode 100644
index 000000000..ccd20b10d
--- /dev/null
+++ b/test/packetimpact/proto/posix_server.proto
@@ -0,0 +1,230 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+
+package posix_server;
+
+message SockaddrIn {
+ int32 family = 1;
+ uint32 port = 2;
+ bytes addr = 3;
+}
+
+message SockaddrIn6 {
+ uint32 family = 1;
+ uint32 port = 2;
+ uint32 flowinfo = 3;
+ bytes addr = 4;
+ uint32 scope_id = 5;
+}
+
+message Sockaddr {
+ oneof sockaddr {
+ SockaddrIn in = 1;
+ SockaddrIn6 in6 = 2;
+ }
+}
+
+message Timeval {
+ int64 seconds = 1;
+ int64 microseconds = 2;
+}
+
+message SockOptVal {
+ oneof val {
+ bytes bytesval = 1;
+ int32 intval = 2;
+ Timeval timeval = 3;
+ }
+}
+
+// Request and Response pairs for each Posix service RPC call, sorted.
+
+message AcceptRequest {
+ int32 sockfd = 1;
+}
+
+message AcceptResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
+message BindRequest {
+ int32 sockfd = 1;
+ Sockaddr addr = 2;
+}
+
+message BindResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message CloseRequest {
+ int32 fd = 1;
+}
+
+message CloseResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message ConnectRequest {
+ int32 sockfd = 1;
+ Sockaddr addr = 2;
+}
+
+message ConnectResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message FcntlRequest {
+ int32 fd = 1;
+ int32 cmd = 2;
+ int32 arg = 3;
+}
+
+message FcntlResponse {
+ int32 ret = 1;
+ int32 errno_ = 2;
+}
+
+message GetSockNameRequest {
+ int32 sockfd = 1;
+}
+
+message GetSockNameResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ Sockaddr addr = 3;
+}
+
+message GetSockOptRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ int32 optlen = 4;
+ enum SockOptType {
+ UNSPECIFIED = 0;
+ BYTES = 1;
+ INT = 2;
+ TIME = 3;
+ }
+ SockOptType type = 5;
+}
+
+message GetSockOptResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ SockOptVal optval = 3;
+}
+
+message ListenRequest {
+ int32 sockfd = 1;
+ int32 backlog = 2;
+}
+
+message ListenResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SendRequest {
+ int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
+}
+
+message SendResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SendToRequest {
+ int32 sockfd = 1;
+ bytes buf = 2;
+ int32 flags = 3;
+ Sockaddr dest_addr = 4;
+}
+
+message SendToResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SetSockOptRequest {
+ int32 sockfd = 1;
+ int32 level = 2;
+ int32 optname = 3;
+ SockOptVal optval = 4;
+}
+
+message SetSockOptResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message SocketRequest {
+ int32 domain = 1;
+ int32 type = 2;
+ int32 protocol = 3;
+}
+
+message SocketResponse {
+ int32 fd = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+}
+
+message RecvRequest {
+ int32 sockfd = 1;
+ int32 len = 2;
+ int32 flags = 3;
+}
+
+message RecvResponse {
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
+ bytes buf = 3;
+}
+
+service Posix {
+ // Call accept() on the DUT.
+ rpc Accept(AcceptRequest) returns (AcceptResponse);
+ // Call bind() on the DUT.
+ rpc Bind(BindRequest) returns (BindResponse);
+ // Call close() on the DUT.
+ rpc Close(CloseRequest) returns (CloseResponse);
+ // Call connect() on the DUT.
+ rpc Connect(ConnectRequest) returns (ConnectResponse);
+ // Call fcntl() on the DUT.
+ rpc Fcntl(FcntlRequest) returns (FcntlResponse);
+ // Call getsockname() on the DUT.
+ rpc GetSockName(GetSockNameRequest) returns (GetSockNameResponse);
+ // Call getsockopt() on the DUT.
+ rpc GetSockOpt(GetSockOptRequest) returns (GetSockOptResponse);
+ // Call listen() on the DUT.
+ rpc Listen(ListenRequest) returns (ListenResponse);
+ // Call send() on the DUT.
+ rpc Send(SendRequest) returns (SendResponse);
+ // Call sendto() on the DUT.
+ rpc SendTo(SendToRequest) returns (SendToResponse);
+ // Call setsockopt() on the DUT.
+ rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse);
+ // Call socket() on the DUT.
+ rpc Socket(SocketRequest) returns (SocketResponse);
+ // Call recv() on the DUT.
+ rpc Recv(RecvRequest) returns (RecvResponse);
+}
diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD
new file mode 100644
index 000000000..ff2be9b30
--- /dev/null
+++ b/test/packetimpact/runner/BUILD
@@ -0,0 +1,27 @@
+load("//tools:defs.bzl", "bzl_library", "go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+go_test(
+ name = "packetimpact_test",
+ srcs = ["packetimpact_test.go"],
+ tags = [
+ # Not intended to be run directly.
+ "local",
+ "manual",
+ ],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/packetimpact/netdevs",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
+ ],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
new file mode 100644
index 000000000..93a36c6c2
--- /dev/null
+++ b/test/packetimpact/runner/defs.bzl
@@ -0,0 +1,143 @@
+"""Defines rules for packetimpact test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def _packetimpact_test_impl(ctx):
+ test_runner = ctx.executable._test_runner
+ bench = ctx.actions.declare_file("%s-bench" % ctx.label.name)
+ bench_content = "\n".join([
+ "#!/bin/bash",
+ # This test will run part in a distinct user namespace. This can cause
+ # permission problems, because all runfiles may not be owned by the
+ # current user, and no other users will be mapped in that namespace.
+ # Make sure that everything is readable here.
+ "find . -type f -or -type d -exec chmod a+rx {} \\;",
+ "%s %s --testbench_binary %s $@\n" % (
+ test_runner.short_path,
+ " ".join(ctx.attr.flags),
+ ctx.files.testbench_binary[0].short_path,
+ ),
+ ])
+ ctx.actions.write(bench, bench_content, is_executable = True)
+
+ transitive_files = []
+ if hasattr(ctx.attr._test_runner, "data_runfiles"):
+ transitive_files.append(ctx.attr._test_runner.data_runfiles.files)
+ runfiles = ctx.runfiles(
+ files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary,
+ transitive_files = depset(transitive = transitive_files),
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = bench, runfiles = runfiles)]
+
+_packetimpact_test = rule(
+ attrs = {
+ "_test_runner": attr.label(
+ executable = True,
+ cfg = "target",
+ default = ":packetimpact_test",
+ ),
+ "_posix_server_binary": attr.label(
+ cfg = "target",
+ default = "//test/packetimpact/dut:posix_server",
+ ),
+ "testbench_binary": attr.label(
+ cfg = "target",
+ mandatory = True,
+ ),
+ "flags": attr.string_list(
+ mandatory = False,
+ default = [],
+ ),
+ },
+ test = True,
+ implementation = _packetimpact_test_impl,
+)
+
+PACKETIMPACT_TAGS = [
+ "local",
+ "manual",
+ "packetimpact",
+]
+
+def packetimpact_native_test(
+ name,
+ testbench_binary,
+ expect_failure = False,
+ **kwargs):
+ """Add a native packetimpact test.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ expect_failure: the test must fail
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ expect_failure_flag = ["--expect_failure"] if expect_failure else []
+ _packetimpact_test(
+ name = name + "_native_test",
+ testbench_binary = testbench_binary,
+ flags = ["--native"] + expect_failure_flag,
+ tags = PACKETIMPACT_TAGS,
+ **kwargs
+ )
+
+def packetimpact_netstack_test(
+ name,
+ testbench_binary,
+ expect_failure = False,
+ **kwargs):
+ """Add a packetimpact test on netstack.
+
+ Args:
+ name: name of the test
+ testbench_binary: the testbench binary
+ expect_failure: the test must fail
+ **kwargs: all the other args, forwarded to _packetimpact_test
+ """
+ expect_failure_flag = []
+ if expect_failure:
+ expect_failure_flag = ["--expect_failure"]
+ _packetimpact_test(
+ name = name + "_netstack_test",
+ testbench_binary = testbench_binary,
+ # Note that a distinct runtime must be provided in the form
+ # --test_arg=--runtime=other when invoking bazel.
+ flags = expect_failure_flag,
+ tags = PACKETIMPACT_TAGS,
+ **kwargs
+ )
+
+def packetimpact_go_test(name, size = "small", pure = True, expect_native_failure = False, expect_netstack_failure = False, **kwargs):
+ """Add packetimpact tests written in go.
+
+ Args:
+ name: name of the test
+ size: size of the test
+ pure: make a static go binary
+ expect_native_failure: the test must fail natively
+ expect_netstack_failure: the test must fail for Netstack
+ **kwargs: all the other args, forwarded to go_test
+ """
+ testbench_binary = name + "_test"
+ go_test(
+ name = testbench_binary,
+ size = size,
+ pure = pure,
+ tags = [
+ "local",
+ "manual",
+ ],
+ **kwargs
+ )
+ packetimpact_native_test(
+ name = name,
+ expect_failure = expect_native_failure,
+ testbench_binary = testbench_binary,
+ )
+ packetimpact_netstack_test(
+ name = name,
+ expect_failure = expect_netstack_failure,
+ testbench_binary = testbench_binary,
+ )
diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go
new file mode 100644
index 000000000..e8c183977
--- /dev/null
+++ b/test/packetimpact/runner/packetimpact_test.go
@@ -0,0 +1,383 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// The runner starts docker containers and networking for a packetimpact test.
+package packetimpact_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "os/exec"
+ "path"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/docker/docker/api/types/mount"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
+)
+
+// stringList implements flag.Value.
+type stringList []string
+
+// String implements flag.Value.String.
+func (l *stringList) String() string {
+ return strings.Join(*l, ",")
+}
+
+// Set implements flag.Value.Set.
+func (l *stringList) Set(value string) error {
+ *l = append(*l, value)
+ return nil
+}
+
+var (
+ native = flag.Bool("native", false, "whether the test should be run natively")
+ testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary")
+ tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump")
+ extraTestArgs = stringList{}
+ expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run")
+
+ dutAddr = net.IPv4(0, 0, 0, 10)
+ testbenchAddr = net.IPv4(0, 0, 0, 20)
+)
+
+const ctrlPort = "40000"
+
+// logger implements testutil.Logger.
+//
+// Labels logs based on their source and formats multi-line logs.
+type logger string
+
+// Name implements testutil.Logger.Name.
+func (l logger) Name() string {
+ return string(l)
+}
+
+// Logf implements testutil.Logger.Logf.
+func (l logger) Logf(format string, args ...interface{}) {
+ lines := strings.Split(fmt.Sprintf(format, args...), "\n")
+ log.Printf("%s: %s", l, lines[0])
+ for _, line := range lines[1:] {
+ log.Printf("%*s %s", len(l), "", line)
+ }
+}
+
+func TestOne(t *testing.T) {
+ flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench")
+ flag.Parse()
+ if *testbenchBinary == "" {
+ t.Fatal("--testbench_binary is missing")
+ }
+ dockerutil.EnsureSupportedDockerVersion()
+ ctx := context.Background()
+
+ // Create the networks needed for the test. One control network is needed for
+ // the gRPC control packets and one test network on which to transmit the test
+ // packets.
+ ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet"))
+ testNet := dockerutil.NewNetwork(ctx, logger("testNet"))
+ for _, dn := range []*dockerutil.Network{ctrlNet, testNet} {
+ for {
+ if err := createDockerNetwork(ctx, dn); err != nil {
+ t.Log("creating docker network:", err)
+ const wait = 100 * time.Millisecond
+ t.Logf("sleeping %s and will try creating docker network again", wait)
+ // This can fail if another docker network claimed the same IP so we'll
+ // just try again.
+ time.Sleep(wait)
+ continue
+ }
+ break
+ }
+ defer func(dn *dockerutil.Network) {
+ if err := dn.Cleanup(ctx); err != nil {
+ t.Errorf("unable to cleanup container %s: %s", dn.Name, err)
+ }
+ }(dn)
+ // Sanity check.
+ inspect, err := dn.Inspect(ctx)
+ if err != nil {
+ t.Fatalf("failed to inspect network %s: %v", dn.Name, err)
+ } else if inspect.Name != dn.Name {
+ t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name)
+ }
+
+ }
+
+ tmpDir, err := ioutil.TempDir("", "container-output")
+ if err != nil {
+ t.Fatal("creating temp dir:", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ const testOutputDir = "/tmp/testoutput"
+
+ // Create the Docker container for the DUT.
+ var dut *dockerutil.Container
+ if *native {
+ dut = dockerutil.MakeNativeContainer(ctx, logger("dut"))
+ } else {
+ dut = dockerutil.MakeContainer(ctx, logger("dut"))
+ }
+
+ runOpts := dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ Mounts: []mount.Mount{mount.Mount{
+ Type: mount.TypeBind,
+ Source: tmpDir,
+ Target: testOutputDir,
+ ReadOnly: false,
+ }},
+ }
+
+ const containerPosixServerBinary = "/packetimpact/posix_server"
+ dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server")
+
+ conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort)
+ hostconf.AutoRemove = true
+ hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
+
+ if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("unable to create container %s: %v", dut.Name, err)
+ }
+
+ defer dut.CleanUp(ctx)
+
+ // Add ctrlNet as eth1 and testNet as eth2.
+ const testNetDev = "eth2"
+ if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := dut.Start(ctx); err != nil {
+ t.Fatalf("unable to start container %s: %s", dut.Name, err)
+ }
+
+ if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil {
+ t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err)
+ }
+
+ dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ remoteMAC := dutDeviceInfo.MAC
+ remoteIPv6 := dutDeviceInfo.IPv6Addr
+ // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if
+ // needed.
+ if remoteIPv6 == nil {
+ if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil {
+ t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err)
+ }
+ // Now try again, to make sure that it worked.
+ _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet))
+ if err != nil {
+ t.Fatal(err)
+ }
+ remoteIPv6 = dutDeviceInfo.IPv6Addr
+ if remoteIPv6 == nil {
+ t.Fatal("unable to set IPv6 address on container", dut.Name)
+ }
+ }
+
+ // Create the Docker container for the testbench.
+ testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
+
+ tbb := path.Base(*testbenchBinary)
+ containerTestbenchBinary := "/packetimpact/" + tbb
+ runOpts = dockerutil.RunOpts{
+ Image: "packetimpact",
+ CapAdd: []string{"NET_ADMIN"},
+ Mounts: []mount.Mount{mount.Mount{
+ Type: mount.TypeBind,
+ Source: tmpDir,
+ Target: testOutputDir,
+ ReadOnly: false,
+ }},
+ }
+ testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb)
+
+ // Run tcpdump in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs := []string{
+ "tcpdump",
+ "-S", "-vvv", "-U", "-n",
+ "-i", testNetDev,
+ "-w", testOutputDir + "/dump.pcap",
+ }
+ snifferRegex := "tcpdump: listening.*\n"
+ if *tshark {
+ // Run tshark in the test bench unbuffered, without DNS resolution, just on
+ // the interface with the test packets.
+ snifferArgs = []string{
+ "tshark", "-V", "-l", "-n", "-i", testNetDev,
+ "-o", "tcp.check_checksum:TRUE",
+ "-o", "udp.check_checksum:TRUE",
+ }
+ snifferRegex = "Capturing on.*\n"
+ }
+
+ defer func() {
+ if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil {
+ t.Error("unable to copy container output files:", err)
+ }
+ }()
+
+ conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...)
+ hostconf.AutoRemove = true
+ hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
+
+ if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("unable to create container %s: %s", testbench.Name, err)
+ }
+ defer testbench.CleanUp(ctx)
+
+ // Add ctrlNet as eth1 and testNet as eth2.
+ if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := testbench.Start(ctx); err != nil {
+ t.Fatalf("unable to start container %s: %s", testbench.Name, err)
+ }
+
+ // Kill so that it will flush output.
+ defer func() {
+ time.Sleep(1 * time.Second)
+ testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0])
+ }()
+
+ if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil {
+ t.Fatalf("sniffer on %s never listened: %s", dut.Name, err)
+ }
+
+ // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it
+ // will issue an RST. To prevent this IPtables can be used to filter out all
+ // incoming packets. The raw socket that packetimpact tests use will still see
+ // everything.
+ for _, bin := range []string{"iptables", "ip6tables"} {
+ if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil {
+ t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs)
+ }
+ }
+
+ // FIXME(b/156449515): Some piece of the system has a race. The old
+ // bash script version had a sleep, so we have one too. The race should
+ // be fixed and this sleep removed.
+ time.Sleep(time.Second)
+
+ // Start a packetimpact test on the test bench. The packetimpact test sends
+ // and receives packets and also sends POSIX socket commands to the
+ // posix_server to be executed on the DUT.
+ testArgs := []string{containerTestbenchBinary}
+ testArgs = append(testArgs, extraTestArgs...)
+ testArgs = append(testArgs,
+ "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(),
+ "--posix_server_port", ctrlPort,
+ "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(),
+ "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(),
+ "--remote_ipv6", remoteIPv6.String(),
+ "--remote_mac", remoteMAC.String(),
+ "--remote_interface_id", fmt.Sprintf("%d", dutDeviceInfo.ID),
+ "--device", testNetDev,
+ fmt.Sprintf("--native=%t", *native),
+ )
+ testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
+ if (err != nil) != *expectFailure {
+ var dutLogs string
+ if logs, err := dut.Logs(ctx); err != nil {
+ dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err)
+ } else {
+ dutLogs = logs
+ }
+
+ t.Errorf(`test error: %v, expect failure: %t
+
+====== Begin of DUT Logs ======
+
+%s
+
+====== End of DUT Logs ======
+
+====== Begin of Testbench Logs ======
+
+%s
+
+====== End of Testbench Logs ======`,
+ err, *expectFailure, dutLogs, testbenchLogs)
+ }
+}
+
+func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error {
+ for _, dn := range networks {
+ ip := addressInSubnet(addr, *dn.Subnet)
+ // Connect to the network with the specified IP address.
+ if err := dn.Connect(ctx, d, ip.String(), ""); err != nil {
+ return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err)
+ }
+ }
+ return nil
+}
+
+// addressInSubnet combines the subnet provided with the address and returns a
+// new address. The return address bits come from the subnet where the mask is 1
+// and from the ip address where the mask is 0.
+func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP {
+ var octets []byte
+ for i := 0; i < 4; i++ {
+ octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i])))
+ }
+ return net.IP(octets)
+}
+
+// createDockerNetwork makes a randomly-named network that will start with the
+// namePrefix. The network will be a random /24 subnet.
+func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
+ randSource := rand.NewSource(time.Now().UnixNano())
+ r1 := rand.New(randSource)
+ // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24.
+ ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0)
+ n.Subnet = &net.IPNet{
+ IP: ip,
+ Mask: ip.DefaultMask(),
+ }
+ return n.Create(ctx)
+}
+
+// deviceByIP finds a deviceInfo and device name from an IP address.
+func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) {
+ out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show")
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err)
+ }
+ devs, err := netdevs.ParseDevices(out)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err)
+ }
+ testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs)
+ if err != nil {
+ return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err)
+ }
+ return testDevice, deviceInfo, nil
+}
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
new file mode 100644
index 000000000..5a0ee1367
--- /dev/null
+++ b/test/packetimpact/testbench/BUILD
@@ -0,0 +1,46 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+go_library(
+ name = "testbench",
+ srcs = [
+ "connections.go",
+ "dut.go",
+ "dut_client.go",
+ "layers.go",
+ "rawsockets.go",
+ "testbench.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//pkg/usermem",
+ "//test/packetimpact/netdevs",
+ "//test/packetimpact/proto:posix_server_go_proto",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ "@org_golang_google_grpc//:go_default_library",
+ "@org_golang_google_grpc//keepalive:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ "@org_uber_go_multierr//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "testbench_test",
+ size = "small",
+ srcs = ["layers_test.go"],
+ library = ":testbench",
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "@com_github_mohae_deepcopy//:go_default_library",
+ ],
+)
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
new file mode 100644
index 000000000..3af5f83fd
--- /dev/null
+++ b/test/packetimpact/testbench/connections.go
@@ -0,0 +1,1205 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package testbench has utilities to send and receive packets and also command
+// the DUT to run POSIX functions.
+package testbench
+
+import (
+ "fmt"
+ "math/rand"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/mohae/deepcopy"
+ "go.uber.org/multierr"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+)
+
+func portFromSockaddr(sa unix.Sockaddr) (uint16, error) {
+ switch sa := sa.(type) {
+ case *unix.SockaddrInet4:
+ return uint16(sa.Port), nil
+ case *unix.SockaddrInet6:
+ return uint16(sa.Port), nil
+ }
+ return 0, fmt.Errorf("sockaddr type %T does not contain port", sa)
+}
+
+// pickPort makes a new socket and returns the socket FD and port. The domain
+// should be AF_INET or AF_INET6. The caller must close the FD when done with
+// the port if there is no error.
+func pickPort(domain, typ int) (fd int, port uint16, err error) {
+ fd, err = unix.Socket(domain, typ, 0)
+ if err != nil {
+ return -1, 0, fmt.Errorf("creating socket: %w", err)
+ }
+ defer func() {
+ if err != nil {
+ if cerr := unix.Close(fd); cerr != nil {
+ err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr))
+ }
+ }
+ }()
+ var sa unix.Sockaddr
+ switch domain {
+ case unix.AF_INET:
+ var sa4 unix.SockaddrInet4
+ copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4())
+ sa = &sa4
+ case unix.AF_INET6:
+ sa6 := unix.SockaddrInet6{ZoneId: uint32(LocalInterfaceID)}
+ copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16())
+ sa = &sa6
+ default:
+ return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain)
+ }
+ if err = unix.Bind(fd, sa); err != nil {
+ return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err)
+ }
+ sa, err = unix.Getsockname(fd)
+ if err != nil {
+ return -1, 0, fmt.Errorf("Getsocketname(%d): %w", fd, err)
+ }
+ port, err = portFromSockaddr(sa)
+ if err != nil {
+ return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err)
+ }
+ return fd, port, nil
+}
+
+// layerState stores the state of a layer of a connection.
+type layerState interface {
+ // outgoing returns an outgoing layer to be sent in a frame. It should not
+ // update layerState, that is done in layerState.sent.
+ outgoing() Layer
+
+ // incoming creates an expected Layer for comparing against a received Layer.
+ // Because the expectation can depend on values in the received Layer, it is
+ // an input to incoming. For example, the ACK number needs to be checked in a
+ // TCP packet but only if the ACK flag is set in the received packet. It
+ // should not update layerState, that is done in layerState.received. The
+ // caller takes ownership of the returned Layer.
+ incoming(received Layer) Layer
+
+ // sent updates the layerState based on the Layer that was sent. The input is
+ // a Layer with all prev and next pointers populated so that the entire frame
+ // as it was sent is available.
+ sent(sent Layer) error
+
+ // received updates the layerState based on a Layer that is receieved. The
+ // input is a Layer with all prev and next pointers populated so that the
+ // entire frame as it was receieved is available.
+ received(received Layer) error
+
+ // close frees associated resources held by the LayerState.
+ close() error
+}
+
+// etherState maintains state about an Ethernet connection.
+type etherState struct {
+ out, in Ether
+}
+
+var _ layerState = (*etherState)(nil)
+
+// newEtherState creates a new etherState.
+func newEtherState(out, in Ether) (*etherState, error) {
+ lMAC, err := tcpip.ParseMACAddress(LocalMAC)
+ if err != nil {
+ return nil, fmt.Errorf("parsing local MAC: %q: %w", LocalMAC, err)
+ }
+
+ rMAC, err := tcpip.ParseMACAddress(RemoteMAC)
+ if err != nil {
+ return nil, fmt.Errorf("parsing remote MAC: %q: %w", RemoteMAC, err)
+ }
+ s := etherState{
+ out: Ether{SrcAddr: &lMAC, DstAddr: &rMAC},
+ in: Ether{SrcAddr: &rMAC, DstAddr: &lMAC},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *etherState) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+// incoming implements layerState.incoming.
+func (s *etherState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*etherState) sent(Layer) error {
+ return nil
+}
+
+func (*etherState) received(Layer) error {
+ return nil
+}
+
+func (*etherState) close() error {
+ return nil
+}
+
+// ipv4State maintains state about an IPv4 connection.
+type ipv4State struct {
+ out, in IPv4
+}
+
+var _ layerState = (*ipv4State)(nil)
+
+// newIPv4State creates a new ipv4State.
+func newIPv4State(out, in IPv4) (*ipv4State, error) {
+ lIP := tcpip.Address(net.ParseIP(LocalIPv4).To4())
+ rIP := tcpip.Address(net.ParseIP(RemoteIPv4).To4())
+ s := ipv4State{
+ out: IPv4{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv4{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *ipv4State) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+// incoming implements layerState.incoming.
+func (s *ipv4State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*ipv4State) sent(Layer) error {
+ return nil
+}
+
+func (*ipv4State) received(Layer) error {
+ return nil
+}
+
+func (*ipv4State) close() error {
+ return nil
+}
+
+// ipv6State maintains state about an IPv6 connection.
+type ipv6State struct {
+ out, in IPv6
+}
+
+var _ layerState = (*ipv6State)(nil)
+
+// newIPv6State creates a new ipv6State.
+func newIPv6State(out, in IPv6) (*ipv6State, error) {
+ lIP := tcpip.Address(net.ParseIP(LocalIPv6).To16())
+ rIP := tcpip.Address(net.ParseIP(RemoteIPv6).To16())
+ s := ipv6State{
+ out: IPv6{SrcAddr: &lIP, DstAddr: &rIP},
+ in: IPv6{SrcAddr: &rIP, DstAddr: &lIP},
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+// outgoing returns an outgoing layer to be sent in a frame.
+func (s *ipv6State) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+func (s *ipv6State) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (s *ipv6State) sent(Layer) error {
+ // Nothing to do.
+ return nil
+}
+
+func (s *ipv6State) received(Layer) error {
+ // Nothing to do.
+ return nil
+}
+
+// close cleans up any resources held.
+func (s *ipv6State) close() error {
+ return nil
+}
+
+// tcpState maintains state about a TCP connection.
+type tcpState struct {
+ out, in TCP
+ localSeqNum, remoteSeqNum *seqnum.Value
+ synAck *TCP
+ portPickerFD int
+ finSent bool
+}
+
+var _ layerState = (*tcpState)(nil)
+
+// SeqNumValue is a helper routine that allocates a new seqnum.Value value to
+// store v and returns a pointer to it.
+func SeqNumValue(v seqnum.Value) *seqnum.Value {
+ return &v
+}
+
+// newTCPState creates a new TCPState.
+func newTCPState(domain int, out, in TCP) (*tcpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_STREAM)
+ if err != nil {
+ return nil, err
+ }
+ s := tcpState{
+ out: TCP{SrcPort: &localPort},
+ in: TCP{DstPort: &localPort},
+ localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())),
+ portPickerFD: portPickerFD,
+ finSent: false,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *tcpState) outgoing() Layer {
+ newOutgoing := deepcopy.Copy(s.out).(TCP)
+ if s.localSeqNum != nil {
+ newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum))
+ }
+ if s.remoteSeqNum != nil {
+ newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ return &newOutgoing
+}
+
+// incoming implements layerState.incoming.
+func (s *tcpState) incoming(received Layer) Layer {
+ tcpReceived, ok := received.(*TCP)
+ if !ok {
+ return nil
+ }
+ newIn := deepcopy.Copy(s.in).(TCP)
+ if s.remoteSeqNum != nil {
+ newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum))
+ }
+ if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 {
+ // The caller didn't specify an AckNum so we'll expect the calculated one,
+ // but only if the ACK flag is set because the AckNum is not valid in a
+ // header if ACK is not set.
+ newIn.AckNum = Uint32(uint32(*s.localSeqNum))
+ }
+ return &newIn
+}
+
+func (s *tcpState) sent(sent Layer) error {
+ tcp, ok := sent.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", sent)
+ }
+ if !s.finSent {
+ // update localSeqNum by the payload only when FIN is not yet sent by us
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.localSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ }
+ if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.localSeqNum.UpdateForward(1)
+ }
+ if *tcp.Flags&(header.TCPFlagFin) != 0 {
+ s.finSent = true
+ }
+ return nil
+}
+
+func (s *tcpState) received(l Layer) error {
+ tcp, ok := l.(*TCP)
+ if !ok {
+ return fmt.Errorf("can't update tcpState with %T Layer", l)
+ }
+ s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum))
+ if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 {
+ s.remoteSeqNum.UpdateForward(1)
+ }
+ for current := tcp.next(); current != nil; current = current.next() {
+ s.remoteSeqNum.UpdateForward(seqnum.Size(current.length()))
+ }
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *tcpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// udpState maintains state about a UDP connection.
+type udpState struct {
+ out, in UDP
+ portPickerFD int
+}
+
+var _ layerState = (*udpState)(nil)
+
+// newUDPState creates a new udpState.
+func newUDPState(domain int, out, in UDP) (*udpState, error) {
+ portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM)
+ if err != nil {
+ return nil, fmt.Errorf("picking port: %w", err)
+ }
+ s := udpState{
+ out: UDP{SrcPort: &localPort},
+ in: UDP{DstPort: &localPort},
+ portPickerFD: portPickerFD,
+ }
+ if err := s.out.merge(&out); err != nil {
+ return nil, err
+ }
+ if err := s.in.merge(&in); err != nil {
+ return nil, err
+ }
+ return &s, nil
+}
+
+func (s *udpState) outgoing() Layer {
+ return deepcopy.Copy(&s.out).(Layer)
+}
+
+// incoming implements layerState.incoming.
+func (s *udpState) incoming(Layer) Layer {
+ return deepcopy.Copy(&s.in).(Layer)
+}
+
+func (*udpState) sent(l Layer) error {
+ return nil
+}
+
+func (*udpState) received(l Layer) error {
+ return nil
+}
+
+// close frees the port associated with this connection.
+func (s *udpState) close() error {
+ if err := unix.Close(s.portPickerFD); err != nil {
+ return err
+ }
+ s.portPickerFD = -1
+ return nil
+}
+
+// Connection holds a collection of layer states for maintaining a connection
+// along with sockets for sniffer and injecting packets.
+type Connection struct {
+ layerStates []layerState
+ injector Injector
+ sniffer Sniffer
+}
+
+// Returns the default incoming frame against which to match. If received is
+// longer than layerStates then that may still count as a match. The reverse is
+// never a match and nil is returned.
+func (conn *Connection) incoming(received Layers) Layers {
+ if len(received) < len(conn.layerStates) {
+ return nil
+ }
+ in := Layers{}
+ for i, s := range conn.layerStates {
+ toMatch := s.incoming(received[i])
+ if toMatch == nil {
+ return nil
+ }
+ in = append(in, toMatch)
+ }
+ return in
+}
+
+func (conn *Connection) match(override, received Layers) bool {
+ toMatch := conn.incoming(received)
+ if toMatch == nil {
+ return false // Not enough layers in gotLayers for matching.
+ }
+ if err := toMatch.merge(override); err != nil {
+ return false // Failing to merge is not matching.
+ }
+ return toMatch.match(received)
+}
+
+// Close frees associated resources held by the Connection.
+func (conn *Connection) Close(t *testing.T) {
+ t.Helper()
+
+ errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
+ for _, s := range conn.layerStates {
+ if err := s.close(); err != nil {
+ errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err))
+ }
+ }
+ if errs != nil {
+ t.Fatalf("unable to close %+v: %s", conn, errs)
+ }
+}
+
+// CreateFrame builds a frame for the connection with defaults overriden
+// from the innermost layer out, and additionalLayers added after it.
+//
+// Note that overrideLayers can have a length that is less than the number
+// of layers in this connection, and in such cases the innermost layers are
+// overriden first. As an example, valid values of overrideLayers for a TCP-
+// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and
+// [Ethernet, IPv4, TCP].
+func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers {
+ t.Helper()
+
+ var layersToSend Layers
+ for i, s := range conn.layerStates {
+ layer := s.outgoing()
+ // overrideLayers and conn.layerStates have their tails aligned, so
+ // to find the index we move backwards by the distance i is to the
+ // end.
+ if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 {
+ if err := layer.merge(overrideLayers[j]); err != nil {
+ t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
+ }
+ }
+ layersToSend = append(layersToSend, layer)
+ }
+ layersToSend = append(layersToSend, additionalLayers...)
+ return layersToSend
+}
+
+// SendFrameStateless sends a frame without updating any of the layer states.
+//
+// This method is useful for sending out-of-band control messages such as
+// ICMP packets, where it would not make sense to update the transport layer's
+// state using the ICMP header.
+func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) {
+ t.Helper()
+
+ outBytes, err := frame.ToBytes()
+ if err != nil {
+ t.Fatalf("can't build outgoing packet: %s", err)
+ }
+ conn.injector.Send(t, outBytes)
+}
+
+// SendFrame sends a frame on the wire and updates the state of all layers.
+func (conn *Connection) SendFrame(t *testing.T, frame Layers) {
+ t.Helper()
+
+ outBytes, err := frame.ToBytes()
+ if err != nil {
+ t.Fatalf("can't build outgoing packet: %s", err)
+ }
+ conn.injector.Send(t, outBytes)
+
+ // frame might have nil values where the caller wanted to use default values.
+ // sentFrame will have no nil values in it because it comes from parsing the
+ // bytes that were actually sent.
+ sentFrame := parse(parseEther, outBytes)
+ // Update the state of each layer based on what was sent.
+ for i, s := range conn.layerStates {
+ if err := s.sent(sentFrame[i]); err != nil {
+ t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
+ }
+ }
+}
+
+// send sends a packet, possibly with layers of this connection overridden and
+// additional layers added.
+//
+// Types defined with Connection as the underlying type should expose
+// type-safe versions of this method.
+func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
+ t.Helper()
+
+ conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...))
+}
+
+// recvFrame gets the next successfully parsed frame (of type Layers) within the
+// timeout provided. If no parsable frame arrives before the timeout, it returns
+// nil.
+func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers {
+ t.Helper()
+
+ if timeout <= 0 {
+ return nil
+ }
+ b := conn.sniffer.Recv(t, timeout)
+ if b == nil {
+ return nil
+ }
+ return parse(parseEther, b)
+}
+
+// layersError stores the Layers that we got and the Layers that we wanted to
+// match.
+type layersError struct {
+ got, want Layers
+}
+
+func (e *layersError) Error() string {
+ return e.got.diff(e.want)
+}
+
+// Expect expects a frame with the final layerStates layer matching the
+// provided Layer within the timeout specified. If it doesn't arrive in time,
+// an error is returned.
+func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) {
+ t.Helper()
+
+ // Make a frame that will ignore all but the final layer.
+ layers := make([]Layer, len(conn.layerStates))
+ layers[len(layers)-1] = layer
+
+ gotFrame, err := conn.ExpectFrame(t, layers, timeout)
+ if err != nil {
+ return nil, err
+ }
+ if len(conn.layerStates)-1 < len(gotFrame) {
+ return gotFrame[len(conn.layerStates)-1], nil
+ }
+ t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame)
+ panic("unreachable")
+}
+
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If one arrives in time, the Layers is returned without an
+// error. If it doesn't arrive in time, it returns nil and error is non-nil.
+func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ deadline := time.Now().Add(timeout)
+ var errs error
+ for {
+ var gotLayers Layers
+ if timeout = time.Until(deadline); timeout > 0 {
+ gotLayers = conn.recvFrame(t, timeout)
+ }
+ if gotLayers == nil {
+ if errs == nil {
+ return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout)
+ }
+ return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs)
+ }
+ if conn.match(layers, gotLayers) {
+ for i, s := range conn.layerStates {
+ if err := s.received(gotLayers[i]); err != nil {
+ t.Fatalf("failed to update test connection's layer states based on received frame: %s", err)
+ }
+ }
+ return gotLayers, nil
+ }
+ errs = multierr.Combine(errs, &layersError{got: gotLayers, want: conn.incoming(gotLayers)})
+ }
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *Connection) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
+}
+
+// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
+type TCPIPv4 Connection
+
+// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
+func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ t.Helper()
+
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ tcpState, err := newTCPState(unix.AF_INET, outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv4{
+ layerStates: []layerState{etherState, ipv4State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+// Connect performs a TCP 3-way handshake. The input Connection should have a
+// final TCP Layer.
+func (conn *TCPIPv4) Connect(t *testing.T) {
+ t.Helper()
+
+ // Send the SYN.
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
+// The input Connection should have a final TCP Layer.
+func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) {
+ t.Helper()
+
+ // Send the SYN.
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
+
+ // Wait for the SYN-ACK.
+ synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
+
+ // Send an ACK.
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
+}
+
+// ExpectNextData attempts to receive the next incoming segment for the
+// connection and expects that to match the given layers.
+//
+// It differs from ExpectData() in that here we are only interested in the next
+// received segment, while ExpectData() can receive multiple segments for the
+// connection until there is a match with given layers or a timeout.
+func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ // Receive the first incoming TCP segment for this connection.
+ got, err := conn.ExpectData(t, &TCP{}, nil, timeout)
+ if err != nil {
+ return nil, err
+ }
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length()))
+ }
+ if !(*Connection)(conn).match(expected, got) {
+ return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got)
+ }
+ return got, nil
+}
+
+// Send a packet with reasonable defaults. Potentially override the TCP layer in
+// the connection with the provided layer and add additionLayers.
+func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&tcp}, additionalLayers...)
+}
+
+// Close frees associated resources held by the TCPIPv4 connection.
+func (conn *TCPIPv4) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
+}
+
+// Expect expects a frame with the TCP layer matching the provided TCP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &tcp, timeout)
+ if layer == nil {
+ return nil, err
+ }
+ gotTCP, ok := layer.(*TCP)
+ if !ok {
+ t.Fatalf("expected %s to be TCP", layer)
+ }
+ return gotTCP, err
+}
+
+func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState {
+ t.Helper()
+
+ state, ok := conn.layerStates[2].(*tcpState)
+ if !ok {
+ t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State {
+ t.Helper()
+
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ }
+ return state
+}
+
+// RemoteSeqNum returns the next expected sequence number from the DUT.
+func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value {
+ t.Helper()
+
+ return conn.tcpState(t).remoteSeqNum
+}
+
+// LocalSeqNum returns the next sequence number to send from the testbench.
+func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value {
+ t.Helper()
+
+ return conn.tcpState(t).localSeqNum
+}
+
+// SynAck returns the SynAck that was part of the handshake.
+func (conn *TCPIPv4) SynAck(t *testing.T) *TCP {
+ t.Helper()
+
+ return conn.tcpState(t).synAck
+}
+
+// LocalAddr gets the local socket address of this connection.
+func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
+ return sa
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *TCPIPv4) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
+}
+
+// IPv6Conn maintains the state for all the layers in a IPv6 connection.
+type IPv6Conn Connection
+
+// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults.
+func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
+ t.Helper()
+
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make EtherState: %s", err)
+ }
+ ipv6State, err := newIPv6State(outgoingIPv6, incomingIPv6)
+ if err != nil {
+ t.Fatalf("can't make IPv6State: %s", err)
+ }
+
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return IPv6Conn{
+ layerStates: []layerState{etherState, ipv6State},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+// Send sends a frame with ipv6 overriding the IPv6 layer defaults and
+// additionalLayers added after it.
+func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ipv6}, additionalLayers...)
+}
+
+// Close to clean up any resources held.
+func (conn *IPv6Conn) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
+}
+
+// ExpectFrame expects a frame that matches the provided Layers within the
+// timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ return (*Connection)(conn).ExpectFrame(t, frame, timeout)
+}
+
+// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
+type UDPIPv4 Connection
+
+// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
+func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+ t.Helper()
+
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv4State, err := newIPv4State(IPv4{}, IPv4{})
+ if err != nil {
+ t.Fatalf("can't make ipv4State: %s", err)
+ }
+ udpState, err := newUDPState(unix.AF_INET, outgoingUDP, incomingUDP)
+ if err != nil {
+ t.Fatalf("can't make udpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return UDPIPv4{
+ layerStates: []layerState{etherState, ipv4State, udpState},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+func (conn *UDPIPv4) udpState(t *testing.T) *udpState {
+ t.Helper()
+
+ state, ok := conn.layerStates[2].(*udpState)
+ if !ok {
+ t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State {
+ t.Helper()
+
+ state, ok := conn.layerStates[1].(*ipv4State)
+ if !ok {
+ t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ }
+ return state
+}
+
+// LocalAddr gets the local socket address of this connection.
+func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
+ return sa
+}
+
+// Send sends a packet with reasonable defaults, potentially overriding the UDP
+// layer and adding additionLayers.
+func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...)
+}
+
+// SendIP sends a packet with reasonable defaults, potentially overriding the
+// UDP and IPv4 headers and adding additionLayers.
+func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
+}
+
+// Expect expects a frame with the UDP layer matching the provided UDP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &udp, timeout)
+ if err != nil {
+ return nil, err
+ }
+ gotUDP, ok := layer.(*UDP)
+ if !ok {
+ t.Fatalf("expected %s to be UDP", layer)
+ }
+ return gotUDP, nil
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = &udp
+ if payload.length() != 0 {
+ expected = append(expected, &payload)
+ }
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
+}
+
+// Close frees associated resources held by the UDPIPv4 connection.
+func (conn *UDPIPv4) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv4) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
+}
+
+// UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection.
+type UDPIPv6 Connection
+
+// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults.
+func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 {
+ t.Helper()
+
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv6State, err := newIPv6State(IPv6{}, IPv6{})
+ if err != nil {
+ t.Fatalf("can't make IPv6State: %s", err)
+ }
+ udpState, err := newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP)
+ if err != nil {
+ t.Fatalf("can't make udpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+ return UDPIPv6{
+ layerStates: []layerState{etherState, ipv6State, udpState},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+func (conn *UDPIPv6) udpState(t *testing.T) *udpState {
+ t.Helper()
+
+ state, ok := conn.layerStates[2].(*udpState)
+ if !ok {
+ t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ }
+ return state
+}
+
+func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State {
+ t.Helper()
+
+ state, ok := conn.layerStates[1].(*ipv6State)
+ if !ok {
+ t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1])
+ }
+ return state
+}
+
+// LocalAddr gets the local socket address of this connection.
+func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet6{
+ Port: int(*conn.udpState(t).out.SrcPort),
+ // Local address is in perspective to the remote host, so it's scoped to the
+ // ID of the remote interface.
+ ZoneId: uint32(RemoteInterfaceID),
+ }
+ copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr)
+ return sa
+}
+
+// Send sends a packet with reasonable defaults, potentially overriding the UDP
+// layer and adding additionLayers.
+func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...)
+}
+
+// SendIPv6 sends a packet with reasonable defaults, potentially overriding the
+// UDP and IPv6 headers and adding additionLayers.
+func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
+}
+
+// Expect expects a frame with the UDP layer matching the provided UDP within
+// the timeout specified. If it doesn't arrive in time, an error is returned.
+func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &udp, timeout)
+ if err != nil {
+ return nil, err
+ }
+ gotUDP, ok := layer.(*UDP)
+ if !ok {
+ t.Fatalf("expected %s to be UDP", layer)
+ }
+ return gotUDP, nil
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = &udp
+ if payload.length() != 0 {
+ expected = append(expected, &payload)
+ }
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
+}
+
+// Close frees associated resources held by the UDPIPv6 connection.
+func (conn *UDPIPv6) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
+}
+
+// Drain drains the sniffer's receive buffer by receiving packets until there's
+// nothing else to receive.
+func (conn *UDPIPv6) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
+}
+
+// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection.
+type TCPIPv6 Connection
+
+// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults.
+func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 {
+ etherState, err := newEtherState(Ether{}, Ether{})
+ if err != nil {
+ t.Fatalf("can't make etherState: %s", err)
+ }
+ ipv6State, err := newIPv6State(IPv6{}, IPv6{})
+ if err != nil {
+ t.Fatalf("can't make ipv6State: %s", err)
+ }
+ tcpState, err := newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP)
+ if err != nil {
+ t.Fatalf("can't make tcpState: %s", err)
+ }
+ injector, err := NewInjector(t)
+ if err != nil {
+ t.Fatalf("can't make injector: %s", err)
+ }
+ sniffer, err := NewSniffer(t)
+ if err != nil {
+ t.Fatalf("can't make sniffer: %s", err)
+ }
+
+ return TCPIPv6{
+ layerStates: []layerState{etherState, ipv6State, tcpState},
+ injector: injector,
+ sniffer: sniffer,
+ }
+}
+
+func (conn *TCPIPv6) SrcPort() uint16 {
+ state := conn.layerStates[2].(*tcpState)
+ return *state.out.SrcPort
+}
+
+// ExpectData is a convenient method that expects a Layer and the Layer after
+// it. If it doens't arrive in time, it returns nil.
+func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ expected := make([]Layer, len(conn.layerStates))
+ expected[len(expected)-1] = tcp
+ if payload != nil {
+ expected = append(expected, payload)
+ }
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
+}
+
+// Close frees associated resources held by the TCPIPv6 connection.
+func (conn *TCPIPv6) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
+}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
new file mode 100644
index 000000000..73c532e75
--- /dev/null
+++ b/test/packetimpact/testbench/dut.go
@@ -0,0 +1,702 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "context"
+ "flag"
+ "net"
+ "strconv"
+ "syscall"
+ "testing"
+
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+
+ "golang.org/x/sys/unix"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/keepalive"
+)
+
+// DUT communicates with the DUT to force it to make POSIX calls.
+type DUT struct {
+ conn *grpc.ClientConn
+ posixServer POSIXClient
+}
+
+// NewDUT creates a new connection with the DUT over gRPC.
+func NewDUT(t *testing.T) DUT {
+ t.Helper()
+
+ flag.Parse()
+ if err := genPseudoFlags(); err != nil {
+ t.Fatal("generating psuedo flags:", err)
+ }
+
+ posixServerAddress := POSIXServerIP + ":" + strconv.Itoa(POSIXServerPort)
+ conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive}))
+ if err != nil {
+ t.Fatalf("failed to grpc.Dial(%s): %s", posixServerAddress, err)
+ }
+ posixServer := NewPOSIXClient(conn)
+ return DUT{
+ conn: conn,
+ posixServer: posixServer,
+ }
+}
+
+// TearDown closes the underlying connection.
+func (dut *DUT) TearDown() {
+ dut.conn.Close()
+}
+
+func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr {
+ t.Helper()
+
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In{
+ In: &pb.SockaddrIn{
+ Family: unix.AF_INET,
+ Port: uint32(s.Port),
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ case *unix.SockaddrInet6:
+ return &pb.Sockaddr{
+ Sockaddr: &pb.Sockaddr_In6{
+ In6: &pb.SockaddrIn6{
+ Family: unix.AF_INET6,
+ Port: uint32(s.Port),
+ Flowinfo: 0,
+ ScopeId: s.ZoneId,
+ Addr: s.Addr[:],
+ },
+ },
+ }
+ }
+ t.Fatalf("can't parse Sockaddr struct: %+v", sa)
+ return nil
+}
+
+func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr {
+ t.Helper()
+
+ switch s := sa.Sockaddr.(type) {
+ case *pb.Sockaddr_In:
+ ret := unix.SockaddrInet4{
+ Port: int(s.In.GetPort()),
+ }
+ copy(ret.Addr[:], s.In.GetAddr())
+ return &ret
+ case *pb.Sockaddr_In6:
+ ret := unix.SockaddrInet6{
+ Port: int(s.In6.GetPort()),
+ ZoneId: s.In6.GetScopeId(),
+ }
+ copy(ret.Addr[:], s.In6.GetAddr())
+ return &ret
+ }
+ t.Fatalf("can't parse Sockaddr proto: %#v", sa)
+ return nil
+}
+
+// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
+// proto, and bound to the IP address addr. Returns the new file descriptor and
+// the port that was selected on the DUT.
+func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) {
+ t.Helper()
+
+ var fd int32
+ if addr.To4() != nil {
+ fd = dut.Socket(t, unix.AF_INET, typ, proto)
+ sa := unix.SockaddrInet4{}
+ copy(sa.Addr[:], addr.To4())
+ dut.Bind(t, fd, &sa)
+ } else if addr.To16() != nil {
+ fd = dut.Socket(t, unix.AF_INET6, typ, proto)
+ sa := unix.SockaddrInet6{}
+ copy(sa.Addr[:], addr.To16())
+ sa.ZoneId = uint32(RemoteInterfaceID)
+ dut.Bind(t, fd, &sa)
+ } else {
+ t.Fatalf("invalid IP address: %s", addr)
+ }
+ sa := dut.GetSockName(t, fd)
+ var port int
+ switch s := sa.(type) {
+ case *unix.SockaddrInet4:
+ port = s.Port
+ case *unix.SockaddrInet6:
+ port = s.Port
+ default:
+ t.Fatalf("unknown sockaddr type from getsockname: %T", sa)
+ }
+ return fd, uint16(port)
+}
+
+// CreateListener makes a new TCP connection. If it fails, the test ends.
+func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) {
+ t.Helper()
+
+ fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4))
+ dut.Listen(t, fd, backlog)
+ return fd, remotePort
+}
+
+// All the functions that make gRPC calls to the POSIX service are below, sorted
+// alphabetically.
+
+// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// AcceptWithErrno.
+func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd)
+ if fd < 0 {
+ t.Fatalf("failed to accept: %s", err)
+ }
+ return fd, sa
+}
+
+// AcceptWithErrno calls accept on the DUT.
+func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) {
+ t.Helper()
+
+ req := pb.AcceptRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.Accept(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Accept: %s", err)
+ }
+ return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is
+// needed, use BindWithErrno.
+func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.BindWithErrno(ctx, t, fd, sa)
+ if ret != 0 {
+ t.Fatalf("failed to bind socket: %s", err)
+ }
+}
+
+// BindWithErrno calls bind on the DUT.
+func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) {
+ t.Helper()
+
+ req := pb.BindRequest{
+ Sockfd: fd,
+ Addr: dut.sockaddrToProto(t, sa),
+ }
+ resp, err := dut.posixServer.Bind(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Close calls close on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// CloseWithErrno.
+func (dut *DUT) Close(t *testing.T, fd int32) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.CloseWithErrno(ctx, t, fd)
+ if ret != 0 {
+ t.Fatalf("failed to close: %s", err)
+ }
+}
+
+// CloseWithErrno calls close on the DUT.
+func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) {
+ t.Helper()
+
+ req := pb.CloseRequest{
+ Fd: fd,
+ }
+ resp, err := dut.posixServer.Close(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Close: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Connect calls connect on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use ConnectWithErrno.
+func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.ConnectWithErrno(ctx, t, fd, sa)
+ // Ignore 'operation in progress' error that can be returned when the socket
+ // is non-blocking.
+ if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 {
+ t.Fatalf("failed to connect socket: %s", err)
+ }
+}
+
+// ConnectWithErrno calls bind on the DUT.
+func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) {
+ t.Helper()
+
+ req := pb.ConnectRequest{
+ Sockfd: fd,
+ Addr: dut.sockaddrToProto(t, sa),
+ }
+ resp, err := dut.posixServer.Connect(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Connect: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Fcntl calls fcntl on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use FcntlWithErrno.
+func (dut *DUT) Fcntl(t *testing.T, fd, cmd, arg int32) int32 {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.FcntlWithErrno(ctx, t, fd, cmd, arg)
+ if ret == -1 {
+ t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err)
+ }
+ return ret
+}
+
+// FcntlWithErrno calls fcntl on the DUT.
+func (dut *DUT) FcntlWithErrno(ctx context.Context, t *testing.T, fd, cmd, arg int32) (int32, error) {
+ t.Helper()
+
+ req := pb.FcntlRequest{
+ Fd: fd,
+ Cmd: cmd,
+ Arg: arg,
+ }
+ resp, err := dut.posixServer.Fcntl(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Fcntl: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockName calls getsockname on the DUT and causes a fatal test failure if
+// it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockNameWithErrno.
+func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd)
+ if ret != 0 {
+ t.Fatalf("failed to getsockname: %s", err)
+ }
+ return sa
+}
+
+// GetSockNameWithErrno calls getsockname on the DUT.
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) {
+ t.Helper()
+
+ req := pb.GetSockNameRequest{
+ Sockfd: sockfd,
+ }
+ resp, err := dut.posixServer.GetSockName(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Bind: %s", err)
+ }
+ return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+}
+
+func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) {
+ t.Helper()
+
+ req := pb.GetSockOptRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Optlen: optlen,
+ Type: typ,
+ }
+ resp, err := dut.posixServer.GetSockOpt(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call GetSockOpt: %s", err)
+ }
+ optval := resp.GetOptval()
+ if optval == nil {
+ t.Fatalf("GetSockOpt response does not contain a value")
+ }
+ return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_())
+}
+
+// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific GetSockOptXxx function.
+func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen)
+ if ret != 0 {
+ t.Fatalf("failed to GetSockOpt: %s", err)
+ }
+ return optval
+}
+
+// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific GetSockOptXxxWithErrno function.
+func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES)
+ bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval)
+ if !ok {
+ t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val)
+ }
+ return ret, bytesval.Bytesval, errno
+}
+
+// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use GetSockOptIntWithErrno.
+func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname)
+ if ret != 0 {
+ t.Fatalf("failed to GetSockOptInt: %s", err)
+ }
+ return intval
+}
+
+// GetSockOptIntWithErrno calls getsockopt with an integer optval.
+func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT)
+ intval, ok := optval.Val.(*pb.SockOptVal_Intval)
+ if !ok {
+ t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val)
+ }
+ return ret, intval.Intval, errno
+}
+
+// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use GetSockOptTimevalWithErrno.
+func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname)
+ if ret != 0 {
+ t.Fatalf("failed to GetSockOptTimeval: %s", err)
+ }
+ return timeval
+}
+
+// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval.
+func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME)
+ tv, ok := optval.Val.(*pb.SockOptVal_Timeval)
+ if !ok {
+ t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val)
+ }
+ timeval := unix.Timeval{
+ Sec: tv.Timeval.Seconds,
+ Usec: tv.Timeval.Microseconds,
+ }
+ return ret, timeval, errno
+}
+
+// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// ListenWithErrno.
+func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog)
+ if ret != 0 {
+ t.Fatalf("failed to listen: %s", err)
+ }
+}
+
+// ListenWithErrno calls listen on the DUT.
+func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) {
+ t.Helper()
+
+ req := pb.ListenRequest{
+ Sockfd: sockfd,
+ Backlog: backlog,
+ }
+ resp, err := dut.posixServer.Listen(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Listen: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// Send calls send on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendWithErrno.
+func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags)
+ if ret == -1 {
+ t.Fatalf("failed to send: %s", err)
+ }
+ return ret
+}
+
+// SendWithErrno calls send on the DUT.
+func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) {
+ t.Helper()
+
+ req := pb.SendRequest{
+ Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
+ }
+ resp, err := dut.posixServer.Send(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Send: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// SendToWithErrno.
+func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr)
+ if ret == -1 {
+ t.Fatalf("failed to sendto: %s", err)
+ }
+ return ret
+}
+
+// SendToWithErrno calls sendto on the DUT.
+func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
+ t.Helper()
+
+ req := pb.SendToRequest{
+ Sockfd: sockfd,
+ Buf: buf,
+ Flags: flags,
+ DestAddr: dut.sockaddrToProto(t, destAddr),
+ }
+ resp, err := dut.posixServer.SendTo(ctx, &req)
+ if err != nil {
+ t.Fatalf("faled to call SendTo: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking
+// is true, otherwise it will clear the flag.
+func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) {
+ t.Helper()
+
+ flags := dut.Fcntl(t, fd, unix.F_GETFL, 0)
+ if nonblocking {
+ flags |= unix.O_NONBLOCK
+ } else {
+ flags &= ^unix.O_NONBLOCK
+ }
+ dut.Fcntl(t, fd, unix.F_SETFL, flags)
+}
+
+func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) {
+ t.Helper()
+
+ req := pb.SetSockOptRequest{
+ Sockfd: sockfd,
+ Level: level,
+ Optname: optname,
+ Optval: optval,
+ }
+ resp, err := dut.posixServer.SetSockOpt(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call SetSockOpt: %s", err)
+ }
+ return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+}
+
+// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
+// doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptWithErrno. Because endianess and the width of values
+// might differ between the testbench and DUT architectures, prefer to use a
+// more specific SetSockOptXxx function.
+func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval)
+ if ret != 0 {
+ t.Fatalf("failed to SetSockOpt: %s", err)
+ }
+}
+
+// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
+// width of values might differ between the testbench and DUT architectures,
+// prefer to use a more specific SetSockOptXxxWithErrno function.
+func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) {
+ t.Helper()
+
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}})
+}
+
+// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the int optval or error handling
+// is needed, use SetSockOptIntWithErrno.
+func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval)
+ if ret != 0 {
+ t.Fatalf("failed to SetSockOptInt: %s", err)
+ }
+}
+
+// SetSockOptIntWithErrno calls setsockopt with an integer optval.
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) {
+ t.Helper()
+
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}})
+}
+
+// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
+// if it doesn't succeed. If more control over the timeout or error handling is
+// needed, use SetSockOptTimevalWithErrno.
+func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv)
+ if ret != 0 {
+ t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ }
+}
+
+// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
+// bytes.
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+ t.Helper()
+
+ timeval := pb.Timeval{
+ Seconds: int64(tv.Sec),
+ Microseconds: int64(tv.Usec),
+ }
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}})
+}
+
+// Socket calls socket on the DUT and returns the file descriptor. If socket
+// fails on the DUT, the test ends.
+func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 {
+ t.Helper()
+
+ fd, err := dut.SocketWithErrno(t, domain, typ, proto)
+ if fd < 0 {
+ t.Fatalf("failed to create socket: %s", err)
+ }
+ return fd
+}
+
+// SocketWithErrno calls socket on the DUT and returns the fd and errno.
+func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) {
+ t.Helper()
+
+ req := pb.SocketRequest{
+ Domain: domain,
+ Type: typ,
+ Protocol: proto,
+ }
+ ctx := context.Background()
+ resp, err := dut.posixServer.Socket(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Socket: %s", err)
+ }
+ return resp.GetFd(), syscall.Errno(resp.GetErrno_())
+}
+
+// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
+// succeed. If more control over the timeout or error handling is needed, use
+// RecvWithErrno.
+func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
+ defer cancel()
+ ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags)
+ if ret == -1 {
+ t.Fatalf("failed to recv: %s", err)
+ }
+ return buf
+}
+
+// RecvWithErrno calls recv on the DUT.
+func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) {
+ t.Helper()
+
+ req := pb.RecvRequest{
+ Sockfd: sockfd,
+ Len: len,
+ Flags: flags,
+ }
+ resp, err := dut.posixServer.Recv(ctx, &req)
+ if err != nil {
+ t.Fatalf("failed to call Recv: %s", err)
+ }
+ return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
+}
diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go
new file mode 100644
index 000000000..d0e68c5da
--- /dev/null
+++ b/test/packetimpact/testbench/dut_client.go
@@ -0,0 +1,28 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "google.golang.org/grpc"
+ pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto"
+)
+
+// PosixClient is a gRPC client for the Posix service.
+type POSIXClient pb.PosixClient
+
+// NewPOSIXClient makes a new gRPC client for the POSIX service.
+func NewPOSIXClient(c grpc.ClientConnInterface) POSIXClient {
+ return pb.NewPosixClient(c)
+}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
new file mode 100644
index 000000000..a35562ca8
--- /dev/null
+++ b/test/packetimpact/testbench/layers.go
@@ -0,0 +1,1506 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "encoding/binary"
+ "encoding/hex"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/google/go-cmp/cmp/cmpopts"
+ "go.uber.org/multierr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+// Layer is the interface that all encapsulations must implement.
+//
+// A Layer is an encapsulation in a packet, such as TCP, IPv4, IPv6, etc. A
+// Layer contains all the fields of the encapsulation. Each field is a pointer
+// and may be nil.
+type Layer interface {
+ fmt.Stringer
+
+ // ToBytes converts the Layer into bytes. In places where the Layer's field
+ // isn't nil, the value that is pointed to is used. When the field is nil, a
+ // reasonable default for the Layer is used. For example, "64" for IPv4 TTL
+ // and a calculated checksum for TCP or IP. Some layers require information
+ // from the previous or next layers in order to compute a default, such as
+ // TCP's checksum or Ethernet's type, so each Layer has a doubly-linked list
+ // to the layer's neighbors.
+ ToBytes() ([]byte, error)
+
+ // match checks if the current Layer matches the provided Layer. If either
+ // Layer has a nil in a given field, that field is considered matching.
+ // Otherwise, the values pointed to by the fields must match. The LayerBase is
+ // ignored.
+ match(Layer) bool
+
+ // length in bytes of the current encapsulation
+ length() int
+
+ // next gets a pointer to the encapsulated Layer.
+ next() Layer
+
+ // prev gets a pointer to the Layer encapsulating this one.
+ Prev() Layer
+
+ // setNext sets the pointer to the encapsulated Layer.
+ setNext(Layer)
+
+ // setPrev sets the pointer to the Layer encapsulating this one.
+ setPrev(Layer)
+
+ // merge overrides the values in the interface with the provided values.
+ merge(Layer) error
+}
+
+// LayerBase is the common elements of all layers.
+type LayerBase struct {
+ nextLayer Layer
+ prevLayer Layer
+}
+
+func (lb *LayerBase) next() Layer {
+ return lb.nextLayer
+}
+
+// Prev returns the previous layer.
+func (lb *LayerBase) Prev() Layer {
+ return lb.prevLayer
+}
+
+func (lb *LayerBase) setNext(l Layer) {
+ lb.nextLayer = l
+}
+
+func (lb *LayerBase) setPrev(l Layer) {
+ lb.prevLayer = l
+}
+
+// equalLayer compares that two Layer structs match while ignoring field in
+// which either input has a nil and also ignoring the LayerBase of the inputs.
+func equalLayer(x, y Layer) bool {
+ if x == nil || y == nil {
+ return true
+ }
+ // opt ignores comparison pairs where either of the inputs is a nil.
+ opt := cmp.FilterValues(func(x, y interface{}) bool {
+ for _, l := range []interface{}{x, y} {
+ v := reflect.ValueOf(l)
+ if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Slice) && v.IsNil() {
+ return true
+ }
+ }
+ return false
+ }, cmp.Ignore())
+ return cmp.Equal(x, y, opt, cmpopts.IgnoreTypes(LayerBase{}))
+}
+
+// mergeLayer merges y into x. Any fields for which y has a non-nil value, that
+// value overwrite the corresponding fields in x.
+func mergeLayer(x, y Layer) error {
+ if y == nil {
+ return nil
+ }
+ if reflect.TypeOf(x) != reflect.TypeOf(y) {
+ return fmt.Errorf("can't merge %T into %T", y, x)
+ }
+ vx := reflect.ValueOf(x).Elem()
+ vy := reflect.ValueOf(y).Elem()
+ t := vy.Type()
+ for i := 0; i < vy.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := vy.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ vx.Field(i).Set(v)
+ }
+ return nil
+}
+
+func stringLayer(l Layer) string {
+ v := reflect.ValueOf(l).Elem()
+ t := v.Type()
+ var ret []string
+ for i := 0; i < v.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ v := v.Field(i)
+ if v.IsNil() {
+ continue
+ }
+ v = reflect.Indirect(v)
+ if v.Kind() == reflect.Slice && v.Type().Elem().Kind() == reflect.Uint8 {
+ ret = append(ret, fmt.Sprintf("%s:\n%v", t.Name, hex.Dump(v.Bytes())))
+ } else {
+ ret = append(ret, fmt.Sprintf("%s:%v", t.Name, v))
+ }
+ }
+ return fmt.Sprintf("&%s{%s}", t, strings.Join(ret, " "))
+}
+
+// Ether can construct and match an ethernet encapsulation.
+type Ether struct {
+ LayerBase
+ SrcAddr *tcpip.LinkAddress
+ DstAddr *tcpip.LinkAddress
+ Type *tcpip.NetworkProtocolNumber
+}
+
+func (l *Ether) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *Ether) ToBytes() ([]byte, error) {
+ b := make([]byte, header.EthernetMinimumSize)
+ h := header.Ethernet(b)
+ fields := &header.EthernetFields{}
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Type != nil {
+ fields.Type = *l.Type
+ } else {
+ switch n := l.next().(type) {
+ case *IPv4:
+ fields.Type = header.IPv4ProtocolNumber
+ case *IPv6:
+ fields.Type = header.IPv6ProtocolNumber
+ default:
+ return nil, fmt.Errorf("ethernet header's next layer is unrecognized: %#v", n)
+ }
+ }
+ h.Encode(fields)
+ return h, nil
+}
+
+// LinkAddress is a helper routine that allocates a new tcpip.LinkAddress value
+// to store v and returns a pointer to it.
+func LinkAddress(v tcpip.LinkAddress) *tcpip.LinkAddress {
+ return &v
+}
+
+// NetworkProtocolNumber is a helper routine that allocates a new
+// tcpip.NetworkProtocolNumber value to store v and returns a pointer to it.
+func NetworkProtocolNumber(v tcpip.NetworkProtocolNumber) *tcpip.NetworkProtocolNumber {
+ return &v
+}
+
+// layerParser parses the input bytes and returns a Layer along with the next
+// layerParser to run. If there is no more parsing to do, the returned
+// layerParser is nil.
+type layerParser func([]byte) (Layer, layerParser)
+
+// parse parses bytes starting with the first layerParser and using successive
+// layerParsers until all the bytes are parsed.
+func parse(parser layerParser, b []byte) Layers {
+ var layers Layers
+ for {
+ var layer Layer
+ layer, parser = parser(b)
+ layers = append(layers, layer)
+ if parser == nil {
+ break
+ }
+ b = b[layer.length():]
+ }
+ layers.linkLayers()
+ return layers
+}
+
+// parseEther parses the bytes assuming that they start with an ethernet header
+// and continues parsing further encapsulations.
+func parseEther(b []byte) (Layer, layerParser) {
+ h := header.Ethernet(b)
+ ether := Ether{
+ SrcAddr: LinkAddress(h.SourceAddress()),
+ DstAddr: LinkAddress(h.DestinationAddress()),
+ Type: NetworkProtocolNumber(h.Type()),
+ }
+ var nextParser layerParser
+ switch h.Type() {
+ case header.IPv4ProtocolNumber:
+ nextParser = parseIPv4
+ case header.IPv6ProtocolNumber:
+ nextParser = parseIPv6
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
+ }
+ return &ether, nextParser
+}
+
+func (l *Ether) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Ether) length() int {
+ return header.EthernetMinimumSize
+}
+
+// merge implements Layer.merge.
+func (l *Ether) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv4 can construct and match an IPv4 encapsulation.
+type IPv4 struct {
+ LayerBase
+ IHL *uint8
+ TOS *uint8
+ TotalLength *uint16
+ ID *uint16
+ Flags *uint8
+ FragmentOffset *uint16
+ TTL *uint8
+ Protocol *uint8
+ Checksum *uint16
+ SrcAddr *tcpip.Address
+ DstAddr *tcpip.Address
+}
+
+func (l *IPv4) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv4) ToBytes() ([]byte, error) {
+ b := make([]byte, header.IPv4MinimumSize)
+ h := header.IPv4(b)
+ fields := &header.IPv4Fields{
+ IHL: 20,
+ TOS: 0,
+ TotalLength: 0,
+ ID: 0,
+ Flags: 0,
+ FragmentOffset: 0,
+ TTL: 64,
+ Protocol: 0,
+ Checksum: 0,
+ SrcAddr: tcpip.Address(""),
+ DstAddr: tcpip.Address(""),
+ }
+ if l.TOS != nil {
+ fields.TOS = *l.TOS
+ }
+ if l.TotalLength != nil {
+ fields.TotalLength = *l.TotalLength
+ } else {
+ fields.TotalLength = uint16(l.length())
+ current := l.next()
+ for current != nil {
+ fields.TotalLength += uint16(current.length())
+ current = current.next()
+ }
+ }
+ if l.ID != nil {
+ fields.ID = *l.ID
+ }
+ if l.Flags != nil {
+ fields.Flags = *l.Flags
+ }
+ if l.FragmentOffset != nil {
+ fields.FragmentOffset = *l.FragmentOffset
+ }
+ if l.TTL != nil {
+ fields.TTL = *l.TTL
+ }
+ if l.Protocol != nil {
+ fields.Protocol = *l.Protocol
+ } else {
+ switch n := l.next().(type) {
+ case *TCP:
+ fields.Protocol = uint8(header.TCPProtocolNumber)
+ case *UDP:
+ fields.Protocol = uint8(header.UDPProtocolNumber)
+ case *ICMPv4:
+ fields.Protocol = uint8(header.ICMPv4ProtocolNumber)
+ default:
+ // TODO(b/150301488): Support more protocols as needed.
+ return nil, fmt.Errorf("ipv4 header's next layer is unrecognized: %#v", n)
+ }
+ }
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ if l.Checksum != nil {
+ fields.Checksum = *l.Checksum
+ }
+ h.Encode(fields)
+ if l.Checksum == nil {
+ h.SetChecksum(^h.CalculateChecksum())
+ }
+ return h, nil
+}
+
+// Uint16 is a helper routine that allocates a new
+// uint16 value to store v and returns a pointer to it.
+func Uint16(v uint16) *uint16 {
+ return &v
+}
+
+// Uint8 is a helper routine that allocates a new
+// uint8 value to store v and returns a pointer to it.
+func Uint8(v uint8) *uint8 {
+ return &v
+}
+
+// Address is a helper routine that allocates a new tcpip.Address value to store
+// v and returns a pointer to it.
+func Address(v tcpip.Address) *tcpip.Address {
+ return &v
+}
+
+// parseIPv4 parses the bytes assuming that they start with an ipv4 header and
+// continues parsing further encapsulations.
+func parseIPv4(b []byte) (Layer, layerParser) {
+ h := header.IPv4(b)
+ tos, _ := h.TOS()
+ ipv4 := IPv4{
+ IHL: Uint8(h.HeaderLength()),
+ TOS: &tos,
+ TotalLength: Uint16(h.TotalLength()),
+ ID: Uint16(h.ID()),
+ Flags: Uint8(h.Flags()),
+ FragmentOffset: Uint16(h.FragmentOffset()),
+ TTL: Uint8(h.TTL()),
+ Protocol: Uint8(h.Protocol()),
+ Checksum: Uint16(h.Checksum()),
+ SrcAddr: Address(h.SourceAddress()),
+ DstAddr: Address(h.DestinationAddress()),
+ }
+ var nextParser layerParser
+ switch h.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ nextParser = parseTCP
+ case header.UDPProtocolNumber:
+ nextParser = parseUDP
+ case header.ICMPv4ProtocolNumber:
+ nextParser = parseICMPv4
+ default:
+ // Assume that the rest is a payload.
+ nextParser = parsePayload
+ }
+ return &ipv4, nextParser
+}
+
+func (l *IPv4) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *IPv4) length() int {
+ if l.IHL == nil {
+ return header.IPv4MinimumSize
+ }
+ return int(*l.IHL)
+}
+
+// merge implements Layer.merge.
+func (l *IPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv6 can construct and match an IPv6 encapsulation.
+type IPv6 struct {
+ LayerBase
+ TrafficClass *uint8
+ FlowLabel *uint32
+ PayloadLength *uint16
+ NextHeader *uint8
+ HopLimit *uint8
+ SrcAddr *tcpip.Address
+ DstAddr *tcpip.Address
+}
+
+func (l *IPv6) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6) ToBytes() ([]byte, error) {
+ b := make([]byte, header.IPv6MinimumSize)
+ h := header.IPv6(b)
+ fields := &header.IPv6Fields{
+ HopLimit: 64,
+ }
+ if l.TrafficClass != nil {
+ fields.TrafficClass = *l.TrafficClass
+ }
+ if l.FlowLabel != nil {
+ fields.FlowLabel = *l.FlowLabel
+ }
+ if l.PayloadLength != nil {
+ fields.PayloadLength = *l.PayloadLength
+ } else {
+ for current := l.next(); current != nil; current = current.next() {
+ fields.PayloadLength += uint16(current.length())
+ }
+ }
+ if l.NextHeader != nil {
+ fields.NextHeader = *l.NextHeader
+ } else {
+ nh, err := nextHeaderByLayer(l.next())
+ if err != nil {
+ return nil, err
+ }
+ fields.NextHeader = nh
+ }
+ if l.HopLimit != nil {
+ fields.HopLimit = *l.HopLimit
+ }
+ if l.SrcAddr != nil {
+ fields.SrcAddr = *l.SrcAddr
+ }
+ if l.DstAddr != nil {
+ fields.DstAddr = *l.DstAddr
+ }
+ h.Encode(fields)
+ return h, nil
+}
+
+// nextIPv6PayloadParser finds the corresponding parser for nextHeader.
+func nextIPv6PayloadParser(nextHeader uint8) layerParser {
+ switch tcpip.TransportProtocolNumber(nextHeader) {
+ case header.TCPProtocolNumber:
+ return parseTCP
+ case header.UDPProtocolNumber:
+ return parseUDP
+ case header.ICMPv6ProtocolNumber:
+ return parseICMPv6
+ }
+ switch header.IPv6ExtensionHeaderIdentifier(nextHeader) {
+ case header.IPv6HopByHopOptionsExtHdrIdentifier:
+ return parseIPv6HopByHopOptionsExtHdr
+ case header.IPv6DestinationOptionsExtHdrIdentifier:
+ return parseIPv6DestinationOptionsExtHdr
+ case header.IPv6FragmentExtHdrIdentifier:
+ return parseIPv6FragmentExtHdr
+ }
+ return parsePayload
+}
+
+// parseIPv6 parses the bytes assuming that they start with an ipv6 header and
+// continues parsing further encapsulations.
+func parseIPv6(b []byte) (Layer, layerParser) {
+ h := header.IPv6(b)
+ tos, flowLabel := h.TOS()
+ ipv6 := IPv6{
+ TrafficClass: &tos,
+ FlowLabel: &flowLabel,
+ PayloadLength: Uint16(h.PayloadLength()),
+ NextHeader: Uint8(h.NextHeader()),
+ HopLimit: Uint8(h.HopLimit()),
+ SrcAddr: Address(h.SourceAddress()),
+ DstAddr: Address(h.DestinationAddress()),
+ }
+ nextParser := nextIPv6PayloadParser(h.NextHeader())
+ return &ipv6, nextParser
+}
+
+func (l *IPv6) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *IPv6) length() int {
+ return header.IPv6MinimumSize
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions
+// Extension Header.
+type IPv6HopByHopOptionsExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ Options []byte
+}
+
+// IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions
+// Extension Header.
+type IPv6DestinationOptionsExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ Options []byte
+}
+
+// IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header.
+type IPv6FragmentExtHdr struct {
+ LayerBase
+ NextHeader *header.IPv6ExtensionHeaderIdentifier
+ FragmentOffset *uint16
+ MoreFragments *bool
+ Identification *uint32
+}
+
+// nextHeaderByLayer finds the correct next header protocol value for layer l.
+func nextHeaderByLayer(l Layer) (uint8, error) {
+ if l == nil {
+ return uint8(header.IPv6NoNextHeaderIdentifier), nil
+ }
+ switch l.(type) {
+ case *TCP:
+ return uint8(header.TCPProtocolNumber), nil
+ case *UDP:
+ return uint8(header.UDPProtocolNumber), nil
+ case *ICMPv6:
+ return uint8(header.ICMPv6ProtocolNumber), nil
+ case *Payload:
+ return uint8(header.IPv6NoNextHeaderIdentifier), nil
+ case *IPv6HopByHopOptionsExtHdr:
+ return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil
+ case *IPv6DestinationOptionsExtHdr:
+ return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil
+ case *IPv6FragmentExtHdr:
+ return uint8(header.IPv6FragmentExtHdrIdentifier), nil
+ default:
+ // TODO(b/161005083): Support more protocols as needed.
+ return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l)
+ }
+}
+
+// ipv6OptionsExtHdrToBytes serializes an options extension header into bytes.
+func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) {
+ length := len(options) + 2
+ if length%8 != 0 {
+ return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options))
+ }
+ bytes := make([]byte, length)
+ if nextHeader != nil {
+ bytes[0] = byte(*nextHeader)
+ } else {
+ nh, err := nextHeaderByLayer(nextLayer)
+ if err != nil {
+ return nil, err
+ }
+ bytes[0] = nh
+ }
+ // ExtHdrLen field is the length of the extension header
+ // in 8-octet unit, ignoring the first 8 octets.
+ // https://tools.ietf.org/html/rfc2460#section-4.3
+ // https://tools.ietf.org/html/rfc2460#section-4.6
+ bytes[1] = uint8((length - 8) / 8)
+ copy(bytes[2:], options)
+ return bytes, nil
+}
+
+// IPv6ExtHdrIdent is a helper routine that allocates a new
+// header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer
+// to it.
+func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier {
+ return &id
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) {
+ return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) {
+ return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) {
+ var offset, mflag uint16
+ var ident uint32
+ bytes := make([]byte, header.IPv6FragmentExtHdrLength)
+ if l.NextHeader != nil {
+ bytes[0] = byte(*l.NextHeader)
+ } else {
+ nh, err := nextHeaderByLayer(l.next())
+ if err != nil {
+ return nil, err
+ }
+ bytes[0] = nh
+ }
+ bytes[1] = 0 // reserved
+ if l.MoreFragments != nil && *l.MoreFragments {
+ mflag = 1
+ }
+ if l.FragmentOffset != nil {
+ offset = *l.FragmentOffset
+ }
+ if l.Identification != nil {
+ ident = *l.Identification
+ }
+ offsetAndMflag := offset<<3 | mflag
+ binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag)
+ binary.BigEndian.PutUint32(bytes[4:], ident)
+
+ return bytes, nil
+}
+
+// parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader
+// field, the rest of the payload and a parser function for the corresponding
+// next extension header.
+func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) {
+ nextHeader := b[0]
+ // For HopByHop and Destination options extension headers,
+ // This field is the length of the extension header in
+ // 8-octet units, not including the first 8 octets.
+ // https://tools.ietf.org/html/rfc2460#section-4.3
+ // https://tools.ietf.org/html/rfc2460#section-4.6
+ length := b[1]*8 + 8
+ data := b[2:length]
+ nextParser := nextIPv6PayloadParser(nextHeader)
+ return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser
+}
+
+// parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start
+// with an IPv6 HopByHop Options Extension Header.
+func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader, options, nextParser := parseIPv6ExtHdr(b)
+ return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
+}
+
+// parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start
+// with an IPv6 Destination Options Extension Header.
+func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader, options, nextParser := parseIPv6ExtHdr(b)
+ return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser
+}
+
+// Bool is a helper routine that allocates a new
+// bool value to store v and returns a pointer to it.
+func Bool(v bool) *bool {
+ return &v
+}
+
+// parseIPv6FragmentExtHdr parses the bytes assuming that they start
+// with an IPv6 Fragment Extension Header.
+func parseIPv6FragmentExtHdr(b []byte) (Layer, layerParser) {
+ nextHeader := b[0]
+ var extHdr header.IPv6FragmentExtHdr
+ copy(extHdr[:], b[2:])
+ return &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)),
+ FragmentOffset: Uint16(extHdr.FragmentOffset()),
+ MoreFragments: Bool(extHdr.More()),
+ Identification: Uint32(extHdr.ID()),
+ }, nextIPv6PayloadParser(nextHeader)
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) length() int {
+ return len(l.Options) + 2
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6HopByHopOptionsExtHdr) String() string {
+ return stringLayer(l)
+}
+
+func (l *IPv6DestinationOptionsExtHdr) length() int {
+ return len(l.Options) + 2
+}
+
+func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6DestinationOptionsExtHdr) String() string {
+ return stringLayer(l)
+}
+
+func (*IPv6FragmentExtHdr) length() int {
+ return header.IPv6FragmentExtHdrLength
+}
+
+func (l *IPv6FragmentExtHdr) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *IPv6FragmentExtHdr) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+func (l *IPv6FragmentExtHdr) String() string {
+ return stringLayer(l)
+}
+
+// ICMPv6 can construct and match an ICMPv6 encapsulation.
+type ICMPv6 struct {
+ LayerBase
+ Type *header.ICMPv6Type
+ Code *header.ICMPv6Code
+ Checksum *uint16
+ Payload []byte
+}
+
+func (l *ICMPv6) String() string {
+ // TODO(eyalsoha): Do something smarter here when *l.Type is ParameterProblem?
+ // We could parse the contents of the Payload as if it were an IPv6 packet.
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *ICMPv6) ToBytes() ([]byte, error) {
+ b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload))
+ h := header.ICMPv6(b)
+ if l.Type != nil {
+ h.SetType(*l.Type)
+ }
+ if l.Code != nil {
+ h.SetCode(*l.Code)
+ }
+ copy(h.NDPPayload(), l.Payload)
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ } else {
+ // It is possible that the ICMPv6 header does not follow the IPv6 header
+ // immediately, there could be one or more extension headers in between.
+ // We need to search forward to find the IPv6 header.
+ for prev := l.Prev(); prev != nil; prev = prev.Prev() {
+ if ipv6, ok := prev.(*IPv6); ok {
+ payload, err := payload(l)
+ if err != nil {
+ return nil, err
+ }
+ h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, payload))
+ break
+ }
+ }
+ }
+ return h, nil
+}
+
+// ICMPv6Type is a helper routine that allocates a new ICMPv6Type value to store
+// v and returns a pointer to it.
+func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type {
+ return &v
+}
+
+// ICMPv6Code is a helper routine that allocates a new ICMPv6Type value to store
+// v and returns a pointer to it.
+func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code {
+ return &v
+}
+
+// Byte is a helper routine that allocates a new byte value to store
+// v and returns a pointer to it.
+func Byte(v byte) *byte {
+ return &v
+}
+
+// parseICMPv6 parses the bytes assuming that they start with an ICMPv6 header.
+func parseICMPv6(b []byte) (Layer, layerParser) {
+ h := header.ICMPv6(b)
+ icmpv6 := ICMPv6{
+ Type: ICMPv6Type(h.Type()),
+ Code: ICMPv6Code(h.Code()),
+ Checksum: Uint16(h.Checksum()),
+ Payload: h.NDPPayload(),
+ }
+ return &icmpv6, nil
+}
+
+func (l *ICMPv6) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *ICMPv6) length() int {
+ return header.ICMPv6HeaderSize + len(l.Payload)
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *ICMPv6) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// ICMPv4Type is a helper routine that allocates a new header.ICMPv4Type value
+// to store t and returns a pointer to it.
+func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type {
+ return &t
+}
+
+// ICMPv4Code is a helper routine that allocates a new header.ICMPv4Code value
+// to store t and returns a pointer to it.
+func ICMPv4Code(t header.ICMPv4Code) *header.ICMPv4Code {
+ return &t
+}
+
+// ICMPv4 can construct and match an ICMPv4 encapsulation.
+type ICMPv4 struct {
+ LayerBase
+ Type *header.ICMPv4Type
+ Code *header.ICMPv4Code
+ Checksum *uint16
+}
+
+func (l *ICMPv4) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *ICMPv4) ToBytes() ([]byte, error) {
+ b := make([]byte, header.ICMPv4MinimumSize)
+ h := header.ICMPv4(b)
+ if l.Type != nil {
+ h.SetType(*l.Type)
+ }
+ if l.Code != nil {
+ h.SetCode(*l.Code)
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ payload, err := payload(l)
+ if err != nil {
+ return nil, err
+ }
+ h.SetChecksum(header.ICMPv4Checksum(h, payload))
+ return h, nil
+}
+
+// parseICMPv4 parses the bytes as an ICMPv4 header, returning a Layer and a
+// parser for the encapsulated payload.
+func parseICMPv4(b []byte) (Layer, layerParser) {
+ h := header.ICMPv4(b)
+ icmpv4 := ICMPv4{
+ Type: ICMPv4Type(h.Type()),
+ Code: ICMPv4Code(h.Code()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &icmpv4, parsePayload
+}
+
+func (l *ICMPv4) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *ICMPv4) length() int {
+ return header.ICMPv4MinimumSize
+}
+
+// merge overrides the values in l with the values from other but only in fields
+// where the value is not nil.
+func (l *ICMPv4) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// TCP can construct and match a TCP encapsulation.
+type TCP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ SeqNum *uint32
+ AckNum *uint32
+ DataOffset *uint8
+ Flags *uint8
+ WindowSize *uint16
+ Checksum *uint16
+ UrgentPointer *uint16
+ Options []byte
+}
+
+func (l *TCP) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *TCP) ToBytes() ([]byte, error) {
+ b := make([]byte, l.length())
+ h := header.TCP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.SeqNum != nil {
+ h.SetSequenceNumber(*l.SeqNum)
+ }
+ if l.AckNum != nil {
+ h.SetAckNumber(*l.AckNum)
+ }
+ if l.DataOffset != nil {
+ h.SetDataOffset(*l.DataOffset)
+ } else {
+ h.SetDataOffset(uint8(l.length()))
+ }
+ if l.Flags != nil {
+ h.SetFlags(*l.Flags)
+ }
+ if l.WindowSize != nil {
+ h.SetWindowSize(*l.WindowSize)
+ } else {
+ h.SetWindowSize(32768)
+ }
+ if l.UrgentPointer != nil {
+ h.SetUrgentPoiner(*l.UrgentPointer)
+ }
+ copy(b[header.TCPMinimumSize:], l.Options)
+ header.AddTCPOptionPadding(b[header.TCPMinimumSize:], len(l.Options))
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setTCPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// totalLength returns the length of the provided layer and all following
+// layers.
+func totalLength(l Layer) int {
+ var totalLength int
+ for ; l != nil; l = l.next() {
+ totalLength += l.length()
+ }
+ return totalLength
+}
+
+// payload returns a buffer.VectorisedView of l's payload.
+func payload(l Layer) (buffer.VectorisedView, error) {
+ var payloadBytes buffer.VectorisedView
+ for current := l.next(); current != nil; current = current.next() {
+ payload, err := current.ToBytes()
+ if err != nil {
+ return buffer.VectorisedView{}, fmt.Errorf("can't get bytes for next header: %s", payload)
+ }
+ payloadBytes.AppendView(payload)
+ }
+ return payloadBytes, nil
+}
+
+// layerChecksum calculates the checksum of the Layer header, including the
+// peusdeochecksum of the layer before it and all the bytes after it.
+func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) {
+ totalLength := uint16(totalLength(l))
+ var xsum uint16
+ switch p := l.Prev().(type) {
+ case *IPv4:
+ xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength)
+ case *IPv6:
+ xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength)
+ default:
+ // TODO(b/161246171): Support more protocols.
+ return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p)
+ }
+ payloadBytes, err := payload(l)
+ if err != nil {
+ return 0, err
+ }
+ xsum = header.ChecksumVV(payloadBytes, xsum)
+ return xsum, nil
+}
+
+// setTCPChecksum calculates the checksum of the TCP header and sets it in h.
+func setTCPChecksum(h *header.TCP, tcp *TCP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(tcp, header.TCPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// Uint32 is a helper routine that allocates a new
+// uint32 value to store v and returns a pointer to it.
+func Uint32(v uint32) *uint32 {
+ return &v
+}
+
+// parseTCP parses the bytes assuming that they start with a tcp header and
+// continues parsing further encapsulations.
+func parseTCP(b []byte) (Layer, layerParser) {
+ h := header.TCP(b)
+ tcp := TCP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ SeqNum: Uint32(h.SequenceNumber()),
+ AckNum: Uint32(h.AckNumber()),
+ DataOffset: Uint8(h.DataOffset()),
+ Flags: Uint8(h.Flags()),
+ WindowSize: Uint16(h.WindowSize()),
+ Checksum: Uint16(h.Checksum()),
+ UrgentPointer: Uint16(h.UrgentPointer()),
+ Options: b[header.TCPMinimumSize:h.DataOffset()],
+ }
+ return &tcp, parsePayload
+}
+
+func (l *TCP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *TCP) length() int {
+ if l.DataOffset == nil {
+ // TCP header including the options must end on a 32-bit
+ // boundary; the user could potentially give us a slice
+ // whose length is not a multiple of 4 bytes, so we have
+ // to do the alignment here.
+ optlen := (len(l.Options) + 3) & ^3
+ return header.TCPMinimumSize + optlen
+ }
+ return int(*l.DataOffset)
+}
+
+// merge implements Layer.merge.
+func (l *TCP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// UDP can construct and match a UDP encapsulation.
+type UDP struct {
+ LayerBase
+ SrcPort *uint16
+ DstPort *uint16
+ Length *uint16
+ Checksum *uint16
+}
+
+func (l *UDP) String() string {
+ return stringLayer(l)
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *UDP) ToBytes() ([]byte, error) {
+ b := make([]byte, header.UDPMinimumSize)
+ h := header.UDP(b)
+ if l.SrcPort != nil {
+ h.SetSourcePort(*l.SrcPort)
+ }
+ if l.DstPort != nil {
+ h.SetDestinationPort(*l.DstPort)
+ }
+ if l.Length != nil {
+ h.SetLength(*l.Length)
+ } else {
+ h.SetLength(uint16(totalLength(l)))
+ }
+ if l.Checksum != nil {
+ h.SetChecksum(*l.Checksum)
+ return h, nil
+ }
+ if err := setUDPChecksum(&h, l); err != nil {
+ return nil, err
+ }
+ return h, nil
+}
+
+// setUDPChecksum calculates the checksum of the UDP header and sets it in h.
+func setUDPChecksum(h *header.UDP, udp *UDP) error {
+ h.SetChecksum(0)
+ xsum, err := layerChecksum(udp, header.UDPProtocolNumber)
+ if err != nil {
+ return err
+ }
+ h.SetChecksum(^h.CalculateChecksum(xsum))
+ return nil
+}
+
+// parseUDP parses the bytes assuming that they start with a udp header and
+// returns the parsed layer and the next parser to use.
+func parseUDP(b []byte) (Layer, layerParser) {
+ h := header.UDP(b)
+ udp := UDP{
+ SrcPort: Uint16(h.SourcePort()),
+ DstPort: Uint16(h.DestinationPort()),
+ Length: Uint16(h.Length()),
+ Checksum: Uint16(h.Checksum()),
+ }
+ return &udp, parsePayload
+}
+
+func (l *UDP) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *UDP) length() int {
+ return header.UDPMinimumSize
+}
+
+// merge implements Layer.merge.
+func (l *UDP) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// Payload has bytes beyond OSI layer 4.
+type Payload struct {
+ LayerBase
+ Bytes []byte
+}
+
+func (l *Payload) String() string {
+ return stringLayer(l)
+}
+
+// parsePayload parses the bytes assuming that they start with a payload and
+// continue to the end. There can be no further encapsulations.
+func parsePayload(b []byte) (Layer, layerParser) {
+ payload := Payload{
+ Bytes: b,
+ }
+ return &payload, nil
+}
+
+// ToBytes implements Layer.ToBytes.
+func (l *Payload) ToBytes() ([]byte, error) {
+ return l.Bytes, nil
+}
+
+// Length returns payload byte length.
+func (l *Payload) Length() int {
+ return l.length()
+}
+
+func (l *Payload) match(other Layer) bool {
+ return equalLayer(l, other)
+}
+
+func (l *Payload) length() int {
+ return len(l.Bytes)
+}
+
+// merge implements Layer.merge.
+func (l *Payload) merge(other Layer) error {
+ return mergeLayer(l, other)
+}
+
+// Layers is an array of Layer and supports similar functions to Layer.
+type Layers []Layer
+
+// linkLayers sets the linked-list ponters in ls.
+func (ls *Layers) linkLayers() {
+ for i, l := range *ls {
+ if i > 0 {
+ l.setPrev((*ls)[i-1])
+ } else {
+ l.setPrev(nil)
+ }
+ if i+1 < len(*ls) {
+ l.setNext((*ls)[i+1])
+ } else {
+ l.setNext(nil)
+ }
+ }
+}
+
+// ToBytes converts the Layers into bytes. It creates a linked list of the Layer
+// structs and then concatentates the output of ToBytes on each Layer.
+func (ls *Layers) ToBytes() ([]byte, error) {
+ ls.linkLayers()
+ outBytes := []byte{}
+ for _, l := range *ls {
+ layerBytes, err := l.ToBytes()
+ if err != nil {
+ return nil, err
+ }
+ outBytes = append(outBytes, layerBytes...)
+ }
+ return outBytes, nil
+}
+
+func (ls *Layers) match(other Layers) bool {
+ if len(*ls) > len(other) {
+ return false
+ }
+ for i, l := range *ls {
+ if !equalLayer(l, other[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// layerDiff stores the diffs for each field along with the label for the Layer.
+// If rows is nil, that means that there was no diff.
+type layerDiff struct {
+ label string
+ rows []layerDiffRow
+}
+
+// layerDiffRow stores the fields and corresponding values for two got and want
+// layers. If the value was nil then the string stored is the empty string.
+type layerDiffRow struct {
+ field, got, want string
+}
+
+// diffLayer extracts all differing fields between two layers.
+func diffLayer(got, want Layer) []layerDiffRow {
+ vGot := reflect.ValueOf(got).Elem()
+ vWant := reflect.ValueOf(want).Elem()
+ if vGot.Type() != vWant.Type() {
+ return nil
+ }
+ t := vGot.Type()
+ var result []layerDiffRow
+ for i := 0; i < t.NumField(); i++ {
+ t := t.Field(i)
+ if t.Anonymous {
+ // Ignore the LayerBase in the Layer struct.
+ continue
+ }
+ vGot := vGot.Field(i)
+ vWant := vWant.Field(i)
+ gotString := ""
+ if !vGot.IsNil() {
+ gotString = fmt.Sprint(reflect.Indirect(vGot))
+ }
+ wantString := ""
+ if !vWant.IsNil() {
+ wantString = fmt.Sprint(reflect.Indirect(vWant))
+ }
+ result = append(result, layerDiffRow{t.Name, gotString, wantString})
+ }
+ return result
+}
+
+// layerType returns a concise string describing the type of the Layer, like
+// "TCP", or "IPv6".
+func layerType(l Layer) string {
+ return reflect.TypeOf(l).Elem().Name()
+}
+
+// diff compares Layers and returns a representation of the difference. Each
+// Layer in the Layers is pairwise compared. If an element in either is nil, it
+// is considered a match with the other Layer. If two Layers have differing
+// types, they don't match regardless of the contents. If two Layers have the
+// same type then the fields in the Layer are pairwise compared. Fields that are
+// nil always match. Two non-nil fields only match if they point to equal
+// values. diff returns an empty string if and only if *ls and other match.
+func (ls *Layers) diff(other Layers) string {
+ var allDiffs []layerDiff
+ // Check the cases where one list is longer than the other, where one or both
+ // elements are nil, where the sides have different types, and where the sides
+ // have the same type.
+ for i := 0; i < len(*ls) || i < len(other); i++ {
+ if i >= len(*ls) {
+ // Matching ls against other where other is longer than ls. missing
+ // matches everything so we just include a label without any rows. Having
+ // no rows is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: "missing matches " + layerType(other[i]),
+ })
+ continue
+ }
+
+ if i >= len(other) {
+ // Matching ls against other where ls is longer than other. missing
+ // matches everything so we just include a label without any rows. Having
+ // no rows is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " matches missing",
+ })
+ continue
+ }
+
+ if (*ls)[i] == nil && other[i] == nil {
+ // Matching ls against other where both elements are nil. nil matches
+ // everything so we just include a label without any rows. Having no rows
+ // is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: "nil matches nil",
+ })
+ continue
+ }
+
+ if (*ls)[i] == nil {
+ // Matching ls against other where the element in ls is nil. nil matches
+ // everything so we just include a label without any rows. Having no rows
+ // is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: "nil matches " + layerType(other[i]),
+ })
+ continue
+ }
+
+ if other[i] == nil {
+ // Matching ls against other where the element in other is nil. nil
+ // matches everything so we just include a label without any rows. Having
+ // no rows is a sign that there was no diff.
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " matches nil",
+ })
+ continue
+ }
+
+ if reflect.TypeOf((*ls)[i]) == reflect.TypeOf(other[i]) {
+ // Matching ls against other where both elements have the same type. Match
+ // each field pairwise and only report a diff if there is a mismatch,
+ // which is only when both sides are non-nil and have differring values.
+ diff := diffLayer((*ls)[i], other[i])
+ var layerDiffRows []layerDiffRow
+ for _, d := range diff {
+ if d.got == "" || d.want == "" || d.got == d.want {
+ continue
+ }
+ layerDiffRows = append(layerDiffRows, layerDiffRow{
+ d.field,
+ d.got,
+ d.want,
+ })
+ }
+ if len(layerDiffRows) > 0 {
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]),
+ rows: layerDiffRows,
+ })
+ } else {
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " matches " + layerType(other[i]),
+ // Having no rows is a sign that there was no diff.
+ })
+ }
+ continue
+ }
+ // Neither side is nil and the types are different, so we'll display one
+ // side then the other.
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]) + " doesn't match " + layerType(other[i]),
+ })
+ diff := diffLayer((*ls)[i], (*ls)[i])
+ layerDiffRows := []layerDiffRow{}
+ for _, d := range diff {
+ if len(d.got) == 0 {
+ continue
+ }
+ layerDiffRows = append(layerDiffRows, layerDiffRow{
+ d.field,
+ d.got,
+ "",
+ })
+ }
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType((*ls)[i]),
+ rows: layerDiffRows,
+ })
+
+ layerDiffRows = []layerDiffRow{}
+ diff = diffLayer(other[i], other[i])
+ for _, d := range diff {
+ if len(d.want) == 0 {
+ continue
+ }
+ layerDiffRows = append(layerDiffRows, layerDiffRow{
+ d.field,
+ "",
+ d.want,
+ })
+ }
+ allDiffs = append(allDiffs, layerDiff{
+ label: layerType(other[i]),
+ rows: layerDiffRows,
+ })
+ }
+
+ output := ""
+ // These are for output formatting.
+ maxLabelLen, maxFieldLen, maxGotLen, maxWantLen := 0, 0, 0, 0
+ foundOne := false
+ for _, l := range allDiffs {
+ if len(l.label) > maxLabelLen && len(l.rows) > 0 {
+ maxLabelLen = len(l.label)
+ }
+ if l.rows != nil {
+ foundOne = true
+ }
+ for _, r := range l.rows {
+ if len(r.field) > maxFieldLen {
+ maxFieldLen = len(r.field)
+ }
+ if l := len(fmt.Sprint(r.got)); l > maxGotLen {
+ maxGotLen = l
+ }
+ if l := len(fmt.Sprint(r.want)); l > maxWantLen {
+ maxWantLen = l
+ }
+ }
+ }
+ if !foundOne {
+ return ""
+ }
+ for _, l := range allDiffs {
+ if len(l.rows) == 0 {
+ output += "(" + l.label + ")\n"
+ continue
+ }
+ for i, r := range l.rows {
+ var label string
+ if i == 0 {
+ label = l.label + ":"
+ }
+ output += fmt.Sprintf(
+ "%*s %*s %*v %*v\n",
+ maxLabelLen+1, label,
+ maxFieldLen+1, r.field+":",
+ maxGotLen, r.got,
+ maxWantLen, r.want,
+ )
+ }
+ }
+ return output
+}
+
+// merge merges the other Layers into ls. If the other Layers is longer, those
+// additional Layer structs are added to ls. The errors from merging are
+// collected and returned.
+func (ls *Layers) merge(other Layers) error {
+ var errs error
+ for i, o := range other {
+ if i < len(*ls) {
+ errs = multierr.Combine(errs, (*ls)[i].merge(o))
+ } else {
+ *ls = append(*ls, o)
+ }
+ }
+ return errs
+}
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
new file mode 100644
index 000000000..eca0780b5
--- /dev/null
+++ b/test/packetimpact/testbench/layers_test.go
@@ -0,0 +1,728 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "bytes"
+ "net"
+ "testing"
+
+ "github.com/mohae/deepcopy"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+func TestLayerMatch(t *testing.T) {
+ var nilPayload *Payload
+ noPayload := &Payload{}
+ emptyPayload := &Payload{Bytes: []byte{}}
+ fullPayload := &Payload{Bytes: []byte{1, 2, 3}}
+ emptyTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: emptyPayload}}
+ fullTCP := &TCP{SrcPort: Uint16(1234), LayerBase: LayerBase{nextLayer: fullPayload}}
+ for _, tt := range []struct {
+ a, b Layer
+ want bool
+ }{
+ {nilPayload, nilPayload, true},
+ {nilPayload, noPayload, true},
+ {nilPayload, emptyPayload, true},
+ {nilPayload, fullPayload, true},
+ {noPayload, noPayload, true},
+ {noPayload, emptyPayload, true},
+ {noPayload, fullPayload, true},
+ {emptyPayload, emptyPayload, true},
+ {emptyPayload, fullPayload, false},
+ {fullPayload, fullPayload, true},
+ {emptyTCP, fullTCP, true},
+ } {
+ if got := tt.a.match(tt.b); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.a, tt.b, got, tt.want)
+ }
+ if got := tt.b.match(tt.a); got != tt.want {
+ t.Errorf("%s.match(%s) = %t, want %t", tt.b, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestLayerMergeMismatch(t *testing.T) {
+ tcp := &TCP{}
+ otherTCP := &TCP{}
+ ipv4 := &IPv4{}
+ ether := &Ether{}
+ for _, tt := range []struct {
+ a, b Layer
+ success bool
+ }{
+ {tcp, tcp, true},
+ {tcp, otherTCP, true},
+ {tcp, ipv4, false},
+ {tcp, ether, false},
+ {tcp, nil, true},
+
+ {otherTCP, otherTCP, true},
+ {otherTCP, ipv4, false},
+ {otherTCP, ether, false},
+ {otherTCP, nil, true},
+
+ {ipv4, ipv4, true},
+ {ipv4, ether, false},
+ {ipv4, nil, true},
+
+ {ether, ether, true},
+ {ether, nil, true},
+ } {
+ if err := tt.a.merge(tt.b); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.a, tt.b, err)
+ }
+ if tt.b != nil {
+ if err := tt.b.merge(tt.a); (err == nil) != tt.success {
+ t.Errorf("%s.merge(%s) got %s, wanted the opposite", tt.b, tt.a, err)
+ }
+ }
+ }
+}
+
+func TestLayerMerge(t *testing.T) {
+ zero := Uint32(0)
+ one := Uint32(1)
+ two := Uint32(2)
+ empty := []byte{}
+ foo := []byte("foo")
+ bar := []byte("bar")
+ for _, tt := range []struct {
+ a, b Layer
+ want Layer
+ }{
+ {&TCP{AckNum: nil}, &TCP{AckNum: nil}, &TCP{AckNum: nil}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: nil}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: nil}, nil, &TCP{AckNum: nil}},
+
+ {&TCP{AckNum: zero}, &TCP{AckNum: nil}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: zero}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: zero}, nil, &TCP{AckNum: zero}},
+
+ {&TCP{AckNum: one}, &TCP{AckNum: nil}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: one}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: one}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: one}, nil, &TCP{AckNum: one}},
+
+ {&TCP{AckNum: two}, &TCP{AckNum: nil}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, &TCP{AckNum: zero}, &TCP{AckNum: zero}},
+ {&TCP{AckNum: two}, &TCP{AckNum: one}, &TCP{AckNum: one}},
+ {&TCP{AckNum: two}, &TCP{AckNum: two}, &TCP{AckNum: two}},
+ {&TCP{AckNum: two}, nil, &TCP{AckNum: two}},
+
+ {&Payload{Bytes: nil}, &Payload{Bytes: nil}, &Payload{Bytes: nil}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: nil}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: nil}, nil, &Payload{Bytes: nil}},
+
+ {&Payload{Bytes: empty}, &Payload{Bytes: nil}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: empty}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: empty}, nil, &Payload{Bytes: empty}},
+
+ {&Payload{Bytes: foo}, &Payload{Bytes: nil}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: foo}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: foo}, nil, &Payload{Bytes: foo}},
+
+ {&Payload{Bytes: bar}, &Payload{Bytes: nil}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: empty}, &Payload{Bytes: empty}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: foo}, &Payload{Bytes: foo}},
+ {&Payload{Bytes: bar}, &Payload{Bytes: bar}, &Payload{Bytes: bar}},
+ {&Payload{Bytes: bar}, nil, &Payload{Bytes: bar}},
+ } {
+ a := deepcopy.Copy(tt.a).(Layer)
+ if err := a.merge(tt.b); err != nil {
+ t.Errorf("%s.merge(%s) = %s, wanted nil", tt.a, tt.b, err)
+ continue
+ }
+ if a.String() != tt.want.String() {
+ t.Errorf("%s.merge(%s) merge result got %s, want %s", tt.a, tt.b, a, tt.want)
+ }
+ }
+}
+
+func TestLayerStringFormat(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ l Layer
+ want string
+ }{
+ {
+ name: "TCP",
+ l: &TCP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ SeqNum: Uint32(3452155723),
+ AckNum: Uint32(2596996163),
+ DataOffset: Uint8(5),
+ Flags: Uint8(20),
+ WindowSize: Uint16(64240),
+ Checksum: Uint16(0x2e2b),
+ },
+ want: "&testbench.TCP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "SeqNum:3452155723 " +
+ "AckNum:2596996163 " +
+ "DataOffset:5 " +
+ "Flags:20 " +
+ "WindowSize:64240 " +
+ "Checksum:11819" +
+ "}",
+ },
+ {
+ name: "UDP",
+ l: &UDP{
+ SrcPort: Uint16(34785),
+ DstPort: Uint16(47767),
+ Length: Uint16(12),
+ },
+ want: "&testbench.UDP{" +
+ "SrcPort:34785 " +
+ "DstPort:47767 " +
+ "Length:12" +
+ "}",
+ },
+ {
+ name: "IPv4",
+ l: &IPv4{
+ IHL: Uint8(5),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(0),
+ Flags: Uint8(2),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(6),
+ Checksum: Uint16(0x2e2b),
+ SrcAddr: Address(tcpip.Address([]byte{197, 34, 63, 10})),
+ DstAddr: Address(tcpip.Address([]byte{197, 34, 63, 20})),
+ },
+ want: "&testbench.IPv4{" +
+ "IHL:5 " +
+ "TOS:0 " +
+ "TotalLength:44 " +
+ "ID:0 " +
+ "Flags:2 " +
+ "FragmentOffset:0 " +
+ "TTL:64 " +
+ "Protocol:6 " +
+ "Checksum:11819 " +
+ "SrcAddr:197.34.63.10 " +
+ "DstAddr:197.34.63.20" +
+ "}",
+ },
+ {
+ name: "Ether",
+ l: &Ether{
+ SrcAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x0a})),
+ DstAddr: LinkAddress(tcpip.LinkAddress([]byte{0x02, 0x42, 0xc5, 0x22, 0x3f, 0x14})),
+ Type: NetworkProtocolNumber(4),
+ },
+ want: "&testbench.Ether{" +
+ "SrcAddr:02:42:c5:22:3f:0a " +
+ "DstAddr:02:42:c5:22:3f:14 " +
+ "Type:4" +
+ "}",
+ },
+ {
+ name: "Payload",
+ l: &Payload{
+ Bytes: []byte("Hooray for packetimpact."),
+ },
+ want: "&testbench.Payload{Bytes:\n" +
+ "00000000 48 6f 6f 72 61 79 20 66 6f 72 20 70 61 63 6b 65 |Hooray for packe|\n" +
+ "00000010 74 69 6d 70 61 63 74 2e |timpact.|\n" +
+ "}",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.l.String(); got != tt.want {
+ t.Errorf("%s.String() = %s, want: %s", tt.name, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestConnectionMatch(t *testing.T) {
+ conn := Connection{
+ layerStates: []layerState{&etherState{}},
+ }
+ protoNum0 := tcpip.NetworkProtocolNumber(0)
+ protoNum1 := tcpip.NetworkProtocolNumber(1)
+ for _, tt := range []struct {
+ description string
+ override, received Layers
+ wantMatch bool
+ }{
+ {
+ description: "shorter override",
+ override: []Layer{&Ether{}},
+ received: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
+ wantMatch: true,
+ },
+ {
+ description: "longer override",
+ override: []Layer{&Ether{}, &Payload{Bytes: []byte("hello")}},
+ received: []Layer{&Ether{}},
+ wantMatch: false,
+ },
+ {
+ description: "ether layer mismatch",
+ override: []Layer{&Ether{Type: &protoNum0}},
+ received: []Layer{&Ether{Type: &protoNum1}},
+ wantMatch: false,
+ },
+ {
+ description: "both nil",
+ override: nil,
+ received: nil,
+ wantMatch: false,
+ },
+ {
+ description: "nil override",
+ override: nil,
+ received: []Layer{&Ether{}},
+ wantMatch: true,
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ if gotMatch := conn.match(tt.override, tt.received); gotMatch != tt.wantMatch {
+ t.Fatalf("conn.match(%s, %s) = %t, want %t", tt.override, tt.received, gotMatch, tt.wantMatch)
+ }
+ })
+ }
+}
+
+func TestLayersDiff(t *testing.T) {
+ for _, tt := range []struct {
+ x, y Layers
+ want string
+ }{
+ {
+ Layers{&Ether{Type: NetworkProtocolNumber(12)}, &TCP{DataOffset: Uint8(5), SeqNum: Uint32(5)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "Ether: Type: 12 13\n" +
+ " TCP: SeqNum: 5 6\n" +
+ " DataOffset: 5 7\n",
+ },
+ {
+ Layers{&Ether{Type: NetworkProtocolNumber(12)}, &UDP{SrcPort: Uint16(123)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "Ether: Type: 12 13\n" +
+ "(UDP doesn't match TCP)\n" +
+ " UDP: SrcPort: 123 \n" +
+ " TCP: SeqNum: 6\n" +
+ " DataOffset: 7\n",
+ },
+ {
+ Layers{&UDP{SrcPort: Uint16(123)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "(UDP doesn't match Ether)\n" +
+ " UDP: SrcPort: 123 \n" +
+ "Ether: Type: 13\n" +
+ "(missing matches TCP)\n",
+ },
+ {
+ Layers{nil, &UDP{SrcPort: Uint16(123)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "(nil matches Ether)\n" +
+ "(UDP doesn't match TCP)\n" +
+ "UDP: SrcPort: 123 \n" +
+ "TCP: SeqNum: 6\n" +
+ " DataOffset: 7\n",
+ },
+ {
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(4)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ Layers{&Ether{Type: NetworkProtocolNumber(13)}, &IPv4{IHL: Uint8(6)}, &TCP{DataOffset: Uint8(7), SeqNum: Uint32(6)}},
+ "(Ether matches Ether)\n" +
+ "IPv4: IHL: 4 6\n" +
+ "(TCP matches TCP)\n",
+ },
+ {
+ Layers{&Payload{Bytes: []byte("foo")}},
+ Layers{&Payload{Bytes: []byte("bar")}},
+ "Payload: Bytes: [102 111 111] [98 97 114]\n",
+ },
+ {
+ Layers{&Payload{Bytes: []byte("")}},
+ Layers{&Payload{}},
+ "",
+ },
+ {
+ Layers{&Payload{Bytes: []byte("")}},
+ Layers{&Payload{Bytes: []byte("")}},
+ "",
+ },
+ {
+ Layers{&UDP{}},
+ Layers{&TCP{}},
+ "(UDP doesn't match TCP)\n" +
+ "(UDP)\n" +
+ "(TCP)\n",
+ },
+ } {
+ if got := tt.x.diff(tt.y); got != tt.want {
+ t.Errorf("%s.diff(%s) = %q, want %q", tt.x, tt.y, got, tt.want)
+ }
+ if tt.x.match(tt.y) != (tt.x.diff(tt.y) == "") {
+ t.Errorf("match and diff of %s and %s disagree", tt.x, tt.y)
+ }
+ if tt.y.match(tt.x) != (tt.y.diff(tt.x) == "") {
+ t.Errorf("match and diff of %s and %s disagree", tt.y, tt.x)
+ }
+ }
+}
+
+func TestTCPOptions(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ wantBytes []byte
+ wantLayers Layers
+ }{
+ {
+ description: "without payload",
+ wantBytes: []byte{
+ // IPv4 Header
+ 0x45, 0x00, 0x00, 0x2c, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06,
+ 0xf9, 0x77, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01,
+ // TCP Header
+ 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xf5, 0x1c, 0x00, 0x00,
+ // WindowScale Option
+ 0x03, 0x03, 0x02,
+ // NOP Option
+ 0x00,
+ },
+ wantLayers: []Layer{
+ &IPv4{
+ IHL: Uint8(20),
+ TOS: Uint8(0),
+ TotalLength: Uint16(44),
+ ID: Uint16(1),
+ Flags: Uint8(0),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(uint8(header.TCPProtocolNumber)),
+ Checksum: Uint16(0xf977),
+ SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())),
+ DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())),
+ },
+ &TCP{
+ SrcPort: Uint16(12345),
+ DstPort: Uint16(54321),
+ SeqNum: Uint32(0),
+ AckNum: Uint32(0),
+ Flags: Uint8(header.TCPFlagSyn),
+ WindowSize: Uint16(8192),
+ Checksum: Uint16(0xf51c),
+ UrgentPointer: Uint16(0),
+ Options: []byte{3, 3, 2, 0},
+ },
+ &Payload{Bytes: nil},
+ },
+ },
+ {
+ description: "with payload",
+ wantBytes: []byte{
+ // IPv4 header
+ 0x45, 0x00, 0x00, 0x37, 0x00, 0x01, 0x00, 0x00, 0x40, 0x06,
+ 0xf9, 0x6c, 0xc0, 0xa8, 0x00, 0x02, 0xc0, 0xa8, 0x00, 0x01,
+ // TCP header
+ 0x30, 0x39, 0xd4, 0x31, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x60, 0x02, 0x20, 0x00, 0xe5, 0x21, 0x00, 0x00,
+ // WindowScale Option
+ 0x03, 0x03, 0x02,
+ // NOP Option
+ 0x00,
+ // Payload: "Sample Data"
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv4{
+ IHL: Uint8(20),
+ TOS: Uint8(0),
+ TotalLength: Uint16(55),
+ ID: Uint16(1),
+ Flags: Uint8(0),
+ FragmentOffset: Uint16(0),
+ TTL: Uint8(64),
+ Protocol: Uint8(uint8(header.TCPProtocolNumber)),
+ Checksum: Uint16(0xf96c),
+ SrcAddr: Address(tcpip.Address(net.ParseIP("192.168.0.2").To4())),
+ DstAddr: Address(tcpip.Address(net.ParseIP("192.168.0.1").To4())),
+ },
+ &TCP{
+ SrcPort: Uint16(12345),
+ DstPort: Uint16(54321),
+ SeqNum: Uint32(0),
+ AckNum: Uint32(0),
+ Flags: Uint8(header.TCPFlagSyn),
+ WindowSize: Uint16(8192),
+ Checksum: Uint16(0xe521),
+ UrgentPointer: Uint16(0),
+ Options: []byte{3, 3, 2, 0},
+ },
+ &Payload{Bytes: []byte("Sample Data")},
+ },
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ layers := parse(parseIPv4, tt.wantBytes)
+ if !layers.match(tt.wantLayers) {
+ t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers))
+ }
+ gotBytes, err := layers.ToBytes()
+ if err != nil {
+ t.Fatalf("ToBytes() failed on %s: %s", &layers, err)
+ }
+ if !bytes.Equal(tt.wantBytes, gotBytes) {
+ t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes)
+ }
+ })
+ }
+}
+
+func TestIPv6ExtHdrOptions(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ wantBytes []byte
+ wantLayers Layers
+ }{
+ {
+ description: "IPv6/HopByHop",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &Payload{
+ Bytes: nil,
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Payload",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Sample Data
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &Payload{
+ Bytes: []byte("Sample Data"),
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Destination/ICMPv6",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x3c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Destination Options
+ 0x3a, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // ICMPv6 Param Problem
+ 0x04, 0x00, 0x5f, 0x98, 0x00, 0x00, 0x00, 0x06,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6DestinationOptionsExtHdrIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &IPv6DestinationOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &ICMPv6{
+ Type: ICMPv6Type(header.ICMPv6ParamProblem),
+ Code: ICMPv6Code(header.ICMPv6ErroneousHeader),
+ Checksum: Uint16(0x5f98),
+ Payload: []byte{0x00, 0x00, 0x00, 0x06},
+ },
+ },
+ },
+ {
+ description: "IPv6/HopByHop/Fragment",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // HopByHop Options
+ 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Fragment ExtHdr
+ 0x3b, 0x00, 0x03, 0x20, 0x00, 0x00, 0x00, 0x2a,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6HopByHopOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ FragmentOffset: Uint16(100),
+ MoreFragments: Bool(false),
+ Identification: Uint32(42),
+ },
+ &Payload{
+ Bytes: nil,
+ },
+ },
+ },
+ {
+ description: "IPv6/DestOpt/Fragment/Payload",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x1b, 0x3c, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // Destination Options
+ 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00,
+ // Fragment ExtHdr
+ 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a,
+ // Sample Data
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6DestinationOptionsExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier),
+ Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00},
+ },
+ &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ FragmentOffset: Uint16(100),
+ MoreFragments: Bool(true),
+ Identification: Uint32(42),
+ },
+ &Payload{
+ Bytes: []byte("Sample Data"),
+ },
+ },
+ },
+ {
+ description: "IPv6/Fragment/Payload",
+ wantBytes: []byte{
+ // IPv6 Header
+ 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x2c, 0x40, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
+ // Fragment ExtHdr
+ 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a,
+ // Sample Data
+ 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61,
+ },
+ wantLayers: []Layer{
+ &IPv6{
+ SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))),
+ DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))),
+ },
+ &IPv6FragmentExtHdr{
+ NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier),
+ FragmentOffset: Uint16(100),
+ MoreFragments: Bool(true),
+ Identification: Uint32(42),
+ },
+ &Payload{
+ Bytes: []byte("Sample Data"),
+ },
+ },
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ layers := parse(parseIPv6, tt.wantBytes)
+ if !layers.match(tt.wantLayers) {
+ t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers))
+ }
+ // Make sure we can generate correct next header values and checksums
+ for _, layer := range layers {
+ switch layer := layer.(type) {
+ case *IPv6HopByHopOptionsExtHdr:
+ layer.NextHeader = nil
+ case *IPv6DestinationOptionsExtHdr:
+ layer.NextHeader = nil
+ case *IPv6FragmentExtHdr:
+ layer.NextHeader = nil
+ case *ICMPv6:
+ layer.Checksum = nil
+ }
+ }
+ gotBytes, err := layers.ToBytes()
+ if err != nil {
+ t.Fatalf("ToBytes() failed on %s: %s", &layers, err)
+ }
+ if !bytes.Equal(tt.wantBytes, gotBytes) {
+ t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
new file mode 100644
index 000000000..57e822725
--- /dev/null
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -0,0 +1,188 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "encoding/binary"
+ "fmt"
+ "math"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Sniffer can sniff raw packets on the wire.
+type Sniffer struct {
+ fd int
+}
+
+func htons(x uint16) uint16 {
+ buf := [2]byte{}
+ binary.BigEndian.PutUint16(buf[:], x)
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+// NewSniffer creates a Sniffer connected to *device.
+func NewSniffer(t *testing.T) (Sniffer, error) {
+ t.Helper()
+
+ snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Sniffer{}, err
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, 1); err != nil {
+ t.Fatalf("can't set sockopt SO_RCVBUFFORCE to 1: %s", err)
+ }
+ if err := unix.SetsockoptInt(snifferFd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1e7); err != nil {
+ t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
+ }
+ return Sniffer{
+ fd: snifferFd,
+ }, nil
+}
+
+// maxReadSize should be large enough for the maximum frame size in bytes. If a
+// packet too large for the buffer arrives, the test will get a fatal error.
+const maxReadSize int = 65536
+
+// Recv tries to read one frame until the timeout is up.
+func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte {
+ t.Helper()
+
+ deadline := time.Now().Add(timeout)
+ for {
+ timeout = deadline.Sub(time.Now())
+ if timeout <= 0 {
+ return nil
+ }
+ whole, frac := math.Modf(timeout.Seconds())
+ tv := unix.Timeval{
+ Sec: int64(whole),
+ Usec: int64(frac * float64(time.Microsecond/time.Second)),
+ }
+
+ if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
+ t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
+ }
+
+ buf := make([]byte, maxReadSize)
+ nread, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN {
+ // There was a timeout.
+ continue
+ }
+ if err != nil {
+ t.Fatalf("can't read: %s", err)
+ }
+ if nread > maxReadSize {
+ t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize)
+ }
+ return buf[:nread]
+ }
+}
+
+// Drain drains the Sniffer's socket receive buffer by receiving until there's
+// nothing else to receive.
+func (s *Sniffer) Drain(t *testing.T) {
+ t.Helper()
+
+ flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0)
+ if err != nil {
+ t.Fatalf("failed to get sniffer socket fd flags: %s", err)
+ }
+ nonBlockingFlags := flags | unix.O_NONBLOCK
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil {
+ t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err)
+ }
+ for {
+ buf := make([]byte, maxReadSize)
+ _, _, err := unix.Recvfrom(s.fd, buf, unix.MSG_TRUNC)
+ if err == unix.EINTR || err == unix.EAGAIN || err == unix.EWOULDBLOCK {
+ break
+ }
+ }
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil {
+ t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err)
+ }
+}
+
+// close the socket that Sniffer is using.
+func (s *Sniffer) close() error {
+ if err := unix.Close(s.fd); err != nil {
+ return fmt.Errorf("can't close sniffer socket: %w", err)
+ }
+ s.fd = -1
+ return nil
+}
+
+// Injector can inject raw frames.
+type Injector struct {
+ fd int
+}
+
+// NewInjector creates a new injector on *device.
+func NewInjector(t *testing.T) (Injector, error) {
+ t.Helper()
+
+ ifInfo, err := net.InterfaceByName(Device)
+ if err != nil {
+ return Injector{}, err
+ }
+
+ var haddr [8]byte
+ copy(haddr[:], ifInfo.HardwareAddr)
+ sa := unix.SockaddrLinklayer{
+ Protocol: unix.ETH_P_IP,
+ Ifindex: ifInfo.Index,
+ Halen: uint8(len(ifInfo.HardwareAddr)),
+ Addr: haddr,
+ }
+
+ injectFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
+ if err != nil {
+ return Injector{}, err
+ }
+ if err := unix.Bind(injectFd, &sa); err != nil {
+ return Injector{}, err
+ }
+ return Injector{
+ fd: injectFd,
+ }, nil
+}
+
+// Send a raw frame.
+func (i *Injector) Send(t *testing.T, b []byte) {
+ t.Helper()
+
+ n, err := unix.Write(i.fd, b)
+ if err != nil {
+ t.Fatalf("can't write bytes of len %d: %s", len(b), err)
+ }
+ if n != len(b) {
+ t.Fatalf("got %d bytes written, want %d", n, len(b))
+ }
+}
+
+// close the underlying socket.
+func (i *Injector) close() error {
+ if err := unix.Close(i.fd); err != nil {
+ return fmt.Errorf("can't close sniffer socket: %w", err)
+ }
+ i.fd = -1
+ return nil
+}
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
new file mode 100644
index 000000000..e3629e1f3
--- /dev/null
+++ b/test/packetimpact/testbench/testbench.go
@@ -0,0 +1,128 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package testbench
+
+import (
+ "flag"
+ "fmt"
+ "math/rand"
+ "net"
+ "os/exec"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/test/packetimpact/netdevs"
+)
+
+var (
+ // Native indicates that the test is being run natively.
+ Native = false
+ // Device is the local device on the test network.
+ Device = ""
+
+ // LocalIPv4 is the local IPv4 address on the test network.
+ LocalIPv4 = ""
+ // RemoteIPv4 is the DUT's IPv4 address on the test network.
+ RemoteIPv4 = ""
+ // IPv4PrefixLength is the network prefix length of the IPv4 test network.
+ IPv4PrefixLength = 0
+
+ // LocalIPv6 is the local IPv6 address on the test network.
+ LocalIPv6 = ""
+ // RemoteIPv6 is the DUT's IPv6 address on the test network.
+ RemoteIPv6 = ""
+
+ // LocalInterfaceID is the ID of the local interface on the test network.
+ LocalInterfaceID uint32
+ // RemoteInterfaceID is the ID of the remote interface on the test network.
+ //
+ // Not using uint32 because package flag does not support uint32.
+ RemoteInterfaceID uint64
+
+ // LocalMAC is the local MAC address on the test network.
+ LocalMAC = ""
+ // RemoteMAC is the DUT's MAC address on the test network.
+ RemoteMAC = ""
+
+ // POSIXServerIP is the POSIX server's IP address on the control network.
+ POSIXServerIP = ""
+ // POSIXServerPort is the UDP port the POSIX server is bound to on the
+ // control network.
+ POSIXServerPort = 40000
+
+ // RPCKeepalive is the gRPC keepalive.
+ RPCKeepalive = 10 * time.Second
+ // RPCTimeout is the gRPC timeout.
+ RPCTimeout = 100 * time.Millisecond
+)
+
+// RegisterFlags defines flags and associates them with the package-level
+// exported variables above. It should be called by tests in their init
+// functions.
+func RegisterFlags(fs *flag.FlagSet) {
+ fs.StringVar(&POSIXServerIP, "posix_server_ip", POSIXServerIP, "ip address to listen to for UDP commands")
+ fs.IntVar(&POSIXServerPort, "posix_server_port", POSIXServerPort, "port to listen to for UDP commands")
+ fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout")
+ fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
+ fs.StringVar(&LocalIPv4, "local_ipv4", LocalIPv4, "local IPv4 address for test packets")
+ fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets")
+ fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets")
+ fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets")
+ fs.StringVar(&Device, "device", Device, "local device for test packets")
+ fs.BoolVar(&Native, "native", Native, "whether the test is running natively")
+ fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets")
+}
+
+// genPseudoFlags populates flag-like global config based on real flags.
+//
+// genPseudoFlags must only be called after flag.Parse.
+func genPseudoFlags() error {
+ out, err := exec.Command("ip", "addr", "show").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("listing devices: %q: %w", string(out), err)
+ }
+ devs, err := netdevs.ParseDevices(string(out))
+ if err != nil {
+ return fmt.Errorf("parsing devices: %w", err)
+ }
+
+ _, deviceInfo, err := netdevs.FindDeviceByIP(net.ParseIP(LocalIPv4), devs)
+ if err != nil {
+ return fmt.Errorf("can't find deviceInfo: %w", err)
+ }
+
+ LocalMAC = deviceInfo.MAC.String()
+ LocalIPv6 = deviceInfo.IPv6Addr.String()
+ LocalInterfaceID = deviceInfo.ID
+
+ if deviceInfo.IPv4Net != nil {
+ IPv4PrefixLength, _ = deviceInfo.IPv4Net.Mask.Size()
+ } else {
+ IPv4PrefixLength, _ = net.ParseIP(LocalIPv4).DefaultMask().Size()
+ }
+
+ return nil
+}
+
+// GenerateRandomPayload generates a random byte slice of the specified length,
+// causing a fatal test failure if it is unable to do so.
+func GenerateRandomPayload(t *testing.T, n int) []byte {
+ t.Helper()
+ buf := make([]byte, n)
+ if _, err := rand.Read(buf); err != nil {
+ t.Fatalf("rand.Read(buf) failed: %s", err)
+ }
+ return buf
+}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
new file mode 100644
index 000000000..74658fea0
--- /dev/null
+++ b/test/packetimpact/tests/BUILD
@@ -0,0 +1,310 @@
+load("//test/packetimpact/runner:defs.bzl", "packetimpact_go_test")
+
+package(
+ default_visibility = ["//test/packetimpact:__subpackages__"],
+ licenses = ["notice"],
+)
+
+packetimpact_go_test(
+ name = "fin_wait2_timeout",
+ srcs = ["fin_wait2_timeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "ipv4_id_uniqueness",
+ srcs = ["ipv4_id_uniqueness_test.go"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_discard_mcast_source_addr",
+ srcs = ["udp_discard_mcast_source_addr_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_recv_mcast_bcast",
+ srcs = ["udp_recv_mcast_bcast_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_any_addr_recv_unicast",
+ srcs = ["udp_any_addr_recv_unicast_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//test/packetimpact/testbench",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_icmp_error_propagation",
+ srcs = ["udp_icmp_error_propagation_test.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_reordering",
+ srcs = ["tcp_reordering_test.go"],
+ # TODO(b/139368047): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_window_shrink",
+ srcs = ["tcp_window_shrink_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_zero_window_probe",
+ srcs = ["tcp_zero_window_probe_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_zero_window_probe_retransmit",
+ srcs = ["tcp_zero_window_probe_retransmit_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_zero_window_probe_usertimeout",
+ srcs = ["tcp_zero_window_probe_usertimeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_retransmits",
+ srcs = ["tcp_retransmits_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_outside_the_window",
+ srcs = ["tcp_outside_the_window_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_noaccept_close_rst",
+ srcs = ["tcp_noaccept_close_rst_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_send_window_sizes_piggyback",
+ srcs = ["tcp_send_window_sizes_piggyback_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_close_wait_ack",
+ srcs = ["tcp_close_wait_ack_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_paws_mechanism",
+ srcs = ["tcp_paws_mechanism_test.go"],
+ # TODO(b/156682000): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_user_timeout",
+ srcs = ["tcp_user_timeout_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_queue_receive_in_syn_sent",
+ srcs = ["tcp_queue_receive_in_syn_sent_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_synsent_reset",
+ srcs = ["tcp_synsent_reset_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_synrcvd_reset",
+ srcs = ["tcp_synrcvd_reset_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_network_unreachable",
+ srcs = ["tcp_network_unreachable_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_cork_mss",
+ srcs = ["tcp_cork_mss_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "tcp_handshake_window_size",
+ srcs = ["tcp_handshake_window_size_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "icmpv6_param_problem",
+ srcs = ["icmpv6_param_problem_test.go"],
+ # TODO(b/153485026): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "ipv6_unknown_options_action",
+ srcs = ["ipv6_unknown_options_action_test.go"],
+ # TODO(b/159928940): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "ipv6_fragment_reassembly",
+ srcs = ["ipv6_fragment_reassembly_test.go"],
+ # TODO(b/160919104): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_go_test(
+ name = "udp_send_recv_dgram",
+ srcs = ["udp_send_recv_dgram_test.go"],
+ deps = [
+ "//test/packetimpact/testbench",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
new file mode 100644
index 000000000..a61054c2c
--- /dev/null
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -0,0 +1,75 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fin_wait2_timeout_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestFinWait2Timeout(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ linger2 bool
+ }{
+ {"WithLinger2", true},
+ {"WithoutLinger2", false},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+ conn.Connect(t)
+
+ acceptFd, _ := dut.Accept(t, listenFd)
+ if tt.linger2 {
+ tv := unix.Timeval{Sec: 1, Usec: 0}
+ dut.SetSockOptTimeval(t, acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv)
+ }
+ dut.Close(t, acceptFd)
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ time.Sleep(5 * time.Second)
+ conn.Drain(t)
+
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if tt.linger2 {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected a RST packet within a second but got none: %s", err)
+ }
+ } else {
+ if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
+ t.Fatalf("expected no RST packets within ten seconds but got one: %s", got)
+ }
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go
new file mode 100644
index 000000000..2d59d552d
--- /dev/null
+++ b/test/packetimpact/tests/icmpv6_param_problem_test.go
@@ -0,0 +1,78 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmpv6_param_problem_test
+
+import (
+ "encoding/binary"
+ "flag"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestICMPv6ParamProblemTest sends a packet with a bad next header. The DUT
+// should respond with an ICMPv6 Parameter Problem message.
+func TestICMPv6ParamProblemTest(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ defer conn.Close(t)
+ ipv6 := testbench.IPv6{
+ // 254 is reserved and used for experimentation and testing. This should
+ // cause an error.
+ NextHeader: testbench.Uint8(254),
+ }
+ icmpv6 := testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest),
+ Payload: []byte("hello world"),
+ }
+
+ toSend := (*testbench.Connection)(&conn).CreateFrame(t, testbench.Layers{&ipv6}, &icmpv6)
+ (*testbench.Connection)(&conn).SendFrame(t, toSend)
+
+ // Build the expected ICMPv6 payload, which includes an index to the
+ // problematic byte and also the problematic packet as described in
+ // https://tools.ietf.org/html/rfc4443#page-12 .
+ ipv6Sent := toSend[1:]
+ expectedPayload, err := ipv6Sent.ToBytes()
+ if err != nil {
+ t.Fatalf("can't convert %s to bytes: %s", ipv6Sent, err)
+ }
+
+ // The problematic field is the NextHeader.
+ b := make([]byte, 4)
+ binary.BigEndian.PutUint32(b, header.IPv6NextHeaderOffset)
+ expectedPayload = append(b, expectedPayload...)
+ expectedICMPv6 := testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem),
+ Payload: expectedPayload,
+ }
+
+ paramProblem := testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
+ &expectedICMPv6,
+ }
+ timeout := time.Second
+ if _, err := conn.ExpectFrame(t, paramProblem, timeout); err != nil {
+ t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err)
+ }
+}
diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
new file mode 100644
index 000000000..cf881418c
--- /dev/null
+++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
@@ -0,0 +1,122 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv4_id_uniqueness_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func recvTCPSegment(t *testing.T, conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) {
+ layers, err := conn.ExpectData(t, expect, expectPayload, time.Second)
+ if err != nil {
+ return 0, fmt.Errorf("failed to receive TCP segment: %s", err)
+ }
+ if len(layers) < 2 {
+ return 0, fmt.Errorf("got packet with layers: %v, expected to have at least 2 layers (link and network)", layers)
+ }
+ ipv4, ok := layers[1].(*testbench.IPv4)
+ if !ok {
+ return 0, fmt.Errorf("got network layer: %T, expected: *IPv4", layers[1])
+ }
+ if *ipv4.Flags&header.IPv4FlagDontFragment != 0 {
+ return 0, fmt.Errorf("got IPv4 DF=1, expected DF=0")
+ }
+ return *ipv4.ID, nil
+}
+
+// RFC 6864 section 4.2 states: "The IPv4 ID of non-atomic datagrams MUST NOT
+// be reused when sending a copy of an earlier non-atomic datagram."
+//
+// This test creates a TCP connection, uses the IP_MTU_DISCOVER socket option
+// to force the DF bit to be 0, and checks that a retransmitted segment has a
+// different IPv4 Identification value than the original segment.
+func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
+ for _, tc := range []struct {
+ name string
+ payload []byte
+ }{
+ {"SmallPayload", []byte("sample data")},
+ // 512 bytes is chosen because sending more than this in a single segment
+ // causes the retransmission to send less than the original amount.
+ {"512BytePayload", testbench.GenerateRandomPayload(t, 512)},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ remoteFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, remoteFD)
+
+ dut.SetSockOptInt(t, remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ // TODO(b/129291778) The following socket option clears the DF bit on
+ // IP packets sent over the socket, and is currently not supported by
+ // gVisor. gVisor by default sends packets with DF=0 anyway, so the
+ // socket option being not supported does not affect the operation of
+ // this test. Once the socket option is supported, the following call
+ // can be changed to simply assert success.
+ ret, errno := dut.SetSockOptIntWithErrno(context.Background(), t, remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT)
+ if ret == -1 && errno != unix.ENOTSUP {
+ t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno)
+ }
+
+ samplePayload := &testbench.Payload{Bytes: tc.payload}
+
+ dut.Send(t, remoteFD, tc.payload, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err)
+ }
+ // Let the DUT estimate RTO with RTT from the DATA-ACK.
+ // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
+ // we can skip sending this ACK.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ dut.Send(t, remoteFD, tc.payload, 0)
+ expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))}
+ originalID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload)
+ if err != nil {
+ t.Fatalf("failed to receive TCP segment: %s", err)
+ }
+
+ retransmitID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload)
+ if err != nil {
+ t.Fatalf("failed to receive retransmitted TCP segment: %s", err)
+ }
+ if originalID == retransmitID {
+ t.Fatalf("unexpectedly got retransmitted TCP segment with same IPv4 ID field=%d", originalID)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
new file mode 100644
index 000000000..a24c85566
--- /dev/null
+++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
@@ -0,0 +1,168 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_fragment_reassembly_test
+
+import (
+ "bytes"
+ "encoding/binary"
+ "encoding/hex"
+ "flag"
+ "net"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+const (
+ // The payload length for the first fragment we send. This number
+ // is a multiple of 8 near 750 (half of 1500).
+ firstPayloadLength = 752
+ // The ID field for our outgoing fragments.
+ fragmentID = 1
+ // A node must be able to accept a fragmented packet that,
+ // after reassembly, is as large as 1500 octets.
+ reassemblyCap = 1500
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestIPv6FragmentReassembly(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ defer conn.Close(t)
+
+ firstPayloadToSend := make([]byte, firstPayloadLength)
+ for i := range firstPayloadToSend {
+ firstPayloadToSend[i] = 'A'
+ }
+
+ secondPayloadLength := reassemblyCap - firstPayloadLength - header.ICMPv6EchoMinimumSize
+ secondPayloadToSend := firstPayloadToSend[:secondPayloadLength]
+
+ icmpv6EchoPayload := make([]byte, 4)
+ binary.BigEndian.PutUint16(icmpv6EchoPayload[0:], 0)
+ binary.BigEndian.PutUint16(icmpv6EchoPayload[2:], 0)
+ icmpv6EchoPayload = append(icmpv6EchoPayload, firstPayloadToSend...)
+
+ lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16())
+ rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16())
+ icmpv6 := testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest),
+ Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode),
+ Payload: icmpv6EchoPayload,
+ }
+ icmpv6Bytes, err := icmpv6.ToBytes()
+ if err != nil {
+ t.Fatalf("failed to serialize ICMPv6: %s", err)
+ }
+ cksum := header.ICMPv6Checksum(
+ header.ICMPv6(icmpv6Bytes),
+ lIP,
+ rIP,
+ buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}),
+ )
+
+ conn.Send(t, testbench.IPv6{},
+ &testbench.IPv6FragmentExtHdr{
+ FragmentOffset: testbench.Uint16(0),
+ MoreFragments: testbench.Bool(true),
+ Identification: testbench.Uint32(fragmentID),
+ },
+ &testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest),
+ Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode),
+ Payload: icmpv6EchoPayload,
+ Checksum: &cksum,
+ })
+
+ icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)
+
+ conn.Send(t, testbench.IPv6{},
+ &testbench.IPv6FragmentExtHdr{
+ NextHeader: &icmpv6ProtoNum,
+ FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8),
+ MoreFragments: testbench.Bool(false),
+ Identification: testbench.Uint32(fragmentID),
+ },
+ &testbench.Payload{
+ Bytes: secondPayloadToSend,
+ })
+
+ gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
+ &testbench.IPv6FragmentExtHdr{
+ FragmentOffset: testbench.Uint16(0),
+ MoreFragments: testbench.Bool(true),
+ },
+ &testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6EchoReply),
+ Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode),
+ },
+ }, time.Second)
+ if err != nil {
+ t.Fatalf("expected a fragmented ICMPv6 Echo Reply, but got none: %s", err)
+ }
+
+ id := *gotEchoReplyFirstPart[2].(*testbench.IPv6FragmentExtHdr).Identification
+ gotFirstPayload, err := gotEchoReplyFirstPart[len(gotEchoReplyFirstPart)-1].ToBytes()
+ if err != nil {
+ t.Fatalf("failed to serialize ICMPv6: %s", err)
+ }
+ icmpPayload := gotFirstPayload[header.ICMPv6EchoMinimumSize:]
+ receivedLen := len(icmpPayload)
+ wantSecondPayloadLen := reassemblyCap - header.ICMPv6EchoMinimumSize - receivedLen
+ wantFirstPayload := make([]byte, receivedLen)
+ for i := range wantFirstPayload {
+ wantFirstPayload[i] = 'A'
+ }
+ wantSecondPayload := wantFirstPayload[:wantSecondPayloadLen]
+ if !bytes.Equal(icmpPayload, wantFirstPayload) {
+ t.Fatalf("received unexpected payload, got: %s, want: %s",
+ hex.Dump(icmpPayload),
+ hex.Dump(wantFirstPayload))
+ }
+
+ gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
+ &testbench.IPv6FragmentExtHdr{
+ NextHeader: &icmpv6ProtoNum,
+ FragmentOffset: testbench.Uint16(uint16((receivedLen + header.ICMPv6EchoMinimumSize) / 8)),
+ MoreFragments: testbench.Bool(false),
+ Identification: &id,
+ },
+ &testbench.ICMPv6{},
+ }, time.Second)
+ if err != nil {
+ t.Fatalf("expected the rest of ICMPv6 Echo Reply, but got none: %s", err)
+ }
+ secondPayload, err := gotEchoReplySecondPart[len(gotEchoReplySecondPart)-1].ToBytes()
+ if err != nil {
+ t.Fatalf("failed to serialize ICMPv6 Echo Reply: %s", err)
+ }
+ if !bytes.Equal(secondPayload, wantSecondPayload) {
+ t.Fatalf("received unexpected payload, got: %s, want: %s",
+ hex.Dump(secondPayload),
+ hex.Dump(wantSecondPayload))
+ }
+}
diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
new file mode 100644
index 000000000..e79d74476
--- /dev/null
+++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
@@ -0,0 +1,187 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6_unknown_options_action_test
+
+import (
+ "encoding/binary"
+ "flag"
+ "net"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func mkHopByHopOptionsExtHdr(optType byte) testbench.Layer {
+ return &testbench.IPv6HopByHopOptionsExtHdr{
+ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00},
+ }
+}
+
+func mkDestinationOptionsExtHdr(optType byte) testbench.Layer {
+ return &testbench.IPv6DestinationOptionsExtHdr{
+ Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00},
+ }
+}
+
+func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte {
+ return byte(action << 6)
+}
+
+func TestIPv6UnknownOptionAction(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ mkExtHdr func(optType byte) testbench.Layer
+ action header.IPv6OptionUnknownAction
+ multicastDst bool
+ wantICMPv6 bool
+ }{
+ {
+ description: "0b00/hbh",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionSkip,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b01/hbh",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscard,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b10/hbh/unicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b10/hbh/multicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: true,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/hbh/unicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/hbh/multicast",
+ mkExtHdr: mkHopByHopOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: true,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b00/destination",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionSkip,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b01/destination",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscard,
+ multicastDst: false,
+ wantICMPv6: false,
+ },
+ {
+ description: "0b10/destination/unicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b10/destination/multicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMP,
+ multicastDst: true,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/destination/unicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: false,
+ wantICMPv6: true,
+ },
+ {
+ description: "0b11/destination/multicast",
+ mkExtHdr: mkDestinationOptionsExtHdr,
+ action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest,
+ multicastDst: true,
+ wantICMPv6: false,
+ },
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ conn := (*testbench.Connection)(&ipv6Conn)
+ defer ipv6Conn.Close(t)
+
+ outgoingOverride := testbench.Layers{}
+ if tt.multicastDst {
+ outgoingOverride = testbench.Layers{&testbench.IPv6{
+ DstAddr: testbench.Address(tcpip.Address(net.ParseIP("ff02::1"))),
+ }}
+ }
+
+ outgoing := conn.CreateFrame(t, outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action)))
+ conn.SendFrame(t, outgoing)
+ ipv6Sent := outgoing[1:]
+ invokingPacket, err := ipv6Sent.ToBytes()
+ if err != nil {
+ t.Fatalf("failed to serialize the outgoing packet: %s", err)
+ }
+ icmpv6Payload := make([]byte, 4)
+ // The pointer in the ICMPv6 parameter problem message should point to
+ // the option type of the unknown option. In our test case, it is the
+ // first option in the extension header whose option type is 2 bytes
+ // after the IPv6 header (after NextHeader and ExtHdrLen).
+ binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2)
+ icmpv6Payload = append(icmpv6Payload, invokingPacket...)
+ gotICMPv6, err := ipv6Conn.ExpectFrame(t, testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
+ &testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem),
+ Code: testbench.ICMPv6Code(header.ICMPv6UnknownOption),
+ Payload: icmpv6Payload,
+ },
+ }, time.Second)
+ if tt.wantICMPv6 && err != nil {
+ t.Fatalf("expected ICMPv6 Parameter Problem but got none: %s", err)
+ }
+ if !tt.wantICMPv6 && gotICMPv6 != nil {
+ t.Fatalf("expected no ICMPv6 Parameter Problem but got one: %s", gotICMPv6)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
new file mode 100644
index 000000000..e6a96f214
--- /dev/null
+++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_close_wait_ack_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestCloseWaitAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ }{
+ {"OTW", generateOTWSeqSegment, 0, false},
+ {"OTW", generateOTWSeqSegment, 1, true},
+ {"OTW", generateOTWSeqSegment, 2, true},
+ {"ACK", generateUnaccACKSegment, 0, false},
+ {"ACK", generateUnaccACKSegment, 1, true},
+ {"ACK", generateUnaccACKSegment, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+
+ // Send a FIN to DUT to intiate the active close
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
+ }
+ windowSize := seqnum.Size(*gotTCP.WindowSize)
+
+ // Send a segment with OTW Seq / unacc ACK and expect an ACK back
+ conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")})
+ gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Fatalf("expected an ack but got none: %s", err)
+ }
+ if !tt.expectAck && gotAck != nil {
+ t.Fatalf("expected no ack but got one: %s", gotAck)
+ }
+
+ // Now let's verify DUT is indeed in CLOSE_WAIT
+ dut.Close(t, acceptFd)
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send a FIN: %s", err)
+ }
+ // Ack the FIN from DUT
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Send some extra data to DUT
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ t.Fatalf("expected DUT to send an RST: %s", err)
+ }
+ })
+ }
+}
+
+// generateOTWSeqSegment generates an segment with
+// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only
+// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the
+// receiver.
+func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.LocalSeqNum(t).Add(windowSize)
+ otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
+ return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)}
+}
+
+// generateUnaccACKSegment generates an segment with
+// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable
+// when seqNumOffset is 0, otherwise an ACK is expected from the receiver.
+func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.RemoteSeqNum(t)
+ unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
+ return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)}
+}
diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go
new file mode 100644
index 000000000..8feea4a82
--- /dev/null
+++ b/test/packetimpact/tests/tcp_cork_mss_test.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_cork_mss_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPCorkMSS tests for segment coalesce and split as per MSS.
+func TestTCPCorkMSS(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ const mss = uint32(header.TCPDefaultMSS)
+ options := make([]byte, header.TCPOptionMSSLength)
+ header.EncodeMSSOption(mss, options)
+ conn.ConnectWithOptions(t, options)
+
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
+
+ dut.SetSockOptInt(t, acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1)
+
+ // Let the dut application send 2 small segments to be held up and coalesced
+ // until the application sends a larger segment to fill up to > MSS.
+ sampleData := []byte("Sample Data")
+ dut.Send(t, acceptFD, sampleData, 0)
+ dut.Send(t, acceptFD, sampleData, 0)
+
+ expectedData := sampleData
+ expectedData = append(expectedData, sampleData...)
+ largeData := make([]byte, mss+1)
+ expectedData = append(expectedData, largeData...)
+ dut.Send(t, acceptFD, largeData, 0)
+
+ // Expect the segments to be coalesced and sent and capped to MSS.
+ expectedPayload := testbench.Payload{Bytes: expectedData[:mss]}
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Expect the coalesced segment to be split and transmitted.
+ expectedPayload = testbench.Payload{Bytes: expectedData[mss:]}
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Check for segments to *not* be held up because of TCP_CORK when
+ // the current send window is less than MSS.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))})
+ dut.Send(t, acceptFD, sampleData, 0)
+ dut.Send(t, acceptFD, sampleData, 0)
+ expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)}
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+}
diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go
new file mode 100644
index 000000000..22937d92f
--- /dev/null
+++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_handshake_window_size_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPHandshakeWindowSize tests if the stack is honoring the window size
+// communicated during handshake.
+func TestTCPHandshakeWindowSize(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ // Start handshake with zero window size.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK: %s", err)
+ }
+ // Update the advertised window size to a non-zero value with the ACK that
+ // completes the handshake.
+ //
+ // Set the window size with MSB set and expect the dut to treat it as
+ // an unsigned value.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))})
+
+ acceptFd, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFd)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ // Since we advertised a zero window followed by a non-zero window,
+ // expect the dut to honor the recently advertised non-zero window
+ // and actually send out the data instead of probing for zero window.
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go
new file mode 100644
index 000000000..2f57dff19
--- /dev/null
+++ b/test/packetimpact/tests/tcp_network_unreachable_test.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_synsent_reset_test
+
+import (
+ "context"
+ "flag"
+ "net"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPSynSentUnreachable verifies that TCP connections fail immediately when
+// an ICMP destination unreachable message is sent in response to the inital
+// SYN.
+func TestTCPSynSentUnreachable(t *testing.T) {
+ // Create the DUT and connection.
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ port := uint16(9001)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port})
+ defer conn.Close(t)
+
+ // Bring the DUT to SYN-SENT state with a non-blocking connect.
+ ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout)
+ defer cancel()
+ sa := unix.SockaddrInet4{Port: int(port)}
+ copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4())
+ if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
+ t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err)
+ }
+
+ // Get the SYN.
+ tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
+ if err != nil {
+ t.Fatalf("expected SYN: %s", err)
+ }
+
+ // Send a host unreachable message.
+ rawConn := (*testbench.Connection)(&conn)
+ layers := rawConn.CreateFrame(t, nil)
+ layers = layers[:len(layers)-1]
+ const ipLayer = 1
+ const tcpLayer = ipLayer + 1
+ ip, ok := tcpLayers[ipLayer].(*testbench.IPv4)
+ if !ok {
+ t.Fatalf("expected %s to be IPv4", tcpLayers[ipLayer])
+ }
+ tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP)
+ if !ok {
+ t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer])
+ }
+ var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{
+ Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable),
+ Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)}
+ layers = append(layers, &icmpv4, ip, tcp)
+ rawConn.SendFrameStateless(t, layers)
+
+ if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) {
+ t.Errorf("expected connect to fail with EHOSTUNREACH, but got %v", err)
+ }
+}
+
+// TestTCPSynSentUnreachable6 verifies that TCP connections fail immediately when
+// an ICMP destination unreachable message is sent in response to the inital
+// SYN.
+func TestTCPSynSentUnreachable6(t *testing.T) {
+ // Create the DUT and connection.
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6))
+ conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort})
+ defer conn.Close(t)
+
+ // Bring the DUT to SYN-SENT state with a non-blocking connect.
+ ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout)
+ defer cancel()
+ sa := unix.SockaddrInet6{
+ Port: int(conn.SrcPort()),
+ ZoneId: uint32(testbench.RemoteInterfaceID),
+ }
+ copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16())
+ if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
+ t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err)
+ }
+
+ // Get the SYN.
+ tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
+ if err != nil {
+ t.Fatalf("expected SYN: %s", err)
+ }
+
+ // Send a host unreachable message.
+ rawConn := (*testbench.Connection)(&conn)
+ layers := rawConn.CreateFrame(t, nil)
+ layers = layers[:len(layers)-1]
+ const ipLayer = 1
+ const tcpLayer = ipLayer + 1
+ ip, ok := tcpLayers[ipLayer].(*testbench.IPv6)
+ if !ok {
+ t.Fatalf("expected %s to be IPv6", tcpLayers[ipLayer])
+ }
+ tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP)
+ if !ok {
+ t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer])
+ }
+ var icmpv6 testbench.ICMPv6 = testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6DstUnreachable),
+ Code: testbench.ICMPv6Code(header.ICMPv6NetworkUnreachable),
+ // Per RFC 4443 3.1, the payload contains 4 zeroed bytes.
+ Payload: []byte{0, 0, 0, 0},
+ }
+ layers = append(layers, &icmpv6, ip, tcp)
+ rawConn.SendFrameStateless(t, layers)
+
+ if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) {
+ t.Errorf("expected connect to fail with ENETUNREACH, but got %v", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
new file mode 100644
index 000000000..82b7a85ff
--- /dev/null
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_noaccept_close_rst_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestTcpNoAcceptCloseReset(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ conn.Connect(t)
+ defer conn.Close(t)
+ dut.Close(t, listenFd)
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ t.Fatalf("expected a RST-ACK packet but got none: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
new file mode 100644
index 000000000..08f759f7c
--- /dev/null
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -0,0 +1,93 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_outside_the_window_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPOutsideTheWindows tests the behavior of the DUT when packets arrive
+// that are inside or outside the TCP window. Packets that are outside the
+// window should force an extra ACK, as described in RFC793 page 69:
+// https://tools.ietf.org/html/rfc793#page-69
+func TestTCPOutsideTheWindow(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ tcpFlags uint8
+ payload []testbench.Layer
+ seqNumOffset seqnum.Size
+ expectACK bool
+ }{
+ {"SYN", header.TCPFlagSyn, nil, 0, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 0, true},
+ {"ACK", header.TCPFlagAck, nil, 0, false},
+ {"FIN", header.TCPFlagFin, nil, 0, false},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 0, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 1, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 1, true},
+ {"ACK", header.TCPFlagAck, nil, 1, true},
+ {"FIN", header.TCPFlagFin, nil, 1, false},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 1, true},
+
+ {"SYN", header.TCPFlagSyn, nil, 2, true},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil, 2, true},
+ {"ACK", header.TCPFlagAck, nil, 2, true},
+ {"FIN", header.TCPFlagFin, nil, 2, false},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}, 2, true},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
+
+ windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + tt.seqNumOffset
+ conn.Drain(t)
+ // Ignore whatever incrementing that this out-of-order packet might cause
+ // to the AckNum.
+ localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t)))
+ conn.Send(t, testbench.TCP{
+ Flags: testbench.Uint8(tt.tcpFlags),
+ SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))),
+ }, tt.payload...)
+ timeout := 3 * time.Second
+ gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ if tt.expectACK && err != nil {
+ t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err)
+ }
+ if !tt.expectACK && gotACK != nil {
+ t.Fatalf("expected no ACK packet within %s but got one: %s", timeout, gotACK)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go
new file mode 100644
index 000000000..37f3b56dd
--- /dev/null
+++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go
@@ -0,0 +1,109 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_paws_mechanism_test
+
+import (
+ "encoding/hex"
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestPAWSMechanism(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ options := make([]byte, header.TCPOptionTSLength)
+ header.EncodeTSOption(currentTS(), 0, options)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options})
+ synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("didn't get synack during handshake: %s", err)
+ }
+ parsedSynOpts := header.ParseSynOptions(synAck.Options, true)
+ if !parsedSynOpts.TS {
+ t.Fatalf("expected TSOpt from DUT, options we got:\n%s", hex.Dump(synAck.Options))
+ }
+ tsecr := parsedSynOpts.TSVal
+ header.EncodeTSOption(currentTS(), tsecr, options)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options})
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
+
+ sampleData := []byte("Sample Data")
+ sentTSVal := currentTS()
+ header.EncodeTSOption(sentTSVal, tsecr, options)
+ // 3ms here is chosen arbitrarily to make sure we have increasing timestamps
+ // every time we send one, it should not cause any flakiness because timestamps
+ // only need to be non-decreasing.
+ time.Sleep(3 * time.Millisecond)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected an ACK but got none: %s", err)
+ }
+
+ parsedOpts := header.ParseTCPOptions(gotTCP.Options)
+ if !parsedOpts.TS {
+ t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options))
+ }
+ if parsedOpts.TSVal < tsecr {
+ t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr)
+ }
+ if parsedOpts.TSEcr != sentTSVal {
+ t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal)
+ }
+ tsecr = parsedOpts.TSVal
+ lastAckNum := gotTCP.AckNum
+
+ badTSVal := sentTSVal - 100
+ header.EncodeTSOption(badTSVal, tsecr, options)
+ // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness
+ // due to the exact same reasoning discussed above.
+ time.Sleep(3 * time.Millisecond)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+
+ gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err)
+ }
+ parsedOpts = header.ParseTCPOptions(gotTCP.Options)
+ if !parsedOpts.TS {
+ t.Fatalf("expected TS option in response, options we got:\n%s", hex.Dump(gotTCP.Options))
+ }
+ if parsedOpts.TSVal < tsecr {
+ t.Fatalf("TSVal should be non-decreasing, but %d < %d", parsedOpts.TSVal, tsecr)
+ }
+ if parsedOpts.TSEcr != sentTSVal {
+ t.Fatalf("TSEcr should match our sent TSVal, %d != %d", parsedOpts.TSEcr, sentTSVal)
+ }
+}
+
+func currentTS() uint32 {
+ return uint32(time.Now().UnixNano() / 1e6)
+}
diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
new file mode 100644
index 000000000..d9f3ea0f2
--- /dev/null
+++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
@@ -0,0 +1,132 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_queue_receive_in_syn_sent_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/hex"
+ "errors"
+ "flag"
+ "net"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestQueueReceiveInSynSent tests receive behavior when the TCP state
+// is SYN-SENT.
+// It tests for 2 variants where the receive is blocked and:
+// (1) we complete handshake and send sample data.
+// (2) we send a TCP RST.
+func TestQueueReceiveInSynSent(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ reset bool
+ }{
+ {description: "Send DATA", reset: false},
+ {description: "Send RST", reset: true},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ sampleData := []byte("Sample Data")
+
+ dut.SetNonBlocking(t, socket, true)
+ if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) {
+ t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
+ }
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ t.Fatalf("expected a SYN from DUT, but got none: %s", err)
+ }
+
+ if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) {
+ t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err)
+ }
+
+ // Test blocking read.
+ dut.SetNonBlocking(t, socket, false)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+ wg.Add(1)
+ var block sync.WaitGroup
+ block.Add(1)
+ go func() {
+ defer wg.Done()
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
+ defer cancel()
+
+ block.Done()
+ // Issue RECEIVE call in SYN-SENT, this should be queued for
+ // process until the connection is established.
+ n, buff, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0)
+ if tt.reset {
+ if err != syscall.Errno(unix.ECONNREFUSED) {
+ t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err)
+ }
+ if n != -1 {
+ t.Errorf("expected return value %d, got %d", -1, n)
+ }
+ return
+ }
+ if n == -1 {
+ t.Errorf("failed to recv on DUT: %s", err)
+ }
+ if got := buff[:n]; !bytes.Equal(got, sampleData) {
+ t.Errorf("received data doesn't match, got:\n%s, want:\n%s", hex.Dump(got), hex.Dump(sampleData))
+ }
+ }()
+
+ // Wait for the goroutine to be scheduled and before it
+ // blocks on endpoint receive.
+ block.Wait()
+ // The following sleep is used to prevent the connection
+ // from being established before we are blocked on Recv.
+ time.Sleep(100 * time.Millisecond)
+
+ if tt.reset {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ return
+ }
+
+ // Bring the connection to Established.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK from DUT, but got none: %s", err)
+ }
+
+ // Send sample payload and expect an ACK.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected an ACK from DUT, but got none: %s", err)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go
new file mode 100644
index 000000000..b4aeaab57
--- /dev/null
+++ b/test/packetimpact/tests/tcp_reordering_test.go
@@ -0,0 +1,174 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reordering_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
+func TestReorderingWindow(t *testing.T) {
+ dut := tb.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ // Enable SACK.
+ opts := make([]byte, 40)
+ optsOff := 0
+ optsOff += header.EncodeNOP(opts[optsOff:])
+ optsOff += header.EncodeNOP(opts[optsOff:])
+ optsOff += header.EncodeSACKPermittedOption(opts[optsOff:])
+
+ // Ethernet guarantees that the MTU is at least 1500 bytes.
+ const minMTU = 1500
+ const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+ optsOff += header.EncodeMSSOption(mss, opts[optsOff:])
+
+ conn.ConnectWithOptions(t, opts[:optsOff])
+
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ if tb.Native {
+ // Linux has changed its handling of reordering, force the old behavior.
+ dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno"))
+ }
+
+ pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG)
+ if !tb.Native {
+ // netstack does not impliment TCP_MAXSEG correctly. Fake it
+ // here. Netstack uses the max SACK size which is 32. The MSS
+ // option is 8 bytes, making the total 36 bytes.
+ pls = mss - 36
+ }
+
+ payload := make([]byte, pls)
+
+ seqNum1 := *conn.RemoteSeqNum(t)
+ const numPkts = 10
+ // Send some packets, checking that we receive each.
+ for i, sn := 0, seqNum1; i < numPkts; i++ {
+ dut.Send(t, acceptFd, payload, 0)
+
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ sn.UpdateForward(seqnum.Size(len(payload)))
+ if err != nil {
+ t.Errorf("Expect #%d: %s", i+1, err)
+ continue
+ }
+ if gotOne == nil {
+ t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ }
+ }
+
+ seqNum2 := *conn.RemoteSeqNum(t)
+
+ // SACK packets #2-4.
+ sackBlock := make([]byte, 40)
+ sbOff := 0
+ sbOff += header.EncodeNOP(sackBlock[sbOff:])
+ sbOff += header.EncodeNOP(sackBlock[sbOff:])
+ sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
+ seqNum1.Add(seqnum.Size(len(payload))),
+ seqNum1.Add(seqnum.Size(4 * len(payload))),
+ }}, sackBlock[sbOff:])
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+
+ // ACK first packet.
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))})
+
+ // Check for retransmit.
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second)
+ if err != nil {
+ t.Error("Expect for retransmit:", err)
+ }
+ if gotOne == nil {
+ t.Error("expected a retransmitted packet within a second but got none")
+ }
+
+ // ACK all send packets with a DSACK block for packet #1. This tells
+ // the other end that we got both the original and retransmit for
+ // packet #1.
+ dsackBlock := make([]byte, 40)
+ dsbOff := 0
+ dsbOff += header.EncodeNOP(dsackBlock[dsbOff:])
+ dsbOff += header.EncodeNOP(dsackBlock[dsbOff:])
+ dsbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
+ seqNum1.Add(seqnum.Size(len(payload))),
+ seqNum1.Add(seqnum.Size(4 * len(payload))),
+ }}, dsackBlock[dsbOff:])
+
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]})
+
+ // Send half of the original window of packets, checking that we
+ // received each.
+ for i, sn := 0, seqNum2; i < numPkts/2; i++ {
+ dut.Send(t, acceptFd, payload, 0)
+
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ sn.UpdateForward(seqnum.Size(len(payload)))
+ if err != nil {
+ t.Errorf("Expect #%d: %s", i+1, err)
+ continue
+ }
+ if gotOne == nil {
+ t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ }
+ }
+
+ if !tb.Native {
+ // The window should now be halved, so we should receive any
+ // more, even if we send them.
+ dut.Send(t, acceptFd, payload, 0)
+ if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
+ }
+ return
+ }
+
+ // Linux reduces the window by three. Check that we can receive the rest.
+ for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ {
+ dut.Send(t, acceptFd, payload, 0)
+
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ sn.UpdateForward(seqnum.Size(len(payload)))
+ if err != nil {
+ t.Errorf("Expect #%d: %s", i+1, err)
+ continue
+ }
+ if gotOne == nil {
+ t.Errorf("#%d: expected a packet within a second but got none", i+1)
+ }
+ }
+
+ // The window should now be full.
+ dut.Send(t, acceptFd, payload, 0)
+ if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go
new file mode 100644
index 000000000..072014ff8
--- /dev/null
+++ b/test/packetimpact/tests/tcp_retransmits_test.go
@@ -0,0 +1,84 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_retransmits_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestRetransmits tests retransmits occur at exponentially increasing
+// time intervals.
+func TestRetransmits(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK.
+ // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
+ // we can skip sending this ACK.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ startRTO := time.Second
+ current := startRTO
+ first := time.Now()
+ dut.Send(t, acceptFd, sampleData, 0)
+ seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ // Expect retransmits of the same segment.
+ for i := 0; i < 5; i++ {
+ start := time.Now()
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil {
+ t.Fatalf("expected payload was not received: %s loop %d", err, i)
+ }
+ if i == 0 {
+ startRTO = time.Now().Sub(first)
+ current = 2 * startRTO
+ continue
+ }
+ // Check if the probes came at exponentially increasing intervals.
+ if p := time.Since(start); p < current-startRTO {
+ t.Fatalf("retransmit came sooner interval %d probe %d", p, i)
+ }
+ current *= 2
+ }
+}
diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
new file mode 100644
index 000000000..f91b06ba1
--- /dev/null
+++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
@@ -0,0 +1,105 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_send_window_sizes_piggyback_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestSendWindowSizesPiggyback tests cases where segment sizes are close to
+// sender window size and checks for ACK piggybacking for each of those case.
+func TestSendWindowSizesPiggyback(t *testing.T) {
+ sampleData := []byte("Sample Data")
+ segmentSize := uint16(len(sampleData))
+ // Advertise receive window sizes that are lesser, equal to or greater than
+ // enqueued segment size and check for segment transmits. The test attempts
+ // to enqueue a segment on the dut before acknowledging previous segment and
+ // lets the dut piggyback any ACKs along with the enqueued segment.
+ for _, tt := range []struct {
+ description string
+ windowSize uint16
+ expectedPayload1 []byte
+ expectedPayload2 []byte
+ enqueue bool
+ }{
+ // Expect the first segment to be split as it cannot be accomodated in
+ // the sender window. This means we need not enqueue a new segment after
+ // the first segment.
+ {"WindowSmallerThanSegment", segmentSize - 1, sampleData[:(segmentSize - 1)], sampleData[(segmentSize - 1):], false /* enqueue */},
+
+ {"WindowEqualToSegment", segmentSize, sampleData, sampleData, true /* enqueue */},
+
+ // Expect the second segment to not be split as its size is greater than
+ // the available sender window size. The segments should not be split
+ // when there is pending unacknowledged data and the segment-size is
+ // greater than available sender window.
+ {"WindowGreaterThanSegment", segmentSize + 1, sampleData, sampleData, true /* enqueue */},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
+
+ dut.Send(t, acceptFd, sampleData, 0)
+ expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1}
+ if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Expect any enqueued segment to be transmitted by the dut along with
+ // piggybacked ACK for our data.
+
+ if tt.enqueue {
+ // Enqueue a segment for the dut to transmit.
+ dut.Send(t, acceptFd, sampleData, 0)
+ }
+
+ // Send ACK for the previous segment along with data for the dut to
+ // receive and ACK back. Sending this ACK would make room for the dut
+ // to transmit any enqueued segment.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData})
+
+ // Expect the dut to piggyback the ACK for received data along with
+ // the segment enqueued for transmit.
+ expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2}
+ if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
new file mode 100644
index 000000000..57d034dd1
--- /dev/null
+++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_syn_reset_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestTCPSynRcvdReset tests transition from SYN-RCVD to CLOSED.
+func TestTCPSynRcvdReset(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ // Expect dut connection to have transitioned to SYN-RCVD state.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
+ // Expect the connection to have transitioned SYN-RCVD to CLOSED.
+ // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go
new file mode 100644
index 000000000..eac8eb19d
--- /dev/null
+++ b/test/packetimpact/tests/tcp_synsent_reset_test.go
@@ -0,0 +1,90 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_synsent_reset_test
+
+import (
+ "flag"
+ "net"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ tb.RegisterFlags(flag.CommandLine)
+}
+
+// dutSynSentState sets up the dut connection in SYN-SENT state.
+func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) {
+ t.Helper()
+
+ dut := tb.NewDUT(t)
+
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4))
+ port := uint16(9001)
+ conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port})
+
+ sa := unix.SockaddrInet4{Port: int(port)}
+ copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4())
+ // Bring the dut to SYN-SENT state with a non-blocking connect.
+ dut.Connect(t, clientFD, &sa)
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN\n")
+ }
+
+ return &dut, &conn, port, clientPort
+}
+
+// TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition.
+func TestTCPSynSentReset(t *testing.T) {
+ dut, conn, _, _ := dutSynSentState(t)
+ defer conn.Close(t)
+ defer dut.TearDown()
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ // Expect the connection to have closed.
+ // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST")
+ }
+}
+
+// TestTCPSynSentRcvdReset tests RFC793, p70, SYN-SENT to SYN-RCVD to CLOSED
+// transitions.
+func TestTCPSynSentRcvdReset(t *testing.T) {
+ dut, c, remotePort, clientPort := dutSynSentState(t)
+ defer dut.TearDown()
+ defer c.Close(t)
+
+ conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort})
+ defer conn.Close(t)
+ // Initiate new SYN connection with the same port pair
+ // (simultaneous open case), expect the dut connection to move to
+ // SYN-RCVD state
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK %s\n", err)
+ }
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)})
+ // Expect the connection to have transitioned SYN-RCVD to CLOSED.
+ // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST")
+ }
+}
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
new file mode 100644
index 000000000..551dc78e7
--- /dev/null
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_user_timeout_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) {
+ sampleData := make([]byte, 100)
+ for i := range sampleData {
+ sampleData[i] = uint8(i)
+ }
+ conn.Drain(t)
+ dut.Send(t, fd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
+ t.Fatalf("expected data but got none: %w", err)
+ }
+}
+
+func sendFIN(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) {
+ dut.Close(t, fd)
+}
+
+func TestTCPUserTimeout(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ userTimeout time.Duration
+ sendDelay time.Duration
+ }{
+ {"NoUserTimeout", 0, 3 * time.Second},
+ {"ACKBeforeUserTimeout", 5 * time.Second, 4 * time.Second},
+ {"ACKAfterUserTimeout", 5 * time.Second, 7 * time.Second},
+ } {
+ for _, ttf := range []struct {
+ description string
+ f func(_ *testing.T, _ *testbench.TCPIPv4, _ *testbench.DUT, fd int32)
+ }{
+ {"AfterPayload", sendPayload},
+ {"AfterFIN", sendFIN},
+ } {
+ t.Run(tt.description+ttf.description, func(t *testing.T) {
+ // Create a socket, listen, TCP handshake, and accept.
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+
+ if tt.userTimeout != 0 {
+ dut.SetSockOptInt(t, acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds()))
+ }
+
+ ttf.f(t, &conn, &dut, acceptFD)
+
+ time.Sleep(tt.sendDelay)
+ conn.Drain(t)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ // If TCP_USER_TIMEOUT was set and the above delay was longer than the
+ // TCP_USER_TIMEOUT then the DUT should send a RST in response to the
+ // testbench's packet.
+ expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout
+ expectTimeout := 5 * time.Second
+ got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout)
+ if expectRST && err != nil {
+ t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err)
+ }
+ if !expectRST && got != nil {
+ t.Errorf("expected no RST packet within %s but got one: %s", expectTimeout, got)
+ }
+ })
+ }
+ }
+}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
new file mode 100644
index 000000000..5b001fbec
--- /dev/null
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -0,0 +1,73 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_window_shrink_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestWindowShrink(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+
+ dut.Send(t, acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ // We close our receiving window here
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+
+ dut.Send(t, acceptFd, []byte("Sample Data"), 0)
+ // Note: There is another kind of zero-window probing which Windows uses (by sending one
+ // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change
+ // the following lines.
+ expectedRemoteSeqNum := *conn.RemoteSeqNum(t) - 1
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
new file mode 100644
index 000000000..da93267d6
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -0,0 +1,104 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_window_probe_retransmit_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestZeroWindowProbeRetransmit tests retransmits of zero window probes
+// to be sent at exponentially inreasing time intervals.
+func TestZeroWindowProbeRetransmit(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ // Send and receive sample data to the dut.
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+
+ // Check for the dut to keep the connection alive as long as the zero window
+ // probes are acknowledged. Check if the zero window probes are sent at
+ // exponentially increasing intervals. The timeout intervals are function
+ // of the recorded first zero probe transmission duration.
+ //
+ // Advertize zero receive window again.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
+
+ startProbeDuration := time.Second
+ current := startProbeDuration
+ first := time.Now()
+ // Ask the dut to send out data.
+ dut.Send(t, acceptFd, sampleData, 0)
+ // Expect the dut to keep the connection alive as long as the remote is
+ // acknowledging the zero-window probes.
+ for i := 0; i < 5; i++ {
+ start := time.Now()
+ // Expect zero-window probe with a timeout which is a function of the typical
+ // first retransmission time. The retransmission times is supposed to
+ // exponentially increase.
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil {
+ t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i)
+ }
+ if i == 0 {
+ startProbeDuration = time.Now().Sub(first)
+ current = 2 * startProbeDuration
+ continue
+ }
+ // Check if the probes came at exponentially increasing intervals.
+ if got, want := time.Since(start), current-startProbeDuration; got < want {
+ t.Errorf("got zero probe %d after %s, want >= %s", i, got, want)
+ }
+ // Acknowledge the zero-window probes from the dut.
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ current *= 2
+ }
+ // Advertize non-zero window.
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Expect the dut to recover and transmit data.
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go
new file mode 100644
index 000000000..44cac42f8
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go
@@ -0,0 +1,112 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_window_probe_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestZeroWindowProbe tests few cases of zero window probing over the
+// same connection.
+func TestZeroWindowProbe(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ start := time.Now()
+ // Send and receive sample data to the dut.
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ sendTime := time.Now().Sub(start)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+
+ // Test 1: Check for receive of a zero window probe, record the duration for
+ // probe to be sent.
+ //
+ // Advertize zero window to the dut.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+
+ // Expected sequence number of the zero window probe.
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
+ // Expected ack number of the ACK for the probe.
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
+
+ // Expect there are no zero-window probes sent until there is data to be sent out
+ // from the dut.
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil {
+ t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err)
+ }
+
+ start = time.Now()
+ // Ask the dut to send out data.
+ dut.Send(t, acceptFd, sampleData, 0)
+ // Expect zero-window probe from the dut.
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err)
+ }
+ // Expect the probe to be sent after some time. Compare against the previous
+ // time recorded when the dut immediately sends out data on receiving the
+ // send command.
+ if startProbeDuration := time.Now().Sub(start); startProbeDuration <= sendTime {
+ t.Fatalf("expected the first probe to be sent out after retransmission interval, got %s want > %s", startProbeDuration, sendTime)
+ }
+
+ // Test 2: Check if the dut recovers on advertizing non-zero receive window.
+ // and sends out the sample payload after the send window opens.
+ //
+ // Advertize non-zero window to the dut and ack the zero window probe.
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ // Expect the dut to recover and transmit data.
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+
+ // Test 3: Sanity check for dut's processing of a similar probe it sent.
+ // Check if the dut responds as we do for a similar probe sent to it.
+ // Basically with sequence number to one byte behind the unacknowledged
+ // sequence number.
+ p := testbench.Uint32(uint32(*conn.LocalSeqNum(t)))
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with ack number: %d: %s", p, err)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
new file mode 100644
index 000000000..09a1c653f
--- /dev/null
+++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
@@ -0,0 +1,98 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_zero_window_probe_usertimeout_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+// TestZeroWindowProbeUserTimeout sanity tests user timeout when we are
+// retransmitting zero window probes.
+func TestZeroWindowProbeUserTimeout(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
+ conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
+
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ // Send and receive sample data to the dut.
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+
+ // Test 1: Check for receive of a zero window probe, record the duration for
+ // probe to be sent.
+ //
+ // Advertize zero window to the dut.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+
+ // Expected sequence number of the zero window probe.
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
+ start := time.Now()
+ // Ask the dut to send out data.
+ dut.Send(t, acceptFd, sampleData, 0)
+ // Expect zero-window probe from the dut.
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err)
+ }
+ // Record the duration for first probe, the dut sends the zero window probe after
+ // a retransmission time interval.
+ startProbeDuration := time.Now().Sub(start)
+
+ // Test 2: Check if the dut times out the connection by honoring usertimeout
+ // when the dut is sending zero-window probes.
+ //
+ // Reduce the retransmit timeout.
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds()))
+ // Advertize zero window again.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ // Ask the dut to send out data that would trigger zero window probe retransmissions.
+ dut.Send(t, acceptFd, sampleData, 0)
+
+ // Wait for the connection to timeout after multiple zero-window probe retransmissions.
+ time.Sleep(8 * startProbeDuration)
+
+ // Expect the connection to have timed out and closed which would cause the dut
+ // to reply with a RST to the ACK we send.
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ t.Fatalf("expected a TCP RST")
+ }
+}
diff --git a/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go
new file mode 100644
index 000000000..17f32ef65
--- /dev/null
+++ b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_any_addr_recv_unicast_test
+
+import (
+ "flag"
+ "net"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestAnyRecvUnicastUDP(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
+ defer dut.Close(t, boundFD)
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */)
+ conn.SendIP(
+ t,
+ testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP(testbench.RemoteIPv4).To4()))},
+ testbench.UDP{},
+ &testbench.Payload{Bytes: payload},
+ )
+ got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff)
+ }
+}
diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
new file mode 100644
index 000000000..d30177e64
--- /dev/null
+++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_discard_mcast_source_addr_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "net"
+ "syscall"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+var oneSecond = unix.Timeval{Sec: 1, Usec: 0}
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4))
+ defer dut.Close(t, remoteFD)
+ dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ for _, mcastAddr := range []net.IP{
+ net.IPv4allsys,
+ net.IPv4allrouter,
+ net.IPv4(224, 0, 1, 42),
+ net.IPv4(232, 1, 2, 3),
+ } {
+ t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) {
+ conn.SendIP(
+ t,
+ testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))},
+ testbench.UDP{},
+ )
+
+ ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0)
+ if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
+ t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
+ }
+ })
+ }
+}
+
+func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6))
+ defer dut.Close(t, remoteFD)
+ dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
+ conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ for _, mcastAddr := range []net.IP{
+ net.IPv6interfacelocalallnodes,
+ net.IPv6linklocalallnodes,
+ net.IPv6linklocalallrouters,
+ net.ParseIP("fe01::42"),
+ net.ParseIP("fe02::4242"),
+ } {
+ t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) {
+ conn.SendIPv6(
+ t,
+ testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))},
+ testbench.UDP{},
+ )
+ ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0)
+ if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
+ t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
new file mode 100644
index 000000000..df35d16c8
--- /dev/null
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -0,0 +1,363 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_icmp_error_propagation_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "net"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+type connectionMode bool
+
+func (c connectionMode) String() string {
+ if c {
+ return "Connected"
+ }
+ return "Connectionless"
+}
+
+type icmpError int
+
+const (
+ portUnreachable icmpError = iota
+ timeToLiveExceeded
+)
+
+func (e icmpError) String() string {
+ switch e {
+ case portUnreachable:
+ return "PortUnreachable"
+ case timeToLiveExceeded:
+ return "TimeToLiveExpired"
+ }
+ return "Unknown ICMP error"
+}
+
+func (e icmpError) ToICMPv4() *testbench.ICMPv4 {
+ switch e {
+ case portUnreachable:
+ return &testbench.ICMPv4{
+ Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable),
+ Code: testbench.ICMPv4Code(header.ICMPv4PortUnreachable)}
+ case timeToLiveExceeded:
+ return &testbench.ICMPv4{
+ Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded),
+ Code: testbench.ICMPv4Code(header.ICMPv4TTLExceeded)}
+ }
+ return nil
+}
+
+type errorDetection struct {
+ name string
+ useValidConn bool
+ f func(context.Context, *testing.T, testData)
+}
+
+type testData struct {
+ dut *testbench.DUT
+ conn *testbench.UDPIPv4
+ remoteFD int32
+ remotePort uint16
+ cleanFD int32
+ cleanPort uint16
+ wantErrno syscall.Errno
+}
+
+// wantErrno computes the errno to expect given the connection mode of a UDP
+// socket and the ICMP error it will receive.
+func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno {
+ if c && icmpErr == portUnreachable {
+ return syscall.Errno(unix.ECONNREFUSED)
+ }
+ return syscall.Errno(0)
+}
+
+// sendICMPError sends an ICMP error message in response to a UDP datagram.
+func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) {
+ t.Helper()
+
+ layers := (*testbench.Connection)(conn).CreateFrame(t, nil)
+ layers = layers[:len(layers)-1]
+ ip, ok := udp.Prev().(*testbench.IPv4)
+ if !ok {
+ t.Fatalf("expected %s to be IPv4", udp.Prev())
+ }
+ if icmpErr == timeToLiveExceeded {
+ *ip.TTL = 1
+ // Let serialization recalculate the checksum since we set the TTL
+ // to 1.
+ ip.Checksum = nil
+ }
+ // Note that the ICMP payload is valid in this case because the UDP
+ // payload is empty. If the UDP payload were not empty, the packet
+ // length during serialization may not be calculated correctly,
+ // resulting in a mal-formed packet.
+ layers = append(layers, icmpErr.ToICMPv4(), ip, udp)
+
+ (*testbench.Connection)(conn).SendFrameStateless(t, layers)
+}
+
+// testRecv tests observing the ICMP error through the recv syscall. A packet
+// is sent to the DUT, and if wantErrno is non-zero, then the first recv should
+// fail and the second should succeed. Otherwise if wantErrno is zero then the
+// first recv should succeed immediately.
+func testRecv(ctx context.Context, t *testing.T, d testData) {
+ t.Helper()
+
+ // Check that receiving on the clean socket works.
+ d.conn.Send(t, testbench.UDP{DstPort: &d.cleanPort})
+ d.dut.Recv(t, d.cleanFD, 100, 0)
+
+ d.conn.Send(t, testbench.UDP{})
+
+ if d.wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(ctx, time.Second)
+ defer cancel()
+ ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0)
+ if ret != -1 {
+ t.Fatalf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
+ }
+ if err != d.wantErrno {
+ t.Fatalf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
+ }
+ }
+
+ d.dut.Recv(t, d.remoteFD, 100, 0)
+}
+
+// testSendTo tests observing the ICMP error through the send syscall. If
+// wantErrno is non-zero, the first send should fail and a subsequent send
+// should suceed; while if wantErrno is zero then the first send should just
+// succeed.
+func testSendTo(ctx context.Context, t *testing.T, d testData) {
+ // Check that sending on the clean socket works.
+ d.dut.SendTo(t, d.cleanFD, nil, 0, d.conn.LocalAddr(t))
+ if _, err := d.conn.Expect(t, testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil {
+ t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err)
+ }
+
+ if d.wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(ctx, time.Second)
+ defer cancel()
+ ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t))
+
+ if ret != -1 {
+ t.Fatalf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
+ }
+ if err != d.wantErrno {
+ t.Fatalf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
+ }
+ }
+
+ d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t))
+ if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil {
+ t.Fatalf("did not receive UDP packet as expected: %s", err)
+ }
+}
+
+func testSockOpt(_ context.Context, t *testing.T, d testData) {
+ // Check that there's no pending error on the clean socket.
+ if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) {
+ t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno)
+ }
+
+ if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno {
+ t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno)
+ }
+
+ // Check that after clearing socket error, sending doesn't fail.
+ d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t))
+ if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil {
+ t.Fatalf("did not receive UDP packet as expected: %s", err)
+ }
+}
+
+// TestUDPICMPErrorPropagation tests that ICMP error messages in response to
+// UDP datagrams are processed correctly. RFC 1122 section 4.1.3.3 states that:
+// "UDP MUST pass to the application layer all ICMP error messages that it
+// receives from the IP layer."
+//
+// The test cases are parametrized in 3 dimensions: 1. the UDP socket is either
+// put into connection mode or left connectionless, 2. the ICMP message type
+// and code, and 3. the method by which the ICMP error is observed on the
+// socket: sendto, recv, or getsockopt(SO_ERROR).
+//
+// Linux's udp(7) man page states: "All fatal errors will be passed to the user
+// as an error return even when the socket is not connected. This includes
+// asynchronous errors received from the network." In practice, the only
+// combination of parameters to the test that causes an error to be observable
+// on the UDP socket is receiving a port unreachable message on a connected
+// socket.
+func TestUDPICMPErrorPropagation(t *testing.T) {
+ for _, connect := range []connectionMode{true, false} {
+ for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} {
+ wantErrno := wantErrno(connect, icmpErr)
+
+ for _, errDetect := range []errorDetection{
+ errorDetection{"SendTo", false, testSendTo},
+ // Send to an address that's different from the one that caused an ICMP
+ // error to be returned.
+ errorDetection{"SendToValid", true, testSendTo},
+ errorDetection{"Recv", false, testRecv},
+ errorDetection{"SockOpt", false, testSockOpt},
+ } {
+ t.Run(fmt.Sprintf("%s/%s/%s", connect, icmpErr, errDetect.name), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
+ defer dut.Close(t, remoteFD)
+
+ // Create a second, clean socket on the DUT to ensure that the ICMP
+ // error messages only affect the sockets they are intended for.
+ cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
+ defer dut.Close(t, cleanFD)
+
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ if connect {
+ dut.Connect(t, remoteFD, conn.LocalAddr(t))
+ dut.Connect(t, cleanFD, conn.LocalAddr(t))
+ }
+
+ dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t))
+ udp, err := conn.Expect(t, testbench.UDP{}, time.Second)
+ if err != nil {
+ t.Fatalf("did not receive message from DUT: %s", err)
+ }
+
+ sendICMPError(t, &conn, icmpErr, udp)
+
+ errDetectConn := &conn
+ if errDetect.useValidConn {
+ // connClean is a UDP socket on the test runner that was not
+ // involved in the generation of the ICMP error. As such,
+ // interactions between it and the the DUT should be independent of
+ // the ICMP error at least at the port level.
+ connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer connClean.Close(t)
+
+ errDetectConn = &connClean
+ }
+
+ errDetect.f(context.Background(), t, testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno})
+ })
+ }
+ }
+ }
+}
+
+// TestICMPErrorDuringUDPRecv tests behavior when a UDP socket is in the middle
+// of a blocking recv and receives an ICMP error.
+func TestICMPErrorDuringUDPRecv(t *testing.T) {
+ for _, connect := range []connectionMode{true, false} {
+ for _, icmpErr := range []icmpError{portUnreachable, timeToLiveExceeded} {
+ wantErrno := wantErrno(connect, icmpErr)
+
+ t.Run(fmt.Sprintf("%s/%s", connect, icmpErr), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
+ defer dut.Close(t, remoteFD)
+
+ // Create a second, clean socket on the DUT to ensure that the ICMP
+ // error messages only affect the sockets they are intended for.
+ cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero)
+ defer dut.Close(t, cleanFD)
+
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ if connect {
+ dut.Connect(t, remoteFD, conn.LocalAddr(t))
+ dut.Connect(t, cleanFD, conn.LocalAddr(t))
+ }
+
+ dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t))
+ udp, err := conn.Expect(t, testbench.UDP{}, time.Second)
+ if err != nil {
+ t.Fatalf("did not receive message from DUT: %s", err)
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ if wantErrno != syscall.Errno(0) {
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0)
+ if ret != -1 {
+ t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
+ return
+ }
+ if err != wantErrno {
+ t.Errorf("recv during ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, wantErrno)
+ return
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0); ret == -1 {
+ t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ if ret, _, err := dut.RecvWithErrno(ctx, t, cleanFD, 100, 0); ret == -1 {
+ t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err)
+ }
+ }()
+
+ // TODO(b/155684889) This sleep is to allow time for the DUT to
+ // actually call recv since we want the ICMP error to arrive during the
+ // blocking recv, and should be replaced when a better synchronization
+ // alternative is available.
+ time.Sleep(2 * time.Second)
+
+ sendICMPError(t, &conn, icmpErr, udp)
+
+ conn.Send(t, testbench.UDP{DstPort: &cleanPort})
+ conn.Send(t, testbench.UDP{})
+ wg.Wait()
+ })
+ }
+ }
+}
diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
new file mode 100644
index 000000000..526173969
--- /dev/null
+++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
@@ -0,0 +1,110 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_recv_mcast_bcast_test
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "net"
+ "syscall"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+func TestUDPRecvMcastBcast(t *testing.T) {
+ subnetBcastAddr := broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32))
+
+ for _, v := range []struct {
+ bound, to net.IP
+ }{
+ {bound: net.IPv4zero, to: subnetBcastAddr},
+ {bound: net.IPv4zero, to: net.IPv4bcast},
+ {bound: net.IPv4zero, to: net.IPv4allsys},
+
+ {bound: subnetBcastAddr, to: subnetBcastAddr},
+ {bound: subnetBcastAddr, to: net.IPv4bcast},
+
+ {bound: net.IPv4bcast, to: net.IPv4bcast},
+ {bound: net.IPv4allsys, to: net.IPv4allsys},
+ } {
+ t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound)
+ defer dut.Close(t, boundFD)
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */)
+ conn.SendIP(
+ t,
+ testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(v.to.To4()))},
+ testbench.UDP{},
+ &testbench.Payload{Bytes: payload},
+ )
+ got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4))
+ dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0})
+ defer dut.Close(t, boundFD)
+ conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ for _, to := range []net.IP{
+ broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)),
+ net.IPv4(255, 255, 255, 255),
+ net.IPv4(224, 0, 0, 1),
+ } {
+ t.Run(fmt.Sprint("to=%s", to), func(t *testing.T) {
+ payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */)
+ conn.SendIP(
+ t,
+ testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(to.To4()))},
+ testbench.UDP{},
+ &testbench.Payload{Bytes: payload},
+ )
+ ret, payload, errno := dut.RecvWithErrno(context.Background(), t, boundFD, 100, 0)
+ if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
+ t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
+ }
+ })
+ }
+}
+
+func broadcastAddr(ip net.IP, mask net.IPMask) net.IP {
+ ip4 := ip.To4()
+ for i := range ip4 {
+ ip4[i] |= ^mask[i]
+ }
+ return ip4
+}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
new file mode 100644
index 000000000..91b967400
--- /dev/null
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -0,0 +1,104 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp_send_recv_dgram_test
+
+import (
+ "flag"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/google/go-cmp/cmp"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.RegisterFlags(flag.CommandLine)
+}
+
+type udpConn interface {
+ Send(*testing.T, testbench.UDP, ...testbench.Layer)
+ ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error)
+ Drain(*testing.T)
+ Close(*testing.T)
+}
+
+func TestUDP(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ defer dut.TearDown()
+
+ for _, isIPv4 := range []bool{true, false} {
+ ipVersionName := "IPv6"
+ if isIPv4 {
+ ipVersionName = "IPv4"
+ }
+ t.Run(ipVersionName, func(t *testing.T) {
+ var addr string
+ if isIPv4 {
+ addr = testbench.RemoteIPv4
+ } else {
+ addr = testbench.RemoteIPv6
+ }
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr))
+ defer dut.Close(t, boundFD)
+
+ var conn udpConn
+ var localAddr unix.Sockaddr
+ if isIPv4 {
+ v4Conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ localAddr = v4Conn.LocalAddr(t)
+ conn = &v4Conn
+ } else {
+ v6Conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
+ localAddr = v6Conn.LocalAddr(t)
+ conn = &v6Conn
+ }
+ defer conn.Close(t)
+
+ testCases := []struct {
+ name string
+ payload []byte
+ }{
+ {"emptypayload", nil},
+ {"small payload", []byte("hello world")},
+ {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)},
+ // Even though UDP allows larger dgrams we don't test it here as
+ // they need to be fragmented and written out as individual
+ // frames.
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Run("Send", func(t *testing.T) {
+ conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload})
+ got, want := dut.Recv(t, boundFD, int32(len(tc.payload)+1), 0), tc.payload
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Fatalf("received payload does not match sent payload, diff (-want, +got):\n%s", diff)
+ }
+ })
+ t.Run("Recv", func(t *testing.T) {
+ conn.Drain(t)
+ if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want {
+ t.Fatalf("short write got: %d, want: %d", got, want)
+ }
+ if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil {
+ t.Fatal(err)
+ }
+ })
+ })
+ }
+ })
+ }
+}
diff --git a/test/perf/BUILD b/test/perf/BUILD
new file mode 100644
index 000000000..471d8c2ab
--- /dev/null
+++ b/test/perf/BUILD
@@ -0,0 +1,117 @@
+load("//test/runner:defs.bzl", "syscall_test")
+
+package(licenses = ["notice"])
+
+syscall_test(
+ test = "//test/perf/linux:clock_getres_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:clock_gettime_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:death_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:epoll_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:fork_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:futex_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ shard_count = 10,
+ tags = ["nogotsan"],
+ test = "//test/perf/linux:getdents_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:getpid_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ tags = ["nogotsan"],
+ test = "//test/perf/linux:gettid_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:mapping_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:open_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:pipe_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:randread_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:read_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:sched_yield_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:send_recv_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:seqwrite_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:signal_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:sleep_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:stat_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ add_overlay = True,
+ test = "//test/perf/linux:unlink_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:write_benchmark",
+)
diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD
new file mode 100644
index 000000000..b4e907826
--- /dev/null
+++ b/test/perf/linux/BUILD
@@ -0,0 +1,356 @@
+load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "getpid_benchmark",
+ testonly = 1,
+ srcs = [
+ "getpid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "send_recv_benchmark",
+ testonly = 1,
+ srcs = [
+ "send_recv_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/syscalls/linux:socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "gettid_benchmark",
+ testonly = 1,
+ srcs = [
+ "gettid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "sched_yield_benchmark",
+ testonly = 1,
+ srcs = [
+ "sched_yield_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "clock_getres_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_getres_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "clock_gettime_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_gettime_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "open_benchmark",
+ testonly = 1,
+ srcs = [
+ "open_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "read_benchmark",
+ testonly = 1,
+ srcs = [
+ "read_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "randread_benchmark",
+ testonly = 1,
+ srcs = [
+ "randread_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "write_benchmark",
+ testonly = 1,
+ srcs = [
+ "write_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "seqwrite_benchmark",
+ testonly = 1,
+ srcs = [
+ "seqwrite_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "pipe_benchmark",
+ testonly = 1,
+ srcs = [
+ "pipe_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "fork_benchmark",
+ testonly = 1,
+ srcs = [
+ "fork_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "futex_benchmark",
+ testonly = 1,
+ srcs = [
+ "futex_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "epoll_benchmark",
+ testonly = 1,
+ srcs = [
+ "epoll_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:epoll_util",
+ "//test/util:file_descriptor",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "death_benchmark",
+ testonly = 1,
+ srcs = [
+ "death_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "mapping_benchmark",
+ testonly = 1,
+ srcs = [
+ "mapping_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "signal_benchmark",
+ testonly = 1,
+ srcs = [
+ "signal_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getdents_benchmark",
+ testonly = 1,
+ srcs = [
+ "getdents_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sleep_benchmark",
+ testonly = 1,
+ srcs = [
+ "sleep_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "stat_benchmark",
+ testonly = 1,
+ srcs = [
+ "stat_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "unlink_benchmark",
+ testonly = 1,
+ srcs = [
+ "unlink_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc
new file mode 100644
index 000000000..b051293ad
--- /dev/null
+++ b/test/perf/linux/clock_getres_benchmark.cc
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// clock_getres(1) is very nearly a no-op syscall, but it does require copying
+// out to a userspace struct. It thus provides a nice small copy-out benchmark.
+void BM_ClockGetRes(benchmark::State& state) {
+ struct timespec ts;
+ for (auto _ : state) {
+ clock_getres(CLOCK_MONOTONIC, &ts);
+ }
+}
+
+BENCHMARK(BM_ClockGetRes);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc
new file mode 100644
index 000000000..6691bebd9
--- /dev/null
+++ b/test/perf/linux/clock_gettime_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <pthread.h>
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_ClockGettimeThreadCPUTime(benchmark::State& state) {
+ clockid_t clockid;
+ ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid));
+ struct timespec tp;
+
+ for (auto _ : state) {
+ clock_gettime(clockid, &tp);
+ }
+}
+
+BENCHMARK(BM_ClockGettimeThreadCPUTime);
+
+void BM_VDSOClockGettime(benchmark::State& state) {
+ const clockid_t clock = state.range(0);
+ struct timespec tp;
+ absl::Time start = absl::Now();
+
+ // Don't benchmark the calibration phase.
+ while (absl::Now() < start + absl::Milliseconds(2100)) {
+ clock_gettime(clock, &tp);
+ }
+
+ for (auto _ : state) {
+ clock_gettime(clock, &tp);
+ }
+}
+
+BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc
new file mode 100644
index 000000000..cb2b6fd07
--- /dev/null
+++ b/test/perf/linux/death_benchmark.cc
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing
+// the ability of gVisor (on whatever platform) to execute all the related
+// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH.
+TEST(DeathTest, ZeroEqualsOne) {
+ EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc
new file mode 100644
index 000000000..0b121338a
--- /dev/null
+++ b/test/perf/linux/epoll_benchmark.cc
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/epoll.h>
+#include <sys/eventfd.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/epoll_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Returns a new eventfd.
+PosixErrorOr<FileDescriptor> NewEventFD() {
+ int fd = eventfd(0, /* flags = */ 0);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "eventfd");
+ }
+ return FileDescriptor(fd);
+}
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollTimeout(benchmark::State& state) {
+ constexpr int kFDsPerEpoll = 3;
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+ }
+
+ struct epoll_event result[kFDsPerEpoll];
+ int timeout_ms = state.range(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms));
+ }
+}
+
+BENCHMARK(BM_EpollTimeout)->Range(0, 8);
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollAllEvents(benchmark::State& state) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ const int fds_per_epoll = state.range(0);
+ constexpr uint64_t kEventVal = 5;
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < fds_per_epoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+
+ ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)),
+ SyscallSucceedsWithValue(sizeof(kEventVal)));
+ }
+
+ std::vector<struct epoll_event> result(fds_per_epoll);
+
+ for (auto _ : state) {
+ EXPECT_EQ(fds_per_epoll,
+ epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0));
+ }
+}
+
+BENCHMARK(BM_EpollAllEvents)->Range(2, 1024);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc
new file mode 100644
index 000000000..84fdbc8a0
--- /dev/null
+++ b/test/perf/linux/fork_benchmark.cc
@@ -0,0 +1,350 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/barrier.h"
+#include "benchmark/benchmark.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBusyMax = 250;
+
+// Do some CPU-bound busy-work.
+int busy(int max) {
+ // Prevent the compiler from optimizing this work away,
+ volatile int count = 0;
+
+ for (int i = 1; i < max; i++) {
+ for (int j = 2; j < i / 2; j++) {
+ if (i % j == 0) {
+ count++;
+ }
+ }
+ }
+
+ return count;
+}
+
+void BM_CPUBoundUniprocess(benchmark::State& state) {
+ for (auto _ : state) {
+ busy(kBusyMax);
+ }
+}
+
+BENCHMARK(BM_CPUBoundUniprocess);
+
+void BM_CPUBoundAsymmetric(benchmark::State& state) {
+ const size_t max = state.max_iterations;
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < max; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ ASSERT_TRUE(state.KeepRunningBatch(max));
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ ASSERT_FALSE(state.KeepRunning());
+}
+
+BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime();
+
+void BM_CPUBoundSymmetric(benchmark::State& state) {
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ ASSERT_FALSE(state.KeepRunning());
+ });
+
+ const int processes = state.range(0);
+ for (int i = 0; i < processes; i++) {
+ size_t cur = (state.max_iterations + (processes - 1)) / processes;
+ if ((state.iterations() + cur) >= state.max_iterations) {
+ cur = state.max_iterations - state.iterations();
+ }
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < cur; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ if (cur > 0) {
+ // We can have a zero cur here, depending.
+ ASSERT_TRUE(state.KeepRunningBatch(cur));
+ }
+ children.push_back(child);
+ }
+}
+
+BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime();
+
+// Child routine for ProcessSwitch/ThreadSwitch.
+// Reads from readfd and writes the result to writefd.
+void SwitchChild(int readfd, int writefd) {
+ while (1) {
+ char buf;
+ int ret = ReadFd(readfd, &buf, 1);
+ if (ret == 0) {
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "read failed");
+
+ ret = WriteFd(writefd, &buf, 1);
+ if (ret == -1) {
+ TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure");
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "write failed");
+ }
+}
+
+// Send bytes in a loop through a series of pipes, each passing through a
+// different process.
+//
+// Proc 0 Proc 1
+// * ----------> *
+// ^ Pipe 1 |
+// | |
+// | Pipe 0 | Pipe 2
+// | |
+// | |
+// | Pipe 3 v
+// * <---------- *
+// Proc 3 Proc 2
+//
+// This exercises context switching through multiple processes.
+void BM_ProcessSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two processes.
+ const int num_processes = state.range(0);
+ ASSERT_GE(num_processes, 2);
+
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ });
+
+ // Must come after children, as the FDs must be closed before the children
+ // will exit.
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_processes; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This process is one of the processes in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_processes; i++) {
+ // Read from current pipe index, write to next.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = (i + 1) % num_processes;
+ const int write_fd = write_fds[write_index].get();
+
+ // std::vector isn't safe to use from the fork child.
+ FileDescriptor* read_array = read_fds.data();
+ FileDescriptor* write_array = write_fds.data();
+
+ pid_t child = fork();
+ if (!child) {
+ // Close all other FDs.
+ for (int j = 0; j < num_processes; j++) {
+ if (j != read_index) {
+ read_array[j].reset();
+ }
+ if (j != write_index) {
+ write_array[j].reset();
+ }
+ }
+
+ SwitchChild(read_fd, write_fd);
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ children.push_back(child);
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+}
+
+BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime();
+
+// Equivalent to BM_ThreadSwitch using threads instead of processes.
+void BM_ThreadSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two threads.
+ const int num_threads = state.range(0);
+ ASSERT_GE(num_threads, 2);
+
+ // Must come after threads, as the FDs must be closed before the children
+ // will exit.
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_threads; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This thread is one of the threads in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_threads; i++) {
+ // Read from current pipe index, write to next.
+ //
+ // Transfer ownership of the FDs to the thread.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].release();
+
+ const int write_index = (i + 1) % num_threads;
+ const int write_fd = write_fds[write_index].release();
+
+ threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] {
+ FileDescriptor read(read_fd);
+ FileDescriptor write(write_fd);
+ SwitchChild(read.get(), write.get());
+ }));
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+
+ // The two FDs still owned by this thread are closed, causing the next thread
+ // to exit its loop and close its FDs, and so on until all threads exit.
+}
+
+BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime();
+
+void BM_ThreadStart(benchmark::State& state) {
+ const int num_threads = state.range(0);
+
+ for (auto _ : state) {
+ state.PauseTiming();
+
+ auto barrier = new absl::Barrier(num_threads + 1);
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+
+ state.ResumeTiming();
+
+ for (size_t i = 0; i < num_threads; ++i) {
+ threads.emplace_back(std::make_unique<ScopedThread>([barrier] {
+ if (barrier->Block()) {
+ delete barrier;
+ }
+ }));
+ }
+
+ if (barrier->Block()) {
+ delete barrier;
+ }
+
+ state.PauseTiming();
+
+ for (const auto& thread : threads) {
+ thread->Join();
+ }
+
+ state.ResumeTiming();
+ }
+}
+
+BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime();
+
+// Benchmark the complete fork + exit + wait.
+void BM_ProcessLifecycle(benchmark::State& state) {
+ const int num_procs = state.range(0);
+
+ std::vector<pid_t> pids(num_procs);
+ for (auto _ : state) {
+ for (size_t i = 0; i < num_procs; ++i) {
+ int pid = fork();
+ if (pid == 0) {
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ pids[i] = pid;
+ }
+
+ for (const int pid : pids) {
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0),
+ SyscallSucceedsWithValue(pid));
+ }
+ }
+}
+
+BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc
new file mode 100644
index 000000000..e686041c9
--- /dev/null
+++ b/test/perf/linux/futex_benchmark.cc
@@ -0,0 +1,198 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/futex.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+inline int FutexWait(std::atomic<int32_t>* v, int32_t val) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, nullptr);
+}
+
+inline int FutexWaitMonotonicTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* timeout) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, val, timeout);
+}
+
+inline int FutexWaitMonotonicDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE, val, deadline,
+ nullptr, FUTEX_BITSET_MATCH_ANY);
+}
+
+inline int FutexWaitRealtimeDeadline(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* deadline) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME,
+ val, deadline, nullptr, FUTEX_BITSET_MATCH_ANY);
+}
+
+inline int FutexWake(std::atomic<int32_t>* v, int32_t count) {
+ return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count);
+}
+
+// This just uses FUTEX_WAKE on an address with nothing waiting, very simple.
+void BM_FutexWakeNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWake(&v, 1) == 0);
+ }
+}
+
+BENCHMARK(BM_FutexWakeNop)->MinTime(5);
+
+// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the
+// syscall won't wait.
+void BM_FutexWaitNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWait(&v, 1) == -1 && errno == EAGAIN);
+ }
+}
+
+BENCHMARK(BM_FutexWaitNop)->MinTime(5);
+
+// This uses FUTEX_WAIT with a timeout on an address whose value never
+// changes, such that it always times out. Timeout overhead can be estimated by
+// timer overruns for short timeouts.
+void BM_FutexWaitMonotonicTimeout(benchmark::State& state) {
+ const absl::Duration timeout = absl::Nanoseconds(state.range(0));
+ std::atomic<int32_t> v(0);
+ auto ts = absl::ToTimespec(timeout);
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitMonotonicTimeout(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitMonotonicTimeout)
+ ->MinTime(5)
+ ->UseRealTime()
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(100)
+ ->Arg(1000)
+ ->Arg(10000);
+
+// This uses FUTEX_WAIT_BITSET with a deadline that is in the past. This allows
+// estimation of the overhead of setting up a timer for a deadline (as opposed
+// to a timeout as specified for FUTEX_WAIT).
+void BM_FutexWaitMonotonicDeadline(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ struct timespec ts = {};
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitMonotonicDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitMonotonicDeadline)->MinTime(5);
+
+// This is equivalent to BM_FutexWaitMonotonicDeadline, but uses CLOCK_REALTIME
+// instead of CLOCK_MONOTONIC for the deadline.
+void BM_FutexWaitRealtimeDeadline(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ struct timespec ts = {};
+
+ for (auto _ : state) {
+ TEST_PCHECK(FutexWaitRealtimeDeadline(&v, 0, &ts) == -1 &&
+ errno == ETIMEDOUT);
+ }
+}
+
+BENCHMARK(BM_FutexWaitRealtimeDeadline)->MinTime(5);
+
+int64_t GetCurrentMonotonicTimeNanos() {
+ struct timespec ts;
+ TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1);
+ return ts.tv_sec * 1000000000ULL + ts.tv_nsec;
+}
+
+void SpinNanos(int64_t delay_ns) {
+ if (delay_ns <= 0) {
+ return;
+ }
+ const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns;
+ while (GetCurrentMonotonicTimeNanos() < end) {
+ // spin
+ }
+}
+
+// Each iteration of FutexRoundtripDelayed involves a thread sending a futex
+// wakeup to another thread, which spins for delay_us and then sends a futex
+// wakeup back. The time per iteration is 2 * (delay_us + kBeforeWakeDelayNs +
+// futex/scheduling overhead).
+void BM_FutexRoundtripDelayed(benchmark::State& state) {
+ const int delay_us = state.range(0);
+ const int64_t delay_ns = delay_us * 1000;
+ // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce
+ // the probability that the wakeup comes before the wait, preventing the wait
+ // from ever taking effect and causing the benchmark to underestimate the
+ // actual wakeup time.
+ constexpr int64_t kBeforeWakeDelayNs = 500;
+ std::atomic<int32_t> v(0);
+ ScopedThread t([&] {
+ for (int i = 0; i < state.max_iterations; i++) {
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 0) {
+ FutexWait(&v, 0);
+ }
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(0, std::memory_order_release);
+ FutexWake(&v, 1);
+ }
+ });
+ for (auto _ : state) {
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(1, std::memory_order_release);
+ FutexWake(&v, 1);
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 1) {
+ FutexWait(&v, 1);
+ }
+ }
+}
+
+BENCHMARK(BM_FutexRoundtripDelayed)
+ ->MinTime(5)
+ ->UseRealTime()
+ ->Arg(0)
+ ->Arg(10)
+ ->Arg(20)
+ ->Arg(50)
+ ->Arg(100);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
new file mode 100644
index 000000000..d8e81fa8c
--- /dev/null
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+#ifndef SYS_getdents64
+#if defined(__x86_64__)
+#define SYS_getdents64 217
+#elif defined(__aarch64__)
+#define SYS_getdents64 217
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_getdents64
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBufferSize = 65536;
+
+PosixErrorOr<TempPath> CreateDirectory(int count,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir());
+
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (int i = 0; i < count; i++) {
+ auto file = NewTempRelPath();
+ auto res = MknodAt(dfd, file, S_IFREG | 0644, 0);
+ RETURN_IF_ERRNO(res);
+ files->push_back(file);
+ }
+
+ return std::move(dir);
+}
+
+PosixError CleanupDirectory(const TempPath& dir,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (auto it = files->begin(); it != files->end(); ++it) {
+ auto res = UnlinkAt(dfd, *it, 0);
+ RETURN_IF_ERRNO(res);
+ }
+ return NoError();
+}
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a single FD.
+void BM_GetdentsSameFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds());
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime();
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a new FD each time.
+void BM_GetdentsNewFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc
new file mode 100644
index 000000000..db74cb264
--- /dev/null
+++ b/test/perf/linux/getpid_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Getpid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_getpid);
+ }
+}
+
+BENCHMARK(BM_Getpid);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc
new file mode 100644
index 000000000..8f4961f5e
--- /dev/null
+++ b/test/perf/linux/gettid_benchmark.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Gettid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_gettid);
+ }
+}
+
+BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc
new file mode 100644
index 000000000..39c30fe69
--- /dev/null
+++ b/test/perf/linux/mapping_benchmark.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Conservative value for /proc/sys/vm/max_map_count, which limits the number of
+// VMAs, minus a safety margin for VMAs that already exist for the test binary.
+// The default value for max_map_count is
+// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530.
+constexpr size_t kMaxVMAs = 64001;
+
+// Map then unmap pages without touching them.
+void BM_MapUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map, touch, then unmap pages.
+void BM_MapTouchUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ char* end = c + pages * kPageSize;
+ while (c < end) {
+ *c = 42;
+ c += kPageSize;
+ }
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map and touch many pages, unmapping all at once.
+//
+// NOTE(b/111429208): This is a regression test to ensure performant mapping and
+// allocation even with tons of mappings.
+void BM_MapTouchMany(benchmark::State& state) {
+ // Number of pages to map.
+ const int page_count = state.range(0);
+
+ while (state.KeepRunning()) {
+ std::vector<void*> pages;
+
+ for (int i = 0; i < page_count; i++) {
+ void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ *c = 42;
+
+ pages.push_back(addr);
+ }
+
+ for (void* addr : pages) {
+ int ret = munmap(addr, kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+ }
+
+ state.SetBytesProcessed(kPageSize * page_count * state.iterations());
+}
+
+BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime();
+
+void BM_PageFault(benchmark::State& state) {
+ // Map the region in which we will take page faults. To ensure that each page
+ // fault maps only a single page, each page we touch must correspond to a
+ // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However,
+ // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that
+ // if there are background threads running, they can't inadvertently creating
+ // mappings in our gaps that are unmapped when the test ends.
+ size_t test_pages = kMaxVMAs;
+ // Ensure that test_pages is odd, since we want the test region to both
+ // begin and end with a mapped page.
+ if (test_pages % 2 == 0) {
+ test_pages--;
+ }
+ const size_t test_region_bytes = test_pages * kPageSize;
+ // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on
+ // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE
+ // so that Linux pre-allocates the shmem file used to back the mapping.
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE));
+ for (size_t i = 0; i < test_pages / 2; i++) {
+ ASSERT_THAT(
+ mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)),
+ kPageSize, PROT_NONE),
+ SyscallSucceeds());
+ }
+
+ const size_t mapped_pages = test_pages / 2 + 1;
+ // "Start" at the end of the mapped region to force the mapped region to be
+ // reset, since we mapped it with MAP_POPULATE.
+ size_t cur_page = mapped_pages;
+ for (auto _ : state) {
+ if (cur_page >= mapped_pages) {
+ // We've reached the end of our mapped region and have to reset it to
+ // incur page faults again.
+ state.PauseTiming();
+ ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED),
+ SyscallSucceeds());
+ cur_page = 0;
+ state.ResumeTiming();
+ }
+ const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize);
+ const char c = *reinterpret_cast<volatile char*>(addr);
+ benchmark::DoNotOptimize(c);
+ cur_page++;
+ }
+}
+
+BENCHMARK(BM_PageFault)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc
new file mode 100644
index 000000000..68008f6d5
--- /dev/null
+++ b/test/perf/linux/open_benchmark.cc
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Open(benchmark::State& state) {
+ const int size = state.range(0);
+ std::vector<TempPath> cache;
+ for (int i = 0; i < size; i++) {
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ cache.emplace_back(std::move(path));
+ }
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ const int chosen = rand_r(&seed) % size;
+ int fd = open(cache[chosen].path().c_str(), O_RDONLY);
+ TEST_CHECK(fd != -1);
+ close(fd);
+ }
+}
+
+BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc
new file mode 100644
index 000000000..8f5f6a2a3
--- /dev/null
+++ b/test/perf/linux/pipe_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include <cerrno>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Pipe(benchmark::State& state) {
+ int fds[2];
+ TEST_CHECK(pipe(fds) == 0);
+
+ const int size = state.range(0);
+ std::vector<char> wbuf(size);
+ std::vector<char> rbuf(size);
+ RandomizeBuffer(wbuf.data(), size);
+
+ ScopedThread t([&] {
+ auto const fd = fds[1];
+ for (int i = 0; i < state.max_iterations; i++) {
+ TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size);
+ }
+ });
+
+ for (auto _ : state) {
+ TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size);
+ }
+
+ t.Join();
+
+ close(fds[0]);
+ close(fds[1]);
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc
new file mode 100644
index 000000000..b0eb8c24e
--- /dev/null
+++ b/test/perf/linux/randread_benchmark.cc
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Create a 1GB file that will be read from at random positions. This should
+// invalid any performance gains from caching.
+const uint64_t kFileSize = 1ULL << 30;
+
+// How many bytes to write at once to initialize the file used to read from.
+const uint32_t kWriteSize = 65536;
+
+// Largest benchmarked read unit.
+const uint32_t kMaxRead = 1UL << 26;
+
+TempPath CreateFile(uint64_t file_size) {
+ auto path = TempPath::CreateFile().ValueOrDie();
+ FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie();
+
+ // Try to minimize syscalls by using maximum size writev() requests.
+ std::vector<char> buffer(kWriteSize);
+ RandomizeBuffer(buffer.data(), buffer.size());
+ const std::vector<std::vector<struct iovec>> iovecs_list =
+ GenerateIovecs(file_size, buffer.data(), buffer.size());
+ for (const auto& iovecs : iovecs_list) {
+ TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0);
+ }
+
+ return path;
+}
+
+// Global test state, initialized once per process lifetime.
+struct GlobalState {
+ const TempPath tmpfile;
+ explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {}
+};
+
+GlobalState& GetGlobalState() {
+ // This gets created only once throughout the lifetime of the process.
+ // Use a dynamically allocated object (that is never deleted) to avoid order
+ // of destruction of static storage variables issues.
+ static GlobalState* const state =
+ // The actual file size is the maximum random seek range (kFileSize) + the
+ // maximum read size so we can read that number of bytes at the end of the
+ // file.
+ new GlobalState(CreateFile(kFileSize + kMaxRead));
+ return *state;
+}
+
+void BM_RandRead(benchmark::State& state) {
+ const int size = state.range(0);
+
+ GlobalState& global_state = GetGlobalState();
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY));
+ std::vector<char> buf(size);
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(),
+ rand_r(&seed) % kFileSize) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc
new file mode 100644
index 000000000..62445867d
--- /dev/null
+++ b/test/perf/linux/read_benchmark.cc
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Read(benchmark::State& state) {
+ const int size = state.range(0);
+ const std::string contents(size, 0);
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY));
+
+ std::vector<char> buf(size);
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc
new file mode 100644
index 000000000..6756b5575
--- /dev/null
+++ b/test/perf/linux/sched_yield_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Sched_yield(benchmark::State& state) {
+ for (auto ignored : state) {
+ TEST_CHECK(sched_yield() == 0);
+ }
+}
+
+BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc
new file mode 100644
index 000000000..d73e49523
--- /dev/null
+++ b/test/perf/linux/send_recv_benchmark.cc
@@ -0,0 +1,372 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "benchmark/benchmark.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr ssize_t kMessageSize = 1024;
+
+class Message {
+ public:
+ explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {}
+
+ explicit Message(int byte, int sz) : Message(byte, sz, 0) {}
+
+ explicit Message(int byte, int sz, int cmsg_sz)
+ : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) {
+ iov_.iov_base = buffer_.data();
+ iov_.iov_len = sz;
+ hdr_.msg_iov = &iov_;
+ hdr_.msg_iovlen = 1;
+ hdr_.msg_control = cmsg_buffer_.data();
+ hdr_.msg_controllen = cmsg_sz;
+ }
+
+ struct msghdr* header() {
+ return &hdr_;
+ }
+
+ private:
+ std::vector<char> buffer_;
+ std::vector<char> cmsg_buffer_;
+ struct iovec iov_ = {};
+ struct msghdr hdr_ = {};
+};
+
+void BM_Recvmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvmsg)->UseRealTime();
+
+void BM_Sendmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendmsg(send_socket.get(), send_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendmsg)->UseRealTime();
+
+void BM_Recvfrom(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&send_socket, &send_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ }
+ });
+
+ int bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvfrom)->UseRealTime();
+
+void BM_Sendto(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&recv_socket, &recv_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendto)->UseRealTime();
+
+PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = family;
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ break;
+ case AF_INET6:
+ reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
+ in6addr_loopback;
+ break;
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+ return addr;
+}
+
+// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate
+// space for control messages. Note that we do not expect to receive any.
+void BM_RecvmsgWithControlBuf(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ absl::Notification notification;
+ Message send_msg('a');
+ // Create a msghdr with a buffer allocated for control messages.
+ Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24);
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime();
+
+// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes.
+//
+// state.Args[0] indicates whether the underlying socket should be blocking or
+// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking.
+// state.Args[1] is the size of the payload to be used per sendmsg call.
+void BM_SendmsgTCP(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ // Check if we want to run the test w/ a blocking send socket
+ // or non-blocking.
+ const int blocking = state.range(0);
+ if (!blocking) {
+ // Set the send FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds());
+ }
+
+ absl::Notification notification;
+
+ // Get the buffer size we should use for this iteration of the test.
+ const int buf_size = state.range(1);
+ Message send_msg('a', buf_size), recv_msg(0, buf_size);
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ int ncalls = 0;
+ for (auto ignored : state) {
+ int sent = 0;
+ while (true) {
+ struct msghdr hdr = {};
+ struct iovec iov = {};
+ struct msghdr* snd_header = send_msg.header();
+ iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent;
+ iov.iov_len = snd_header->msg_iov->iov_len - sent;
+ hdr.msg_iov = &iov;
+ hdr.msg_iovlen = 1;
+ int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0);
+ ncalls++;
+ if (n > 0) {
+ sent += n;
+ if (sent == buf_size) {
+ break;
+ }
+ // n can be > 0 but less than requested size. In which case we don't
+ // poll.
+ continue;
+ }
+ // Poll the fd for it to become writable.
+ struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10),
+ SyscallSucceedsWithValue(0));
+ }
+ bytes_sent += static_cast<int64_t>(sent);
+ }
+
+ notification.Notify();
+ send_socket.reset();
+ state.SetBytesProcessed(bytes_sent);
+}
+
+void Args(benchmark::internal::Benchmark* benchmark) {
+ for (int blocking = 0; blocking < 2; blocking++) {
+ for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) {
+ benchmark->Args({blocking, buf_size});
+ }
+ }
+}
+
+BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc
new file mode 100644
index 000000000..af49e4477
--- /dev/null
+++ b/test/perf/linux/seqwrite_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// The maximum file size of the test file, when writes get beyond this point
+// they wrap around. This should be large enough to blow away caches.
+const uint64_t kMaxFile = 1 << 30;
+
+// Perform writes of various sizes sequentially to one file. Wraps around if it
+// goes above a certain maximum file size.
+void BM_SeqWrite(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ // Start writes at offset 0.
+ uint64_t offset = 0;
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) ==
+ buf.size());
+ offset += buf.size();
+ // Wrap around if going above the maximum file size.
+ if (offset >= kMaxFile) {
+ offset = 0;
+ }
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc
new file mode 100644
index 000000000..cec679191
--- /dev/null
+++ b/test/perf/linux/signal_benchmark.cc
@@ -0,0 +1,61 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <string.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void FixupHandler(int sig, siginfo_t* si, void* void_ctx) {
+ static unsigned int dataval = 0;
+
+ // Skip the offending instruction.
+ ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx);
+ ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval);
+}
+
+void BM_FaultSignalFixup(benchmark::State& state) {
+ // Set up the signal handler.
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_sigaction = FixupHandler;
+ sa.sa_flags = SA_SIGINFO;
+ TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0);
+
+ // Fault, fault, fault.
+ for (auto _ : state) {
+ // Trigger the segfault.
+ asm volatile(
+ "movq $0, %%rax\n"
+ "movq $0x77777777, (%%rax)\n"
+ :
+ :
+ : "rax");
+ }
+}
+
+BENCHMARK(BM_FaultSignalFixup)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc
new file mode 100644
index 000000000..99ef05117
--- /dev/null
+++ b/test/perf/linux/sleep_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Sleep for 'param' nanoseconds.
+void BM_Sleep(benchmark::State& state) {
+ const int nanoseconds = state.range(0);
+
+ for (auto _ : state) {
+ struct timespec ts;
+ ts.tv_sec = 0;
+ ts.tv_nsec = nanoseconds;
+
+ int ret;
+ do {
+ ret = syscall(SYS_nanosleep, &ts, &ts);
+ if (ret < 0) {
+ TEST_CHECK(errno == EINTR);
+ }
+ } while (ret < 0);
+ }
+}
+
+BENCHMARK(BM_Sleep)
+ ->Arg(0)
+ ->Arg(1)
+ ->Arg(1000) // 1us
+ ->Arg(1000 * 1000) // 1ms
+ ->Arg(10 * 1000 * 1000) // 10ms
+ ->Arg(50 * 1000 * 1000) // 50ms
+ ->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc
new file mode 100644
index 000000000..f15424482
--- /dev/null
+++ b/test/perf/linux/stat_benchmark.cc
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a file in a nested directory hierarchy at least `depth` directories
+// deep, and stats that file multiple times.
+void BM_Stat(benchmark::State& state) {
+ // Create nested directories with given depth.
+ int depth = state.range(0);
+ const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string dir_path = top_dir.path();
+
+ while (depth-- > 0) {
+ // Don't use TempPath because it will make paths too long to use.
+ //
+ // The top_dir destructor will clean up this whole tree.
+ dir_path = JoinPath(dir_path, absl::StrCat(depth));
+ ASSERT_NO_ERRNO(Mkdir(dir_path, 0755));
+ }
+
+ // Create the file that will be stat'd.
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path));
+
+ struct stat st;
+ for (auto _ : state) {
+ ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds());
+ }
+}
+
+BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc
new file mode 100644
index 000000000..92243a042
--- /dev/null
+++ b/test/perf/linux/unlink_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a directory containing `files` files, and unlinks all the files.
+void BM_Unlink(benchmark::State& state) {
+ // Create directory with given files.
+ const int file_count = state.range(0);
+
+ // We unlink all files on each iteration, but report this as a "batch"
+ // iteration so that reported times are per file.
+ TempPath dir;
+ while (state.KeepRunningBatch(file_count)) {
+ state.PauseTiming();
+ // N.B. dir is declared outside the loop so that destruction of the previous
+ // iteration's directory occurs here, inside of PauseTiming.
+ dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ std::vector<TempPath> files;
+ for (int i = 0; i < file_count; i++) {
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ files.push_back(std::move(file));
+ }
+ state.ResumeTiming();
+
+ while (!files.empty()) {
+ // Destructor unlinks.
+ files.pop_back();
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc
new file mode 100644
index 000000000..7b060c70e
--- /dev/null
+++ b/test/perf/linux/write_benchmark.cc
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Write(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), size);
+
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/root/BUILD b/test/root/BUILD
index d5dd9bca2..a9130b34f 100644
--- a/test/root/BUILD
+++ b/test/root/BUILD
@@ -1,11 +1,11 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/vm:defs.bzl", "vm_test")
package(licenses = ["notice"])
go_library(
name = "root",
srcs = ["root.go"],
- importpath = "gvisor.dev/gvisor/test/root",
)
go_test(
@@ -17,28 +17,39 @@ go_test(
"crictl_test.go",
"main_test.go",
"oom_score_adj_test.go",
+ "runsc_test.go",
],
data = [
"//runsc",
],
- embed = [":root"],
+ library = ":root",
tags = [
# Requires docker and runsc to be configured before the test runs.
- # Also test only runs as root.
+ # Also, the test needs to be run as root. Note that below, the
+ # root_vm_test relies on the default runtime 'runsc' being installed by
+ # the default installer.
"manual",
"local",
],
visibility = ["//:sandbox"],
deps = [
- "//runsc/boot",
+ "//pkg/cleanup",
+ "//pkg/test/criutil",
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
"//runsc/cgroup",
"//runsc/container",
- "//runsc/criutil",
- "//runsc/dockerutil",
"//runsc/specutils",
- "//runsc/testutil",
- "//test/root/testdata",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@com_github_cenkalti_backoff//:go_default_library",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
+
+vm_test(
+ name = "root_vm_test",
+ size = "large",
+ shard_count = 1,
+ targets = [":root_test"],
+)
diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go
index 76f1e4f2a..a26b83081 100644
--- a/test/root/cgroup_test.go
+++ b/test/root/cgroup_test.go
@@ -16,6 +16,7 @@ package root
import (
"bufio"
+ "context"
"fmt"
"io/ioutil"
"os"
@@ -24,10 +25,11 @@ import (
"strconv"
"strings"
"testing"
+ "time"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/runsc/cgroup"
- "gvisor.dev/gvisor/runsc/dockerutil"
- "gvisor.dev/gvisor/runsc/testutil"
)
func verifyPid(pid int, path string) error {
@@ -52,15 +54,82 @@ func verifyPid(pid int, path string) error {
if scanner.Err() != nil {
return scanner.Err()
}
- return fmt.Errorf("got: %s, want: %d", gots, pid)
+ return fmt.Errorf("got: %v, want: %d", gots, pid)
+}
+
+func TestMemCgroup(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Start a new container and allocate the specified about of memory.
+ allocMemSize := 128 << 20
+ allocMemLimit := 2 * allocMemSize
+
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ Memory: allocMemLimit, // Must be in bytes.
+ }, "python3", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Extract the ID to lookup the cgroup.
+ gid := d.ID()
+ t.Logf("cgroup ID: %s", gid)
+
+ // Wait when the container will allocate memory.
+ memUsage := 0
+ start := time.Now()
+ for time.Since(start) < 30*time.Second {
+ // Sleep for a brief period of time after spawning the
+ // container (so that Docker can create the cgroup etc.
+ // or after looping below (so the application can start).
+ time.Sleep(100 * time.Millisecond)
+
+ // Read the cgroup memory limit.
+ path := filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.limit_in_bytes")
+ outRaw, err := ioutil.ReadFile(path)
+ if err != nil {
+ // It's possible that the container does not exist yet.
+ continue
+ }
+ out := strings.TrimSpace(string(outRaw))
+ memLimit, err := strconv.Atoi(out)
+ if err != nil {
+ t.Fatalf("Atoi(%v): %v", out, err)
+ }
+ if memLimit != allocMemLimit {
+ // The group may not have had the correct limit set yet.
+ continue
+ }
+
+ // Read the cgroup memory usage.
+ path = filepath.Join("/sys/fs/cgroup/memory/docker", gid, "memory.max_usage_in_bytes")
+ outRaw, err = ioutil.ReadFile(path)
+ if err != nil {
+ t.Fatalf("error reading usage: %v", err)
+ }
+ out = strings.TrimSpace(string(outRaw))
+ memUsage, err = strconv.Atoi(out)
+ if err != nil {
+ t.Fatalf("Atoi(%v): %v", out, err)
+ }
+ t.Logf("read usage: %v, wanted: %v", memUsage, allocMemSize)
+
+ // Are we done?
+ if memUsage >= allocMemSize {
+ return
+ }
+ }
+
+ t.Fatalf("%vMB is less than %vMB", memUsage>>20, allocMemSize>>20)
}
// TestCgroup sets cgroup options and checks that cgroup was properly configured.
func TestCgroup(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatal("docker pull failed:", err)
- }
- d := dockerutil.MakeDocker("cgroup-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
// This is not a comprehensive list of attributes.
//
@@ -69,84 +138,133 @@ func TestCgroup(t *testing.T) {
// are often run on a single core virtual machine, and there is only a single
// CPU available in our current set, and every container's set.
attrs := []struct {
- arg string
+ field string
+ value int64
ctrl string
file string
want string
skipIfNotFound bool
}{
{
- arg: "--cpu-shares=1000",
- ctrl: "cpu",
- file: "cpu.shares",
- want: "1000",
+ field: "cpu-shares",
+ value: 1000,
+ ctrl: "cpu",
+ file: "cpu.shares",
+ want: "1000",
},
{
- arg: "--cpu-period=2000",
- ctrl: "cpu",
- file: "cpu.cfs_period_us",
- want: "2000",
+ field: "cpu-period",
+ value: 2000,
+ ctrl: "cpu",
+ file: "cpu.cfs_period_us",
+ want: "2000",
},
{
- arg: "--cpu-quota=3000",
- ctrl: "cpu",
- file: "cpu.cfs_quota_us",
- want: "3000",
+ field: "cpu-quota",
+ value: 3000,
+ ctrl: "cpu",
+ file: "cpu.cfs_quota_us",
+ want: "3000",
},
{
- arg: "--kernel-memory=100MB",
- ctrl: "memory",
- file: "memory.kmem.limit_in_bytes",
- want: "104857600",
+ field: "kernel-memory",
+ value: 100 << 20,
+ ctrl: "memory",
+ file: "memory.kmem.limit_in_bytes",
+ want: "104857600",
},
{
- arg: "--memory=1GB",
- ctrl: "memory",
- file: "memory.limit_in_bytes",
- want: "1073741824",
+ field: "memory",
+ value: 1 << 30,
+ ctrl: "memory",
+ file: "memory.limit_in_bytes",
+ want: "1073741824",
},
{
- arg: "--memory-reservation=500MB",
- ctrl: "memory",
- file: "memory.soft_limit_in_bytes",
- want: "524288000",
+ field: "memory-reservation",
+ value: 500 << 20,
+ ctrl: "memory",
+ file: "memory.soft_limit_in_bytes",
+ want: "524288000",
},
{
- arg: "--memory-swap=2GB",
+ field: "memory-swap",
+ value: 2 << 30,
ctrl: "memory",
file: "memory.memsw.limit_in_bytes",
want: "2147483648",
skipIfNotFound: true, // swap may be disabled on the machine.
},
{
- arg: "--memory-swappiness=5",
- ctrl: "memory",
- file: "memory.swappiness",
- want: "5",
+ field: "memory-swappiness",
+ value: 5,
+ ctrl: "memory",
+ file: "memory.swappiness",
+ want: "5",
+ },
+ {
+ field: "blkio-weight",
+ value: 750,
+ ctrl: "blkio",
+ file: "blkio.weight",
+ want: "750",
+ skipIfNotFound: true, // blkio groups may not be available.
},
{
- arg: "--blkio-weight=750",
- ctrl: "blkio",
- file: "blkio.weight",
- want: "750",
+ field: "pids-limit",
+ value: 1000,
+ ctrl: "pids",
+ file: "pids.max",
+ want: "1000",
},
}
- args := make([]string, 0, len(attrs))
+ // Make configs.
+ conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000")
+
+ // Add Cgroup arguments to configs.
for _, attr := range attrs {
- args = append(args, attr.arg)
+ switch attr.field {
+ case "cpu-shares":
+ hostconf.Resources.CPUShares = attr.value
+ case "cpu-period":
+ hostconf.Resources.CPUPeriod = attr.value
+ case "cpu-quota":
+ hostconf.Resources.CPUQuota = attr.value
+ case "kernel-memory":
+ hostconf.Resources.KernelMemory = attr.value
+ case "memory":
+ hostconf.Resources.Memory = attr.value
+ case "memory-reservation":
+ hostconf.Resources.MemoryReservation = attr.value
+ case "memory-swap":
+ hostconf.Resources.MemorySwap = attr.value
+ case "memory-swappiness":
+ val := attr.value
+ hostconf.Resources.MemorySwappiness = &val
+ case "blkio-weight":
+ hostconf.Resources.BlkioWeight = uint16(attr.value)
+ case "pids-limit":
+ val := attr.value
+ hostconf.Resources.PidsLimit = &val
+
+ }
}
- args = append(args, "alpine", "sleep", "10000")
- if err := d.Run(args...); err != nil {
- t.Fatal("docker create failed:", err)
+ // Create container.
+ if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("create failed with: %v", err)
}
- defer d.CleanUp()
- gid, err := d.ID()
- if err != nil {
- t.Fatalf("Docker.ID() failed: %v", err)
+ // Start container.
+ if err := d.Start(ctx); err != nil {
+ t.Fatalf("start failed with: %v", err)
}
+
+ // Lookup the relevant cgroup ID.
+ gid := d.ID()
t.Logf("cgroup ID: %s", gid)
// Check list of attributes defined above.
@@ -161,7 +279,7 @@ func TestCgroup(t *testing.T) {
t.Fatalf("failed to read %q: %v", path, err)
}
if got := strings.TrimSpace(string(out)); got != attr.want {
- t.Errorf("arg: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.arg, attr.ctrl, attr.file, got, attr.want)
+ t.Errorf("field: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.field, attr.ctrl, attr.file, got, attr.want)
}
}
@@ -179,7 +297,7 @@ func TestCgroup(t *testing.T) {
"pids",
"systemd",
}
- pid, err := d.SandboxPid()
+ pid, err := d.SandboxPid(ctx)
if err != nil {
t.Fatalf("SandboxPid: %v", err)
}
@@ -191,25 +309,34 @@ func TestCgroup(t *testing.T) {
}
}
+// TestCgroupParent sets the "CgroupParent" option and checks that the child and parent's
+// cgroups are created correctly relative to each other.
func TestCgroupParent(t *testing.T) {
- if err := dockerutil.Pull("alpine"); err != nil {
- t.Fatal("docker pull failed:", err)
- }
- d := dockerutil.MakeDocker("cgroup-test")
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
- parent := testutil.RandomName("runsc")
- if err := d.Run("--cgroup-parent", parent, "alpine", "sleep", "10000"); err != nil {
- t.Fatal("docker create failed:", err)
+ // Construct a known cgroup name.
+ parent := testutil.RandomID("runsc-")
+ conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000")
+ hostconf.Resources.CgroupParent = parent
+
+ if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil {
+ t.Fatalf("create failed with: %v", err)
}
- defer d.CleanUp()
- gid, err := d.ID()
- if err != nil {
- t.Fatalf("Docker.ID() failed: %v", err)
+
+ if err := d.Start(ctx); err != nil {
+ t.Fatalf("start failed with: %v", err)
}
+
+ // Extract the ID to look up the cgroup.
+ gid := d.ID()
t.Logf("cgroup ID: %s", gid)
// Check that sandbox is inside cgroup.
- pid, err := d.SandboxPid()
+ pid, err := d.SandboxPid(ctx)
if err != nil {
t.Fatalf("SandboxPid: %v", err)
}
diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go
index be0f63d18..58fcd6f08 100644
--- a/test/root/chroot_test.go
+++ b/test/root/chroot_test.go
@@ -16,6 +16,7 @@
package root
import (
+ "context"
"fmt"
"io/ioutil"
"os/exec"
@@ -24,19 +25,23 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
)
// TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned
// up after the sandbox is destroyed.
func TestChroot(t *testing.T) {
- d := dockerutil.MakeDocker("chroot-test")
- if err := d.Run("alpine", "sleep", "10000"); err != nil {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
- pid, err := d.SandboxPid()
+ pid, err := d.SandboxPid(ctx)
if err != nil {
t.Fatalf("Docker.SandboxPid(): %v", err)
}
@@ -72,20 +77,24 @@ func TestChroot(t *testing.T) {
t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc")
}
- d.CleanUp()
+ d.CleanUp(ctx)
}
func TestChrootGofer(t *testing.T) {
- d := dockerutil.MakeDocker("chroot-test")
- if err := d.Run("alpine", "sleep", "10000"); err != nil {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ if err := d.Spawn(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "10000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
- defer d.CleanUp()
// It's tricky to find gofers. Get sandbox PID first, then find parent. From
// parent get all immediate children, remove the sandbox, and everything else
// are gofers.
- sandPID, err := d.SandboxPid()
+ sandPID, err := d.SandboxPid(ctx)
if err != nil {
t.Fatalf("Docker.SandboxPid(): %v", err)
}
diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go
index 3f90c4c6a..df91fa0fe 100644
--- a/test/root/crictl_test.go
+++ b/test/root/crictl_test.go
@@ -16,196 +16,362 @@ package root
import (
"bytes"
+ "encoding/json"
"fmt"
"io"
"io/ioutil"
- "log"
"net/http"
"os"
"os/exec"
"path"
- "path/filepath"
+ "regexp"
+ "strconv"
"strings"
+ "sync"
"testing"
"time"
- "gvisor.dev/gvisor/runsc/criutil"
- "gvisor.dev/gvisor/runsc/dockerutil"
- "gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/testutil"
- "gvisor.dev/gvisor/test/root/testdata"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/test/criutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
// Tests for crictl have to be run as root (rather than in a user namespace)
// because crictl creates named network namespaces in /var/run/netns/.
-// TestCrictlSanity refers to b/112433158.
-func TestCrictlSanity(t *testing.T) {
- // Setup containerd and crictl.
- crictl, cleanup, err := setup(t)
- if err != nil {
- t.Fatalf("failed to setup crictl: %v", err)
+// Sandbox returns a JSON config for a simple sandbox. Sandbox names must be
+// unique so different names should be used when running tests on the same
+// containerd instance.
+func Sandbox(name string) string {
+ // Sandbox is a default JSON config for a sandbox.
+ s := map[string]interface{}{
+ "metadata": map[string]string{
+ "name": name,
+ "namespace": "default",
+ "uid": testutil.RandomID(""),
+ },
+ "linux": map[string]string{},
+ "log_directory": "/tmp",
}
- defer cleanup()
- podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.Httpd)
+
+ v, err := json.Marshal(s)
if err != nil {
- t.Fatal(err)
+ // This shouldn't happen.
+ panic(err)
}
+ return string(v)
+}
- // Look for the httpd page.
- if err = httpGet(crictl, podID, "index.html"); err != nil {
- t.Fatalf("failed to get page: %v", err)
+// SimpleSpec returns a JSON config for a simple container that runs the
+// specified command in the specified image.
+func SimpleSpec(name, image string, cmd []string, extra map[string]interface{}) string {
+ s := map[string]interface{}{
+ "metadata": map[string]string{
+ "name": name,
+ },
+ "image": map[string]string{
+ "image": testutil.ImageByName(image),
+ },
+ // Log files are not deleted after root tests are run. Log to random
+ // paths to ensure logs are fresh.
+ "log_path": fmt.Sprintf("%s.log", testutil.RandomID(name)),
+ "stdin": false,
+ "tty": false,
+ }
+ if len(cmd) > 0 { // Omit if empty.
+ s["command"] = cmd
+ }
+ for k, v := range extra {
+ s[k] = v // Extra settings.
+ }
+ v, err := json.Marshal(s)
+ if err != nil {
+ // This shouldn't happen.
+ panic(err)
}
+ return string(v)
+}
+
+// Httpd is a JSON config for an httpd container.
+var Httpd = SimpleSpec("httpd", "basic/httpd", nil, nil)
+
+// TestCrictlSanity refers to b/112433158.
+func TestCrictlSanity(t *testing.T) {
+ for _, version := range allVersions {
+ t.Run(version, func(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t, version)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
+ podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), Httpd)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
- // Stop everything.
- if err := crictl.StopPodAndContainer(podID, contID); err != nil {
- t.Fatal(err)
+ // Look for the httpd page.
+ if err = httpGet(crictl, podID, "index.html"); err != nil {
+ t.Fatalf("failed to get page: %v", err)
+ }
+
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+ })
}
}
+// HttpdMountPaths is a JSON config for an httpd container with additional
+// mounts.
+var HttpdMountPaths = SimpleSpec("httpd", "basic/httpd", nil, map[string]interface{}{
+ "mounts": []map[string]interface{}{
+ map[string]interface{}{
+ "container_path": "/var/run/secrets/kubernetes.io/serviceaccount",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx",
+ "readonly": true,
+ },
+ map[string]interface{}{
+ "container_path": "/etc/hosts",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts",
+ "readonly": false,
+ },
+ map[string]interface{}{
+ "container_path": "/dev/termination-log",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580",
+ "readonly": false,
+ },
+ map[string]interface{}{
+ "container_path": "/usr/local/apache2/htdocs/test",
+ "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064",
+ "readonly": true,
+ },
+ },
+ "linux": map[string]interface{}{},
+})
+
// TestMountPaths refers to b/117635704.
func TestMountPaths(t *testing.T) {
- // Setup containerd and crictl.
- crictl, cleanup, err := setup(t)
- if err != nil {
- t.Fatalf("failed to setup crictl: %v", err)
- }
- defer cleanup()
- podID, contID, err := crictl.StartPodAndContainer("httpd", testdata.Sandbox, testdata.HttpdMountPaths)
- if err != nil {
- t.Fatal(err)
- }
+ for _, version := range allVersions {
+ t.Run(version, func(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t, version)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
+ podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), HttpdMountPaths)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
- // Look for the directory available at /test.
- if err = httpGet(crictl, podID, "test"); err != nil {
- t.Fatalf("failed to get page: %v", err)
- }
+ // Look for the directory available at /test.
+ if err = httpGet(crictl, podID, "test"); err != nil {
+ t.Fatalf("failed to get page: %v", err)
+ }
- // Stop everything.
- if err := crictl.StopPodAndContainer(podID, contID); err != nil {
- t.Fatal(err)
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+ })
}
}
// TestMountPaths refers to b/118728671.
func TestMountOverSymlinks(t *testing.T) {
- // Setup containerd and crictl.
- crictl, cleanup, err := setup(t)
- if err != nil {
- t.Fatalf("failed to setup crictl: %v", err)
- }
- defer cleanup()
- podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, testdata.MountOverSymlink)
- if err != nil {
- t.Fatal(err)
- }
+ for _, version := range allVersions {
+ t.Run(version, func(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t, version)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
- out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf")
- if err != nil {
- t.Fatal(err)
- }
- if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) {
- t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out))
- }
+ spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil)
+ podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/resolv", Sandbox("default"), spec)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
- etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf")
- if err != nil {
- t.Fatal(err)
- }
- tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf")
- if err != nil {
- t.Fatal(err)
- }
- if tmp != etc {
- t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp))
- }
+ out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf")
+ if err != nil {
+ t.Fatalf("readlink failed: %v, out: %s", err, out)
+ }
+ if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) {
+ t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out))
+ }
+
+ etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf")
+ if err != nil {
+ t.Fatalf("cat failed: %v, out: %s", err, etc)
+ }
+ tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf")
+ if err != nil {
+ t.Fatalf("cat failed: %v, out: %s", err, out)
+ }
+ if tmp != etc {
+ t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp))
+ }
- // Stop everything.
- if err := crictl.StopPodAndContainer(podID, contID); err != nil {
- t.Fatal(err)
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+ })
}
}
// TestHomeDir tests that the HOME environment variable is set for
-// multi-containers.
+// Pod containers.
func TestHomeDir(t *testing.T) {
- // Setup containerd and crictl.
- crictl, cleanup, err := setup(t)
- if err != nil {
- t.Fatalf("failed to setup crictl: %v", err)
- }
- defer cleanup()
- contSpec := testdata.SimpleSpec("root", "k8s.gcr.io/busybox", []string{"sleep", "1000"})
- podID, contID, err := crictl.StartPodAndContainer("k8s.gcr.io/busybox", testdata.Sandbox, contSpec)
- if err != nil {
- t.Fatal(err)
- }
+ for _, version := range allVersions {
+ t.Run(version, func(t *testing.T) {
+ // Setup containerd and crictl.
+ crictl, cleanup, err := setup(t, version)
+ if err != nil {
+ t.Fatalf("failed to setup crictl: %v", err)
+ }
+ defer cleanup()
- t.Run("root container", func(t *testing.T) {
- out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME")
- if err != nil {
- t.Fatal(err)
- }
- if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
- t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want)
- }
- })
+ // Note that container ID returned here is a sub-container. All Pod
+ // containers are sub-containers. The root container of the sandbox is the
+ // pause container.
+ t.Run("sub-container", func(t *testing.T) {
+ contSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sh", "-c", "echo $HOME"}, nil)
+ podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("subcont-sandbox"), contSpec)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
- t.Run("sub-container", func(t *testing.T) {
- // Create a sub container in the same pod.
- subContSpec := testdata.SimpleSpec("subcontainer", "k8s.gcr.io/busybox", []string{"sleep", "1000"})
- subContID, err := crictl.StartContainer(podID, "k8s.gcr.io/busybox", testdata.Sandbox, subContSpec)
- if err != nil {
- t.Fatal(err)
- }
+ out, err := crictl.Logs(contID)
+ if err != nil {
+ t.Fatalf("failed retrieving container logs: %v, out: %s", err, out)
+ }
+ if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
+ t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want)
+ }
- out, err := crictl.Exec(subContID, "sh", "-c", "echo $HOME")
- if err != nil {
- t.Fatal(err)
- }
- if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
- t.Fatalf("Home directory invalid. Got %q, Want: %q", got, want)
- }
+ // Stop everything; note that the pod may have already stopped.
+ crictl.StopPodAndContainer(podID, contID)
+ })
- if err := crictl.StopContainer(subContID); err != nil {
- t.Fatal(err)
- }
- })
+ // Tests that HOME is set for the exec process.
+ t.Run("exec", func(t *testing.T) {
+ contSpec := SimpleSpec("exec", "basic/busybox", []string{"sleep", "1000"}, nil)
+ podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("exec-sandbox"), contSpec)
+ if err != nil {
+ t.Fatalf("start failed: %v", err)
+ }
- // Stop everything.
- if err := crictl.StopPodAndContainer(podID, contID); err != nil {
- t.Fatal(err)
- }
+ out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("failed retrieving container logs: %v, out: %s", err, out)
+ }
+ if got, want := strings.TrimSpace(string(out)), "/root"; got != want {
+ t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want)
+ }
+ // Stop everything.
+ if err := crictl.StopPodAndContainer(podID, contID); err != nil {
+ t.Fatalf("stop failed: %v", err)
+ }
+ })
+ })
+ }
}
+const containerdRuntime = "runsc"
+
+const v1Template = `
+disabled_plugins = ["restart"]
+[plugins.cri]
+ disable_tcp_service = true
+[plugins.linux]
+ shim = "%s"
+ shim_debug = true
+[plugins.cri.containerd.runtimes.` + containerdRuntime + `]
+ runtime_type = "io.containerd.runtime.v1.linux"
+ runtime_engine = "%s"
+ runtime_root = "%s/root/runsc"
+`
+
+const v2Template = `
+disabled_plugins = ["restart"]
+[plugins.cri]
+ disable_tcp_service = true
+[plugins.linux]
+ shim_debug = true
+[plugins.cri.containerd.runtimes.` + containerdRuntime + `]
+ runtime_type = "io.containerd.` + containerdRuntime + `.v1"
+[plugins.cri.containerd.runtimes.` + containerdRuntime + `.options]
+ TypeUrl = "io.containerd.` + containerdRuntime + `.v1.options"
+`
+
+const (
+ // v1 is the containerd API v1.
+ v1 string = "v1"
+
+ // v1 is the containerd API v21.
+ v2 string = "v2"
+)
+
+// allVersions is the set of known versions.
+var allVersions = []string{v1, v2}
+
// setup sets up before a test. Specifically it:
// * Creates directories and a socket for containerd to utilize.
// * Runs containerd and waits for it to reach a "ready" state for testing.
// * Returns a cleanup function that should be called at the end of the test.
-func setup(t *testing.T) (*criutil.Crictl, func(), error) {
- var cleanups []func()
- cleanupFunc := func() {
- for i := len(cleanups) - 1; i >= 0; i-- {
- cleanups[i]()
- }
- }
- cleanup := specutils.MakeCleanup(cleanupFunc)
- defer cleanup.Clean()
-
+func setup(t *testing.T, version string) (*criutil.Crictl, func(), error) {
// Create temporary containerd root and state directories, and a socket
// via which crictl and containerd communicate.
containerdRoot, err := ioutil.TempDir(testutil.TmpDir(), "containerd-root")
if err != nil {
t.Fatalf("failed to create containerd root: %v", err)
}
- cleanups = append(cleanups, func() { os.RemoveAll(containerdRoot) })
+ cu := cleanup.Make(func() { os.RemoveAll(containerdRoot) })
+ defer cu.Clean()
+ t.Logf("Using containerd root: %s", containerdRoot)
+
containerdState, err := ioutil.TempDir(testutil.TmpDir(), "containerd-state")
if err != nil {
t.Fatalf("failed to create containerd state: %v", err)
}
- cleanups = append(cleanups, func() { os.RemoveAll(containerdState) })
- sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock")
+ cu.Add(func() { os.RemoveAll(containerdState) })
+ t.Logf("Using containerd state: %s", containerdState)
+
+ sockDir, err := ioutil.TempDir(testutil.TmpDir(), "containerd-sock")
+ if err != nil {
+ t.Fatalf("failed to create containerd socket directory: %v", err)
+ }
+ cu.Add(func() { os.RemoveAll(sockDir) })
+ sockAddr := path.Join(sockDir, "test.sock")
+ t.Logf("Using containerd socket: %s", sockAddr)
+
+ // Extract the containerd version.
+ versionCmd := exec.Command(getContainerd(), "-v")
+ out, err := versionCmd.CombinedOutput()
+ if err != nil {
+ t.Fatalf("error extracting containerd version: %v (%s)", err, string(out))
+ }
+ r := regexp.MustCompile(" v([0-9]+)\\.([0-9]+)\\.([0-9+])")
+ vs := r.FindStringSubmatch(string(out))
+ if len(vs) != 4 {
+ t.Fatalf("error unexpected version string: %s", string(out))
+ }
+ major, err := strconv.ParseUint(vs[1], 10, 64)
+ if err != nil {
+ t.Fatalf("error parsing containerd major version: %v (%s)", err, string(out))
+ }
+ minor, err := strconv.ParseUint(vs[2], 10, 64)
+ if err != nil {
+ t.Fatalf("error parsing containerd minor version: %v (%s)", err, string(out))
+ }
+ t.Logf("Using containerd version: %d.%d", major, minor)
// We rewrite a configuration. This is based on the current docker
// configuration for the runtime under test.
@@ -213,50 +379,125 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) {
if err != nil {
t.Fatalf("error discovering runtime path: %v", err)
}
- config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(runtime))
+ t.Logf("Using runtime: %v", runtime)
+
+ // Construct a PATH that includes the runtime directory. This is
+ // because the shims will be installed there, and containerd may infer
+ // the binary name and search the PATH.
+ runtimeDir := path.Dir(runtime)
+ modifiedPath := os.Getenv("PATH")
+ if modifiedPath != "" {
+ modifiedPath = ":" + modifiedPath // We prepend below.
+ }
+ modifiedPath = path.Dir(getContainerd()) + modifiedPath
+ modifiedPath = runtimeDir + ":" + modifiedPath
+ t.Logf("Using PATH: %v", modifiedPath)
+
+ var (
+ config string
+ runpArgs []string
+ )
+ switch version {
+ case v1:
+ // This is only supported less than 1.3.
+ if major > 1 || (major == 1 && minor >= 3) {
+ t.Skipf("skipping unsupported containerd (want less than 1.3, got %d.%d)", major, minor)
+ }
+
+ // We provide the shim, followed by the runtime, and then a
+ // temporary root directory.
+ config = fmt.Sprintf(v1Template, criutil.ResolvePath("gvisor-containerd-shim"), runtime, containerdRoot)
+ case v2:
+ // This is only supported past 1.2.
+ if major < 1 || (major == 1 && minor <= 1) {
+ t.Skipf("skipping incompatible containerd (want at least 1.2, got %d.%d)", major, minor)
+ }
+
+ // The runtime is provided via parameter. Note that the v2 shim
+ // binary name is always containerd-shim-* so we don't actually
+ // care about the docker runtime name.
+ config = v2Template
+ default:
+ t.Fatalf("unknown version: %d", version)
+ }
+ t.Logf("Using config: %s", config)
+
+ // Generate the configuration for the test.
+ configFile, configCleanup, err := testutil.WriteTmpFile("containerd-config", config)
if err != nil {
t.Fatalf("failed to write containerd config")
}
- cleanups = append(cleanups, func() { os.RemoveAll(config) })
+ cu.Add(configCleanup)
// Start containerd.
- containerd := exec.Command(getContainerd(),
- "--config", config,
+ args := []string{
+ getContainerd(),
+ "--config", configFile,
"--log-level", "debug",
"--root", containerdRoot,
"--state", containerdState,
- "--address", sockAddr)
- cleanups = append(cleanups, func() {
- if err := testutil.KillCommand(containerd); err != nil {
- log.Printf("error killing containerd: %v", err)
- }
- })
- containerdStderr, err := containerd.StderrPipe()
- if err != nil {
- t.Fatalf("failed to get containerd stderr: %v", err)
+ "--address", sockAddr,
}
- containerdStdout, err := containerd.StdoutPipe()
+ t.Logf("Using args: %s", strings.Join(args, " "))
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Env = append(os.Environ(), "PATH="+modifiedPath)
+
+ // Include output in logs.
+ stderrPipe, err := cmd.StderrPipe()
if err != nil {
- t.Fatalf("failed to get containerd stdout: %v", err)
+ t.Fatalf("failed to create stderr pipe: %v", err)
}
- if err := containerd.Start(); err != nil {
+ cu.Add(func() { stderrPipe.Close() })
+ stdoutPipe, err := cmd.StdoutPipe()
+ if err != nil {
+ t.Fatalf("failed to create stdout pipe: %v", err)
+ }
+ cu.Add(func() { stdoutPipe.Close() })
+ var (
+ wg sync.WaitGroup
+ stderr bytes.Buffer
+ stdout bytes.Buffer
+ )
+ startupR, startupW := io.Pipe()
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ io.Copy(io.MultiWriter(startupW, &stderr), stderrPipe)
+ }()
+ go func() {
+ defer wg.Done()
+ io.Copy(io.MultiWriter(startupW, &stdout), stdoutPipe)
+ }()
+ cu.Add(func() {
+ wg.Wait()
+ t.Logf("containerd stdout: %s", stdout.String())
+ t.Logf("containerd stderr: %s", stderr.String())
+ })
+
+ // Start the process.
+ if err := cmd.Start(); err != nil {
t.Fatalf("failed running containerd: %v", err)
}
- // Wait for containerd to boot. Then put all containerd output into a
- // buffer to be logged at the end of the test.
- testutil.WaitUntilRead(containerdStderr, "Start streaming server", nil, 10*time.Second)
- stdoutBuf := &bytes.Buffer{}
- stderrBuf := &bytes.Buffer{}
- go func() { io.Copy(stdoutBuf, containerdStdout) }()
- go func() { io.Copy(stderrBuf, containerdStderr) }()
- cleanups = append(cleanups, func() {
- t.Logf("containerd stdout: %s", string(stdoutBuf.Bytes()))
- t.Logf("containerd stderr: %s", string(stderrBuf.Bytes()))
+ // Wait for containerd to boot.
+ if err := testutil.WaitUntilRead(startupR, "Start streaming server", nil, 10*time.Second); err != nil {
+ t.Fatalf("failed to start containerd: %v", err)
+ }
+
+ // Discard all subsequent data.
+ go io.Copy(ioutil.Discard, startupR)
+
+ // Create the crictl interface.
+ cc := criutil.NewCrictl(t, sockAddr, runpArgs)
+ cu.Add(cc.CleanUp)
+
+ // Kill must be the last cleanup (as it will be executed first).
+ cu.Add(func() {
+ // Best effort: ignore errors.
+ testutil.KillCommand(cmd)
})
- cleanup.Release()
- return criutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil
+ return cc, cu.Release(), nil
}
// httpGet GETs the contents of a file served from a pod on port 80.
diff --git a/test/root/main_test.go b/test/root/main_test.go
index d74dec85f..9fb17e0dd 100644
--- a/test/root/main_test.go
+++ b/test/root/main_test.go
@@ -21,7 +21,7 @@ import (
"testing"
"github.com/syndtr/gocapability/capability"
- "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/runsc/specutils"
)
diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go
index 126f0975a..4243eb59e 100644
--- a/test/root/oom_score_adj_test.go
+++ b/test/root/oom_score_adj_test.go
@@ -20,10 +20,10 @@ import (
"testing"
specs "github.com/opencontainers/runtime-spec/specs-go"
- "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/pkg/cleanup"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/testutil"
)
var (
@@ -40,15 +40,6 @@ var (
// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a
// single container sandbox.
func TestOOMScoreAdjSingle(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
-
- conf := testutil.TestConfig()
- conf.RootDir = rootDir
-
ppid, err := specutils.GetParentPid(os.Getpid())
if err != nil {
t.Fatalf("getting parent pid: %v", err)
@@ -89,11 +80,11 @@ func TestOOMScoreAdjSingle(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.Name, func(t *testing.T) {
- id := testutil.UniqueContainerID()
+ id := testutil.RandomContainerID()
s := testutil.NewSpecWithArgs("sleep", "1000")
s.Process.OOMScoreAdj = testCase.OOMScoreAdj
- containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id})
+ containers, cleanup, err := startContainers(t, []*specs.Spec{s}, []string{id})
if err != nil {
t.Fatalf("error starting containers: %v", err)
}
@@ -131,15 +122,6 @@ func TestOOMScoreAdjSingle(t *testing.T) {
// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a
// multi-container sandbox.
func TestOOMScoreAdjMulti(t *testing.T) {
- rootDir, err := testutil.SetupRootDir()
- if err != nil {
- t.Fatalf("error creating root dir: %v", err)
- }
- defer os.RemoveAll(rootDir)
-
- conf := testutil.TestConfig()
- conf.RootDir = rootDir
-
ppid, err := specutils.GetParentPid(os.Getpid())
if err != nil {
t.Fatalf("getting parent pid: %v", err)
@@ -257,7 +239,7 @@ func TestOOMScoreAdjMulti(t *testing.T) {
}
}
- containers, cleanup, err := startContainers(conf, specs, ids)
+ containers, cleanup, err := startContainers(t, specs, ids)
if err != nil {
t.Fatalf("error starting containers: %v", err)
}
@@ -321,7 +303,7 @@ func TestOOMScoreAdjMulti(t *testing.T) {
func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
var specs []*specs.Spec
var ids []string
- rootID := testutil.UniqueContainerID()
+ rootID := testutil.RandomContainerID()
for i, cmd := range cmds {
spec := testutil.NewSpecWithArgs(cmd...)
@@ -335,35 +317,34 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer,
specutils.ContainerdSandboxIDAnnotation: rootID,
}
- ids = append(ids, testutil.UniqueContainerID())
+ ids = append(ids, testutil.RandomContainerID())
}
specs = append(specs, spec)
}
return specs, ids
}
-func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) {
- if len(conf.RootDir) == 0 {
- panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.")
- }
-
+func startContainers(t *testing.T, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) {
var containers []*container.Container
- var bundles []string
- cleanup := func() {
- for _, c := range containers {
- c.Destroy()
- }
- for _, b := range bundles {
- os.RemoveAll(b)
- }
+
+ // All containers must share the same root.
+ rootDir, clean, err := testutil.SetupRootDir()
+ if err != nil {
+ t.Fatalf("error creating root dir: %v", err)
}
+ cu := cleanup.Make(clean)
+ defer cu.Clean()
+
+ // Point this to from the configuration.
+ conf := testutil.TestConfig(t)
+ conf.RootDir = rootDir
+
for i, spec := range specs {
- bundleDir, err := testutil.SetupBundleDir(spec)
+ bundleDir, clean, err := testutil.SetupBundleDir(spec)
if err != nil {
- cleanup()
- return nil, nil, fmt.Errorf("error setting up container: %v", err)
+ return nil, nil, fmt.Errorf("error setting up bundle: %v", err)
}
- bundles = append(bundles, bundleDir)
+ cu.Add(clean)
args := container.Args{
ID: ids[i],
@@ -372,15 +353,14 @@ func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*c
}
cont, err := container.New(conf, args)
if err != nil {
- cleanup()
return nil, nil, fmt.Errorf("error creating container: %v", err)
}
containers = append(containers, cont)
if err := cont.Start(conf); err != nil {
- cleanup()
return nil, nil, fmt.Errorf("error starting container: %v", err)
}
}
- return containers, cleanup, nil
+
+ return containers, cu.Release(), nil
}
diff --git a/test/root/runsc_test.go b/test/root/runsc_test.go
new file mode 100644
index 000000000..25204bebb
--- /dev/null
+++ b/test/root/runsc_test.go
@@ -0,0 +1,151 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/cenkalti/backoff"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/test/testutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// TestDoKill checks that when "runsc do..." is killed, the sandbox process is
+// also terminated. This ensures that parent death signal is propagate to the
+// sandbox process correctly.
+func TestDoKill(t *testing.T) {
+ // Make the sandbox process be reparented here when it's killed, so we can
+ // wait for it.
+ if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil {
+ t.Fatalf("prctl(PR_SET_CHILD_SUBREAPER): %v", err)
+ }
+
+ cmd := exec.Command(specutils.ExePath, "do", "sleep", "10000")
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ cmd.Stderr = buf
+ cmd.Start()
+
+ var pid int
+ findSandbox := func() error {
+ var err error
+ pid, err = sandboxPid(cmd.Process.Pid)
+ if err != nil {
+ return &backoff.PermanentError{Err: err}
+ }
+ if pid == 0 {
+ return fmt.Errorf("sandbox process not found")
+ }
+ return nil
+ }
+ if err := testutil.Poll(findSandbox, 10*time.Second); err != nil {
+ t.Fatalf("failed to find sandbox: %v", err)
+ }
+ t.Logf("Found sandbox, pid: %d", pid)
+
+ if err := cmd.Process.Kill(); err != nil {
+ t.Fatalf("failed to kill run process: %v", err)
+ }
+ cmd.Wait()
+ t.Logf("Parent process killed (%d). Output: %s", cmd.Process.Pid, buf.String())
+
+ ch := make(chan struct{})
+ go func() {
+ defer func() { ch <- struct{}{} }()
+ t.Logf("Waiting for sandbox process (%d) termination", pid)
+ if _, err := unix.Wait4(pid, nil, 0, nil); err != nil {
+ t.Errorf("error waiting for sandbox process (%d): %v", pid, err)
+ }
+ }()
+ select {
+ case <-ch:
+ // Done
+ case <-time.After(5 * time.Second):
+ t.Fatalf("timeout waiting for sandbox process (%d) to exit", pid)
+ }
+}
+
+// sandboxPid looks for the sandbox process inside the process tree starting
+// from "pid". It returns 0 and no error if no sandbox process is found. It
+// returns error if anything failed.
+func sandboxPid(pid int) (int, error) {
+ cmd := exec.Command("pgrep", "-P", strconv.Itoa(pid))
+ buf := &bytes.Buffer{}
+ cmd.Stdout = buf
+ if err := cmd.Start(); err != nil {
+ return 0, err
+ }
+ ps, err := cmd.Process.Wait()
+ if err != nil {
+ return 0, err
+ }
+ if ps.ExitCode() == 1 {
+ // pgrep returns 1 when no process is found.
+ return 0, nil
+ }
+
+ var children []int
+ for _, line := range strings.Split(buf.String(), "\n") {
+ if len(line) == 0 {
+ continue
+ }
+ child, err := strconv.Atoi(line)
+ if err != nil {
+ return 0, err
+ }
+
+ cmdline, err := ioutil.ReadFile(filepath.Join("/proc", line, "cmdline"))
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Raced with process exit.
+ continue
+ }
+ return 0, err
+ }
+ args := strings.SplitN(string(cmdline), "\x00", 2)
+ if len(args) == 0 {
+ return 0, fmt.Errorf("malformed cmdline file: %q", cmdline)
+ }
+ // The sandbox process has the first argument set to "runsc-sandbox".
+ if args[0] == "runsc-sandbox" {
+ return child, nil
+ }
+
+ children = append(children, child)
+ }
+
+ // Sandbox process wasn't found, try another level down.
+ for _, pid := range children {
+ sand, err := sandboxPid(pid)
+ if err != nil {
+ return 0, err
+ }
+ if sand != 0 {
+ return sand, nil
+ }
+ // Not found, continue the search.
+ }
+ return 0, nil
+}
diff --git a/test/root/testdata/BUILD b/test/root/testdata/BUILD
deleted file mode 100644
index 125633680..000000000
--- a/test/root/testdata/BUILD
+++ /dev/null
@@ -1,19 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "testdata",
- srcs = [
- "busybox.go",
- "containerd_config.go",
- "httpd.go",
- "httpd_mount_paths.go",
- "sandbox.go",
- "simple.go",
- ],
- importpath = "gvisor.dev/gvisor/test/root/testdata",
- visibility = [
- "//visibility:public",
- ],
-)
diff --git a/test/root/testdata/containerd_config.go b/test/root/testdata/containerd_config.go
deleted file mode 100644
index e12f1ec88..000000000
--- a/test/root/testdata/containerd_config.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package testdata contains data required for root tests.
-package testdata
-
-import "fmt"
-
-// containerdConfigTemplate is a .toml config for containerd. It contains a
-// formatting verb so the runtime field can be set via fmt.Sprintf.
-const containerdConfigTemplate = `
-disabled_plugins = ["restart"]
-[plugins.linux]
- runtime = "%s"
- runtime_root = "/tmp/test-containerd/runsc"
- shim = "/usr/local/bin/gvisor-containerd-shim"
- shim_debug = true
-
-[plugins.cri.containerd.runtimes.runsc]
- runtime_type = "io.containerd.runtime.v1.linux"
- runtime_engine = "%s"
-`
-
-// ContainerdConfig returns a containerd config file with the specified
-// runtime.
-func ContainerdConfig(runtime string) string {
- return fmt.Sprintf(containerdConfigTemplate, runtime, runtime)
-}
diff --git a/test/root/testdata/httpd_mount_paths.go b/test/root/testdata/httpd_mount_paths.go
deleted file mode 100644
index ac3f4446a..000000000
--- a/test/root/testdata/httpd_mount_paths.go
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package testdata
-
-// HttpdMountPaths is a JSON config for an httpd container with additional
-// mounts.
-const HttpdMountPaths = `
-{
- "metadata": {
- "name": "httpd"
- },
- "image":{
- "image": "httpd"
- },
- "mounts": [
- {
- "container_path": "/var/run/secrets/kubernetes.io/serviceaccount",
- "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/volumes/kubernetes.io~secret/default-token-2rpfx",
- "readonly": true
- },
- {
- "container_path": "/etc/hosts",
- "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/etc-hosts",
- "readonly": false
- },
- {
- "container_path": "/dev/termination-log",
- "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064/containers/httpd/d1709580",
- "readonly": false
- },
- {
- "container_path": "/usr/local/apache2/htdocs/test",
- "host_path": "/var/lib/kubelet/pods/82bae206-cdf5-11e8-b245-8cdcd43ac064",
- "readonly": true
- }
- ],
- "linux": {
- },
- "log_path": "httpd.log"
-}
-`
diff --git a/test/runner/BUILD b/test/runner/BUILD
new file mode 100644
index 000000000..582d2946d
--- /dev/null
+++ b/test/runner/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "bzl_library", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["runner.go"],
+ data = [
+ "//runsc",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/test/testutil",
+ "//runsc/specutils",
+ "//test/runner/gtest",
+ "//test/uds",
+ "@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@com_github_syndtr_gocapability//capability:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
new file mode 100644
index 000000000..2d64934b0
--- /dev/null
+++ b/test/runner/defs.bzl
@@ -0,0 +1,249 @@
+"""Defines a rule for syscall test targets."""
+
+load("//tools:defs.bzl", "default_platform", "loopback", "platforms")
+
+def _runner_test_impl(ctx):
+ # Generate a runner binary.
+ runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "set -euf -x -o pipefail",
+ "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then",
+ " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ "fi",
+ "exec %s %s %s\n" % (
+ ctx.files.runner[0].short_path,
+ " ".join(ctx.attr.runner_args),
+ ctx.files.test[0].short_path,
+ ),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return with all transitive files.
+ runfiles = ctx.runfiles(
+ transitive_files = depset(transitive = [
+ target.data_runfiles.files
+ for target in (ctx.attr.runner, ctx.attr.test)
+ if hasattr(target, "data_runfiles")
+ ]),
+ files = ctx.files.runner + ctx.files.test,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_runner_test = rule(
+ attrs = {
+ "runner": attr.label(
+ default = "//test/runner:runner",
+ ),
+ "test": attr.label(
+ mandatory = True,
+ ),
+ "runner_args": attr.string_list(),
+ "data": attr.label_list(
+ allow_files = True,
+ ),
+ },
+ test = True,
+ implementation = _runner_test_impl,
+)
+
+def _syscall_test(
+ test,
+ shard_count,
+ size,
+ platform,
+ use_tmpfs,
+ tags,
+ network = "none",
+ file_access = "exclusive",
+ overlay = False,
+ add_uds_tree = False,
+ vfs2 = False,
+ fuse = False):
+ # Prepend "runsc" to non-native platform names.
+ full_platform = platform if platform == "native" else "runsc_" + platform
+
+ # Name the test appropriately.
+ name = test.split(":")[1] + "_" + full_platform
+ if file_access == "shared":
+ name += "_shared"
+ if overlay:
+ name += "_overlay"
+ if vfs2:
+ name += "_vfs2"
+ if fuse:
+ name += "_fuse"
+ if network != "none":
+ name += "_" + network + "net"
+
+ # Apply all tags.
+ if tags == None:
+ tags = []
+
+ # Add the full_platform and file access in a tag to make it easier to run
+ # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
+ tags += [full_platform, "file_" + file_access]
+
+ # Hash this target into one of 15 buckets. This can be used to
+ # randomly split targets between different workflows.
+ hash15 = hash(native.package_name() + name) % 15
+ tags.append("hash15:" + str(hash15))
+
+ # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until
+ # we figure out how to request ipv4 sockets on Guitar machines.
+ if network == "host":
+ tags.append("noguitar")
+ tags.append("block-network")
+
+ # Disable off-host networking.
+ tags.append("requires-net:loopback")
+
+ runner_args = [
+ # Arguments are passed directly to runner binary.
+ "--platform=" + platform,
+ "--network=" + network,
+ "--use-tmpfs=" + str(use_tmpfs),
+ "--file-access=" + file_access,
+ "--overlay=" + str(overlay),
+ "--add-uds-tree=" + str(add_uds_tree),
+ "--vfs2=" + str(vfs2),
+ "--fuse=" + str(fuse),
+ ]
+
+ # Call the rule above.
+ _runner_test(
+ name = name,
+ test = test,
+ runner_args = runner_args,
+ data = [loopback],
+ size = size,
+ tags = tags,
+ shard_count = shard_count,
+ )
+
+def syscall_test(
+ test,
+ shard_count = 5,
+ size = "small",
+ use_tmpfs = False,
+ add_overlay = False,
+ add_uds_tree = False,
+ add_hostinet = False,
+ vfs2 = True,
+ fuse = False,
+ tags = None):
+ """syscall_test is a macro that will create targets for all platforms.
+
+ Args:
+ test: the test target.
+ shard_count: shards for defined tests.
+ size: the defined test size.
+ use_tmpfs: use tmpfs in the defined tests.
+ add_overlay: add an overlay test.
+ add_uds_tree: add a UDS test.
+ add_hostinet: add a hostinet test.
+ tags: starting test tags.
+ """
+ if not tags:
+ tags = []
+
+ vfs2_tags = list(tags)
+ if vfs2:
+ # Add tag to easily run VFS2 tests with --test_tag_filters=vfs2
+ vfs2_tags.append("vfs2")
+ if fuse:
+ vfs2_tags.append("fuse")
+
+ else:
+ # Don't automatically run tests tests not yet passing.
+ vfs2_tags.append("manual")
+ vfs2_tags.append("noguitar")
+ vfs2_tags.append("notap")
+
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + vfs2_tags,
+ vfs2 = True,
+ fuse = fuse,
+ )
+ if fuse:
+ # Only generate *_vfs2_fuse target if fuse parameter is enabled.
+ return
+
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = "native",
+ use_tmpfs = False,
+ add_uds_tree = add_uds_tree,
+ tags = list(tags),
+ )
+
+ for (platform, platform_tags) in platforms.items():
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platform_tags + tags,
+ )
+
+ # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests.
+ if add_overlay:
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ overlay = True,
+ )
+
+ if add_hostinet:
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ network = "host",
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ )
+
+ if not use_tmpfs:
+ # Also test shared gofer access.
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + tags,
+ file_access = "shared",
+ )
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = default_platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platforms[default_platform] + vfs2_tags,
+ file_access = "shared",
+ vfs2 = True,
+ )
diff --git a/test/runner/gtest/BUILD b/test/runner/gtest/BUILD
new file mode 100644
index 000000000..de4b2727c
--- /dev/null
+++ b/test/runner/gtest/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gtest",
+ srcs = ["gtest.go"],
+ visibility = ["//:sandbox"],
+)
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
new file mode 100644
index 000000000..e4445e01b
--- /dev/null
+++ b/test/runner/gtest/gtest.go
@@ -0,0 +1,170 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gtest contains helpers for running google-test tests from Go.
+package gtest
+
+import (
+ "fmt"
+ "os/exec"
+ "strings"
+)
+
+var (
+ // listTestFlag is the flag that will list tests in gtest binaries.
+ listTestFlag = "--gtest_list_tests"
+
+ // filterTestFlag is the flag that will filter tests in gtest binaries.
+ filterTestFlag = "--gtest_filter"
+
+ // listBechmarkFlag is the flag that will list benchmarks in gtest binaries.
+ listBenchmarkFlag = "--benchmark_list_tests"
+
+ // filterBenchmarkFlag is the flag that will run specified benchmarks.
+ filterBenchmarkFlag = "--benchmark_filter"
+)
+
+// TestCase is a single gtest test case.
+type TestCase struct {
+ // Suite is the suite for this test.
+ Suite string
+
+ // Name is the name of this individual test.
+ Name string
+
+ // all indicates that this will run without flags. This takes
+ // precendence over benchmark below.
+ all bool
+
+ // benchmark indicates that this is a benchmark. In this case, the
+ // suite will be empty, and we will use the appropriate test and
+ // benchmark flags.
+ benchmark bool
+}
+
+// FullName returns the name of the test including the suite. It is suitable to
+// pass to "-gtest_filter".
+func (tc TestCase) FullName() string {
+ return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
+}
+
+// Args returns arguments to be passed when invoking the test.
+func (tc TestCase) Args() []string {
+ if tc.all {
+ return []string{} // No arguments.
+ }
+ if tc.benchmark {
+ return []string{
+ fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name),
+ fmt.Sprintf("%s=", filterTestFlag),
+ }
+ }
+ return []string{
+ fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()),
+ }
+}
+
+// ParseTestCases calls a gtest test binary to list its test and returns a
+// slice with the name and suite of each test.
+//
+// If benchmarks is true, then benchmarks will be included in the list of test
+// cases provided. Note that this requires the binary to support the
+// benchmarks_list_tests flag.
+func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) {
+ // Run to extract test cases.
+ args := append([]string{listTestFlag}, extraArgs...)
+ cmd := exec.Command(testBin, args...)
+ out, err := cmd.Output()
+ if err != nil {
+ // We failed to list tests with the given flags. Just
+ // return something that will run the binary with no
+ // flags, which should execute all tests.
+ return []TestCase{
+ TestCase{
+ Suite: "Default",
+ Name: "All",
+ all: true,
+ },
+ }, nil
+ }
+
+ // Parse test output.
+ var t []TestCase
+ var suite string
+ for _, line := range strings.Split(string(out), "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // New suite?
+ if !strings.HasPrefix(line, " ") {
+ suite = strings.TrimSuffix(strings.TrimSpace(line), ".")
+ continue
+ }
+
+ // Individual test.
+ name := strings.TrimSpace(line)
+
+ // Do we have a suite yet?
+ if suite == "" {
+ return nil, fmt.Errorf("test without a suite: %v", name)
+ }
+
+ // Add this individual test.
+ t = append(t, TestCase{
+ Suite: suite,
+ Name: name,
+ })
+ }
+
+ // Finished?
+ if !benchmarks {
+ return t, nil
+ }
+
+ // Run again to extract benchmarks.
+ args = append([]string{listBenchmarkFlag}, extraArgs...)
+ cmd = exec.Command(testBin, args...)
+ out, err = cmd.Output()
+ if err != nil {
+ // We were able to enumerate tests above, but not benchmarks?
+ // We requested them, so we return an error in this case.
+ exitErr, ok := err.(*exec.ExitError)
+ if !ok {
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err)
+ }
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr)
+ }
+
+ benches := strings.Trim(string(out), "\n")
+ if len(benches) == 0 {
+ return t, nil
+ }
+
+ // Parse benchmark output.
+ for _, line := range strings.Split(benches, "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // Single benchmark.
+ name := strings.TrimSpace(line)
+
+ // Add the single benchmark.
+ t = append(t, TestCase{
+ Suite: "Benchmarks",
+ Name: name,
+ benchmark: true,
+ })
+ }
+ return t, nil
+}
diff --git a/test/syscalls/syscall_test_runner.go b/test/runner/runner.go
index 856398994..5ac91310d 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/runner/runner.go
@@ -30,25 +30,25 @@ import (
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "github.com/syndtr/gocapability/capability"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/testutil"
- "gvisor.dev/gvisor/test/syscalls/gtest"
+ "gvisor.dev/gvisor/test/runner/gtest"
"gvisor.dev/gvisor/test/uds"
)
-// Location of syscall tests, relative to the repo root.
-const testDir = "test/syscalls/linux"
-
var (
- testName = flag.String("test-name", "", "name of test binary to run")
debug = flag.Bool("debug", false, "enable debug logs")
strace = flag.Bool("strace", false, "enable strace logs")
platform = flag.String("platform", "ptrace", "platform to run on")
+ network = flag.String("network", "none", "network stack to run on (sandbox, host, none)")
useTmpfs = flag.Bool("use-tmpfs", false, "mounts tmpfs for /tmp")
fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode")
overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay")
+ vfs2 = flag.Bool("vfs2", false, "enable VFS2")
+ fuse = flag.Bool("fuse", false, "enable FUSE")
parallel = flag.Bool("parallel", false, "run tests in parallel")
runscPath = flag.String("runsc", "", "path to runsc binary")
@@ -102,10 +102,17 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir)
}
- cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName())
+ cmd := exec.Command(testBin, tc.Args()...)
cmd.Env = env
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
+
+ if specutils.HasCapabilities(capability.CAP_NET_ADMIN) {
+ cmd.SysProcAttr = &syscall.SysProcAttr{
+ Cloneflags: syscall.CLONE_NEWNET,
+ }
+ }
+
if err := cmd.Run(); err != nil {
ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus)
t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus())
@@ -118,26 +125,26 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
//
// Returns an error if the sandboxed application exits non-zero.
func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
- bundleDir, err := testutil.SetupBundleDir(spec)
+ bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
if err != nil {
return fmt.Errorf("SetupBundleDir failed: %v", err)
}
- defer os.RemoveAll(bundleDir)
+ defer cleanup()
- rootDir, err := testutil.SetupRootDir()
+ rootDir, cleanup, err := testutil.SetupRootDir()
if err != nil {
return fmt.Errorf("SetupRootDir failed: %v", err)
}
- defer os.RemoveAll(rootDir)
+ defer cleanup()
name := tc.FullName()
- id := testutil.UniqueContainerID()
+ id := testutil.RandomContainerID()
log.Infof("Running test %q in container %q", name, id)
specutils.LogSpec(spec)
args := []string{
"-root", rootDir,
- "-network=none",
+ "-network", *network,
"-log-format=text",
"-TESTONLY-unsafe-nonroot=true",
"-net-raw=true",
@@ -149,6 +156,12 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
if *overlay {
args = append(args, "-overlay")
}
+ if *vfs2 {
+ args = append(args, "-vfs2")
+ if *fuse {
+ args = append(args, "-fuse")
+ }
+ }
if *debug {
args = append(args, "-debug", "-log-packets=true")
}
@@ -159,12 +172,14 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
args = append(args, "-fsgofer-host-uds")
}
- if outDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
- tdir := filepath.Join(outDir, strings.Replace(name, "/", "_", -1))
- if err := os.MkdirAll(tdir, 0755); err != nil {
+ testLogDir := ""
+ if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
+ // Create log directory dedicated for this test.
+ testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1))
+ if err := os.MkdirAll(testLogDir, 0755); err != nil {
return fmt.Errorf("could not create test dir: %v", err)
}
- debugLogDir, err := ioutil.TempDir(tdir, "runsc")
+ debugLogDir, err := ioutil.TempDir(testLogDir, "runsc")
if err != nil {
return fmt.Errorf("could not create temp dir: %v", err)
}
@@ -200,22 +215,25 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
sig := make(chan os.Signal, 1)
+ defer close(sig)
signal.Notify(sig, syscall.SIGTERM)
+ defer signal.Stop(sig)
go func() {
s, ok := <-sig
if !ok {
return
}
log.Warningf("%s: Got signal: %v", name, s)
- done := make(chan bool)
- go func() {
- dArgs := append(args, "-alsologtostderr=true", "debug", "--stacks", id)
- cmd := exec.Command(*runscPath, dArgs...)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- cmd.Run()
+ done := make(chan bool, 1)
+ dArgs := append([]string{}, args...)
+ dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id)
+ go func(dArgs []string) {
+ debug := exec.Command(*runscPath, dArgs...)
+ debug.Stdout = os.Stdout
+ debug.Stderr = os.Stderr
+ debug.Run()
done <- true
- }()
+ }(dArgs)
timeout := time.After(3 * time.Second)
select {
@@ -225,19 +243,21 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
}
log.Warningf("Send SIGTERM to the sandbox process")
- dArgs := append(args, "debug",
+ dArgs = append(args, "debug",
fmt.Sprintf("--signal=%d", syscall.SIGTERM),
id)
- cmd = exec.Command(*runscPath, dArgs...)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- cmd.Run()
+ signal := exec.Command(*runscPath, dArgs...)
+ signal.Stdout = os.Stdout
+ signal.Stderr = os.Stderr
+ signal.Run()
}()
err = cmd.Run()
-
- signal.Stop(sig)
- close(sig)
+ if err == nil && len(testLogDir) > 0 {
+ // If the test passed, then we erase the log directory. This speeds up
+ // uploading logs in continuous integration & saves on disk space.
+ os.RemoveAll(testLogDir)
+ }
return err
}
@@ -294,7 +314,7 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) {
func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Run a new container with the test executable and filter for the
// given test suite and name.
- spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName())
+ spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...)
// Mark the root as writeable, as some tests attempt to
// write to the rootfs, and expect EACCES, not EROFS.
@@ -302,6 +322,7 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Test spec comes with pre-defined mounts that we don't want. Reset it.
spec.Mounts = nil
+ testTmpDir := "/tmp"
if *useTmpfs {
// Forces '/tmp' to be mounted as tmpfs, otherwise test that rely on
// features only available in gVisor's internal tmpfs may fail.
@@ -327,17 +348,38 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
t.Fatalf("could not chmod temp dir: %v", err)
}
- spec.Mounts = append(spec.Mounts, specs.Mount{
- Type: "bind",
- Destination: "/tmp",
- Source: tmpDir,
- })
+ // "/tmp" is not replaced with a tmpfs mount inside the sandbox
+ // when it's not empty. This ensures that testTmpDir uses gofer
+ // in exclusive mode.
+ testTmpDir = tmpDir
+ if *fileAccess == "shared" {
+ // All external mounts except the root mount are shared.
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Type: "bind",
+ Destination: "/tmp",
+ Source: tmpDir,
+ })
+ testTmpDir = "/tmp"
+ }
}
- // Set environment variable that indicates we are
- // running in gVisor and with the given platform.
+ // Set environment variables that indicate we are running in gVisor with
+ // the given platform, network, and filesystem stack.
platformVar := "TEST_ON_GVISOR"
- env := append(os.Environ(), platformVar+"="+*platform)
+ networkVar := "GVISOR_NETWORK"
+ env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network)
+ vfsVar := "GVISOR_VFS"
+ if *vfs2 {
+ env = append(env, vfsVar+"=VFS2")
+ fuseVar := "FUSE_ENABLED"
+ if *fuse {
+ env = append(env, fuseVar+"=TRUE")
+ } else {
+ env = append(env, fuseVar+"=FALSE")
+ }
+ } else {
+ env = append(env, vfsVar+"=VFS1")
+ }
// Remove env variables that cause the gunit binary to write output
// files, since they will stomp on eachother, and on the output files
@@ -350,12 +392,8 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to
// be backed by tmpfs.
- for i, kv := range env {
- if strings.HasPrefix(kv, "TEST_TMPDIR=") {
- env[i] = "TEST_TMPDIR=/tmp"
- break
- }
- }
+ env = filterEnv(env, []string{"TEST_TMPDIR"})
+ env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir))
spec.Process.Env = env
@@ -372,12 +410,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
}
}
-// filterEnv returns an environment with the blacklisted variables removed.
-func filterEnv(env, blacklist []string) []string {
+// filterEnv returns an environment with the excluded variables removed.
+func filterEnv(env, exclude []string) []string {
var out []string
for _, kv := range env {
ok := true
- for _, k := range blacklist {
+ for _, k := range exclude {
if strings.HasPrefix(kv, k+"=") {
ok = false
break
@@ -401,9 +439,10 @@ func matchString(a, b string) (bool, error) {
func main() {
flag.Parse()
- if *testName == "" {
- fatalf("test-name flag must be provided")
+ if flag.NArg() != 1 {
+ fatalf("test must be provided")
}
+ testBin := flag.Args()[0] // Only argument.
log.SetLevel(log.Info)
if *debug {
@@ -433,34 +472,31 @@ func main() {
}
}
- // Get path to test binary.
- fullTestName := filepath.Join(testDir, *testName)
- testBin, err := testutil.FindFile(fullTestName)
- if err != nil {
- fatalf("FindFile(%q) failed: %v", fullTestName, err)
- }
-
// Get all test cases in each binary.
- testCases, err := gtest.ParseTestCases(testBin)
+ testCases, err := gtest.ParseTestCases(testBin, true)
if err != nil {
fatalf("ParseTestCases(%q) failed: %v", testBin, err)
}
// Get subset of tests corresponding to shard.
- begin, end, err := testutil.TestBoundsForShard(len(testCases))
+ indices, err := testutil.TestIndicesForShard(len(testCases))
if err != nil {
fatalf("TestsForShard() failed: %v", err)
}
- testCases = testCases[begin:end]
+
+ // Resolve the absolute path for the binary.
+ testBin, err = filepath.Abs(testBin)
+ if err != nil {
+ fatalf("Abs() failed: %v", err)
+ }
// Run the tests.
var tests []testing.InternalTest
- for _, tc := range testCases {
+ for _, tci := range indices {
// Capture tc.
- tc := tc
- testName := fmt.Sprintf("%s_%s", tc.Suite, tc.Name)
+ tc := testCases[tci]
tests = append(tests, testing.InternalTest{
- Name: testName,
+ Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name),
F: func(t *testing.T) {
if *parallel {
t.Parallel()
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index 2e125525b..066338ee3 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -1,53 +1,46 @@
-# These packages are used to run language runtime tests inside gVisor sandboxes.
-
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
-load("//test/runtimes:build_defs.bzl", "runtime_test")
+load("//tools:defs.bzl", "bzl_library")
+load("//test/runtimes:defs.bzl", "runtime_test")
package(licenses = ["notice"])
-go_binary(
- name = "runner",
- testonly = 1,
- srcs = ["runner.go"],
- deps = [
- "//runsc/dockerutil",
- "//runsc/testutil",
- ],
-)
-
runtime_test(
- blacklist_file = "blacklist_go1.12.csv",
- image = "gcr.io/gvisor-presubmit/go1.12",
+ name = "go1.12",
+ exclude_file = "exclude_go1.12.csv",
lang = "go",
+ shard_count = 8,
)
runtime_test(
- blacklist_file = "blacklist_java11.csv",
- image = "gcr.io/gvisor-presubmit/java11",
+ name = "java11",
+ batch = 100,
+ exclude_file = "exclude_java11.csv",
lang = "java",
+ shard_count = 16,
)
runtime_test(
- blacklist_file = "blacklist_nodejs12.4.0.csv",
- image = "gcr.io/gvisor-presubmit/nodejs12.4.0",
+ name = "nodejs12.4.0",
+ exclude_file = "exclude_nodejs12.4.0.csv",
lang = "nodejs",
+ shard_count = 8,
)
runtime_test(
- blacklist_file = "blacklist_php7.3.6.csv",
- image = "gcr.io/gvisor-presubmit/php7.3.6",
+ name = "php7.3.6",
+ exclude_file = "exclude_php7.3.6.csv",
lang = "php",
+ shard_count = 8,
)
runtime_test(
- blacklist_file = "blacklist_python3.7.3.csv",
- image = "gcr.io/gvisor-presubmit/python3.7.3",
+ name = "python3.7.3",
+ exclude_file = "exclude_python3.7.3.csv",
lang = "python",
+ shard_count = 8,
)
-go_test(
- name = "blacklist_test",
- size = "small",
- srcs = ["blacklist_test.go"],
- embed = [":runner"],
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
)
diff --git a/test/runtimes/README.md b/test/runtimes/README.md
deleted file mode 100644
index e41e78f77..000000000
--- a/test/runtimes/README.md
+++ /dev/null
@@ -1,41 +0,0 @@
-# Runtimes Tests Dockerfiles
-
-The Dockerfiles defined under this path are configured to host the execution of
-the runtimes language tests. Each Dockerfile can support the language indicated
-by its directory.
-
-The following runtimes are currently supported:
-
-- Go 1.12
-- Java 11
-- Node.js 12
-- PHP 7.3
-- Python 3.7
-
-#### Prerequisites:
-
-1) [Install and configure Docker](https://docs.docker.com/install/)
-
-2) Build each Docker container from the runtimes/images directory:
-
-```bash
-$ cd images
-$ docker build -f Dockerfile_$LANG [-t $NAME] .
-```
-
-### Testing:
-
-If the prerequisites have been fulfilled, you can run the tests with the
-following command:
-
-```bash
-$ docker run --rm -it $NAME [FLAG]
-```
-
-Running the command with no flags will cause all the available tests to execute.
-
-Flags can be added for additional functionality:
-
-- --list: Print a list of all available tests
-- --test &lt;name&gt;: Run a single test from the list of available tests
-- --v: Print the language version
diff --git a/test/runtimes/blacklist_go1.12.csv b/test/runtimes/blacklist_go1.12.csv
deleted file mode 100644
index 8c8ae0c5d..000000000
--- a/test/runtimes/blacklist_go1.12.csv
+++ /dev/null
@@ -1,16 +0,0 @@
-test name,bug id,comment
-cgo_errors,,FLAKY
-cgo_test,,FLAKY
-go_test:cmd/go,,FLAKY
-go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing
-go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug.
-go_test:os,b/118780122,we have a pollable filesystem but that's a surprise
-go_test:os/signal,b/118780860,/dev/pts not properly supported
-go_test:runtime,b/118782341,sigtrap not reported or caught or something
-go_test:syscall,b/118781998,bad bytes -- bad mem addr
-race,b/118782931,thread sanitizer. Works as intended: b/62219744.
-runtime:cpu124,b/118778254,segmentation fault
-test:0_1,,FLAKY
-testasan,,
-testcarchive,b/118782924,no sigpipe
-testshared,,FLAKY
diff --git a/test/runtimes/blacklist_java11.csv b/test/runtimes/blacklist_java11.csv
deleted file mode 100644
index c012e5a56..000000000
--- a/test/runtimes/blacklist_java11.csv
+++ /dev/null
@@ -1,126 +0,0 @@
-test name,bug id,comment
-com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker
-com/sun/jdi/NashornPopFrameTest.java,,
-com/sun/jdi/ProcessAttachTest.java,,
-com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker
-com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,,
-com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,,
-com/sun/tools/attach/AttachSelf.java,,
-com/sun/tools/attach/BasicTests.java,,
-com/sun/tools/attach/PermissionTest.java,,
-com/sun/tools/attach/StartManagementAgent.java,,
-com/sun/tools/attach/TempDirTest.java,,
-com/sun/tools/attach/modules/Driver.java,,
-java/lang/Character/CheckScript.java,,Fails in Docker
-java/lang/Character/CheckUnicode.java,,Fails in Docker
-java/lang/Class/GetPackageBootLoaderChildLayer.java,,
-java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker
-java/lang/String/nativeEncoding/StringPlatformChars.java,,
-java/net/DatagramSocket/ReuseAddressTest.java,,
-java/net/DatagramSocket/SendDatagramToBadAddress.java,b/78473345,
-java/net/Inet4Address/PingThis.java,,
-java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103,
-java/net/MulticastSocket/MulticastTTL.java,,
-java/net/MulticastSocket/Promiscuous.java,,
-java/net/MulticastSocket/SetLoopbackMode.java,,
-java/net/MulticastSocket/SetTTLAndGetTTL.java,,
-java/net/MulticastSocket/Test.java,,
-java/net/MulticastSocket/TestDefaults.java,,
-java/net/MulticastSocket/TimeToLive.java,,
-java/net/NetworkInterface/NetworkInterfaceStreamTest.java,,
-java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported
-java/net/Socket/TrafficClass.java,b/78527818,Not supported on gVisor
-java/net/Socket/UrgentDataTest.java,b/111515323,
-java/net/Socket/setReuseAddress/Basic.java,b/78519214,SO_REUSEADDR enabled by default
-java/net/SocketOption/OptionsTest.java,,Fails in Docker
-java/net/SocketOption/TcpKeepAliveTest.java,,
-java/net/SocketPermission/SocketPermissionTest.java,,
-java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker
-java/net/httpclient/RequestBuilderTest.java,,Fails in Docker
-java/net/httpclient/ShortResponseBody.java,,
-java/net/httpclient/ShortResponseBodyWithRetry.java,,
-java/nio/channels/AsyncCloseAndInterrupt.java,,
-java/nio/channels/AsynchronousServerSocketChannel/Basic.java,,
-java/nio/channels/AsynchronousSocketChannel/Basic.java,b/77921528,SO_KEEPALIVE is not settable
-java/nio/channels/DatagramChannel/BasicMulticastTests.java,,
-java/nio/channels/DatagramChannel/SocketOptionTests.java,,Fails in Docker
-java/nio/channels/DatagramChannel/UseDGWithIPv6.java,,
-java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker
-java/nio/channels/Selector/OutOfBand.java,,
-java/nio/channels/Selector/SelectWithConsumer.java,,Flaky
-java/nio/channels/ServerSocketChannel/SocketOptionTests.java,,
-java/nio/channels/SocketChannel/LingerOnClose.java,,
-java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901,
-java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker
-java/rmi/activation/Activatable/extLoadedImpl/ext.sh,,
-java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,,
-java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
-java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
-java/util/Calendar/JapaneseEraNameTest.java,,
-java/util/Currency/CurrencyTest.java,,Fails in Docker
-java/util/Currency/ValidateISO4217.java,,Fails in Docker
-java/util/Locale/LSRDataTest.java,,
-java/util/concurrent/locks/Lock/TimedAcquireLeak.java,,
-java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker
-java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,,
-java/util/logging/TestLoggerWeakRefLeak.java,,
-javax/imageio/AppletResourceTest.java,,
-javax/management/security/HashedPasswordFileTest.java,,
-javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker
-javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,,
-jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,,
-jdk/jfr/event/runtime/TestThreadParkEvent.java,,
-jdk/jfr/event/sampling/TestNative.java,,
-jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,,
-jdk/jfr/jcmd/TestJcmdConfigure.java,,
-jdk/jfr/jcmd/TestJcmdDump.java,,
-jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,,
-jdk/jfr/jcmd/TestJcmdDumpLimited.java,,
-jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,,
-jdk/jfr/jcmd/TestJcmdLegacy.java,,
-jdk/jfr/jcmd/TestJcmdSaveToFile.java,,
-jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,,
-jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,,
-jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,,
-jdk/jfr/jcmd/TestJcmdStartStopDefault.java,,
-jdk/jfr/jcmd/TestJcmdStartWithOptions.java,,
-jdk/jfr/jcmd/TestJcmdStartWithSettings.java,,
-jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,,
-jdk/jfr/jvm/TestJfrJavaBase.java,,
-jdk/jfr/startupargs/TestStartRecording.java,,
-jdk/modules/incubator/ImageModules.java,,
-jdk/net/Sockets/ExtOptionTest.java,,
-jdk/net/Sockets/QuickAckTest.java,,
-lib/security/cacerts/VerifyCACerts.java,,
-sun/management/jmxremote/bootstrap/CustomLauncherTest.java,,
-sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,,
-sun/management/jmxremote/bootstrap/LocalManagementTest.java,,
-sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,,
-sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,,
-sun/management/jmxremote/startstop/JMXStartStopTest.java,,
-sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,,
-sun/management/jmxremote/startstop/JMXStatusTest.java,,
-sun/text/resources/LocaleDataTest.java,,
-sun/tools/jcmd/TestJcmdSanity.java,,
-sun/tools/jhsdb/AlternateHashingTest.java,,
-sun/tools/jhsdb/BasicLauncherTest.java,,
-sun/tools/jhsdb/HeapDumpTest.java,,
-sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,,
-sun/tools/jinfo/BasicJInfoTest.java,,
-sun/tools/jinfo/JInfoTest.java,,
-sun/tools/jmap/BasicJMapTest.java,,
-sun/tools/jstack/BasicJStackTest.java,,
-sun/tools/jstack/DeadlockDetectionTest.java,,
-sun/tools/jstatd/TestJstatdExternalRegistry.java,,
-sun/tools/jstatd/TestJstatdPort.java,,Flaky
-sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky
-sun/util/calendar/zi/TestZoneInfo310.java,,
-tools/jar/modularJar/Basic.java,,
-tools/jar/multiRelease/Basic.java,,
-tools/jimage/JImageExtractTest.java,,
-tools/jimage/JImageTest.java,,
-tools/jlink/JLinkTest.java,,
-tools/jlink/plugins/IncludeLocalesPluginTest.java,,
-tools/jmod/hashes/HashesTest.java,,
-tools/launcher/BigJar.java,b/111611473,
-tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,,
diff --git a/test/runtimes/blacklist_nodejs12.4.0.csv b/test/runtimes/blacklist_nodejs12.4.0.csv
deleted file mode 100644
index 4ab4e2927..000000000
--- a/test/runtimes/blacklist_nodejs12.4.0.csv
+++ /dev/null
@@ -1,47 +0,0 @@
-test name,bug id,comment
-benchmark/test-benchmark-fs.js,,
-benchmark/test-benchmark-module.js,,
-benchmark/test-benchmark-napi.js,,
-doctool/test-make-doc.js,b/68848110,Expected to fail.
-fixtures/test-error-first-line-offset.js,,
-fixtures/test-fs-readfile-error.js,,
-fixtures/test-fs-stat-sync-overflow.js,,
-internet/test-dgram-broadcast-multi-process.js,,
-internet/test-dgram-multicast-multi-process.js,,
-internet/test-dgram-multicast-set-interface-lo.js,,
-parallel/test-cluster-dgram-reuse.js,b/64024294,
-parallel/test-dgram-bind-fd.js,b/132447356,
-parallel/test-dgram-create-socket-handle-fd.js,b/132447238,
-parallel/test-dgram-createSocket-type.js,b/68847739,
-parallel/test-dgram-socket-buffer-size.js,b/68847921,
-parallel/test-fs-access.js,,
-parallel/test-fs-write-stream-double-close.js,,
-parallel/test-fs-write-stream-throw-type-error.js,b/110226209,
-parallel/test-fs-write-stream.js,,
-parallel/test-http2-respond-file-error-pipe-offset.js,,
-parallel/test-os.js,,
-parallel/test-process-uid-gid.js,,
-pseudo-tty/test-assert-colors.js,,
-pseudo-tty/test-assert-no-color.js,,
-pseudo-tty/test-assert-position-indicator.js,,
-pseudo-tty/test-async-wrap-getasyncid-tty.js,,
-pseudo-tty/test-fatal-error.js,,
-pseudo-tty/test-handle-wrap-isrefed-tty.js,,
-pseudo-tty/test-readable-tty-keepalive.js,,
-pseudo-tty/test-set-raw-mode-reset-process-exit.js,,
-pseudo-tty/test-set-raw-mode-reset-signal.js,,
-pseudo-tty/test-set-raw-mode-reset.js,,
-pseudo-tty/test-stderr-stdout-handle-sigwinch.js,,
-pseudo-tty/test-stdout-read.js,,
-pseudo-tty/test-tty-color-support.js,,
-pseudo-tty/test-tty-isatty.js,,
-pseudo-tty/test-tty-stdin-call-end.js,,
-pseudo-tty/test-tty-stdin-end.js,,
-pseudo-tty/test-stdin-write.js,,
-pseudo-tty/test-tty-stdout-end.js,,
-pseudo-tty/test-tty-stdout-resize.js,,
-pseudo-tty/test-tty-stream-constructors.js,,
-pseudo-tty/test-tty-window-size.js,,
-pseudo-tty/test-tty-wrap.js,,
-pummel/test-net-pingpong.js,,
-pummel/test-vm-memleak.js,,
diff --git a/test/runtimes/blacklist_python3.7.3.csv b/test/runtimes/blacklist_python3.7.3.csv
deleted file mode 100644
index 2b9947212..000000000
--- a/test/runtimes/blacklist_python3.7.3.csv
+++ /dev/null
@@ -1,27 +0,0 @@
-test name,bug id,comment
-test_asynchat,b/76031995,SO_REUSEADDR
-test_asyncio,,Fails on Docker.
-test_asyncore,b/76031995,SO_REUSEADDR
-test_epoll,,
-test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode.
-test_ftplib,,Fails in Docker
-test_httplib,b/76031995,SO_REUSEADDR
-test_imaplib,,
-test_logging,,
-test_multiprocessing_fork,,Flaky. Sometimes times out.
-test_multiprocessing_forkserver,,Flaky. Sometimes times out.
-test_multiprocessing_main_handling,,Flaky. Sometimes times out.
-test_multiprocessing_spawn,,Flaky. Sometimes times out.
-test_nntplib,b/76031995,tests should not set SO_REUSEADDR
-test_poplib,,Fails on Docker
-test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted
-test_pty,b/76157709,out of pty devices
-test_readline,b/76157709,out of pty devices
-test_resource,b/76174079,
-test_selectors,b/76116849,OSError not raised with epoll
-test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets
-test_socket,b/75983380,
-test_ssl,b/76031995,SO_REUSEADDR
-test_subprocess,,
-test_support,b/76031995,SO_REUSEADDR
-test_telnetlib,b/76031995,SO_REUSEADDR
diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl
deleted file mode 100644
index 7c11624b4..000000000
--- a/test/runtimes/build_defs.bzl
+++ /dev/null
@@ -1,57 +0,0 @@
-"""Defines a rule for runtime test targets."""
-
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-# runtime_test is a macro that will create targets to run the given test target
-# with different runtime options.
-def runtime_test(
- lang,
- image,
- shard_count = 50,
- size = "enormous",
- blacklist_file = ""):
- args = [
- "--lang",
- lang,
- "--image",
- image,
- ]
- data = [
- ":runner",
- ]
- if blacklist_file != "":
- args += ["--blacklist_file", "test/runtimes/" + blacklist_file]
- data += [blacklist_file]
-
- # Add a test that the blacklist parses correctly.
- blacklist_test(lang, blacklist_file)
-
- sh_test(
- name = lang + "_test",
- srcs = ["runner.sh"],
- args = args,
- data = data,
- size = size,
- shard_count = shard_count,
- tags = [
- # Requires docker and runsc to be configured before the test runs.
- "manual",
- "local",
- ],
- )
-
-def blacklist_test(lang, blacklist_file):
- """Test that a blacklist parses correctly."""
- go_test(
- name = lang + "_blacklist_test",
- embed = [":runner"],
- srcs = ["blacklist_test.go"],
- args = ["--blacklist_file", "test/runtimes/" + blacklist_file],
- data = [blacklist_file],
- )
-
-def sh_test(**kwargs):
- """Wraps the standard sh_test."""
- native.sh_test(
- **kwargs
- )
diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl
new file mode 100644
index 000000000..702522d86
--- /dev/null
+++ b/test/runtimes/defs.bzl
@@ -0,0 +1,90 @@
+"""Defines a rule for runtime test targets."""
+
+load("//tools:defs.bzl", "go_test")
+
+def _runtime_test_impl(ctx):
+ # Construct arguments.
+ args = [
+ "--lang",
+ ctx.attr.lang,
+ "--image",
+ ctx.attr.image,
+ "--batch",
+ str(ctx.attr.batch),
+ ]
+ if ctx.attr.exclude_file:
+ args += [
+ "--exclude_file",
+ ctx.files.exclude_file[0].short_path,
+ ]
+
+ # Build a runner.
+ runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "%s %s $@\n" % (ctx.files._runner[0].short_path, " ".join(args)),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return the runner.
+ return [DefaultInfo(
+ executable = runner,
+ runfiles = ctx.runfiles(
+ files = ctx.files._runner + ctx.files.exclude_file + ctx.files._proctor,
+ collect_default = True,
+ collect_data = True,
+ ),
+ )]
+
+_runtime_test = rule(
+ implementation = _runtime_test_impl,
+ attrs = {
+ "image": attr.string(
+ mandatory = False,
+ ),
+ "lang": attr.string(
+ mandatory = True,
+ ),
+ "exclude_file": attr.label(
+ mandatory = False,
+ allow_single_file = True,
+ ),
+ "batch": attr.int(
+ default = 50,
+ mandatory = False,
+ ),
+ "_runner": attr.label(
+ default = "//test/runtimes/runner:runner",
+ executable = True,
+ cfg = "target",
+ ),
+ "_proctor": attr.label(
+ default = "//test/runtimes/proctor:proctor",
+ executable = True,
+ cfg = "target",
+ ),
+ },
+ test = True,
+)
+
+def runtime_test(name, **kwargs):
+ _runtime_test(
+ name = name,
+ image = name, # Resolved as images/runtimes/%s.
+ tags = [
+ "local",
+ "manual",
+ ],
+ size = "enormous",
+ **kwargs
+ )
+
+def exclude_test(name, exclude_file):
+ """Test that a exclude file parses correctly."""
+ go_test(
+ name = name + "_exclude_test",
+ library = ":runner",
+ srcs = ["exclude_test.go"],
+ args = ["--exclude_file", "test/runtimes/" + exclude_file],
+ data = [exclude_file],
+ )
diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude_go1.12.csv
new file mode 100644
index 000000000..81e02cf64
--- /dev/null
+++ b/test/runtimes/exclude_go1.12.csv
@@ -0,0 +1,13 @@
+test name,bug id,comment
+cgo_errors,,FLAKY
+cgo_test,,FLAKY
+go_test:cmd/go,,FLAKY
+go_test:net,b/162473575,setsockopt: protocol not available.
+go_test:os,b/118780122,we have a pollable filesystem but that's a surprise
+go_test:os/signal,b/118780860,/dev/pts not properly supported. Also being tracked in b/29356795.
+go_test:runtime,b/118782341,sigtrap not reported or caught or something. Also being tracked in b/33003106.
+go_test:syscall,b/118781998,bad bytes -- bad mem addr; FcntlFlock(F_GETLK) not supported.
+runtime:cpu124,b/118778254,segmentation fault
+test:0_1,,FLAKY
+testcarchive,b/118782924,no sigpipe
+testshared,,FLAKY
diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude_java11.csv
new file mode 100644
index 000000000..997a29cad
--- /dev/null
+++ b/test/runtimes/exclude_java11.csv
@@ -0,0 +1,208 @@
+test name,bug id,comment
+com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker
+com/sun/jdi/NashornPopFrameTest.java,,
+com/sun/jdi/ProcessAttachTest.java,,
+com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker
+com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,,
+com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,,
+com/sun/tools/attach/AttachSelf.java,,
+com/sun/tools/attach/BasicTests.java,,
+com/sun/tools/attach/PermissionTest.java,,
+com/sun/tools/attach/StartManagementAgent.java,,
+com/sun/tools/attach/TempDirTest.java,,
+com/sun/tools/attach/modules/Driver.java,,
+java/lang/Character/CheckScript.java,,Fails in Docker
+java/lang/Character/CheckUnicode.java,,Fails in Docker
+java/lang/Class/GetPackageBootLoaderChildLayer.java,,
+java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker
+java/lang/module/ModuleDescriptorTest.java,,
+java/lang/String/nativeEncoding/StringPlatformChars.java,,
+java/net/CookieHandler/B6791927.java,,java.lang.RuntimeException: Expiration date shouldn't be 0
+java/net/ipv6tests/TcpTest.java,,java.net.ConnectException: Connection timed out (Connection timed out)
+java/net/ipv6tests/UdpTest.java,,Times out
+java/net/Inet6Address/B6558853.java,,Times out
+java/net/InetAddress/CheckJNI.java,,java.net.ConnectException: Connection timed out (Connection timed out)
+java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103,
+java/net/MulticastSocket/B6425815.java,,java.net.SocketException: Protocol not available (Error getting socket option)
+java/net/MulticastSocket/B6427403.java,,java.net.SocketException: Protocol not available
+java/net/MulticastSocket/MulticastTTL.java,,
+java/net/MulticastSocket/NetworkInterfaceEmptyGetInetAddressesTest.java,,java.net.SocketException: Protocol not available (Error getting socket option)
+java/net/MulticastSocket/NoLoopbackPackets.java,,java.net.SocketException: Protocol not available
+java/net/MulticastSocket/Promiscuous.java,,
+java/net/MulticastSocket/SetLoopbackMode.java,,
+java/net/MulticastSocket/SetTTLAndGetTTL.java,,
+java/net/MulticastSocket/Test.java,,
+java/net/MulticastSocket/TestDefaults.java,,
+java/net/MulticastSocket/TimeToLive.java,,
+java/net/NetworkInterface/NetworkInterfaceStreamTest.java,,
+java/net/Socket/LinkLocal.java,,java.net.SocketTimeoutException: Receive timed out
+java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported
+java/net/Socket/UrgentDataTest.java,b/111515323,
+java/net/SocketOption/OptionsTest.java,,Fails in Docker
+java/net/SocketPermission/SocketPermissionTest.java,,
+java/net/URLConnection/6212146/TestDriver.java,,Fails in Docker
+java/net/httpclient/RequestBuilderTest.java,,Fails in Docker
+java/nio/channels/DatagramChannel/BasicMulticastTests.java,,
+java/nio/channels/DatagramChannel/SocketOptionTests.java,,java.net.SocketException: Invalid argument
+java/nio/channels/DatagramChannel/UseDGWithIPv6.java,,
+java/nio/channels/FileChannel/directio/DirectIOTest.java,,Fails in Docker
+java/nio/channels/FileChannel/directio/PwriteDirect.java,,java.io.IOException: Invalid argument
+java/nio/channels/Selector/OutOfBand.java,,
+java/nio/channels/Selector/SelectWithConsumer.java,,Flaky
+java/nio/channels/ServerSocketChannel/SocketOptionTests.java,,
+java/nio/channels/SocketChannel/LingerOnClose.java,,
+java/nio/channels/SocketChannel/SocketOptionTests.java,b/77965901,
+java/nio/channels/spi/SelectorProvider/inheritedChannel/InheritedChannelTest.java,,Fails in Docker
+java/rmi/activation/Activatable/extLoadedImpl/ext.sh,,
+java/rmi/transport/checkLeaseInfoLeak/CheckLeaseLeak.java,,
+java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
+java/text/Format/NumberFormat/CurrencyFormat.java,,Fails in Docker
+java/util/Calendar/JapaneseEraNameTest.java,,
+java/util/Currency/CurrencyTest.java,,Fails in Docker
+java/util/Currency/ValidateISO4217.java,,Fails in Docker
+java/util/EnumSet/BogusEnumSet.java,,"java.io.InvalidClassException: java.util.EnumSet; local class incompatible: stream classdesc serialVersionUID = -2409567991088730183, local class serialVersionUID = 1009687484059888093"
+java/util/Locale/Bug8040211.java,,java.lang.RuntimeException: Failed.
+java/util/Locale/LSRDataTest.java,,
+java/util/Properties/CompatibilityTest.java,,"java.lang.RuntimeException: jdk.internal.org.xml.sax.SAXParseException; Internal DTD subset is not allowed. The Properties XML document must have the following DOCTYPE declaration: <!DOCTYPE properties SYSTEM ""http://java.sun.com/dtd/properties.dtd"">"
+java/util/ResourceBundle/Control/XMLResourceBundleTest.java,,java.util.MissingResourceException: Can't find bundle for base name XmlRB locale
+java/util/ResourceBundle/modules/xmlformat/xmlformat.sh,,Timeout reached: 60000. Process is not alive!
+java/util/TimeZone/TimeZoneTest.java,,Uncaught exception thrown in test method TestShortZoneIDs
+java/util/concurrent/locks/Lock/TimedAcquireLeak.java,,
+java/util/jar/JarFile/mrjar/MultiReleaseJarAPI.java,,Fails in Docker
+java/util/logging/LogManager/Configuration/updateConfiguration/SimpleUpdateConfigWithInputStreamTest.java,,
+java/util/logging/TestLoggerWeakRefLeak.java,,
+java/util/spi/ResourceBundleControlProvider/UserDefaultControlTest.java,,java.util.MissingResourceException: Can't find bundle for base name com.foo.XmlRB locale
+javax/imageio/AppletResourceTest.java,,
+javax/imageio/plugins/jpeg/JPEGsNotAcceleratedTest.java,,java.awt.HeadlessException: No X11 DISPLAY variable was set but this program performed an operation which requires it.
+javax/management/security/HashedPasswordFileTest.java,,
+javax/net/ssl/DTLS/DTLSBufferOverflowUnderflowTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSDataExchangeTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSEnginesClosureTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSHandshakeTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSHandshakeWithReplicatedPacketsTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSIncorrectAppDataTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSMFLNTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSNotEnabledRC4Test.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSRehandshakeTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSRehandshakeWithCipherChangeTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSRehandshakeWithDataExTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSSequenceNumberTest.java,,Compilation failed
+javax/net/ssl/DTLS/DTLSUnsupportedCiphersTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10BufferOverflowUnderflowTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10DataExchangeTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10EnginesClosureTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10HandshakeTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10HandshakeWithReplicatedPacketsTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10IncorrectAppDataTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10MFLNTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10NotEnabledRC4Test.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10RehandshakeTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10RehandshakeWithCipherChangeTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10RehandshakeWithDataExTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10SequenceNumberTest.java,,Compilation failed
+javax/net/ssl/DTLSv10/DTLSv10UnsupportedCiphersTest.java,,Compilation failed
+javax/net/ssl/SSLSession/JSSERenegotiate.java,,Fails in Docker
+javax/net/ssl/TLS/TLSDataExchangeTest.java,,Compilation failed
+javax/net/ssl/TLS/TLSEnginesClosureTest.java,,Compilation failed
+javax/net/ssl/TLS/TLSHandshakeTest.java,,Compilation failed
+javax/net/ssl/TLS/TLSMFLNTest.java,,Compilation failed
+javax/net/ssl/TLS/TLSNotEnabledRC4Test.java,,Compilation failed
+javax/net/ssl/TLS/TLSRehandshakeTest.java,,Compilation failed
+javax/net/ssl/TLS/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed
+javax/net/ssl/TLS/TLSRehandshakeWithDataExTest.java,,Compilation failed
+javax/net/ssl/TLS/TLSUnsupportedCiphersTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSDataExchangeTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSEnginesClosureTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSHandshakeTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSMFLNTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSNotEnabledRC4Test.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSRehandshakeTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSRehandshakeWithDataExTest.java,,Compilation failed
+javax/net/ssl/TLSv1/TLSUnsupportedCiphersTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSDataExchangeTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSEnginesClosureTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSHandshakeTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSMFLNTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSNotEnabledRC4Test.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSRehandshakeTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSRehandshakeWithCipherChangeTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSRehandshakeWithDataExTest.java,,Compilation failed
+javax/net/ssl/TLSv11/TLSUnsupportedCiphersTest.java,,Compilation failed
+javax/net/ssl/TLSv12/TLSEnginesClosureTest.java,,Compilation failed
+javax/sound/sampled/AudioInputStream/FrameLengthAfterConversion.java,,
+jdk/jfr/cmd/TestHelp.java,,java.lang.RuntimeException: 'Available commands are:' missing from stdout/stderr
+jdk/jfr/cmd/TestPrint.java,,Missing file' missing from stdout/stderr
+jdk/jfr/cmd/TestPrintDefault.java,,java.lang.RuntimeException: 'JVMInformation' missing from stdout/stderr
+jdk/jfr/cmd/TestPrintJSON.java,,javax.script.ScriptException: <eval>:1:17 Expected an operand but found eof var jsonObject = ^ in <eval> at line number 1 at column number 17
+jdk/jfr/cmd/TestPrintXML.java,,org.xml.sax.SAXParseException; lineNumber: 1; columnNumber: 1; Premature end of file.
+jdk/jfr/cmd/TestReconstruct.java,,java.lang.RuntimeException: 'Too few arguments' missing from stdout/stderr
+jdk/jfr/cmd/TestSplit.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr
+jdk/jfr/cmd/TestSummary.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr
+jdk/jfr/event/compiler/TestCompilerStats.java,,java.lang.RuntimeException: Field nmetodsSize not in event
+jdk/jfr/event/metadata/TestDefaultConfigurations.java,,Setting 'threshold' in event 'jdk.SecurityPropertyModification' was not configured in the configuration 'default'
+jdk/jfr/event/runtime/TestActiveSettingEvent.java,,java.lang.Exception: Could not find setting with name jdk.X509Validation#threshold
+jdk/jfr/event/runtime/TestModuleEvents.java,,java.lang.RuntimeException: assertEquals: expected jdk.proxy1 to equal java.base
+jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,,
+jdk/jfr/event/runtime/TestThreadParkEvent.java,,
+jdk/jfr/event/sampling/TestNative.java,,
+jdk/jfr/jcmd/TestJcmdChangeLogLevel.java,,
+jdk/jfr/jcmd/TestJcmdConfigure.java,,
+jdk/jfr/jcmd/TestJcmdDump.java,,
+jdk/jfr/jcmd/TestJcmdDumpGeneratedFilename.java,,
+jdk/jfr/jcmd/TestJcmdDumpLimited.java,,
+jdk/jfr/jcmd/TestJcmdDumpPathToGCRoots.java,,
+jdk/jfr/jcmd/TestJcmdLegacy.java,,
+jdk/jfr/jcmd/TestJcmdSaveToFile.java,,
+jdk/jfr/jcmd/TestJcmdStartDirNotExist.java,,
+jdk/jfr/jcmd/TestJcmdStartInvaldFile.java,,
+jdk/jfr/jcmd/TestJcmdStartPathToGCRoots.java,,
+jdk/jfr/jcmd/TestJcmdStartStopDefault.java,,
+jdk/jfr/jcmd/TestJcmdStartWithOptions.java,,
+jdk/jfr/jcmd/TestJcmdStartWithSettings.java,,
+jdk/jfr/jcmd/TestJcmdStopInvalidFile.java,,
+jdk/jfr/jvm/TestGetAllEventClasses.java,,Compilation failed
+jdk/jfr/jvm/TestJfrJavaBase.java,,
+jdk/jfr/startupargs/TestStartRecording.java,,
+jdk/modules/incubator/ImageModules.java,,
+jdk/net/Sockets/ExtOptionTest.java,,
+jdk/net/Sockets/QuickAckTest.java,,
+lib/security/cacerts/VerifyCACerts.java,,
+sun/management/jmxremote/bootstrap/CustomLauncherTest.java,,
+sun/management/jmxremote/bootstrap/JvmstatCountersTest.java,,
+sun/management/jmxremote/bootstrap/LocalManagementTest.java,,
+sun/management/jmxremote/bootstrap/RmiRegistrySslTest.java,,
+sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,,
+sun/management/jmxremote/startstop/JMXStartStopTest.java,,
+sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,,
+sun/management/jmxremote/startstop/JMXStatusTest.java,,
+sun/management/jdp/JdpDefaultsTest.java,,
+sun/management/jdp/JdpJmxRemoteDynamicPortTest.java,,
+sun/management/jdp/JdpOffTest.java,,
+sun/management/jdp/JdpSpecificAddressTest.java,,
+sun/text/resources/LocaleDataTest.java,,
+sun/tools/jcmd/TestJcmdSanity.java,,
+sun/tools/jhsdb/AlternateHashingTest.java,,
+sun/tools/jhsdb/BasicLauncherTest.java,,
+sun/tools/jhsdb/HeapDumpTest.java,,
+sun/tools/jhsdb/heapconfig/JMapHeapConfigTest.java,,
+sun/tools/jinfo/BasicJInfoTest.java,,
+sun/tools/jinfo/JInfoTest.java,,
+sun/tools/jmap/BasicJMapTest.java,,
+sun/tools/jstack/BasicJStackTest.java,,
+sun/tools/jstack/DeadlockDetectionTest.java,,
+sun/tools/jstatd/TestJstatdExternalRegistry.java,,
+sun/tools/jstatd/TestJstatdPort.java,,Flaky
+sun/tools/jstatd/TestJstatdPortAndServer.java,,Flaky
+sun/util/calendar/zi/TestZoneInfo310.java,,
+tools/jar/modularJar/Basic.java,,
+tools/jar/multiRelease/Basic.java,,
+tools/jimage/JImageExtractTest.java,,
+tools/jimage/JImageTest.java,,
+tools/jlink/JLinkTest.java,,
+tools/jlink/plugins/IncludeLocalesPluginTest.java,,
+tools/jmod/hashes/HashesTest.java,,
+tools/launcher/BigJar.java,b/111611473,
+tools/launcher/HelpFlagsTest.java,,java.lang.AssertionError: HelpFlagsTest failed: Tool jfr not covered by this test. Add specification to jdkTools array!
+tools/launcher/VersionCheck.java,,java.lang.AssertionError: VersionCheck failed: testToolVersion: [jfr];
+tools/launcher/modules/patch/systemmodules/PatchSystemModules.java,,
diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude_nodejs12.4.0.csv
new file mode 100644
index 000000000..1d8e65fd0
--- /dev/null
+++ b/test/runtimes/exclude_nodejs12.4.0.csv
@@ -0,0 +1,55 @@
+test name,bug id,comment
+benchmark/test-benchmark-fs.js,,
+benchmark/test-benchmark-napi.js,,
+doctool/test-make-doc.js,b/68848110,Expected to fail.
+internet/test-dgram-multicast-set-interface-lo.js,b/162798882,
+internet/test-doctool-versions.js,,
+internet/test-uv-threadpool-schedule.js,,
+parallel/test-cluster-dgram-reuse.js,b/64024294,
+parallel/test-dgram-bind-fd.js,b/132447356,
+parallel/test-dgram-socket-buffer-size.js,b/68847921,
+parallel/test-dns-channel-timeout.js,b/161893056,
+parallel/test-fs-access.js,,
+parallel/test-fs-watchfile.js,,Flaky - File already exists error
+parallel/test-fs-write-stream.js,,Flaky
+parallel/test-fs-write-stream-throw-type-error.js,b/110226209,
+parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2
+parallel/test-os.js,b/63997097,
+parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE
+parallel/test-process-uid-gid.js,,
+parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE
+parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE
+parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE
+parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE
+parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE
+parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE
+parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE
+pseudo-tty/test-assert-colors.js,b/162801321,
+pseudo-tty/test-assert-no-color.js,b/162801321,
+pseudo-tty/test-assert-position-indicator.js,b/162801321,
+pseudo-tty/test-async-wrap-getasyncid-tty.js,b/162801321,
+pseudo-tty/test-fatal-error.js,b/162801321,
+pseudo-tty/test-handle-wrap-isrefed-tty.js,b/162801321,
+pseudo-tty/test-readable-tty-keepalive.js,b/162801321,
+pseudo-tty/test-set-raw-mode-reset-process-exit.js,b/162801321,
+pseudo-tty/test-set-raw-mode-reset-signal.js,b/162801321,
+pseudo-tty/test-set-raw-mode-reset.js,b/162801321,
+pseudo-tty/test-stderr-stdout-handle-sigwinch.js,b/162801321,
+pseudo-tty/test-stdout-read.js,b/162801321,
+pseudo-tty/test-tty-color-support.js,b/162801321,
+pseudo-tty/test-tty-isatty.js,b/162801321,
+pseudo-tty/test-tty-stdin-call-end.js,b/162801321,
+pseudo-tty/test-tty-stdin-end.js,b/162801321,
+pseudo-tty/test-stdin-write.js,b/162801321,
+pseudo-tty/test-tty-stdout-end.js,b/162801321,
+pseudo-tty/test-tty-stdout-resize.js,b/162801321,
+pseudo-tty/test-tty-stream-constructors.js,b/162801321,
+pseudo-tty/test-tty-window-size.js,b/162801321,
+pseudo-tty/test-tty-wrap.js,b/162801321,
+pummel/test-heapdump-http2.js,,Flaky
+pummel/test-net-pingpong.js,,
+pummel/test-vm-memleak.js,b/162799436,
+sequential/test-child-process-pass-fd.js,b/63926391,Flaky
+sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE
+sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout
+tick-processor/test-tick-processor-builtin.js,,
diff --git a/test/runtimes/blacklist_php7.3.6.csv b/test/runtimes/exclude_php7.3.6.csv
index 456bf7487..2ce979dc8 100644
--- a/test/runtimes/blacklist_php7.3.6.csv
+++ b/test/runtimes/exclude_php7.3.6.csv
@@ -8,22 +8,31 @@ ext/mbstring/tests/bug77165.phpt,,
ext/mbstring/tests/bug77454.phpt,,
ext/mbstring/tests/mb_convert_encoding_leak.phpt,,
ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,,
-ext/standard/tests/file/filetype_variation.phpt,,
-ext/standard/tests/file/fopen_variation19.phpt,,
+ext/session/tests/session_module_name_variation4.phpt,,Flaky
+ext/session/tests/session_set_save_handler_class_018.phpt,,
+ext/session/tests/session_set_save_handler_iface_003.phpt,,
+ext/session/tests/session_set_save_handler_sid_001.phpt,,
+ext/session/tests/session_set_save_handler_variation4.phpt,,
+ext/standard/tests/file/fopen_variation19.phpt,b/162894964,
+ext/standard/tests/file/lstat_stat_variation14.phpt,,Flaky
ext/standard/tests/file/php_fd_wrapper_01.phpt,,
ext/standard/tests/file/php_fd_wrapper_02.phpt,,
ext/standard/tests/file/php_fd_wrapper_03.phpt,,
ext/standard/tests/file/php_fd_wrapper_04.phpt,,
-ext/standard/tests/file/realpath_bug77484.phpt,,
+ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,
ext/standard/tests/file/rename_variation.phpt,b/68717309,
-ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,,
-ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,,
+ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341,
+ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223,
ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,,
ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,,
-ext/standard/tests/network/bug20134.phpt,,
+ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky
+ext/standard/tests/streams/stream_socket_sendto.phpt,,
+ext/standard/tests/strings/007.phpt,,
+sapi/cli/tests/upload_2G.phpt,,
tests/output/stream_isatty_err.phpt,b/68720279,
tests/output/stream_isatty_in-err.phpt,b/68720282,
tests/output/stream_isatty_in-out-err.phpt,,
tests/output/stream_isatty_in-out.phpt,b/68720299,
tests/output/stream_isatty_out-err.phpt,b/68720311,
tests/output/stream_isatty_out.phpt,b/68720325,
+Zend/tests/concat_003.phpt,b/162896021,
diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude_python3.7.3.csv
new file mode 100644
index 000000000..8760f8951
--- /dev/null
+++ b/test/runtimes/exclude_python3.7.3.csv
@@ -0,0 +1,21 @@
+test name,bug id,comment
+test_asyncio,,Fails on Docker.
+test_asyncore,b/162973328,
+test_epoll,b/162983393,
+test_fcntl,b/162978767,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode.
+test_httplib,b/163000009,OSError: [Errno 98] Address already in use
+test_imaplib,b/162979661,
+test_logging,b/162980079,
+test_multiprocessing_fork,,Flaky. Sometimes times out.
+test_multiprocessing_forkserver,,Flaky. Sometimes times out.
+test_multiprocessing_main_handling,,Flaky. Sometimes times out.
+test_multiprocessing_spawn,,Flaky. Sometimes times out.
+test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted
+test_pty,b/162979921,
+test_readline,b/162980389,TestReadline hangs forever
+test_resource,b/76174079,
+test_selectors,b/76116849,OSError not raised with epoll
+test_smtplib,b/162980434,unclosed sockets
+test_signal,,Flaky - signal: alarm clock
+test_socket,b/75983380,
+test_subprocess,b/162980831,
diff --git a/test/runtimes/images/Dockerfile_go1.12 b/test/runtimes/images/Dockerfile_go1.12
deleted file mode 100644
index ab9d6abf3..000000000
--- a/test/runtimes/images/Dockerfile_go1.12
+++ /dev/null
@@ -1,10 +0,0 @@
-# Go is easy, since we already have everything we need to compile the proctor
-# binary and run the tests in the golang Docker image.
-FROM golang:1.12
-ADD ["proctor/", "/go/src/proctor/"]
-RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
-
-# Pre-compile the tests so we don't need to do so in each test run.
-RUN ["go", "tool", "dist", "test", "-compile-only"]
-
-ENTRYPOINT ["/proctor", "--runtime=go"]
diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/proctor/BUILD
index 09dc6c42f..f76e2ddc0 100644
--- a/test/runtimes/images/proctor/BUILD
+++ b/test/runtimes/proctor/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
+load("//tools:defs.bzl", "go_binary", "go_test")
package(licenses = ["notice"])
@@ -12,15 +12,17 @@ go_binary(
"proctor.go",
"python.go",
],
- visibility = ["//test/runtimes/images:__subpackages__"],
+ pure = True,
+ visibility = ["//test/runtimes:__pkg__"],
)
go_test(
name = "proctor_test",
size = "small",
srcs = ["proctor_test.go"],
- embed = [":proctor"],
+ library = ":proctor",
+ pure = True,
deps = [
- "//runsc/testutil",
+ "//pkg/test/testutil",
],
)
diff --git a/test/runtimes/images/proctor/go.go b/test/runtimes/proctor/go.go
index 3e2d5d8db..d0ae844e6 100644
--- a/test/runtimes/images/proctor/go.go
+++ b/test/runtimes/proctor/go.go
@@ -74,17 +74,26 @@ func (goRunner) ListTests() ([]string, error) {
return append(toolSlice, diskFiltered...), nil
}
-// TestCmd implements TestRunner.TestCmd.
-func (goRunner) TestCmd(test string) *exec.Cmd {
- // Check if test exists on disk by searching for file of the same name.
- // This will determine whether or not it is a Go test on disk.
- if strings.HasSuffix(test, ".go") {
- // Test has suffix ".go" which indicates a disk test, run it as such.
- cmd := exec.Command("go", "run", "run.go", "-v", "--", test)
+// TestCmds implements TestRunner.TestCmds.
+func (goRunner) TestCmds(tests []string) []*exec.Cmd {
+ var toolTests, onDiskTests []string
+ for _, test := range tests {
+ if strings.HasSuffix(test, ".go") {
+ onDiskTests = append(onDiskTests, test)
+ } else {
+ toolTests = append(toolTests, "^"+test+"$")
+ }
+ }
+
+ var cmds []*exec.Cmd
+ if len(toolTests) > 0 {
+ cmds = append(cmds, exec.Command("go", "tool", "dist", "test", "-v", "-no-rebuild", "-run", strings.Join(toolTests, "\\|")))
+ }
+ if len(onDiskTests) > 0 {
+ cmd := exec.Command("go", append([]string{"run", "run.go", "-v", "--"}, onDiskTests...)...)
cmd.Dir = goTestDir
- return cmd
+ cmds = append(cmds, cmd)
}
- // No ".go" suffix, run as a tool test.
- return exec.Command("go", "tool", "dist", "test", "-run", test)
+ return cmds
}
diff --git a/test/runtimes/images/proctor/java.go b/test/runtimes/proctor/java.go
index 8b362029d..d456fa681 100644
--- a/test/runtimes/images/proctor/java.go
+++ b/test/runtimes/proctor/java.go
@@ -60,12 +60,17 @@ func (javaRunner) ListTests() ([]string, error) {
return testSlice, nil
}
-// TestCmd implements TestRunner.TestCmd.
-func (javaRunner) TestCmd(test string) *exec.Cmd {
- args := []string{
- "-noreport",
- "-dir:" + javaTestDir,
- test,
- }
- return exec.Command("jtreg", args...)
+// TestCmds implements TestRunner.TestCmds.
+func (javaRunner) TestCmds(tests []string) []*exec.Cmd {
+ args := append(
+ []string{
+ "-agentvm", // Execute each action using a pool of reusable JVMs.
+ "-dir:" + javaTestDir, // Base directory for test files and directories.
+ "-noreport", // Do not generate a final report.
+ "-timeoutFactor:20", // Extend the default timeout (2 min) of all tests by this factor.
+ "-verbose:nopass", // Verbose output but supress it for tests that passed.
+ },
+ tests...,
+ )
+ return []*exec.Cmd{exec.Command("jtreg", args...)}
}
diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/proctor/nodejs.go
index bd57db444..dead5af4f 100644
--- a/test/runtimes/images/proctor/nodejs.go
+++ b/test/runtimes/proctor/nodejs.go
@@ -39,8 +39,8 @@ func (nodejsRunner) ListTests() ([]string, error) {
return testSlice, nil
}
-// TestCmd implements TestRunner.TestCmd.
-func (nodejsRunner) TestCmd(test string) *exec.Cmd {
- args := []string{filepath.Join("tools", "test.py"), test}
- return exec.Command("/usr/bin/python", args...)
+// TestCmds implements TestRunner.TestCmds.
+func (nodejsRunner) TestCmds(tests []string) []*exec.Cmd {
+ args := append([]string{filepath.Join("tools", "test.py"), "--timeout=180"}, tests...)
+ return []*exec.Cmd{exec.Command("/usr/bin/python", args...)}
}
diff --git a/test/runtimes/images/proctor/php.go b/test/runtimes/proctor/php.go
index 9115040e1..6a83d64e3 100644
--- a/test/runtimes/images/proctor/php.go
+++ b/test/runtimes/proctor/php.go
@@ -17,6 +17,7 @@ package main
import (
"os/exec"
"regexp"
+ "strings"
)
var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`)
@@ -35,8 +36,8 @@ func (phpRunner) ListTests() ([]string, error) {
return testSlice, nil
}
-// TestCmd implements TestRunner.TestCmd.
-func (phpRunner) TestCmd(test string) *exec.Cmd {
- args := []string{"test", "TESTS=" + test}
- return exec.Command("make", args...)
+// TestCmds implements TestRunner.TestCmds.
+func (phpRunner) TestCmds(tests []string) []*exec.Cmd {
+ args := []string{"test", "TESTS=" + strings.Join(tests, " ")}
+ return []*exec.Cmd{exec.Command("make", args...)}
}
diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/proctor/proctor.go
index e6178e82b..9e0642424 100644
--- a/test/runtimes/images/proctor/proctor.go
+++ b/test/runtimes/proctor/proctor.go
@@ -25,6 +25,7 @@ import (
"os/signal"
"path/filepath"
"regexp"
+ "strings"
"syscall"
)
@@ -34,15 +35,18 @@ type TestRunner interface {
// ListTests returns a string slice of tests available to run.
ListTests() ([]string, error)
- // TestCmd returns an *exec.Cmd that will run the given test.
- TestCmd(test string) *exec.Cmd
+ // TestCmds returns a slice of *exec.Cmd that will run the given tests.
+ // There is no correlation between the number of exec.Cmds returned and the
+ // number of tests. It could return one command to run all tests or a few
+ // commands that collectively run all.
+ TestCmds(tests []string) []*exec.Cmd
}
var (
- runtime = flag.String("runtime", "", "name of runtime")
- list = flag.Bool("list", false, "list all available tests")
- test = flag.String("test", "", "run a single test from the list of available tests")
- pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
+ runtime = flag.String("runtime", "", "name of runtime")
+ list = flag.Bool("list", false, "list all available tests")
+ testNames = flag.String("tests", "", "run a subset of the available tests")
+ pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
)
func main() {
@@ -74,14 +78,25 @@ func main() {
return
}
- // Run a single test.
- if *test == "" {
- log.Fatalf("test flag must be provided")
+ var tests []string
+ if *testNames == "" {
+ // Run every test.
+ tests, err = tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to get all tests: %v", err)
+ }
+ } else {
+ // Run subset of test.
+ tests = strings.Split(*testNames, ",")
}
- cmd := tr.TestCmd(*test)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- log.Fatalf("FAIL: %v", err)
+
+ // Run tests.
+ cmds := tr.TestCmds(tests)
+ for _, cmd := range cmds {
+ cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
+ if err := cmd.Run(); err != nil {
+ log.Fatalf("FAIL: %v", err)
+ }
}
}
diff --git a/test/runtimes/images/proctor/proctor_test.go b/test/runtimes/proctor/proctor_test.go
index 6bb61d142..6ef2de085 100644
--- a/test/runtimes/images/proctor/proctor_test.go
+++ b/test/runtimes/proctor/proctor_test.go
@@ -23,24 +23,24 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
func touch(t *testing.T, name string) {
t.Helper()
f, err := os.Create(name)
if err != nil {
- t.Fatal(err)
+ t.Fatalf("error creating file %q: %v", name, err)
}
if err := f.Close(); err != nil {
- t.Fatal(err)
+ t.Fatalf("error closing file %q: %v", name, err)
}
}
func TestSearchEmptyDir(t *testing.T) {
td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest")
if err != nil {
- t.Fatal(err)
+ t.Fatalf("error creating searchtest: %v", err)
}
defer os.RemoveAll(td)
@@ -60,7 +60,7 @@ func TestSearchEmptyDir(t *testing.T) {
func TestSearch(t *testing.T) {
td, err := ioutil.TempDir(testutil.TmpDir(), "searchtest")
if err != nil {
- t.Fatal(err)
+ t.Fatalf("error creating searchtest: %v", err)
}
defer os.RemoveAll(td)
@@ -101,14 +101,14 @@ func TestSearch(t *testing.T) {
if strings.HasSuffix(item, "/") {
// This item is a directory, create it.
if err := os.MkdirAll(filepath.Join(td, item), 0755); err != nil {
- t.Fatal(err)
+ t.Fatalf("error making directory: %v", err)
}
} else {
// This item is a file, create the directory and touch file.
// Create directory in which file should be created
fullDirPath := filepath.Join(td, filepath.Dir(item))
if err := os.MkdirAll(fullDirPath, 0755); err != nil {
- t.Fatal(err)
+ t.Fatalf("error making directory: %v", err)
}
// Create file with full path to file.
touch(t, filepath.Join(td, item))
diff --git a/test/runtimes/images/proctor/python.go b/test/runtimes/proctor/python.go
index b9e0fbe6f..7c598801b 100644
--- a/test/runtimes/images/proctor/python.go
+++ b/test/runtimes/proctor/python.go
@@ -42,8 +42,8 @@ func (pythonRunner) ListTests() ([]string, error) {
return toolSlice, nil
}
-// TestCmd implements TestRunner.TestCmd.
-func (pythonRunner) TestCmd(test string) *exec.Cmd {
- args := []string{"-m", "test", test}
- return exec.Command("./python", args...)
+// TestCmds implements TestRunner.TestCmds.
+func (pythonRunner) TestCmds(tests []string) []*exec.Cmd {
+ args := append([]string{"-m", "test"}, tests...)
+ return []*exec.Cmd{exec.Command("./python", args...)}
}
diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD
new file mode 100644
index 000000000..dc0d5d5b4
--- /dev/null
+++ b/test/runtimes/runner/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_binary", "go_test")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["main.go"],
+ visibility = ["//test/runtimes:__pkg__"],
+ deps = [
+ "//pkg/log",
+ "//pkg/test/dockerutil",
+ "//pkg/test/testutil",
+ ],
+)
+
+go_test(
+ name = "exclude_test",
+ size = "small",
+ srcs = ["exclude_test.go"],
+ library = ":runner",
+)
diff --git a/test/runtimes/blacklist_test.go b/test/runtimes/runner/exclude_test.go
index 52f49b984..67c2170c8 100644
--- a/test/runtimes/blacklist_test.go
+++ b/test/runtimes/runner/exclude_test.go
@@ -25,13 +25,13 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
-// Test that the blacklist parses without error.
-func TestBlacklists(t *testing.T) {
- bl, err := getBlacklist()
+// Test that the exclude file parses without error.
+func TestExcludelist(t *testing.T) {
+ ex, err := getExcludes()
if err != nil {
- t.Fatalf("error parsing blacklist: %v", err)
+ t.Fatalf("error parsing exclude file: %v", err)
}
- if *blacklistFile != "" && len(bl) == 0 {
- t.Errorf("got empty blacklist for file %q", blacklistFile)
+ if *excludeFile != "" && len(ex) == 0 {
+ t.Errorf("got empty excludes for file %q", *excludeFile)
}
}
diff --git a/test/runtimes/runner.go b/test/runtimes/runner/main.go
index bec37c69d..948e7cf9c 100644
--- a/test/runtimes/runner.go
+++ b/test/runtimes/runner/main.go
@@ -16,29 +16,31 @@
package main
import (
+ "context"
"encoding/csv"
"flag"
"fmt"
"io"
- "log"
"os"
"sort"
"strings"
"testing"
"time"
- "gvisor.dev/gvisor/runsc/dockerutil"
- "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
)
var (
- lang = flag.String("lang", "", "language runtime to test")
- image = flag.String("image", "", "docker image with runtime tests")
- blacklistFile = flag.String("blacklist_file", "", "file containing blacklist of tests to exclude, in CSV format with fields: test name, bug id, comment")
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment")
+ batchSize = flag.Int("batch", 50, "number of test cases run in one command")
)
// Wait time for each test to run.
-const timeout = 5 * time.Minute
+const timeout = 90 * time.Minute
func main() {
flag.Parse()
@@ -46,7 +48,6 @@ func main() {
fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
os.Exit(1)
}
-
os.Exit(runTests())
}
@@ -54,21 +55,27 @@ func main() {
// defered functions before exiting. It returns an exit code that should be
// passed to os.Exit.
func runTests() int {
- // Get tests to blacklist.
- blacklist, err := getBlacklist()
+ // Get tests to exclude..
+ excludes, err := getExcludes()
if err != nil {
- fmt.Fprintf(os.Stderr, "Error getting blacklist: %s\n", err.Error())
+ fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error())
return 1
}
- // Create a single docker container that will be used for all tests.
- d := dockerutil.MakeDocker("gvisor-" + *lang)
- defer d.CleanUp()
+ // Construct the shared docker instance.
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(*lang))
+ defer d.CleanUp(ctx)
+
+ if err := testutil.TouchShardStatusFile(); err != nil {
+ fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err)
+ return 1
+ }
// Get a slice of tests to run. This will also start a single Docker
// container that will be used to run each test. The final test will
// stop the Docker container.
- tests, err := getTests(d, blacklist)
+ tests, err := getTests(ctx, d, excludes)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
return 1
@@ -78,21 +85,19 @@ func runTests() int {
return m.Run()
}
-// getTests returns a slice of tests to run, subject to the shard size and
-// index.
-func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) {
- // Pull the image.
- if err := dockerutil.Pull(*image); err != nil {
- return nil, fmt.Errorf("docker pull %q failed: %v", *image, err)
+// getTests executes all tests as table tests.
+func getTests(ctx context.Context, d *dockerutil.Container, excludes map[string]struct{}) ([]testing.InternalTest, error) {
+ // Start the container.
+ opts := dockerutil.RunOpts{
+ Image: fmt.Sprintf("runtimes/%s", *image),
}
-
- // Run proctor with --pause flag to keep container alive forever.
- if err := d.Run(*image, "--pause"); err != nil {
+ d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor")
+ if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil {
return nil, fmt.Errorf("docker run failed: %v", err)
}
// Get a list of all tests in the image.
- list, err := d.Exec("/proctor", "--runtime", *lang, "--list")
+ list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--list")
if err != nil {
return nil, fmt.Errorf("docker exec failed: %v", err)
}
@@ -101,25 +106,29 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int
// shard.
tests := strings.Fields(list)
sort.Strings(tests)
- begin, end, err := testutil.TestBoundsForShard(len(tests))
+ indices, err := testutil.TestIndicesForShard(len(tests))
if err != nil {
return nil, fmt.Errorf("TestsForShard() failed: %v", err)
}
- log.Printf("Got bounds [%d:%d) for shard out of %d total tests", begin, end, len(tests))
- tests = tests[begin:end]
var itests []testing.InternalTest
- for _, tc := range tests {
- // Capture tc in this scope.
- tc := tc
+ for i := 0; i < len(indices); i += *batchSize {
+ var tcs []string
+ end := i + *batchSize
+ if end > len(indices) {
+ end = len(indices)
+ }
+ for _, tc := range indices[i:end] {
+ // Add test if not excluded.
+ if _, ok := excludes[tests[tc]]; ok {
+ log.Infof("Skipping test case %s\n", tests[tc])
+ continue
+ }
+ tcs = append(tcs, tests[tc])
+ }
itests = append(itests, testing.InternalTest{
- Name: tc,
+ Name: strings.Join(tcs, ", "),
F: func(t *testing.T) {
- // Is the test blacklisted?
- if _, ok := blacklist[tc]; ok {
- t.Skip("SKIP: blacklisted test %q", tc)
- }
-
var (
now = time.Now()
done = make(chan struct{})
@@ -128,39 +137,36 @@ func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.Int
)
go func() {
- fmt.Printf("RUNNING %s...\n", tc)
- output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc)
+ fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n"))
+ output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--tests", strings.Join(tcs, ","))
close(done)
}()
select {
case <-done:
if err == nil {
- fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now))
+ fmt.Printf("PASS: (%v)\n\n", time.Since(now))
return
}
- t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output)
+ t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output)
case <-time.After(timeout):
- t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output)
+ t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output)
}
},
})
}
+
return itests, nil
}
-// getBlacklist reads the blacklist file and returns a set of test names to
+// getBlacklist reads the exclude file and returns a set of test names to
// exclude.
-func getBlacklist() (map[string]struct{}, error) {
- blacklist := make(map[string]struct{})
- if *blacklistFile == "" {
- return blacklist, nil
- }
- file, err := testutil.FindFile(*blacklistFile)
- if err != nil {
- return nil, err
+func getExcludes() (map[string]struct{}, error) {
+ excludes := make(map[string]struct{})
+ if *excludeFile == "" {
+ return excludes, nil
}
- f, err := os.Open(file)
+ f, err := os.Open(*excludeFile)
if err != nil {
return nil, err
}
@@ -181,9 +187,9 @@ func getBlacklist() (map[string]struct{}, error) {
if err != nil {
return nil, err
}
- blacklist[record[0]] = struct{}{}
+ excludes[record[0]] = struct{}{}
}
- return blacklist, nil
+ return excludes, nil
}
// testDeps implements testing.testDeps (an unexported interface), and is
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index a53a23afd..0eadc6b08 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -1,15 +1,18 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-load("//test/syscalls:build_defs.bzl", "syscall_test")
+load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
-syscall_test(test = "//test/syscalls/linux:32bit_test")
+syscall_test(
+ test = "//test/syscalls/linux:32bit_test",
+)
-syscall_test(test = "//test/syscalls/linux:accept_bind_stream_test")
+syscall_test(
+ test = "//test/syscalls/linux:accept_bind_stream_test",
+)
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:accept_bind_test",
)
@@ -18,7 +21,9 @@ syscall_test(
test = "//test/syscalls/linux:access_test",
)
-syscall_test(test = "//test/syscalls/linux:affinity_test")
+syscall_test(
+ test = "//test/syscalls/linux:affinity_test",
+)
syscall_test(
add_overlay = True,
@@ -31,9 +36,13 @@ syscall_test(
test = "//test/syscalls/linux:alarm_test",
)
-syscall_test(test = "//test/syscalls/linux:arch_prctl_test")
+syscall_test(
+ test = "//test/syscalls/linux:arch_prctl_test",
+)
-syscall_test(test = "//test/syscalls/linux:bad_test")
+syscall_test(
+ test = "//test/syscalls/linux:bad_test",
+)
syscall_test(
size = "large",
@@ -41,9 +50,27 @@ syscall_test(
test = "//test/syscalls/linux:bind_test",
)
-syscall_test(test = "//test/syscalls/linux:brk_test")
+syscall_test(
+ test = "//test/syscalls/linux:brk_test",
+)
-syscall_test(test = "//test/syscalls/linux:socket_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_capability_test",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
+ # Takes too long for TSAN. Since this is kind of a stress test that doesn't
+ # involve much concurrency, TSAN's usefulness here is limited anyway.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:socket_stress_test",
+ vfs2 = False,
+)
syscall_test(
add_overlay = True,
@@ -67,16 +94,22 @@ syscall_test(
test = "//test/syscalls/linux:chroot_test",
)
-syscall_test(test = "//test/syscalls/linux:clock_getres_test")
+syscall_test(
+ test = "//test/syscalls/linux:clock_getres_test",
+)
syscall_test(
size = "medium",
test = "//test/syscalls/linux:clock_gettime_test",
)
-syscall_test(test = "//test/syscalls/linux:clock_nanosleep_test")
+syscall_test(
+ test = "//test/syscalls/linux:clock_nanosleep_test",
+)
-syscall_test(test = "//test/syscalls/linux:concurrency_test")
+syscall_test(
+ test = "//test/syscalls/linux:concurrency_test",
+)
syscall_test(
add_uds_tree = True,
@@ -89,18 +122,27 @@ syscall_test(
test = "//test/syscalls/linux:creat_test",
)
-syscall_test(test = "//test/syscalls/linux:dev_test")
+syscall_test(
+ fuse = "True",
+ test = "//test/syscalls/linux:dev_test",
+)
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:dup_test",
)
-syscall_test(test = "//test/syscalls/linux:epoll_test")
+syscall_test(
+ test = "//test/syscalls/linux:epoll_test",
+)
-syscall_test(test = "//test/syscalls/linux:eventfd_test")
+syscall_test(
+ test = "//test/syscalls/linux:eventfd_test",
+)
-syscall_test(test = "//test/syscalls/linux:exceptions_test")
+syscall_test(
+ test = "//test/syscalls/linux:exceptions_test",
+)
syscall_test(
size = "medium",
@@ -114,7 +156,9 @@ syscall_test(
test = "//test/syscalls/linux:exec_binary_test",
)
-syscall_test(test = "//test/syscalls/linux:exit_test")
+syscall_test(
+ test = "//test/syscalls/linux:exit_test",
+)
syscall_test(
add_overlay = True,
@@ -126,7 +170,9 @@ syscall_test(
test = "//test/syscalls/linux:fallocate_test",
)
-syscall_test(test = "//test/syscalls/linux:fault_test")
+syscall_test(
+ test = "//test/syscalls/linux:fault_test",
+)
syscall_test(
add_overlay = True,
@@ -144,11 +190,17 @@ syscall_test(
test = "//test/syscalls/linux:flock_test",
)
-syscall_test(test = "//test/syscalls/linux:fork_test")
+syscall_test(
+ test = "//test/syscalls/linux:fork_test",
+)
-syscall_test(test = "//test/syscalls/linux:fpsig_fork_test")
+syscall_test(
+ test = "//test/syscalls/linux:fpsig_fork_test",
+)
-syscall_test(test = "//test/syscalls/linux:fpsig_nested_test")
+syscall_test(
+ test = "//test/syscalls/linux:fpsig_nested_test",
+)
syscall_test(
add_overlay = True,
@@ -161,18 +213,26 @@ syscall_test(
test = "//test/syscalls/linux:futex_test",
)
-syscall_test(test = "//test/syscalls/linux:getcpu_host_test")
+syscall_test(
+ test = "//test/syscalls/linux:getcpu_host_test",
+)
-syscall_test(test = "//test/syscalls/linux:getcpu_test")
+syscall_test(
+ test = "//test/syscalls/linux:getcpu_test",
+)
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:getdents_test",
)
-syscall_test(test = "//test/syscalls/linux:getrandom_test")
+syscall_test(
+ test = "//test/syscalls/linux:getrandom_test",
+)
-syscall_test(test = "//test/syscalls/linux:getrusage_test")
+syscall_test(
+ test = "//test/syscalls/linux:getrusage_test",
+)
syscall_test(
size = "medium",
@@ -196,7 +256,9 @@ syscall_test(
test = "//test/syscalls/linux:itimer_test",
)
-syscall_test(test = "//test/syscalls/linux:kill_test")
+syscall_test(
+ test = "//test/syscalls/linux:kill_test",
+)
syscall_test(
add_overlay = True,
@@ -209,13 +271,21 @@ syscall_test(
test = "//test/syscalls/linux:lseek_test",
)
-syscall_test(test = "//test/syscalls/linux:madvise_test")
+syscall_test(
+ test = "//test/syscalls/linux:madvise_test",
+)
-syscall_test(test = "//test/syscalls/linux:memory_accounting_test")
+syscall_test(
+ test = "//test/syscalls/linux:memory_accounting_test",
+)
-syscall_test(test = "//test/syscalls/linux:mempolicy_test")
+syscall_test(
+ test = "//test/syscalls/linux:mempolicy_test",
+)
-syscall_test(test = "//test/syscalls/linux:mincore_test")
+syscall_test(
+ test = "//test/syscalls/linux:mincore_test",
+)
syscall_test(
add_overlay = True,
@@ -225,7 +295,6 @@ syscall_test(
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:mknod_test",
- use_tmpfs = True, # mknod is not supported over gofer.
)
syscall_test(
@@ -249,7 +318,13 @@ syscall_test(
test = "//test/syscalls/linux:msync_test",
)
-syscall_test(test = "//test/syscalls/linux:munmap_test")
+syscall_test(
+ test = "//test/syscalls/linux:munmap_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:network_namespace_test",
+)
syscall_test(
add_overlay = True,
@@ -261,13 +336,28 @@ syscall_test(
test = "//test/syscalls/linux:open_test",
)
-syscall_test(test = "//test/syscalls/linux:packet_socket_raw_test")
+syscall_test(
+ test = "//test/syscalls/linux:packet_socket_raw_test",
+)
-syscall_test(test = "//test/syscalls/linux:packet_socket_test")
+syscall_test(
+ test = "//test/syscalls/linux:packet_socket_test",
+)
-syscall_test(test = "//test/syscalls/linux:partial_bad_buffer_test")
+syscall_test(
+ test = "//test/syscalls/linux:partial_bad_buffer_test",
+)
-syscall_test(test = "//test/syscalls/linux:pause_test")
+syscall_test(
+ test = "//test/syscalls/linux:pause_test",
+)
+
+syscall_test(
+ size = "medium",
+ # Takes too long under gotsan to run.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:ping_socket_test",
+)
syscall_test(
size = "large",
@@ -276,16 +366,22 @@ syscall_test(
test = "//test/syscalls/linux:pipe_test",
)
-syscall_test(test = "//test/syscalls/linux:poll_test")
+syscall_test(
+ test = "//test/syscalls/linux:poll_test",
+)
syscall_test(
size = "medium",
test = "//test/syscalls/linux:ppoll_test",
)
-syscall_test(test = "//test/syscalls/linux:prctl_setuid_test")
+syscall_test(
+ test = "//test/syscalls/linux:prctl_setuid_test",
+)
-syscall_test(test = "//test/syscalls/linux:prctl_test")
+syscall_test(
+ test = "//test/syscalls/linux:prctl_test",
+)
syscall_test(
add_overlay = True,
@@ -302,23 +398,39 @@ syscall_test(
test = "//test/syscalls/linux:preadv2_test",
)
-syscall_test(test = "//test/syscalls/linux:priority_test")
+syscall_test(
+ test = "//test/syscalls/linux:priority_test",
+)
syscall_test(
size = "medium",
test = "//test/syscalls/linux:proc_test",
)
-syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test")
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_pid_oomscore_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_pid_smaps_test",
+)
-syscall_test(test = "//test/syscalls/linux:proc_net_test")
+syscall_test(
+ test = "//test/syscalls/linux:proc_pid_uid_gid_map_test",
+)
syscall_test(
size = "medium",
test = "//test/syscalls/linux:pselect_test",
)
-syscall_test(test = "//test/syscalls/linux:ptrace_test")
+syscall_test(
+ test = "//test/syscalls/linux:ptrace_test",
+)
syscall_test(
size = "medium",
@@ -340,11 +452,17 @@ syscall_test(
test = "//test/syscalls/linux:pwrite64_test",
)
-syscall_test(test = "//test/syscalls/linux:raw_socket_hdrincl_test")
+syscall_test(
+ test = "//test/syscalls/linux:raw_socket_hdrincl_test",
+)
-syscall_test(test = "//test/syscalls/linux:raw_socket_icmp_test")
+syscall_test(
+ test = "//test/syscalls/linux:raw_socket_icmp_test",
+)
-syscall_test(test = "//test/syscalls/linux:raw_socket_ipv4_test")
+syscall_test(
+ test = "//test/syscalls/linux:raw_socket_test",
+)
syscall_test(
add_overlay = True,
@@ -374,17 +492,37 @@ syscall_test(
test = "//test/syscalls/linux:rename_test",
)
-syscall_test(test = "//test/syscalls/linux:rlimits_test")
+syscall_test(
+ test = "//test/syscalls/linux:rlimits_test",
+)
-syscall_test(test = "//test/syscalls/linux:rtsignal_test")
+syscall_test(
+ test = "//test/syscalls/linux:rseq_test",
+)
-syscall_test(test = "//test/syscalls/linux:sched_test")
+syscall_test(
+ test = "//test/syscalls/linux:rtsignal_test",
+)
-syscall_test(test = "//test/syscalls/linux:sched_yield_test")
+syscall_test(
+ test = "//test/syscalls/linux:signalfd_test",
+)
-syscall_test(test = "//test/syscalls/linux:seccomp_test")
+syscall_test(
+ test = "//test/syscalls/linux:sched_test",
+)
-syscall_test(test = "//test/syscalls/linux:select_test")
+syscall_test(
+ test = "//test/syscalls/linux:sched_yield_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:seccomp_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:select_test",
+)
syscall_test(
shard_count = 20,
@@ -406,21 +544,29 @@ syscall_test(
test = "//test/syscalls/linux:splice_test",
)
-syscall_test(test = "//test/syscalls/linux:sigaction_test")
+syscall_test(
+ test = "//test/syscalls/linux:sigaction_test",
+)
# TODO(b/119826902): Enable once the test passes in runsc.
-# syscall_test(test = "//test/syscalls/linux:sigaltstack_test")
+# syscall_test(vfs2="True",test = "//test/syscalls/linux:sigaltstack_test")
-syscall_test(test = "//test/syscalls/linux:sigiret_test")
+syscall_test(
+ test = "//test/syscalls/linux:sigiret_test",
+)
-syscall_test(test = "//test/syscalls/linux:sigprocmask_test")
+syscall_test(
+ test = "//test/syscalls/linux:sigprocmask_test",
+)
syscall_test(
size = "medium",
test = "//test/syscalls/linux:sigstop_test",
)
-syscall_test(test = "//test/syscalls/linux:sigtimedwait_test")
+syscall_test(
+ test = "//test/syscalls/linux:sigtimedwait_test",
+)
syscall_test(
size = "medium",
@@ -434,7 +580,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_abstract_test",
)
@@ -445,7 +591,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_domain_test",
)
@@ -458,19 +604,27 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_filesystem_test",
)
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_inet_loopback_test",
)
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
+ # Takes too long for TSAN. Creates a lot of TCP sockets.
+ tags = ["nogotsan"],
+ test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test",
+)
+
+syscall_test(
+ size = "large",
+ shard_count = 50,
test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test",
)
@@ -481,13 +635,13 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_ip_tcp_loopback_test",
)
syscall_test(
size = "medium",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test",
)
@@ -498,7 +652,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_ip_udp_loopback_test",
)
@@ -507,19 +661,41 @@ syscall_test(
test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test",
)
-syscall_test(test = "//test/syscalls/linux:socket_ip_unbound_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_ip_unbound_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_netdevice_test",
+)
-syscall_test(test = "//test/syscalls/linux:socket_netdevice_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_netlink_test",
+)
-syscall_test(test = "//test/syscalls/linux:socket_netlink_route_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_netlink_route_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:socket_netlink_uevent_test",
+)
-syscall_test(test = "//test/syscalls/linux:socket_blocking_local_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_blocking_local_test",
+)
-syscall_test(test = "//test/syscalls/linux:socket_blocking_ip_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_blocking_ip_test",
+)
-syscall_test(test = "//test/syscalls/linux:socket_non_stream_blocking_local_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_non_stream_blocking_local_test",
+)
-syscall_test(test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test")
+syscall_test(
+ test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test",
+)
syscall_test(
size = "large",
@@ -556,7 +732,7 @@ syscall_test(
syscall_test(
size = "large",
add_overlay = True,
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_unix_pair_test",
)
@@ -595,7 +771,7 @@ syscall_test(
syscall_test(
size = "large",
- shard_count = 10,
+ shard_count = 50,
test = "//test/syscalls/linux:socket_unix_unbound_stream_test",
)
@@ -634,11 +810,17 @@ syscall_test(
test = "//test/syscalls/linux:sync_file_range_test",
)
-syscall_test(test = "//test/syscalls/linux:sysinfo_test")
+syscall_test(
+ test = "//test/syscalls/linux:sysinfo_test",
+)
-syscall_test(test = "//test/syscalls/linux:syslog_test")
+syscall_test(
+ test = "//test/syscalls/linux:syslog_test",
+)
-syscall_test(test = "//test/syscalls/linux:sysret_test")
+syscall_test(
+ test = "//test/syscalls/linux:sysret_test",
+)
syscall_test(
size = "medium",
@@ -646,52 +828,88 @@ syscall_test(
test = "//test/syscalls/linux:tcp_socket_test",
)
-syscall_test(test = "//test/syscalls/linux:tgkill_test")
+syscall_test(
+ test = "//test/syscalls/linux:tgkill_test",
+)
-syscall_test(test = "//test/syscalls/linux:timerfd_test")
+syscall_test(
+ test = "//test/syscalls/linux:timerfd_test",
+)
-syscall_test(test = "//test/syscalls/linux:timers_test")
+syscall_test(
+ test = "//test/syscalls/linux:timers_test",
+)
-syscall_test(test = "//test/syscalls/linux:time_test")
+syscall_test(
+ test = "//test/syscalls/linux:time_test",
+)
-syscall_test(test = "//test/syscalls/linux:tkill_test")
+syscall_test(
+ test = "//test/syscalls/linux:tkill_test",
+)
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:truncate_test",
)
-syscall_test(test = "//test/syscalls/linux:udp_bind_test")
+syscall_test(
+ test = "//test/syscalls/linux:tuntap_test",
+)
+
+syscall_test(
+ add_hostinet = True,
+ test = "//test/syscalls/linux:tuntap_hostinet_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:udp_bind_test",
+)
syscall_test(
size = "medium",
+ add_hostinet = True,
shard_count = 10,
test = "//test/syscalls/linux:udp_socket_test",
)
-syscall_test(test = "//test/syscalls/linux:uidgid_test")
+syscall_test(
+ test = "//test/syscalls/linux:uidgid_test",
+)
-syscall_test(test = "//test/syscalls/linux:uname_test")
+syscall_test(
+ test = "//test/syscalls/linux:uname_test",
+)
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:unlink_test",
)
-syscall_test(test = "//test/syscalls/linux:unshare_test")
+syscall_test(
+ test = "//test/syscalls/linux:unshare_test",
+)
-syscall_test(test = "//test/syscalls/linux:utimes_test")
+syscall_test(
+ test = "//test/syscalls/linux:utimes_test",
+)
syscall_test(
size = "medium",
test = "//test/syscalls/linux:vdso_clock_gettime_test",
)
-syscall_test(test = "//test/syscalls/linux:vdso_test")
+syscall_test(
+ test = "//test/syscalls/linux:vdso_test",
+)
-syscall_test(test = "//test/syscalls/linux:vsyscall_test")
+syscall_test(
+ test = "//test/syscalls/linux:vsyscall_test",
+)
-syscall_test(test = "//test/syscalls/linux:vfork_test")
+syscall_test(
+ test = "//test/syscalls/linux:vfork_test",
+)
syscall_test(
size = "medium",
@@ -704,26 +922,14 @@ syscall_test(
test = "//test/syscalls/linux:write_test",
)
-syscall_test(test = "//test/syscalls/linux:proc_net_unix_test")
-
-syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test")
-
-syscall_test(test = "//test/syscalls/linux:proc_net_udp_test")
-
-go_binary(
- name = "syscall_test_runner",
- testonly = 1,
- srcs = ["syscall_test_runner.go"],
- data = [
- "//runsc",
- ],
- deps = [
- "//pkg/log",
- "//runsc/specutils",
- "//runsc/testutil",
- "//test/syscalls/gtest",
- "//test/uds",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- ],
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_unix_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_tcp_test",
+)
+
+syscall_test(
+ test = "//test/syscalls/linux:proc_net_udp_test",
)
diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl
deleted file mode 100644
index dcf5b73ed..000000000
--- a/test/syscalls/build_defs.bzl
+++ /dev/null
@@ -1,136 +0,0 @@
-"""Defines a rule for syscall test targets."""
-
-# syscall_test is a macro that will create targets to run the given test target
-# on the host (native) and runsc.
-def syscall_test(
- test,
- shard_count = 5,
- size = "small",
- use_tmpfs = False,
- add_overlay = False,
- add_uds_tree = False,
- tags = None):
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "native",
- use_tmpfs = False,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "kvm",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- if add_overlay:
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = False, # overlay is adding a writable tmpfs on top of root.
- add_uds_tree = add_uds_tree,
- tags = tags,
- overlay = True,
- )
-
- if not use_tmpfs:
- # Also test shared gofer access.
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- file_access = "shared",
- )
-
-def _syscall_test(
- test,
- shard_count,
- size,
- platform,
- use_tmpfs,
- tags,
- file_access = "exclusive",
- overlay = False,
- add_uds_tree = False):
- test_name = test.split(":")[1]
-
- # Prepend "runsc" to non-native platform names.
- full_platform = platform if platform == "native" else "runsc_" + platform
-
- name = test_name + "_" + full_platform
- if file_access == "shared":
- name += "_shared"
- if overlay:
- name += "_overlay"
-
- if tags == None:
- tags = []
-
- # Add the full_platform and file access in a tag to make it easier to run
- # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
- tags += [full_platform, "file_" + file_access]
-
- # Add tag to prevent the tests from running in a Bazel sandbox.
- # TODO(b/120560048): Make the tests run without this tag.
- tags.append("no-sandbox")
-
- # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is
- # more stable.
- if platform == "kvm":
- tags += ["manual"]
- tags += ["requires-kvm"]
-
- args = [
- # Arguments are passed directly to syscall_test_runner binary.
- "--test-name=" + test_name,
- "--platform=" + platform,
- "--use-tmpfs=" + str(use_tmpfs),
- "--file-access=" + file_access,
- "--overlay=" + str(overlay),
- "--add-uds-tree=" + str(add_uds_tree),
- ]
-
- sh_test(
- srcs = ["syscall_test_runner.sh"],
- name = name,
- data = [
- ":syscall_test_runner",
- test,
- ],
- args = args,
- size = size,
- tags = tags,
- shard_count = shard_count,
- )
-
-def sh_test(**kwargs):
- """Wraps the standard sh_test."""
- native.sh_test(
- **kwargs
- )
-
-def select_for_linux(for_linux, for_others = []):
- return for_linux
diff --git a/test/syscalls/gtest/BUILD b/test/syscalls/gtest/BUILD
deleted file mode 100644
index 9293f25cb..000000000
--- a/test/syscalls/gtest/BUILD
+++ /dev/null
@@ -1,12 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "gtest",
- srcs = ["gtest.go"],
- importpath = "gvisor.dev/gvisor/test/syscalls/gtest",
- visibility = [
- "//test:__subpackages__",
- ],
-)
diff --git a/test/syscalls/gtest/gtest.go b/test/syscalls/gtest/gtest.go
deleted file mode 100644
index bdec8eb07..000000000
--- a/test/syscalls/gtest/gtest.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package gtest contains helpers for running google-test tests from Go.
-package gtest
-
-import (
- "fmt"
- "os/exec"
- "strings"
-)
-
-var (
- // ListTestFlag is the flag that will list tests in gtest binaries.
- ListTestFlag = "--gtest_list_tests"
-
- // FilterTestFlag is the flag that will filter tests in gtest binaries.
- FilterTestFlag = "--gtest_filter"
-)
-
-// TestCase is a single gtest test case.
-type TestCase struct {
- // Suite is the suite for this test.
- Suite string
-
- // Name is the name of this individual test.
- Name string
-}
-
-// FullName returns the name of the test including the suite. It is suitable to
-// pass to "-gtest_filter".
-func (tc TestCase) FullName() string {
- return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
-}
-
-// ParseTestCases calls a gtest test binary to list its test and returns a
-// slice with the name and suite of each test.
-func ParseTestCases(testBin string, extraArgs ...string) ([]TestCase, error) {
- args := append([]string{ListTestFlag}, extraArgs...)
- cmd := exec.Command(testBin, args...)
- out, err := cmd.Output()
- if err != nil {
- exitErr, ok := err.(*exec.ExitError)
- if !ok {
- return nil, fmt.Errorf("could not enumerate gtest tests: %v", err)
- }
- return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr)
- }
-
- var t []TestCase
- var suite string
- for _, line := range strings.Split(string(out), "\n") {
- // Strip comments.
- line = strings.Split(line, "#")[0]
-
- // New suite?
- if !strings.HasPrefix(line, " ") {
- suite = strings.TrimSuffix(strings.TrimSpace(line), ".")
- continue
- }
-
- // Individual test.
- name := strings.TrimSpace(line)
-
- // Do we have a suite yet?
- if suite == "" {
- return nil, fmt.Errorf("test without a suite: %v", name)
- }
-
- // Add this individual test.
- t = append(t, TestCase{
- Suite: suite,
- Name: name,
- })
-
- }
-
- if len(t) == 0 {
- return nil, fmt.Errorf("no tests parsed from %v", testBin)
- }
- return t, nil
-}
diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc
index a7cbee06b..3c825477c 100644
--- a/test/syscalls/linux/32bit.cc
+++ b/test/syscalls/linux/32bit.cc
@@ -15,10 +15,12 @@
#include <string.h>
#include <sys/mman.h>
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
#include "test/util/memory_util.h"
+#include "test/util/platform_util.h"
#include "test/util/posix_error.h"
#include "test/util/test_util.h"
-#include "gtest/gtest.h"
#ifndef __x86_64__
#error "This test is x86-64 specific."
@@ -30,7 +32,6 @@ namespace testing {
namespace {
constexpr char kInt3 = '\xcc';
-
constexpr char kInt80[2] = {'\xcd', '\x80'};
constexpr char kSyscall[2] = {'\x0f', '\x05'};
constexpr char kSysenter[2] = {'\x0f', '\x34'};
@@ -43,6 +44,7 @@ void ExitGroup32(const char instruction[2], int code) {
// Fill with INT 3 in case we execute too far.
memset(m.ptr(), kInt3, m.len());
+ // Copy in the actual instruction.
memcpy(m.ptr(), instruction, 2);
// We're playing *extremely* fast-and-loose with the various syscall ABIs
@@ -71,77 +73,96 @@ void ExitGroup32(const char instruction[2], int code) {
"iretl\n"
"int $3\n"
:
- : [code] "m"(code), [ip] "d"(m.ptr())
- : "rax", "rbx", "rsp");
+ : [ code ] "m"(code), [ ip ] "d"(m.ptr())
+ : "rax", "rbx");
}
constexpr int kExitCode = 42;
TEST(Syscall32Bit, Int80) {
- switch (GvisorPlatform()) {
- case Platform::kKVM:
- // TODO(b/111805002): 32-bit segments are broken (but not explictly
- // disabled).
- return;
- case Platform::kPtrace:
- // TODO(gvisor.dev/issue/167): The ptrace platform does not have a
- // consistent story here.
- return;
- case Platform::kNative:
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
+ break;
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(ExitGroup32(kInt80, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
break;
- }
- // Upstream Linux. 32-bit syscalls allowed.
- EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), ::testing::ExitedWithCode(42),
- "");
-}
+ case PlatformSupport::Ignored:
+ // Since the call is ignored, we'll hit the int3 trap.
+ EXPECT_EXIT(ExitGroup32(kInt80, kExitCode),
+ ::testing::KilledBySignal(SIGTRAP), "");
+ break;
-TEST(Syscall32Bit, Sysenter) {
- switch (GvisorPlatform()) {
- case Platform::kKVM:
- // TODO(b/111805002): See above.
- return;
- case Platform::kPtrace:
- // TODO(gvisor.dev/issue/167): See above.
- return;
- case Platform::kNative:
+ case PlatformSupport::Allowed:
+ EXPECT_EXIT(ExitGroup32(kInt80, kExitCode), ::testing::ExitedWithCode(42),
+ "");
break;
}
+}
- if (GetCPUVendor() == CPUVendor::kAMD) {
+TEST(Syscall32Bit, Sysenter) {
+ if ((PlatformSupport32Bit() == PlatformSupport::Allowed ||
+ PlatformSupport32Bit() == PlatformSupport::Ignored) &&
+ GetCPUVendor() == CPUVendor::kAMD) {
// SYSENTER is an illegal instruction in compatibility mode on AMD.
EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
::testing::KilledBySignal(SIGILL), "");
return;
}
- // Upstream Linux on !AMD, 32-bit syscalls allowed.
- EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode), ::testing::ExitedWithCode(42),
- "");
-}
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
+ break;
-TEST(Syscall32Bit, Syscall) {
- switch (GvisorPlatform()) {
- case Platform::kKVM:
- // TODO(b/111805002): See above.
- return;
- case Platform::kPtrace:
- // TODO(gvisor.dev/issue/167): See above.
- return;
- case Platform::kNative:
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Ignored:
+ // See above, except expected code is SIGSEGV.
+ EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Allowed:
+ EXPECT_EXIT(ExitGroup32(kSysenter, kExitCode),
+ ::testing::ExitedWithCode(42), "");
break;
}
+}
- if (GetCPUVendor() == CPUVendor::kIntel) {
+TEST(Syscall32Bit, Syscall) {
+ if ((PlatformSupport32Bit() == PlatformSupport::Allowed ||
+ PlatformSupport32Bit() == PlatformSupport::Ignored) &&
+ GetCPUVendor() == CPUVendor::kIntel) {
// SYSCALL is an illegal instruction in compatibility mode on Intel.
EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
::testing::KilledBySignal(SIGILL), "");
return;
}
- // Upstream Linux on !Intel, 32-bit syscalls allowed.
- EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode), ::testing::ExitedWithCode(42),
- "");
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
+ break;
+
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Ignored:
+ // See above.
+ EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
+ ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Allowed:
+ EXPECT_EXIT(ExitGroup32(kSyscall, kExitCode),
+ ::testing::ExitedWithCode(42), "");
+ break;
+ }
}
// Far call code called below.
@@ -205,19 +226,20 @@ void FarCall32() {
}
TEST(Call32Bit, Disallowed) {
- switch (GvisorPlatform()) {
- case Platform::kKVM:
- // TODO(b/111805002): See above.
- return;
- case Platform::kPtrace:
- // The ptrace platform cannot prevent switching to compatibility mode.
- ABSL_FALLTHROUGH_INTENDED;
- case Platform::kNative:
+ switch (PlatformSupport32Bit()) {
+ case PlatformSupport::NotSupported:
break;
- }
- // Shouldn't crash.
- FarCall32();
+ case PlatformSupport::Segfault:
+ EXPECT_EXIT(FarCall32(), ::testing::KilledBySignal(SIGSEGV), "");
+ break;
+
+ case PlatformSupport::Ignored:
+ ABSL_FALLTHROUGH_INTENDED;
+ case PlatformSupport::Allowed:
+ // Shouldn't crash.
+ FarCall32();
+ }
}
} // namespace
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 833fbaa09..66a31cd28 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1,11 +1,34 @@
-load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
-load("//test/syscalls:build_defs.bzl", "select_for_linux")
+load("//tools:defs.bzl", "cc_binary", "cc_library", "default_net_util", "gtest", "select_arch", "select_system")
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
)
+exports_files(
+ [
+ "socket.cc",
+ "socket_inet_loopback.cc",
+ "socket_ip_loopback_blocking.cc",
+ "socket_ip_tcp_generic_loopback.cc",
+ "socket_ip_tcp_loopback.cc",
+ "socket_ip_tcp_loopback_blocking.cc",
+ "socket_ip_tcp_loopback_nonblock.cc",
+ "socket_ip_tcp_udp_generic.cc",
+ "socket_ip_udp_loopback.cc",
+ "socket_ip_udp_loopback_blocking.cc",
+ "socket_ip_udp_loopback_nonblock.cc",
+ "socket_ip_unbound.cc",
+ "socket_ipv4_tcp_unbound_external_networking_test.cc",
+ "socket_ipv4_udp_unbound_external_networking_test.cc",
+ "socket_ipv4_udp_unbound_loopback.cc",
+ "tcp_socket.cc",
+ "udp_bind.cc",
+ "udp_socket.cc",
+ ],
+ visibility = ["//:sandbox"],
+)
+
cc_binary(
name = "sigaltstack_check",
testonly = 1,
@@ -70,14 +93,14 @@ cc_library(
srcs = ["base_poll_test.cc"],
hdrs = ["base_poll_test.h"],
deps = [
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -87,11 +110,11 @@ cc_library(
hdrs = ["file_base.h"],
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -109,34 +132,37 @@ cc_library(
)
cc_library(
+ name = "socket_netlink_route_util",
+ testonly = 1,
+ srcs = ["socket_netlink_route_util.cc"],
+ hdrs = ["socket_netlink_route_util.h"],
+ deps = [
+ ":socket_netlink_util",
+ ],
+)
+
+cc_library(
name = "socket_test_util",
testonly = 1,
srcs = [
"socket_test_util.cc",
- ] + select_for_linux(
- [
- "socket_test_util_impl.cc",
- ],
- ),
+ "socket_test_util_impl.cc",
+ ],
hdrs = ["socket_test_util.h"],
- deps = [
- "@com_google_googletest//:gtest",
+ defines = select_system(),
+ deps = default_net_util() + [
+ gtest,
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
+ "@com_google_absl//absl/types:optional",
"//test/util:file_descriptor",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
- ] + select_for_linux([
- ]),
-)
-
-cc_library(
- name = "temp_umask",
- hdrs = ["temp_umask.h"],
+ ],
)
cc_library(
@@ -146,9 +172,9 @@ cc_library(
hdrs = ["unix_domain_socket_test_util.h"],
deps = [
":socket_test_util",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_util",
],
)
@@ -170,28 +196,33 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "32bit_test",
testonly = 1,
- srcs = ["32bit.cc"],
+ srcs = select_arch(
+ amd64 = ["32bit.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:memory_util",
+ "//test/util:platform_util",
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -204,9 +235,9 @@ cc_binary(
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -219,9 +250,9 @@ cc_binary(
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -233,10 +264,10 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -248,12 +279,12 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -266,12 +297,11 @@ cc_binary(
],
linkstatic = 1,
deps = [
- # The heapchecker doesn't recognize that io_destroy munmaps.
- "@com_google_googletest//:gtest",
- "@com_google_absl//absl/strings",
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:proc_util",
@@ -288,12 +318,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -306,9 +336,9 @@ cc_binary(
"//:sandbox",
],
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -320,9 +350,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -333,10 +363,26 @@ cc_binary(
linkstatic = 1,
deps = [
":socket_test_util",
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:temp_umask",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_capability_test",
+ testonly = 1,
+ srcs = ["socket_capability.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:capability_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -358,10 +404,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -374,10 +420,10 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -390,14 +436,14 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/synchronization",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/synchronization",
- "@com_google_googletest//:gtest",
],
)
@@ -410,12 +456,12 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_googletest//:gtest",
],
)
@@ -429,12 +475,12 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:mount_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -444,9 +490,9 @@ cc_binary(
srcs = ["clock_getres.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -456,11 +502,11 @@ cc_binary(
srcs = ["clock_gettime.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -470,12 +516,13 @@ cc_binary(
srcs = ["concurrency.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:platform_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -488,9 +535,9 @@ cc_binary(
":socket_test_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -501,10 +548,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -515,9 +562,9 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -529,11 +576,11 @@ cc_binary(
deps = [
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -546,10 +593,10 @@ cc_binary(
"//test/util:epoll_util",
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -561,24 +608,28 @@ cc_binary(
deps = [
"//test/util:epoll_util",
"//test/util:eventfd_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "exceptions_test",
testonly = 1,
- srcs = ["exceptions.cc"],
+ srcs = select_arch(
+ amd64 = ["exceptions.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
+ "//test/util:platform_util",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -588,10 +639,10 @@ cc_binary(
srcs = ["getcpu.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -601,10 +652,10 @@ cc_binary(
srcs = ["getcpu.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -614,13 +665,13 @@ cc_binary(
srcs = ["getrusage.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -633,14 +684,14 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -663,15 +714,15 @@ cc_binary(
deps = [
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:optional",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/types:optional",
- "@com_google_googletest//:gtest",
],
)
@@ -682,11 +733,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:time_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -697,12 +748,17 @@ cc_binary(
linkstatic = 1,
deps = [
":file_base",
+ ":socket_test_util",
"//test/util:cleanup",
+ "//test/util:eventfd_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -712,9 +768,9 @@ cc_binary(
srcs = ["fault.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -725,10 +781,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -740,18 +796,22 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:cleanup",
+ "//test/util:epoll_util",
"//test/util:eventfd_util",
- "//test/util:multiprocess_util",
- "//test/util:posix_error",
- "//test/util:temp_path",
- "//test/util:test_util",
- "//test/util:timer_util",
+ "//test/util:fs_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:multiprocess_util",
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:temp_path",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "//test/util:timer_util",
],
)
@@ -764,16 +824,19 @@ cc_binary(
],
linkstatic = 1,
deps = [
+ ":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:epoll_util",
+ "//test/util:eventfd_util",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -784,13 +847,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -800,11 +863,11 @@ cc_binary(
srcs = ["fpsig_fork.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -814,10 +877,10 @@ cc_binary(
srcs = ["fpsig_nested.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -828,10 +891,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -842,10 +905,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -857,6 +920,9 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:memory_util",
"//test/util:save_util",
"//test/util:temp_path",
@@ -865,9 +931,6 @@ cc_binary(
"//test/util:thread_util",
"//test/util:time_util",
"//test/util:timer_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -880,12 +943,13 @@ cc_binary(
"//test/util:eventfd_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/container:node_hash_set",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -895,9 +959,9 @@ cc_binary(
srcs = ["getrandom.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -910,12 +974,14 @@ cc_binary(
"//test/util:epoll_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)
@@ -930,10 +996,10 @@ cc_binary(
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -957,9 +1023,9 @@ cc_binary(
":socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -970,6 +1036,9 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -977,9 +1046,6 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -991,15 +1057,15 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1012,14 +1078,14 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1030,10 +1096,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1044,6 +1110,7 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:multiprocess_util",
@@ -1051,7 +1118,6 @@ cc_binary(
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1062,12 +1128,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ "@com_google_absl//absl/memory",
+ gtest,
"//test/util:memory_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest",
],
)
@@ -1077,11 +1143,11 @@ cc_binary(
srcs = ["mincore.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1091,13 +1157,13 @@ cc_binary(
srcs = ["mkdir.cc"],
linkstatic = 1,
deps = [
- ":temp_umask",
"//test/util:capability_util",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
+ "//test/util:temp_umask",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1108,11 +1174,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1124,12 +1190,12 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:cleanup",
+ gtest,
"//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:rlimit_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1142,13 +1208,13 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1161,6 +1227,9 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:mount_util",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -1168,9 +1237,6 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1180,10 +1246,9 @@ cc_binary(
srcs = ["mremap.cc"],
linkstatic = 1,
deps = [
- # The heap check fails due to MremapDeathTest
- "@com_google_googletest//:gtest",
- "@com_google_absl//absl/strings",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:multiprocess_util",
@@ -1215,9 +1280,9 @@ cc_binary(
srcs = ["munmap.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1234,14 +1299,14 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1251,14 +1316,14 @@ cc_binary(
srcs = ["open_create.cc"],
linkstatic = 1,
deps = [
- ":temp_umask",
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:temp_path",
+ "//test/util:temp_umask",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1266,17 +1331,18 @@ cc_binary(
name = "packet_socket_raw_test",
testonly = 1,
srcs = ["packet_socket_raw.cc"],
+ defines = select_system(),
linkstatic = 1,
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -1290,11 +1356,11 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -1306,16 +1372,16 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:pty_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1327,12 +1393,12 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:posix_error",
"//test/util:pty_util",
"//test/util:test_main",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1342,15 +1408,15 @@ cc_binary(
srcs = ["partial_bad_buffer.cc"],
linkstatic = 1,
deps = [
- "//test/syscalls/linux:socket_test_util",
+ ":socket_test_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1360,13 +1426,28 @@ cc_binary(
srcs = ["pause.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "ping_socket_test",
+ testonly = 1,
+ srcs = ["ping_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -1377,15 +1458,16 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1398,13 +1480,13 @@ cc_binary(
":base_poll_test",
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1415,23 +1497,27 @@ cc_binary(
linkstatic = 1,
deps = [
":base_poll_test",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "arch_prctl_test",
testonly = 1,
- srcs = ["arch_prctl.cc"],
+ srcs = select_arch(
+ amd64 = ["arch_prctl.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ "//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1443,12 +1529,12 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:cleanup",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_googletest//:gtest",
],
)
@@ -1459,13 +1545,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/flags:flag",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_googletest//:gtest",
],
)
@@ -1476,10 +1562,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1490,6 +1576,8 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:temp_path",
@@ -1497,8 +1585,6 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1512,13 +1598,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1530,11 +1616,11 @@ cc_binary(
deps = [
"//test/util:capability_util",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1548,6 +1634,10 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:temp_path",
@@ -1555,10 +1645,6 @@ cc_binary(
"//test/util:thread_util",
"//test/util:time_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1572,11 +1658,24 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "proc_pid_oomscore_test",
+ testonly = 1,
+ srcs = ["proc_pid_oomscore.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:fs_util",
"//test/util:test_main",
"//test/util:test_util",
"@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1588,17 +1687,17 @@ cc_binary(
deps = [
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
+ gtest,
"//test/util:memory_util",
"//test/util:posix_error",
"//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/container:flat_hash_set",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:optional",
- "@com_google_googletest//:gtest",
],
)
@@ -1612,6 +1711,8 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -1619,8 +1720,6 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:time_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1631,11 +1730,11 @@ cc_binary(
linkstatic = 1,
deps = [
":base_poll_test",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1645,15 +1744,16 @@ cc_binary(
srcs = ["ptrace.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
+ "//test/util:platform_util",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1663,10 +1763,10 @@ cc_binary(
srcs = ["pwrite64.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1680,12 +1780,12 @@ cc_binary(
deps = [
":file_base",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1699,28 +1799,29 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
cc_binary(
- name = "raw_socket_ipv4_test",
+ name = "raw_socket_test",
testonly = 1,
- srcs = ["raw_socket_ipv4.cc"],
+ srcs = ["raw_socket.cc"],
+ defines = select_system(),
linkstatic = 1,
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1734,10 +1835,10 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:capability_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1748,10 +1849,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1762,10 +1863,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1781,13 +1882,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:timer_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1795,7 +1896,6 @@ cc_binary(
name = "readv_socket_test",
testonly = 1,
srcs = [
- "file_base.h",
"readv_common.cc",
"readv_common.h",
"readv_socket.cc",
@@ -1803,12 +1903,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1822,11 +1922,11 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1843,17 +1943,33 @@ cc_binary(
)
cc_binary(
+ name = "rseq_test",
+ testonly = 1,
+ srcs = ["rseq.cc"],
+ data = ["//test/syscalls/linux/rseq"],
+ linkstatic = 1,
+ deps = [
+ "//test/syscalls/linux/rseq:lib",
+ gtest,
+ "//test/util:logging",
+ "//test/util:multiprocess_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "rtsignal_test",
testonly = 1,
srcs = ["rtsignal.cc"],
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ gtest,
"//test/util:logging",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1863,9 +1979,9 @@ cc_binary(
srcs = ["sched.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1875,9 +1991,9 @@ cc_binary(
srcs = ["sched_yield.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1887,6 +2003,8 @@ cc_binary(
srcs = ["seccomp.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/base:core_headers",
+ gtest,
"//test/util:logging",
"//test/util:memory_util",
"//test/util:multiprocess_util",
@@ -1894,8 +2012,6 @@ cc_binary(
"//test/util:proc_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_googletest//:gtest",
],
)
@@ -1907,14 +2023,14 @@ cc_binary(
deps = [
":base_poll_test",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:rlimit_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1926,13 +2042,13 @@ cc_binary(
deps = [
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1944,12 +2060,14 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
+ ":ip_socket_test_util",
+ ":unix_domain_socket_test_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -1960,13 +2078,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -1976,9 +2094,9 @@ cc_binary(
srcs = ["sigaction.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -1993,28 +2111,34 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:fs_util",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
name = "sigiret_test",
testonly = 1,
- srcs = ["sigiret.cc"],
+ srcs = select_arch(
+ amd64 = ["sigiret.cc"],
+ arm64 = [],
+ ),
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:timer_util",
- "@com_google_googletest//:gtest",
- ],
+ ] + select_arch(
+ amd64 = [],
+ arm64 = ["//test/util:test_main"],
+ ),
)
cc_binary(
@@ -2024,14 +2148,14 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/synchronization",
+ gtest,
"//test/util:logging",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_googletest//:gtest",
],
)
@@ -2041,10 +2165,10 @@ cc_binary(
srcs = ["sigprocmask.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2054,13 +2178,13 @@ cc_binary(
srcs = ["sigstop.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -2071,13 +2195,13 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -2093,14 +2217,30 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_util",
],
alwayslink = 1,
)
+cc_binary(
+ name = "socket_stress_test",
+ testonly = 1,
+ srcs = [
+ "socket_generic_stress.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
cc_library(
name = "socket_unix_dgram_test_cases",
testonly = 1,
@@ -2109,8 +2249,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2123,8 +2263,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2140,8 +2280,11 @@ cc_library(
],
deps = [
":socket_test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
+ "//test/util:thread_util",
],
alwayslink = 1,
)
@@ -2158,8 +2301,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2176,9 +2319,9 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:memory_util",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2196,8 +2339,8 @@ cc_library(
":ip_socket_test_util",
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2214,8 +2357,8 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2232,8 +2375,9 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ "@com_google_absl//absl/memory",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2250,8 +2394,8 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2268,8 +2412,8 @@ cc_library(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2332,9 +2476,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2364,9 +2508,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2396,9 +2540,9 @@ cc_binary(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2496,10 +2640,10 @@ cc_binary(
":socket_bind_to_device_util",
":socket_test_util",
"//test/util:capability_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2515,10 +2659,11 @@ cc_binary(
":socket_bind_to_device_util",
":socket_test_util",
"//test/util:capability_util",
+ "@com_google_absl//absl/container:node_hash_map",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2534,10 +2679,10 @@ cc_binary(
":socket_bind_to_device_util",
":socket_test_util",
"//test/util:capability_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2583,9 +2728,9 @@ cc_binary(
deps = [
":ip_socket_test_util",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2661,17 +2806,52 @@ cc_binary(
srcs = ["socket_inet_loopback.cc"],
linkstatic = 1,
deps = [
+ ":ip_socket_test_util",
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:save_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_inet_loopback_nogotsan_test",
+ testonly = 1,
+ srcs = ["socket_inet_loopback_nogotsan.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:save_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_netlink_test",
+ testonly = 1,
+ srcs = ["socket_netlink.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -2681,14 +2861,31 @@ cc_binary(
srcs = ["socket_netlink_route.cc"],
linkstatic = 1,
deps = [
+ ":socket_netlink_route_util",
":socket_netlink_util",
":socket_test_util",
+ "//test/util:capability_util",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "socket_netlink_uevent_test",
+ testonly = 1,
+ srcs = ["socket_netlink_uevent.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_netlink_util",
+ ":socket_test_util",
+ "//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
],
)
@@ -2706,9 +2903,9 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
- "//test/util:test_util",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_util",
],
alwayslink = 1,
)
@@ -2725,11 +2922,11 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2746,10 +2943,10 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2766,10 +2963,10 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2786,11 +2983,11 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2807,8 +3004,8 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2825,11 +3022,10 @@ cc_library(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_util",
"//test/util:thread_util",
- "//test/util:timer_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
alwayslink = 1,
)
@@ -2923,9 +3119,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2937,9 +3133,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2951,9 +3147,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2968,9 +3164,9 @@ cc_binary(
":socket_blocking_test_cases",
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -2985,9 +3181,9 @@ cc_binary(
":ip_socket_test_util",
":socket_blocking_test_cases",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3002,9 +3198,9 @@ cc_binary(
":socket_non_stream_blocking_test_cases",
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3019,9 +3215,9 @@ cc_binary(
":ip_socket_test_util",
":socket_non_stream_blocking_test_cases",
":socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3037,9 +3233,9 @@ cc_binary(
":socket_unix_cmsg_test_cases",
":socket_unix_test_cases",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3051,9 +3247,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3065,9 +3261,9 @@ cc_binary(
deps = [
":socket_test_util",
":unix_domain_socket_test_util",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3080,10 +3276,10 @@ cc_binary(
":socket_netlink_util",
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/base:endian",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/base:endian",
- "@com_google_googletest//:gtest",
],
)
@@ -3099,12 +3295,12 @@ cc_binary(
"//test/util:cleanup",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3115,11 +3311,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3133,12 +3329,12 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3151,10 +3347,11 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3164,10 +3361,10 @@ cc_binary(
srcs = ["sync.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3177,10 +3374,10 @@ cc_binary(
srcs = ["sysinfo.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3190,9 +3387,9 @@ cc_binary(
srcs = ["syslog.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3202,10 +3399,10 @@ cc_binary(
srcs = ["sysret.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3213,18 +3410,17 @@ cc_binary(
name = "tcp_socket_test",
testonly = 1,
srcs = ["tcp_socket.cc"],
+ defines = select_system(),
linkstatic = 1,
- # FIXME(b/135470853)
- tags = ["flaky"],
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3234,11 +3430,11 @@ cc_binary(
srcs = ["tgkill.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:signal_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3248,10 +3444,10 @@ cc_binary(
srcs = ["time.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:proc_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3276,15 +3472,15 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3294,11 +3490,11 @@ cc_binary(
srcs = ["tkill.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:logging",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3312,28 +3508,78 @@ cc_binary(
"//test/util:capability_util",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "tuntap_test",
+ testonly = 1,
+ srcs = ["tuntap.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ ":socket_netlink_route_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
- name = "udp_socket_test",
+ name = "tuntap_hostinet_test",
testonly = 1,
- srcs = ["udp_socket.cc"],
+ srcs = ["tuntap_hostinet.cc"],
linkstatic = 1,
deps = [
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_library(
+ name = "udp_socket_test_cases",
+ testonly = 1,
+ srcs = [
+ "udp_socket_errqueue_test_case.cc",
+ "udp_socket_test_cases.cc",
+ ],
+ hdrs = ["udp_socket_test_cases.h"],
+ defines = select_system(),
+ deps = [
+ ":ip_socket_test_util",
":socket_test_util",
":unix_domain_socket_test_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/base:core_headers",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ ],
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "udp_socket_test",
+ testonly = 1,
+ srcs = ["udp_socket.cc"],
+ linkstatic = 1,
+ deps = [
+ ":udp_socket_test_cases",
],
)
@@ -3345,9 +3591,9 @@ cc_binary(
deps = [
":socket_test_util",
"//test/util:file_descriptor",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3358,14 +3604,14 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:uid_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3376,11 +3622,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:capability_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3393,11 +3639,11 @@ cc_binary(
"//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3407,11 +3653,11 @@ cc_binary(
srcs = ["unshare.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/synchronization",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/synchronization",
- "@com_google_googletest//:gtest",
],
)
@@ -3437,11 +3683,11 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:fs_util",
+ gtest,
"//test/util:posix_error",
"//test/util:proc_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3451,13 +3697,13 @@ cc_binary(
srcs = ["vfork.cc"],
linkstatic = 1,
deps = [
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:test_util",
"//test/util:time_util",
- "@com_google_absl//absl/flags:flag",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3469,6 +3715,10 @@ cc_binary(
deps = [
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/time",
+ gtest,
"//test/util:logging",
"//test/util:multiprocess_util",
"//test/util:posix_error",
@@ -3477,10 +3727,6 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/synchronization",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
],
)
@@ -3491,10 +3737,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:cleanup",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3505,30 +3751,46 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ gtest,
"//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
],
)
cc_binary(
- name = "semaphore_test",
+ name = "network_namespace_test",
testonly = 1,
- srcs = ["semaphore.cc"],
+ srcs = ["network_namespace.cc"],
linkstatic = 1,
deps = [
+ ":socket_test_util",
+ gtest,
"//test/util:capability_util",
+ "//test/util:posix_error",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "semaphore_test",
+ testonly = 1,
+ srcs = ["semaphore.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
],
)
@@ -3554,10 +3816,10 @@ cc_binary(
linkstatic = 1,
deps = [
"//test/util:file_descriptor",
+ gtest,
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3567,11 +3829,11 @@ cc_binary(
srcs = ["vdso_clock_gettime.cc"],
linkstatic = 1,
deps = [
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -3581,10 +3843,10 @@ cc_binary(
srcs = ["vsyscall.cc"],
linkstatic = 1,
deps = [
+ gtest,
"//test/util:proc_util",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3597,11 +3859,11 @@ cc_binary(
":unix_domain_socket_test_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
- "//test/util:test_main",
- "//test/util:test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
@@ -3613,12 +3875,12 @@ cc_binary(
deps = [
"//test/util:file_descriptor",
"//test/util:fs_util",
+ gtest,
"//test/util:memory_util",
"//test/util:multiprocess_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_googletest//:gtest",
],
)
@@ -3630,10 +3892,10 @@ cc_binary(
deps = [
":ip_socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
],
)
@@ -3645,9 +3907,31 @@ cc_binary(
deps = [
":ip_socket_test_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/strings",
+ gtest,
"//test/util:test_main",
"//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "xattr_test",
+ testonly = 1,
+ srcs = [
+ "file_base.h",
+ "xattr.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
],
)
diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc
index 427c42ede..f65a14fb8 100644
--- a/test/syscalls/linux/accept_bind.cc
+++ b/test/syscalls/linux/accept_bind.cc
@@ -13,9 +13,12 @@
// limitations under the License.
#include <stdio.h>
+#include <sys/socket.h>
#include <sys/un.h>
+
#include <algorithm>
#include <vector>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
@@ -139,6 +142,47 @@ TEST_P(AllSocketPairTest, Connect) {
SyscallSucceeds());
}
+TEST_P(AllSocketPairTest, ConnectWithWrongType) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int type;
+ socklen_t typelen = sizeof(type);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen),
+ SyscallSucceeds());
+ switch (type) {
+ case SOCK_STREAM:
+ type = SOCK_SEQPACKET;
+ break;
+ case SOCK_SEQPACKET:
+ type = SOCK_STREAM;
+ break;
+ }
+
+ const FileDescriptor another_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0));
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds());
+
+ if (sockets->first_addr()->sa_data[0] != 0) {
+ ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(EPROTOTYPE));
+ } else {
+ ASSERT_THAT(connect(another_socket.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallFailsWithErrno(ECONNREFUSED));
+ }
+
+ ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+}
+
TEST_P(AllSocketPairTest, ConnectNonListening) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/accept_bind_stream.cc b/test/syscalls/linux/accept_bind_stream.cc
index 7bcd91e9e..4857f160b 100644
--- a/test/syscalls/linux/accept_bind_stream.cc
+++ b/test/syscalls/linux/accept_bind_stream.cc
@@ -14,8 +14,10 @@
#include <stdio.h>
#include <sys/un.h>
+
#include <algorithm>
#include <vector>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc
index b27d4e10a..806d5729e 100644
--- a/test/syscalls/linux/aio.cc
+++ b/test/syscalls/linux/aio.cc
@@ -89,6 +89,7 @@ class AIOTest : public FileTest {
FileTest::TearDown();
if (ctx_ != 0) {
ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
}
}
@@ -129,7 +130,7 @@ TEST_F(AIOTest, BasicWrite) {
// aio implementation uses aio_ring. gVisor doesn't and returns all zeroes.
// Linux implements aio_ring, so skip the zeroes check.
//
- // TODO(b/65486370): Remove when gVisor implements aio_ring.
+ // TODO(gvisor.dev/issue/204): Remove when gVisor implements aio_ring.
auto ring = reinterpret_cast<struct aio_ring*>(ctx_);
auto magic = IsRunningOnGvisor() ? 0 : AIO_RING_MAGIC;
EXPECT_EQ(ring->magic, magic);
@@ -188,14 +189,19 @@ TEST_F(AIOTest, BadWrite) {
}
TEST_F(AIOTest, ExitWithPendingIo) {
- // Setup a context that is 5 entries deep.
- ASSERT_THAT(SetupContext(5), SyscallSucceeds());
+ // Setup a context that is 100 entries deep.
+ ASSERT_THAT(SetupContext(100), SyscallSucceeds());
struct iocb cb = CreateCallback();
struct iocb* cbs[] = {&cb};
// Submit a request but don't complete it to make it pending.
- EXPECT_THAT(Submit(1, cbs), SyscallSucceeds());
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_THAT(Submit(1, cbs), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(DestroyContext(), SyscallSucceeds());
+ ctx_ = 0;
}
int Submitter(void* arg) {
diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc
index d89269985..940c97285 100644
--- a/test/syscalls/linux/alarm.cc
+++ b/test/syscalls/linux/alarm.cc
@@ -188,6 +188,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/bad.cc b/test/syscalls/linux/bad.cc
index f246a799e..a26fc6af3 100644
--- a/test/syscalls/linux/bad.cc
+++ b/test/syscalls/linux/bad.cc
@@ -22,11 +22,17 @@ namespace gvisor {
namespace testing {
namespace {
+#ifdef __x86_64__
+// get_kernel_syms is not supported in Linux > 2.6, and not implemented in
+// gVisor.
+constexpr uint32_t kNotImplementedSyscall = SYS_get_kernel_syms;
+#elif __aarch64__
+// Use the last of arch_specific_syscalls which are not implemented on arm64.
+constexpr uint32_t kNotImplementedSyscall = __NR_arch_specific_syscall + 15;
+#endif
TEST(BadSyscallTest, NotImplemented) {
- // get_kernel_syms is not supported in Linux > 2.6, and not implemented in
- // gVisor.
- EXPECT_THAT(syscall(SYS_get_kernel_syms), SyscallFailsWithErrno(ENOSYS));
+ EXPECT_THAT(syscall(kNotImplementedSyscall), SyscallFailsWithErrno(ENOSYS));
}
TEST(BadSyscallTest, NegativeOne) {
diff --git a/test/syscalls/linux/chmod.cc b/test/syscalls/linux/chmod.cc
index 7e918b9b2..a06b5cfd6 100644
--- a/test/syscalls/linux/chmod.cc
+++ b/test/syscalls/linux/chmod.cc
@@ -16,6 +16,7 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <string>
#include "gtest/gtest.h"
diff --git a/test/syscalls/linux/chroot.cc b/test/syscalls/linux/chroot.cc
index de1611c21..85ec013d5 100644
--- a/test/syscalls/linux/chroot.cc
+++ b/test/syscalls/linux/chroot.cc
@@ -19,6 +19,7 @@
#include <sys/stat.h>
#include <syscall.h>
#include <unistd.h>
+
#include <string>
#include <vector>
@@ -161,12 +162,12 @@ TEST(ChrootTest, DotDotFromOpenFD) {
// getdents on fd should not error.
char buf[1024];
- ASSERT_THAT(syscall(SYS_getdents, fd.get(), buf, sizeof(buf)),
+ ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)),
SyscallSucceeds());
}
// Test that link resolution in a chroot can escape the root by following an
-// open proc fd.
+// open proc fd. Regression test for b/32316719.
TEST(ChrootTest, ProcFdLinkResolutionInChroot) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_CHROOT)));
diff --git a/test/syscalls/linux/clock_gettime.cc b/test/syscalls/linux/clock_gettime.cc
index c9e3ed6b2..7f6015049 100644
--- a/test/syscalls/linux/clock_gettime.cc
+++ b/test/syscalls/linux/clock_gettime.cc
@@ -14,6 +14,7 @@
#include <pthread.h>
#include <sys/time.h>
+
#include <cerrno>
#include <cstdint>
#include <ctime>
@@ -55,11 +56,6 @@ void spin_ns(int64_t ns) {
// Test that CLOCK_PROCESS_CPUTIME_ID is a superset of CLOCK_THREAD_CPUTIME_ID.
TEST(ClockGettime, CputimeId) {
- // TODO(b/128871825,golang.org/issue/10958): Test times out when there is a
- // small number of core because one goroutine starves the others.
- printf("CPUS: %d\n", std::thread::hardware_concurrency());
- SKIP_IF(std::thread::hardware_concurrency() <= 2);
-
constexpr int kNumThreads = 13; // arbitrary
absl::Duration spin_time = absl::Seconds(1);
diff --git a/test/syscalls/linux/concurrency.cc b/test/syscalls/linux/concurrency.cc
index 4e0a13f8b..7cd6a75bd 100644
--- a/test/syscalls/linux/concurrency.cc
+++ b/test/syscalls/linux/concurrency.cc
@@ -13,12 +13,14 @@
// limitations under the License.
#include <signal.h>
+
#include <atomic>
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
+#include "test/util/platform_util.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -44,7 +46,8 @@ TEST(ConcurrencyTest, SingleProcessMultithreaded) {
}
// Test that multiple threads in this process continue to execute in parallel,
-// even if an unrelated second process is spawned.
+// even if an unrelated second process is spawned. Regression test for
+// b/32119508.
TEST(ConcurrencyTest, MultiProcessMultithreaded) {
// In PID 1, start TIDs 1 and 2, and put both to sleep.
//
@@ -98,6 +101,7 @@ TEST(ConcurrencyTest, MultiProcessMultithreaded) {
// Test that multiple processes can execute concurrently, even if one process
// never yields.
TEST(ConcurrencyTest, MultiProcessConcurrency) {
+ SKIP_IF(PlatformSupportMultiProcess() == PlatformSupport::NotSupported);
pid_t child_pid = fork();
if (child_pid == 0) {
diff --git a/test/syscalls/linux/connect_external.cc b/test/syscalls/linux/connect_external.cc
index bfe1da82e..1edb50e47 100644
--- a/test/syscalls/linux/connect_external.cc
+++ b/test/syscalls/linux/connect_external.cc
@@ -56,7 +56,7 @@ TEST_P(GoferStreamSeqpacketTest, Echo) {
ProtocolSocket proto;
std::tie(env, proto) = GetParam();
- char *val = getenv(env.c_str());
+ char* val = getenv(env.c_str());
ASSERT_NE(val, nullptr);
std::string root(val);
@@ -69,7 +69,7 @@ TEST_P(GoferStreamSeqpacketTest, Echo) {
addr.sun_family = AF_UNIX;
memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
- ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr),
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr),
sizeof(addr)),
SyscallSucceeds());
@@ -92,7 +92,7 @@ TEST_P(GoferStreamSeqpacketTest, NonListening) {
ProtocolSocket proto;
std::tie(env, proto) = GetParam();
- char *val = getenv(env.c_str());
+ char* val = getenv(env.c_str());
ASSERT_NE(val, nullptr);
std::string root(val);
@@ -105,7 +105,7 @@ TEST_P(GoferStreamSeqpacketTest, NonListening) {
addr.sun_family = AF_UNIX;
memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
- ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr),
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr),
sizeof(addr)),
SyscallFailsWithErrno(ECONNREFUSED));
}
@@ -127,7 +127,7 @@ using GoferDgramTest = ::testing::TestWithParam<std::string>;
// unnamed. The server thus has no way to reply to us.
TEST_P(GoferDgramTest, Null) {
std::string env = GetParam();
- char *val = getenv(env.c_str());
+ char* val = getenv(env.c_str());
ASSERT_NE(val, nullptr);
std::string root(val);
@@ -140,7 +140,7 @@ TEST_P(GoferDgramTest, Null) {
addr.sun_family = AF_UNIX;
memcpy(addr.sun_path, socket_path.c_str(), socket_path.length());
- ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr *>(&addr),
+ ASSERT_THAT(connect(sock.get(), reinterpret_cast<struct sockaddr*>(&addr),
sizeof(addr)),
SyscallSucceeds());
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
index 4dd302eed..1d0d584cd 100644
--- a/test/syscalls/linux/dev.cc
+++ b/test/syscalls/linux/dev.cc
@@ -153,6 +153,27 @@ TEST(DevTest, TTYExists) {
EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
}
+TEST(DevTest, OpenDevFuse) {
+ // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new
+ // device registration is complete.
+ SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor() || !IsFUSEEnabled());
+
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY));
+}
+
+TEST(DevTest, ReadDevFuseWithoutMount) {
+ // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new
+ // device registration is complete.
+ SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor());
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY));
+
+ std::vector<char> buf(1);
+ EXPECT_THAT(ReadFd(fd.get(), buf.data(), sizeof(buf)),
+ SyscallFailsWithErrno(EPERM));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc
index a4f8f3cec..2101e5c9f 100644
--- a/test/syscalls/linux/epoll.cc
+++ b/test/syscalls/linux/epoll.cc
@@ -56,10 +56,6 @@ TEST(EpollTest, AllWritable) {
struct epoll_event result[kFDsPerEpoll];
ASSERT_THAT(RetryEINTR(epoll_wait)(epollfd.get(), result, kFDsPerEpoll, -1),
SyscallSucceedsWithValue(kFDsPerEpoll));
- // TODO(edahlgren): Why do some tests check epoll_event::data, and others
- // don't? Does Linux actually guarantee that, in any of these test cases,
- // epoll_wait will necessarily write out the epoll_events in the order that
- // they were registered?
for (int i = 0; i < kFDsPerEpoll; i++) {
ASSERT_EQ(result[i].events, EPOLLOUT);
}
@@ -426,6 +422,28 @@ TEST(EpollTest, CloseFile) {
SyscallSucceedsWithValue(0));
}
+TEST(EpollTest, PipeReaderHupAfterWriterClosed) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ int pipefds[2];
+ ASSERT_THAT(pipe(pipefds), SyscallSucceeds());
+ FileDescriptor rfd(pipefds[0]);
+ FileDescriptor wfd(pipefds[1]);
+
+ ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), rfd.get(), 0, kMagicConstant));
+ struct epoll_event result[kFDsPerEpoll];
+ // Initially, rfd should not generate any events of interest.
+ ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0),
+ SyscallSucceedsWithValue(0));
+ // Close the write end of the pipe.
+ wfd.reset();
+ // rfd should now generate EPOLLHUP, which EPOLL_CTL_ADD unconditionally adds
+ // to the set of events of interest.
+ ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(result[0].events, EPOLLHUP);
+ EXPECT_EQ(result[0].data.u64, kMagicConstant);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc
index 367682c3d..dc794415e 100644
--- a/test/syscalls/linux/eventfd.cc
+++ b/test/syscalls/linux/eventfd.cc
@@ -100,6 +100,23 @@ TEST(EventfdTest, SmallRead) {
ASSERT_THAT(read(efd.get(), &l, 4), SyscallFailsWithErrno(EINVAL));
}
+TEST(EventfdTest, IllegalSeek) {
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ EXPECT_THAT(lseek(efd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(EventfdTest, IllegalPread) {
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ int l;
+ EXPECT_THAT(pread(efd.get(), &l, sizeof(l), 0),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(EventfdTest, IllegalPwrite) {
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ EXPECT_THAT(pwrite(efd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE));
+}
+
TEST(EventfdTest, BigWrite) {
FileDescriptor efd =
ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
@@ -132,6 +149,31 @@ TEST(EventfdTest, BigWriteBigRead) {
EXPECT_EQ(l[0], 1);
}
+TEST(EventfdTest, SpliceFromPipePartialSucceeds) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+ const FileDescriptor pipe_rfd(pipes[0]);
+ const FileDescriptor pipe_wfd(pipes[1]);
+ constexpr uint64_t kVal{1};
+
+ FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK));
+
+ uint64_t event_array[2];
+ event_array[0] = kVal;
+ event_array[1] = kVal;
+ ASSERT_THAT(write(pipe_wfd.get(), event_array, sizeof(event_array)),
+ SyscallSucceedsWithValue(sizeof(event_array)));
+ EXPECT_THAT(splice(pipe_rfd.get(), /*__offin=*/nullptr, efd.get(),
+ /*__offout=*/nullptr, sizeof(event_array[0]) + 1,
+ SPLICE_F_NONBLOCK),
+ SyscallSucceedsWithValue(sizeof(event_array[0])));
+
+ uint64_t val;
+ ASSERT_THAT(read(efd.get(), &val, sizeof(val)),
+ SyscallSucceedsWithValue(sizeof(val)));
+ EXPECT_EQ(val, kVal);
+}
+
// NotifyNonZero is inherently racy, so random save is disabled.
TEST(EventfdTest, NotifyNonZero_NoRandomSave) {
// Waits will time out at 10 seconds.
diff --git a/test/syscalls/linux/exceptions.cc b/test/syscalls/linux/exceptions.cc
index 370e85166..420b9543f 100644
--- a/test/syscalls/linux/exceptions.cc
+++ b/test/syscalls/linux/exceptions.cc
@@ -16,12 +16,30 @@
#include "gtest/gtest.h"
#include "test/util/logging.h"
+#include "test/util/platform_util.h"
#include "test/util/signal_util.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+// Default value for the x87 FPU control word. See Intel SDM Vol 1, Ch 8.1.5
+// "x87 FPU Control Word".
+constexpr uint16_t kX87ControlWordDefault = 0x37f;
+
+// Mask for the divide-by-zero exception.
+constexpr uint16_t kX87ControlWordDiv0Mask = 1 << 2;
+
+// Default value for the SSE control register (MXCSR). See Intel SDM Vol 1, Ch
+// 11.6.4 "Initialization of SSE/SSE3 Extensions".
+constexpr uint32_t kMXCSRDefault = 0x1f80;
+
+// Mask for the divide-by-zero exception.
+constexpr uint32_t kMXCSRDiv0Mask = 1 << 9;
+
+// Flag for a pending divide-by-zero exception.
+constexpr uint32_t kMXCSRDiv0Flag = 1 << 2;
+
void inline Halt() { asm("hlt\r\n"); }
void inline SetAlignmentCheck() {
@@ -107,6 +125,170 @@ TEST(ExceptionTest, DivideByZero) {
::testing::KilledBySignal(SIGFPE), "");
}
+// By default, x87 exceptions are masked and simply return a default value.
+TEST(ExceptionTest, X87DivideByZeroMasked) {
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm("fildl %[value]\r\n"
+ "fidivl %[divisor]\r\n"
+ "fistpl %[quotient]\r\n"
+ : [ quotient ] "=m"(quotient)
+ : [ value ] "m"(value), [ divisor ] "m"(divisor));
+
+ EXPECT_EQ(quotient, INT32_MIN);
+}
+
+// When unmasked, division by zero raises SIGFPE.
+TEST(ExceptionTest, X87DivideByZeroUnmasked) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa));
+
+ EXPECT_EXIT(
+ {
+ // Clear the divide by zero exception mask.
+ constexpr uint16_t kControlWord =
+ kX87ControlWordDefault & ~kX87ControlWordDiv0Mask;
+
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm volatile(
+ "fldcw %[cw]\r\n"
+ "fildl %[value]\r\n"
+ "fidivl %[divisor]\r\n"
+ "fistpl %[quotient]\r\n"
+ : [ quotient ] "=m"(quotient)
+ : [ cw ] "m"(kControlWord), [ value ] "m"(value),
+ [ divisor ] "m"(divisor));
+ },
+ ::testing::KilledBySignal(SIGFPE), "");
+}
+
+// Pending exceptions in the x87 status register are not clobbered by syscalls.
+TEST(ExceptionTest, X87StatusClobber) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa));
+
+ EXPECT_EXIT(
+ {
+ // Clear the divide by zero exception mask.
+ constexpr uint16_t kControlWord =
+ kX87ControlWordDefault & ~kX87ControlWordDiv0Mask;
+
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm volatile(
+ "fildl %[value]\r\n"
+ "fidivl %[divisor]\r\n"
+ // Exception is masked, so it does not occur here.
+ "fistpl %[quotient]\r\n"
+
+ // SYS_getpid placed in rax by constraint.
+ "syscall\r\n"
+
+ // Unmask exception. The syscall didn't clobber the pending
+ // exception, so now it can be raised.
+ //
+ // N.B. "a floating-point exception will be generated upon execution
+ // of the *next* floating-point instruction".
+ "fldcw %[cw]\r\n"
+ "fwait\r\n"
+ : [ quotient ] "=m"(quotient)
+ : [ value ] "m"(value), [ divisor ] "m"(divisor), "a"(SYS_getpid),
+ [ cw ] "m"(kControlWord)
+ : "rcx", "r11");
+ },
+ ::testing::KilledBySignal(SIGFPE), "");
+}
+
+// By default, SSE exceptions are masked and simply return a default value.
+TEST(ExceptionTest, SSEDivideByZeroMasked) {
+ uint32_t status;
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm("cvtsi2ssl %[value], %%xmm0\r\n"
+ "cvtsi2ssl %[divisor], %%xmm1\r\n"
+ "divss %%xmm1, %%xmm0\r\n"
+ "cvtss2sil %%xmm0, %[quotient]\r\n"
+ : [ quotient ] "=r"(quotient), [ status ] "=r"(status)
+ : [ value ] "r"(value), [ divisor ] "r"(divisor)
+ : "xmm0", "xmm1");
+
+ EXPECT_EQ(quotient, INT32_MIN);
+}
+
+// When unmasked, division by zero raises SIGFPE.
+TEST(ExceptionTest, SSEDivideByZeroUnmasked) {
+ // See above.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_DFL;
+ auto const cleanup = ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGFPE, sa));
+
+ EXPECT_EXIT(
+ {
+ // Clear the divide by zero exception mask.
+ constexpr uint32_t kMXCSR = kMXCSRDefault & ~kMXCSRDiv0Mask;
+
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm volatile(
+ "ldmxcsr %[mxcsr]\r\n"
+ "cvtsi2ssl %[value], %%xmm0\r\n"
+ "cvtsi2ssl %[divisor], %%xmm1\r\n"
+ "divss %%xmm1, %%xmm0\r\n"
+ "cvtss2sil %%xmm0, %[quotient]\r\n"
+ : [ quotient ] "=r"(quotient)
+ : [ mxcsr ] "m"(kMXCSR), [ value ] "r"(value),
+ [ divisor ] "r"(divisor)
+ : "xmm0", "xmm1");
+ },
+ ::testing::KilledBySignal(SIGFPE), "");
+}
+
+// Pending exceptions in the SSE status register are not clobbered by syscalls.
+TEST(ExceptionTest, SSEStatusClobber) {
+ uint32_t mxcsr;
+ int32_t quotient;
+ int32_t value = 1;
+ int32_t divisor = 0;
+ asm("cvtsi2ssl %[value], %%xmm0\r\n"
+ "cvtsi2ssl %[divisor], %%xmm1\r\n"
+ "divss %%xmm1, %%xmm0\r\n"
+ // Exception is masked, so it does not occur here.
+ "cvtss2sil %%xmm0, %[quotient]\r\n"
+
+ // SYS_getpid placed in rax by constraint.
+ "syscall\r\n"
+
+ // Intel SDM Vol 1, Ch 10.2.3.1 "SIMD Floating-Point Mask and Flag Bits":
+ // "If LDMXCSR or FXRSTOR clears a mask bit and sets the corresponding
+ // exception flag bit, a SIMD floating-point exception will not be
+ // generated as a result of this change. The unmasked exception will be
+ // generated only upon the execution of the next SSE/SSE2/SSE3 instruction
+ // that detects the unmasked exception condition."
+ //
+ // Though ambiguous, empirical evidence indicates that this means that
+ // exception flags set in the status register will never cause an
+ // exception to be raised; only a new exception condition will do so.
+ //
+ // Thus here we just check for the flag itself rather than trying to raise
+ // the exception.
+ "stmxcsr %[mxcsr]\r\n"
+ : [ quotient ] "=r"(quotient), [ mxcsr ] "+m"(mxcsr)
+ : [ value ] "r"(value), [ divisor ] "r"(divisor), "a"(SYS_getpid)
+ : "xmm0", "xmm1", "rcx", "r11");
+
+ EXPECT_TRUE(mxcsr & kMXCSRDiv0Flag);
+}
+
TEST(ExceptionTest, IOAccessFault) {
// See above.
struct sigaction sa = {};
@@ -143,6 +325,7 @@ TEST(ExceptionTest, AlignmentHalt) {
}
TEST(ExceptionTest, AlignmentCheck) {
+ SKIP_IF(PlatformSupportAlignmentCheck() != PlatformSupport::Allowed);
// See above.
struct sigaction sa = {};
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index 21a5ffd40..c5acfc794 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -47,23 +47,14 @@ namespace testing {
namespace {
-constexpr char kBasicWorkload[] = "exec_basic_workload";
-constexpr char kExitScript[] = "exit_script";
-constexpr char kStateWorkload[] = "exec_state_workload";
-constexpr char kProcExeWorkload[] = "exec_proc_exe_workload";
-constexpr char kAssertClosedWorkload[] = "exec_assert_closed_workload";
-constexpr char kPriorityWorkload[] = "priority_execve";
-
-std::string WorkloadPath(absl::string_view binary) {
- std::string full_path;
- char* test_src = getenv("TEST_SRCDIR");
- if (test_src) {
- full_path = JoinPath(test_src, "__main__/test/syscalls/linux", binary);
- }
-
- TEST_CHECK(full_path.empty() == false);
- return full_path;
-}
+constexpr char kBasicWorkload[] = "test/syscalls/linux/exec_basic_workload";
+constexpr char kExitScript[] = "test/syscalls/linux/exit_script";
+constexpr char kStateWorkload[] = "test/syscalls/linux/exec_state_workload";
+constexpr char kProcExeWorkload[] =
+ "test/syscalls/linux/exec_proc_exe_workload";
+constexpr char kAssertClosedWorkload[] =
+ "test/syscalls/linux/exec_assert_closed_workload";
+constexpr char kPriorityWorkload[] = "test/syscalls/linux/priority_execve";
constexpr char kExit42[] = "--exec_exit_42";
constexpr char kExecWithThread[] = "--exec_exec_with_thread";
@@ -171,44 +162,44 @@ TEST(ExecTest, EmptyPath) {
}
TEST(ExecTest, Basic) {
- CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {},
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {},
ArgEnvExitStatus(0, 0),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n"));
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n"));
}
TEST(ExecTest, OneArg) {
- CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "1"},
- {}, ArgEnvExitStatus(1, 0),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "1"}, {},
+ ArgEnvExitStatus(1, 0),
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n"));
}
TEST(ExecTest, FiveArg) {
- CheckExec(WorkloadPath(kBasicWorkload),
- {WorkloadPath(kBasicWorkload), "1", "2", "3", "4", "5"}, {},
+ CheckExec(RunfilePath(kBasicWorkload),
+ {RunfilePath(kBasicWorkload), "1", "2", "3", "4", "5"}, {},
ArgEnvExitStatus(5, 0),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
}
TEST(ExecTest, OneEnv) {
- CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)}, {"1"},
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)}, {"1"},
ArgEnvExitStatus(0, 1),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n"));
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n"));
}
TEST(ExecTest, FiveEnv) {
- CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload)},
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload)},
{"1", "2", "3", "4", "5"}, ArgEnvExitStatus(0, 5),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
+ absl::StrCat(RunfilePath(kBasicWorkload), "\n1\n2\n3\n4\n5\n"));
}
TEST(ExecTest, OneArgOneEnv) {
- CheckExec(WorkloadPath(kBasicWorkload), {WorkloadPath(kBasicWorkload), "arg"},
+ CheckExec(RunfilePath(kBasicWorkload), {RunfilePath(kBasicWorkload), "arg"},
{"env"}, ArgEnvExitStatus(1, 1),
- absl::StrCat(WorkloadPath(kBasicWorkload), "\narg\nenv\n"));
+ absl::StrCat(RunfilePath(kBasicWorkload), "\narg\nenv\n"));
}
TEST(ExecTest, InterpreterScript) {
- CheckExec(WorkloadPath(kExitScript), {WorkloadPath(kExitScript), "25"}, {},
+ CheckExec(RunfilePath(kExitScript), {RunfilePath(kExitScript), "25"}, {},
ArgEnvExitStatus(25, 0), "");
}
@@ -216,7 +207,7 @@ TEST(ExecTest, InterpreterScript) {
TEST(ExecTest, InterpreterScriptArgSplit) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo bar"),
@@ -230,7 +221,7 @@ TEST(ExecTest, InterpreterScriptArgSplit) {
TEST(ExecTest, InterpreterScriptArgvZero) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
@@ -244,7 +235,7 @@ TEST(ExecTest, InterpreterScriptArgvZero) {
TEST(ExecTest, InterpreterScriptArgvZeroRelative) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
@@ -261,7 +252,7 @@ TEST(ExecTest, InterpreterScriptArgvZeroRelative) {
TEST(ExecTest, InterpreterScriptArgvZeroAdded) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path()), 0755));
@@ -274,7 +265,7 @@ TEST(ExecTest, InterpreterScriptArgvZeroAdded) {
TEST(ExecTest, InterpreterScriptArgNUL) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(),
@@ -289,7 +280,7 @@ TEST(ExecTest, InterpreterScriptArgNUL) {
TEST(ExecTest, InterpreterScriptTrailingWhitespace) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " "), 0755));
@@ -302,7 +293,7 @@ TEST(ExecTest, InterpreterScriptTrailingWhitespace) {
TEST(ExecTest, InterpreterScriptArgWhitespace) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " foo"), 0755));
@@ -325,7 +316,7 @@ TEST(ExecTest, InterpreterScriptNoPath) {
TEST(ExecTest, ExecFn) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(), absl::StrCat("#!", link.path(), " PrintExecFn"),
@@ -342,7 +333,7 @@ TEST(ExecTest, ExecFn) {
}
TEST(ExecTest, ExecName) {
- std::string path = WorkloadPath(kStateWorkload);
+ std::string path = RunfilePath(kStateWorkload);
CheckExec(path, {path, "PrintExecName"}, {}, ArgEnvExitStatus(0, 0),
absl::StrCat(Basename(path).substr(0, 15), "\n"));
@@ -351,7 +342,7 @@ TEST(ExecTest, ExecName) {
TEST(ExecTest, ExecNameScript) {
// Symlink through /tmp to ensure the path is short enough.
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kStateWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kStateWorkload)));
TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
GetAbsoluteTestTmpdir(),
@@ -405,13 +396,13 @@ TEST(ExecStateTest, HandlerReset) {
ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
ExecveArray args = {
- WorkloadPath(kStateWorkload),
+ RunfilePath(kStateWorkload),
"CheckSigHandler",
absl::StrCat(SIGUSR1),
absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_DFL))),
};
- CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+ CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
}
// Ignored signal dispositions are not reset.
@@ -421,13 +412,13 @@ TEST(ExecStateTest, IgnorePreserved) {
ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
ExecveArray args = {
- WorkloadPath(kStateWorkload),
+ RunfilePath(kStateWorkload),
"CheckSigHandler",
absl::StrCat(SIGUSR1),
absl::StrCat(absl::Hex(reinterpret_cast<uintptr_t>(SIG_IGN))),
};
- CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+ CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
}
// Signal masks are not reset on exec
@@ -438,12 +429,12 @@ TEST(ExecStateTest, SignalMask) {
ASSERT_THAT(sigprocmask(SIG_BLOCK, &s, nullptr), SyscallSucceeds());
ExecveArray args = {
- WorkloadPath(kStateWorkload),
+ RunfilePath(kStateWorkload),
"CheckSigBlocked",
absl::StrCat(SIGUSR1),
};
- CheckExec(WorkloadPath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
+ CheckExec(RunfilePath(kStateWorkload), args, {}, W_EXITCODE(0, 0), "");
}
// itimers persist across execve.
@@ -471,7 +462,7 @@ TEST(ExecStateTest, ItimerPreserved) {
}
};
- std::string filename = WorkloadPath(kStateWorkload);
+ std::string filename = RunfilePath(kStateWorkload);
ExecveArray argv = {
filename,
"CheckItimerEnabled",
@@ -495,8 +486,8 @@ TEST(ExecStateTest, ItimerPreserved) {
TEST(ProcSelfExe, ChangesAcrossExecve) {
// See exec_proc_exe_workload for more details. We simply
// assert that the /proc/self/exe link changes across execve.
- CheckExec(WorkloadPath(kProcExeWorkload),
- {WorkloadPath(kProcExeWorkload),
+ CheckExec(RunfilePath(kProcExeWorkload),
+ {RunfilePath(kProcExeWorkload),
ASSERT_NO_ERRNO_AND_VALUE(ProcessExePath(getpid()))},
{}, W_EXITCODE(0, 0), "");
}
@@ -507,8 +498,8 @@ TEST(ExecTest, CloexecNormalFile) {
const FileDescriptor fd_closed_on_exec =
ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC));
- CheckExec(WorkloadPath(kAssertClosedWorkload),
- {WorkloadPath(kAssertClosedWorkload),
+ CheckExec(RunfilePath(kAssertClosedWorkload),
+ {RunfilePath(kAssertClosedWorkload),
absl::StrCat(fd_closed_on_exec.get())},
{}, W_EXITCODE(0, 0), "");
@@ -517,10 +508,10 @@ TEST(ExecTest, CloexecNormalFile) {
const FileDescriptor fd_open_on_exec =
ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY));
- CheckExec(WorkloadPath(kAssertClosedWorkload),
- {WorkloadPath(kAssertClosedWorkload),
- absl::StrCat(fd_open_on_exec.get())},
- {}, W_EXITCODE(2, 0), "");
+ CheckExec(
+ RunfilePath(kAssertClosedWorkload),
+ {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd_open_on_exec.get())},
+ {}, W_EXITCODE(2, 0), "");
}
TEST(ExecTest, CloexecEventfd) {
@@ -528,19 +519,65 @@ TEST(ExecTest, CloexecEventfd) {
ASSERT_THAT(efd = eventfd(0, EFD_CLOEXEC), SyscallSucceeds());
FileDescriptor fd(efd);
- CheckExec(WorkloadPath(kAssertClosedWorkload),
- {WorkloadPath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {},
+ CheckExec(RunfilePath(kAssertClosedWorkload),
+ {RunfilePath(kAssertClosedWorkload), absl::StrCat(fd.get())}, {},
W_EXITCODE(0, 0), "");
}
+constexpr int kLinuxMaxSymlinks = 40;
+
+TEST(ExecTest, SymlinkLimitExceeded) {
+ std::string path = RunfilePath(kBasicWorkload);
+
+ // Hold onto TempPath objects so they are not destructed prematurely.
+ std::vector<TempPath> symlinks;
+ for (int i = 0; i < kLinuxMaxSymlinks + 1; i++) {
+ symlinks.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateSymlinkTo("/tmp", path)));
+ path = symlinks[i].path();
+ }
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(path, {path}, {}, /*child=*/nullptr, &execve_errno));
+ EXPECT_EQ(execve_errno, ELOOP);
+}
+
+TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) {
+ std::string tmp_dir = "/tmp";
+ std::string interpreter_path = "/bin/echo";
+ TempPath script = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ tmp_dir, absl::StrCat("#!", interpreter_path), 0755));
+ std::string script_path = script.path();
+
+ // Hold onto TempPath objects so they are not destructed prematurely.
+ std::vector<TempPath> interpreter_symlinks;
+ std::vector<TempPath> script_symlinks;
+ // Replace both the interpreter and script paths with symlink chains of just
+ // over half the symlink limit each; this is the minimum required to test that
+ // the symlink limit applies separately to each traversal, while tolerating
+ // some symlinks in the resolution of (the original) interpreter_path and
+ // script_path.
+ for (int i = 0; i < (kLinuxMaxSymlinks / 2) + 1; i++) {
+ interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(tmp_dir, interpreter_path)));
+ interpreter_path = interpreter_symlinks[i].path();
+ script_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(tmp_dir, script_path)));
+ script_path = script_symlinks[i].path();
+ }
+
+ CheckExec(script_path, {script_path}, {}, ArgEnvExitStatus(0, 0), "");
+}
+
TEST(ExecveatTest, BasicWithFDCWD) {
- std::string path = WorkloadPath(kBasicWorkload);
+ std::string path = RunfilePath(kBasicWorkload);
CheckExecveat(AT_FDCWD, path, {path}, {}, /*flags=*/0, ArgEnvExitStatus(0, 0),
absl::StrCat(path, "\n"));
}
TEST(ExecveatTest, Basic) {
- std::string absolute_path = WorkloadPath(kBasicWorkload);
+ std::string absolute_path = RunfilePath(kBasicWorkload);
std::string parent_dir = std::string(Dirname(absolute_path));
std::string base = std::string(Basename(absolute_path));
const FileDescriptor dirfd =
@@ -551,7 +588,7 @@ TEST(ExecveatTest, Basic) {
}
TEST(ExecveatTest, FDNotADirectory) {
- std::string absolute_path = WorkloadPath(kBasicWorkload);
+ std::string absolute_path = RunfilePath(kBasicWorkload);
std::string base = std::string(Basename(absolute_path));
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(absolute_path, 0));
@@ -563,13 +600,13 @@ TEST(ExecveatTest, FDNotADirectory) {
}
TEST(ExecveatTest, AbsolutePathWithFDCWD) {
- std::string path = WorkloadPath(kBasicWorkload);
+ std::string path = RunfilePath(kBasicWorkload);
CheckExecveat(AT_FDCWD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0,
absl::StrCat(path, "\n"));
}
TEST(ExecveatTest, AbsolutePath) {
- std::string path = WorkloadPath(kBasicWorkload);
+ std::string path = RunfilePath(kBasicWorkload);
// File descriptor should be ignored when an absolute path is given.
const int32_t badFD = -1;
CheckExecveat(badFD, path, {path}, {}, ArgEnvExitStatus(0, 0), 0,
@@ -577,7 +614,7 @@ TEST(ExecveatTest, AbsolutePath) {
}
TEST(ExecveatTest, EmptyPathBasic) {
- std::string path = WorkloadPath(kBasicWorkload);
+ std::string path = RunfilePath(kBasicWorkload);
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH));
CheckExecveat(fd.get(), "", {path}, {}, AT_EMPTY_PATH, ArgEnvExitStatus(0, 0),
@@ -585,7 +622,7 @@ TEST(ExecveatTest, EmptyPathBasic) {
}
TEST(ExecveatTest, EmptyPathWithDirFD) {
- std::string path = WorkloadPath(kBasicWorkload);
+ std::string path = RunfilePath(kBasicWorkload);
std::string parent_dir = std::string(Dirname(path));
const FileDescriptor dirfd =
ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY));
@@ -598,7 +635,7 @@ TEST(ExecveatTest, EmptyPathWithDirFD) {
}
TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) {
- std::string path = WorkloadPath(kBasicWorkload);
+ std::string path = RunfilePath(kBasicWorkload);
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH));
int execve_errno;
@@ -608,7 +645,7 @@ TEST(ExecveatTest, EmptyPathWithoutEmptyPathFlag) {
}
TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) {
- std::string path = WorkloadPath(kBasicWorkload);
+ std::string path = RunfilePath(kBasicWorkload);
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_PATH));
CheckExecveat(fd.get(), path, {path}, {}, AT_EMPTY_PATH,
@@ -616,7 +653,7 @@ TEST(ExecveatTest, AbsolutePathWithEmptyPathFlag) {
}
TEST(ExecveatTest, RelativePathWithEmptyPathFlag) {
- std::string absolute_path = WorkloadPath(kBasicWorkload);
+ std::string absolute_path = RunfilePath(kBasicWorkload);
std::string parent_dir = std::string(Dirname(absolute_path));
std::string base = std::string(Basename(absolute_path));
const FileDescriptor dirfd =
@@ -629,7 +666,7 @@ TEST(ExecveatTest, RelativePathWithEmptyPathFlag) {
TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) {
std::string parent_dir = "/tmp";
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload)));
const FileDescriptor dirfd =
ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_DIRECTORY));
std::string base = std::string(Basename(link.path()));
@@ -641,10 +678,35 @@ TEST(ExecveatTest, SymlinkNoFollowWithRelativePath) {
EXPECT_EQ(execve_errno, ELOOP);
}
+TEST(ExecveatTest, UnshareFiles) {
+ TempPath tempFile = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(GetAbsoluteTestTmpdir(), "bar", 0755));
+ const FileDescriptor fd_closed_on_exec =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC));
+
+ ExecveArray argv = {"test"};
+ ExecveArray envp;
+ std::string child_path = RunfilePath(kBasicWorkload);
+ pid_t child =
+ syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES, 0, 0, 0, 0);
+ if (child == 0) {
+ execve(child_path.c_str(), argv.get(), envp.get());
+ _exit(1);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_EQ(status, 0);
+
+ struct stat st;
+ EXPECT_THAT(fstat(fd_closed_on_exec.get(), &st), SyscallSucceeds());
+}
+
TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) {
std::string parent_dir = "/tmp";
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo(parent_dir, WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo(parent_dir, RunfilePath(kBasicWorkload)));
std::string path = link.path();
int execve_errno;
@@ -656,7 +718,7 @@ TEST(ExecveatTest, SymlinkNoFollowWithAbsolutePath) {
TEST(ExecveatTest, SymlinkNoFollowAndEmptyPath) {
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateSymlinkTo("/tmp", WorkloadPath(kBasicWorkload)));
+ TempPath::CreateSymlinkTo("/tmp", RunfilePath(kBasicWorkload)));
std::string path = link.path();
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, 0));
@@ -681,6 +743,39 @@ TEST(ExecveatTest, SymlinkNoFollowWithNormalFile) {
ArgEnvExitStatus(0, 0), "");
}
+TEST(ExecveatTest, BasicWithCloexecFD) {
+ std::string path = RunfilePath(kBasicWorkload);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC));
+
+ CheckExecveat(fd.get(), "", {path}, {}, AT_SYMLINK_NOFOLLOW | AT_EMPTY_PATH,
+ ArgEnvExitStatus(0, 0), absl::StrCat(path, "\n"));
+}
+
+TEST(ExecveatTest, InterpreterScriptWithCloexecFD) {
+ std::string path = RunfilePath(kExitScript);
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CLOEXEC));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(fd.get(), "", {path}, {},
+ AT_EMPTY_PATH, /*child=*/nullptr,
+ &execve_errno));
+ EXPECT_EQ(execve_errno, ENOENT);
+}
+
+TEST(ExecveatTest, InterpreterScriptWithCloexecDirFD) {
+ std::string absolute_path = RunfilePath(kExitScript);
+ std::string parent_dir = std::string(Dirname(absolute_path));
+ std::string base = std::string(Basename(absolute_path));
+ const FileDescriptor dirfd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent_dir, O_CLOEXEC | O_DIRECTORY));
+
+ int execve_errno;
+ ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(dirfd.get(), base, {base}, {},
+ /*flags=*/0, /*child=*/nullptr,
+ &execve_errno));
+ EXPECT_EQ(execve_errno, ENOENT);
+}
+
TEST(ExecveatTest, InvalidFlags) {
int execve_errno;
ASSERT_NO_ERRNO_AND_VALUE(ForkAndExecveat(
@@ -701,7 +796,7 @@ TEST(GetpriorityTest, ExecveMaintainsPriority) {
// Program run (priority_execve) will exit(X) where
// X=getpriority(PRIO_PROCESS,0). Check that this exit value is prio.
- CheckExec(WorkloadPath(kPriorityWorkload), {WorkloadPath(kPriorityWorkload)},
+ CheckExec(RunfilePath(kPriorityWorkload), {RunfilePath(kPriorityWorkload)},
{}, W_EXITCODE(expected_exit_code, 0), "");
}
@@ -747,26 +842,28 @@ void ExecFromThread() {
bool ValidateProcCmdlineVsArgv(const int argc, const char* const* argv) {
auto contents_or = GetContents("/proc/self/cmdline");
if (!contents_or.ok()) {
- std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error();
+ std::cerr << "Unable to get /proc/self/cmdline: " << contents_or.error()
+ << std::endl;
return false;
}
auto contents = contents_or.ValueOrDie();
if (contents.back() != '\0') {
- std::cerr << "Non-null terminated /proc/self/cmdline!";
+ std::cerr << "Non-null terminated /proc/self/cmdline!" << std::endl;
return false;
}
contents.pop_back();
std::vector<std::string> procfs_cmdline = absl::StrSplit(contents, '\0');
if (static_cast<int>(procfs_cmdline.size()) != argc) {
- std::cerr << "argc = " << argc << " != " << procfs_cmdline.size();
+ std::cerr << "argc = " << argc << " != " << procfs_cmdline.size()
+ << std::endl;
return false;
}
for (int i = 0; i < argc; ++i) {
if (procfs_cmdline[i] != argv[i]) {
std::cerr << "Procfs command line argument " << i << " mismatch "
- << procfs_cmdline[i] << " != " << argv[i];
+ << procfs_cmdline[i] << " != " << argv[i] << std::endl;
return false;
}
}
@@ -803,6 +900,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 0a3931e5a..18d2f22c1 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -20,6 +20,7 @@
#include <sys/types.h>
#include <sys/user.h>
#include <unistd.h>
+
#include <algorithm>
#include <functional>
#include <iterator>
@@ -47,10 +48,17 @@ namespace {
using ::testing::AnyOf;
using ::testing::Eq;
-#ifndef __x86_64__
+#if !defined(__x86_64__) && !defined(__aarch64__)
// The assembly stub and ELF internal details must be ported to other arches.
-#error "Test only supported on x86-64"
-#endif // __x86_64__
+#error "Test only supported on x86-64/arm64"
+#endif // __x86_64__ || __aarch64__
+
+#if defined(__x86_64__)
+#define EM_TYPE EM_X86_64
+#define IP_REG(p) ((p).rip)
+#define RAX_REG(p) ((p).rax)
+#define RDI_REG(p) ((p).rdi)
+#define RETURN_REG(p) ((p).rax)
// amd64 stub that calls PTRACE_TRACEME and sends itself SIGSTOP.
const char kPtraceCode[] = {
@@ -138,6 +146,76 @@ const char kPtraceCode[] = {
// Size of a syscall instruction.
constexpr int kSyscallSize = 2;
+#elif defined(__aarch64__)
+#define EM_TYPE EM_AARCH64
+#define IP_REG(p) ((p).pc)
+#define RAX_REG(p) ((p).regs[8])
+#define RDI_REG(p) ((p).regs[0])
+#define RETURN_REG(p) ((p).regs[0])
+
+const char kPtraceCode[] = {
+ // MOVD $117, R8 /* ptrace */
+ '\xa8',
+ '\x0e',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R0 /* PTRACE_TRACEME */
+ '\x00',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R1 /* pid */
+ '\x01',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R2 /* addr */
+ '\x02',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // MOVD $0, R3 /* data */
+ '\x03',
+ '\x00',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $172, R8 /* getpid */
+ '\x88',
+ '\x15',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+ // MOVD $129, R8 /* kill, R0=pid */
+ '\x28',
+ '\x10',
+ '\x80',
+ '\xd2',
+ // MOVD $19, R1 /* SIGSTOP */
+ '\x61',
+ '\x02',
+ '\x80',
+ '\xd2',
+ // SVC
+ '\x01',
+ '\x00',
+ '\x00',
+ '\xd4',
+};
+// Size of a syscall instruction.
+constexpr int kSyscallSize = 4;
+#else
+#error "Unknown architecture"
+#endif
+
// This test suite tests executable loading in the kernel (ELF and interpreter
// scripts).
@@ -280,7 +358,7 @@ ElfBinary<64> StandardElf() {
elf.header.e_ident[EI_DATA] = ELFDATA2LSB;
elf.header.e_ident[EI_VERSION] = EV_CURRENT;
elf.header.e_type = ET_EXEC;
- elf.header.e_machine = EM_X86_64;
+ elf.header.e_machine = EM_TYPE;
elf.header.e_version = EV_CURRENT;
elf.header.e_phoff = sizeof(elf.header);
elf.header.e_phentsize = sizeof(decltype(elf)::ElfPhdr);
@@ -326,9 +404,15 @@ TEST(ElfTest, Execute) {
ASSERT_NO_ERRNO(WaitStopped(child));
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
- // RIP is just beyond the final syscall instruction.
- EXPECT_EQ(regs.rip, elf.header.e_entry + sizeof(kPtraceCode));
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
+ // RIP/PC is just beyond the final syscall instruction.
+ EXPECT_EQ(IP_REG(regs), elf.header.e_entry + sizeof(kPtraceCode));
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
{0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
@@ -354,7 +438,12 @@ TEST(ElfTest, MissingText) {
ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0),
SyscallSucceedsWithValue(child));
// It runs off the end of the zeroes filling the end of the page.
+#if defined(__x86_64__)
EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV) << status;
+#elif defined(__aarch64__)
+ // 0 is an invalid instruction opcode on arm64.
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGILL) << status;
+#endif
}
// Typical ELF with a data + bss segment
@@ -717,9 +806,16 @@ TEST(ElfTest, PIE) {
// RIP tells us which page the first segment was loaded into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
// text page.
@@ -786,9 +882,15 @@ TEST(ElfTest, PIENonZeroStart) {
// RIP tells us which page the first segment was loaded into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t load_addr = IP_REG(regs) & ~(kPageSize - 1);
// The ELF is loaded at an arbitrary address, not the first PT_LOAD vaddr.
//
@@ -909,9 +1011,15 @@ TEST(ElfTest, ELFInterpreter) {
// RIP tells us which page the first segment of the interpreter was loaded
// into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(
child, ContainsMappings(std::vector<ProcMapsEntry>({
@@ -1083,9 +1191,15 @@ TEST(ElfTest, ELFInterpreterRelative) {
// RIP tells us which page the first segment of the interpreter was loaded
// into.
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
- const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
+ const uint64_t interp_load_addr = IP_REG(regs) & ~(kPageSize - 1);
EXPECT_THAT(
child, ContainsMappings(std::vector<ProcMapsEntry>({
@@ -1479,14 +1593,21 @@ TEST(ExecveTest, BrkAfterBinary) {
ASSERT_NO_ERRNO(WaitStopped(child));
struct user_regs_struct regs;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
// RIP is just beyond the final syscall instruction. Rewind to execute a brk
// syscall.
- regs.rip -= kSyscallSize;
- regs.rax = __NR_brk;
- regs.rdi = 0;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child, 0, &regs), SyscallSucceeds());
+ IP_REG(regs) -= kSyscallSize;
+ RAX_REG(regs) = __NR_brk;
+ RDI_REG(regs) = 0;
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
// Resume the child, waiting for syscall entry.
ASSERT_THAT(ptrace(PTRACE_SYSCALL, child, 0, 0), SyscallSucceeds());
@@ -1503,7 +1624,12 @@ TEST(ExecveTest, BrkAfterBinary) {
ASSERT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGTRAP)
<< "status = " << status;
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child, 0, &regs), SyscallSucceeds());
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+ // Read exactly the full register set.
+ EXPECT_EQ(iov.iov_len, sizeof(regs));
// brk is after the text page.
//
@@ -1511,7 +1637,7 @@ TEST(ExecveTest, BrkAfterBinary) {
// address will be, but it is always beyond the final page in the binary.
// i.e., it does not start immediately after memsz in the middle of a page.
// Userspace may expect to use that space.
- EXPECT_GE(regs.rax, 0x41000);
+ EXPECT_GE(RETURN_REG(regs), 0x41000);
}
} // namespace
diff --git a/test/syscalls/linux/exec_proc_exe_workload.cc b/test/syscalls/linux/exec_proc_exe_workload.cc
index b790fe5be..2989379b7 100644
--- a/test/syscalls/linux/exec_proc_exe_workload.cc
+++ b/test/syscalls/linux/exec_proc_exe_workload.cc
@@ -21,6 +21,12 @@
#include "test/util/posix_error.h"
int main(int argc, char** argv, char** envp) {
+ // This is annoying. Because remote build systems may put these binaries
+ // in a content-addressable-store, you may wind up with /proc/self/exe
+ // pointing to some random path (but with a sensible argv[0]).
+ //
+ // Therefore, this test simply checks that the /proc/self/exe
+ // is absolute and *doesn't* match argv[1].
std::string exe =
gvisor::testing::ProcessExePath(getpid()).ValueOrDie();
if (exe[0] != '/') {
diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc
index 1c3d00287..cabc2b751 100644
--- a/test/syscalls/linux/fallocate.cc
+++ b/test/syscalls/linux/fallocate.cc
@@ -15,16 +15,27 @@
#include <errno.h>
#include <fcntl.h>
#include <signal.h>
+#include <sys/eventfd.h>
#include <sys/resource.h>
+#include <sys/signalfd.h>
+#include <sys/socket.h>
#include <sys/stat.h>
+#include <sys/timerfd.h>
#include <syscall.h>
#include <time.h>
#include <unistd.h>
+#include <ctime>
+
#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/file_base.h"
+#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/cleanup.h"
+#include "test/util/eventfd_util.h"
#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -33,7 +44,7 @@ namespace testing {
namespace {
int fallocate(int fd, int mode, off_t offset, off_t len) {
- return syscall(__NR_fallocate, fd, mode, offset, len);
+ return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len);
}
class AllocateTest : public FileTest {
@@ -47,27 +58,33 @@ TEST_F(AllocateTest, Fallocate) {
EXPECT_EQ(buf.st_size, 0);
// Grow to ten bytes.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 10);
// Allocate to a smaller size should be noop.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 10);
// Grow again.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 20);
// Grow with offset.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 30);
// Grow with offset beyond EOF.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds());
+ ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
+ EXPECT_EQ(buf.st_size, 40);
+
+ // Given length 0 should fail with EINVAL.
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 50, 0),
+ SyscallFailsWithErrno(EINVAL));
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 40);
}
@@ -136,6 +153,34 @@ TEST_F(AllocateTest, FallocateRlimit) {
ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds());
}
+TEST_F(AllocateTest, FallocateOtherFDs) {
+ int fd;
+ ASSERT_THAT(fd = timerfd_create(CLOCK_MONOTONIC, 0), SyscallSucceeds());
+ auto timer_fd = FileDescriptor(fd);
+ EXPECT_THAT(fallocate(timer_fd.get(), 0, 0, 10),
+ SyscallFailsWithErrno(ENODEV));
+
+ sigset_t mask;
+ sigemptyset(&mask);
+ ASSERT_THAT(fd = signalfd(-1, &mask, 0), SyscallSucceeds());
+ auto sfd = FileDescriptor(fd);
+ EXPECT_THAT(fallocate(sfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+
+ auto efd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE));
+ EXPECT_THAT(fallocate(efd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+
+ auto sockfd = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ EXPECT_THAT(fallocate(sockfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+
+ int socks[2];
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks),
+ SyscallSucceeds());
+ auto sock0 = FileDescriptor(socks[0]);
+ auto sock1 = FileDescriptor(socks[1]);
+ EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV));
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/fault.cc b/test/syscalls/linux/fault.cc
index f6e19026f..a85750382 100644
--- a/test/syscalls/linux/fault.cc
+++ b/test/syscalls/linux/fault.cc
@@ -37,6 +37,9 @@ int GetPcFromUcontext(ucontext_t* uc, uintptr_t* pc) {
#elif defined(__i386__)
*pc = uc->uc_mcontext.gregs[REG_EIP];
return 1;
+#elif defined(__aarch64__)
+ *pc = uc->uc_mcontext.pc;
+ return 1;
#else
return 0;
#endif
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 8a45be12a..34016d4bd 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -14,10 +14,14 @@
#include <fcntl.h>
#include <signal.h>
+#include <sys/types.h>
#include <syscall.h>
#include <unistd.h>
+#include <iostream>
+#include <list>
#include <string>
+#include <vector>
#include "gtest/gtest.h"
#include "absl/base/macros.h"
@@ -30,10 +34,13 @@
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/cleanup.h"
#include "test/util/eventfd_util.h"
+#include "test/util/fs_util.h"
#include "test/util/multiprocess_util.h"
#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
#include "test/util/timer_util.h"
ABSL_FLAG(std::string, child_setlock_on, "",
@@ -53,10 +60,6 @@ ABSL_FLAG(int32_t, socket_fd, -1,
namespace gvisor {
namespace testing {
-// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0
-// because "it isn't needed", even though Linux can return it via F_GETFL.
-constexpr int kOLargeFile = 00100000;
-
class FcntlLockTest : public ::testing::Test {
public:
void SetUp() override {
@@ -116,6 +119,15 @@ PosixErrorOr<Cleanup> SubprocessLock(std::string const& path, bool for_write,
return std::move(cleanup);
}
+TEST(FcntlTest, SetCloExecBadFD) {
+ // Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set.
+ FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
+ auto fd = f.get();
+ f.reset();
+ ASSERT_THAT(fcntl(fd, F_GETFD), SyscallFailsWithErrno(EBADF));
+ ASSERT_THAT(fcntl(fd, F_SETFD, FD_CLOEXEC), SyscallFailsWithErrno(EBADF));
+}
+
TEST(FcntlTest, SetCloExec) {
// Open an eventfd file descriptor with FD_CLOEXEC descriptor flag not set.
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0));
@@ -183,45 +195,85 @@ TEST(FcntlTest, SetFlags) {
EXPECT_EQ(rflags, expected);
}
-TEST_F(FcntlLockTest, SetLockBadFd) {
+void TestLock(int fd, short lock_type = F_RDLCK) { // NOLINT, type in flock
struct flock fl;
- fl.l_type = F_WRLCK;
+ fl.l_type = lock_type;
fl.l_whence = SEEK_SET;
fl.l_start = 0;
- // len 0 has a special meaning: lock all bytes despite how
- // large the file grows.
+ // len 0 locks all bytes despite how large the file grows.
fl.l_len = 0;
- EXPECT_THAT(fcntl(-1, F_SETLK, &fl), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallSucceeds());
}
-TEST_F(FcntlLockTest, SetLockPipe) {
- int fds[2];
- ASSERT_THAT(pipe(fds), SyscallSucceeds());
-
+void TestLockBadFD(int fd,
+ short lock_type = F_RDLCK) { // NOLINT, type in flock
struct flock fl;
- fl.l_type = F_WRLCK;
+ fl.l_type = lock_type;
fl.l_whence = SEEK_SET;
fl.l_start = 0;
- // Same as SetLockBadFd, but doesn't matter, we expect this to fail.
+ // len 0 locks all bytes despite how large the file grows.
fl.l_len = 0;
- EXPECT_THAT(fcntl(fds[0], F_SETLK, &fl), SyscallFailsWithErrno(EBADF));
- EXPECT_THAT(close(fds[0]), SyscallSucceeds());
- EXPECT_THAT(close(fds[1]), SyscallSucceeds());
+ EXPECT_THAT(fcntl(fd, F_SETLK, &fl), SyscallFailsWithErrno(EBADF));
}
+TEST_F(FcntlLockTest, SetLockBadFd) { TestLockBadFD(-1); }
+
TEST_F(FcntlLockTest, SetLockDir) {
auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0666));
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000));
+ TestLock(fd.get());
+}
- struct flock fl;
- fl.l_type = F_RDLCK;
- fl.l_whence = SEEK_SET;
- fl.l_start = 0;
- // Same as SetLockBadFd.
- fl.l_len = 0;
+TEST_F(FcntlLockTest, SetLockSymlink) {
+ // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH
+ // is supported.
+ SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(fcntl(fd.get(), F_SETLK, &fl), SyscallSucceeds());
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path()));
+
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000));
+ TestLockBadFD(fd.get());
+}
+
+TEST_F(FcntlLockTest, SetLockProc) {
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000));
+ TestLock(fd.get());
+}
+
+TEST_F(FcntlLockTest, SetLockPipe) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ TestLock(fds[0]);
+ TestLockBadFD(fds[0], F_WRLCK);
+
+ TestLock(fds[1], F_WRLCK);
+ TestLockBadFD(fds[1]);
+
+ EXPECT_THAT(close(fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(fds[1]), SyscallSucceeds());
+}
+
+TEST_F(FcntlLockTest, SetLockSocket) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ int sock = socket(AF_UNIX, SOCK_STREAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX));
+ ASSERT_THAT(
+ bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+
+ TestLock(sock);
+ EXPECT_THAT(close(sock), SyscallSucceeds());
}
TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) {
@@ -233,8 +285,7 @@ TEST_F(FcntlLockTest, SetLockBadOpenFlagsWrite) {
fl0.l_type = F_WRLCK;
fl0.l_whence = SEEK_SET;
fl0.l_start = 0;
- // Same as SetLockBadFd.
- fl0.l_len = 0;
+ fl0.l_len = 0; // Lock all file
// Expect that setting a write lock using a read only file descriptor
// won't work.
@@ -696,7 +747,7 @@ TEST_F(FcntlLockTest, SetWriteLockThenBlockingWriteLock) {
<< "Exited with code: " << status;
}
-// This test will veirfy that blocking works as expected when another process
+// This test will verify that blocking works as expected when another process
// holds a read lock when obtaining a write lock. This test will hold the lock
// for some amount of time and then wait for the second process to send over the
// socket_fd the amount of time it was blocked for before the lock succeeded.
@@ -906,14 +957,346 @@ TEST(FcntlTest, DupAfterO_ASYNC) {
EXPECT_EQ(after & O_ASYNC, O_ASYNC);
}
-TEST(FcntlTest, GetOwn) {
+TEST(FcntlTest, GetOwnNone) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Use the raw syscall because the glibc wrapper may convert F_{GET,SET}OWN
+ // into F_{GET,SET}OWN_EX.
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+ MaybeSave();
+}
+
+TEST(FcntlTest, GetOwnExNone) {
FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
- ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ f_owner_ex owner = {};
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &owner),
SyscallSucceedsWithValue(0));
}
+TEST(FcntlTest, SetOwnInvalidPid) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 12345678),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnInvalidPgrp) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -12345678),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ pid_t pid;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(pid));
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ pid_t pgid;
+ EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid),
+ SyscallSucceedsWithValue(0));
+
+ // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the
+ // negative return value as an error, converting the return value to -1 and
+ // setting errno accordingly.
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, F_OWNER_PGRP);
+ EXPECT_EQ(got_owner.pid, pgid);
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnUnset) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Set and unset pid.
+ pid_t pid;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+
+ // Set and unset pgid.
+ pid_t pgid;
+ EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+ MaybeSave();
+}
+
+// F_SETOWN flips the sign of negative values, an operation that is guarded
+// against overflow.
+TEST(FcntlTest, SetOwnOverflow) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, INT_MIN),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FcntlTest, SetOwnExInvalidType) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = __pid_type(-1);
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(FcntlTest, SetOwnExInvalidTid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_TID;
+ owner.pid = -1;
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnExInvalidPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PID;
+ owner.pid = -1;
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnExInvalidPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PGRP;
+ owner.pid = -1;
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST(FcntlTest, SetOwnExTid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_TID;
+ EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(owner.pid));
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnExPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PID;
+ EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(owner.pid));
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnExPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_PGRP;
+ EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceedsWithValue(0));
+
+ // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the
+ // negative return value as an error, converting the return value to -1 and
+ // setting errno accordingly.
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+ MaybeSave();
+}
+
+TEST(FcntlTest, SetOwnExUnset) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ // Set and unset pid.
+ f_owner_ex owner = {};
+ owner.type = F_OWNER_PID;
+ EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceedsWithValue(0));
+ owner.pid = 0;
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+
+ // Set and unset pgid.
+ owner.type = F_OWNER_PGRP;
+ EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds());
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceedsWithValue(0));
+ owner.pid = 0;
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner),
+ SyscallSucceedsWithValue(0));
+
+ EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN),
+ SyscallSucceedsWithValue(0));
+ MaybeSave();
+}
+
+TEST(FcntlTest, GetOwnExTid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_TID;
+ EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceedsWithValue(0));
+
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+}
+
+TEST(FcntlTest, GetOwnExPid) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_PID;
+ EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceedsWithValue(0));
+
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+}
+
+TEST(FcntlTest, GetOwnExPgrp) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ f_owner_ex set_owner = {};
+ set_owner.type = F_OWNER_PGRP;
+ EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds());
+
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner),
+ SyscallSucceedsWithValue(0));
+
+ f_owner_ex got_owner = {};
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(got_owner.type, set_owner.type);
+ EXPECT_EQ(got_owner.pid, set_owner.pid);
+}
+
+// Make sure that making multiple concurrent changes to async signal generation
+// does not cause any race issues.
+TEST(FcntlTest, SetFlSetOwnDoNotRace) {
+ FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
+
+ pid_t pid;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
+
+ constexpr absl::Duration runtime = absl::Milliseconds(300);
+ auto setAsync = [&s, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, O_ASYNC),
+ SyscallSucceeds());
+ sched_yield();
+ }
+ };
+ auto resetAsync = [&s, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, 0), SyscallSucceeds());
+ sched_yield();
+ }
+ };
+ auto setOwn = [&s, &pid, &runtime] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid),
+ SyscallSucceeds());
+ sched_yield();
+ }
+ };
+
+ std::list<ScopedThread> threads;
+ for (int i = 0; i < 10; i++) {
+ threads.emplace_back(setAsync);
+ threads.emplace_back(resetAsync);
+ threads.emplace_back(setOwn);
+ }
+}
+
} // namespace
} // namespace testing
@@ -943,8 +1326,7 @@ int main(int argc, char** argv) {
fl.l_start = absl::GetFlag(FLAGS_child_setlock_start);
fl.l_len = absl::GetFlag(FLAGS_child_setlock_len);
- // Test the fcntl, no need to log, the error is unambiguously
- // from fcntl at this point.
+ // Test the fcntl.
int err = 0;
int ret = 0;
@@ -957,6 +1339,8 @@ int main(int argc, char** argv) {
if (ret == -1 && errno != 0) {
err = errno;
+ std::cerr << "CHILD lock " << setlock_on << " failed " << err
+ << std::endl;
}
// If there is a socket fd let's send back the time in microseconds it took
@@ -971,5 +1355,5 @@ int main(int argc, char** argv) {
exit(err);
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/file_base.h b/test/syscalls/linux/file_base.h
index 4d155b618..fb418e052 100644
--- a/test/syscalls/linux/file_base.h
+++ b/test/syscalls/linux/file_base.h
@@ -27,6 +27,7 @@
#include <sys/types.h>
#include <sys/uio.h>
#include <unistd.h>
+
#include <cstring>
#include <string>
@@ -51,17 +52,6 @@ class FileTest : public ::testing::Test {
test_file_fd_ = ASSERT_NO_ERRNO_AND_VALUE(
Open(test_file_name_, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR));
- // FIXME(edahlgren): enable when mknod syscall is supported.
- // test_fifo_name_ = NewTempAbsPath();
- // ASSERT_THAT(mknod(test_fifo_name_.c_str()), S_IFIFO|0644, 0,
- // SyscallSucceeds());
- // ASSERT_THAT(test_fifo_[1] = open(test_fifo_name_.c_str(),
- // O_WRONLY),
- // SyscallSucceeds());
- // ASSERT_THAT(test_fifo_[0] = open(test_fifo_name_.c_str(),
- // O_RDONLY),
- // SyscallSucceeds());
-
ASSERT_THAT(pipe(test_pipe_), SyscallSucceeds());
ASSERT_THAT(fcntl(test_pipe_[0], F_SETFL, O_NONBLOCK), SyscallSucceeds());
}
@@ -95,110 +85,15 @@ class FileTest : public ::testing::Test {
CloseFile();
UnlinkFile();
ClosePipes();
-
- // FIXME(edahlgren): enable when mknod syscall is supported.
- // close(test_fifo_[0]);
- // close(test_fifo_[1]);
- // unlink(test_fifo_name_.c_str());
}
+ protected:
std::string test_file_name_;
- std::string test_fifo_name_;
FileDescriptor test_file_fd_;
- int test_fifo_[2];
int test_pipe_[2];
};
-class SocketTest : public ::testing::Test {
- public:
- void SetUp() override {
- test_unix_stream_socket_[0] = -1;
- test_unix_stream_socket_[1] = -1;
- test_unix_dgram_socket_[0] = -1;
- test_unix_dgram_socket_[1] = -1;
- test_unix_seqpacket_socket_[0] = -1;
- test_unix_seqpacket_socket_[1] = -1;
- test_tcp_socket_[0] = -1;
- test_tcp_socket_[1] = -1;
-
- ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_),
- SyscallSucceeds());
- ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK),
- SyscallSucceeds());
- ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_),
- SyscallSucceeds());
- ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK),
- SyscallSucceeds());
- ASSERT_THAT(
- socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_),
- SyscallSucceeds());
- ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK),
- SyscallSucceeds());
- }
-
- void TearDown() override {
- close(test_unix_stream_socket_[0]);
- close(test_unix_stream_socket_[1]);
-
- close(test_unix_dgram_socket_[0]);
- close(test_unix_dgram_socket_[1]);
-
- close(test_unix_seqpacket_socket_[0]);
- close(test_unix_seqpacket_socket_[1]);
-
- close(test_tcp_socket_[0]);
- close(test_tcp_socket_[1]);
- }
-
- int test_unix_stream_socket_[2];
- int test_unix_dgram_socket_[2];
- int test_unix_seqpacket_socket_[2];
- int test_tcp_socket_[2];
-};
-
-// MatchesStringLength checks that a tuple argument of (struct iovec *, int)
-// corresponding to an iovec array and its length, contains data that matches
-// the string length strlen.
-MATCHER_P(MatchesStringLength, strlen, "") {
- struct iovec* iovs = arg.first;
- int niov = arg.second;
- int offset = 0;
- for (int i = 0; i < niov; i++) {
- offset += iovs[i].iov_len;
- }
- if (offset != static_cast<int>(strlen)) {
- *result_listener << offset;
- return false;
- }
- return true;
-}
-
-// MatchesStringValue checks that a tuple argument of (struct iovec *, int)
-// corresponding to an iovec array and its length, contains data that matches
-// the string value str.
-MATCHER_P(MatchesStringValue, str, "") {
- struct iovec* iovs = arg.first;
- int len = strlen(str);
- int niov = arg.second;
- int offset = 0;
- for (int i = 0; i < niov; i++) {
- struct iovec iov = iovs[i];
- if (len < offset) {
- *result_listener << "strlen " << len << " < offset " << offset;
- return false;
- }
- if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) {
- absl::string_view iovec_string(static_cast<char*>(iov.iov_base),
- iov.iov_len);
- *result_listener << iovec_string << " @offset " << offset;
- return false;
- }
- offset += iov.iov_len;
- }
- return true;
-}
-
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc
index b4a91455d..638a93979 100644
--- a/test/syscalls/linux/flock.cc
+++ b/test/syscalls/linux/flock.cc
@@ -14,12 +14,14 @@
#include <errno.h>
#include <sys/file.h>
+
#include <string>
#include "gtest/gtest.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/syscalls/linux/file_base.h"
+#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -33,11 +35,6 @@ namespace {
class FlockTest : public FileTest {};
-TEST_F(FlockTest, BadFD) {
- // EBADF: fd is not an open file descriptor.
- ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF));
-}
-
TEST_F(FlockTest, InvalidOpCombinations) {
// The operation cannot be both exclusive and shared.
EXPECT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_SH | LOCK_NB),
@@ -56,15 +53,6 @@ TEST_F(FlockTest, NoOperationSpecified) {
SyscallFailsWithErrno(EINVAL));
}
-TEST(FlockTestNoFixture, FlockSupportsPipes) {
- int fds[2];
- ASSERT_THAT(pipe(fds), SyscallSucceeds());
-
- EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds());
- EXPECT_THAT(close(fds[0]), SyscallSucceeds());
- EXPECT_THAT(close(fds[1]), SyscallSucceeds());
-}
-
TEST_F(FlockTest, TestSimpleExLock) {
// Test that we can obtain an exclusive lock (no other holders)
// and that we can unlock it.
@@ -582,6 +570,66 @@ TEST_F(FlockTest, BlockingLockFirstExclusiveSecondExclusive_NoRandomSave) {
EXPECT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceeds());
}
+TEST(FlockTestNoFixture, BadFD) {
+ // EBADF: fd is not an open file descriptor.
+ ASSERT_THAT(flock(-1, 0), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FlockTestNoFixture, FlockDir) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockSymlink) {
+ // TODO(gvisor.dev/issue/2782): Replace with IsRunningWithVFS1() when O_PATH
+ // is supported.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(GetAbsoluteTestTmpdir(), file.path()));
+
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(symlink.path(), O_RDONLY | O_PATH, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(FlockTestNoFixture, FlockProc) {
+ auto fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/status", O_RDONLY, 0000));
+ EXPECT_THAT(flock(fd.get(), LOCK_EX | LOCK_NB), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockPipe) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+
+ EXPECT_THAT(flock(fds[0], LOCK_EX | LOCK_NB), SyscallSucceeds());
+ // Check that the pipe was locked above.
+ EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallFailsWithErrno(EAGAIN));
+
+ EXPECT_THAT(flock(fds[0], LOCK_UN), SyscallSucceeds());
+ EXPECT_THAT(flock(fds[1], LOCK_EX | LOCK_NB), SyscallSucceeds());
+
+ EXPECT_THAT(close(fds[0]), SyscallSucceeds());
+ EXPECT_THAT(close(fds[1]), SyscallSucceeds());
+}
+
+TEST(FlockTestNoFixture, FlockSocket) {
+ int sock = socket(AF_UNIX, SOCK_STREAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true /* abstract */, AF_UNIX));
+ ASSERT_THAT(
+ bind(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(flock(sock, LOCK_EX | LOCK_NB), SyscallSucceeds());
+ EXPECT_THAT(close(sock), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/fork.cc b/test/syscalls/linux/fork.cc
index dd6e1a422..853f6231a 100644
--- a/test/syscalls/linux/fork.cc
+++ b/test/syscalls/linux/fork.cc
@@ -20,6 +20,7 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <atomic>
#include <cstdlib>
@@ -214,6 +215,8 @@ TEST_F(ForkTest, PrivateMapping) {
EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
}
+// CPUID is x86 specific.
+#ifdef __x86_64__
// Test that cpuid works after a fork.
TEST_F(ForkTest, Cpuid) {
pid_t child = Fork();
@@ -226,6 +229,7 @@ TEST_F(ForkTest, Cpuid) {
}
EXPECT_THAT(Wait(child), SyscallSucceedsWithValue(0));
}
+#endif
TEST_F(ForkTest, Mmap) {
pid_t child = Fork();
@@ -267,7 +271,7 @@ TEST_F(ForkTest, Alarm) {
EXPECT_EQ(0, alarmed);
}
-// Child cannot affect parent private memory.
+// Child cannot affect parent private memory. Regression test for b/24137240.
TEST_F(ForkTest, PrivateMemory) {
std::atomic<uint32_t> local(0);
@@ -294,6 +298,9 @@ TEST_F(ForkTest, PrivateMemory) {
}
// Kernel-accessed buffers should remain coherent across COW.
+//
+// The buffer must be >= usermem.ZeroCopyMinBytes, as UnsafeAccess operates
+// differently. Regression test for b/33811887.
TEST_F(ForkTest, COWSegment) {
constexpr int kBufSize = 1024;
char* read_buf = private_;
@@ -424,7 +431,6 @@ TEST(CloneTest, NewUserNamespacePermitsAllOtherNamespaces) {
<< "status = " << status;
}
-#ifdef __x86_64__
// Clone with CLONE_SETTLS and a non-canonical TLS address is rejected.
TEST(CloneTest, NonCanonicalTLS) {
constexpr uintptr_t kNonCanonical = 1ull << 48;
@@ -433,11 +439,25 @@ TEST(CloneTest, NonCanonicalTLS) {
// on this.
char stack;
+ // The raw system call interface on x86-64 is:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, int *child_tid,
+ // unsigned long tls);
+ //
+ // While on arm64, the order of the last two arguments is reversed:
+ // long clone(unsigned long flags, void *stack,
+ // int *parent_tid, unsigned long tls,
+ // int *child_tid);
+#if defined(__x86_64__)
EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
nullptr, kNonCanonical),
SyscallFailsWithErrno(EPERM));
-}
+#elif defined(__aarch64__)
+ EXPECT_THAT(syscall(__NR_clone, SIGCHLD | CLONE_SETTLS, &stack, nullptr,
+ kNonCanonical, nullptr),
+ SyscallFailsWithErrno(EPERM));
#endif
+}
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc
index e7e9f06a1..c47567b4e 100644
--- a/test/syscalls/linux/fpsig_fork.cc
+++ b/test/syscalls/linux/fpsig_fork.cc
@@ -27,9 +27,22 @@ namespace testing {
namespace {
+#ifdef __x86_64__
#define GET_XMM(__var, __xmm) \
asm volatile("movq %%" #__xmm ", %0" : "=r"(__var))
#define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var))
+#define GET_FP0(__var) GET_XMM(__var, xmm0)
+#define SET_FP0(__var) SET_XMM(__var, xmm0)
+#elif __aarch64__
+#define __stringify_1(x...) #x
+#define __stringify(x...) __stringify_1(x)
+#define GET_FPREG(var, regname) \
+ asm volatile("str " __stringify(regname) ", %0" : "=m"(var))
+#define SET_FPREG(var, regname) \
+ asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var))
+#define GET_FP0(var) GET_FPREG(var, d0)
+#define SET_FP0(var) SET_FPREG(var, d0)
+#endif
int parent, child;
@@ -40,7 +53,10 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
TEST_CHECK_MSG(child >= 0, "fork failed");
uint64_t val = SIGUSR1;
- SET_XMM(val, xmm0);
+ SET_FP0(val);
+ uint64_t got;
+ GET_FP0(got);
+ TEST_CHECK_MSG(val == got, "Basic FP check failed in sigusr1()");
}
TEST(FPSigTest, Fork) {
@@ -67,8 +83,9 @@ TEST(FPSigTest, Fork) {
// be the one clobbered.
uint64_t expected = 0xdeadbeeffacefeed;
- SET_XMM(expected, xmm0);
+ SET_FP0(expected);
+#ifdef __x86_64__
asm volatile(
"movl %[killnr], %%eax;"
"movl %[parent], %%edi;"
@@ -76,14 +93,23 @@ TEST(FPSigTest, Fork) {
"movl %[sig], %%edx;"
"syscall;"
:
- : [killnr] "i"(__NR_tgkill), [parent] "rm"(parent),
- [tid] "rm"(parent_tid), [sig] "i"(SIGUSR1)
+ : [ killnr ] "i"(__NR_tgkill), [ parent ] "rm"(parent),
+ [ tid ] "rm"(parent_tid), [ sig ] "i"(SIGUSR1)
: "rax", "rdi", "rsi", "rdx",
// Clobbered by syscall.
"rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(parent), "r"(parent_tid), "r"(SIGUSR1));
+#endif
uint64_t got;
- GET_XMM(got, xmm0);
+ GET_FP0(got);
if (getpid() == parent) { // Parent.
int status;
diff --git a/test/syscalls/linux/fpsig_nested.cc b/test/syscalls/linux/fpsig_nested.cc
index 395463aed..302d928d1 100644
--- a/test/syscalls/linux/fpsig_nested.cc
+++ b/test/syscalls/linux/fpsig_nested.cc
@@ -26,9 +26,22 @@ namespace testing {
namespace {
+#ifdef __x86_64__
#define GET_XMM(__var, __xmm) \
asm volatile("movq %%" #__xmm ", %0" : "=r"(__var))
#define SET_XMM(__var, __xmm) asm volatile("movq %0, %%" #__xmm : : "r"(__var))
+#define GET_FP0(__var) GET_XMM(__var, xmm0)
+#define SET_FP0(__var) SET_XMM(__var, xmm0)
+#elif __aarch64__
+#define __stringify_1(x...) #x
+#define __stringify(x...) __stringify_1(x)
+#define GET_FPREG(var, regname) \
+ asm volatile("str " __stringify(regname) ", %0" : "=m"(var))
+#define SET_FPREG(var, regname) \
+ asm volatile("ldr " __stringify(regname) ", %0" : "=m"(var))
+#define GET_FP0(var) GET_FPREG(var, d0)
+#define SET_FP0(var) SET_FPREG(var, d0)
+#endif
int pid;
int tid;
@@ -40,20 +53,21 @@ void sigusr2(int s, siginfo_t* siginfo, void* _uc) {
uint64_t val = SIGUSR2;
// Record the value of %xmm0 on entry and then clobber it.
- GET_XMM(entryxmm[1], xmm0);
- SET_XMM(val, xmm0);
- GET_XMM(exitxmm[1], xmm0);
+ GET_FP0(entryxmm[1]);
+ SET_FP0(val);
+ GET_FP0(exitxmm[1]);
}
void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
uint64_t val = SIGUSR1;
// Record the value of %xmm0 on entry and then clobber it.
- GET_XMM(entryxmm[0], xmm0);
- SET_XMM(val, xmm0);
+ GET_FP0(entryxmm[0]);
+ SET_FP0(val);
// Send a SIGUSR2 to ourself. The signal mask is configured such that
// the SIGUSR2 handler will run before this handler returns.
+#ifdef __x86_64__
asm volatile(
"movl %[killnr], %%eax;"
"movl %[pid], %%edi;"
@@ -61,15 +75,24 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
"movl %[sig], %%edx;"
"syscall;"
:
- : [killnr] "i"(__NR_tgkill), [pid] "rm"(pid), [tid] "rm"(tid),
- [sig] "i"(SIGUSR2)
+ : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid),
+ [ sig ] "i"(SIGUSR2)
: "rax", "rdi", "rsi", "rdx",
// Clobbered by syscall.
"rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(pid), "r"(tid), "r"(SIGUSR2));
+#endif
// Record value of %xmm0 again to verify that the nested signal handler
// does not clobber it.
- GET_XMM(exitxmm[0], xmm0);
+ GET_FP0(exitxmm[0]);
}
TEST(FPSigTest, NestedSignals) {
@@ -98,8 +121,9 @@ TEST(FPSigTest, NestedSignals) {
// to signal the current thread ensures that this is the clobbered thread.
uint64_t expected = 0xdeadbeeffacefeed;
- SET_XMM(expected, xmm0);
+ SET_FP0(expected);
+#ifdef __x86_64__
asm volatile(
"movl %[killnr], %%eax;"
"movl %[pid], %%edi;"
@@ -107,14 +131,23 @@ TEST(FPSigTest, NestedSignals) {
"movl %[sig], %%edx;"
"syscall;"
:
- : [killnr] "i"(__NR_tgkill), [pid] "rm"(pid), [tid] "rm"(tid),
- [sig] "i"(SIGUSR1)
+ : [ killnr ] "i"(__NR_tgkill), [ pid ] "rm"(pid), [ tid ] "rm"(tid),
+ [ sig ] "i"(SIGUSR1)
: "rax", "rdi", "rsi", "rdx",
// Clobbered by syscall.
"rcx", "r11");
+#elif __aarch64__
+ asm volatile(
+ "mov x8, %0\n"
+ "mov x0, %1\n"
+ "mov x1, %2\n"
+ "mov x2, %3\n"
+ "svc #0\n" ::"r"(__NR_tgkill),
+ "r"(pid), "r"(tid), "r"(SIGUSR1));
+#endif
uint64_t got;
- GET_XMM(got, xmm0);
+ GET_FP0(got);
//
// The checks below verifies the following:
diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc
index d3e3f998c..90b1f0508 100644
--- a/test/syscalls/linux/futex.cc
+++ b/test/syscalls/linux/futex.cc
@@ -18,6 +18,7 @@
#include <sys/syscall.h>
#include <sys/time.h>
#include <sys/types.h>
+#include <syscall.h>
#include <unistd.h>
#include <algorithm>
@@ -239,6 +240,27 @@ TEST_P(PrivateAndSharedFutexTest, Wake1_NoRandomSave) {
EXPECT_THAT(futex_wake(IsPrivate(), &a, 1), SyscallSucceedsWithValue(1));
}
+TEST_P(PrivateAndSharedFutexTest, Wake0_NoRandomSave) {
+ constexpr int kInitialValue = 1;
+ std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
+
+ // Prevent save/restore from interrupting futex_wait, which will cause it to
+ // return EAGAIN instead of the expected result if futex_wait is restarted
+ // after we change the value of a below.
+ DisableSave ds;
+ ScopedThread thread([&] {
+ EXPECT_THAT(futex_wait(IsPrivate(), &a, kInitialValue),
+ SyscallSucceedsWithValue(0));
+ });
+ absl::SleepFor(kWaiterStartupDelay);
+
+ // Change a so that if futex_wake happens before futex_wait, the latter
+ // returns EAGAIN instead of hanging the test.
+ a.fetch_add(1);
+ // The Linux kernel wakes one waiter even if val is 0 or negative.
+ EXPECT_THAT(futex_wake(IsPrivate(), &a, 0), SyscallSucceedsWithValue(1));
+}
+
TEST_P(PrivateAndSharedFutexTest, WakeAll_NoRandomSave) {
constexpr int kInitialValue = 1;
std::atomic<int> a = ATOMIC_VAR_INIT(kInitialValue);
@@ -716,6 +738,97 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) {
}
}
+int get_robust_list(int pid, struct robust_list_head** head_ptr,
+ size_t* len_ptr) {
+ return syscall(__NR_get_robust_list, pid, head_ptr, len_ptr);
+}
+
+int set_robust_list(struct robust_list_head* head, size_t len) {
+ return syscall(__NR_set_robust_list, head, len);
+}
+
+TEST(RobustFutexTest, BasicSetGet) {
+ struct robust_list_head hd = {};
+ struct robust_list_head* hd_ptr = &hd;
+
+ // Set!
+ EXPECT_THAT(set_robust_list(hd_ptr, sizeof(hd)), SyscallSucceedsWithValue(0));
+
+ // Get!
+ struct robust_list_head* new_hd_ptr = hd_ptr;
+ size_t len;
+ EXPECT_THAT(get_robust_list(0, &new_hd_ptr, &len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(new_hd_ptr, hd_ptr);
+ EXPECT_EQ(len, sizeof(hd));
+}
+
+TEST(RobustFutexTest, GetFromOtherTid) {
+ // Get the current tid and list head.
+ pid_t tid = gettid();
+ struct robust_list_head* hd_ptr = {};
+ size_t len;
+ EXPECT_THAT(get_robust_list(0, &hd_ptr, &len), SyscallSucceedsWithValue(0));
+
+ // Create a new thread.
+ ScopedThread t([&] {
+ // Current tid list head should be different from parent tid.
+ struct robust_list_head* got_hd_ptr = {};
+ EXPECT_THAT(get_robust_list(0, &got_hd_ptr, &len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_NE(hd_ptr, got_hd_ptr);
+
+ // Get the parent list head by passing its tid.
+ EXPECT_THAT(get_robust_list(tid, &got_hd_ptr, &len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(hd_ptr, got_hd_ptr);
+ });
+
+ // Wait for thread.
+ t.Join();
+}
+
+TEST(RobustFutexTest, InvalidSize) {
+ struct robust_list_head* hd = {};
+ EXPECT_THAT(set_robust_list(hd, sizeof(*hd) + 1),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(RobustFutexTest, PthreadMutexAttr) {
+ constexpr int kNumMutexes = 3;
+
+ // Create a bunch of robust mutexes.
+ pthread_mutexattr_t attrs[kNumMutexes];
+ pthread_mutex_t mtxs[kNumMutexes];
+ for (int i = 0; i < kNumMutexes; i++) {
+ TEST_PCHECK(pthread_mutexattr_init(&attrs[i]) == 0);
+ TEST_PCHECK(pthread_mutexattr_setrobust(&attrs[i], PTHREAD_MUTEX_ROBUST) ==
+ 0);
+ TEST_PCHECK(pthread_mutex_init(&mtxs[i], &attrs[i]) == 0);
+ }
+
+ // Start thread to lock the mutexes and then exit.
+ ScopedThread t([&] {
+ for (int i = 0; i < kNumMutexes; i++) {
+ TEST_PCHECK(pthread_mutex_lock(&mtxs[i]) == 0);
+ }
+ pthread_exit(NULL);
+ });
+
+ // Wait for thread.
+ t.Join();
+
+ // Now try to take the mutexes.
+ for (int i = 0; i < kNumMutexes; i++) {
+ // Should get EOWNERDEAD.
+ EXPECT_EQ(pthread_mutex_lock(&mtxs[i]), EOWNERDEAD);
+ // Make the mutex consistent.
+ EXPECT_EQ(pthread_mutex_consistent(&mtxs[i]), 0);
+ // Unlock.
+ EXPECT_EQ(pthread_mutex_unlock(&mtxs[i]), 0);
+ }
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc
index fe9cfafe8..b040cdcf7 100644
--- a/test/syscalls/linux/getdents.cc
+++ b/test/syscalls/linux/getdents.cc
@@ -23,6 +23,7 @@
#include <sys/types.h>
#include <syscall.h>
#include <unistd.h>
+
#include <map>
#include <string>
#include <unordered_map>
@@ -31,6 +32,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/container/node_hash_set.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "test/util/eventfd_util.h"
@@ -227,19 +229,28 @@ class GetdentsTest : public ::testing::Test {
// Multiple template parameters are not allowed, so we must use explicit
// template specialization to set the syscall number.
+
+// SYS_getdents isn't defined on arm64.
+#ifdef __x86_64__
template <>
int GetdentsTest<struct linux_dirent>::SyscallNum() {
return SYS_getdents;
}
+#endif
template <>
int GetdentsTest<struct linux_dirent64>::SyscallNum() {
return SYS_getdents64;
}
-// Test both legacy getdents and getdents64.
+#ifdef __x86_64__
+// Test both legacy getdents and getdents64 on x86_64.
typedef ::testing::Types<struct linux_dirent, struct linux_dirent64>
GetdentsTypes;
+#elif __aarch64__
+// Test only getdents64 on arm64.
+typedef ::testing::Types<struct linux_dirent64> GetdentsTypes;
+#endif
TYPED_TEST_SUITE(GetdentsTest, GetdentsTypes);
// N.B. TYPED_TESTs require explicitly using this-> to access members of
@@ -383,7 +394,7 @@ TYPED_TEST(GetdentsTest, ProcSelfFd) {
// Make the buffer very small since we want to iterate.
typename TestFixture::DirentBufferType dirents(
2 * sizeof(typename TestFixture::LinuxDirentType));
- std::unordered_set<int> prev_fds;
+ absl::node_hash_set<int> prev_fds;
while (true) {
dirents.Reset();
int rv;
diff --git a/test/syscalls/linux/getrandom.cc b/test/syscalls/linux/getrandom.cc
index f97f60029..f87cdd7a1 100644
--- a/test/syscalls/linux/getrandom.cc
+++ b/test/syscalls/linux/getrandom.cc
@@ -29,6 +29,8 @@ namespace {
#define SYS_getrandom 318
#elif defined(__i386__)
#define SYS_getrandom 355
+#elif defined(__aarch64__)
+#define SYS_getrandom 278
#else
#error "Unknown architecture"
#endif
diff --git a/test/syscalls/linux/getrusage.cc b/test/syscalls/linux/getrusage.cc
index 9bdb1e4cd..0e51d42a8 100644
--- a/test/syscalls/linux/getrusage.cc
+++ b/test/syscalls/linux/getrusage.cc
@@ -67,7 +67,7 @@ TEST(GetrusageTest, Grandchild) {
pid = fork();
if (pid == 0) {
int flags = MAP_ANONYMOUS | MAP_POPULATE | MAP_PRIVATE;
- void *addr =
+ void* addr =
mmap(nullptr, kGrandchildSizeKb * 1024, PROT_WRITE, flags, -1, 0);
TEST_PCHECK(addr != MAP_FAILED);
} else {
diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc
index 7384c27dc..5cb325a9e 100644
--- a/test/syscalls/linux/inotify.cc
+++ b/test/syscalls/linux/inotify.cc
@@ -18,7 +18,9 @@
#include <sys/epoll.h>
#include <sys/inotify.h>
#include <sys/ioctl.h>
+#include <sys/sendfile.h>
#include <sys/time.h>
+#include <sys/xattr.h>
#include <atomic>
#include <list>
@@ -28,11 +30,13 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
+#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/epoll_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -330,9 +334,32 @@ PosixErrorOr<int> InotifyAddWatch(int fd, const std::string& path,
return wd;
}
-TEST(Inotify, InotifyFdNotWritable) {
+TEST(Inotify, IllegalSeek) {
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
- EXPECT_THAT(write(fd.get(), "x", 1), SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(Inotify, IllegalPread) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ int val;
+ EXPECT_THAT(pread(fd.get(), &val, sizeof(val), 0),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(Inotify, IllegalPwrite) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ EXPECT_THAT(pwrite(fd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST(Inotify, IllegalWrite) {
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0));
+ int val = 0;
+ EXPECT_THAT(write(fd.get(), &val, sizeof(val)), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(Inotify, InitFlags) {
+ EXPECT_THAT(inotify_init1(IN_NONBLOCK | IN_CLOEXEC), SyscallSucceeds());
+ EXPECT_THAT(inotify_init1(12345), SyscallFailsWithErrno(EINVAL));
}
TEST(Inotify, NonBlockingReadReturnsEagain) {
@@ -395,7 +422,7 @@ TEST(Inotify, CanDeleteFileAfterRemovingWatch) {
file1.reset();
}
-TEST(Inotify, CanRemoveWatchAfterDeletingFile) {
+TEST(Inotify, RemoveWatchAfterDeletingFileFails) {
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
TempPath file1 =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
@@ -491,17 +518,23 @@ TEST(Inotify, DeletingChildGeneratesEvents) {
Event(IN_DELETE, root_wd, Basename(file1_path))}));
}
+// Creating a file in "parent/child" should generate events for child, but not
+// parent.
TEST(Inotify, CreatingFileGeneratesEvents) {
- const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath child =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path()));
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), parent.path(), IN_ALL_EVENTS));
const int wd = ASSERT_NO_ERRNO_AND_VALUE(
- InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ InotifyAddWatch(fd.get(), child.path(), IN_ALL_EVENTS));
// Create a new file in the directory.
const TempPath file1 =
- ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(child.path()));
const std::vector<Event> events =
ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
@@ -554,6 +587,47 @@ TEST(Inotify, WritingFileGeneratesModifyEvent) {
ASSERT_THAT(events, Are({Event(IN_MODIFY, wd, Basename(file1.path()))}));
}
+TEST(Inotify, SizeZeroReadWriteGeneratesNothing) {
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+
+ // Read from the empty file.
+ int val;
+ ASSERT_THAT(read(file1_fd.get(), &val, sizeof(val)),
+ SyscallSucceedsWithValue(0));
+
+ // Write zero bytes.
+ ASSERT_THAT(write(file1_fd.get(), "", 0), SyscallSucceedsWithValue(0));
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({}));
+}
+
+TEST(Inotify, FailedFileCreationGeneratesNoEvents) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string dir_path = dir.path();
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(fd.get(), dir_path, IN_ALL_EVENTS));
+
+ const char* p = dir_path.c_str();
+ ASSERT_THAT(mkdir(p, 0777), SyscallFails());
+ ASSERT_THAT(mknod(p, S_IFIFO, 0777), SyscallFails());
+ ASSERT_THAT(symlink(p, p), SyscallFails());
+ ASSERT_THAT(link(p, p), SyscallFails());
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({}));
+}
+
TEST(Inotify, WatchSetAfterOpenReportsCloseFdEvent) {
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const FileDescriptor fd =
@@ -602,7 +676,7 @@ TEST(Inotify, ChildrenDeletionInWatchedDirGeneratesEvent) {
Event(IN_DELETE | IN_ISDIR, wd, Basename(dir1_path))}));
}
-TEST(Inotify, WatchTargetDeletionGeneratesEvent) {
+TEST(Inotify, RmdirOnWatchedTargetGeneratesEvent) {
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
@@ -977,7 +1051,7 @@ TEST(Inotify, WatchOnRelativePath) {
ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
// Change working directory to root.
- const char* old_working_dir = get_current_dir_name();
+ const FileDescriptor cwd = ASSERT_NO_ERRNO_AND_VALUE(Open(".", O_PATH));
EXPECT_THAT(chdir(root.path().c_str()), SyscallSucceeds());
// Add a watch on file1 with a relative path.
@@ -997,7 +1071,7 @@ TEST(Inotify, WatchOnRelativePath) {
// continue to hold a reference, random save/restore tests can fail if a save
// is triggered after "root" is unlinked; we can't save deleted fs objects
// with active references.
- EXPECT_THAT(chdir(old_working_dir), SyscallSucceeds());
+ EXPECT_THAT(fchdir(cwd.get()), SyscallSucceeds());
}
TEST(Inotify, ZeroLengthReadWriteDoesNotGenerateEvent) {
@@ -1055,9 +1129,9 @@ TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) {
const TempPath file1 =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
- const FileDescriptor root_fd =
+ FileDescriptor root_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(root.path(), O_RDONLY));
- const FileDescriptor file1_fd =
+ FileDescriptor file1_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDWR));
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
@@ -1091,6 +1165,11 @@ TEST(Inotify, ChmodGeneratesAttribEvent_NoRandomSave) {
ASSERT_THAT(fchmodat(root_fd.get(), file1_basename.c_str(), S_IWGRP, 0),
SyscallSucceeds());
verify_chmod_events();
+
+ // Make sure the chmod'ed file descriptors are destroyed before DisableSave
+ // is destructed.
+ root_fd.reset();
+ file1_fd.reset();
}
TEST(Inotify, TruncateGeneratesModifyEvent) {
@@ -1223,7 +1302,7 @@ TEST(Inotify, LinkGeneratesAttribAndCreateEvents) {
InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
const int rc = link(file1.path().c_str(), link1.path().c_str());
- // link(2) is only supported on tmpfs in the sandbox.
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
(errno == EPERM || errno == ENOENT));
ASSERT_THAT(rc, SyscallSucceeds());
@@ -1246,7 +1325,7 @@ TEST(Inotify, UtimesGeneratesAttribEvent) {
const int wd = ASSERT_NO_ERRNO_AND_VALUE(
InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
- struct timeval times[2] = {{1, 0}, {2, 0}};
+ const struct timeval times[2] = {{1, 0}, {2, 0}};
EXPECT_THAT(futimes(file1_fd.get(), times), SyscallSucceeds());
const std::vector<Event> events =
@@ -1317,21 +1396,27 @@ TEST(Inotify, HardlinksReuseSameWatch) {
Event(IN_DELETE, root_wd, Basename(file1_path))}));
}
+// Calling mkdir within "parent/child" should generate an event for child, but
+// not parent.
TEST(Inotify, MkdirGeneratesCreateEventWithDirFlag) {
- const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath child =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path()));
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
- const int root_wd = ASSERT_NO_ERRNO_AND_VALUE(
- InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), parent.path(), IN_ALL_EVENTS));
+ const int child_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), child.path(), IN_ALL_EVENTS));
- const TempPath dir1(NewTempAbsPathInDir(root.path()));
+ const TempPath dir1(NewTempAbsPathInDir(child.path()));
ASSERT_THAT(mkdir(dir1.path().c_str(), 0777), SyscallSucceeds());
const std::vector<Event> events =
ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
ASSERT_THAT(
events,
- Are({Event(IN_CREATE | IN_ISDIR, root_wd, Basename(dir1.path()))}));
+ Are({Event(IN_CREATE | IN_ISDIR, child_wd, Basename(dir1.path()))}));
}
TEST(Inotify, MultipleInotifyInstancesAndWatchesAllGetEvents) {
@@ -1419,20 +1504,26 @@ TEST(Inotify, DuplicateWatchReturnsSameWatchDescriptor) {
TEST(Inotify, UnmatchedEventsAreDiscarded) {
const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- const TempPath file1 =
+ TempPath file1 =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path()));
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
- ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS));
- const FileDescriptor file1_fd =
+ FileDescriptor file1_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY));
- const std::vector<Event> events =
- ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
// We only asked for access events, the open event should be discarded.
ASSERT_THAT(events, Are({}));
+
+ // IN_IGNORED events are always generated, regardless of the mask.
+ file1_fd.reset();
+ file1.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_IGNORED, wd)}));
}
TEST(Inotify, AddWatchWithInvalidEventMaskFails) {
@@ -1591,6 +1682,754 @@ TEST(Inotify, EpollNoDeadlock) {
}
}
+TEST(Inotify, Fallocate) {
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS));
+
+ // Do an arbitrary modification with fallocate.
+ ASSERT_THAT(RetryEINTR(fallocate)(fd.get(), 0, 0, 123), SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_MODIFY, wd)}));
+}
+
+TEST(Inotify, Sendfile) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(root.path(), "x", 0644));
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor in =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+ const FileDescriptor out =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
+
+ // Create separate inotify instances for the in and out fds. If both watches
+ // were on the same instance, we would have discrepancies between Linux and
+ // gVisor (order of events, duplicate events), which is not that important
+ // since inotify is asynchronous anyway.
+ const FileDescriptor in_inotify =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const FileDescriptor out_inotify =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int in_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(in_inotify.get(), in_file.path(), IN_ALL_EVENTS));
+ const int out_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(out_inotify.get(), out_file.path(), IN_ALL_EVENTS));
+
+ ASSERT_THAT(sendfile(out.get(), in.get(), /*offset=*/nullptr, 1),
+ SyscallSucceeds());
+
+ // Expect a single access event and a single modify event.
+ std::vector<Event> in_events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(in_inotify.get()));
+ std::vector<Event> out_events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(out_inotify.get()));
+ EXPECT_THAT(in_events, Are({Event(IN_ACCESS, in_wd)}));
+ EXPECT_THAT(out_events, Are({Event(IN_MODIFY, out_wd)}));
+}
+
+// On Linux, inotify behavior is not very consistent with splice(2). We try our
+// best to emulate Linux for very basic calls to splice.
+TEST(Inotify, SpliceOnWatchTarget) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ dir.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), dir.path(), IN_ALL_EVENTS));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS));
+
+ EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr, 1, /*flags=*/0),
+ SyscallSucceedsWithValue(1));
+
+ // Surprisingly, events are not generated in Linux if we read from a file.
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({}));
+
+ EXPECT_THAT(splice(pipes[0], nullptr, fd.get(), nullptr, 1, /*flags=*/0),
+ SyscallSucceedsWithValue(1));
+
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({
+ Event(IN_MODIFY, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, file_wd),
+ }));
+}
+
+TEST(Inotify, SpliceOnInotifyFD) {
+ int pipes[2];
+ ASSERT_THAT(pipe2(pipes, O_NONBLOCK), SyscallSucceeds());
+
+ const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ root.path(), "some content", TempPath::kDefaultFileMode));
+
+ const FileDescriptor file1_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_RDONLY));
+ const int watcher = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS));
+
+ char buf;
+ EXPECT_THAT(read(file1_fd.get(), &buf, 1), SyscallSucceeds());
+
+ EXPECT_THAT(splice(fd.get(), nullptr, pipes[1], nullptr,
+ sizeof(struct inotify_event) + 1, SPLICE_F_NONBLOCK),
+ SyscallSucceedsWithValue(sizeof(struct inotify_event)));
+
+ const FileDescriptor read_fd(pipes[0]);
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(read_fd.get()));
+ ASSERT_THAT(events, Are({Event(IN_ACCESS, watcher)}));
+}
+
+// Watches on a parent should not be triggered by actions on a hard link to one
+// of its children that has a different parent.
+TEST(Inotify, LinkOnOtherParent) {
+ const TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path()));
+ std::string link_path = NewTempAbsPathInDir(dir2.path());
+
+ const int rc = link(file.path().c_str(), link_path.c_str());
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
+ (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), dir1.path(), IN_ALL_EVENTS));
+
+ // Perform various actions on the link outside of dir1, which should trigger
+ // no inotify events.
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(link_path.c_str(), O_RDWR));
+ int val = 0;
+ ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds());
+ ASSERT_THAT(unlink(link_path.c_str()), SyscallSucceeds());
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+}
+
+TEST(Inotify, Xattr) {
+ // TODO(gvisor.dev/issue/1636): Support extended attributes in runsc gofer.
+ SKIP_IF(IsRunningOnGvisor());
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string path = file.path();
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR));
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), path, IN_ALL_EVENTS));
+
+ const char* cpath = path.c_str();
+ const char* name = "user.test";
+ int val = 123;
+ ASSERT_THAT(setxattr(cpath, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+
+ ASSERT_THAT(getxattr(cpath, name, &val, sizeof(val)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ char list[100];
+ ASSERT_THAT(listxattr(cpath, list, sizeof(list)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ ASSERT_THAT(removexattr(cpath, name), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+
+ ASSERT_THAT(fsetxattr(fd.get(), name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+
+ ASSERT_THAT(fgetxattr(fd.get(), name, &val, sizeof(val)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ ASSERT_THAT(flistxattr(fd.get(), list, sizeof(list)), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({}));
+
+ ASSERT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd)}));
+}
+
+TEST(Inotify, Exec) {
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath bin = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(dir.path(), "/bin/true"));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(fd.get(), bin.path(), IN_ALL_EVENTS));
+
+ // Perform exec.
+ ScopedThread t([&bin]() {
+ ASSERT_THAT(execl(bin.path().c_str(), bin.path().c_str(), (char*)nullptr),
+ SyscallSucceeds());
+ });
+ t.Join();
+
+ std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get()));
+ EXPECT_THAT(events, Are({Event(IN_OPEN, wd), Event(IN_ACCESS, wd)}));
+}
+
+// Watches without IN_EXCL_UNLINK, should continue to emit events for file
+// descriptors after their corresponding files have been unlinked.
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) {
+ const DisableSave ds;
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(dir.path(), "123", TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), dir.path(), IN_ALL_EVENTS));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS));
+
+ ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
+ int val = 0;
+ ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, file_wd),
+ Event(IN_DELETE, dir_wd, Basename(file.path())),
+ Event(IN_ACCESS, dir_wd, Basename(file.path())),
+ Event(IN_ACCESS, file_wd),
+ Event(IN_MODIFY, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, file_wd),
+ }));
+
+ fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_CLOSE_WRITE, dir_wd, Basename(file.path())),
+ Event(IN_CLOSE_WRITE, file_wd),
+ Event(IN_DELETE_SELF, file_wd),
+ Event(IN_IGNORED, file_wd),
+ }));
+}
+
+// Watches created with IN_EXCL_UNLINK will stop emitting events on fds for
+// children that have already been unlinked.
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlink_NoRandomSave) {
+ const DisableSave ds;
+ // TODO(gvisor.dev/issue/1624): This test fails on VFS1.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // Unlink the child, which should cause further operations on the open file
+ // descriptor to be ignored.
+ ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
+ int val = 0;
+ ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, file_wd),
+ Event(IN_DELETE, dir_wd, Basename(file.path())),
+ }));
+
+ fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({
+ Event(IN_DELETE_SELF, file_wd),
+ Event(IN_IGNORED, file_wd),
+ }));
+}
+
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1624): This test fails on VFS1. Remove once VFS1 is
+ // deleted.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const DisableSave ds;
+
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath dir =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(parent.path()));
+ std::string dirPath = dir.path();
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dirPath.c_str(), O_RDONLY | O_DIRECTORY));
+ const int parent_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), parent.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+ const int self_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // Unlink the dir, and then close the open fd.
+ ASSERT_THAT(rmdir(dirPath.c_str()), SyscallSucceeds());
+ dir.reset();
+
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ // No close event should appear.
+ ASSERT_THAT(events,
+ Are({Event(IN_DELETE | IN_ISDIR, parent_wd, Basename(dirPath))}));
+
+ fd.reset();
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ ASSERT_THAT(events, Are({
+ Event(IN_DELETE_SELF, self_wd),
+ Event(IN_IGNORED, self_wd),
+ }));
+}
+
+// If "dir/child" and "dir/child2" are links to the same file, and "dir/child"
+// is unlinked, a watch on "dir" with IN_EXCL_UNLINK will exclude future events
+// for fds on "dir/child" but not "dir/child2".
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) {
+ const DisableSave ds;
+ // TODO(gvisor.dev/issue/1624): This test fails on VFS1.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ std::string path1 = file.path();
+ std::string path2 = NewTempAbsPathInDir(dir.path());
+
+ const int rc = link(path1.c_str(), path2.c_str());
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(IsRunningOnGvisor() && rc != 0 &&
+ (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+ const FileDescriptor fd1 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path1.c_str(), O_RDWR));
+ const FileDescriptor fd2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(path2.c_str(), O_RDWR));
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // After unlinking path1, only events on the fd for path2 should be generated.
+ ASSERT_THAT(unlink(path1.c_str()), SyscallSucceeds());
+ ASSERT_THAT(write(fd1.get(), "x", 1), SyscallSucceeds());
+ ASSERT_THAT(write(fd2.get(), "x", 1), SyscallSucceeds());
+
+ const std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_DELETE, wd, Basename(path1)),
+ Event(IN_MODIFY, wd, Basename(path2)),
+ }));
+}
+
+// On native Linux, actions of data type FSNOTIFY_EVENT_INODE are not affected
+// by IN_EXCL_UNLINK (see
+// fs/notify/inotify/inotify_fsnotify.c:inotify_handle_event). Inode-level
+// events include changes to metadata and extended attributes.
+//
+// We need to disable S/R because there are filesystems where we cannot re-open
+// fds to an unlinked file across S/R, e.g. gofer-backed filesytems.
+TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) {
+ const DisableSave ds;
+
+ const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_RDWR));
+
+ // NOTE(b/157163751): Create another link before unlinking. This is needed for
+ // the gofer filesystem in gVisor, where open fds will not work once the link
+ // count hits zero. In VFS2, we end up skipping the gofer test anyway, because
+ // hard links are not supported for gofer fs.
+ if (IsRunningOnGvisor()) {
+ std::string link_path = NewTempAbsPath();
+ const int rc = link(file.path().c_str(), link_path.c_str());
+ // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox.
+ SKIP_IF(rc != 0 && (errno == EPERM || errno == ENOENT));
+ ASSERT_THAT(rc, SyscallSucceeds());
+ }
+
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+ const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(
+ inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK));
+
+ // Even after unlinking, inode-level operations will trigger events regardless
+ // of IN_EXCL_UNLINK.
+ ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
+
+ // Perform various actions on fd.
+ ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds());
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, file_wd),
+ Event(IN_DELETE, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, dir_wd, Basename(file.path())),
+ Event(IN_MODIFY, file_wd),
+ }));
+
+ const struct timeval times[2] = {{1, 0}, {2, 0}};
+ ASSERT_THAT(futimes(fd.get(), times), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, dir_wd, Basename(file.path())),
+ Event(IN_ATTRIB, file_wd),
+ }));
+
+ // S/R is disabled on this entire test due to behavior with unlink; it must
+ // also be disabled after this point because of fchmod.
+ ASSERT_THAT(fchmod(fd.get(), 0777), SyscallSucceeds());
+ events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_ATTRIB, dir_wd, Basename(file.path())),
+ Event(IN_ATTRIB, file_wd),
+ }));
+}
+
+TEST(Inotify, OneShot) {
+ // TODO(gvisor.dev/issue/1624): IN_ONESHOT not supported in VFS1.
+ SKIP_IF(IsRunningWithVFS1());
+
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor inotify_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+
+ const int wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(inotify_fd.get(), file.path(), IN_MODIFY | IN_ONESHOT));
+
+ // Open an fd, write to it, and then close it.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+ ASSERT_THAT(write(fd.get(), "x", 1), SyscallSucceedsWithValue(1));
+ fd.reset();
+
+ // We should get a single event followed by IN_IGNORED indicating removal
+ // of the one-shot watch. Prior activity (i.e. open) that is not in the mask
+ // should not trigger removal, and activity after removal (i.e. close) should
+ // not generate events.
+ std::vector<Event> events =
+ ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get()));
+ EXPECT_THAT(events, Are({
+ Event(IN_MODIFY, wd),
+ Event(IN_IGNORED, wd),
+ }));
+
+ // The watch should already have been removed.
+ EXPECT_THAT(inotify_rm_watch(inotify_fd.get(), wd),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// This test helps verify that the lock order of filesystem and inotify locks
+// is respected when inotify instances and watch targets are concurrently being
+// destroyed.
+TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) {
+ const DisableSave ds; // Too many syscalls.
+
+ // A file descriptor protected by a mutex. This ensures that while a
+ // descriptor is in use, it cannot be closed and reused for a different file
+ // description.
+ struct atomic_fd {
+ int fd;
+ absl::Mutex mu;
+ };
+
+ // Set up initial inotify instances.
+ constexpr int num_fds = 3;
+ std::vector<atomic_fd> fds(num_fds);
+ for (int i = 0; i < num_fds; i++) {
+ int fd;
+ ASSERT_THAT(fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds());
+ fds[i].fd = fd;
+ }
+
+ // Set up initial watch targets.
+ std::vector<std::string> paths;
+ for (int i = 0; i < 3; i++) {
+ paths.push_back(NewTempAbsPath());
+ ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+
+ constexpr absl::Duration runtime = absl::Seconds(4);
+
+ // Constantly replace each inotify instance with a new one.
+ auto replace_fds = [&] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (auto& afd : fds) {
+ int new_fd;
+ ASSERT_THAT(new_fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds());
+ absl::MutexLock l(&afd.mu);
+ ASSERT_THAT(close(afd.fd), SyscallSucceeds());
+ afd.fd = new_fd;
+ for (auto& p : paths) {
+ // inotify_add_watch may fail if the file at p was deleted.
+ ASSERT_THAT(inotify_add_watch(afd.fd, p.c_str(), IN_ALL_EVENTS),
+ AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT)));
+ }
+ }
+ sched_yield();
+ }
+ };
+
+ std::list<ScopedThread> ts;
+ for (int i = 0; i < 3; i++) {
+ ts.emplace_back(replace_fds);
+ }
+
+ // Constantly replace each watch target with a new one.
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (auto& p : paths) {
+ ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds());
+ ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+ sched_yield();
+ }
+}
+
+// This test helps verify that the lock order of filesystem and inotify locks
+// is respected when adding/removing watches occurs concurrently with the
+// removal of their targets.
+TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) {
+ const DisableSave ds; // Too many syscalls.
+
+ // Set up inotify instances.
+ constexpr int num_fds = 3;
+ std::vector<int> fds(num_fds);
+ for (int i = 0; i < num_fds; i++) {
+ ASSERT_THAT(fds[i] = inotify_init1(IN_NONBLOCK), SyscallSucceeds());
+ }
+
+ // Set up initial watch targets.
+ std::vector<std::string> paths;
+ for (int i = 0; i < 3; i++) {
+ paths.push_back(NewTempAbsPath());
+ ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+
+ constexpr absl::Duration runtime = absl::Seconds(1);
+
+ // Constantly add/remove watches for each inotify instance/watch target pair.
+ auto add_remove_watches = [&] {
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (int fd : fds) {
+ for (auto& p : paths) {
+ // Do not assert on inotify_add_watch and inotify_rm_watch. They may
+ // fail if the file at p was deleted. inotify_add_watch may also fail
+ // if another thread beat us to adding a watch.
+ const int wd = inotify_add_watch(fd, p.c_str(), IN_ALL_EVENTS);
+ if (wd > 0) {
+ inotify_rm_watch(fd, wd);
+ }
+ }
+ }
+ sched_yield();
+ }
+ };
+
+ std::list<ScopedThread> ts;
+ for (int i = 0; i < 15; i++) {
+ ts.emplace_back(add_remove_watches);
+ }
+
+ // Constantly replace each watch target with a new one.
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ for (auto& p : paths) {
+ ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds());
+ ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds());
+ }
+ sched_yield();
+ }
+}
+
+// This test helps verify that the lock order of filesystem and inotify locks
+// is respected when many inotify events and filesystem operations occur
+// simultaneously.
+TEST(InotifyTest, NotifyNoDeadlock_NoRandomSave) {
+ const DisableSave ds; // Too many syscalls.
+
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const std::string dir = parent.path();
+
+ // mu protects file, which will change on rename.
+ absl::Mutex mu;
+ std::string file = NewTempAbsPathInDir(dir);
+ ASSERT_THAT(mknod(file.c_str(), 0644 | S_IFREG, 0), SyscallSucceeds());
+
+ const absl::Duration runtime = absl::Milliseconds(300);
+
+ // Add/remove watches on dir and file.
+ ScopedThread add_remove_watches([&] {
+ const FileDescriptor ifd =
+ ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK));
+ int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS));
+ int file_wd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS));
+ }
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ ASSERT_THAT(inotify_rm_watch(ifd.get(), file_wd), SyscallSucceeds());
+ ASSERT_THAT(inotify_rm_watch(ifd.get(), dir_wd), SyscallSucceeds());
+ dir_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS));
+ {
+ absl::ReaderMutexLock l(&mu);
+ file_wd = ASSERT_NO_ERRNO_AND_VALUE(
+ InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS));
+ }
+ sched_yield();
+ }
+ });
+
+ // Modify attributes on dir and file.
+ ScopedThread stats([&] {
+ int fd, dir_fd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds());
+ }
+ ASSERT_THAT(dir_fd = open(dir.c_str(), O_RDONLY | O_DIRECTORY),
+ SyscallSucceeds());
+ const struct timeval times[2] = {{1, 0}, {2, 0}};
+
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ {
+ absl::ReaderMutexLock l(&mu);
+ EXPECT_THAT(utimes(file.c_str(), times), SyscallSucceeds());
+ }
+ EXPECT_THAT(futimes(fd, times), SyscallSucceeds());
+ EXPECT_THAT(utimes(dir.c_str(), times), SyscallSucceeds());
+ EXPECT_THAT(futimes(dir_fd, times), SyscallSucceeds());
+ sched_yield();
+ }
+ });
+
+ // Modify extended attributes on dir and file.
+ ScopedThread xattrs([&] {
+ // TODO(gvisor.dev/issue/1636): Support extended attributes in runsc gofer.
+ if (!IsRunningOnGvisor()) {
+ int fd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds());
+ }
+
+ const char* name = "user.test";
+ int val = 123;
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(
+ setxattr(file.c_str(), name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ ASSERT_THAT(removexattr(file.c_str(), name), SyscallSucceeds());
+ }
+
+ ASSERT_THAT(fsetxattr(fd, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+ ASSERT_THAT(fremovexattr(fd, name), SyscallSucceeds());
+ sched_yield();
+ }
+ }
+ });
+
+ // Read and write file's contents. Read and write dir's entries.
+ ScopedThread read_write([&] {
+ int fd;
+ {
+ absl::ReaderMutexLock l(&mu);
+ ASSERT_THAT(fd = open(file.c_str(), O_RDWR), SyscallSucceeds());
+ }
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ int val = 123;
+ ASSERT_THAT(write(fd, &val, sizeof(val)), SyscallSucceeds());
+ ASSERT_THAT(read(fd, &val, sizeof(val)), SyscallSucceeds());
+ TempPath new_file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir));
+ ASSERT_NO_ERRNO(ListDir(dir, false));
+ new_file.reset();
+ sched_yield();
+ }
+ });
+
+ // Rename file.
+ for (auto start = absl::Now(); absl::Now() - start < runtime;) {
+ const std::string new_path = NewTempAbsPathInDir(dir);
+ {
+ absl::WriterMutexLock l(&mu);
+ ASSERT_THAT(rename(file.c_str(), new_path.c_str()), SyscallSucceeds());
+ file = new_path;
+ }
+ sched_yield();
+ }
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/ioctl.cc b/test/syscalls/linux/ioctl.cc
index c4f8bff08..b0a07a064 100644
--- a/test/syscalls/linux/ioctl.cc
+++ b/test/syscalls/linux/ioctl.cc
@@ -215,7 +215,8 @@ TEST_F(IoctlTest, FIOASYNCSelfTarget2) {
auto mask_cleanup =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_UNBLOCK, SIGIO));
- pid_t pid = getpid();
+ pid_t pid = -1;
+ EXPECT_THAT(pid = getpid(), SyscallSucceeds());
EXPECT_THAT(ioctl(pair->second_fd(), FIOSETOWN, &pid), SyscallSucceeds());
int set = 1;
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index 57e99596f..98d07ae85 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "test/syscalls/linux/ip_socket_test_util.h"
+
#include <net/if.h>
#include <netinet/in.h>
-#include <sys/ioctl.h>
#include <sys/socket.h>
-#include <cstring>
-#include "test/syscalls/linux/ip_socket_test_util.h"
+#include <cstring>
namespace gvisor {
namespace testing {
@@ -34,12 +34,11 @@ uint16_t PortFromInetSockaddr(const struct sockaddr* addr) {
}
PosixErrorOr<int> InterfaceIndex(std::string name) {
- // TODO(igudger): Consider using netlink.
- ifreq req = {};
- memcpy(req.ifr_name, name.c_str(), name.size());
- ASSIGN_OR_RETURN_ERRNO(auto sock, Socket(AF_INET, SOCK_DGRAM, 0));
- RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(sock.get(), SIOCGIFINDEX, &req));
- return req.ifr_ifindex;
+ int index = if_nametoindex(name.c_str());
+ if (index) {
+ return index;
+ }
+ return PosixError(errno);
}
namespace {
@@ -78,6 +77,33 @@ SocketPairKind DualStackTCPAcceptBindSocketPair(int type) {
/* dual_stack = */ true)};
}
+SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv6 TCP socket");
+ return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected IPv4 TCP socket");
+ return SocketPairKind{description, AF_INET, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET, type | SOCK_STREAM, 0,
+ /* dual_stack = */ false)};
+}
+
+SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type) {
+ std::string description =
+ absl::StrCat(DescribeSocketType(type), "connected dual stack TCP socket");
+ return SocketPairKind{description, AF_INET6, type | SOCK_STREAM, IPPROTO_TCP,
+ TCPAcceptBindPersistentListenerSocketPairCreator(
+ AF_INET6, type | SOCK_STREAM, 0,
+ /* dual_stack = */ true)};
+}
+
SocketPairKind IPv6UDPBidirectionalBindSocketPair(int type) {
std::string description =
absl::StrCat(DescribeSocketType(type), "connected IPv6 UDP socket");
@@ -149,17 +175,17 @@ SocketKind IPv6TCPUnboundSocket(int type) {
PosixError IfAddrHelper::Load() {
Release();
RETURN_ERROR_IF_SYSCALL_FAIL(getifaddrs(&ifaddr_));
- return PosixError(0);
+ return NoError();
}
void IfAddrHelper::Release() {
if (ifaddr_) {
freeifaddrs(ifaddr_);
+ ifaddr_ = nullptr;
}
- ifaddr_ = nullptr;
}
-std::vector<std::string> IfAddrHelper::InterfaceList(int family) {
+std::vector<std::string> IfAddrHelper::InterfaceList(int family) const {
std::vector<std::string> names;
for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
@@ -170,7 +196,7 @@ std::vector<std::string> IfAddrHelper::InterfaceList(int family) {
return names;
}
-sockaddr* IfAddrHelper::GetAddr(int family, std::string name) {
+const sockaddr* IfAddrHelper::GetAddr(int family, std::string name) const {
for (auto ifa = ifaddr_; ifa != NULL; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == NULL || ifa->ifa_addr->sa_family != family) {
continue;
@@ -182,28 +208,28 @@ sockaddr* IfAddrHelper::GetAddr(int family, std::string name) {
return nullptr;
}
-PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) {
+PosixErrorOr<int> IfAddrHelper::GetIndex(std::string name) const {
return InterfaceIndex(name);
}
-std::string GetAddr4Str(in_addr* a) {
+std::string GetAddr4Str(const in_addr* a) {
char str[INET_ADDRSTRLEN];
inet_ntop(AF_INET, a, str, sizeof(str));
return std::string(str);
}
-std::string GetAddr6Str(in6_addr* a) {
+std::string GetAddr6Str(const in6_addr* a) {
char str[INET6_ADDRSTRLEN];
inet_ntop(AF_INET6, a, str, sizeof(str));
return std::string(str);
}
-std::string GetAddrStr(sockaddr* a) {
+std::string GetAddrStr(const sockaddr* a) {
if (a->sa_family == AF_INET) {
- auto src = &(reinterpret_cast<sockaddr_in*>(a)->sin_addr);
+ auto src = &(reinterpret_cast<const sockaddr_in*>(a)->sin_addr);
return GetAddr4Str(src);
} else if (a->sa_family == AF_INET6) {
- auto src = &(reinterpret_cast<sockaddr_in6*>(a)->sin6_addr);
+ auto src = &(reinterpret_cast<const sockaddr_in6*>(a)->sin6_addr);
return GetAddr6Str(src);
}
return std::string("<invalid>");
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 072230d85..9c3859fcd 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -26,25 +26,6 @@
namespace gvisor {
namespace testing {
-// Possible values of the "st" field in a /proc/net/{tcp,udp} entry. Source:
-// Linux kernel, include/net/tcp_states.h.
-enum {
- TCP_ESTABLISHED = 1,
- TCP_SYN_SENT,
- TCP_SYN_RECV,
- TCP_FIN_WAIT1,
- TCP_FIN_WAIT2,
- TCP_TIME_WAIT,
- TCP_CLOSE,
- TCP_CLOSE_WAIT,
- TCP_LAST_ACK,
- TCP_LISTEN,
- TCP_CLOSING,
- TCP_NEW_SYN_RECV,
-
- TCP_MAX_STATES
-};
-
// Extracts the IP address from an inet sockaddr in network byte order.
uint32_t IPFromInetSockaddr(const struct sockaddr* addr);
@@ -69,6 +50,21 @@ SocketPairKind IPv4TCPAcceptBindSocketPair(int type);
// given type bound to the IPv4 loopback.
SocketPairKind DualStackTCPAcceptBindSocketPair(int type);
+// IPv6TCPAcceptBindPersistentListenerSocketPair is like
+// IPv6TCPAcceptBindSocketPair except it uses a persistent listening socket to
+// create all socket pairs.
+SocketPairKind IPv6TCPAcceptBindPersistentListenerSocketPair(int type);
+
+// IPv4TCPAcceptBindPersistentListenerSocketPair is like
+// IPv4TCPAcceptBindSocketPair except it uses a persistent listening socket to
+// create all socket pairs.
+SocketPairKind IPv4TCPAcceptBindPersistentListenerSocketPair(int type);
+
+// DualStackTCPAcceptBindPersistentListenerSocketPair is like
+// DualStackTCPAcceptBindSocketPair except it uses a persistent listening socket
+// to create all socket pairs.
+SocketPairKind DualStackTCPAcceptBindPersistentListenerSocketPair(int type);
+
// IPv6UDPBidirectionalBindSocketPair returns a SocketPairKind that represents
// SocketPairs created with bind() and connect() syscalls with AF_INET6 and the
// given type bound to the IPv6 loopback.
@@ -88,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type);
// SocketPairs created with AF_INET and the given type.
SocketPairKind IPv4UDPUnboundSocketPair(int type);
-// IPv4UDPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type.
+// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_DGRAM, and the given type.
SocketKind IPv4UDPUnboundSocket(int type);
-// IPv6UDPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type.
+// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_DGRAM, and the given type.
SocketKind IPv6UDPUnboundSocket(int type);
-// IPv4TCPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type.
+// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_STREAM and the given type.
SocketKind IPv4TCPUnboundSocket(int type);
-// IPv6TCPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type.
+// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_STREAM and the given type.
SocketKind IPv6TCPUnboundSocket(int type);
// IfAddrHelper is a helper class that determines the local interfaces present
@@ -114,24 +110,24 @@ class IfAddrHelper {
PosixError Load();
void Release();
- std::vector<std::string> InterfaceList(int family);
+ std::vector<std::string> InterfaceList(int family) const;
- struct sockaddr* GetAddr(int family, std::string name);
- PosixErrorOr<int> GetIndex(std::string name);
+ const sockaddr* GetAddr(int family, std::string name) const;
+ PosixErrorOr<int> GetIndex(std::string name) const;
private:
struct ifaddrs* ifaddr_;
};
// GetAddr4Str returns the given IPv4 network address structure as a string.
-std::string GetAddr4Str(in_addr* a);
+std::string GetAddr4Str(const in_addr* a);
// GetAddr6Str returns the given IPv6 network address structure as a string.
-std::string GetAddr6Str(in6_addr* a);
+std::string GetAddr6Str(const in6_addr* a);
// GetAddrStr returns the given IPv4 or IPv6 network address structure as a
// string.
-std::string GetAddrStr(sockaddr* a);
+std::string GetAddrStr(const sockaddr* a);
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/iptables.h b/test/syscalls/linux/iptables.h
index 616bea550..0719c60a4 100644
--- a/test/syscalls/linux/iptables.h
+++ b/test/syscalls/linux/iptables.h
@@ -188,7 +188,7 @@ struct ipt_replace {
unsigned int num_counters;
// The unchanged values from each ipt_entry's counters.
- struct xt_counters *counters;
+ struct xt_counters* counters;
// The entries to write to the table. This will run past the size defined by
// sizeof(srtuct ipt_replace);
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index 930d2b940..e397d5f57 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -246,7 +246,7 @@ int TestSIGPROFFairness(absl::Duration sleep) {
// The number of samples on the main thread should be very low as it did
// nothing.
- TEST_CHECK(result.main_thread_samples < 60);
+ TEST_CHECK(result.main_thread_samples < 80);
// Both workers should get roughly equal number of samples.
TEST_CHECK(result.worker_samples.size() == 2);
@@ -267,6 +267,20 @@ int TestSIGPROFFairness(absl::Duration sleep) {
// Random save/restore is disabled as it introduces additional latency and
// unpredictable distribution patterns.
TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) {
+ // On the KVM and ptrace platforms, switches between sentry and application
+ // context are sometimes extremely slow, causing the itimer to send SIGPROF to
+ // a thread that either already has one pending or has had SIGPROF delivered,
+ // but hasn't handled it yet (and thus therefore still has SIGPROF masked). In
+ // either case, since itimer signals are group-directed, signal sending falls
+ // back to notifying the thread group leader. ItimerSignalTest() fails if "too
+ // many" signals are delivered to the thread group leader, so these tests are
+ // flaky on these platforms.
+ //
+ // TODO(b/143247272): Clarify why context switches are so slow on KVM.
+ const auto gvisor_platform = GvisorPlatform();
+ SKIP_IF(gvisor_platform == Platform::kKVM ||
+ gvisor_platform == Platform::kPtrace);
+
pid_t child;
int execve_errno;
auto kill = ASSERT_NO_ERRNO_AND_VALUE(
@@ -288,6 +302,11 @@ TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyActive_NoRandomSave) {
// Random save/restore is disabled as it introduces additional latency and
// unpredictable distribution patterns.
TEST(ItimerTest, DeliversSIGPROFToThreadsRoughlyFairlyIdle_NoRandomSave) {
+ // See comment in DeliversSIGPROFToThreadsRoughlyFairlyActive.
+ const auto gvisor_platform = GvisorPlatform();
+ SKIP_IF(gvisor_platform == Platform::kKVM ||
+ gvisor_platform == Platform::kPtrace);
+
pid_t child;
int execve_errno;
auto kill = ASSERT_NO_ERRNO_AND_VALUE(
@@ -343,6 +362,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc
index dd5352954..544681168 100644
--- a/test/syscalls/linux/link.cc
+++ b/test/syscalls/linux/link.cc
@@ -55,7 +55,8 @@ TEST(LinkTest, CanCreateLinkFile) {
const std::string newname = NewTempAbsPath();
// Get the initial link count.
- uint64_t initial_link_count = ASSERT_NO_ERRNO_AND_VALUE(Links(oldfile.path()));
+ uint64_t initial_link_count =
+ ASSERT_NO_ERRNO_AND_VALUE(Links(oldfile.path()));
EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), SyscallSucceeds());
@@ -78,8 +79,13 @@ TEST(LinkTest, PermissionDenied) {
// Make the file "unsafe" to link by making it only readable, but not
// writable.
- const auto oldfile =
+ const auto unwriteable_file =
ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400));
+ const std::string special_path = NewTempAbsPath();
+ ASSERT_THAT(mkfifo(special_path.c_str(), 0666), SyscallSucceeds());
+ const auto setuid_file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666 | S_ISUID));
+
const std::string newname = NewTempAbsPath();
// Do setuid in a separate thread so that after finishing this test, the
@@ -96,8 +102,14 @@ TEST(LinkTest, PermissionDenied) {
EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
SyscallSucceeds());
- EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()),
+ EXPECT_THAT(link(unwriteable_file.path().c_str(), newname.c_str()),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(link(special_path.c_str(), newname.c_str()),
SyscallFailsWithErrno(EPERM));
+ if (!IsRunningWithVFS1()) {
+ EXPECT_THAT(link(setuid_file.path().c_str(), newname.c_str()),
+ SyscallFailsWithErrno(EPERM));
+ }
});
}
diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc
index a8af8e545..6ce1e6cc3 100644
--- a/test/syscalls/linux/lseek.cc
+++ b/test/syscalls/linux/lseek.cc
@@ -53,7 +53,7 @@ TEST(LseekTest, NegativeOffset) {
// A 32-bit off_t is not large enough to represent an offset larger than
// maximum file size on standard file systems, so it isn't possible to cause
// overflow.
-#ifdef __x86_64__
+#if defined(__x86_64__) || defined(__aarch64__)
TEST(LseekTest, Overflow) {
// HA! Classic Linux. We really should have an EOVERFLOW
// here, since we're seeking to something that cannot be
diff --git a/test/syscalls/linux/madvise.cc b/test/syscalls/linux/madvise.cc
index 7fd0ea20c..5a1973f60 100644
--- a/test/syscalls/linux/madvise.cc
+++ b/test/syscalls/linux/madvise.cc
@@ -139,7 +139,7 @@ TEST(MadviseDontneedTest, IgnoresPermissions) {
TEST(MadviseDontforkTest, AddressLength) {
auto m =
ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(kPageSize, PROT_NONE, MAP_PRIVATE));
- char *addr = static_cast<char *>(m.ptr());
+ char* addr = static_cast<char*>(m.ptr());
// Address must be page aligned.
EXPECT_THAT(madvise(addr + 1, kPageSize, MADV_DONTFORK),
@@ -168,9 +168,9 @@ TEST(MadviseDontforkTest, DontforkShared) {
Mapping m = ASSERT_NO_ERRNO_AND_VALUE(Mmap(
nullptr, kPageSize * 2, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0));
- const Mapping ms1 = Mapping(reinterpret_cast<void *>(m.addr()), kPageSize);
+ const Mapping ms1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize);
const Mapping ms2 =
- Mapping(reinterpret_cast<void *>(m.addr() + kPageSize), kPageSize);
+ Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize);
m.release();
ASSERT_THAT(madvise(ms2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds());
@@ -197,11 +197,11 @@ TEST(MadviseDontforkTest, DontforkAnonPrivate) {
// Mmap three anonymous pages and MADV_DONTFORK the middle page.
Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
MmapAnon(kPageSize * 3, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- const Mapping mp1 = Mapping(reinterpret_cast<void *>(m.addr()), kPageSize);
+ const Mapping mp1 = Mapping(reinterpret_cast<void*>(m.addr()), kPageSize);
const Mapping mp2 =
- Mapping(reinterpret_cast<void *>(m.addr() + kPageSize), kPageSize);
+ Mapping(reinterpret_cast<void*>(m.addr() + kPageSize), kPageSize);
const Mapping mp3 =
- Mapping(reinterpret_cast<void *>(m.addr() + 2 * kPageSize), kPageSize);
+ Mapping(reinterpret_cast<void*>(m.addr() + 2 * kPageSize), kPageSize);
m.release();
ASSERT_THAT(madvise(mp2.ptr(), kPageSize, MADV_DONTFORK), SyscallSucceeds());
diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc
index e57b49a4a..f8b7f7938 100644
--- a/test/syscalls/linux/memfd.cc
+++ b/test/syscalls/linux/memfd.cc
@@ -16,6 +16,7 @@
#include <fcntl.h>
#include <linux/magic.h>
#include <linux/memfd.h>
+#include <linux/unistd.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/statfs.h>
diff --git a/test/syscalls/linux/memory_accounting.cc b/test/syscalls/linux/memory_accounting.cc
index ff2f49863..94aea4077 100644
--- a/test/syscalls/linux/memory_accounting.cc
+++ b/test/syscalls/linux/memory_accounting.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <sys/mman.h>
+
#include <map>
#include "gtest/gtest.h"
diff --git a/test/syscalls/linux/mempolicy.cc b/test/syscalls/linux/mempolicy.cc
index 9d5f47651..059fad598 100644
--- a/test/syscalls/linux/mempolicy.cc
+++ b/test/syscalls/linux/mempolicy.cc
@@ -43,17 +43,17 @@ namespace {
#define MPOL_MF_MOVE (1 << 1)
#define MPOL_MF_MOVE_ALL (1 << 2)
-int get_mempolicy(int *policy, uint64_t *nmask, uint64_t maxnode, void *addr,
+int get_mempolicy(int* policy, uint64_t* nmask, uint64_t maxnode, void* addr,
int flags) {
return syscall(SYS_get_mempolicy, policy, nmask, maxnode, addr, flags);
}
-int set_mempolicy(int mode, uint64_t *nmask, uint64_t maxnode) {
+int set_mempolicy(int mode, uint64_t* nmask, uint64_t maxnode) {
return syscall(SYS_set_mempolicy, mode, nmask, maxnode);
}
-int mbind(void *addr, unsigned long len, int mode,
- const unsigned long *nodemask, unsigned long maxnode,
+int mbind(void* addr, unsigned long len, int mode,
+ const unsigned long* nodemask, unsigned long maxnode,
unsigned flags) {
return syscall(SYS_mbind, addr, len, mode, nodemask, maxnode, flags);
}
@@ -68,7 +68,7 @@ Cleanup ScopedMempolicy() {
// Temporarily change the memory policy for the calling thread within the
// caller's scope.
-PosixErrorOr<Cleanup> ScopedSetMempolicy(int mode, uint64_t *nmask,
+PosixErrorOr<Cleanup> ScopedSetMempolicy(int mode, uint64_t* nmask,
uint64_t maxnode) {
if (set_mempolicy(mode, nmask, maxnode)) {
return PosixError(errno, "set_mempolicy");
diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc
index cf138d328..4036a9275 100644
--- a/test/syscalls/linux/mkdir.cc
+++ b/test/syscalls/linux/mkdir.cc
@@ -18,10 +18,10 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/temp_umask.h"
#include "test/util/capability_util.h"
#include "test/util/fs_util.h"
#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -36,21 +36,12 @@ class MkdirTest : public ::testing::Test {
// TearDown unlinks created files.
void TearDown() override {
- // FIXME(edahlgren): We don't currently implement rmdir.
- // We do this unconditionally because there's no harm in trying.
- rmdir(dirname_.c_str());
+ EXPECT_THAT(rmdir(dirname_.c_str()), SyscallSucceeds());
}
std::string dirname_;
};
-TEST_F(MkdirTest, DISABLED_CanCreateReadbleDir) {
- ASSERT_THAT(mkdir(dirname_.c_str(), 0444), SyscallSucceeds());
- ASSERT_THAT(
- open(JoinPath(dirname_, "anything").c_str(), O_RDWR | O_CREAT, 0666),
- SyscallFailsWithErrno(EACCES));
-}
-
TEST_F(MkdirTest, CanCreateWritableDir) {
ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds());
std::string filename = JoinPath(dirname_, "anything");
@@ -84,10 +75,11 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) {
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
- auto parent = ASSERT_NO_ERRNO_AND_VALUE(
- TempPath::CreateDirWith(GetAbsoluteTestTmpdir(), 0555));
- auto dir = JoinPath(parent.path(), "foo");
- ASSERT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES));
+ ASSERT_THAT(mkdir(dirname_.c_str(), 0555), SyscallSucceeds());
+ auto dir = JoinPath(dirname_.c_str(), "foo");
+ EXPECT_THAT(mkdir(dir.c_str(), 0777), SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(open(JoinPath(dirname_, "file").c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EACCES));
}
} // namespace
diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc
index 4c45766c7..05dfb375a 100644
--- a/test/syscalls/linux/mknod.cc
+++ b/test/syscalls/linux/mknod.cc
@@ -15,6 +15,7 @@
#include <errno.h>
#include <fcntl.h>
#include <sys/stat.h>
+#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
@@ -39,7 +40,28 @@ TEST(MknodTest, RegularFile) {
EXPECT_THAT(mknod(node1.c_str(), 0, 0), SyscallSucceeds());
}
-TEST(MknodTest, MknodAtRegularFile) {
+TEST(MknodTest, RegularFilePermissions) {
+ const std::string node = NewTempAbsPath();
+ mode_t newUmask = 0077;
+ umask(newUmask);
+
+ // Attempt to open file with mode 0777. Not specifying file type should create
+ // a regualar file.
+ mode_t perms = S_IRWXU | S_IRWXG | S_IRWXO;
+ EXPECT_THAT(mknod(node.c_str(), perms, 0), SyscallSucceeds());
+
+ // In the absence of a default ACL, the permissions of the created node are
+ // (mode & ~umask). -- mknod(2)
+ mode_t wantPerms = perms & ~newUmask;
+ struct stat st;
+ ASSERT_THAT(stat(node.c_str(), &st), SyscallSucceeds());
+ ASSERT_EQ(st.st_mode & 0777, wantPerms);
+
+ // "Zero file type is equivalent to type S_IFREG." - mknod(2)
+ ASSERT_EQ(st.st_mode & S_IFMT, S_IFREG);
+}
+
+TEST(MknodTest, MknodAtFIFO) {
const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
const std::string fifo_relpath = NewTempRelPath();
const std::string fifo = JoinPath(dir.path(), fifo_relpath);
@@ -72,7 +94,7 @@ TEST(MknodTest, MknodOnExistingPathFails) {
TEST(MknodTest, UnimplementedTypesReturnError) {
const std::string path = NewTempAbsPath();
- if (IsRunningOnGvisor()) {
+ if (IsRunningWithVFS1()) {
ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0),
SyscallFailsWithErrno(EOPNOTSUPP));
}
diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc
index 283c21ed3..78ac96bed 100644
--- a/test/syscalls/linux/mlock.cc
+++ b/test/syscalls/linux/mlock.cc
@@ -16,6 +16,7 @@
#include <sys/resource.h>
#include <sys/syscall.h>
#include <unistd.h>
+
#include <cerrno>
#include <cstring>
@@ -59,7 +60,6 @@ bool IsPageMlocked(uintptr_t addr) {
return true;
}
-
TEST(MlockTest, Basic) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(CanMlock()));
auto const mapping = ASSERT_NO_ERRNO_AND_VALUE(
@@ -199,8 +199,10 @@ TEST(MunlockallTest, Basic) {
}
#ifndef SYS_mlock2
-#ifdef __x86_64__
+#if defined(__x86_64__)
#define SYS_mlock2 325
+#elif defined(__aarch64__)
+#define SYS_mlock2 284
#endif
#endif
diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc
index a112316e9..6d3227ab6 100644
--- a/test/syscalls/linux/mmap.cc
+++ b/test/syscalls/linux/mmap.cc
@@ -28,6 +28,7 @@
#include <sys/types.h>
#include <sys/wait.h>
#include <unistd.h>
+
#include <vector>
#include "gmock/gmock.h"
@@ -360,7 +361,7 @@ TEST_F(MMapTest, MapFixed) {
}
// 64-bit addresses work too
-#ifdef __x86_64__
+#if defined(__x86_64__) || defined(__aarch64__)
TEST_F(MMapTest, MapFixed64) {
EXPECT_THAT(Map(0x300000000000, kPageSize, PROT_NONE,
MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0),
@@ -570,6 +571,12 @@ const uint8_t machine_code[] = {
0xb8, 0x2a, 0x00, 0x00, 0x00, // movl $42, %eax
0xc3, // retq
};
+#elif defined(__aarch64__)
+const uint8_t machine_code[] = {
+ 0x40, 0x05, 0x80, 0x52, // mov w0, #42
+ 0xc0, 0x03, 0x5f, 0xd6, // ret
+};
+#endif
// PROT_EXEC allows code execution
TEST_F(MMapTest, ProtExec) {
@@ -604,7 +611,6 @@ TEST_F(MMapTest, NoProtExecDeath) {
EXPECT_EXIT(func(), ::testing::KilledBySignal(SIGSEGV), "");
}
-#endif
TEST_F(MMapTest, NoExceedLimitData) {
void* prevbrk;
@@ -813,23 +819,27 @@ class MMapFileTest : public MMapTest {
}
};
+class MMapFileParamTest
+ : public MMapFileTest,
+ public ::testing::WithParamInterface<std::tuple<int, int>> {
+ protected:
+ int prot() const { return std::get<0>(GetParam()); }
+
+ int flags() const { return std::get<1>(GetParam()); }
+};
+
// MAP_POPULATE allowed.
// There isn't a good way to verify it actually did anything.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, MapPopulate) {
- ASSERT_THAT(
- Map(0, kPageSize, PROT_READ, MAP_PRIVATE | MAP_POPULATE, fd_.get(), 0),
- SyscallSucceeds());
+TEST_P(MMapFileParamTest, MapPopulate) {
+ ASSERT_THAT(Map(0, kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0),
+ SyscallSucceeds());
}
// MAP_POPULATE on a short file.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, MapPopulateShort) {
- ASSERT_THAT(Map(0, 2 * kPageSize, PROT_READ, MAP_PRIVATE | MAP_POPULATE,
- fd_.get(), 0),
- SyscallSucceeds());
+TEST_P(MMapFileParamTest, MapPopulateShort) {
+ ASSERT_THAT(
+ Map(0, 2 * kPageSize, prot(), flags() | MAP_POPULATE, fd_.get(), 0),
+ SyscallSucceeds());
}
// Read contents from mapped file.
@@ -900,16 +910,6 @@ TEST_F(MMapFileTest, WritePrivateOnReadOnlyFd) {
reinterpret_cast<volatile char*>(addr));
}
-// MAP_PRIVATE PROT_READ is not allowed on write-only FDs.
-TEST_F(MMapFileTest, ReadPrivateOnWriteOnlyFd) {
- const FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY));
-
- uintptr_t addr;
- EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_PRIVATE, fd.get(), 0),
- SyscallFailsWithErrno(EACCES));
-}
-
// MAP_SHARED PROT_WRITE not allowed on read-only FDs.
TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) {
const FileDescriptor fd =
@@ -921,28 +921,13 @@ TEST_F(MMapFileTest, WriteSharedOnReadOnlyFd) {
SyscallFailsWithErrno(EACCES));
}
-// MAP_SHARED PROT_READ not allowed on write-only FDs.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, ReadSharedOnWriteOnlyFd) {
- const FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY));
-
- uintptr_t addr;
- EXPECT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd.get(), 0),
- SyscallFailsWithErrno(EACCES));
-}
-
-// MAP_SHARED PROT_WRITE not allowed on write-only FDs.
-// The FD must always be readable.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, WriteSharedOnWriteOnlyFd) {
+// The FD must be readable.
+TEST_P(MMapFileParamTest, WriteOnlyFd) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(filename_, O_WRONLY));
uintptr_t addr;
- EXPECT_THAT(addr = Map(0, kPageSize, PROT_WRITE, MAP_SHARED, fd.get(), 0),
+ EXPECT_THAT(addr = Map(0, kPageSize, prot(), flags(), fd.get(), 0),
SyscallFailsWithErrno(EACCES));
}
@@ -1181,7 +1166,7 @@ TEST_F(MMapFileTest, ReadSharedTruncateDownThenUp) {
ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
SyscallSucceeds());
- // Check that the memory contains he file data.
+ // Check that the memory contains the file data.
EXPECT_EQ(0, memcmp(reinterpret_cast<void*>(addr), buf.c_str(), kPageSize));
// Truncate down, then up.
@@ -1370,132 +1355,75 @@ TEST_F(MMapFileTest, WritePrivate) {
EqualsMemory(std::string(len, '\0')));
}
-// SIGBUS raised when writing past end of file to a private mapping.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, SigBusDeathWritePrivate) {
- SetupGvisorDeathTest();
-
- uintptr_t addr;
- ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
- fd_.get(), 0),
- SyscallSucceeds());
-
- // MMapFileTest makes a file kPageSize/2 long. The entire first page will be
- // accessible. Write just beyond that.
- size_t len = strlen(kFileContents);
- EXPECT_EXIT(std::copy(kFileContents, kFileContents + len,
- reinterpret_cast<volatile char*>(addr + kPageSize)),
- ::testing::KilledBySignal(SIGBUS), "");
-}
-
-// SIGBUS raised when reading past end of file on a shared mapping.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, SigBusDeathReadShared) {
+// SIGBUS raised when reading or writing past end of a mapped file.
+TEST_P(MMapFileParamTest, SigBusDeath) {
SetupGvisorDeathTest();
uintptr_t addr;
- ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0),
SyscallSucceeds());
- // MMapFileTest makes a file kPageSize/2 long. The entire first page will be
- // accessible. Read just beyond that.
- std::vector<char> in(kPageSize);
- EXPECT_EXIT(
- std::copy(reinterpret_cast<volatile char*>(addr + kPageSize),
- reinterpret_cast<volatile char*>(addr + kPageSize) + kPageSize,
- in.data()),
- ::testing::KilledBySignal(SIGBUS), "");
+ auto* start = reinterpret_cast<volatile char*>(addr + kPageSize);
+
+ // MMapFileTest makes a file kPageSize/2 long. The entire first page should be
+ // accessible, but anything beyond it should not.
+ if (prot() & PROT_WRITE) {
+ // Write beyond first page.
+ size_t len = strlen(kFileContents);
+ EXPECT_EXIT(std::copy(kFileContents, kFileContents + len, start),
+ ::testing::KilledBySignal(SIGBUS), "");
+ } else {
+ // Read beyond first page.
+ std::vector<char> in(kPageSize);
+ EXPECT_EXIT(std::copy(start, start + kPageSize, in.data()),
+ ::testing::KilledBySignal(SIGBUS), "");
+ }
}
-// SIGBUS raised when reading past end of file on a shared mapping.
+// Tests that SIGBUS is not raised when reading or writing to a file-mapped
+// page before EOF, even if part of the mapping extends beyond EOF.
//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, SigBusDeathWriteShared) {
- SetupGvisorDeathTest();
-
+// See b/27877699.
+TEST_P(MMapFileParamTest, NoSigBusOnPagesBeforeEOF) {
uintptr_t addr;
- ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
- fd_.get(), 0),
- SyscallSucceeds());
-
- // MMapFileTest makes a file kPageSize/2 long. The entire first page will be
- // accessible. Write just beyond that.
- size_t len = strlen(kFileContents);
- EXPECT_EXIT(std::copy(kFileContents, kFileContents + len,
- reinterpret_cast<volatile char*>(addr + kPageSize)),
- ::testing::KilledBySignal(SIGBUS), "");
-}
-
-// Tests that SIGBUS is not raised when writing to a file-mapped page before
-// EOF, even if part of the mapping extends beyond EOF.
-TEST_F(MMapFileTest, NoSigBusOnPagesBeforeEOF) {
- uintptr_t addr;
- ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
- fd_.get(), 0),
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0),
SyscallSucceeds());
// The test passes if this survives.
- size_t len = strlen(kFileContents);
- std::copy(kFileContents, kFileContents + len,
- reinterpret_cast<volatile char*>(addr));
-}
-
-// Tests that SIGBUS is not raised when writing to a file-mapped page containing
-// EOF, *after* the EOF for a private mapping.
-TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFWritePrivate) {
- uintptr_t addr;
- ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
- fd_.get(), 0),
- SyscallSucceeds());
-
- // The test passes if this survives. (Technically addr+kPageSize/2 is already
- // beyond EOF, but +1 to check for fencepost errors.)
- size_t len = strlen(kFileContents);
- std::copy(kFileContents, kFileContents + len,
- reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1));
-}
-
-// Tests that SIGBUS is not raised when reading from a file-mapped page
-// containing EOF, *after* the EOF for a shared mapping.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFReadShared) {
- uintptr_t addr;
- ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
- SyscallSucceeds());
-
- // The test passes if this survives. (Technically addr+kPageSize/2 is already
- // beyond EOF, but +1 to check for fencepost errors.)
auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1);
size_t len = strlen(kFileContents);
- std::vector<char> in(len);
- std::copy(start, start + len, in.data());
+ if (prot() & PROT_WRITE) {
+ std::copy(kFileContents, kFileContents + len, start);
+ } else {
+ std::vector<char> in(len);
+ std::copy(start, start + len, in.data());
+ }
}
-// Tests that SIGBUS is not raised when writing to a file-mapped page containing
-// EOF, *after* the EOF for a shared mapping.
-//
-// FIXME(b/37222275): Parameterize.
-TEST_F(MMapFileTest, NoSigBusOnPageContainingEOFWriteShared) {
+// Tests that SIGBUS is not raised when reading or writing from a file-mapped
+// page containing EOF, *after* the EOF.
+TEST_P(MMapFileParamTest, NoSigBusOnPageContainingEOF) {
uintptr_t addr;
- ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
- fd_.get(), 0),
+ ASSERT_THAT(addr = Map(0, 2 * kPageSize, prot(), flags(), fd_.get(), 0),
SyscallSucceeds());
// The test passes if this survives. (Technically addr+kPageSize/2 is already
// beyond EOF, but +1 to check for fencepost errors.)
+ auto* start = reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1);
size_t len = strlen(kFileContents);
- std::copy(kFileContents, kFileContents + len,
- reinterpret_cast<volatile char*>(addr + (kPageSize / 2) + 1));
+ if (prot() & PROT_WRITE) {
+ std::copy(kFileContents, kFileContents + len, start);
+ } else {
+ std::vector<char> in(len);
+ std::copy(start, start + len, in.data());
+ }
}
// Tests that reading from writable shared file-mapped pages succeeds.
//
// On most platforms this is trivial, but when the file is mapped via the sentry
// page cache (which does not yet support writing to shared mappings), a bug
-// caused reads to fail unnecessarily on such mappings.
+// caused reads to fail unnecessarily on such mappings. See b/28913513.
TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) {
uintptr_t addr;
size_t len = strlen(kFileContents);
@@ -1512,7 +1440,7 @@ TEST_F(MMapFileTest, ReadingWritableSharedFilePageSucceeds) {
// Tests that EFAULT is returned when invoking a syscall that requires the OS to
// read past end of file (resulting in a fault in sentry context in the gVisor
-// case).
+// case). See b/28913513.
TEST_F(MMapFileTest, InternalSigBus) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, 2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE,
@@ -1655,7 +1583,7 @@ TEST_F(MMapFileTest, Bug38498194) {
}
// Tests that reading from a file to a memory mapping of the same file does not
-// deadlock.
+// deadlock. See b/34813270.
TEST_F(MMapFileTest, SelfRead) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED,
@@ -1667,7 +1595,7 @@ TEST_F(MMapFileTest, SelfRead) {
}
// Tests that writing to a file from a memory mapping of the same file does not
-// deadlock.
+// deadlock. Regression test for b/34813270.
TEST_F(MMapFileTest, SelfWrite) {
uintptr_t addr;
ASSERT_THAT(addr = Map(0, kPageSize, PROT_READ, MAP_SHARED, fd_.get(), 0),
@@ -1721,6 +1649,7 @@ TEST(MMapNoFixtureTest, MapReadOnlyAfterCreateWriteOnly) {
}
// Conditional on MAP_32BIT.
+// This flag is supported only on x86-64, for 64-bit programs.
#ifdef __x86_64__
TEST(MMapNoFixtureTest, Map32Bit) {
@@ -1732,6 +1661,15 @@ TEST(MMapNoFixtureTest, Map32Bit) {
#endif // defined(__x86_64__)
+INSTANTIATE_TEST_SUITE_P(
+ ReadWriteSharedPrivate, MMapFileParamTest,
+ ::testing::Combine(::testing::ValuesIn({
+ PROT_READ,
+ PROT_WRITE,
+ PROT_READ | PROT_WRITE,
+ }),
+ ::testing::ValuesIn({MAP_SHARED, MAP_PRIVATE})));
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
index e35be3cab..46b6f38db 100644
--- a/test/syscalls/linux/mount.cc
+++ b/test/syscalls/linux/mount.cc
@@ -18,6 +18,7 @@
#include <sys/mount.h>
#include <sys/stat.h>
#include <unistd.h>
+
#include <functional>
#include <memory>
#include <string>
@@ -320,6 +321,42 @@ TEST(MountTest, RenameRemoveMountPoint) {
ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY));
}
+TEST(MountTest, MountFuseFilesystemNoDevice) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled());
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // Before kernel version 4.16-rc6, FUSE mount is protected by
+ // capable(CAP_SYS_ADMIN). After this version, it uses
+ // ns_capable(CAP_SYS_ADMIN) to protect. Before the 4.16 kernel, it was not
+ // allowed to mount fuse file systems without the global CAP_SYS_ADMIN.
+ int res = mount("", dir.path().c_str(), "fuse", 0, "");
+ SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM);
+
+ EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(MountTest, MountFuseFilesystem) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled());
+
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY));
+ std::string mopts = "fd=" + std::to_string(fd.get());
+
+ auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ // See comments in MountFuseFilesystemNoDevice for the reason why we skip
+ // EPERM when running on Linux.
+ int res = mount("", dir.path().c_str(), "fuse", 0, "");
+ SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM);
+
+ auto const mount =
+ ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/msync.cc b/test/syscalls/linux/msync.cc
index ac7146017..2b2b6aef9 100644
--- a/test/syscalls/linux/msync.cc
+++ b/test/syscalls/linux/msync.cc
@@ -60,9 +60,7 @@ std::vector<std::function<PosixErrorOr<Mapping>()>> SyncableMappings() {
for (int const mflags : {MAP_PRIVATE, MAP_SHARED}) {
int const prot = PROT_READ | (writable ? PROT_WRITE : 0);
int const oflags = O_CREAT | (writable ? O_RDWR : O_RDONLY);
- funcs.push_back([=] {
- return MmapAnon(kPageSize, prot, mflags);
- });
+ funcs.push_back([=] { return MmapAnon(kPageSize, prot, mflags); });
funcs.push_back([=]() -> PosixErrorOr<Mapping> {
std::string const path = NewTempAbsPath();
ASSIGN_OR_RETURN_ERRNO(auto fd, Open(path, oflags, 0644));
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
new file mode 100644
index 000000000..133fdecf0
--- /dev/null
+++ b/test/syscalls/linux/network_namespace.cc
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <net/if.h>
+#include <sched.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+TEST(NetworkNamespaceTest, LoopbackExists) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ ScopedThread t([&] {
+ ASSERT_THAT(unshare(CLONE_NEWNET), SyscallSucceedsWithValue(0));
+
+ // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
+ // Check loopback device exists.
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+ struct ifreq ifr;
+ strncpy(ifr.ifr_name, "lo", IFNAMSIZ);
+ EXPECT_THAT(ioctl(sock, SIOCGIFINDEX, &ifr), SyscallSucceeds())
+ << "lo cannot be found";
+ });
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index 2b1df52ce..77f390f3c 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -27,6 +27,7 @@
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -73,6 +74,60 @@ class OpenTest : public FileTest {
const std::string test_data_ = "hello world\n";
};
+TEST_F(OpenTest, OTrunc) {
+ auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_TRUNC, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(OpenTest, OTruncAndReadOnlyDir) {
+ auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(OpenTest, OTruncAndReadOnlyFile) {
+ auto dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile");
+ const FileDescriptor existing =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dirpath.c_str(), O_RDWR | O_CREAT, 0666));
+ const FileDescriptor otrunc = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666));
+}
+
+TEST_F(OpenTest, OCreateDirectory) {
+ SKIP_IF(IsRunningWithVFS1());
+ auto dirpath = GetAbsoluteTestTmpdir();
+
+ // Normal case: existing directory.
+ ASSERT_THAT(open(dirpath.c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EISDIR));
+ // Trailing separator on existing directory.
+ ASSERT_THAT(open(dirpath.append("/").c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EISDIR));
+ // Trailing separator on non-existing directory.
+ ASSERT_THAT(open(JoinPath(dirpath, "non-existent").append("/").c_str(),
+ O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EISDIR));
+ // "." special case.
+ ASSERT_THAT(open(JoinPath(dirpath, ".").c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST_F(OpenTest, MustCreateExisting) {
+ auto dirPath = GetAbsoluteTestTmpdir();
+
+ // Existing directory.
+ ASSERT_THAT(open(dirPath.c_str(), O_RDWR | O_CREAT | O_EXCL, 0666),
+ SyscallFailsWithErrno(EEXIST));
+
+ // Existing file.
+ auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dirPath));
+ ASSERT_THAT(open(newFile.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666),
+ SyscallFailsWithErrno(EEXIST));
+}
+
TEST_F(OpenTest, ReadOnly) {
char buf;
const FileDescriptor ro_file =
@@ -93,6 +148,26 @@ TEST_F(OpenTest, WriteOnly) {
EXPECT_THAT(write(wo_file.get(), &buf, 1), SyscallSucceedsWithValue(1));
}
+TEST_F(OpenTest, CreateWithAppend) {
+ std::string data = "text";
+ std::string new_file = NewTempAbsPath();
+ const FileDescriptor file = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(new_file, O_WRONLY | O_APPEND | O_CREAT, 0666));
+ EXPECT_THAT(write(file.get(), data.c_str(), data.size()),
+ SyscallSucceedsWithValue(data.size()));
+ EXPECT_THAT(lseek(file.get(), 0, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(write(file.get(), data.c_str(), data.size()),
+ SyscallSucceedsWithValue(data.size()));
+
+ // Check that the size of the file is correct and that the offset has been
+ // incremented to that size.
+ struct stat s0;
+ EXPECT_THAT(fstat(file.get(), &s0), SyscallSucceeds());
+ EXPECT_EQ(s0.st_size, 2 * data.size());
+ EXPECT_THAT(lseek(file.get(), 0, SEEK_CUR),
+ SyscallSucceedsWithValue(2 * data.size()));
+}
+
TEST_F(OpenTest, ReadWrite) {
char buf;
const FileDescriptor rw_file =
@@ -164,6 +239,28 @@ TEST_F(OpenTest, OpenNoFollowStillFollowsLinksInPath) {
ASSERT_NO_ERRNO_AND_VALUE(Open(path_via_symlink, O_RDONLY | O_NOFOLLOW));
}
+// Test that open(2) can follow symlinks that point back to the same tree.
+// Test sets up files as follows:
+// root/child/symlink => redirects to ../..
+// root/child/target => regular file
+//
+// open("root/child/symlink/root/child/file")
+TEST_F(OpenTest, SymlinkRecurse) {
+ auto root =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(GetAbsoluteTestTmpdir()));
+ auto child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(root.path()));
+ auto symlink = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(child.path(), "../.."));
+ auto target = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(child.path(), "abc", 0644));
+ auto path_via_symlink =
+ JoinPath(symlink.path(), Basename(root.path()), Basename(child.path()),
+ Basename(target.path()));
+ const auto contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents(path_via_symlink));
+ ASSERT_EQ(contents, "abc");
+}
+
TEST_F(OpenTest, Fault) {
char* totally_not_null = nullptr;
ASSERT_THAT(open(totally_not_null, O_RDONLY), SyscallFailsWithErrno(EFAULT));
@@ -191,7 +288,7 @@ TEST_F(OpenTest, AppendOnly) {
ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND));
EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
- // Then try to write to the first file and make sure the bytes are appended.
+ // Then try to write to the first fd and make sure the bytes are appended.
EXPECT_THAT(WriteFd(fd1.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(buf.size()));
@@ -203,7 +300,7 @@ TEST_F(OpenTest, AppendOnly) {
EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR),
SyscallSucceedsWithValue(kBufSize * 2));
- // Then try to write to the second file and make sure the bytes are appended.
+ // Then try to write to the second fd and make sure the bytes are appended.
EXPECT_THAT(WriteFd(fd2.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(buf.size()));
@@ -312,6 +409,13 @@ TEST_F(OpenTest, FileNotDirectory) {
SyscallFailsWithErrno(ENOTDIR));
}
+TEST_F(OpenTest, SymlinkDirectory) {
+ auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string link = NewTempAbsPath();
+ ASSERT_THAT(symlink(dir.path().c_str(), link.c_str()), SyscallSucceeds());
+ ASSERT_NO_ERRNO(Open(link, O_RDONLY | O_DIRECTORY));
+}
+
TEST_F(OpenTest, Null) {
char c = '\0';
ASSERT_THAT(open(&c, O_RDONLY), SyscallFailsWithErrno(ENOENT));
@@ -372,6 +476,35 @@ TEST_F(OpenTest, CanTruncateWriteOnlyNoReadPermission_NoRandomSave) {
EXPECT_EQ(stat.st_size, 0);
}
+TEST_F(OpenTest, CanTruncateWithStrangePermissions) {
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+ const DisableSave ds; // Permissions are dropped.
+ std::string path = NewTempAbsPath();
+ int fd;
+ // Create a file without user permissions.
+ EXPECT_THAT( // SAVE_BELOW
+ fd = open(path.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 055),
+ SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ // Cannot open file because we are owner and have no permissions set.
+ EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES));
+
+ // We *can* chmod the file, because we are the owner.
+ EXPECT_THAT(chmod(path.c_str(), 0755), SyscallSucceeds());
+
+ // Now we can open the file again.
+ EXPECT_THAT(fd = open(path.c_str(), O_RDWR), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) {
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string bad_path = file.path() + "/";
+ EXPECT_THAT(open(bad_path.c_str(), O_RDONLY), SyscallFailsWithErrno(ENOTDIR));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc
index e5a85ef9d..51eacf3f2 100644
--- a/test/syscalls/linux/open_create.cc
+++ b/test/syscalls/linux/open_create.cc
@@ -19,11 +19,11 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/temp_umask.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/temp_path.h"
+#include "test/util/temp_umask.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -88,6 +88,30 @@ TEST(CreateTest, CreateExclusively) {
SyscallFailsWithErrno(EEXIST));
}
+TEST(CreateTeast, CreatWithOTrunc) {
+ std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) {
+ std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd");
+ ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666),
+ SyscallFailsWithErrno(EISDIR));
+}
+
+TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) {
+ std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile");
+ int dirfd;
+ ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666),
+ SyscallSucceeds());
+ ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666),
+ SyscallSucceeds());
+ ASSERT_THAT(close(dirfd), SyscallSucceeds());
+}
+
TEST(CreateTest, CreateFailsOnUnpermittedDir) {
// Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to
// always override directory permissions.
@@ -108,6 +132,7 @@ TEST(CreateTest, CreateFailsOnDirWithoutWritePerms) {
}
// A file originally created RW, but opened RO can later be opened RW.
+// Regression test for b/65385065.
TEST(CreateTest, OpenCreateROThenRW) {
TempPath file(NewTempAbsPath());
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 92ae55eec..861617ff7 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <arpa/inet.h>
+#include <ifaddrs.h>
#include <linux/capability.h>
#include <linux/if_arp.h>
#include <linux/if_packet.h>
@@ -163,16 +164,11 @@ int CookedPacketTest::GetLoopbackIndex() {
return ifr.ifr_ifindex;
}
-// Receive via a packet socket.
-TEST_P(CookedPacketTest, Receive) {
- // Let's use a simple IP payload: a UDP datagram.
- FileDescriptor udp_sock =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
- SendUDPMessage(udp_sock.get());
-
+// Receive and verify the message via packet socket on interface.
+void ReceiveMessage(int sock, int ifindex) {
// Wait for the socket to become readable.
struct pollfd pfd = {};
- pfd.fd = socket_;
+ pfd.fd = sock;
pfd.events = POLLIN;
EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
@@ -182,20 +178,22 @@ TEST_P(CookedPacketTest, Receive) {
char buf[64];
struct sockaddr_ll src = {};
socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ ASSERT_THAT(recvfrom(sock, buf, sizeof(buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_len),
SyscallSucceedsWithValue(packet_size));
+
// sockaddr_ll ends with an 8 byte physical address field, but ethernet
// addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
// here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
// sizeof(sockaddr_ll).
ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
- // TODO(b/129292371): Verify protocol once we return it.
+ // TODO(gvisor.dev/issue/173): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
- EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
+ EXPECT_EQ(src.sll_ifindex, ifindex);
EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
EXPECT_EQ(src.sll_addr[i], 0);
@@ -222,9 +220,21 @@ TEST_P(CookedPacketTest, Receive) {
EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
}
+// Receive via a packet socket.
+TEST_P(CookedPacketTest, Receive) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ int loopback_index = GetLoopbackIndex();
+ ReceiveMessage(socket_, loopback_index);
+}
+
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
- // TODO(b/129292371): Remove once we support packet socket writing.
+ // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing.
SKIP_IF(IsRunningOnGvisor());
// Let's send a UDP packet and receive it using a regular UDP socket.
@@ -313,6 +323,230 @@ TEST_P(CookedPacketTest, Send) {
EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
}
+// Bind and receive via packet socket.
+TEST_P(CookedPacketTest, BindReceive) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+ SendUDPMessage(udp_sock.get());
+
+ // Receive and verify the data.
+ ReceiveMessage(socket_, bind_addr.sll_ifindex);
+}
+
+// Double Bind socket.
+TEST_P(CookedPacketTest, DoubleBindSucceeds) {
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Binding socket again should fail.
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ // Linux 4.09 returns EINVAL here, but some time before 4.19 it
+ // switched to EADDRINUSE.
+ SyscallSucceeds());
+}
+
+// Bind and verify we do not receive data on interface which is not bound
+TEST_P(CookedPacketTest, BindDrop) {
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+
+ // Bind to packet socket requires only family, protocol and ifindex.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = ifr.ifr_ifindex;
+
+ ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // Send to loopback interface.
+ struct sockaddr_in dest = {};
+ dest.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket never receives any data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
+}
+
+// Verify that we receive outbound packets. This test requires at least one
+// non loopback interface so that we can actually capture an outgoing packet.
+TEST_P(CookedPacketTest, ReceiveOutbound) {
+ // Only ETH_P_ALL sockets can receive outbound packets on linux.
+ SKIP_IF(GetParam() != ETH_P_ALL);
+
+ // Let's use a simple IP payload: a UDP datagram.
+ FileDescriptor udp_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ifaddrs* if_addr_list = nullptr;
+ auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); });
+
+ ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds());
+
+ // Get interface other than loopback.
+ struct ifreq ifr = {};
+ for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) {
+ if (strcmp(i->ifa_name, "lo") != 0) {
+ strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name));
+ break;
+ }
+ }
+
+ // Skip if no interface is available other than loopback.
+ if (strlen(ifr.ifr_name) == 0) {
+ GTEST_SKIP();
+ }
+
+ // Get interface index and name.
+ EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_NE(ifr.ifr_ifindex, 0);
+ int ifindex = ifr.ifr_ifindex;
+
+ constexpr int kMACSize = 6;
+ char hwaddr[kMACSize];
+ // Get interface address.
+ ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds());
+ ASSERT_THAT(ifr.ifr_hwaddr.sa_family,
+ AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER)));
+ memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize);
+
+ // Just send it to the google dns server 8.8.8.8. It's UDP we don't care
+ // if it actually gets to the DNS Server we just want to see that we receive
+ // it on our AF_PACKET socket.
+ //
+ // NOTE: We just want to pick an IP that is non-local to avoid having to
+ // handle ARP as this should cause the UDP packet to be sent to the default
+ // gateway configured for the system under test. Otherwise the only packet we
+ // will see is the ARP query unless we picked an IP which will actually
+ // resolve. The test is a bit brittle but this was the best compromise for
+ // now.
+ struct sockaddr_in dest = {};
+ ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1);
+ dest.sin_family = AF_INET;
+ dest.sin_port = kPort;
+ EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0,
+ reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+
+ // Wait and make sure the socket receives the data.
+ struct pollfd pfd = {};
+ pfd.fd = socket_;
+ pfd.events = POLLIN;
+ EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1));
+
+ // Now read and check that the packet is the one we just sent.
+ // Read and verify the data.
+ constexpr size_t packet_size =
+ sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage);
+ char buf[64];
+ struct sockaddr_ll src = {};
+ socklen_t src_len = sizeof(src);
+ ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ reinterpret_cast<struct sockaddr*>(&src), &src_len),
+ SyscallSucceedsWithValue(packet_size));
+
+ // sockaddr_ll ends with an 8 byte physical address field, but ethernet
+ // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2
+ // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns
+ // sizeof(sockaddr_ll).
+ ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
+
+ // Verify the source address.
+ EXPECT_EQ(src.sll_family, AF_PACKET);
+ EXPECT_EQ(src.sll_ifindex, ifindex);
+ EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
+ EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING);
+ // Verify the link address of the interface matches that of the non
+ // non loopback interface address we stored above.
+ for (int i = 0; i < src.sll_halen; i++) {
+ EXPECT_EQ(src.sll_addr[i], hwaddr[i]);
+ }
+
+ // Verify the IP header.
+ struct iphdr ip = {};
+ memcpy(&ip, buf, sizeof(ip));
+ EXPECT_EQ(ip.ihl, 5);
+ EXPECT_EQ(ip.version, 4);
+ EXPECT_EQ(ip.tot_len, htons(packet_size));
+ EXPECT_EQ(ip.protocol, IPPROTO_UDP);
+ EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr);
+ EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK));
+
+ // Verify the UDP header.
+ struct udphdr udp = {};
+ memcpy(&udp, buf + sizeof(iphdr), sizeof(udp));
+ EXPECT_EQ(udp.dest, kPort);
+ EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage)));
+
+ // Verify the payload.
+ char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr));
+ EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0);
+}
+
+// Bind with invalid address.
+TEST_P(CookedPacketTest, BindFail) {
+ // Null address.
+ ASSERT_THAT(
+ bind(socket_, nullptr, sizeof(struct sockaddr)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EINVAL)));
+
+ // Address of size 1.
+ uint8_t addr = 0;
+ ASSERT_THAT(
+ bind(socket_, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, CookedPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index d258d353c..a11a03415 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -14,6 +14,9 @@
#include <arpa/inet.h>
#include <linux/capability.h>
+#ifndef __fuchsia__
+#include <linux/filter.h>
+#endif // __fuchsia__
#include <linux/if_arp.h>
#include <linux/if_packet.h>
#include <net/ethernet.h>
@@ -97,7 +100,7 @@ class RawPacketTest : public ::testing::TestWithParam<int> {
int GetLoopbackIndex();
// The socket used for both reading and writing.
- int socket_;
+ int s_;
};
void RawPacketTest::SetUp() {
@@ -108,34 +111,58 @@ void RawPacketTest::SetUp() {
}
if (!IsRunningOnGvisor()) {
+ // Ensure that looped back packets aren't rejected by the kernel.
FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
- Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
+ Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDWR));
FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE(
- Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY));
+ Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDWR));
char enabled;
ASSERT_THAT(read(acceptLocal.get(), &enabled, 1),
SyscallSucceedsWithValue(1));
- ASSERT_EQ(enabled, '1');
+ if (enabled != '1') {
+ enabled = '1';
+ ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(write(acceptLocal.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(read(acceptLocal.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ }
+
ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1),
SyscallSucceedsWithValue(1));
- ASSERT_EQ(enabled, '1');
+ if (enabled != '1') {
+ enabled = '1';
+ ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(write(routeLocalnet.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET),
+ SyscallSucceedsWithValue(0));
+ ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1),
+ SyscallSucceedsWithValue(1));
+ ASSERT_EQ(enabled, '1');
+ }
}
- ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
+ ASSERT_THAT(s_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
SyscallSucceeds());
}
void RawPacketTest::TearDown() {
// TearDown will be run even if we skip the test.
if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ EXPECT_THAT(close(s_), SyscallSucceeds());
}
}
int RawPacketTest::GetLoopbackIndex() {
struct ifreq ifr;
snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
- EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ EXPECT_THAT(ioctl(s_, SIOCGIFINDEX, &ifr), SyscallSucceeds());
EXPECT_NE(ifr.ifr_ifindex, 0);
return ifr.ifr_ifindex;
}
@@ -149,7 +176,7 @@ TEST_P(RawPacketTest, Receive) {
// Wait for the socket to become readable.
struct pollfd pfd = {};
- pfd.fd = socket_;
+ pfd.fd = s_;
pfd.events = POLLIN;
EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1));
@@ -159,7 +186,7 @@ TEST_P(RawPacketTest, Receive) {
char buf[64];
struct sockaddr_ll src = {};
socklen_t src_len = sizeof(src);
- ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0,
+ ASSERT_THAT(recvfrom(s_, buf, sizeof(buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_len),
SyscallSucceedsWithValue(packet_size));
// sockaddr_ll ends with an 8 byte physical address field, but ethernet
@@ -168,11 +195,12 @@ TEST_P(RawPacketTest, Receive) {
// sizeof(sockaddr_ll).
ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2)));
- // TODO(b/129292371): Verify protocol once we return it.
+ // TODO(gvisor.dev/issue/173): Verify protocol once we return it.
// Verify the source address.
EXPECT_EQ(src.sll_family, AF_PACKET);
EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex());
EXPECT_EQ(src.sll_halen, ETH_ALEN);
+ EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP);
// This came from the loopback device, so the address is all 0s.
for (int i = 0; i < src.sll_halen; i++) {
EXPECT_EQ(src.sll_addr[i], 0);
@@ -212,7 +240,7 @@ TEST_P(RawPacketTest, Receive) {
// Send via a packet socket.
TEST_P(RawPacketTest, Send) {
- // TODO(b/129292371): Remove once we support packet socket writing.
+ // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing.
SKIP_IF(IsRunningOnGvisor());
// Let's send a UDP packet and receive it using a regular UDP socket.
@@ -277,7 +305,7 @@ TEST_P(RawPacketTest, Send) {
sizeof(kMessage));
// Send it.
- ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0,
+ ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0,
reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)),
SyscallSucceedsWithValue(sizeof(send_buf)));
@@ -286,13 +314,13 @@ TEST_P(RawPacketTest, Send) {
pfd.fd = udp_sock.get();
pfd.events = POLLIN;
ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
- pfd.fd = socket_;
+ pfd.fd = s_;
pfd.events = POLLIN;
ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1));
// Receive on the packet socket.
char recv_buf[sizeof(send_buf)];
- ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0),
+ ASSERT_THAT(recv(s_, recv_buf, sizeof(recv_buf), 0),
SyscallSucceedsWithValue(sizeof(recv_buf)));
ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0);
@@ -309,6 +337,318 @@ TEST_P(RawPacketTest, Send) {
EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK));
}
+// Check that setting SO_RCVBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawPacketTest, SetSocketRecvBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum receive buf size by trying to set it to zero.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_RCVBUF above max is clamped to the maximum
+// receive buffer size.
+TEST_P(RawPacketTest, SetSocketRecvBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover max buf size by trying to set the largest possible buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored.
+TEST_P(RawPacketTest, SetSocketRecvBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover max buf size by trying to set a really large buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set a zero size receive buffer
+ // size.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Check that setting SO_SNDBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawPacketTest, SetSocketSendBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_SNDBUF above max is clamped to the maximum
+// send buffer size.
+TEST_P(RawPacketTest, SetSocketSendBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored.
+TEST_P(RawPacketTest, SetSocketSendBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack
+ // matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+
+ ASSERT_EQ(quarter_sz, val);
+}
+
+TEST_P(RawPacketTest, GetSocketError) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(val, 0);
+}
+
+TEST_P(RawPacketTest, GetSocketErrorBind) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ {
+ // Bind to the loopback device.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = GetLoopbackIndex();
+
+ ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallSucceeds());
+
+ // SO_ERROR should return no errors.
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(val, 0);
+ }
+
+ {
+ // Now try binding to an invalid interface.
+ struct sockaddr_ll bind_addr = {};
+ bind_addr.sll_family = AF_PACKET;
+ bind_addr.sll_protocol = htons(GetParam());
+ bind_addr.sll_ifindex = 0xffff; // Just pick a really large number.
+
+ // Binding should fail with EINVAL
+ ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ sizeof(bind_addr)),
+ SyscallFailsWithErrno(ENODEV));
+
+ // SO_ERROR does not return error when the device is invalid.
+ // On Linux there is just one odd ball condition where this can return
+ // an error where the device was valid and then removed or disabled
+ // between the first check for index and the actual registration of
+ // the packet endpoint. On Netstack this is not possible as the stack
+ // global mutex is held during registration and check.
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(val, 0);
+ }
+}
+
+#ifndef __fuchsia__
+
+TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ //
+ // gVisor returns no error on SO_DETACH_FILTER even if there is no filter
+ // attached unlike linux which does return ENOENT in such cases. This is
+ // because gVisor doesn't support SO_ATTACH_FILTER and just silently returns
+ // success.
+ if (IsRunningOnGvisor()) {
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+ return;
+ }
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(RawPacketTest, GetSocketDetachFilter) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest,
::testing::Values(ETH_P_IP, ETH_P_ALL));
diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc
index 33822ee57..df7129acc 100644
--- a/test/syscalls/linux/partial_bad_buffer.cc
+++ b/test/syscalls/linux/partial_bad_buffer.cc
@@ -18,7 +18,9 @@
#include <netinet/tcp.h>
#include <sys/mman.h>
#include <sys/socket.h>
+#include <sys/stat.h>
#include <sys/syscall.h>
+#include <sys/types.h>
#include <sys/uio.h>
#include <unistd.h>
@@ -62,9 +64,9 @@ class PartialBadBufferTest : public ::testing::Test {
// Write some initial data.
size_t size = sizeof(kMessage) - 1;
EXPECT_THAT(WriteFd(fd_, &kMessage, size), SyscallSucceedsWithValue(size));
-
ASSERT_THAT(lseek(fd_, 0, SEEK_SET), SyscallSucceeds());
+ // Map a useable buffer.
addr_ = mmap(0, 2 * kPageSize, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
ASSERT_NE(addr_, MAP_FAILED);
@@ -79,6 +81,15 @@ class PartialBadBufferTest : public ::testing::Test {
bad_buffer_ = buf + kPageSize - 1;
}
+ off_t Size() {
+ struct stat st;
+ int rc = fstat(fd_, &st);
+ if (rc < 0) {
+ return static_cast<off_t>(rc);
+ }
+ return st.st_size;
+ }
+
void TearDown() override {
EXPECT_THAT(munmap(addr_, 2 * kPageSize), SyscallSucceeds()) << addr_;
EXPECT_THAT(close(fd_), SyscallSucceeds());
@@ -165,97 +176,99 @@ TEST_F(PartialBadBufferTest, PreadvSmall) {
}
TEST_F(PartialBadBufferTest, WriteBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, kPageSize),
- SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, kPageSize)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WriteSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10),
- SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(write)(fd_, bad_buffer_, 10)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwriteBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, kPageSize, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwriteSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwrite)(fd_, bad_buffer_, 10, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwrite)(fd_, bad_buffer_, 10, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WritevBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, WritevSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(writev)(fd_, &vec, 1), SyscallFailsWithErrno(EFAULT));
+ ASSERT_THAT(lseek(fd_, orig_size, SEEK_SET), SyscallSucceeds());
+ EXPECT_THAT(
+ (n = RetryEINTR(writev)(fd_, &vec, 1)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwritevBig) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = kPageSize;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
TEST_F(PartialBadBufferTest, PwritevSmall) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
struct iovec vec;
vec.iov_base = bad_buffer_;
vec.iov_len = 10;
+ off_t orig_size = Size();
+ int n;
- EXPECT_THAT(RetryEINTR(pwritev)(fd_, &vec, 1, 0),
- SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(
+ (n = RetryEINTR(pwritev)(fd_, &vec, 1, orig_size)),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(1)));
+ EXPECT_EQ(Size(), orig_size + (n >= 0 ? n : 0));
}
// getdents returns EFAULT when the you claim the buffer is large enough, but
@@ -283,29 +296,6 @@ TEST_F(PartialBadBufferTest, GetdentsOneEntry) {
SyscallSucceedsWithValue(Gt(0)));
}
-// Verify that when write returns EFAULT the kernel hasn't silently written
-// the initial valid bytes.
-TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) {
- // FIXME(b/24788078): The sentry write syscalls will return immediately
- // if Access returns an error, but Access may not return an error
- // and the sentry will instead perform a partial write.
- SKIP_IF(IsRunningOnGvisor());
-
- bad_buffer_[0] = 'A';
- EXPECT_THAT(RetryEINTR(write)(fd_, bad_buffer_, 10),
- SyscallFailsWithErrno(EFAULT));
-
- size_t size = 255;
- char buf[255];
- memset(buf, 0, size);
-
- EXPECT_THAT(RetryEINTR(pread)(fd_, buf, size, 0),
- SyscallSucceedsWithValue(sizeof(kMessage) - 1));
-
- // 'A' has not been written.
- EXPECT_STREQ(buf, kMessage);
-}
-
PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
struct sockaddr_storage addr;
memset(&addr, 0, sizeof(addr));
diff --git a/test/syscalls/linux/ping_socket.cc b/test/syscalls/linux/ping_socket.cc
new file mode 100644
index 000000000..a9bfdb37b
--- /dev/null
+++ b/test/syscalls/linux/ping_socket.cc
@@ -0,0 +1,91 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+class PingSocket : public ::testing::Test {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // The loopback address.
+ struct sockaddr_in addr_;
+};
+
+void PingSocket::SetUp() {
+ // On some hosts ping sockets are restricted to specific groups using the
+ // sysctl "ping_group_range".
+ int s = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP);
+ if (s < 0 && errno == EPERM) {
+ GTEST_SKIP();
+ }
+ close(s);
+
+ addr_ = {};
+ // Just a random port as the destination port number is irrelevant for ping
+ // sockets.
+ addr_.sin_port = 12345;
+ addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr_.sin_family = AF_INET;
+}
+
+void PingSocket::TearDown() {}
+
+// Test ICMP port exhaustion returns EAGAIN.
+//
+// We disable both random/cooperative S/R for this test as it makes way too many
+// syscalls.
+TEST_F(PingSocket, ICMPPortExhaustion_NoRandomSave) {
+ DisableSave ds;
+ std::vector<FileDescriptor> sockets;
+ constexpr int kSockets = 65536;
+ addr_.sin_port = 0;
+ for (int i = 0; i < kSockets; i++) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP));
+ int ret = connect(s.get(), reinterpret_cast<struct sockaddr*>(&addr_),
+ sizeof(addr_));
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index c0b354e65..34291850d 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -25,6 +25,7 @@
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -144,11 +145,10 @@ TEST_P(PipeTest, Flags) {
if (IsNamedPipe()) {
// May be stubbed to zero; define locally.
- constexpr int kLargefile = 0100000;
EXPECT_THAT(fcntl(rfd_.get(), F_GETFL),
- SyscallSucceedsWithValue(kLargefile | O_RDONLY));
+ SyscallSucceedsWithValue(kOLargeFile | O_RDONLY));
EXPECT_THAT(fcntl(wfd_.get(), F_GETFL),
- SyscallSucceedsWithValue(kLargefile | O_WRONLY));
+ SyscallSucceedsWithValue(kOLargeFile | O_WRONLY));
} else {
EXPECT_THAT(fcntl(rfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_RDONLY));
EXPECT_THAT(fcntl(wfd_.get(), F_GETFL), SyscallSucceedsWithValue(O_WRONLY));
@@ -212,6 +212,20 @@ TEST(Pipe2Test, BadOptions) {
EXPECT_THAT(pipe2(fds, 0xDEAD), SyscallFailsWithErrno(EINVAL));
}
+// Tests that opening named pipes with O_TRUNC shouldn't cause an error, but
+// calls to (f)truncate should.
+TEST(NamedPipeTest, Truncate) {
+ const std::string tmp_path = NewTempAbsPath();
+ SKIP_IF(mkfifo(tmp_path.c_str(), 0644) != 0);
+
+ ASSERT_THAT(open(tmp_path.c_str(), O_NONBLOCK | O_RDONLY), SyscallSucceeds());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(tmp_path.c_str(), O_RDWR | O_NONBLOCK | O_TRUNC));
+
+ ASSERT_THAT(truncate(tmp_path.c_str(), 0), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
TEST_P(PipeTest, Seek) {
SKIP_IF(!CreateBlocking());
@@ -251,6 +265,8 @@ TEST_P(PipeTest, OffsetCalls) {
SyscallFailsWithErrno(ESPIPE));
struct iovec iov;
+ iov.iov_base = &buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(preadv(wfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
EXPECT_THAT(pwritev(rfd_.get(), &iov, 1, 0), SyscallFailsWithErrno(ESPIPE));
}
@@ -615,11 +631,14 @@ INSTANTIATE_TEST_SUITE_P(
"namednonblocking",
[](int fds[2], bool* is_blocking, bool* is_namedpipe) {
// Create a new file-based pipe (non-blocking).
- auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
- SKIP_IF(mkfifo(file.path().c_str(), 0644) != 0);
- fds[0] = open(file.path().c_str(), O_NONBLOCK | O_RDONLY);
- fds[1] = open(file.path().c_str(), O_NONBLOCK | O_WRONLY);
+ std::string path;
+ {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ path = file.path();
+ }
+ SKIP_IF(mkfifo(path.c_str(), 0644) != 0);
+ fds[0] = open(path.c_str(), O_NONBLOCK | O_RDONLY);
+ fds[1] = open(path.c_str(), O_NONBLOCK | O_WRONLY);
MaybeSave();
*is_blocking = false;
*is_namedpipe = true;
@@ -629,13 +648,15 @@ INSTANTIATE_TEST_SUITE_P(
"namedblocking",
[](int fds[2], bool* is_blocking, bool* is_namedpipe) {
// Create a new file-based pipe (blocking).
- auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds());
- SKIP_IF(mkfifo(file.path().c_str(), 0644) != 0);
- ScopedThread t([&file, &fds]() {
- fds[1] = open(file.path().c_str(), O_WRONLY);
- });
- fds[0] = open(file.path().c_str(), O_RDONLY);
+ std::string path;
+ {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ path = file.path();
+ }
+ SKIP_IF(mkfifo(path.c_str(), 0644) != 0);
+ ScopedThread t(
+ [&path, &fds]() { fds[1] = open(path.c_str(), O_WRONLY); });
+ fds[0] = open(path.c_str(), O_RDONLY);
t.Join();
MaybeSave();
*is_blocking = true;
diff --git a/test/syscalls/linux/poll.cc b/test/syscalls/linux/poll.cc
index 9e5aa7fd0..7a316427d 100644
--- a/test/syscalls/linux/poll.cc
+++ b/test/syscalls/linux/poll.cc
@@ -259,14 +259,14 @@ TEST_F(PollTest, Nfds) {
TEST_PCHECK(getrlimit(RLIMIT_NOFILE, &rlim) == 0);
// gVisor caps the number of FDs that epoll can use beyond RLIMIT_NOFILE.
- constexpr rlim_t gVisorMax = 1048576;
- if (rlim.rlim_cur > gVisorMax) {
- rlim.rlim_cur = gVisorMax;
+ constexpr rlim_t maxFD = 4096;
+ if (rlim.rlim_cur > maxFD) {
+ rlim.rlim_cur = maxFD;
TEST_PCHECK(setrlimit(RLIMIT_NOFILE, &rlim) == 0);
}
rlim_t max_fds = rlim.rlim_cur;
- std::cout << "Using limit: " << max_fds;
+ std::cout << "Using limit: " << max_fds << std::endl;
// Create an eventfd. Since its value is initially zero, it is writable.
FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD());
@@ -275,7 +275,8 @@ TEST_F(PollTest, Nfds) {
// Each entry in the 'fds' array refers to the eventfd and polls for
// "writable" events (events=POLLOUT). This essentially guarantees that the
// poll() is a no-op and allows negative testing of the 'nfds' parameter.
- std::vector<struct pollfd> fds(max_fds, {.fd = efd.get(), .events = POLLOUT});
+ std::vector<struct pollfd> fds(max_fds + 1,
+ {.fd = efd.get(), .events = POLLOUT});
// Verify that 'nfds' up to RLIMIT_NOFILE are allowed.
EXPECT_THAT(RetryEINTR(poll)(fds.data(), 1, 1), SyscallSucceedsWithValue(1));
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index d07571a5f..04c5161f5 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -226,5 +226,5 @@ int main(int argc, char** argv) {
prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0));
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc
index 30f0d75b3..c4e9cf528 100644
--- a/test/syscalls/linux/prctl_setuid.cc
+++ b/test/syscalls/linux/prctl_setuid.cc
@@ -264,5 +264,5 @@ int main(int argc, char** argv) {
prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0);
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/pread64.cc b/test/syscalls/linux/pread64.cc
index 2cecf2e5f..bcdbbb044 100644
--- a/test/syscalls/linux/pread64.cc
+++ b/test/syscalls/linux/pread64.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
@@ -118,6 +119,21 @@ TEST_F(Pread64Test, EndOfFile) {
EXPECT_THAT(pread64(fd.get(), buf, 1024, 0), SyscallSucceedsWithValue(0));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST_F(Pread64Test, Overflow) {
+ int f = memfd_create("negative", 0);
+ const FileDescriptor fd(f);
+
+ EXPECT_THAT(ftruncate(fd.get(), 0x7fffffffffffffffull), SyscallSucceeds());
+
+ char buf[10];
+ EXPECT_THAT(pread64(fd.get(), buf, sizeof(buf), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(Pread64TestNoTempFile, CantReadSocketPair_NoRandomSave) {
int sock_fds[2];
EXPECT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_fds), SyscallSucceeds());
diff --git a/test/syscalls/linux/preadv.cc b/test/syscalls/linux/preadv.cc
index f7ea44054..5b0743fe9 100644
--- a/test/syscalls/linux/preadv.cc
+++ b/test/syscalls/linux/preadv.cc
@@ -37,6 +37,7 @@ namespace testing {
namespace {
+// Stress copy-on-write. Attempts to reproduce b/38430174.
TEST(PreadvTest, MMConcurrencyStress) {
// Fill a one-page file with zeroes (the contents don't really matter).
const auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
diff --git a/test/syscalls/linux/preadv2.cc b/test/syscalls/linux/preadv2.cc
index c9246367d..4a9acd7ae 100644
--- a/test/syscalls/linux/preadv2.cc
+++ b/test/syscalls/linux/preadv2.cc
@@ -35,6 +35,8 @@ namespace {
#ifndef SYS_preadv2
#if defined(__x86_64__)
#define SYS_preadv2 327
+#elif defined(__aarch64__)
+#define SYS_preadv2 286
#else
#error "Unknown architecture"
#endif
@@ -202,7 +204,7 @@ TEST(Preadv2Test, TestInvalidOffset) {
iov[0].iov_len = 0;
EXPECT_THAT(preadv2(fd.get(), iov.get(), /*iovcnt=*/1, /*offset=*/-8,
- /*flags=*/RWF_HIPRI),
+ /*flags=*/0),
SyscallFailsWithErrno(EINVAL));
}
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index e4c030bbb..d6b875dbf 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -37,6 +37,7 @@
#include <map>
#include <memory>
#include <ostream>
+#include <regex>
#include <string>
#include <unordered_set>
#include <utility>
@@ -51,6 +52,7 @@
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
+#include "absl/synchronization/notification.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/capability_util.h"
@@ -98,9 +100,39 @@ namespace {
#define SUID_DUMP_ROOT 2
#endif /* SUID_DUMP_ROOT */
-// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0
-// because "it isn't needed", even though Linux can return it via F_GETFL.
-constexpr int kOLargeFile = 00100000;
+#if defined(__x86_64__) || defined(__i386__)
+// This list of "required" fields is taken from reading the file
+// arch/x86/kernel/cpu/proc.c and seeing which fields will be unconditionally
+// printed by the kernel.
+static const char* required_fields[] = {
+ "processor",
+ "vendor_id",
+ "cpu family",
+ "model\t\t:",
+ "model name",
+ "stepping",
+ "cpu MHz",
+ "fpu\t\t:",
+ "fpu_exception",
+ "cpuid level",
+ "wp",
+ "bogomips",
+ "clflush size",
+ "cache_alignment",
+ "address sizes",
+ "power management",
+};
+#elif __aarch64__
+// This list of "required" fields is taken from reading the file
+// arch/arm64/kernel/cpuinfo.c and seeing which fields will be unconditionally
+// printed by the kernel.
+static const char* required_fields[] = {
+ "processor", "BogoMIPS", "Features", "CPU implementer",
+ "CPU architecture", "CPU variant", "CPU part", "CPU revision",
+};
+#else
+#error "Unknown architecture"
+#endif
// Takes the subprocess command line and pid.
// If it returns !OK, WithSubprocess returns immediately.
@@ -183,7 +215,8 @@ PosixError WithSubprocess(SubprocessCallback const& running,
siginfo_t info;
// Wait until the child process has exited (WEXITED flag) but don't
// reap the child (WNOWAIT flag).
- waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED);
+ EXPECT_THAT(waitid(P_PID, child_pid, &info, WNOWAIT | WEXITED),
+ SyscallSucceeds());
if (zombied) {
// Arg of "Z" refers to a Zombied Process.
@@ -714,28 +747,6 @@ TEST(ProcCpuinfo, RequiredFieldsArePresent) {
ASSERT_FALSE(proc_cpuinfo.empty());
std::vector<std::string> cpuinfo_fields = absl::StrSplit(proc_cpuinfo, '\n');
- // This list of "required" fields is taken from reading the file
- // arch/x86/kernel/cpu/proc.c and seeing which fields will be unconditionally
- // printed by the kernel.
- static const char* required_fields[] = {
- "processor",
- "vendor_id",
- "cpu family",
- "model\t\t:",
- "model name",
- "stepping",
- "cpu MHz",
- "fpu\t\t:",
- "fpu_exception",
- "cpuid level",
- "wp",
- "bogomips",
- "clflush size",
- "cache_alignment",
- "address sizes",
- "power management",
- };
-
// Check that the usual fields are there. We don't really care about the
// contents.
for (const std::string& field : required_fields) {
@@ -743,8 +754,53 @@ TEST(ProcCpuinfo, RequiredFieldsArePresent) {
}
}
-TEST(ProcCpuinfo, DeniesWrite) {
- EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES));
+TEST(ProcCpuinfo, DeniesWriteNonRoot) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER)));
+
+ // Do setuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test. After calling
+ // setuid(non-zero-UID), there is no way to get root privileges back.
+ ScopedThread([&] {
+ // Use syscall instead of glibc setuid wrapper because we want this setuid
+ // call to only apply to this task. POSIX threads, however, require that all
+ // threads have the same UIDs, so using the setuid wrapper sets all threads'
+ // real UID.
+ // Also drops capabilities.
+ constexpr int kNobody = 65534;
+ EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds());
+ EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES));
+ // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in
+ // kernfs.
+ if (!IsRunningOnGvisor() || IsRunningWithVFS1()) {
+ EXPECT_THAT(truncate("/proc/cpuinfo", 123),
+ SyscallFailsWithErrno(EACCES));
+ }
+ });
+}
+
+// With root privileges, it is possible to open /proc/cpuinfo with write mode,
+// but all write operations will return EIO.
+TEST(ProcCpuinfo, DeniesWriteRoot) {
+ // VFS1 does not behave differently for root/non-root.
+ SKIP_IF(IsRunningWithVFS1());
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER)));
+
+ int fd;
+ EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds());
+ if (fd > 0) {
+ EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO));
+ EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO));
+ }
+ // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in
+ // kernfs.
+ if (!IsRunningOnGvisor() || IsRunningWithVFS1()) {
+ if (fd > 0) {
+ EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO));
+ }
+ EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO));
+ }
}
// Sanity checks that uptime is present.
@@ -983,7 +1039,7 @@ constexpr uint64_t kMappingSize = 100 << 20;
// Tolerance on RSS comparisons to account for background thread mappings,
// reclaimed pages, newly faulted pages, etc.
-constexpr uint64_t kRSSTolerance = 5 << 20;
+constexpr uint64_t kRSSTolerance = 10 << 20;
// Capture RSS before and after an anonymous mapping with passed prot.
void MapPopulateRSS(int prot, uint64_t* before, uint64_t* after) {
@@ -1315,8 +1371,6 @@ TEST(ProcPidSymlink, SubprocessRunning) {
SyscallSucceedsWithValue(sizeof(buf)));
}
-// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
-// on proc files.
TEST(ProcPidSymlink, SubprocessZombied) {
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
@@ -1326,7 +1380,7 @@ TEST(ProcPidSymlink, SubprocessZombied) {
int want = EACCES;
if (!IsRunningOnGvisor()) {
auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
- if (version.major == 4 && version.minor > 3) {
+ if (version.major > 4 || (version.major == 4 && version.minor > 3)) {
want = ENOENT;
}
}
@@ -1339,24 +1393,25 @@ TEST(ProcPidSymlink, SubprocessZombied) {
SyscallFailsWithErrno(want));
}
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
- // on proc files.
- // 4.17 & gVisor: Syscall succeeds and returns 1
- // EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
- // SyscallFailsWithErrno(EACCES));
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between linux on proc
+ // files.
+ //
+ // ~4.3: Syscall fails with EACCES.
+ // 4.17: Syscall succeeds and returns 1.
+ //
+ if (!IsRunningOnGvisor()) {
+ return;
+ }
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
- // on proc files.
- // 4.17 & gVisor: Syscall succeeds and returns 1.
- // EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
- // SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(ReadlinkWhileZombied("ns/pid", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
+
+ EXPECT_THAT(ReadlinkWhileZombied("ns/user", buf, sizeof(buf)),
+ SyscallFailsWithErrno(want));
}
// Test whether /proc/PID/ symlinks can be read for an exited process.
TEST(ProcPidSymlink, SubprocessExited) {
- // FIXME(gvisor.dev/issue/164): These all succeed on gVisor.
- SKIP_IF(IsRunningOnGvisor());
-
char buf[1];
EXPECT_THAT(ReadlinkWhileExited("exe", buf, sizeof(buf)),
@@ -1414,14 +1469,24 @@ TEST(ProcPidFile, SubprocessRunning) {
EXPECT_THAT(ReadWhileRunning("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileRunning("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
}
// Test whether /proc/PID/ files can be read for a zombie process.
TEST(ProcPidFile, SubprocessZombie) {
char buf[1];
- // 4.17: Succeeds and returns 1
- // gVisor: Succeeds and returns 0
+ // FIXME(gvisor.dev/issue/164): Loosen requirement due to inconsistent
+ // behavior on different kernels.
+ //
+ // ~4.3: Succeds and returns 0.
+ // 4.17: Succeeds and returns 1.
+ // gVisor: Succeeds and returns 0.
EXPECT_THAT(ReadWhileZombied("auxv", buf, sizeof(buf)), SyscallSucceeds());
EXPECT_THAT(ReadWhileZombied("cmdline", buf, sizeof(buf)),
@@ -1445,9 +1510,18 @@ TEST(ProcPidFile, SubprocessZombie) {
EXPECT_THAT(ReadWhileZombied("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+ EXPECT_THAT(ReadWhileZombied("oom_score", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(ReadWhileZombied("oom_score_adj", buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
// FIXME(gvisor.dev/issue/164): Inconsistent behavior between gVisor and linux
// on proc files.
+ //
+ // ~4.3: Fails and returns EACCES.
// gVisor & 4.17: Succeeds and returns 1.
+ //
// EXPECT_THAT(ReadWhileZombied("io", buf, sizeof(buf)),
// SyscallFailsWithErrno(EACCES));
}
@@ -1456,9 +1530,12 @@ TEST(ProcPidFile, SubprocessZombie) {
TEST(ProcPidFile, SubprocessExited) {
char buf[1];
- // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels
+ // FIXME(gvisor.dev/issue/164): Inconsistent behavior between kernels.
+ //
+ // ~4.3: Fails and returns ESRCH.
// gVisor: Fails with ESRCH.
// 4.17: Succeeds and returns 1.
+ //
// EXPECT_THAT(ReadWhileExited("auxv", buf, sizeof(buf)),
// SyscallFailsWithErrno(ESRCH));
@@ -1500,6 +1577,15 @@ TEST(ProcPidFile, SubprocessExited) {
EXPECT_THAT(ReadWhileExited("uid_map", buf, sizeof(buf)),
SyscallSucceedsWithValue(sizeof(buf)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(gvisor.dev/issue/164): Succeeds on gVisor.
+ EXPECT_THAT(ReadWhileExited("oom_score", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
+ }
+
+ EXPECT_THAT(ReadWhileExited("oom_score_adj", buf, sizeof(buf)),
+ SyscallFailsWithErrno(ESRCH));
}
PosixError DirContainsImpl(absl::string_view path,
@@ -1630,7 +1716,7 @@ TEST(ProcTask, KilledThreadsDisappear) {
EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task",
TaskFiles(initial, {child1.Tid()})));
- // Stat child1's task file.
+ // Stat child1's task file. Regression test for b/32097707.
struct stat statbuf;
const std::string child1_task_file =
absl::StrCat("/proc/self/task/", child1.Tid());
@@ -1658,7 +1744,7 @@ TEST(ProcTask, KilledThreadsDisappear) {
EXPECT_NO_ERRNO(EventuallyDirContainsExactly(
"/proc/self/task", TaskFiles(initial, {child3.Tid(), child5.Tid()})));
- // Stat child1's task file again. This time it should fail.
+ // Stat child1's task file again. This time it should fail. See b/32097707.
EXPECT_THAT(stat(child1_task_file.c_str(), &statbuf),
SyscallFailsWithErrno(ENOENT));
@@ -1813,7 +1899,7 @@ TEST(ProcSysVmOvercommitMemory, HasNumericValue) {
}
// Check that link for proc fd entries point the target node, not the
-// symlink itself.
+// symlink itself. Regression test for b/31155070.
TEST(ProcTaskFd, FstatatFollowsSymlink) {
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const FileDescriptor fd =
@@ -1872,6 +1958,20 @@ TEST(ProcMounts, IsSymlink) {
EXPECT_EQ(link, "self/mounts");
}
+TEST(ProcSelfMountinfo, RequiredFieldsArePresent) {
+ auto mountinfo =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mountinfo"));
+ EXPECT_THAT(
+ mountinfo,
+ AllOf(
+ // Root mount.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ /\S* / (rw|ro).*- \S+ \S+ (rw|ro)\S*)"),
+ // Proc mount - always rw.
+ ContainsRegex(
+ R"([0-9]+ [0-9]+ [0-9]+:[0-9]+ / /proc rw.*- \S+ \S+ rw\S*)")));
+}
+
// Check that /proc/self/mounts looks something like a real mounts file.
TEST(ProcSelfMounts, RequiredFieldsArePresent) {
auto mounts = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/mounts"));
@@ -1884,43 +1984,77 @@ TEST(ProcSelfMounts, RequiredFieldsArePresent) {
}
void CheckDuplicatesRecursively(std::string path) {
- errno = 0;
- DIR* dir = opendir(path.c_str());
- if (dir == nullptr) {
- // Ignore any directories we can't read or missing directories as the
- // directory could have been deleted/mutated from the time the parent
- // directory contents were read.
- return;
- }
- auto dir_closer = Cleanup([&dir]() { closedir(dir); });
- std::unordered_set<std::string> children;
- while (true) {
- // Readdir(3): If the end of the directory stream is reached, NULL is
- // returned and errno is not changed. If an error occurs, NULL is returned
- // and errno is set appropriately. To distinguish end of stream and from an
- // error, set errno to zero before calling readdir() and then check the
- // value of errno if NULL is returned.
+ std::vector<std::string> child_dirs;
+
+ // There is the known issue of the linux procfs, that two consequent calls of
+ // readdir can return the same entry twice if between these calls one or more
+ // entries have been removed from this directory.
+ int max_attempts = 5;
+ for (int i = 0; i < max_attempts; i++) {
+ child_dirs.clear();
errno = 0;
- struct dirent* dp = readdir(dir);
- if (dp == nullptr) {
- ASSERT_EQ(errno, 0) << path;
- break; // We're done.
+ bool success = true;
+ DIR* dir = opendir(path.c_str());
+ if (dir == nullptr) {
+ // Ignore any directories we can't read or missing directories as the
+ // directory could have been deleted/mutated from the time the parent
+ // directory contents were read.
+ return;
}
+ auto dir_closer = Cleanup([&dir]() { closedir(dir); });
+ std::unordered_set<std::string> children;
+ while (true) {
+ // Readdir(3): If the end of the directory stream is reached, NULL is
+ // returned and errno is not changed. If an error occurs, NULL is
+ // returned and errno is set appropriately. To distinguish end of stream
+ // and from an error, set errno to zero before calling readdir() and then
+ // check the value of errno if NULL is returned.
+ errno = 0;
+ struct dirent* dp = readdir(dir);
+ if (dp == nullptr) {
+ // Linux will return EINVAL when calling getdents on a /proc/tid/net
+ // file corresponding to a zombie task.
+ // See fs/proc/proc_net.c:proc_tgid_net_readdir().
+ //
+ // We just ignore the directory in this case.
+ if (errno == EINVAL && absl::StartsWith(path, "/proc/") &&
+ absl::EndsWith(path, "/net")) {
+ break;
+ }
- if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) {
- continue;
+ // Otherwise, no errors are allowed.
+ ASSERT_EQ(errno, 0) << path;
+ break; // We're done.
+ }
+
+ const std::string name = dp->d_name;
+
+ if (name == "." || name == "..") {
+ continue;
+ }
+
+ // Ignore a duplicate entry if it isn't the last attempt.
+ if (i == max_attempts - 1) {
+ ASSERT_EQ(children.find(name), children.end())
+ << absl::StrCat(path, "/", name);
+ } else if (children.find(name) != children.end()) {
+ std::cerr << "Duplicate entry: " << i << ":"
+ << absl::StrCat(path, "/", name) << std::endl;
+ success = false;
+ break;
+ }
+ children.insert(name);
+
+ if (dp->d_type == DT_DIR) {
+ child_dirs.push_back(name);
+ }
}
-
- ASSERT_EQ(children.find(std::string(dp->d_name)), children.end())
- << dp->d_name;
- children.insert(std::string(dp->d_name));
-
- ASSERT_NE(dp->d_type, DT_UNKNOWN);
-
- if (dp->d_type != DT_DIR) {
- continue;
+ if (success) {
+ break;
}
- CheckDuplicatesRecursively(absl::StrCat(path, "/", dp->d_name));
+ }
+ for (auto dname = child_dirs.begin(); dname != child_dirs.end(); dname++) {
+ CheckDuplicatesRecursively(absl::StrCat(path, "/", *dname));
}
}
@@ -1983,10 +2117,48 @@ TEST(Proc, GetdentsEnoent) {
},
nullptr, nullptr));
char buf[1024];
- ASSERT_THAT(syscall(SYS_getdents, fd.get(), buf, sizeof(buf)),
+ ASSERT_THAT(syscall(SYS_getdents64, fd.get(), buf, sizeof(buf)),
SyscallFailsWithErrno(ENOENT));
}
+void CheckSyscwFromIOFile(const std::string& path, const std::string& regex) {
+ std::string output;
+ ASSERT_NO_ERRNO(GetContents(path, &output));
+ ASSERT_THAT(output, ContainsRegex(absl::StrCat("syscw:\\s+", regex, "\n")));
+}
+
+// Checks that there is variable accounting of IO between threads/tasks.
+TEST(Proc, PidTidIOAccounting) {
+ absl::Notification notification;
+
+ // Run a thread with a bunch of writes. Check that io account records exactly
+ // the number of write calls. File open/close is there to prevent buffering.
+ ScopedThread writer([&notification] {
+ const int num_writes = 100;
+ for (int i = 0; i < num_writes; i++) {
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ ASSERT_NO_ERRNO(SetContents(path.path(), "a"));
+ }
+ notification.Notify();
+ const std::string& writer_dir =
+ absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io");
+
+ CheckSyscwFromIOFile(writer_dir, std::to_string(num_writes));
+ });
+
+ // Run a thread and do no writes. Check that no writes are recorded.
+ ScopedThread noop([&notification] {
+ notification.WaitForNotification();
+ const std::string& noop_dir =
+ absl::StrCat("/proc/", getpid(), "/task/", gettid(), "/io");
+
+ CheckSyscwFromIOFile(noop_dir, "0");
+ });
+
+ writer.Join();
+ noop.Join();
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
@@ -1997,5 +2169,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 897cf4950..b9a5a99bd 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -20,8 +20,13 @@
#include <sys/syscall.h>
#include <sys/types.h>
+#include <vector>
+
#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
#include "absl/time/clock.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/capability_util.h"
@@ -33,6 +38,31 @@ namespace gvisor {
namespace testing {
namespace {
+constexpr const char kProcNet[] = "/proc/net";
+
+TEST(ProcNetSymlinkTarget, FileMode) {
+ struct stat s;
+ ASSERT_THAT(stat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFDIR);
+ EXPECT_EQ(s.st_mode & 0777, 0555);
+}
+
+TEST(ProcNetSymlink, FileMode) {
+ struct stat s;
+ ASSERT_THAT(lstat(kProcNet, &s), SyscallSucceeds());
+ EXPECT_EQ(s.st_mode & S_IFMT, S_IFLNK);
+ EXPECT_EQ(s.st_mode & 0777, 0777);
+}
+
+TEST(ProcNetSymlink, Contents) {
+ char buf[40] = {};
+ int n = readlink(kProcNet, buf, sizeof(buf));
+ ASSERT_THAT(n, SyscallSucceeds());
+
+ buf[n] = 0;
+ EXPECT_STREQ(buf, "self/net");
+}
+
TEST(ProcNetIfInet6, Format) {
auto ifinet6 = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/if_inet6"));
EXPECT_THAT(ifinet6,
@@ -67,9 +97,62 @@ TEST(ProcSysNetIpv4Sack, CanReadAndWrite) {
EXPECT_EQ(buf, to_write);
}
+// DeviceEntry is an entry in /proc/net/dev
+struct DeviceEntry {
+ std::string name;
+ uint64_t stats[16];
+};
+
+PosixErrorOr<std::vector<DeviceEntry>> GetDeviceMetricsFromProc(
+ const std::string dev) {
+ std::vector<std::string> lines = absl::StrSplit(dev, '\n');
+ std::vector<DeviceEntry> entries;
+
+ // /proc/net/dev prints 2 lines of headers followed by a line of metrics for
+ // each network interface.
+ for (unsigned i = 2; i < lines.size(); i++) {
+ // Ignore empty lines.
+ if (lines[i].empty()) {
+ continue;
+ }
+
+ std::vector<std::string> values =
+ absl::StrSplit(lines[i], ' ', absl::SkipWhitespace());
+
+ // Interface name + 16 values.
+ if (values.size() != 17) {
+ return PosixError(EINVAL, "invalid line: " + lines[i]);
+ }
+
+ DeviceEntry entry;
+ entry.name = values[0];
+ // Skip the interface name and read only the values.
+ for (unsigned j = 1; j < 17; j++) {
+ uint64_t num;
+ if (!absl::SimpleAtoi(values[j], &num)) {
+ return PosixError(EINVAL, "invalid value: " + values[j]);
+ }
+ entry.stats[j - 1] = num;
+ }
+
+ entries.push_back(entry);
+ }
+
+ return entries;
+}
+
+// TEST(ProcNetDev, Format) tests that /proc/net/dev is parsable and
+// contains at least one entry.
+TEST(ProcNetDev, Format) {
+ auto dev = ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/dev"));
+ auto entries = ASSERT_NO_ERRNO_AND_VALUE(GetDeviceMetricsFromProc(dev));
+
+ EXPECT_GT(entries.size(), 0);
+}
+
PosixErrorOr<uint64_t> GetSNMPMetricFromProc(const std::string snmp,
- const std::string &type,
- const std::string &item) {
+ const std::string& type,
+ const std::string& item) {
std::vector<std::string> snmp_vec = absl::StrSplit(snmp, '\n');
// /proc/net/snmp prints a line of headers followed by a line of metrics.
@@ -127,7 +210,7 @@ TEST(ProcNetSnmp, TcpReset_NoRandomSave) {
};
ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
- ASSERT_THAT(connect(s.get(), (struct sockaddr *)&sin, sizeof(sin)),
+ ASSERT_THAT(connect(s.get(), (struct sockaddr*)&sin, sizeof(sin)),
SyscallFailsWithErrno(ECONNREFUSED));
uint64_t newAttemptFails;
@@ -172,19 +255,19 @@ TEST(ProcNetSnmp, TcpEstab_NoRandomSave) {
};
ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
- ASSERT_THAT(bind(s_listen.get(), (struct sockaddr *)&sin, sizeof(sin)),
+ ASSERT_THAT(bind(s_listen.get(), (struct sockaddr*)&sin, sizeof(sin)),
SyscallSucceeds());
ASSERT_THAT(listen(s_listen.get(), 1), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = sizeof(sin);
ASSERT_THAT(
- getsockname(s_listen.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen),
+ getsockname(s_listen.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen),
SyscallSucceeds());
FileDescriptor s_connect =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, 0));
- ASSERT_THAT(connect(s_connect.get(), (struct sockaddr *)&sin, sizeof(sin)),
+ ASSERT_THAT(connect(s_connect.get(), (struct sockaddr*)&sin, sizeof(sin)),
SyscallSucceeds());
auto s_accept =
@@ -260,7 +343,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
.sin_port = htons(4444),
};
ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
- ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)),
+ ASSERT_THAT(sendto(s.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)),
SyscallSucceedsWithValue(1));
uint64_t newOutDatagrams;
@@ -275,7 +358,7 @@ TEST(ProcNetSnmp, UdpNoPorts_NoRandomSave) {
EXPECT_EQ(oldNoPorts, newNoPorts - 1);
}
-TEST(ProcNetSnmp, UdpIn) {
+TEST(ProcNetSnmp, UdpIn_NoRandomSave) {
// TODO(gvisor.dev/issue/866): epsocket metrics are not savable.
const DisableSave ds;
@@ -295,18 +378,18 @@ TEST(ProcNetSnmp, UdpIn) {
.sin_port = htons(0),
};
ASSERT_EQ(inet_pton(AF_INET, "127.0.0.1", &(sin.sin_addr)), 1);
- ASSERT_THAT(bind(server.get(), (struct sockaddr *)&sin, sizeof(sin)),
+ ASSERT_THAT(bind(server.get(), (struct sockaddr*)&sin, sizeof(sin)),
SyscallSucceeds());
// Get the port bound by the server socket.
socklen_t addrlen = sizeof(sin);
ASSERT_THAT(
- getsockname(server.get(), reinterpret_cast<sockaddr *>(&sin), &addrlen),
+ getsockname(server.get(), reinterpret_cast<sockaddr*>(&sin), &addrlen),
SyscallSucceeds());
FileDescriptor client =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
ASSERT_THAT(
- sendto(client.get(), "a", 1, 0, (struct sockaddr *)&sin, sizeof(sin)),
+ sendto(client.get(), "a", 1, 0, (struct sockaddr*)&sin, sizeof(sin)),
SyscallSucceedsWithValue(1));
char buf[128];
@@ -326,6 +409,113 @@ TEST(ProcNetSnmp, UdpIn) {
EXPECT_EQ(oldInDatagrams, newInDatagrams - 1);
}
+TEST(ProcNetSnmp, CheckNetStat) {
+ // TODO(b/155123175): SNMP and netstat don't work on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ std::string contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/netstat"));
+
+ int name_count = 0;
+ int value_count = 0;
+ std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
+ for (int i = 0; i + 1 < lines.size(); i += 2) {
+ std::vector<absl::string_view> names =
+ absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
+ std::vector<absl::string_view> values =
+ absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
+ EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
+ << "' and '" << lines[i + 1] << "'";
+ for (int j = 0; j < names.size() && j < values.size(); ++j) {
+ if (names[j] == "TCPOrigDataSent" || names[j] == "TCPSynRetrans" ||
+ names[j] == "TCPDSACKRecv" || names[j] == "TCPDSACKOfoRecv") {
+ ++name_count;
+ int64_t val;
+ if (absl::SimpleAtoi(values[j], &val)) {
+ ++value_count;
+ }
+ }
+ }
+ }
+ EXPECT_EQ(name_count, 4);
+ EXPECT_EQ(value_count, 4);
+}
+
+TEST(ProcNetSnmp, Stat) {
+ struct stat st = {};
+ ASSERT_THAT(stat("/proc/net/snmp", &st), SyscallSucceeds());
+}
+
+TEST(ProcNetSnmp, CheckSnmp) {
+ // TODO(b/155123175): SNMP and netstat don't work on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ std::string contents =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/snmp"));
+
+ int name_count = 0;
+ int value_count = 0;
+ std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n');
+ for (int i = 0; i + 1 < lines.size(); i += 2) {
+ std::vector<absl::string_view> names =
+ absl::StrSplit(lines[i], absl::ByAnyChar("\t "));
+ std::vector<absl::string_view> values =
+ absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t "));
+ EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i]
+ << "' and '" << lines[i + 1] << "'";
+ for (int j = 0; j < names.size() && j < values.size(); ++j) {
+ if (names[j] == "RetransSegs") {
+ ++name_count;
+ int64_t val;
+ if (absl::SimpleAtoi(values[j], &val)) {
+ ++value_count;
+ }
+ }
+ }
+ }
+ EXPECT_EQ(name_count, 1);
+ EXPECT_EQ(value_count, 1);
+}
+
+TEST(ProcSysNetIpv4Recovery, Exists) {
+ EXPECT_THAT(open("/proc/sys/net/ipv4/tcp_recovery", O_RDONLY),
+ SyscallSucceeds());
+}
+
+TEST(ProcSysNetIpv4Recovery, CanReadAndWrite) {
+ // TODO(b/162988252): Enable save/restore for this test after the bug is
+ // fixed.
+ DisableSave ds;
+
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE))));
+
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open("/proc/sys/net/ipv4/tcp_recovery", O_RDWR));
+
+ char buf[10] = {'\0'};
+ char to_write = '2';
+
+ // Check initial value is set to 1.
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(to_write) + 1));
+ EXPECT_EQ(strcmp(buf, "1\n"), 0);
+
+ // Set tcp_recovery to one of the allowed constants.
+ EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0),
+ SyscallSucceedsWithValue(sizeof(to_write)));
+ EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0),
+ SyscallSucceedsWithValue(sizeof(to_write) + 1));
+ EXPECT_EQ(strcmp(buf, "2\n"), 0);
+
+ // Set tcp_recovery to any random value.
+ char kMessage[] = "100";
+ EXPECT_THAT(PwriteFd(fd.get(), kMessage, strlen(kMessage), 0),
+ SyscallSucceedsWithValue(strlen(kMessage)));
+ EXPECT_THAT(PreadFd(fd.get(), buf, sizeof(kMessage), 0),
+ SyscallSucceedsWithValue(sizeof(kMessage)));
+ EXPECT_EQ(strcmp(buf, "100\n"), 0);
+}
+
TEST(ProcSysNetIpv4IpForward, Exists) {
auto fd =
ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR));
diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc
index 2659f6a98..5b6e3e3cd 100644
--- a/test/syscalls/linux/proc_net_tcp.cc
+++ b/test/syscalls/linux/proc_net_tcp.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc
index f06f1a24b..786b4b4af 100644
--- a/test/syscalls/linux/proc_net_udp.cc
+++ b/test/syscalls/linux/proc_net_udp.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <netinet/tcp.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
index 66db0acaa..a63067586 100644
--- a/test/syscalls/linux/proc_net_unix.cc
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -106,7 +106,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
std::vector<UnixEntry> entries;
std::vector<std::string> lines = absl::StrSplit(content, '\n');
std::cerr << "<contents of /proc/net/unix>" << std::endl;
- for (std::string line : lines) {
+ for (const std::string& line : lines) {
// Emit the proc entry to the test output to provide context for the test
// results.
std::cerr << line << std::endl;
@@ -374,7 +374,7 @@ TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) {
// corresponding entries, as they don't have an address yet.
if (IsRunningOnGvisor()) {
ASSERT_EQ(entries.size(), 2);
- for (auto e : entries) {
+ for (const auto& e : entries) {
ASSERT_EQ(e.state, SS_DISCONNECTING);
}
}
@@ -403,7 +403,7 @@ TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) {
// corresponding entries, as they don't have an address yet.
if (IsRunningOnGvisor()) {
ASSERT_EQ(entries.size(), 2);
- for (auto e : entries) {
+ for (const auto& e : entries) {
ASSERT_EQ(e.state, SS_DISCONNECTING);
}
}
diff --git a/test/syscalls/linux/proc_pid_oomscore.cc b/test/syscalls/linux/proc_pid_oomscore.cc
new file mode 100644
index 000000000..707821a3f
--- /dev/null
+++ b/test/syscalls/linux/proc_pid_oomscore.cc
@@ -0,0 +1,72 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+
+#include <exception>
+#include <iostream>
+#include <string>
+
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+PosixErrorOr<int> ReadProcNumber(std::string path) {
+ ASSIGN_OR_RETURN_ERRNO(std::string contents, GetContents(path));
+ EXPECT_EQ(contents[contents.length() - 1], '\n');
+
+ int num;
+ if (!absl::SimpleAtoi(contents, &num)) {
+ return PosixError(EINVAL, "invalid value: " + contents);
+ }
+
+ return num;
+}
+
+TEST(ProcPidOomscoreTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score"));
+ EXPECT_LE(oom_score, 1000);
+ EXPECT_GE(oom_score, -1000);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicRead) {
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+
+ // oom_score_adj defaults to 0.
+ EXPECT_EQ(oom_score, 0);
+}
+
+TEST(ProcPidOomscoreAdjTest, BasicWrite) {
+ constexpr int test_value = 7;
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/oom_score_adj", O_WRONLY));
+ ASSERT_THAT(
+ RetryEINTR(write)(fd.get(), std::to_string(test_value).c_str(), 1),
+ SyscallSucceeds());
+
+ auto const oom_score =
+ ASSERT_NO_ERRNO_AND_VALUE(ReadProcNumber("/proc/self/oom_score_adj"));
+ EXPECT_EQ(oom_score, test_value);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc
index 7f2e8f203..9fb1b3a2c 100644
--- a/test/syscalls/linux/proc_pid_smaps.cc
+++ b/test/syscalls/linux/proc_pid_smaps.cc
@@ -173,7 +173,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
return;
}
unknown_fields.insert(std::string(key));
- std::cerr << "skipping unknown smaps field " << key;
+ std::cerr << "skipping unknown smaps field " << key << std::endl;
};
auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
@@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps(
// amount of whitespace).
if (!entry) {
std::cerr << "smaps line not considered a maps line: "
- << maybe_maps_entry.error_message();
+ << maybe_maps_entry.error_message() << std::endl;
return PosixError(
EINVAL,
absl::StrCat("smaps field line without preceding maps line: ", l));
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index 8f3800380..926690eb8 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -32,6 +32,7 @@
#include "absl/time/time.h"
#include "test/util/logging.h"
#include "test/util/multiprocess_util.h"
+#include "test/util/platform_util.h"
#include "test/util/signal_util.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -178,7 +179,8 @@ TEST(PtraceTest, GetSigMask) {
// Install a signal handler for kBlockSignal to avoid termination and block
// it.
- TEST_PCHECK(signal(kBlockSignal, +[](int signo) {}) != SIG_ERR);
+ TEST_PCHECK(signal(
+ kBlockSignal, +[](int signo) {}) != SIG_ERR);
MaybeSave();
TEST_PCHECK(sigprocmask(SIG_SETMASK, &blocked, nullptr) == 0);
MaybeSave();
@@ -398,9 +400,11 @@ TEST(PtraceTest, GetRegSet) {
// Read exactly the full register set.
EXPECT_EQ(iov.iov_len, sizeof(regs));
-#ifdef __x86_64__
+#if defined(__x86_64__)
// Child called kill(2), with SIGSTOP as arg 2.
EXPECT_EQ(regs.rsi, SIGSTOP);
+#elif defined(__aarch64__)
+ EXPECT_EQ(regs.regs[1], SIGSTOP);
#endif
// Suppress SIGSTOP and resume the child.
@@ -750,15 +754,23 @@ TEST(PtraceTest,
SyscallSucceeds());
EXPECT_TRUE(siginfo.si_code == SIGTRAP || siginfo.si_code == (SIGTRAP | 0x80))
<< "si_code = " << siginfo.si_code;
-#ifdef __x86_64__
+
{
struct user_regs_struct regs = {};
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
EXPECT_TRUE(regs.orig_rax == SYS_vfork || regs.orig_rax == SYS_clone)
<< "orig_rax = " << regs.orig_rax;
EXPECT_EQ(grandchild_pid, regs.rax);
- }
+#elif defined(__aarch64__)
+ EXPECT_TRUE(regs.regs[8] == SYS_clone) << "regs[8] = " << regs.regs[8];
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
#endif // defined(__x86_64__)
+ }
// After this point, the child will be making wait4 syscalls that will be
// interrupted by saving, so saving is not permitted. Note that this is
@@ -803,14 +815,21 @@ TEST(PtraceTest,
SyscallSucceedsWithValue(child_pid));
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == (SIGTRAP | 0x80))
<< " status " << status;
-#ifdef __x86_64__
{
struct user_regs_struct regs = {};
- ASSERT_THAT(ptrace(PTRACE_GETREGS, child_pid, 0, &regs), SyscallSucceeds());
+ struct iovec iov;
+ iov.iov_base = &regs;
+ iov.iov_len = sizeof(regs);
+ EXPECT_THAT(ptrace(PTRACE_GETREGSET, child_pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
+#if defined(__x86_64__)
EXPECT_EQ(SYS_wait4, regs.orig_rax);
EXPECT_EQ(grandchild_pid, regs.rax);
- }
+#elif defined(__aarch64__)
+ EXPECT_EQ(SYS_wait4, regs.regs[8]);
+ EXPECT_EQ(grandchild_pid, regs.regs[0]);
#endif // defined(__x86_64__)
+ }
// Detach from the child and wait for it to exit.
ASSERT_THAT(ptrace(PTRACE_DETACH, child_pid, 0, 0), SyscallSucceeds());
@@ -823,13 +842,8 @@ TEST(PtraceTest,
// These tests requires knowledge of architecture-specific syscall convention.
#ifdef __x86_64__
TEST(PtraceTest, Int3) {
- switch (GvisorPlatform()) {
- case Platform::kKVM:
- // TODO(b/124248694): int3 isn't handled properly.
- return;
- default:
- break;
- }
+ SKIP_IF(PlatformSupportInt3() == PlatformSupport::NotSupported);
+
pid_t const child_pid = fork();
if (child_pid == 0) {
// In child process.
@@ -1191,7 +1205,7 @@ TEST(PtraceTest, SeizeSetOptions) {
// gVisor is not susceptible to this race because
// kernel.Task.waitCollectTraceeStopLocked() checks specifically for an
// active ptraceStop, which is not initiated if SIGKILL is pending.
- std::cout << "Observed syscall-exit after SIGKILL";
+ std::cout << "Observed syscall-exit after SIGKILL" << std::endl;
ASSERT_THAT(waitpid(child_pid, &status, 0),
SyscallSucceedsWithValue(child_pid));
}
@@ -1211,5 +1225,5 @@ int main(int argc, char** argv) {
gvisor::testing::RunExecveChild();
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index 99a0df235..f9392b9e0 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -70,6 +70,8 @@ constexpr absl::Duration kTimeout = absl::Seconds(20);
// The maximum line size in bytes returned per read from a pty file.
constexpr int kMaxLineSize = 4096;
+constexpr char kMasterPath[] = "/dev/ptmx";
+
// glibc defines its own, different, version of struct termios. We care about
// what the kernel does, not glibc.
#define KERNEL_NCCS 19
@@ -362,6 +364,12 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count,
ssize_t n =
ReadFd(fd, static_cast<char*>(buf) + completed, count - completed);
if (n < 0) {
+ if (errno == EAGAIN) {
+ // Linux sometimes returns EAGAIN from this read, despite the fact that
+ // poll returned success. Let's just do what do as we are told and try
+ // again.
+ continue;
+ }
return PosixError(errno, "read failed");
}
completed += n;
@@ -376,9 +384,25 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count,
return PosixError(ETIMEDOUT, "Poll timed out");
}
+TEST(PtyTrunc, Truncate) {
+ // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to
+ // (f)truncate should.
+ FileDescriptor master =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC));
+ int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master));
+ std::string spath = absl::StrCat("/dev/pts/", n);
+ FileDescriptor slave =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC));
+
+ EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL));
+}
+
TEST(BasicPtyTest, StatUnopenedMaster) {
struct stat s;
- ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds());
+ ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds());
EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor));
EXPECT_EQ(s.st_size, 0);
@@ -610,6 +634,11 @@ TEST_F(PtyTest, TermiosAffectsSlave) {
// Verify this by setting ICRNL (which rewrites input \r to \n) and verify that
// it has no effect on the master.
TEST_F(PtyTest, MasterTermiosUnchangable) {
+ struct kernel_termios master_termios = {};
+ EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds());
+ master_termios.c_lflag |= ICRNL;
+ EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds());
+
char c = '\r';
ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1));
@@ -1108,7 +1137,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) {
std::string kExpected = "GO\nBLUE\n!";
// Write each line.
- for (std::string input : kInputs) {
+ for (const std::string& input : kInputs) {
ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()),
SyscallSucceedsWithValue(input.size()));
}
diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc
index 14a4af980..1d7dbefdb 100644
--- a/test/syscalls/linux/pty_root.cc
+++ b/test/syscalls/linux/pty_root.cc
@@ -25,16 +25,26 @@
namespace gvisor {
namespace testing {
-// These tests should be run as root.
namespace {
+// StealTTY tests whether privileged processes can steal controlling terminals.
+// If the stealing process has CAP_SYS_ADMIN in the root user namespace, the
+// test ensures that stealing works. If it has non-root CAP_SYS_ADMIN, it
+// ensures stealing fails.
TEST(JobControlRootTest, StealTTY) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
- // Make this a session leader, which also drops the controlling terminal.
- // In the gVisor test environment, this test will be run as the session
- // leader already (as the sentry init process).
+ bool true_root = true;
if (!IsRunningOnGvisor()) {
+ // If running in Linux, we may only have CAP_SYS_ADMIN in a non-root user
+ // namespace (i.e. we are not truly root). We use init_module as a proxy for
+ // whether we are true root, as it returns EPERM immediately.
+ ASSERT_THAT(syscall(SYS_init_module, nullptr, 0, nullptr), SyscallFails());
+ true_root = errno != EPERM;
+
+ // Make this a session leader, which also drops the controlling terminal.
+ // In the gVisor test environment, this test will be run as the session
+ // leader already (as the sentry init process).
ASSERT_THAT(setsid(), SyscallSucceeds());
}
@@ -53,8 +63,8 @@ TEST(JobControlRootTest, StealTTY) {
ASSERT_THAT(setsid(), SyscallSucceeds());
// We shouldn't be able to steal the terminal with the wrong arg value.
TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0));
- // We should be able to steal it here.
- TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1));
+ // We should be able to steal it if we are true root.
+ TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1));
_exit(0);
}
diff --git a/test/syscalls/linux/pwrite64.cc b/test/syscalls/linux/pwrite64.cc
index b48fe540d..e69794910 100644
--- a/test/syscalls/linux/pwrite64.cc
+++ b/test/syscalls/linux/pwrite64.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
@@ -27,14 +28,7 @@ namespace testing {
namespace {
-// This test is currently very rudimentary.
-//
-// TODO(edahlgren):
-// * bad buffer states (EFAULT).
-// * bad fds (wrong permission, wrong type of file, EBADF).
-// * check offset is not incremented.
-// * check for EOF.
-// * writing to pipes, symlinks, special files.
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
class Pwrite64 : public ::testing::Test {
void SetUp() override {
name_ = NewTempAbsPath();
@@ -72,6 +66,17 @@ TEST_F(Pwrite64, InvalidArgs) {
EXPECT_THAT(close(fd), SyscallSucceeds());
}
+TEST_F(Pwrite64, Overflow) {
+ int fd;
+ ASSERT_THAT(fd = open(name_.c_str(), O_APPEND | O_RDWR), SyscallSucceeds());
+ constexpr int64_t kBufSize = 1024;
+ std::vector<char> buf(kBufSize);
+ std::fill(buf.begin(), buf.end(), 'a');
+ EXPECT_THAT(PwriteFd(fd, buf.data(), buf.size(), 0x7fffffffffffffffull),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc
index 1dbc0d6df..63b686c62 100644
--- a/test/syscalls/linux/pwritev2.cc
+++ b/test/syscalls/linux/pwritev2.cc
@@ -34,6 +34,8 @@ namespace {
#ifndef SYS_pwritev2
#if defined(__x86_64__)
#define SYS_pwritev2 328
+#elif defined(__aarch64__)
+#define SYS_pwritev2 287
#else
#error "Unknown architecture"
#endif
@@ -67,7 +69,7 @@ ssize_t pwritev2(unsigned long fd, const struct iovec* iov,
}
// This test is the base case where we call pwritev (no offset, no flags).
-TEST(Writev2Test, TestBaseCall) {
+TEST(Writev2Test, BaseCall) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
@@ -95,7 +97,7 @@ TEST(Writev2Test, TestBaseCall) {
}
// This test is where we call pwritev2 with a positive offset and no flags.
-TEST(Pwritev2Test, TestValidPositiveOffset) {
+TEST(Pwritev2Test, ValidPositiveOffset) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
std::string prefix(kBufSize, '0');
@@ -127,7 +129,7 @@ TEST(Pwritev2Test, TestValidPositiveOffset) {
// This test is the base case where we call writev by using -1 as the offset.
// The write should use the file offset, so the test increments the file offset
// prior to call pwritev2.
-TEST(Pwritev2Test, TestNegativeOneOffset) {
+TEST(Pwritev2Test, NegativeOneOffset) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
const std::string prefix = "00";
@@ -162,7 +164,7 @@ TEST(Pwritev2Test, TestNegativeOneOffset) {
// pwritev2 requires if the RWF_HIPRI flag is passed, the fd must be opened with
// O_DIRECT. This test implements a correct call with the RWF_HIPRI flag.
-TEST(Pwritev2Test, TestCallWithRWF_HIPRI) {
+TEST(Pwritev2Test, CallWithRWF_HIPRI) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
@@ -187,47 +189,8 @@ TEST(Pwritev2Test, TestCallWithRWF_HIPRI) {
EXPECT_EQ(buf, content);
}
-// This test checks that pwritev2 can be called with valid flags
-TEST(Pwritev2Test, TestCallWithValidFlags) {
- SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
-
- const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
- GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode));
- const FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
-
- std::vector<char> content(kBufSize, '0');
- struct iovec iov;
- iov.iov_base = content.data();
- iov.iov_len = content.size();
-
- EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
- /*offset=*/0, /*flags=*/RWF_DSYNC),
- SyscallSucceedsWithValue(kBufSize));
-
- std::vector<char> buf(content.size());
- EXPECT_THAT(read(fd.get(), buf.data(), buf.size()),
- SyscallSucceedsWithValue(buf.size()));
-
- EXPECT_EQ(buf, content);
-
- SetContent(content);
-
- EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
- /*offset=*/0, /*flags=*/0x4),
- SyscallSucceedsWithValue(kBufSize));
-
- ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR),
- SyscallSucceedsWithValue(content.size()));
-
- EXPECT_THAT(pread(fd.get(), buf.data(), buf.size(), /*offset=*/0),
- SyscallSucceedsWithValue(buf.size()));
-
- EXPECT_EQ(buf, content);
-}
-
// This test calls pwritev2 with a bad file descriptor.
-TEST(Writev2Test, TestBadFile) {
+TEST(Writev2Test, BadFile) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
ASSERT_THAT(pwritev2(/*fd=*/-1, /*iov=*/nullptr, /*iovcnt=*/0,
/*offset=*/0, /*flags=*/0),
@@ -235,7 +198,7 @@ TEST(Writev2Test, TestBadFile) {
}
// This test calls pwrite2 with an invalid offset.
-TEST(Pwritev2Test, TestInvalidOffset) {
+TEST(Pwritev2Test, InvalidOffset) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
@@ -253,7 +216,7 @@ TEST(Pwritev2Test, TestInvalidOffset) {
SyscallFailsWithErrno(EINVAL));
}
-TEST(Pwritev2Test, TestUnseekableFileValid) {
+TEST(Pwritev2Test, UnseekableFileValid) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
int pipe_fds[2];
@@ -283,7 +246,7 @@ TEST(Pwritev2Test, TestUnseekableFileValid) {
// Calling pwritev2 with a non-negative offset calls pwritev. Calling pwritev
// with an unseekable file is not allowed. A pipe is used for an unseekable
// file.
-TEST(Pwritev2Test, TestUnseekableFileInValid) {
+TEST(Pwritev2Test, UnseekableFileInvalid) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
int pipe_fds[2];
@@ -302,7 +265,7 @@ TEST(Pwritev2Test, TestUnseekableFileInValid) {
EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds());
}
-TEST(Pwritev2Test, TestReadOnlyFile) {
+TEST(Pwritev2Test, ReadOnlyFile) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
@@ -321,7 +284,7 @@ TEST(Pwritev2Test, TestReadOnlyFile) {
}
// This test calls pwritev2 with an invalid flag.
-TEST(Pwritev2Test, TestInvalidFlag) {
+TEST(Pwritev2Test, InvalidFlag) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
new file mode 100644
index 000000000..8d6e5c913
--- /dev/null
+++ b/test/syscalls/linux/raw_socket.cc
@@ -0,0 +1,869 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/capability.h>
+#ifndef __fuchsia__
+#include <linux/filter.h>
+#endif // __fuchsia__
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/ip_icmp.h>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <algorithm>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Note: in order to run these tests, /proc/sys/net/ipv4/ping_group_range will
+// need to be configured to let the superuser create ping sockets (see icmp(7)).
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Fixture for tests parameterized by protocol.
+class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
+ protected:
+ // Creates a socket to be used in tests.
+ void SetUp() override;
+
+ // Closes the socket created by SetUp().
+ void TearDown() override;
+
+ // Sends buf via s_.
+ void SendBuf(const char* buf, int buf_len);
+
+ // Reads from s_ into recv_buf.
+ void ReceiveBuf(char* recv_buf, size_t recv_buf_len);
+
+ void ReceiveBufFrom(int sock, char* recv_buf, size_t recv_buf_len);
+
+ int Protocol() { return std::get<0>(GetParam()); }
+
+ int Family() { return std::get<1>(GetParam()); }
+
+ socklen_t AddrLen() {
+ if (Family() == AF_INET) {
+ return sizeof(sockaddr_in);
+ }
+ return sizeof(sockaddr_in6);
+ }
+
+ int HdrLen() {
+ if (Family() == AF_INET) {
+ return sizeof(struct iphdr);
+ }
+ // IPv6 raw sockets don't include the header.
+ return 0;
+ }
+
+ // The socket used for both reading and writing.
+ int s_;
+
+ // The loopback address.
+ struct sockaddr_storage addr_;
+};
+
+void RawSocketTest::SetUp() {
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(Family(), SOCK_RAW, Protocol()),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
+ ASSERT_THAT(s_ = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
+
+ addr_ = {};
+
+ // We don't set ports because raw sockets don't have a notion of ports.
+ if (Family() == AF_INET) {
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr_);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ } else {
+ struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr_);
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_addr = in6addr_loopback;
+ }
+}
+
+void RawSocketTest::TearDown() {
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
+}
+
+// We should be able to create multiple raw sockets for the same protocol.
+// BasicRawSocket::Setup creates the first one, so we only have to create one
+// more here.
+TEST_P(RawSocketTest, MultipleCreation) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int s2;
+ ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
+
+ ASSERT_THAT(close(s2), SyscallSucceeds());
+}
+
+// Test that shutting down an unconnected socket fails.
+TEST_P(RawSocketTest, FailShutdownWithoutConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Shutdown is a no-op for raw sockets (and datagram sockets in general).
+TEST_P(RawSocketTest, ShutdownWriteNoop) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "noop";
+ ASSERT_THAT(RetryEINTR(write)(s_, kBuf, sizeof(kBuf)),
+ SyscallSucceedsWithValue(sizeof(kBuf)));
+}
+
+// Shutdown is a no-op for raw sockets (and datagram sockets in general).
+TEST_P(RawSocketTest, ShutdownReadNoop) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "gdg";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ std::vector<char> c(sizeof(kBuf) + HdrLen());
+ ASSERT_THAT(read(s_, c.data(), c.size()), SyscallSucceedsWithValue(c.size()));
+}
+
+// Test that listen() fails.
+TEST_P(RawSocketTest, FailListen) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that accept() fails.
+TEST_P(RawSocketTest, FailAccept) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr saddr;
+ socklen_t addrlen;
+ ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP));
+}
+
+// Test that getpeername() returns nothing before connect().
+TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr saddr;
+ socklen_t addrlen = sizeof(saddr);
+ ASSERT_THAT(getpeername(s_, &saddr, &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+// Test that getpeername() returns something after connect().
+TEST_P(RawSocketTest, GetPeerName) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ struct sockaddr saddr;
+ socklen_t addrlen = sizeof(saddr);
+ ASSERT_THAT(getpeername(s_, &saddr, &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_GT(addrlen, 0);
+}
+
+// Test that the socket is writable immediately.
+TEST_P(RawSocketTest, PollWritableImmediately) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct pollfd pfd = {};
+ pfd.fd = s_;
+ pfd.events = POLLOUT;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
+}
+
+// Test that the socket isn't readable before receiving anything.
+TEST_P(RawSocketTest, PollNotReadableInitially) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Try to receive data with MSG_DONTWAIT, which returns immediately if there's
+ // nothing to be read.
+ char buf[117];
+ ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// Test that the socket becomes readable once something is written to it.
+TEST_P(RawSocketTest, PollTriggeredOnWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Write something so that there's data to be read.
+ // Arbitrary.
+ constexpr char kBuf[] = "JP5";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ struct pollfd pfd = {};
+ pfd.fd = s_;
+ pfd.events = POLLIN;
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
+}
+
+// Test that we can connect() to a valid IP (loopback).
+TEST_P(RawSocketTest, ConnectToLoopback) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+}
+
+// Test that calling send() without connect() fails.
+TEST_P(RawSocketTest, SendWithoutConnectFails) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Arbitrary.
+ constexpr char kBuf[] = "Endgame was good";
+ ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0),
+ SyscallFailsWithErrno(EDESTADDRREQ));
+}
+
+// Wildcard Bind.
+TEST_P(RawSocketTest, BindToWildcard) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ struct sockaddr_storage addr;
+ addr = {};
+
+ // We don't set ports because raw sockets don't have a notion of ports.
+ if (Family() == AF_INET) {
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_family = AF_INET;
+ sin->sin_addr.s_addr = htonl(INADDR_ANY);
+ } else {
+ struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_addr = in6addr_any;
+ }
+
+ ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+}
+
+// Bind to localhost.
+TEST_P(RawSocketTest, BindToLocalhost) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+}
+
+// Bind to a different address.
+TEST_P(RawSocketTest, BindToInvalid) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr_storage bind_addr = addr_;
+ if (Family() == AF_INET) {
+ struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&bind_addr);
+ sin->sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
+ } else {
+ struct sockaddr_in6* sin6 =
+ reinterpret_cast<struct sockaddr_in6*>(&bind_addr);
+ memset(&sin6->sin6_addr.s6_addr, 0, sizeof(sin6->sin6_addr.s6_addr));
+ sin6->sin6_addr.s6_addr[0] = 1; // 1: - An address that we can't bind to.
+ }
+ ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr),
+ AddrLen()), SyscallFailsWithErrno(EADDRNOTAVAIL));
+}
+
+// Send and receive an packet.
+TEST_P(RawSocketTest, SendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Arbitrary.
+ constexpr char kBuf[] = "TB12";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// We should be able to create multiple raw sockets for the same protocol and
+// receive the same packet on both.
+TEST_P(RawSocketTest, MultipleSocketReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int s2;
+ ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "TB10";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive it on socket 1.
+ std::vector<char> recv_buf1(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1.data(), recv_buf1.size()));
+
+ // Receive it on socket 2.
+ std::vector<char> recv_buf2(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s2, recv_buf2.data(),
+ recv_buf2.size()));
+
+ EXPECT_EQ(memcmp(recv_buf1.data() + HdrLen(),
+ recv_buf2.data() + HdrLen(), sizeof(kBuf)),
+ 0);
+
+ ASSERT_THAT(close(s2), SyscallSucceeds());
+}
+
+// Test that connect sends packets to the right place.
+TEST_P(RawSocketTest, SendAndReceiveViaConnect) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "JH4";
+ ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0),
+ SyscallSucceedsWithValue(sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// Bind to localhost, then send and receive packets.
+TEST_P(RawSocketTest, BindSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "DR16";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// Bind and connect to localhost and send/receive packets.
+TEST_P(RawSocketTest, BindConnectSendAndReceive) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ // Arbitrary.
+ constexpr char kBuf[] = "DG88";
+ ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(sizeof(kBuf) + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0);
+}
+
+// Check that setting SO_RCVBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawSocketTest, SetSocketRecvBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum receive buf size by trying to set it to zero.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_RCVBUF above max is clamped to the maximum
+// receive buffer size.
+TEST_P(RawSocketTest, SetSocketRecvBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover max buf size by trying to set the largest possible buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored.
+TEST_P(RawSocketTest, SetSocketRecvBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover max buf size by trying to set a really large buffer size.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set a zero size receive buffer
+ // size.
+ // See:
+ // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Check that setting SO_SNDBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(RawSocketTest, SetSocketSendBufBelowMin) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_SNDBUF above max is clamped to the maximum
+// send buffer size.
+TEST_P(RawSocketTest, SetSocketSendBufAboveMax) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored.
+TEST_P(RawSocketTest, SetSocketSendBuf) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maximum buffer size by trying to set it to a large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack
+ // matches linux behavior.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Test that receive buffer limits are not enforced when the recv buffer is
+// empty.
+TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ int min = 0;
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Send data of size min and verify that it's received.
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(
+ memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()),
+ 0);
+ }
+
+ {
+ // Send data of size min + 1 and verify that its received. Both linux and
+ // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer
+ // is currently empty.
+ std::vector<char> buf(min + 1);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(
+ memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()),
+ 0);
+ }
+}
+
+TEST_P(RawSocketTest, RecvBufLimits) {
+ // TCP stack generates RSTs for unknown endpoints and it complicates the test
+ // as we have to deal with the RST packets as well. For testing the raw socket
+ // endpoints buffer limit enforcement we can just test for UDP.
+ //
+ // We don't use SKIP_IF here because root_test_runner explicitly fails if a
+ // test is skipped.
+ if (Protocol() == IPPROTO_TCP) {
+ return;
+ }
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ ASSERT_THAT(
+ bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()),
+ SyscallSucceeds());
+
+ int min = 0;
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(
+ setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ // Now set the limit to min * 2.
+ int new_rcv_buf_sz = min * 4;
+ if (!IsRunningOnGvisor()) {
+ // Linux doubles the value specified so just set to min.
+ new_rcv_buf_sz = min * 2;
+ }
+
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz,
+ sizeof(new_rcv_buf_sz)),
+ SyscallSucceeds());
+ int rcv_buf_sz = 0;
+ {
+ socklen_t rcv_buf_len = sizeof(rcv_buf_sz);
+ ASSERT_THAT(
+ getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, &rcv_buf_len),
+ SyscallSucceeds());
+ }
+
+ // Set a receive timeout so that we don't block forever on reads if the test
+ // fails.
+ struct timeval tv {
+ .tv_sec = 1, .tv_usec = 0,
+ };
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ {
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ int sent = 4;
+ if (IsRunningOnGvisor()) {
+ // Linux seems to drop the 4th packet even though technically it should
+ // fit in the receive buffer.
+ ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size()));
+ sent++;
+ }
+
+ // Verify that the expected number of packets are available to be read.
+ for (int i = 0; i < sent - 1; i++) {
+ // Receive the packet and make sure it's identical.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size()));
+ EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), buf.data(),
+ buf.size()),
+ 0);
+ }
+
+ // Assert that the last packet is dropped because the receive buffer should
+ // be full after the first four packets.
+ std::vector<char> recv_buf(buf.size() + HdrLen());
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<char*>(recv_buf.data()));
+ iov.iov_len = buf.size();
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
+void RawSocketTest::SendBuf(const char* buf, int buf_len) {
+ // It's safe to use const_cast here because sendmsg won't modify the iovec or
+ // address.
+ struct iovec iov = {};
+ iov.iov_base = static_cast<void*>(const_cast<char*>(buf));
+ iov.iov_len = static_cast<size_t>(buf_len);
+ struct msghdr msg = {};
+ msg.msg_name = static_cast<void*>(&addr_);
+ msg.msg_namelen = AddrLen();
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = NULL;
+ msg.msg_controllen = 0;
+ msg.msg_flags = 0;
+ ASSERT_THAT(sendmsg(s_, &msg, 0), SyscallSucceedsWithValue(buf_len));
+}
+
+void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) {
+ ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s_, recv_buf, recv_buf_len));
+}
+
+void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf,
+ size_t recv_buf_len) {
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len));
+}
+
+#ifndef __fuchsia__
+
+TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ if (IsRunningOnGvisor()) {
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+ return;
+ }
+
+ constexpr int val = 0;
+ ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(RawSocketTest, GetSocketDetachFilter) {
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
+
+// AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to.
+TEST(RawSocketTest, IPv6ProtoRaw) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ int sock;
+ ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW),
+ SyscallSucceeds());
+
+ // Verify that writing yields EINVAL.
+ char buf[] = "This is such a weird little edge case";
+ struct sockaddr_in6 sin6 = {};
+ sin6.sin6_family = AF_INET6;
+ sin6.sin6_addr = in6addr_loopback;
+ ASSERT_THAT(sendto(sock, buf, sizeof(buf), 0 /* flags */,
+ reinterpret_cast<struct sockaddr*>(&sin6), sizeof(sin6)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllInetTests, RawSocketTest,
+ ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP),
+ ::testing::Values(AF_INET, AF_INET6)));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc
index 0a27506aa..2f25aceb2 100644
--- a/test/syscalls/linux/raw_socket_hdrincl.cc
+++ b/test/syscalls/linux/raw_socket_hdrincl.cc
@@ -167,7 +167,7 @@ TEST_F(RawHDRINCL, NotReadable) {
// nothing to be read.
char buf[117];
ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EINVAL));
+ SyscallFailsWithErrno(EAGAIN));
}
// Test that we can connect() to a valid IP (loopback).
@@ -178,6 +178,9 @@ TEST_F(RawHDRINCL, ConnectToLoopback) {
}
TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
+ // FIXME(gvisor.dev/issue/3159): Test currently flaky.
+ SKIP_IF(true);
+
struct iphdr hdr = LoopbackHeader();
ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0),
SyscallSucceedsWithValue(sizeof(hdr)));
@@ -273,14 +276,17 @@ TEST_F(RawHDRINCL, SendAndReceive) {
// The network stack should have set the source address.
EXPECT_EQ(src.sin_family, AF_INET);
EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
- // The packet ID should be 0, as the packet is less than 68 bytes.
- struct iphdr iphdr = {};
- memcpy(&iphdr, recv_buf, sizeof(iphdr));
- EXPECT_EQ(iphdr.id, 0);
+ // The packet ID should not be 0, as the packet has DF=0.
+ struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf);
+ EXPECT_NE(iphdr->id, 0);
}
-// Send and receive a packet with nonzero IP ID.
-TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
+// Send and receive a packet where the sendto address is not the same as the
+// provided destination.
+TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
+ // FIXME(gvisor.dev/issue/3160): Test currently flaky.
+ SKIP_IF(true);
+
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -292,19 +298,24 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
- // Construct a packet with an IP header, UDP header, and payload. Make the
- // payload large enough to force an IP ID to be assigned.
- constexpr char kPayload[128] = {};
+ // Construct a packet with an IP header, UDP header, and payload.
+ constexpr char kPayload[] = "toto";
char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
ASSERT_TRUE(
FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
+ // Overwrite the IP destination address with an IP we can't get to.
+ struct iphdr iphdr = {};
+ memcpy(&iphdr, packet, sizeof(iphdr));
+ iphdr.daddr = 42;
+ memcpy(packet, &iphdr, sizeof(iphdr));
socklen_t addrlen = sizeof(addr_);
ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0,
reinterpret_cast<struct sockaddr*>(&addr_),
addrlen));
- // Receive the payload.
+ // Receive the payload, since sendto should replace the bad destination with
+ // localhost.
char recv_buf[sizeof(packet)];
struct sockaddr_in src;
socklen_t src_size = sizeof(src);
@@ -318,47 +329,58 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
// The network stack should have set the source address.
EXPECT_EQ(src.sin_family, AF_INET);
EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
- // The packet ID should not be 0, as the packet was more than 68 bytes.
- struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf);
- EXPECT_NE(iphdr->id, 0);
+ // The packet ID should not be 0, as the packet has DF=0.
+ struct iphdr recv_iphdr = {};
+ memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr));
+ EXPECT_NE(recv_iphdr.id, 0);
+ // The destination address should be localhost, not the bad IP we set
+ // initially.
+ EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK);
}
-// Send and receive a packet where the sendto address is not the same as the
-// provided destination.
-TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
+// Send and receive a packet w/ the IP_HDRINCL option set.
+TEST_F(RawHDRINCL, SendAndReceiveIPHdrIncl) {
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
}
- // IPPROTO_RAW sockets are write-only. We'll have to open another socket to
- // read what we write.
- FileDescriptor udp_sock =
+ FileDescriptor recv_sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+
+ FileDescriptor send_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
+ // Enable IP_HDRINCL option so that we can build and send w/ an IP
+ // header.
+ constexpr int kSockOptOn = 1;
+ ASSERT_THAT(setsockopt(send_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ // This is not strictly required but we do it to make sure that setting
+ // IP_HDRINCL on a non IPPROTO_RAW socket does not prevent it from receiving
+ // packets.
+ ASSERT_THAT(setsockopt(recv_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
// Construct a packet with an IP header, UDP header, and payload.
constexpr char kPayload[] = "toto";
char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
ASSERT_TRUE(
FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
- // Overwrite the IP destination address with an IP we can't get to.
- struct iphdr iphdr = {};
- memcpy(&iphdr, packet, sizeof(iphdr));
- iphdr.daddr = 42;
- memcpy(packet, &iphdr, sizeof(iphdr));
socklen_t addrlen = sizeof(addr_);
- ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0,
+ ASSERT_NO_FATAL_FAILURE(sendto(send_sock.get(), &packet, sizeof(packet), 0,
reinterpret_cast<struct sockaddr*>(&addr_),
addrlen));
- // Receive the payload, since sendto should replace the bad destination with
- // localhost.
+ // Receive the payload.
char recv_buf[sizeof(packet)];
struct sockaddr_in src;
socklen_t src_size = sizeof(src);
- ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0,
+ ASSERT_THAT(recvfrom(recv_sock.get(), recv_buf, sizeof(recv_buf), 0,
reinterpret_cast<struct sockaddr*>(&src), &src_size),
SyscallSucceedsWithValue(sizeof(packet)));
EXPECT_EQ(
@@ -368,13 +390,20 @@ TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
// The network stack should have set the source address.
EXPECT_EQ(src.sin_family, AF_INET);
EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
- // The packet ID should be 0, as the packet is less than 68 bytes.
- struct iphdr recv_iphdr = {};
- memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr));
- EXPECT_EQ(recv_iphdr.id, 0);
- // The destination address should be localhost, not the bad IP we set
- // initially.
- EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK);
+ struct iphdr iphdr = {};
+ memcpy(&iphdr, recv_buf, sizeof(iphdr));
+ EXPECT_NE(iphdr.id, 0);
+
+ // Also verify that the packet we just sent was not delivered to the
+ // IPPROTO_RAW socket.
+ {
+ char recv_buf[sizeof(packet)];
+ struct sockaddr_in src;
+ socklen_t src_size = sizeof(src);
+ ASSERT_THAT(recvfrom(socket_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT,
+ reinterpret_cast<struct sockaddr*>(&src), &src_size),
+ SyscallFailsWithErrno(EAGAIN));
+ }
}
} // namespace
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 8bcaba6f1..3de898df7 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -129,7 +129,7 @@ TEST_F(RawSocketICMPTest, SendAndReceiveBadChecksum) {
EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
}
-//
+
// Send and receive an ICMP packet.
TEST_F(RawSocketICMPTest, SendAndReceive) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc
deleted file mode 100644
index cde2f07c9..000000000
--- a/test/syscalls/linux/raw_socket_ipv4.cc
+++ /dev/null
@@ -1,392 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <linux/capability.h>
-#include <netinet/in.h>
-#include <netinet/ip.h>
-#include <netinet/ip_icmp.h>
-#include <poll.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <unistd.h>
-
-#include <algorithm>
-
-#include "gtest/gtest.h"
-#include "test/syscalls/linux/socket_test_util.h"
-#include "test/syscalls/linux/unix_domain_socket_test_util.h"
-#include "test/util/capability_util.h"
-#include "test/util/file_descriptor.h"
-#include "test/util/test_util.h"
-
-// Note: in order to run these tests, /proc/sys/net/ipv4/ping_group_range will
-// need to be configured to let the superuser create ping sockets (see icmp(7)).
-
-namespace gvisor {
-namespace testing {
-
-namespace {
-
-// Fixture for tests parameterized by protocol.
-class RawSocketTest : public ::testing::TestWithParam<int> {
- protected:
- // Creates a socket to be used in tests.
- void SetUp() override;
-
- // Closes the socket created by SetUp().
- void TearDown() override;
-
- // Sends buf via s_.
- void SendBuf(const char* buf, int buf_len);
-
- // Sends buf to the provided address via the provided socket.
- void SendBufTo(int sock, const struct sockaddr_in& addr, const char* buf,
- int buf_len);
-
- // Reads from s_ into recv_buf.
- void ReceiveBuf(char* recv_buf, size_t recv_buf_len);
-
- int Protocol() { return GetParam(); }
-
- // The socket used for both reading and writing.
- int s_;
-
- // The loopback address.
- struct sockaddr_in addr_;
-};
-
-void RawSocketTest::SetUp() {
- if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
- ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()),
- SyscallFailsWithErrno(EPERM));
- GTEST_SKIP();
- }
-
- ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
-
- addr_ = {};
-
- // We don't set ports because raw sockets don't have a notion of ports.
- addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- addr_.sin_family = AF_INET;
-}
-
-void RawSocketTest::TearDown() {
- // TearDown will be run even if we skip the test.
- if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
- EXPECT_THAT(close(s_), SyscallSucceeds());
- }
-}
-
-// We should be able to create multiple raw sockets for the same protocol.
-// BasicRawSocket::Setup creates the first one, so we only have to create one
-// more here.
-TEST_P(RawSocketTest, MultipleCreation) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- int s2;
- ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
-
- ASSERT_THAT(close(s2), SyscallSucceeds());
-}
-
-// Test that shutting down an unconnected socket fails.
-TEST_P(RawSocketTest, FailShutdownWithoutConnect) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
- ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
-}
-
-// Shutdown is a no-op for raw sockets (and datagram sockets in general).
-TEST_P(RawSocketTest, ShutdownWriteNoop) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
- ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds());
-
- // Arbitrary.
- constexpr char kBuf[] = "noop";
- ASSERT_THAT(RetryEINTR(write)(s_, kBuf, sizeof(kBuf)),
- SyscallSucceedsWithValue(sizeof(kBuf)));
-}
-
-// Shutdown is a no-op for raw sockets (and datagram sockets in general).
-TEST_P(RawSocketTest, ShutdownReadNoop) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
- ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
-
- // Arbitrary.
- constexpr char kBuf[] = "gdg";
- ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
-
- constexpr size_t kReadSize = sizeof(kBuf) + sizeof(struct iphdr);
- char c[kReadSize];
- ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(kReadSize));
-}
-
-// Test that listen() fails.
-TEST_P(RawSocketTest, FailListen) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(listen(s_, 1), SyscallFailsWithErrno(ENOTSUP));
-}
-
-// Test that accept() fails.
-TEST_P(RawSocketTest, FailAccept) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- struct sockaddr saddr;
- socklen_t addrlen;
- ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP));
-}
-
-// Test that getpeername() returns nothing before connect().
-TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- struct sockaddr saddr;
- socklen_t addrlen = sizeof(saddr);
- ASSERT_THAT(getpeername(s_, &saddr, &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
-}
-
-// Test that getpeername() returns something after connect().
-TEST_P(RawSocketTest, GetPeerName) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
- struct sockaddr saddr;
- socklen_t addrlen = sizeof(saddr);
- ASSERT_THAT(getpeername(s_, &saddr, &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
- ASSERT_GT(addrlen, 0);
-}
-
-// Test that the socket is writable immediately.
-TEST_P(RawSocketTest, PollWritableImmediately) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- struct pollfd pfd = {};
- pfd.fd = s_;
- pfd.events = POLLOUT;
- ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
-}
-
-// Test that the socket isn't readable before receiving anything.
-TEST_P(RawSocketTest, PollNotReadableInitially) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- // Try to receive data with MSG_DONTWAIT, which returns immediately if there's
- // nothing to be read.
- char buf[117];
- ASSERT_THAT(RetryEINTR(recv)(s_, buf, sizeof(buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EAGAIN));
-}
-
-// Test that the socket becomes readable once something is written to it.
-TEST_P(RawSocketTest, PollTriggeredOnWrite) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- // Write something so that there's data to be read.
- // Arbitrary.
- constexpr char kBuf[] = "JP5";
- ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
-
- struct pollfd pfd = {};
- pfd.fd = s_;
- pfd.events = POLLIN;
- ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 10000), SyscallSucceedsWithValue(1));
-}
-
-// Test that we can connect() to a valid IP (loopback).
-TEST_P(RawSocketTest, ConnectToLoopback) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
-}
-
-// Test that calling send() without connect() fails.
-TEST_P(RawSocketTest, SendWithoutConnectFails) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- // Arbitrary.
- constexpr char kBuf[] = "Endgame was good";
- ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0),
- SyscallFailsWithErrno(EDESTADDRREQ));
-}
-
-// Bind to localhost.
-TEST_P(RawSocketTest, BindToLocalhost) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
-}
-
-// Bind to a different address.
-TEST_P(RawSocketTest, BindToInvalid) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- struct sockaddr_in bind_addr = {};
- bind_addr.sin_family = AF_INET;
- bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
- ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr),
- sizeof(bind_addr)),
- SyscallFailsWithErrno(EADDRNOTAVAIL));
-}
-
-// Send and receive an packet.
-TEST_P(RawSocketTest, SendAndReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- // Arbitrary.
- constexpr char kBuf[] = "TB12";
- ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
-
- // Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
-}
-
-// We should be able to create multiple raw sockets for the same protocol and
-// receive the same packet on both.
-TEST_P(RawSocketTest, MultipleSocketReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- int s2;
- ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
-
- // Arbitrary.
- constexpr char kBuf[] = "TB10";
- ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
-
- // Receive it on socket 1.
- char recv_buf1[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1, sizeof(recv_buf1)));
-
- // Receive it on socket 2.
- char recv_buf2[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s2, recv_buf2, sizeof(recv_buf2)));
-
- EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr),
- recv_buf2 + sizeof(struct iphdr), sizeof(kBuf)),
- 0);
-
- ASSERT_THAT(close(s2), SyscallSucceeds());
-}
-
-// Test that connect sends packets to the right place.
-TEST_P(RawSocketTest, SendAndReceiveViaConnect) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
-
- // Arbitrary.
- constexpr char kBuf[] = "JH4";
- ASSERT_THAT(send(s_, kBuf, sizeof(kBuf), 0),
- SyscallSucceedsWithValue(sizeof(kBuf)));
-
- // Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
-}
-
-// Bind to localhost, then send and receive packets.
-TEST_P(RawSocketTest, BindSendAndReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
-
- // Arbitrary.
- constexpr char kBuf[] = "DR16";
- ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
-
- // Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
-}
-
-// Bind and connect to localhost and send/receive packets.
-TEST_P(RawSocketTest, BindConnectSendAndReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- ASSERT_THAT(
- bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
- ASSERT_THAT(
- connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
- SyscallSucceeds());
-
- // Arbitrary.
- constexpr char kBuf[] = "DG88";
- ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf)));
-
- // Receive the packet and make sure it's identical.
- char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)];
- ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf)));
- EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0);
-}
-
-void RawSocketTest::SendBuf(const char* buf, int buf_len) {
- ASSERT_NO_FATAL_FAILURE(SendBufTo(s_, addr_, buf, buf_len));
-}
-
-void RawSocketTest::SendBufTo(int sock, const struct sockaddr_in& addr,
- const char* buf, int buf_len) {
- // It's safe to use const_cast here because sendmsg won't modify the iovec or
- // address.
- struct iovec iov = {};
- iov.iov_base = static_cast<void*>(const_cast<char*>(buf));
- iov.iov_len = static_cast<size_t>(buf_len);
- struct msghdr msg = {};
- msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr));
- msg.msg_namelen = sizeof(addr);
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
- msg.msg_control = NULL;
- msg.msg_controllen = 0;
- msg.msg_flags = 0;
- ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(buf_len));
-}
-
-void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) {
- ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, recv_buf_len));
-}
-
-INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest,
- ::testing::Values(IPPROTO_TCP, IPPROTO_UDP));
-
-} // namespace
-
-} // namespace testing
-} // namespace gvisor
diff --git a/test/syscalls/linux/read.cc b/test/syscalls/linux/read.cc
index 4430fa3c2..2633ba31b 100644
--- a/test/syscalls/linux/read.cc
+++ b/test/syscalls/linux/read.cc
@@ -14,6 +14,7 @@
#include <fcntl.h>
#include <unistd.h>
+
#include <vector>
#include "gtest/gtest.h"
diff --git a/test/syscalls/linux/readv.cc b/test/syscalls/linux/readv.cc
index 4069cbc7e..baaf9f757 100644
--- a/test/syscalls/linux/readv.cc
+++ b/test/syscalls/linux/readv.cc
@@ -254,7 +254,9 @@ TEST_F(ReadvTest, IovecOutsideTaskAddressRangeInNonemptyArray) {
// This test depends on the maximum extent of a single readv() syscall, so
// we can't tolerate interruption from saving.
TEST(ReadvTestNoFixture, TruncatedAtMax_NoRandomSave) {
- // Ensure that we won't be interrupted by ITIMER_PROF.
+ // Ensure that we won't be interrupted by ITIMER_PROF. This is particularly
+ // important in environments where automated profiling tools may start
+ // ITIMER_PROF automatically.
struct itimerval itv = {};
auto const cleanup_itimer =
ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_PROF, itv));
diff --git a/test/syscalls/linux/readv_common.cc b/test/syscalls/linux/readv_common.cc
index 9658f7d42..2694dc64f 100644
--- a/test/syscalls/linux/readv_common.cc
+++ b/test/syscalls/linux/readv_common.cc
@@ -19,12 +19,53 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/file_base.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+// MatchesStringLength checks that a tuple argument of (struct iovec *, int)
+// corresponding to an iovec array and its length, contains data that matches
+// the string length strlen.
+MATCHER_P(MatchesStringLength, strlen, "") {
+ struct iovec* iovs = arg.first;
+ int niov = arg.second;
+ int offset = 0;
+ for (int i = 0; i < niov; i++) {
+ offset += iovs[i].iov_len;
+ }
+ if (offset != static_cast<int>(strlen)) {
+ *result_listener << offset;
+ return false;
+ }
+ return true;
+}
+
+// MatchesStringValue checks that a tuple argument of (struct iovec *, int)
+// corresponding to an iovec array and its length, contains data that matches
+// the string value str.
+MATCHER_P(MatchesStringValue, str, "") {
+ struct iovec* iovs = arg.first;
+ int len = strlen(str);
+ int niov = arg.second;
+ int offset = 0;
+ for (int i = 0; i < niov; i++) {
+ struct iovec iov = iovs[i];
+ if (len < offset) {
+ *result_listener << "strlen " << len << " < offset " << offset;
+ return false;
+ }
+ if (strncmp(static_cast<char*>(iov.iov_base), &str[offset], iov.iov_len)) {
+ absl::string_view iovec_string(static_cast<char*>(iov.iov_base),
+ iov.iov_len);
+ *result_listener << iovec_string << " @offset " << offset;
+ return false;
+ }
+ offset += iov.iov_len;
+ }
+ return true;
+}
+
extern const char kReadvTestData[] =
"127.0.0.1 localhost"
""
@@ -113,7 +154,7 @@ void ReadBuffersOverlapping(int fd) {
char* expected_ptr = expected.data();
memcpy(expected_ptr, &kReadvTestData[overlap_bytes], overlap_bytes);
memcpy(&expected_ptr[overlap_bytes], &kReadvTestData[overlap_bytes],
- kReadvTestDataSize);
+ kReadvTestDataSize - overlap_bytes);
struct iovec iovs[2];
iovs[0].iov_base = buffer.data();
diff --git a/test/syscalls/linux/readv_socket.cc b/test/syscalls/linux/readv_socket.cc
index 9b6972201..dd6fb7008 100644
--- a/test/syscalls/linux/readv_socket.cc
+++ b/test/syscalls/linux/readv_socket.cc
@@ -19,7 +19,6 @@
#include <unistd.h>
#include "gtest/gtest.h"
-#include "test/syscalls/linux/file_base.h"
#include "test/syscalls/linux/readv_common.h"
#include "test/util/test_util.h"
@@ -28,9 +27,30 @@ namespace testing {
namespace {
-class ReadvSocketTest : public SocketTest {
+class ReadvSocketTest : public ::testing::Test {
+ public:
void SetUp() override {
- SocketTest::SetUp();
+ test_unix_stream_socket_[0] = -1;
+ test_unix_stream_socket_[1] = -1;
+ test_unix_dgram_socket_[0] = -1;
+ test_unix_dgram_socket_[1] = -1;
+ test_unix_seqpacket_socket_[0] = -1;
+ test_unix_seqpacket_socket_[1] = -1;
+
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, test_unix_stream_socket_),
+ SyscallSucceeds());
+ ASSERT_THAT(fcntl(test_unix_stream_socket_[0], F_SETFL, O_NONBLOCK),
+ SyscallSucceeds());
+ ASSERT_THAT(socketpair(AF_UNIX, SOCK_DGRAM, 0, test_unix_dgram_socket_),
+ SyscallSucceeds());
+ ASSERT_THAT(fcntl(test_unix_dgram_socket_[0], F_SETFL, O_NONBLOCK),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ socketpair(AF_UNIX, SOCK_SEQPACKET, 0, test_unix_seqpacket_socket_),
+ SyscallSucceeds());
+ ASSERT_THAT(fcntl(test_unix_seqpacket_socket_[0], F_SETFL, O_NONBLOCK),
+ SyscallSucceeds());
+
ASSERT_THAT(
write(test_unix_stream_socket_[1], kReadvTestData, kReadvTestDataSize),
SyscallSucceedsWithValue(kReadvTestDataSize));
@@ -40,11 +60,22 @@ class ReadvSocketTest : public SocketTest {
ASSERT_THAT(write(test_unix_seqpacket_socket_[1], kReadvTestData,
kReadvTestDataSize),
SyscallSucceedsWithValue(kReadvTestDataSize));
- // FIXME(b/69821513): Enable when possible.
- // ASSERT_THAT(write(test_tcp_socket_[1], kReadvTestData,
- // kReadvTestDataSize),
- // SyscallSucceedsWithValue(kReadvTestDataSize));
}
+
+ void TearDown() override {
+ close(test_unix_stream_socket_[0]);
+ close(test_unix_stream_socket_[1]);
+
+ close(test_unix_dgram_socket_[0]);
+ close(test_unix_dgram_socket_[1]);
+
+ close(test_unix_seqpacket_socket_[0]);
+ close(test_unix_seqpacket_socket_[1]);
+ }
+
+ int test_unix_stream_socket_[2];
+ int test_unix_dgram_socket_[2];
+ int test_unix_seqpacket_socket_[2];
};
TEST_F(ReadvSocketTest, ReadOneBufferPerByte_StreamSocket) {
diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc
index 5b474ff32..833c0dc4f 100644
--- a/test/syscalls/linux/rename.cc
+++ b/test/syscalls/linux/rename.cc
@@ -14,6 +14,7 @@
#include <fcntl.h>
#include <stdio.h>
+
#include <string>
#include "gtest/gtest.h"
diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc
new file mode 100644
index 000000000..4bfb1ff56
--- /dev/null
+++ b/test/syscalls/linux/rseq.cc
@@ -0,0 +1,198 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <signal.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/rseq/test.h"
+#include "test/syscalls/linux/rseq/uapi.h"
+#include "test/util/logging.h"
+#include "test/util/multiprocess_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Syscall test for rseq (restartable sequences).
+//
+// We must be very careful about how these tests are written. Each thread may
+// only have one struct rseq registration, which may be done automatically at
+// thread start (as of 2019-11-13, glibc does *not* support rseq and thus does
+// not do so, but other libraries do).
+//
+// Testing of rseq is thus done primarily in a child process with no
+// registration. This means exec'ing a nostdlib binary, as rseq registration can
+// only be cleared by execve (or knowing the old rseq address), and glibc (based
+// on the current unmerged patches) register rseq before calling main()).
+
+int RSeq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) {
+ return syscall(kRseqSyscall, rseq, rseq_len, flags, sig);
+}
+
+// Returns true if this kernel supports the rseq syscall.
+PosixErrorOr<bool> RSeqSupported() {
+ // We have to be careful here, there are three possible cases:
+ //
+ // 1. rseq is not supported -> ENOSYS
+ // 2. rseq is supported and not registered -> success, but we should
+ // unregister.
+ // 3. rseq is supported and registered -> EINVAL (most likely).
+
+ // The only validation done on new registrations is that rseq is aligned and
+ // writable.
+ rseq rseq = {};
+ int ret = RSeq(&rseq, sizeof(rseq), 0, 0);
+ if (ret == 0) {
+ // Successfully registered, rseq is supported. Unregister.
+ ret = RSeq(&rseq, sizeof(rseq), kRseqFlagUnregister, 0);
+ if (ret != 0) {
+ return PosixError(errno);
+ }
+ return true;
+ }
+
+ switch (errno) {
+ case ENOSYS:
+ // Not supported.
+ return false;
+ case EINVAL:
+ // Supported, but already registered. EINVAL returned because we provided
+ // a different address.
+ return true;
+ default:
+ // Unknown error.
+ return PosixError(errno);
+ }
+}
+
+constexpr char kRseqBinary[] = "test/syscalls/linux/rseq/rseq";
+
+void RunChildTest(std::string test_case, int want_status) {
+ std::string path = RunfilePath(kRseqBinary);
+
+ pid_t child_pid = -1;
+ int execve_errno = 0;
+ auto cleanup = ASSERT_NO_ERRNO_AND_VALUE(
+ ForkAndExec(path, {path, test_case}, {}, &child_pid, &execve_errno));
+
+ ASSERT_GT(child_pid, 0);
+ ASSERT_EQ(execve_errno, 0);
+
+ int status = 0;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds());
+ ASSERT_EQ(status, want_status);
+}
+
+// Test that rseq must be aligned.
+TEST(RseqTest, Unaligned) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestUnaligned, 0);
+}
+
+// Sanity test that registration works.
+TEST(RseqTest, Register) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestRegister, 0);
+}
+
+// Registration can't be done twice.
+TEST(RseqTest, DoubleRegister) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestDoubleRegister, 0);
+}
+
+// Registration can be done again after unregister.
+TEST(RseqTest, RegisterUnregister) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestRegisterUnregister, 0);
+}
+
+// The pointer to rseq must match on register/unregister.
+TEST(RseqTest, UnregisterDifferentPtr) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestUnregisterDifferentPtr, 0);
+}
+
+// The signature must match on register/unregister.
+TEST(RseqTest, UnregisterDifferentSignature) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestUnregisterDifferentSignature, 0);
+}
+
+// The CPU ID is initialized.
+TEST(RseqTest, CPU) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestCPU, 0);
+}
+
+// Critical section is eventually aborted.
+TEST(RseqTest, Abort) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbort, 0);
+}
+
+// Abort may be before the critical section.
+TEST(RseqTest, AbortBefore) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortBefore, 0);
+}
+
+// Signature must match.
+TEST(RseqTest, AbortSignature) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortSignature, SIGSEGV);
+}
+
+// Abort must not be in the critical section.
+TEST(RseqTest, AbortPreCommit) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortPreCommit, SIGSEGV);
+}
+
+// rseq.rseq_cs is cleared on abort.
+TEST(RseqTest, AbortClearsCS) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestAbortClearsCS, 0);
+}
+
+// rseq.rseq_cs is cleared on abort outside of critical section.
+TEST(RseqTest, InvalidAbortClearsCS) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(RSeqSupported()));
+
+ RunChildTest(kRseqTestInvalidAbortClearsCS, 0);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rseq/BUILD b/test/syscalls/linux/rseq/BUILD
new file mode 100644
index 000000000..853258b04
--- /dev/null
+++ b/test/syscalls/linux/rseq/BUILD
@@ -0,0 +1,61 @@
+# This package contains a standalone rseq test binary. This binary must not
+# depend on libc, which might use rseq itself.
+
+load("//tools:defs.bzl", "cc_flags_supplier", "cc_library", "cc_toolchain", "select_arch")
+
+package(licenses = ["notice"])
+
+genrule(
+ name = "rseq_binary",
+ srcs = [
+ "critical.h",
+ "critical_amd64.S",
+ "critical_arm64.S",
+ "rseq.cc",
+ "syscalls.h",
+ "start_amd64.S",
+ "start_arm64.S",
+ "test.h",
+ "types.h",
+ "uapi.h",
+ ],
+ outs = ["rseq"],
+ cmd = "$(CC) " +
+ "$(CC_FLAGS) " +
+ "-I. " +
+ "-Wall " +
+ "-Werror " +
+ "-O2 " +
+ "-std=c++17 " +
+ "-static " +
+ "-nostdlib " +
+ "-ffreestanding " +
+ "-o " +
+ "$(location rseq) " +
+ select_arch(
+ amd64 = "$(location critical_amd64.S) $(location start_amd64.S) ",
+ arm64 = "$(location critical_arm64.S) $(location start_arm64.S) ",
+ no_match_error = "unsupported architecture",
+ ) +
+ "$(location rseq.cc)",
+ toolchains = [
+ cc_toolchain,
+ ":no_pie_cc_flags",
+ ],
+ visibility = ["//:sandbox"],
+)
+
+cc_flags_supplier(
+ name = "no_pie_cc_flags",
+ features = ["-pie"],
+)
+
+cc_library(
+ name = "lib",
+ testonly = 1,
+ hdrs = [
+ "test.h",
+ "uapi.h",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/test/syscalls/linux/rseq/critical.h b/test/syscalls/linux/rseq/critical.h
new file mode 100644
index 000000000..ac987a25e
--- /dev/null
+++ b/test/syscalls/linux/rseq/critical.h
@@ -0,0 +1,39 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_
+
+#include "test/syscalls/linux/rseq/types.h"
+#include "test/syscalls/linux/rseq/uapi.h"
+
+constexpr uint32_t kRseqSignature = 0x90909090;
+
+extern "C" {
+
+extern void rseq_loop(struct rseq* r, struct rseq_cs* cs);
+extern void* rseq_loop_early_abort;
+extern void* rseq_loop_start;
+extern void* rseq_loop_pre_commit;
+extern void* rseq_loop_post_commit;
+extern void* rseq_loop_abort;
+
+extern int rseq_getpid(struct rseq* r, struct rseq_cs* cs);
+extern void* rseq_getpid_start;
+extern void* rseq_getpid_post_commit;
+extern void* rseq_getpid_abort;
+
+} // extern "C"
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_CRITICAL_H_
diff --git a/test/syscalls/linux/rseq/critical_amd64.S b/test/syscalls/linux/rseq/critical_amd64.S
new file mode 100644
index 000000000..8c0687e6d
--- /dev/null
+++ b/test/syscalls/linux/rseq/critical_amd64.S
@@ -0,0 +1,66 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Restartable sequences critical sections.
+
+// Loops continuously until aborted.
+//
+// void rseq_loop(struct rseq* r, struct rseq_cs* cs)
+
+ .text
+ .globl rseq_loop
+ .type rseq_loop, @function
+
+rseq_loop:
+ jmp begin
+
+ // Abort block before the critical section.
+ // Abort signature is 4 nops for simplicity.
+ .byte 0x90, 0x90, 0x90, 0x90
+ .globl rseq_loop_early_abort
+rseq_loop_early_abort:
+ ret
+
+begin:
+ // r->rseq_cs = cs
+ movq %rsi, 8(%rdi)
+
+ // N.B. rseq_cs will be cleared by any preempt, even outside the critical
+ // section. Thus it must be set in or immediately before the critical section
+ // to ensure it is not cleared before the section begins.
+ .globl rseq_loop_start
+rseq_loop_start:
+ jmp rseq_loop_start
+
+ // "Pre-commit": extra instructions inside the critical section. These are
+ // used as the abort point in TestAbortPreCommit, which is not valid.
+ .globl rseq_loop_pre_commit
+rseq_loop_pre_commit:
+ // Extra abort signature + nop for TestAbortPostCommit.
+ .byte 0x90, 0x90, 0x90, 0x90
+ nop
+
+ // "Post-commit": never reached in this case.
+ .globl rseq_loop_post_commit
+rseq_loop_post_commit:
+
+ // Abort signature is 4 nops for simplicity.
+ .byte 0x90, 0x90, 0x90, 0x90
+
+ .globl rseq_loop_abort
+rseq_loop_abort:
+ ret
+
+ .size rseq_loop,.-rseq_loop
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/critical_arm64.S b/test/syscalls/linux/rseq/critical_arm64.S
new file mode 100644
index 000000000..bfe7e8307
--- /dev/null
+++ b/test/syscalls/linux/rseq/critical_arm64.S
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Restartable sequences critical sections.
+
+// Loops continuously until aborted.
+//
+// void rseq_loop(struct rseq* r, struct rseq_cs* cs)
+
+ .text
+ .globl rseq_loop
+ .type rseq_loop, @function
+
+rseq_loop:
+ b begin
+
+ // Abort block before the critical section.
+ // Abort signature.
+ .byte 0x90, 0x90, 0x90, 0x90
+ .globl rseq_loop_early_abort
+rseq_loop_early_abort:
+ ret
+
+begin:
+ // r->rseq_cs = cs
+ str x1, [x0, #8]
+
+ // N.B. rseq_cs will be cleared by any preempt, even outside the critical
+ // section. Thus it must be set in or immediately before the critical section
+ // to ensure it is not cleared before the section begins.
+ .globl rseq_loop_start
+rseq_loop_start:
+ b rseq_loop_start
+
+ // "Pre-commit": extra instructions inside the critical section. These are
+ // used as the abort point in TestAbortPreCommit, which is not valid.
+ .globl rseq_loop_pre_commit
+rseq_loop_pre_commit:
+ // Extra abort signature + nop for TestAbortPostCommit.
+ .byte 0x90, 0x90, 0x90, 0x90
+ nop
+
+ // "Post-commit": never reached in this case.
+ .globl rseq_loop_post_commit
+rseq_loop_post_commit:
+
+ // Abort signature.
+ .byte 0x90, 0x90, 0x90, 0x90
+
+ .globl rseq_loop_abort
+rseq_loop_abort:
+ ret
+
+ .size rseq_loop,.-rseq_loop
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc
new file mode 100644
index 000000000..f036db26d
--- /dev/null
+++ b/test/syscalls/linux/rseq/rseq.cc
@@ -0,0 +1,366 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/rseq/critical.h"
+#include "test/syscalls/linux/rseq/syscalls.h"
+#include "test/syscalls/linux/rseq/test.h"
+#include "test/syscalls/linux/rseq/types.h"
+#include "test/syscalls/linux/rseq/uapi.h"
+
+namespace gvisor {
+namespace testing {
+
+extern "C" int main(int argc, char** argv, char** envp);
+
+// Standalone initialization before calling main().
+extern "C" void __init(uintptr_t* sp) {
+ int argc = sp[0];
+ char** argv = reinterpret_cast<char**>(&sp[1]);
+ char** envp = &argv[argc + 1];
+
+ // Call main() and exit.
+ sys_exit_group(main(argc, argv, envp));
+
+ // sys_exit_group does not return
+}
+
+int strcmp(const char* s1, const char* s2) {
+ const unsigned char* p1 = reinterpret_cast<const unsigned char*>(s1);
+ const unsigned char* p2 = reinterpret_cast<const unsigned char*>(s2);
+
+ while (*p1 == *p2) {
+ if (!*p1) {
+ return 0;
+ }
+ ++p1;
+ ++p2;
+ }
+ return static_cast<int>(*p1) - static_cast<int>(*p2);
+}
+
+int sys_rseq(struct rseq* rseq, uint32_t rseq_len, int flags, uint32_t sig) {
+ return raw_syscall(kRseqSyscall, rseq, rseq_len, flags, sig);
+}
+
+// Test that rseq must be aligned.
+int TestUnaligned() {
+ constexpr uintptr_t kRequiredAlignment = alignof(rseq);
+
+ char buf[2 * kRequiredAlignment] = {};
+ uintptr_t ptr = reinterpret_cast<uintptr_t>(&buf[0]);
+ if ((ptr & (kRequiredAlignment - 1)) == 0) {
+ // buf is already aligned. Misalign it.
+ ptr++;
+ }
+
+ int ret = sys_rseq(reinterpret_cast<rseq*>(ptr), sizeof(rseq), 0, 0);
+ if (sys_errno(ret) != EINVAL) {
+ return 1;
+ }
+ return 0;
+}
+
+// Sanity test that registration works.
+int TestRegister() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+ return 0;
+};
+
+// Registration can't be done twice.
+int TestDoubleRegister() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// Registration can be done again after unregister.
+int TestRegisterUnregister() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// The pointer to rseq must match on register/unregister.
+int TestUnregisterDifferentPtr() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq r2 = {};
+ if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0);
+ sys_errno(ret) != EINVAL) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// The signature must match on register/unregister.
+int TestUnregisterDifferentSignature() {
+ constexpr int kSignature = 0;
+
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1);
+ sys_errno(ret) != EPERM) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// The CPU ID is initialized.
+int TestCPU() {
+ struct rseq r = {};
+ r.cpu_id = kRseqCPUIDUninitialized;
+
+ if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ if (__atomic_load_n(&r.cpu_id, __ATOMIC_RELAXED) < 0) {
+ return 1;
+ }
+ if (__atomic_load_n(&r.cpu_id_start, __ATOMIC_RELAXED) < 0) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// Critical section is eventually aborted.
+int TestAbort() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ // Loops until abort. If this returns then abort occurred.
+ rseq_loop(&r, &cs);
+
+ return 0;
+};
+
+// Abort may be before the critical section.
+int TestAbortBefore() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_early_abort);
+
+ // Loops until abort. If this returns then abort occurred.
+ rseq_loop(&r, &cs);
+
+ return 0;
+};
+
+// Signature must match.
+int TestAbortSignature() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ // Loops until abort. This should SIGSEGV on abort.
+ rseq_loop(&r, &cs);
+
+ return 1;
+};
+
+// Abort must not be in the critical section.
+int TestAbortPreCommit() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_pre_commit);
+
+ // Loops until abort. This should SIGSEGV on abort.
+ rseq_loop(&r, &cs);
+
+ return 1;
+};
+
+// rseq.rseq_cs is cleared on abort.
+int TestAbortClearsCS() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ // Loops until abort. If this returns then abort occurred.
+ rseq_loop(&r, &cs);
+
+ if (__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) {
+ return 1;
+ }
+
+ return 0;
+};
+
+// rseq.rseq_cs is cleared on abort outside of critical section.
+int TestInvalidAbortClearsCS() {
+ struct rseq r = {};
+ if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature);
+ sys_errno(ret) != 0) {
+ return 1;
+ }
+
+ struct rseq_cs cs = {};
+ cs.version = 0;
+ cs.flags = 0;
+ cs.start_ip = reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.post_commit_offset = reinterpret_cast<uint64_t>(&rseq_loop_post_commit) -
+ reinterpret_cast<uint64_t>(&rseq_loop_start);
+ cs.abort_ip = reinterpret_cast<uint64_t>(&rseq_loop_abort);
+
+ __atomic_store_n(&r.rseq_cs, &cs, __ATOMIC_RELAXED);
+
+ // When the next abort condition occurs, the kernel will clear cs once it
+ // determines we aren't in the critical section.
+ while (1) {
+ if (!__atomic_load_n(&r.rseq_cs, __ATOMIC_RELAXED)) {
+ break;
+ }
+ }
+
+ return 0;
+};
+
+// Exit codes:
+// 0 - Pass
+// 1 - Fail
+// 2 - Missing argument
+// 3 - Unknown test case
+extern "C" int main(int argc, char** argv, char** envp) {
+ if (argc != 2) {
+ // Usage: rseq <test case>
+ return 2;
+ }
+
+ if (strcmp(argv[1], kRseqTestUnaligned) == 0) {
+ return TestUnaligned();
+ }
+ if (strcmp(argv[1], kRseqTestRegister) == 0) {
+ return TestRegister();
+ }
+ if (strcmp(argv[1], kRseqTestDoubleRegister) == 0) {
+ return TestDoubleRegister();
+ }
+ if (strcmp(argv[1], kRseqTestRegisterUnregister) == 0) {
+ return TestRegisterUnregister();
+ }
+ if (strcmp(argv[1], kRseqTestUnregisterDifferentPtr) == 0) {
+ return TestUnregisterDifferentPtr();
+ }
+ if (strcmp(argv[1], kRseqTestUnregisterDifferentSignature) == 0) {
+ return TestUnregisterDifferentSignature();
+ }
+ if (strcmp(argv[1], kRseqTestCPU) == 0) {
+ return TestCPU();
+ }
+ if (strcmp(argv[1], kRseqTestAbort) == 0) {
+ return TestAbort();
+ }
+ if (strcmp(argv[1], kRseqTestAbortBefore) == 0) {
+ return TestAbortBefore();
+ }
+ if (strcmp(argv[1], kRseqTestAbortSignature) == 0) {
+ return TestAbortSignature();
+ }
+ if (strcmp(argv[1], kRseqTestAbortPreCommit) == 0) {
+ return TestAbortPreCommit();
+ }
+ if (strcmp(argv[1], kRseqTestAbortClearsCS) == 0) {
+ return TestAbortClearsCS();
+ }
+ if (strcmp(argv[1], kRseqTestInvalidAbortClearsCS) == 0) {
+ return TestInvalidAbortClearsCS();
+ }
+
+ return 3;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rseq/start_amd64.S b/test/syscalls/linux/rseq/start_amd64.S
new file mode 100644
index 000000000..b9611b276
--- /dev/null
+++ b/test/syscalls/linux/rseq/start_amd64.S
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+ .text
+ .align 4
+ .type _start,@function
+ .globl _start
+
+_start:
+ movq %rsp,%rdi
+ call __init
+ hlt
+
+ .size _start,.-_start
+ .section .note.GNU-stack,"",@progbits
+
+ .text
+ .globl raw_syscall
+ .type raw_syscall, @function
+
+raw_syscall:
+ mov %rdi,%rax // syscall #
+ mov %rsi,%rdi // arg0
+ mov %rdx,%rsi // arg1
+ mov %rcx,%rdx // arg2
+ mov %r8,%r10 // arg3 (goes in r10 instead of rcx for system calls)
+ mov %r9,%r8 // arg4
+ mov 0x8(%rsp),%r9 // arg5
+ syscall
+ ret
+
+ .size raw_syscall,.-raw_syscall
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/start_arm64.S b/test/syscalls/linux/rseq/start_arm64.S
new file mode 100644
index 000000000..693c1c6eb
--- /dev/null
+++ b/test/syscalls/linux/rseq/start_arm64.S
@@ -0,0 +1,45 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+ .text
+ .align 4
+ .type _start,@function
+ .globl _start
+
+_start:
+ mov x29, sp
+ bl __init
+ wfi
+
+ .size _start,.-_start
+ .section .note.GNU-stack,"",@progbits
+
+ .text
+ .globl raw_syscall
+ .type raw_syscall, @function
+
+raw_syscall:
+ mov x8,x0 // syscall #
+ mov x0,x1 // arg0
+ mov x1,x2 // arg1
+ mov x2,x3 // arg2
+ mov x3,x4 // arg3
+ mov x4,x5 // arg4
+ mov x5,x6 // arg5
+ svc #0
+ ret
+
+ .size raw_syscall,.-raw_syscall
+ .section .note.GNU-stack,"",@progbits
diff --git a/test/syscalls/linux/rseq/syscalls.h b/test/syscalls/linux/rseq/syscalls.h
new file mode 100644
index 000000000..c4118e6c5
--- /dev/null
+++ b/test/syscalls/linux/rseq/syscalls.h
@@ -0,0 +1,69 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_
+
+#include "test/syscalls/linux/rseq/types.h"
+
+// Syscall numbers.
+#if defined(__x86_64__)
+constexpr int kGetpid = 39;
+constexpr int kExitGroup = 231;
+#elif defined(__aarch64__)
+constexpr int kGetpid = 172;
+constexpr int kExitGroup = 94;
+#else
+#error "Unknown architecture"
+#endif
+
+namespace gvisor {
+namespace testing {
+
+// Standalone system call interfaces.
+// Note that these are all "raw" system call interfaces which encode
+// errors by setting the return value to a small negative number.
+// Use sys_errno() to check system call return values for errors.
+
+// Maximum Linux error number.
+constexpr int kMaxErrno = 4095;
+
+// Errno values.
+#define EPERM 1
+#define EFAULT 14
+#define EBUSY 16
+#define EINVAL 22
+
+// Get the error number from a raw system call return value.
+// Returns a positive error number or 0 if there was no error.
+static inline int sys_errno(uintptr_t rval) {
+ if (rval >= static_cast<uintptr_t>(-kMaxErrno)) {
+ return -static_cast<int>(rval);
+ }
+ return 0;
+}
+
+extern "C" uintptr_t raw_syscall(int number, ...);
+
+static inline void sys_exit_group(int status) {
+ raw_syscall(kExitGroup, status);
+}
+static inline int sys_getpid() {
+ return static_cast<int>(raw_syscall(kGetpid));
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_SYSCALLS_H_
diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h
new file mode 100644
index 000000000..3b7bb74b1
--- /dev/null
+++ b/test/syscalls/linux/rseq/test.h
@@ -0,0 +1,43 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_
+
+namespace gvisor {
+namespace testing {
+
+// Test cases supported by rseq binary.
+
+inline constexpr char kRseqTestUnaligned[] = "unaligned";
+inline constexpr char kRseqTestRegister[] = "register";
+inline constexpr char kRseqTestDoubleRegister[] = "double-register";
+inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister";
+inline constexpr char kRseqTestUnregisterDifferentPtr[] =
+ "unregister-different-ptr";
+inline constexpr char kRseqTestUnregisterDifferentSignature[] =
+ "unregister-different-signature";
+inline constexpr char kRseqTestCPU[] = "cpu";
+inline constexpr char kRseqTestAbort[] = "abort";
+inline constexpr char kRseqTestAbortBefore[] = "abort-before";
+inline constexpr char kRseqTestAbortSignature[] = "abort-signature";
+inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit";
+inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs";
+inline constexpr char kRseqTestInvalidAbortClearsCS[] =
+ "invalid-abort-clears-cs";
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TEST_H_
diff --git a/test/syscalls/linux/rseq/types.h b/test/syscalls/linux/rseq/types.h
new file mode 100644
index 000000000..b6afe9817
--- /dev/null
+++ b/test/syscalls/linux/rseq/types.h
@@ -0,0 +1,31 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_
+
+using size_t = __SIZE_TYPE__;
+using uintptr_t = __UINTPTR_TYPE__;
+
+using uint8_t = __UINT8_TYPE__;
+using uint16_t = __UINT16_TYPE__;
+using uint32_t = __UINT32_TYPE__;
+using uint64_t = __UINT64_TYPE__;
+
+using int8_t = __INT8_TYPE__;
+using int16_t = __INT16_TYPE__;
+using int32_t = __INT32_TYPE__;
+using int64_t = __INT64_TYPE__;
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_TYPES_H_
diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h
new file mode 100644
index 000000000..d3e60d0a4
--- /dev/null
+++ b/test/syscalls/linux/rseq/uapi.h
@@ -0,0 +1,51 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
+
+#include <stdint.h>
+
+// User-kernel ABI for restartable sequences.
+
+// Syscall numbers.
+#if defined(__x86_64__)
+constexpr int kRseqSyscall = 334;
+#elif defined(__aarch64__)
+constexpr int kRseqSyscall = 293;
+#else
+#error "Unknown architecture"
+#endif // __x86_64__
+
+struct rseq_cs {
+ uint32_t version;
+ uint32_t flags;
+ uint64_t start_ip;
+ uint64_t post_commit_offset;
+ uint64_t abort_ip;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
+
+// N.B. alignment is enforced by the kernel.
+struct rseq {
+ uint32_t cpu_id_start;
+ uint32_t cpu_id;
+ struct rseq_cs* rseq_cs;
+ uint32_t flags;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
+
+constexpr int kRseqFlagUnregister = 1 << 0;
+
+constexpr int kRseqCPUIDUninitialized = -1;
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc
index 81d193ffd..ed27e2566 100644
--- a/test/syscalls/linux/rtsignal.cc
+++ b/test/syscalls/linux/rtsignal.cc
@@ -167,6 +167,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc
index e77586852..ce88d90dd 100644
--- a/test/syscalls/linux/seccomp.cc
+++ b/test/syscalls/linux/seccomp.cc
@@ -25,6 +25,7 @@
#include <time.h>
#include <ucontext.h>
#include <unistd.h>
+
#include <atomic>
#include "gmock/gmock.h"
@@ -48,7 +49,12 @@ namespace testing {
namespace {
// A syscall not implemented by Linux that we don't expect to be called.
+#ifdef __x86_64__
constexpr uint32_t kFilteredSyscall = SYS_vserver;
+#elif __aarch64__
+// Use the last of arch_specific_syscalls which are not implemented on arm64.
+constexpr uint32_t kFilteredSyscall = __NR_arch_specific_syscall + 15;
+#endif
// Applies a seccomp-bpf filter that returns `filtered_result` for
// `sysno` and allows all other syscalls. Async-signal-safe.
@@ -64,20 +70,27 @@ void ApplySeccompFilter(uint32_t sysno, uint32_t filtered_result,
MaybeSave();
struct sock_filter filter[] = {
- // A = seccomp_data.arch
- BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4),
- // if (A != AUDIT_ARCH_X86_64) goto kill
- BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4),
- // A = seccomp_data.nr
- BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0),
- // if (A != sysno) goto allow
- BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1),
- // return filtered_result
- BPF_STMT(BPF_RET | BPF_K, filtered_result),
- // allow: return SECCOMP_RET_ALLOW
- BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
- // kill: return SECCOMP_RET_KILL
- BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL),
+ // A = seccomp_data.arch
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 4),
+#if defined(__x86_64__)
+ // if (A != AUDIT_ARCH_X86_64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_X86_64, 0, 4),
+#elif defined(__aarch64__)
+ // if (A != AUDIT_ARCH_AARCH64) goto kill
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, AUDIT_ARCH_AARCH64, 0, 4),
+#else
+#error "Unknown architecture"
+#endif
+ // A = seccomp_data.nr
+ BPF_STMT(BPF_LD | BPF_ABS | BPF_W, 0),
+ // if (A != sysno) goto allow
+ BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, sysno, 0, 1),
+ // return filtered_result
+ BPF_STMT(BPF_RET | BPF_K, filtered_result),
+ // allow: return SECCOMP_RET_ALLOW
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_ALLOW),
+ // kill: return SECCOMP_RET_KILL
+ BPF_STMT(BPF_RET | BPF_K, SECCOMP_RET_KILL),
};
struct sock_fprog prog;
prog.len = ABSL_ARRAYSIZE(filter);
@@ -112,7 +125,8 @@ TEST(SeccompTest, RetKillCausesDeathBySIGSYS) {
pid_t const pid = fork();
if (pid == 0) {
// Register a signal handler for SIGSYS that we don't expect to be invoked.
- RegisterSignalHandler(SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
+ RegisterSignalHandler(
+ SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL);
syscall(kFilteredSyscall);
TEST_CHECK_MSG(false, "Survived invocation of test syscall");
@@ -131,7 +145,8 @@ TEST(SeccompTest, RetKillOnlyKillsOneThread) {
pid_t const pid = fork();
if (pid == 0) {
// Register a signal handler for SIGSYS that we don't expect to be invoked.
- RegisterSignalHandler(SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
+ RegisterSignalHandler(
+ SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL);
// Pass CLONE_VFORK to block the original thread in the child process until
// the clone thread exits with SIGSYS.
@@ -171,9 +186,12 @@ TEST(SeccompTest, RetTrapCausesSIGSYS) {
TEST_CHECK(info->si_errno == kTrapValue);
TEST_CHECK(info->si_call_addr != nullptr);
TEST_CHECK(info->si_syscall == kFilteredSyscall);
-#ifdef __x86_64__
+#if defined(__x86_64__)
TEST_CHECK(info->si_arch == AUDIT_ARCH_X86_64);
TEST_CHECK(uc->uc_mcontext.gregs[REG_RAX] == kFilteredSyscall);
+#elif defined(__aarch64__)
+ TEST_CHECK(info->si_arch == AUDIT_ARCH_AARCH64);
+ TEST_CHECK(uc->uc_mcontext.regs[8] == kFilteredSyscall);
#endif // defined(__x86_64__)
_exit(0);
});
@@ -345,7 +363,8 @@ TEST(SeccompTest, LeastPermissiveFilterReturnValueApplies) {
// one that causes the kill that should be ignored.
pid_t const pid = fork();
if (pid == 0) {
- RegisterSignalHandler(SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
+ RegisterSignalHandler(
+ SIGSYS, +[](int, siginfo_t*, void*) { _exit(1); });
ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_TRACE);
ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_KILL);
ApplySeccompFilter(kFilteredSyscall, SECCOMP_RET_ERRNO | ENOTNAM);
@@ -402,5 +421,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/select.cc b/test/syscalls/linux/select.cc
index e06a2666d..be2364fb8 100644
--- a/test/syscalls/linux/select.cc
+++ b/test/syscalls/linux/select.cc
@@ -16,6 +16,7 @@
#include <sys/resource.h>
#include <sys/select.h>
#include <sys/time.h>
+
#include <climits>
#include <csignal>
#include <cstdio>
@@ -145,7 +146,7 @@ TEST_F(SelectTest, IgnoreBitsAboveNfds) {
// This test illustrates Linux's behavior of 'select' calls passing after
// setrlimit RLIMIT_NOFILE is called. In particular, versions of sshd rely on
-// this behavior.
+// this behavior. See b/122318458.
TEST_F(SelectTest, SetrlimitCallNOFILE) {
fd_set read_set;
FD_ZERO(&read_set);
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 40c57f543..e9b131ca9 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -447,9 +447,8 @@ TEST(SemaphoreTest, SemCtlGetPidFork) {
const pid_t child_pid = fork();
if (child_pid == 0) {
- ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds());
- ASSERT_THAT(semctl(sem.get(), 0, GETPID),
- SyscallSucceedsWithValue(getpid()));
+ TEST_PCHECK(semctl(sem.get(), 0, SETVAL, 1) == 0);
+ TEST_PCHECK(semctl(sem.get(), 0, GETPID) == getpid());
_exit(0);
}
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index 580ab5193..64123e904 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/eventfd.h>
#include <sys/sendfile.h>
#include <unistd.h>
@@ -70,6 +71,28 @@ TEST(SendFileTest, InvalidOffset) {
SyscallFailsWithErrno(EINVAL));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST(SendFileTest, Overflow) {
+ // Create input file.
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Open the output file.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds());
+ const FileDescriptor outf(fd);
+
+ // out_offset + kSize overflows INT64_MAX.
+ loff_t out_offset = 0x7ffffffffffffffeull;
+ constexpr int kSize = 3;
+ EXPECT_THAT(sendfile(outf.get(), inf.get(), &out_offset, kSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(SendFileTest, SendTrivially) {
// Create temp files.
constexpr char kData[] = "To be, or not to be, that is the question:";
@@ -530,6 +553,34 @@ TEST(SendFileTest, SendToSpecialFile) {
SyscallSucceedsWithValue(kSize & (~7)));
}
+TEST(SendFileTest, SendFileToPipe) {
+ // Create temp file.
+ constexpr char kData[] = "<insert-quote-here>";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Create a pipe for sending to a pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Expect to read up to the given size.
+ std::vector<char> buf(kDataSize);
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kDataSize));
+ });
+
+ // Send with twice the size of the file, which should hit EOF.
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize * 2),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sendfile_socket.cc b/test/syscalls/linux/sendfile_socket.cc
index 3331288b7..c101fe9d2 100644
--- a/test/syscalls/linux/sendfile_socket.cc
+++ b/test/syscalls/linux/sendfile_socket.cc
@@ -23,6 +23,7 @@
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
@@ -35,61 +36,39 @@ namespace {
class SendFileTest : public ::testing::TestWithParam<int> {
protected:
- PosixErrorOr<std::tuple<int, int>> Sockets() {
+ PosixErrorOr<std::unique_ptr<SocketPair>> Sockets(int type) {
// Bind a server socket.
int family = GetParam();
- struct sockaddr server_addr = {};
switch (family) {
case AF_INET: {
- struct sockaddr_in *server_addr_in =
- reinterpret_cast<struct sockaddr_in *>(&server_addr);
- server_addr_in->sin_family = family;
- server_addr_in->sin_addr.s_addr = INADDR_ANY;
- break;
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "TCP", AF_INET, type, 0,
+ TCPAcceptBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UDP", AF_INET, type, 0,
+ UDPBidirectionalBindSocketPairCreator(AF_INET, type, 0, false)}
+ .Create();
+ }
}
case AF_UNIX: {
- struct sockaddr_un *server_addr_un =
- reinterpret_cast<struct sockaddr_un *>(&server_addr);
- server_addr_un->sun_family = family;
- server_addr_un->sun_path[0] = '\0';
- break;
+ if (type == SOCK_STREAM) {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemAcceptBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ } else {
+ return SocketPairKind{
+ "UNIX", AF_UNIX, type, 0,
+ FilesystemBidirectionalBindSocketPairCreator(AF_UNIX, type, 0)}
+ .Create();
+ }
}
default:
return PosixError(EINVAL);
}
- int server = socket(family, SOCK_STREAM, 0);
- if (bind(server, &server_addr, sizeof(server_addr)) < 0) {
- return PosixError(errno);
- }
- if (listen(server, 1) < 0) {
- close(server);
- return PosixError(errno);
- }
-
- // Fetch the address; both are anonymous.
- socklen_t length = sizeof(server_addr);
- if (getsockname(server, &server_addr, &length) < 0) {
- close(server);
- return PosixError(errno);
- }
-
- // Connect the client.
- int client = socket(family, SOCK_STREAM, 0);
- if (connect(client, &server_addr, length) < 0) {
- close(server);
- close(client);
- return PosixError(errno);
- }
-
- // Accept on the server.
- int server_client = accept(server, nullptr, 0);
- if (server_client < 0) {
- close(server);
- close(client);
- return PosixError(errno);
- }
- close(server);
- return std::make_tuple(client, server_client);
}
};
@@ -106,9 +85,7 @@ TEST_P(SendFileTest, SendMultiple) {
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
// Create sockets.
- std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
- const FileDescriptor server(std::get<0>(fds));
- FileDescriptor client(std::get<1>(fds)); // non-const, reset is used.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
// Thread that reads data from socket and dumps to a file.
ScopedThread th([&] {
@@ -118,7 +95,7 @@ TEST_P(SendFileTest, SendMultiple) {
// Read until socket is closed.
char buf[10240];
for (int cnt = 0;; cnt++) {
- int r = RetryEINTR(read)(server.get(), buf, sizeof(buf));
+ int r = RetryEINTR(read)(socks->first_fd(), buf, sizeof(buf));
// We cannot afford to save on every read() call.
if (cnt % 1000 == 0) {
ASSERT_THAT(r, SyscallSucceeds());
@@ -149,10 +126,10 @@ TEST_P(SendFileTest, SendMultiple) {
for (size_t sent = 0; sent < data.size(); cnt++) {
const size_t remain = data.size() - sent;
std::cout << "sendfile, size=" << data.size() << ", sent=" << sent
- << ", remain=" << remain;
+ << ", remain=" << remain << std::endl;
// Send data and verify that sendfile returns the correct value.
- int res = sendfile(client.get(), inf.get(), nullptr, remain);
+ int res = sendfile(socks->second_fd(), inf.get(), nullptr, remain);
// We cannot afford to save on every sendfile() call.
if (cnt % 120 == 0) {
MaybeSave();
@@ -169,7 +146,7 @@ TEST_P(SendFileTest, SendMultiple) {
}
// Close socket to stop thread.
- client.reset();
+ close(socks->release_second_fd());
th.Join();
// Verify that the output file has the correct data.
@@ -183,9 +160,7 @@ TEST_P(SendFileTest, SendMultiple) {
TEST_P(SendFileTest, Shutdown) {
// Create a socket.
- std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
- const FileDescriptor client(std::get<0>(fds));
- FileDescriptor server(std::get<1>(fds)); // non-const, reset below.
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_STREAM));
// If this is a TCP socket, then turn off linger.
if (GetParam() == AF_INET) {
@@ -193,7 +168,7 @@ TEST_P(SendFileTest, Shutdown) {
sl.l_onoff = 1;
sl.l_linger = 0;
ASSERT_THAT(
- setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
+ setsockopt(socks->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
SyscallSucceeds());
}
@@ -212,12 +187,12 @@ TEST_P(SendFileTest, Shutdown) {
ScopedThread t([&]() {
size_t done = 0;
while (done < data.size()) {
- int n = RetryEINTR(read)(server.get(), data.data(), data.size());
+ int n = RetryEINTR(read)(socks->first_fd(), data.data(), data.size());
ASSERT_THAT(n, SyscallSucceeds());
done += n;
}
// Close the server side socket.
- server.reset();
+ close(socks->release_first_fd());
});
// Continuously stream from the file to the socket. Note we do not assert
@@ -225,7 +200,7 @@ TEST_P(SendFileTest, Shutdown) {
// data is written. Eventually, we should get a connection reset error.
while (1) {
off_t offset = 0; // Always read from the start.
- int n = sendfile(client.get(), inf.get(), &offset, data.size());
+ int n = sendfile(socks->second_fd(), inf.get(), &offset, data.size());
EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET),
SyscallFailsWithErrno(EPIPE), SyscallSucceeds()));
if (n <= 0) {
@@ -234,6 +209,20 @@ TEST_P(SendFileTest, Shutdown) {
}
}
+TEST_P(SendFileTest, SendpageFromEmptyFileToUDP) {
+ auto socks = ASSERT_NO_ERRNO_AND_VALUE(Sockets(SOCK_DGRAM));
+
+ TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+
+ // The value to the count argument has to be so that it is impossible to
+ // allocate a buffer of this size. In Linux, sendfile transfer at most
+ // 0x7ffff000 (MAX_RW_COUNT) bytes.
+ EXPECT_THAT(sendfile(socks->first_fd(), fd.get(), 0x0, 0x8000000000004),
+ SyscallSucceedsWithValue(0));
+}
+
INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest,
::testing::Values(AF_UNIX, AF_INET));
diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc
index eb7a3966f..c7fdbb924 100644
--- a/test/syscalls/linux/shm.cc
+++ b/test/syscalls/linux/shm.cc
@@ -13,7 +13,6 @@
// limitations under the License.
#include <stdio.h>
-
#include <sys/ipc.h>
#include <sys/mman.h>
#include <sys/shm.h>
@@ -474,7 +473,7 @@ TEST(ShmTest, PartialUnmap) {
}
// Check that sentry does not panic when asked for a zero-length private shm
-// segment.
+// segment. Regression test for b/110694797.
TEST(ShmTest, GracefullyFailOnZeroLenSegmentCreation) {
EXPECT_THAT(Shmget(IPC_PRIVATE, 0, 0), PosixErrorIs(EINVAL, _));
}
diff --git a/test/syscalls/linux/sigaction.cc b/test/syscalls/linux/sigaction.cc
index 9a53fd3e0..9d9dd57a8 100644
--- a/test/syscalls/linux/sigaction.cc
+++ b/test/syscalls/linux/sigaction.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <signal.h>
+#include <sys/syscall.h>
#include "gtest/gtest.h"
#include "test/util/test_util.h"
@@ -23,45 +24,53 @@ namespace testing {
namespace {
TEST(SigactionTest, GetLessThanOrEqualToZeroFails) {
- struct sigaction act;
- memset(&act, 0, sizeof(act));
- ASSERT_THAT(sigaction(-1, NULL, &act), SyscallFailsWithErrno(EINVAL));
- ASSERT_THAT(sigaction(0, NULL, &act), SyscallFailsWithErrno(EINVAL));
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(-1, nullptr, &act), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(sigaction(0, nullptr, &act), SyscallFailsWithErrno(EINVAL));
}
TEST(SigactionTest, SetLessThanOrEqualToZeroFails) {
- struct sigaction act;
- memset(&act, 0, sizeof(act));
- ASSERT_THAT(sigaction(0, &act, NULL), SyscallFailsWithErrno(EINVAL));
- ASSERT_THAT(sigaction(0, &act, NULL), SyscallFailsWithErrno(EINVAL));
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(sigaction(0, &act, nullptr), SyscallFailsWithErrno(EINVAL));
}
TEST(SigactionTest, GetGreaterThanMaxFails) {
- struct sigaction act;
- memset(&act, 0, sizeof(act));
- ASSERT_THAT(sigaction(SIGRTMAX + 1, NULL, &act),
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGRTMAX + 1, nullptr, &act),
SyscallFailsWithErrno(EINVAL));
}
TEST(SigactionTest, SetGreaterThanMaxFails) {
- struct sigaction act;
- memset(&act, 0, sizeof(act));
- ASSERT_THAT(sigaction(SIGRTMAX + 1, &act, NULL),
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGRTMAX + 1, &act, nullptr),
SyscallFailsWithErrno(EINVAL));
}
TEST(SigactionTest, SetSigkillFails) {
- struct sigaction act;
- memset(&act, 0, sizeof(act));
- ASSERT_THAT(sigaction(SIGKILL, NULL, &act), SyscallSucceeds());
- ASSERT_THAT(sigaction(SIGKILL, &act, NULL), SyscallFailsWithErrno(EINVAL));
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGKILL, nullptr, &act), SyscallSucceeds());
+ ASSERT_THAT(sigaction(SIGKILL, &act, nullptr), SyscallFailsWithErrno(EINVAL));
}
TEST(SigactionTest, SetSigstopFails) {
- struct sigaction act;
- memset(&act, 0, sizeof(act));
- ASSERT_THAT(sigaction(SIGSTOP, NULL, &act), SyscallSucceeds());
- ASSERT_THAT(sigaction(SIGSTOP, &act, NULL), SyscallFailsWithErrno(EINVAL));
+ struct sigaction act = {};
+ ASSERT_THAT(sigaction(SIGSTOP, nullptr, &act), SyscallSucceeds());
+ ASSERT_THAT(sigaction(SIGSTOP, &act, nullptr), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SigactionTest, BadSigsetFails) {
+ constexpr size_t kWrongSigSetSize = 43;
+
+ struct sigaction act = {};
+
+ // The syscall itself (rather than the libc wrapper) takes the sigset_t size.
+ ASSERT_THAT(
+ syscall(SYS_rt_sigaction, SIGTERM, nullptr, &act, kWrongSigSetSize),
+ SyscallFailsWithErrno(EINVAL));
+ ASSERT_THAT(
+ syscall(SYS_rt_sigaction, SIGTERM, &act, nullptr, kWrongSigSetSize),
+ SyscallFailsWithErrno(EINVAL));
}
} // namespace
diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc
index 6fd3989a4..24e7c4960 100644
--- a/test/syscalls/linux/sigaltstack.cc
+++ b/test/syscalls/linux/sigaltstack.cc
@@ -95,13 +95,7 @@ TEST(SigaltstackTest, ResetByExecve) {
auto const cleanup_sigstack =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaltstack(stack));
- std::string full_path;
- char* test_src = getenv("TEST_SRCDIR");
- if (test_src) {
- full_path = JoinPath(test_src, "../../linux/sigaltstack_check");
- }
-
- ASSERT_FALSE(full_path.empty());
+ std::string full_path = RunfilePath("test/syscalls/linux/sigaltstack_check");
pid_t child_pid = -1;
int execve_errno = 0;
@@ -120,7 +114,7 @@ TEST(SigaltstackTest, ResetByExecve) {
volatile bool badhandler_on_sigaltstack = true; // Set by the handler.
char* volatile badhandler_low_water_mark = nullptr; // Set by the handler.
-volatile uint8_t badhandler_recursive_faults = 0; // Consumed by the handler.
+volatile uint8_t badhandler_recursive_faults = 0; // Consumed by the handler.
void badhandler(int sig, siginfo_t* siginfo, void* arg) {
char stack_var = 0;
@@ -174,8 +168,8 @@ TEST(SigaltstackTest, WalksOffBottom) {
// Trigger a single fault.
badhandler_low_water_mark =
- static_cast<char*>(stack.ss_sp) + SIGSTKSZ; // Expected top.
- badhandler_recursive_faults = 0; // Disable refault.
+ static_cast<char*>(stack.ss_sp) + SIGSTKSZ; // Expected top.
+ badhandler_recursive_faults = 0; // Disable refault.
Fault();
EXPECT_TRUE(badhandler_on_sigaltstack);
EXPECT_THAT(sigaltstack(nullptr, &stack), SyscallSucceeds());
diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc
index a47c781ea..6227774a4 100644
--- a/test/syscalls/linux/sigiret.cc
+++ b/test/syscalls/linux/sigiret.cc
@@ -78,8 +78,8 @@ TEST(SigIretTest, CheckRcxR11) {
"1: pause; cmpl $0, %[gotvtalrm]; je 1b;" // while (!gotvtalrm);
"movq %%rcx, %[rcx];" // rcx = %rcx
"movq %%r11, %[r11];" // r11 = %r11
- : [ready] "=m"(ready), [rcx] "+m"(rcx), [r11] "+m"(r11)
- : [gotvtalrm] "m"(gotvtalrm)
+ : [ ready ] "=m"(ready), [ rcx ] "+m"(rcx), [ r11 ] "+m"(r11)
+ : [ gotvtalrm ] "m"(gotvtalrm)
: "cc", "memory", "rcx", "r11");
// If sigreturn(2) returns via 'sysret' then %rcx and %r11 will be
@@ -132,6 +132,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
index 09ecad34a..389e5fca2 100644
--- a/test/syscalls/linux/signalfd.cc
+++ b/test/syscalls/linux/signalfd.cc
@@ -39,6 +39,7 @@ namespace testing {
namespace {
constexpr int kSigno = SIGUSR1;
+constexpr int kSignoMax = 64; // SIGRTMAX
constexpr int kSignoAlt = SIGUSR2;
// Returns a new signalfd.
@@ -51,41 +52,45 @@ inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) {
return FileDescriptor(fd);
}
-TEST(Signalfd, Basic) {
+class SignalfdTest : public ::testing::TestWithParam<int> {};
+
+TEST_P(SignalfdTest, Basic) {
+ int signo = GetParam();
// Create the signalfd.
sigset_t mask;
sigemptyset(&mask);
- sigaddset(&mask, kSigno);
+ sigaddset(&mask, signo);
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
// Deliver the blocked signal.
const auto scoped_sigmask =
- ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
- ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
// We should now read the signal.
struct signalfd_siginfo rbuf;
ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
SyscallSucceedsWithValue(sizeof(rbuf)));
- EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ EXPECT_EQ(rbuf.ssi_signo, signo);
}
-TEST(Signalfd, MaskWorks) {
+TEST_P(SignalfdTest, MaskWorks) {
+ int signo = GetParam();
// Create two signalfds with different masks.
sigset_t mask1, mask2;
sigemptyset(&mask1);
sigemptyset(&mask2);
- sigaddset(&mask1, kSigno);
+ sigaddset(&mask1, signo);
sigaddset(&mask2, kSignoAlt);
FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0));
FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0));
// Deliver the two signals.
const auto scoped_sigmask1 =
- ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
const auto scoped_sigmask2 =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt));
- ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds());
// We should see the signals on the appropriate signalfds.
@@ -98,7 +103,7 @@ TEST(Signalfd, MaskWorks) {
EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt);
ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)),
SyscallSucceedsWithValue(sizeof(rbuf1)));
- EXPECT_EQ(rbuf1.ssi_signo, kSigno);
+ EXPECT_EQ(rbuf1.ssi_signo, signo);
}
TEST(Signalfd, Cloexec) {
@@ -111,11 +116,12 @@ TEST(Signalfd, Cloexec) {
EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
}
-TEST(Signalfd, Blocking) {
+TEST_P(SignalfdTest, Blocking) {
+ int signo = GetParam();
// Create the signalfd in blocking mode.
sigset_t mask;
sigemptyset(&mask);
- sigaddset(&mask, kSigno);
+ sigaddset(&mask, signo);
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
// Shared tid variable.
@@ -136,7 +142,7 @@ TEST(Signalfd, Blocking) {
struct signalfd_siginfo rbuf;
ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
SyscallSucceedsWithValue(sizeof(rbuf)));
- EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ EXPECT_EQ(rbuf.ssi_signo, signo);
});
// Wait until blocked.
@@ -149,20 +155,21 @@ TEST(Signalfd, Blocking) {
//
// See gvisor.dev/issue/139.
if (IsRunningOnGvisor()) {
- ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
} else {
- ASSERT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), tid, signo), SyscallSucceeds());
}
// Ensure that it was received.
t.Join();
}
-TEST(Signalfd, ThreadGroup) {
+TEST_P(SignalfdTest, ThreadGroup) {
+ int signo = GetParam();
// Create the signalfd in blocking mode.
sigset_t mask;
sigemptyset(&mask);
- sigaddset(&mask, kSigno);
+ sigaddset(&mask, signo);
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
// Shared variable.
@@ -176,7 +183,7 @@ TEST(Signalfd, ThreadGroup) {
struct signalfd_siginfo rbuf;
ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
SyscallSucceedsWithValue(sizeof(rbuf)));
- EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ EXPECT_EQ(rbuf.ssi_signo, signo);
// Wait for the other thread.
absl::MutexLock ml(&mu);
@@ -185,7 +192,7 @@ TEST(Signalfd, ThreadGroup) {
});
// Deliver the signal to the threadgroup.
- ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+ ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds());
// Wait for the first thread to process.
{
@@ -194,13 +201,13 @@ TEST(Signalfd, ThreadGroup) {
}
// Deliver to the thread group again (other thread still exists).
- ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+ ASSERT_THAT(kill(getpid(), signo), SyscallSucceeds());
// Ensure that we can also receive it.
struct signalfd_siginfo rbuf;
ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
SyscallSucceedsWithValue(sizeof(rbuf)));
- EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ EXPECT_EQ(rbuf.ssi_signo, signo);
// Mark the test as done.
{
@@ -212,11 +219,12 @@ TEST(Signalfd, ThreadGroup) {
t.Join();
}
-TEST(Signalfd, Nonblock) {
+TEST_P(SignalfdTest, Nonblock) {
+ int signo = GetParam();
// Create the signalfd in non-blocking mode.
sigset_t mask;
sigemptyset(&mask);
- sigaddset(&mask, kSigno);
+ sigaddset(&mask, signo);
FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
@@ -227,20 +235,21 @@ TEST(Signalfd, Nonblock) {
// Block and deliver the signal.
const auto scoped_sigmask =
- ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
- ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
// Ensure that a read actually works.
ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
SyscallSucceedsWithValue(sizeof(rbuf)));
- EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ EXPECT_EQ(rbuf.ssi_signo, signo);
// Should block again.
EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
SyscallFailsWithErrno(EWOULDBLOCK));
}
-TEST(Signalfd, SetMask) {
+TEST_P(SignalfdTest, SetMask) {
+ int signo = GetParam();
// Create the signalfd matching nothing.
sigset_t mask;
sigemptyset(&mask);
@@ -249,8 +258,8 @@ TEST(Signalfd, SetMask) {
// Block and deliver a signal.
const auto scoped_sigmask =
- ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
- ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
+ ASSERT_THAT(tgkill(getpid(), gettid(), signo), SyscallSucceeds());
// We should have nothing.
struct signalfd_siginfo rbuf;
@@ -258,29 +267,30 @@ TEST(Signalfd, SetMask) {
SyscallFailsWithErrno(EWOULDBLOCK));
// Change the signal mask.
- sigaddset(&mask, kSigno);
+ sigaddset(&mask, signo);
ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds());
// We should now have the signal.
ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
SyscallSucceedsWithValue(sizeof(rbuf)));
- EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ EXPECT_EQ(rbuf.ssi_signo, signo);
}
-TEST(Signalfd, Poll) {
+TEST_P(SignalfdTest, Poll) {
+ int signo = GetParam();
// Create the signalfd.
sigset_t mask;
sigemptyset(&mask);
- sigaddset(&mask, kSigno);
+ sigaddset(&mask, signo);
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
// Block the signal, and start a thread to deliver it.
const auto scoped_sigmask =
- ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, signo));
pid_t orig_tid = gettid();
ScopedThread t([&] {
absl::SleepFor(absl::Seconds(5));
- ASSERT_THAT(tgkill(getpid(), orig_tid, kSigno), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), orig_tid, signo), SyscallSucceeds());
});
// Start polling for the signal. We expect that it is not available at the
@@ -297,19 +307,18 @@ TEST(Signalfd, Poll) {
SyscallSucceedsWithValue(sizeof(rbuf)));
}
-TEST(Signalfd, KillStillKills) {
- sigset_t mask;
- sigemptyset(&mask);
- sigaddset(&mask, SIGKILL);
- FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
-
- // Just because there is a signalfd, we shouldn't see any change in behavior
- // for unblockable signals. It's easier to test this with SIGKILL.
- const auto scoped_sigmask =
- ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL));
- EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), "");
+std::string PrintSigno(::testing::TestParamInfo<int> info) {
+ switch (info.param) {
+ case kSigno:
+ return "kSigno";
+ case kSignoMax:
+ return "kSignoMax";
+ default:
+ return absl::StrCat(info.param);
+ }
}
+INSTANTIATE_TEST_SUITE_P(Signalfd, SignalfdTest,
+ ::testing::Values(kSigno, kSignoMax), PrintSigno);
TEST(Signalfd, Ppoll) {
sigset_t mask;
@@ -328,6 +337,20 @@ TEST(Signalfd, Ppoll) {
SyscallSucceedsWithValue(0));
}
+TEST(Signalfd, KillStillKills) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+
+ // Just because there is a signalfd, we shouldn't see any change in behavior
+ // for unblockable signals. It's easier to test this with SIGKILL.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL));
+ EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), "");
+}
+
} // namespace
} // namespace testing
@@ -340,10 +363,11 @@ int main(int argc, char** argv) {
sigset_t set;
sigemptyset(&set);
sigaddset(&set, gvisor::testing::kSigno);
+ sigaddset(&set, gvisor::testing::kSignoMax);
sigaddset(&set, gvisor::testing::kSignoAlt);
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/sigprocmask.cc b/test/syscalls/linux/sigprocmask.cc
index 654c6a47f..a603fc1d1 100644
--- a/test/syscalls/linux/sigprocmask.cc
+++ b/test/syscalls/linux/sigprocmask.cc
@@ -237,7 +237,7 @@ TEST_F(SigProcMaskTest, SignalHandler) {
}
// Check that sigprocmask correctly handles aliasing of the set and oldset
-// pointers.
+// pointers. Regression test for b/30502311.
TEST_F(SigProcMaskTest, AliasedSets) {
sigset_t mask;
diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc
index 7db57d968..b2fcedd62 100644
--- a/test/syscalls/linux/sigstop.cc
+++ b/test/syscalls/linux/sigstop.cc
@@ -147,5 +147,5 @@ int main(int argc, char** argv) {
return 1;
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc
index 1e5bf5942..4f8afff15 100644
--- a/test/syscalls/linux/sigtimedwait.cc
+++ b/test/syscalls/linux/sigtimedwait.cc
@@ -319,6 +319,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index 3a07ac8d2..c20cd3fcc 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -13,11 +13,14 @@
// limitations under the License.
#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
#include <unistd.h>
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
+#include "test/util/temp_umask.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -58,12 +61,45 @@ TEST(SocketTest, ProtocolInet) {
}
}
+TEST(SocketTest, UnixSocketStat) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor bound =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+
+ // The permissions of the file created with bind(2) should be defined by the
+ // permissions of the bound socket and the umask.
+ mode_t sock_perm = 0765, mask = 0123;
+ ASSERT_THAT(fchmod(bound.get(), sock_perm), SyscallSucceeds());
+ TempUmask m(mask);
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX));
+ ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ struct stat statbuf = {};
+ ASSERT_THAT(stat(addr.sun_path, &statbuf), SyscallSucceeds());
+
+ // Mode should be S_IFSOCK.
+ EXPECT_EQ(statbuf.st_mode, S_IFSOCK | sock_perm & ~mask);
+
+ // Timestamps should be equal and non-zero.
+ // TODO(b/158882152): Sockets currently don't implement timestamps.
+ if (!IsRunningOnGvisor()) {
+ EXPECT_NE(statbuf.st_atime, 0);
+ EXPECT_EQ(statbuf.st_atime, statbuf.st_mtime);
+ EXPECT_EQ(statbuf.st_atime, statbuf.st_ctime);
+ }
+}
+
using SocketOpenTest = ::testing::TestWithParam<int>;
// UDS cannot be opened.
TEST_P(SocketOpenTest, Unix) {
// FIXME(b/142001530): Open incorrectly succeeds on gVisor.
- SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(IsRunningWithVFS1());
FileDescriptor bound =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
diff --git a/test/syscalls/linux/socket_abstract.cc b/test/syscalls/linux/socket_abstract.cc
index 715d87b76..00999f192 100644
--- a/test/syscalls/linux/socket_abstract.cc
+++ b/test/syscalls/linux/socket_abstract.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -43,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P(
AbstractUnixSockets, UnixSocketPairCmsgTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
index 5767181a1..5ed57625c 100644
--- a/test/syscalls/linux/socket_bind_to_device_distribution.cc
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -183,7 +183,14 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -198,15 +205,29 @@ TEST_P(BindToDeviceDistributionTest, Tcp) {
}
for (int i = 0; i < kConnectAttempts; i++) {
- FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
ASSERT_THAT(
RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
// Join threads to be sure that all connections have been counted.
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
index e4641c62e..d3cc71dbf 100644
--- a/test/syscalls/linux/socket_bind_to_device_sequence.cc
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -33,6 +33,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/container/node_hash_map.h"
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_bind_to_device_util.h"
#include "test/syscalls/linux/socket_test_util.h"
@@ -66,7 +67,7 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
// Gets a device by device_id. If the device_id has been seen before, returns
// the previously returned device. If not, finds or creates a new device.
// Returns an empty string on failure.
- void GetDevice(int device_id, string *device_name) {
+ void GetDevice(int device_id, string* device_name) {
auto device = devices_.find(device_id);
if (device != devices_.end()) {
*device_name = device->second;
@@ -97,12 +98,22 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
sockets_to_close_.erase(socket_id);
}
- // Bind a socket with the reuse option and bind_to_device options. Checks
+ // SetDevice changes the bind_to_device option. It does not bind or re-bind.
+ void SetDevice(int socket_id, int device_id) {
+ auto socket_fd = sockets_to_close_[socket_id]->get();
+ string device_name;
+ ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name));
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE,
+ device_name.c_str(), device_name.size() + 1),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // Bind a socket with the reuse options and bind_to_device options. Checks
// that all steps succeed and that the bind command's error matches want.
// Sets the socket_id to uniquely identify the socket bound if it is not
// nullptr.
- void BindSocket(bool reuse, int device_id = 0, int want = 0,
- int *socket_id = nullptr) {
+ void BindSocket(bool reuse_port, bool reuse_addr, int device_id = 0,
+ int want = 0, int* socket_id = nullptr) {
next_socket_id_++;
sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto socket_fd = sockets_to_close_[next_socket_id_]->get();
@@ -110,13 +121,20 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
*socket_id = next_socket_id_;
}
- // If reuse is indicated, do that.
- if (reuse) {
+ // If reuse_port is indicated, do that.
+ if (reuse_port) {
EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
sizeof(kSockOptOn)),
SyscallSucceedsWithValue(0));
}
+ // If reuse_addr is indicated, do that.
+ if (reuse_addr) {
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ }
+
// If the device is non-zero, bind to that device.
if (device_id != 0) {
string device_name;
@@ -137,12 +155,12 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
addr.sin_port = port_;
if (want == 0) {
ASSERT_THAT(
- bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr),
sizeof(addr)),
SyscallSucceeds());
} else {
ASSERT_THAT(
- bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ bind(socket_fd, reinterpret_cast<const struct sockaddr*>(&addr),
sizeof(addr)),
SyscallFailsWithErrno(want));
}
@@ -152,7 +170,7 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
// remember it for future commands.
socklen_t addr_size = sizeof(addr);
ASSERT_THAT(
- getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr),
+ getsockname(socket_fd, reinterpret_cast<struct sockaddr*>(&addr),
&addr_size),
SyscallSucceeds());
port_ = addr.sin_port;
@@ -162,7 +180,7 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
private:
SocketKind socket_factory_;
// devices maps from the device id in the test case to the name of the device.
- std::unordered_map<int, string> devices_;
+ absl::node_hash_map<int, string> devices_;
// These are the tunnels that were created for the test and will be destroyed
// by the destructor.
vector<std::unique_ptr<Tunnel>> tunnels_;
@@ -175,136 +193,316 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
in_port_t port_ = 0;
// sockets_to_close_ is a map from action index to the socket that was
// created.
- std::unordered_map<int,
- std::unique_ptr<gvisor::testing::FileDescriptor>>
+ absl::node_hash_map<int,
+ std::unique_ptr<gvisor::testing::FileDescriptor>>
sockets_to_close_;
int next_socket_id_ = 0;
};
TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 3));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 3, EADDRINUSE));
}
TEST_P(BindToDeviceSequenceTest, BindToDevice) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 1));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 2));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 1));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 2));
}
TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 123));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
}
TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) {
- ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
}
TEST_P(BindToDeviceSequenceTest, BindWithDevice) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
}
TEST_P(BindToDeviceSequenceTest, BindWithReuse) {
- ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123));
ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0));
+ BindSocket(/* reusePort */ true, /* reuse_addr */ false));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false,
+ /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0));
}
TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 456));
- ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 789));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse_port */ true, /* reuse_addr */ false));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 999, EADDRINUSE));
}
TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 999, 0));
}
TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 456));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
}
TEST_P(BindToDeviceSequenceTest, BindAndRelease) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
int to_release;
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 345, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 789));
// Release the bind to device 0 and try again.
ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 345));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 345));
}
TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) {
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ false, /* bind_to_device */ 123));
- ASSERT_NO_FATAL_FAILURE(
- BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reusePort */ false, /* reuse_addr */ true));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0));
+}
+
+TEST_P(BindToDeviceSequenceTest,
+ CannotBindTo0AfterMixingReuseAddrAndNotReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReusePortThenReuseAddrReusePort) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ true, /* reuse_addr */ false, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindReuseAddrThenReuseAddr) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest,
+ BindReuseAddrThenReuseAddrReusePortThenReuseAddr) {
+ // The behavior described in this test seems like a Linux bug. It doesn't
+ // make any sense and it is unlikely that any applications rely on it.
+ //
+ // Both SO_REUSEADDR and SO_REUSEPORT allow binding multiple UDP sockets to
+ // the same address and deliver each packet to exactly one of the bound
+ // sockets. If both are enabled, one of the strategies is selected to route
+ // packets. The strategy is selected dynamically based on the settings of the
+ // currently bound sockets. Usually, the strategy is selected based on the
+ // common setting (SO_REUSEADDR or SO_REUSEPORT) amongst the sockets, but for
+ // some reason, Linux allows binding sets of sockets with no overlapping
+ // settings in some situations. In this case, it is not obvious which strategy
+ // would be selected as the configured setting is a contradiction.
+ SKIP_IF(IsRunningOnGvisor());
+
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ true, /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ true,
+ /* bind_to_device */ 0));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 0));
+}
+
+// Repro test for gvisor.dev/issue/1217. Not replicated in ports_test.go as this
+// test is different from the others and wouldn't fit well there.
+TEST_P(BindToDeviceSequenceTest, BindAndReleaseDifferentDevice) {
+ int to_release;
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 3, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false,
+ /* reuse_addr */ false,
+ /* bind_to_device */ 3, EADDRINUSE));
+ // Change the device. Since the socket was already bound, this should have no
+ // effect.
+ SetDevice(to_release, 2);
+ // Release the bind to device 3 and try again.
+ ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(
+ /* reuse_port */ false, /* reuse_addr */ false, /* bind_to_device */ 3));
}
INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest,
diff --git a/test/syscalls/linux/socket_blocking.cc b/test/syscalls/linux/socket_blocking.cc
index d7ce57566..7e88aa2d9 100644
--- a/test/syscalls/linux/socket_blocking.cc
+++ b/test/syscalls/linux/socket_blocking.cc
@@ -17,6 +17,7 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
+
#include <cstdio>
#include "gtest/gtest.h"
diff --git a/test/syscalls/linux/socket_capability.cc b/test/syscalls/linux/socket_capability.cc
new file mode 100644
index 000000000..84b5b2b21
--- /dev/null
+++ b/test/syscalls/linux/socket_capability.cc
@@ -0,0 +1,61 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Subset of socket tests that need Linux-specific headers (compared to POSIX
+// headers).
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST(SocketTest, UnixConnectNeedsWritePerm) {
+ SKIP_IF(IsRunningWithVFS1());
+
+ FileDescriptor bound =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX));
+ ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(bound.get(), 1), SyscallSucceeds());
+
+ // Drop capabilites that allow us to override permision checks. Otherwise if
+ // the test is run as root, the connect below will bypass permission checks
+ // and succeed unexpectedly.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+
+ // Connect should fail without write perms.
+ ASSERT_THAT(chmod(addr.sun_path, 0500), SyscallSucceeds());
+ FileDescriptor client =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+ ASSERT_THAT(connect(client.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(EACCES));
+
+ // Connect should succeed with write perms.
+ ASSERT_THAT(chmod(addr.sun_path, 0200), SyscallSucceeds());
+ EXPECT_THAT(connect(client.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_filesystem.cc b/test/syscalls/linux/socket_filesystem.cc
index 74e262959..287359363 100644
--- a/test/syscalls/linux/socket_filesystem.cc
+++ b/test/syscalls/linux/socket_filesystem.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -43,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P(
FilesystemUnixSockets, UnixSocketPairCmsgTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index e8f24a59e..a6182f0ac 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -447,6 +447,62 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST_P(AllSocketPairTest, SendTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // tv_usec should be a multiple of 4000 to work on most systems.
+ timeval tv = {.tv_sec = 89, .tv_usec = 42000};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, tv.tv_sec);
+ EXPECT_EQ(actual_tv.tv_usec, tv.tv_usec);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval_with_extra {
+ struct timeval tv;
+ int64_t extra_data;
+ } ABSL_ATTRIBUTE_PACKED;
+
+ // tv_usec should be a multiple of 4000 to work on most systems.
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 124000},
+ };
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &tv_extra, sizeof(tv_extra)),
+ SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, tv_extra.tv.tv_sec);
+ EXPECT_EQ(actual_tv.tv.tv_usec, tv_extra.tv.tv_usec);
+}
+
TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -491,18 +547,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) {
ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf)));
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) {
+TEST_P(AllSocketPairTest, RecvTimeoutDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- struct timeval tv {
- .tv_sec = 0, .tv_usec = 35
- };
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetRecvTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 123, .tv_usec = 456000};
EXPECT_THAT(
setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 123);
+ EXPECT_EQ(actual_tv.tv_usec, 456000);
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
+TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct timeval_with_extra {
@@ -510,13 +584,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
int64_t extra_data;
} ABSL_ATTRIBUTE_PACKED;
- timeval_with_extra tv_extra;
- tv_extra.tv.tv_sec = 0;
- tv_extra.tv.tv_usec = 25;
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 432000},
+ };
EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
&tv_extra, sizeof(tv_extra)),
SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 432000);
}
TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) {
diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc
new file mode 100644
index 000000000..19239e9e9
--- /dev/null
+++ b/test/syscalls/linux/socket_generic_stress.cc
@@ -0,0 +1,130 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <poll.h>
+#include <stdio.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Test fixture for tests that apply to pairs of connected sockets.
+using ConnectStressTest = SocketPairTest;
+
+TEST_P(ConnectStressTest, Reset65kTimes) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Send some data to ensure that the connection gets reset and the port gets
+ // released immediately. This avoids either end entering TIME-WAIT.
+ char sent_data[100] = {};
+ ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+ // Poll the other FD to make sure that the data is in the receive buffer
+ // before closing it to ensure a RST is triggered.
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = sockets->second_fd(),
+ .events = POLL_IN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllConnectedSockets, ConnectStressTest,
+ ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ DualStackUDPBidirectionalBindSocketPair(0),
+
+ // Without REUSEADDR, we get port exhaustion on Linux.
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn)(IPv6TCPAcceptBindSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn)(IPv4TCPAcceptBindSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ DualStackTCPAcceptBindSocketPair(0))));
+
+// Test fixture for tests that apply to pairs of connected sockets created with
+// a persistent listener (if applicable).
+using PersistentListenerConnectStressTest = SocketPairTest;
+
+TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
+ if (GetParam().type == SOCK_STREAM) {
+ // Poll the other FD to make sure that we see the FIN from the other
+ // side before closing the second_fd. This ensures that the first_fd
+ // enters TIME-WAIT and not second_fd.
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = sockets->second_fd(),
+ .events = POLL_IN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ }
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds());
+ }
+}
+
+TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds());
+ if (GetParam().type == SOCK_STREAM) {
+ // Poll the other FD to make sure that we see the FIN from the other
+ // side before closing the first_fd. This ensures that the second_fd
+ // enters TIME-WAIT and not first_fd.
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = sockets->first_fd(),
+ .events = POLL_IN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ }
+ ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
+ }
+}
+
+TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) {
+ for (int i = 0; i < 1 << 16; ++i) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllConnectedSockets, PersistentListenerConnectStressTest,
+ ::testing::Values(
+ IPv6UDPBidirectionalBindSocketPair(0),
+ IPv4UDPBidirectionalBindSocketPair(0),
+ DualStackUDPBidirectionalBindSocketPair(0),
+
+ // Without REUSEADDR, we get port exhaustion on Linux.
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ IPv6TCPAcceptBindPersistentListenerSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ IPv4TCPAcceptBindPersistentListenerSocketPair(0)),
+ SetSockOpt(SOL_SOCKET, SO_REUSEADDR, &kSockOptOn)(
+ DualStackTCPAcceptBindPersistentListenerSocketPair(0))));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 322ee07ad..c3b42682f 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -14,6 +14,7 @@
#include <arpa/inet.h>
#include <netinet/in.h>
+#include <netinet/tcp.h>
#include <poll.h>
#include <string.h>
#include <sys/socket.h>
@@ -30,7 +31,9 @@
#include "gtest/gtest.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
+#include "absl/time/clock.h"
#include "absl/time/time.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
@@ -43,6 +46,8 @@ namespace testing {
namespace {
+using ::testing::Gt;
+
PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
switch (family) {
case AF_INET:
@@ -99,19 +104,172 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) {
SyscallFailsWithErrno(EAFNOSUPPORT));
}
-TEST_P(SocketInetLoopbackTest, TCP) {
- auto const& param = GetParam();
+enum class Operation {
+ Bind,
+ Connect,
+ SendTo,
+};
- TestAddress const& listener = param.listener;
- TestAddress const& connector = param.connector;
+std::string OperationToString(Operation operation) {
+ switch (operation) {
+ case Operation::Bind:
+ return "Bind";
+ case Operation::Connect:
+ return "Connect";
+ case Operation::SendTo:
+ return "SendTo";
+ }
+}
+
+using OperationSequence = std::vector<Operation>;
+
+using DualStackSocketTest =
+ ::testing::TestWithParam<std::tuple<TestAddress, OperationSequence>>;
+
+TEST_P(DualStackSocketTest, AddressOperations) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_DGRAM, 0));
+
+ const TestAddress& addr = std::get<0>(GetParam());
+ const OperationSequence& operations = std::get<1>(GetParam());
+
+ auto addr_in = reinterpret_cast<const sockaddr*>(&addr.addr);
+
+ // sockets may only be bound once. Both `connect` and `sendto` cause a socket
+ // to be bound.
+ bool bound = false;
+ for (const Operation& operation : operations) {
+ bool sockname = false;
+ bool peername = false;
+ switch (operation) {
+ case Operation::Bind: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 0));
+
+ int bind_ret = bind(fd.get(), addr_in, addr.addr_len);
+
+ // Dual stack sockets may only be bound to AF_INET6.
+ if (!bound && addr.family() == AF_INET6) {
+ EXPECT_THAT(bind_ret, SyscallSucceeds());
+ bound = true;
+
+ sockname = true;
+ } else {
+ EXPECT_THAT(bind_ret, SyscallFailsWithErrno(EINVAL));
+ }
+ break;
+ }
+ case Operation::Connect: {
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ EXPECT_THAT(RetryEINTR(connect)(fd.get(), addr_in, addr.addr_len),
+ SyscallSucceeds())
+ << GetAddrStr(addr_in);
+ bound = true;
+
+ sockname = true;
+ peername = true;
+
+ break;
+ }
+ case Operation::SendTo: {
+ const char payload[] = "hello";
+ ASSERT_NO_ERRNO(SetAddrPort(
+ addr.family(), const_cast<sockaddr_storage*>(&addr.addr), 1337));
+
+ ssize_t sendto_ret = sendto(fd.get(), &payload, sizeof(payload), 0,
+ addr_in, addr.addr_len);
+
+ EXPECT_THAT(sendto_ret, SyscallSucceedsWithValue(sizeof(payload)));
+ sockname = !bound;
+ bound = true;
+ break;
+ }
+ }
+
+ if (sockname) {
+ sockaddr_storage sock_addr;
+ socklen_t addrlen = sizeof(sock_addr);
+ ASSERT_THAT(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&sock_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ auto sock_addr_in6 = reinterpret_cast<const sockaddr_in6*>(&sock_addr);
+
+ if (operation == Operation::SendTo) {
+ EXPECT_EQ(sock_addr_in6->sin6_family, AF_INET6);
+ EXPECT_TRUE(IN6_IS_ADDR_UNSPECIFIED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+
+ EXPECT_NE(sock_addr_in6->sin6_port, 0);
+ } else if (IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(sock_addr_in6->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getsocknam="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&sock_addr));
+ }
+ }
+
+ if (peername) {
+ sockaddr_storage peer_addr;
+ socklen_t addrlen = sizeof(peer_addr);
+ ASSERT_THAT(getpeername(fd.get(), reinterpret_cast<sockaddr*>(&peer_addr),
+ &addrlen),
+ SyscallSucceeds());
+ ASSERT_EQ(addrlen, sizeof(struct sockaddr_in6));
+
+ if (addr.family() == AF_INET ||
+ IN6_IS_ADDR_V4MAPPED(reinterpret_cast<const sockaddr_in6*>(addr_in)
+ ->sin6_addr.s6_addr32)) {
+ EXPECT_TRUE(IN6_IS_ADDR_V4MAPPED(
+ reinterpret_cast<const sockaddr_in6*>(&peer_addr)
+ ->sin6_addr.s6_addr32))
+ << OperationToString(operation) << " getpeername="
+ << GetAddrStr(reinterpret_cast<sockaddr*>(&peer_addr));
+ }
+ }
+ }
+}
+
+// TODO(gvisor.dev/issue/1556): uncomment V4MappedAny.
+INSTANTIATE_TEST_SUITE_P(
+ All, DualStackSocketTest,
+ ::testing::Combine(
+ ::testing::Values(V4Any(), V4Loopback(), /*V4MappedAny(),*/
+ V4MappedLoopback(), V6Any(), V6Loopback()),
+ ::testing::ValuesIn<OperationSequence>(
+ {{Operation::Bind, Operation::Connect, Operation::SendTo},
+ {Operation::Bind, Operation::SendTo, Operation::Connect},
+ {Operation::Connect, Operation::Bind, Operation::SendTo},
+ {Operation::Connect, Operation::SendTo, Operation::Bind},
+ {Operation::SendTo, Operation::Bind, Operation::Connect},
+ {Operation::SendTo, Operation::Connect, Operation::Bind}})),
+ [](::testing::TestParamInfo<
+ std::tuple<TestAddress, OperationSequence>> const& info) {
+ const TestAddress& addr = std::get<0>(info.param);
+ const OperationSequence& operations = std::get<1>(info.param);
+ std::string s = addr.description;
+ for (const Operation& operation : operations) {
+ absl::StrAppend(&s, OperationToString(operation));
+ }
+ return s;
+ });
+void tcpSimpleConnectTest(TestAddress const& listener,
+ TestAddress const& connector, bool unbound) {
// Create the listening socket.
const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
+ if (!unbound) {
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ }
ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
// Get the port bound by the listening socket.
@@ -145,12 +303,31 @@ TEST_P(SocketInetLoopbackTest, TCP) {
ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds());
}
-TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+TEST_P(SocketInetLoopbackTest, TCP) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ tcpSimpleConnectTest(listener, connector, true);
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenUnbound) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
+ tcpSimpleConnectTest(listener, connector, false);
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) {
+ const auto& param = GetParam();
+
+ const TestAddress& listener = param.listener;
+ const TestAddress& connector = param.connector;
+
+ constexpr int kBacklog = 5;
+
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
@@ -158,7 +335,52 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
- ASSERT_THAT(listen(listen_fd.get(), 1001), SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+
+ for (int i = 0; i < kBacklog; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(RetryEINTR(connect)(client.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ }
+ for (int i = 0; i < kBacklog; i++) {
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdown) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kFDs = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
// Get the port bound by the listening socket.
socklen_t addrlen = listener.addr_len;
@@ -168,42 +390,169 @@ TEST_P(SocketInetLoopbackTest, TCPListenClose) {
uint16_t const port =
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
- DisableSave ds; // Too many system calls.
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- constexpr int kFDs = 2048;
- constexpr int kThreadCount = 4;
- constexpr int kFDsPerThread = kFDs / kThreadCount;
- FileDescriptor clients[kFDs];
- std::unique_ptr<ScopedThread> threads[kThreadCount];
+
+ // Shutdown the write of the listener, expect to not have any effect.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_WR), SyscallSucceeds());
+
for (int i = 0; i < kFDs; i++) {
- clients[i] = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(RetryEINTR(connect)(client.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr), SyscallSucceeds());
}
- for (int i = 0; i < kThreadCount; i++) {
- threads[i] = absl::make_unique<ScopedThread>([&connector, &conn_addr,
- &clients, i]() {
- for (int j = 0; j < kFDsPerThread; j++) {
- int k = i * kFDsPerThread + j;
- int ret =
- connect(clients[k].get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
- if (ret != 0) {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- }
- }
- });
+
+ // Shutdown the read of the listener, expect to fail subsequent
+ // server accepts, binds and client connects.
+ ASSERT_THAT(shutdown(listen_fd.get(), SHUT_RD), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Check that shutdown did not release the port.
+ FileDescriptor new_listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ bind(new_listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Check that subsequent connection attempts receive a RST.
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(RetryEINTR(connect)(client.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallFailsWithErrno(ECONNREFUSED));
}
- for (int i = 0; i < kThreadCount; i++) {
- threads[i]->Join();
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenClose) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kAcceptCount = 2;
+ constexpr int kBacklog = kAcceptCount + 2;
+ constexpr int kFDs = kBacklog * 3;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kFDs; i++) {
+ auto client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ }
+ clients.push_back(std::move(client));
}
- for (int i = 0; i < 32; i++) {
+ for (int i = 0; i < kAcceptCount; i++) {
auto accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
}
- // TODO(b/138400178): Fix cooperative S/R failure when ds.reset() is invoked
- // before function end.
- // ds.reset()
+}
+
+void TestListenWhileConnect(const TestParam& param,
+ void (*stopListen)(FileDescriptor&)) {
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 2;
+ constexpr int kClients = kBacklog + 1;
+
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
+ clients.push_back(std::move(client));
+ }
+ }
+
+ stopListen(listen_fd);
+
+ for (auto& client : clients) {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = client.get(),
+ .events = POLLIN,
+ };
+ // When the listening socket is closed, then we expect the remote to reset
+ // the connection.
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+ char c;
+ // Subsequent read can fail with:
+ // ECONNRESET: If the client connection was established and was reset by the
+ // remote.
+ // ECONNREFUSED: If the client connection failed to be established.
+ ASSERT_THAT(read(client.get(), &c, sizeof(c)),
+ AnyOf(SyscallFailsWithErrno(ECONNRESET),
+ SyscallFailsWithErrno(ECONNREFUSED)));
+ }
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenCloseWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(close(f.release()), SyscallSucceeds());
+ });
+}
+
+TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
+ TestListenWhileConnect(GetParam(), [](FileDescriptor& f) {
+ ASSERT_THAT(shutdown(f.get(), SHUT_RD), SyscallSucceeds());
+ });
}
TEST_P(SocketInetLoopbackTest, TCPbacklog) {
@@ -266,6 +615,649 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) {
}
}
+// TCPFinWait2Test creates a pair of connected sockets then closes one end to
+// trigger FIN_WAIT2 state for the closed endpoint. Then it binds the same local
+// IP/port on a new socket and tries to connect. The connect should fail w/
+// an EADDRINUSE. Then we wait till the FIN_WAIT2 timeout is over and try the
+// connect again with a new socket and this time it should succeed.
+//
+// TCP timers are not S/R today, this can cause this test to be flaky when run
+// under random S/R due to timer being reset on a restore.
+TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Lower FIN_WAIT2 state to 5 seconds for test.
+ constexpr int kTCPLingerTimeout = 5;
+ EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ // Now bind and connect a new socket.
+ const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Disable cooperative saves after this point. As a save between the first
+ // bind/connect and the second one can cause the linger timeout timer to
+ // be restarted causing the final bind/connect to fail.
+ DisableSave ds;
+
+ ASSERT_THAT(bind(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Sleep for a little over the linger timeout to reduce flakiness in
+ // save/restore tests.
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 2));
+
+ ds.reset();
+
+ if (!IsRunningOnGvisor()) {
+ ASSERT_THAT(
+ bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+ }
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+}
+
+// TCPLinger2TimeoutAfterClose creates a pair of connected sockets
+// then closes one end to trigger FIN_WAIT2 state for the closed endpont.
+// It then sleeps for the TCP_LINGER2 timeout and verifies that bind/
+// connecting the same address succeeds.
+//
+// TCP timers are not S/R today, this can cause this test to be flaky when run
+// under random S/R due to timer being reset on a restore.
+TEST_P(SocketInetLoopbackTest, TCPLinger2TimeoutAfterClose_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // Disable cooperative saves after this point as TCP timers are not restored
+ // across a S/R.
+ {
+ DisableSave ds;
+ constexpr int kTCPLingerTimeout = 5;
+ EXPECT_THAT(setsockopt(conn_fd.get(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
+
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ absl::SleepFor(absl::Seconds(kTCPLingerTimeout + 1));
+
+ // ds going out of scope will Re-enable S/R's since at this point the timer
+ // must have fired and cleaned up the endpoint.
+ }
+
+ // Now bind and connect a new socket and verify that we can immediately
+ // rebind the address bound by the conn_fd as it never entered TIME_WAIT.
+ const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+}
+
+// TCPResetAfterClose creates a pair of connected sockets then closes
+// one end to trigger FIN_WAIT2 state for the closed endpoint verifies
+// that we generate RSTs for any new data after the socket is fully
+// closed.
+TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // close the connecting FD to trigger FIN_WAIT2 on the connected fd.
+ conn_fd.reset();
+
+ int data = 1234;
+
+ // Now send data which should trigger a RST as the other end should
+ // have timed out and closed the socket.
+ EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0),
+ SyscallSucceeds());
+ // Sleep for a shortwhile to get a RST back.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Try writing again and we should get an EPIPE back.
+ EXPECT_THAT(RetryEINTR(send)(accepted.get(), &data, sizeof(data), 0),
+ SyscallFailsWithErrno(EPIPE));
+
+ // Trying to read should return zero as the other end did send
+ // us a FIN. We do it twice to verify that the RST does not cause an
+ // ECONNRESET on the read after EOF has been read by applicaiton.
+ EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(RetryEINTR(recv)(accepted.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(0));
+}
+
+// This test is disabled under random save as the the restore run
+// results in the stack.Seed() being different which can cause
+// sequence number of final connect to be one that is considered
+// old and can cause the test to be flaky.
+TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // We disable saves after this point as a S/R causes the netstack seed
+ // to be regenerated which changes what ports/ISN is picked for a given
+ // tuple (src ip,src port, dst ip, dst port). This can cause the final
+ // SYN to use a sequence number that looks like one from the current
+ // connection in TIME_WAIT and will not be accepted causing the test
+ // to timeout.
+ //
+ // TODO(gvisor.dev/issue/940): S/R portSeed/portHint
+ DisableSave ds;
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // shutdown the accept FD to trigger TIME_WAIT on the accepted socket which
+ // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of
+ // TIME_WAIT.
+ ASSERT_THAT(shutdown(accepted.get(), SHUT_RDWR), SyscallSucceeds());
+ {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = conn_fd.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+ }
+
+ conn_fd.reset();
+ // This sleep is required to give conn_fd time to transition to TIME-WAIT.
+ absl::SleepFor(absl::Seconds(1));
+
+ // At this point conn_fd should be the one that moved to CLOSE_WAIT and
+ // eventually to CLOSED.
+
+ // Now bind and connect a new socket and verify that we can immediately
+ // rebind the address bound by the conn_fd as it never entered TIME_WAIT.
+ const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ conn_addrlen),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ // We disable saves after this point as a S/R causes the netstack seed
+ // to be regenerated which changes what ports/ISN is picked for a given
+ // tuple (src ip,src port, dst ip, dst port). This can cause the final
+ // SYN to use a sequence number that looks like one from the current
+ // connection in TIME_WAIT and will not be accepted causing the test
+ // to timeout.
+ //
+ // TODO(gvisor.dev/issue/940): S/R portSeed/portHint
+ DisableSave ds;
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Get the address/port bound by the connecting socket.
+ sockaddr_storage conn_bound_addr;
+ socklen_t conn_addrlen = connector.addr_len;
+ ASSERT_THAT(
+ getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ &conn_addrlen),
+ SyscallSucceeds());
+
+ // shutdown the conn FD to trigger TIME_WAIT on the connect socket.
+ ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds());
+ {
+ const int kTimeout = 10000;
+ struct pollfd pfd = {
+ .fd = accepted.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+ }
+ ScopedThread t([&]() {
+ constexpr int kTimeout = 10000;
+ constexpr int16_t want_events = POLLHUP;
+ struct pollfd pfd = {
+ .fd = conn_fd.get(),
+ .events = want_events,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ });
+
+ accepted.reset();
+ t.Join();
+ conn_fd.reset();
+
+ // Now bind and connect a new socket and verify that we can't immediately
+ // rebind the address bound by the conn_fd as it is in TIME_WAIT.
+ conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
+ conn_addrlen),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the userTimeout on the listening socket.
+ constexpr int kUserTimeout = 10;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kUserTimeout, sizeof(kUserTimeout)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ // Verify that the accepted socket inherited the user timeout set on
+ // listening socket.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kUserTimeout);
+}
+
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now write some data to the socket.
+ int data = 0;
+ ASSERT_THAT(RetryEINTR(write)(conn_fd.get(), &data, sizeof(data)),
+ SyscallSucceedsWithValue(sizeof(data)));
+
+ // This should now cause the connection to complete and be delivered to the
+ // accept socket.
+
+ // Accept the connection.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+
+ // Verify that the accepted socket returns the data written.
+ int get = -1;
+ ASSERT_THAT(RetryEINTR(recv)(accepted.get(), &get, sizeof(get), 0),
+ SyscallSucceedsWithValue(sizeof(get)));
+
+ EXPECT_EQ(get, data);
+}
+
+// TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+// saved. Enable S/R once issue is fixed.
+TEST_P(SocketInetLoopbackTest, TCPDeferAcceptTimeout_NoRandomSave) {
+ // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not
+ // saved. Enable S/R once issue is fixed.
+ DisableSave ds;
+
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+
+ const uint16_t port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Set the TCP_DEFER_ACCEPT on the listening socket.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+
+ // Connect to the listening socket.
+ FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Set the listening socket to nonblock so that we can verify that there is no
+ // connection in queue despite the connect above succeeding since the peer has
+ // sent no data and TCP_DEFER_ACCEPT is set on the listening socket. Set the
+ // FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(listen_fd.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Verify that there is no acceptable connection before TCP_DEFER_ACCEPT
+ // timeout is hit.
+ absl::SleepFor(absl::Seconds(kTCPDeferAccept - 1));
+ ASSERT_THAT(accept(listen_fd.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Set FD back to blocking.
+ opts &= ~O_NONBLOCK;
+ ASSERT_THAT(fcntl(listen_fd.get(), F_SETFL, opts), SyscallSucceeds());
+
+ // Now sleep for a little over the TCP_DEFER_ACCEPT duration. When the timeout
+ // is hit a SYN-ACK should be retransmitted by the listener as a last ditch
+ // attempt to complete the connection with or without data.
+ absl::SleepFor(absl::Seconds(2));
+
+ // Verify that we have a connection that can be accepted even though no
+ // data was written.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+}
+
INSTANTIATE_TEST_SUITE_P(
All, SocketInetLoopbackTest,
::testing::Values(
@@ -298,7 +1290,9 @@ INSTANTIATE_TEST_SUITE_P(
using SocketInetReusePortTest = ::testing::TestWithParam<TestParam>;
-TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
+// TODO(gvisor.dev/issue/940): Remove _NoRandomSave when portHint/stack.Seed is
+// saved/restored.
+TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -306,6 +1300,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
sockaddr_storage listen_addr = listener.addr;
sockaddr_storage conn_addr = connector.addr;
constexpr int kThreadCount = 3;
+ constexpr int kConnectAttempts = 10000;
// Create the listening socket.
FileDescriptor listener_fds[kThreadCount];
@@ -339,7 +1334,6 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
}
- constexpr int kConnectAttempts = 10000;
std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
std::unique_ptr<ScopedThread> listen_thread[kThreadCount];
int accept_counts[kThreadCount] = {};
@@ -357,6 +1351,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
if (connects_received >= kConnectAttempts) {
// Another thread have shutdown our read side causing the
// accept to fail.
+ ASSERT_EQ(errno, EINVAL);
break;
}
ASSERT_NO_ERRNO(fd);
@@ -364,7 +1359,14 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
}
// Receive some data from a socket to be sure that the connect()
// system call has been completed on another side.
- int data;
+ // Do a short read and then close the socket to trigger a RST. This
+ // ensures that both ends of the connection are cleaned up and no
+ // goroutines hang around in TIME-WAIT. We do this so that this test
+ // does not timeout under gotsan runs where lots of goroutines can
+ // cause the test to use absurd amounts of memory.
+ //
+ // See: https://tools.ietf.org/html/rfc2525#page-50 section 2.17
+ uint16_t data;
EXPECT_THAT(
RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
SyscallSucceedsWithValue(sizeof(data)));
@@ -387,8 +1389,22 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
connector.addr_len),
SyscallSucceeds());
+ // Do two separate sends to ensure two segments are received. This is
+ // required for netstack where read is incorrectly assuming a whole
+ // segment is read when endpoint.Read() is called which is technically
+ // incorrect as the syscall that invoked endpoint.Read() may only
+ // consume it partially. This results in a case where a close() of
+ // such a socket does not trigger a RST in netstack due to the
+ // endpoint assuming that the endpoint has no unread data.
EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
SyscallSucceedsWithValue(sizeof(i)));
+
+ // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly
+ // generates a RST.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
}
});
@@ -403,7 +1419,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
-TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -516,6 +1532,115 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort_NoRandomSave) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // TODO(b/141211329): endpointsByNic.seed has to be saved/restored.
+ const DisableSave ds141211329;
+
+ // Create listening sockets.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10;
+ FileDescriptor client_fds[kConnectAttempts];
+
+ // Do the first run without save/restore.
+ DisableSave ds;
+ for (int i = 0; i < kConnectAttempts; i++) {
+ client_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ ds.reset();
+
+ // Check that a mapping of client and server sockets has
+ // not been change after save/restore.
+ for (int i = 0; i < kConnectAttempts; i++) {
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ struct pollfd pollfds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ pollfds[i].fd = listener_fds[i].get();
+ pollfds[i].events = POLLIN;
+ }
+
+ std::map<uint16_t, int> portToFD;
+
+ int received = 0;
+ while (received < kConnectAttempts * 2) {
+ ASSERT_THAT(poll(pollfds, kThreadCount, -1),
+ SyscallSucceedsWithValue(Gt(0)));
+
+ for (int i = 0; i < kThreadCount; i++) {
+ if ((pollfds[i].revents & POLLIN) == 0) {
+ continue;
+ }
+
+ received++;
+
+ const int fd = pollfds[i].fd;
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+ EXPECT_THAT(RetryEINTR(recvfrom)(
+ fd, &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr));
+ auto prev_port = portToFD.find(port);
+ // Check that all packets from one client have been delivered to the
+ // same server socket.
+ if (prev_port == portToFD.end()) {
+ portToFD[port] = fd;
+ } else {
+ EXPECT_EQ(portToFD[port], fd);
+ }
+ }
+ }
+}
+
INSTANTIATE_TEST_SUITE_P(
All, SocketInetReusePortTest,
::testing::Values(
@@ -702,6 +1827,171 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) {
ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
test_addr_v4.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 any on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ DualStackV6AnyReuseAddrDoesNotReserveV4Any) {
+ auto const& param = GetParam();
+
+ // Bind the v6 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v4 any on the same port with a v4 socket succeeds.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ DualStackV6AnyReuseAddrListenReservesV4Any) {
+ auto const& param = GetParam();
+
+ // Only TCP sockets are supported.
+ SKIP_IF((param.type & SOCK_STREAM) == 0);
+
+ // Bind the v6 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v4 any on the same port with a v4 socket succeeds.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ DualStackV6AnyWithListenReservesEverything) {
+ auto const& param = GetParam();
+
+ // Only TCP sockets are supported.
+ SKIP_IF((param.type & SOCK_STREAM) == 0);
+
+ // Bind the v6 any on a dual stack socket.
+ TestAddress const& test_addr_dual = V6Any();
+ sockaddr_storage addr_dual = test_addr_dual.addr;
+ const FileDescriptor fd_dual =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
+ test_addr_dual.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds());
+
+ // Get the port that we bound.
+ socklen_t addrlen = test_addr_dual.addr_len;
+ ASSERT_THAT(getsockname(fd_dual.get(),
+ reinterpret_cast<sockaddr*>(&addr_dual), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual));
+
+ // Verify that binding the v6 loopback with the same port fails.
+ TestAddress const& test_addr_v6 = V6Loopback();
+ sockaddr_storage addr_v6 = test_addr_v6.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port));
+ const FileDescriptor fd_v6 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6),
+ test_addr_v6.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v6 socket
+ // fails.
+ TestAddress const& test_addr_v4_mapped = V4MappedLoopback();
+ sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr;
+ ASSERT_NO_ERRNO(
+ SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port));
+ const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_mapped.family(), param.type, 0));
+ ASSERT_THAT(
+ bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped),
+ test_addr_v4_mapped.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 loopback on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4 = V4Loopback();
+ sockaddr_storage addr_v4 = test_addr_v4.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port));
+ const FileDescriptor fd_v4 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4),
+ test_addr_v4.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+
+ // Verify that binding the v4 any on the same port with a v4 socket
+ // fails.
+ TestAddress const& test_addr_v4_any = V4Any();
+ sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port));
+ const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(test_addr_v4_any.family(), param.type, 0));
+ ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
+ test_addr_v4_any.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
}
TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) {
@@ -713,10 +2003,9 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) {
sockaddr_storage addr_dual = test_addr_dual.addr;
const FileDescriptor fd_dual = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_dual.family(), param.type, 0));
- int one = 1;
- EXPECT_THAT(
- setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY, &one, sizeof(one)),
- SyscallSucceeds());
+ EXPECT_THAT(setsockopt(fd_dual.get(), IPPROTO_IPV6, IPV6_V6ONLY,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual),
test_addr_dual.addr_len),
SyscallSucceeds());
@@ -764,9 +2053,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) {
TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
auto const& param = GetParam();
- // FIXME(b/114268588)
- SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM);
-
for (int i = 0; true; i++) {
// Bind the v6 loopback on a dual stack socket.
TestAddress const& test_addr = V6Loopback();
@@ -792,10 +2078,10 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
// Connect to bind an ephemeral port.
const FileDescriptor connected_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(
- connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- bound_addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
@@ -829,17 +2115,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
test_addr_v6.addr_len),
SyscallFailsWithErrno(EADDRINUSE));
- // Verify that binding the v4 any with the same port fails.
- TestAddress const& test_addr_v4_any = V4Any();
- sockaddr_storage addr_v4_any = test_addr_v4_any.addr;
- ASSERT_NO_ERRNO(
- SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, ephemeral_port));
- const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(test_addr_v4_any.family(), param.type, 0));
- ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any),
- test_addr_v4_any.addr_len),
- SyscallFailsWithErrno(EADDRINUSE));
-
// Verify that we can still bind the v4 loopback on the same port.
TestAddress const& test_addr_v4_mapped = V4MappedLoopback();
sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr;
@@ -862,11 +2137,71 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) {
}
}
-TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
+TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) {
auto const& param = GetParam();
- // FIXME(b/114268588)
- SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM);
+ // Bind the v6 loopback on a dual stack socket.
+ TestAddress const& test_addr = V6Loopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is not reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
+ auto const& param = GetParam();
for (int i = 0; true; i++) {
// Bind the v4 loopback on a dual stack socket.
@@ -893,10 +2228,10 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
// Connect to bind an ephemeral port.
const FileDescriptor connected_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(
- connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- bound_addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
@@ -965,9 +2300,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
// v6-only socket.
const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v6_any.family(), param.type, 0));
- int one = 1;
EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY,
- &one, sizeof(one)),
+ &kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
ret =
bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
@@ -986,11 +2320,73 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) {
}
}
-TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
+TEST_P(SocketMultiProtocolInetLoopbackTest,
+ V4MappedEphemeralPortReservedResueAddr) {
auto const& param = GetParam();
- // FIXME(b/114268588)
- SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM);
+ // Bind the v4 loopback on a dual stack socket.
+ TestAddress const& test_addr = V4MappedLoopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is not reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
+ auto const& param = GetParam();
for (int i = 0; true; i++) {
// Bind the v4 loopback on a v4 socket.
@@ -1017,10 +2413,10 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
// Connect to bind an ephemeral port.
const FileDescriptor connected_fd =
ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
- ASSERT_THAT(
- connect(connected_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
- bound_addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
// Get the ephemeral port.
sockaddr_storage connected_addr = {};
@@ -1090,9 +2486,8 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
// v6-only socket.
const FileDescriptor fd_v6_only_any = ASSERT_NO_ERRNO_AND_VALUE(
Socket(test_addr_v6_any.family(), param.type, 0));
- int one = 1;
EXPECT_THAT(setsockopt(fd_v6_only_any.get(), IPPROTO_IPV6, IPV6_V6ONLY,
- &one, sizeof(one)),
+ &kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
ret =
bind(fd_v6_only_any.get(), reinterpret_cast<sockaddr*>(&addr_v6_any),
@@ -1111,6 +2506,73 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) {
}
}
+TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) {
+ auto const& param = GetParam();
+
+ // Bind the v4 loopback on a v4 socket.
+ TestAddress const& test_addr = V4Loopback();
+ sockaddr_storage bound_addr = test_addr.addr;
+ const FileDescriptor bound_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+
+ ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ test_addr.addr_len),
+ SyscallSucceeds());
+
+ // Listen iff TCP.
+ if (param.type == SOCK_STREAM) {
+ ASSERT_THAT(listen(bound_fd.get(), SOMAXCONN), SyscallSucceeds());
+ }
+
+ // Get the port that we bound.
+ socklen_t bound_addr_len = test_addr.addr_len;
+ ASSERT_THAT(
+ getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr),
+ &bound_addr_len),
+ SyscallSucceeds());
+
+ // Connect to bind an ephemeral port.
+ const FileDescriptor connected_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+
+ ASSERT_THAT(setsockopt(connected_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(RetryEINTR(connect)(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&bound_addr),
+ bound_addr_len),
+ SyscallSucceeds());
+
+ // Get the ephemeral port.
+ sockaddr_storage connected_addr = {};
+ socklen_t connected_addr_len = sizeof(connected_addr);
+ ASSERT_THAT(getsockname(connected_fd.get(),
+ reinterpret_cast<sockaddr*>(&connected_addr),
+ &connected_addr_len),
+ SyscallSucceeds());
+ uint16_t const ephemeral_port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr.family(), connected_addr));
+
+ // Verify that we actually got an ephemeral port.
+ ASSERT_NE(ephemeral_port, 0);
+
+ // Verify that the ephemeral port is not reserved.
+ const FileDescriptor checking_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0));
+ ASSERT_THAT(setsockopt(checking_fd.get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ bind(checking_fd.get(), reinterpret_cast<sockaddr*>(&connected_addr),
+ connected_addr_len),
+ SyscallSucceeds());
+}
+
TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
auto const& param = GetParam();
TestAddress const& test_addr = V4Loopback();
@@ -1148,7 +2610,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) {
setsockopt(fd2, SOL_SOCKET, SO_REUSEPORT, &portreuse2, sizeof(int)),
SyscallSucceeds());
- std::cout << portreuse1 << " " << portreuse2;
+ std::cout << portreuse1 << " " << portreuse2 << std::endl;
int ret = bind(fd2, reinterpret_cast<sockaddr*>(&addr), addrlen);
// Verify that two sockets can be bound to the same port only if
@@ -1197,7 +2659,7 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, NoReusePortFollowingReusePort) {
}
INSTANTIATE_TEST_SUITE_P(
- AllFamlies, SocketMultiProtocolInetLoopbackTest,
+ AllFamilies, SocketMultiProtocolInetLoopbackTest,
::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM},
ProtocolTestParam{"UDP", SOCK_DGRAM}),
DescribeProtocolTestParam);
diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
new file mode 100644
index 000000000..791e2bd51
--- /dev/null
+++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc
@@ -0,0 +1,174 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <string.h>
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using ::testing::Gt;
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+struct TestParam {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+std::string DescribeTestParam(::testing::TestParamInfo<TestParam> const& info) {
+ return absl::StrCat("Listen", info.param.listener.description, "_Connect",
+ info.param.connector.description);
+}
+
+using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>;
+
+// This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral
+// ports are already in use for a given destination ip/port.
+//
+// We disable S/R because this test creates a large number of sockets.
+//
+// FIXME(b/162475855): This test is failing reliably.
+TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ constexpr int kBacklog = 10;
+ constexpr int kClients = 65536;
+
+ // Create the listening socket.
+ auto listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+
+ // Now we keep opening connections till we run out of local ephemeral ports.
+ // and assert the error we get back.
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ std::vector<FileDescriptor> clients;
+ std::vector<FileDescriptor> servers;
+
+ for (int i = 0; i < kClients; i++) {
+ FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret == 0) {
+ clients.push_back(std::move(client));
+ FileDescriptor server =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ servers.push_back(std::move(server));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRNOTAVAIL));
+ break;
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ All, SocketInetLoopbackTest,
+ ::testing::Values(
+ // Listeners bound to IPv4 addresses refuse connections using IPv6
+ // addresses.
+ TestParam{V4Any(), V4Any()}, TestParam{V4Any(), V4Loopback()},
+ TestParam{V4Any(), V4MappedAny()},
+ TestParam{V4Any(), V4MappedLoopback()},
+ TestParam{V4Loopback(), V4Any()}, TestParam{V4Loopback(), V4Loopback()},
+ TestParam{V4Loopback(), V4MappedLoopback()},
+ TestParam{V4MappedAny(), V4Any()},
+ TestParam{V4MappedAny(), V4Loopback()},
+ TestParam{V4MappedAny(), V4MappedAny()},
+ TestParam{V4MappedAny(), V4MappedLoopback()},
+ TestParam{V4MappedLoopback(), V4Any()},
+ TestParam{V4MappedLoopback(), V4Loopback()},
+ TestParam{V4MappedLoopback(), V4MappedLoopback()},
+
+ // Listeners bound to IN6ADDR_ANY accept all connections.
+ TestParam{V6Any(), V4Any()}, TestParam{V6Any(), V4Loopback()},
+ TestParam{V6Any(), V4MappedAny()},
+ TestParam{V6Any(), V4MappedLoopback()}, TestParam{V6Any(), V6Any()},
+ TestParam{V6Any(), V6Loopback()},
+
+ // Listeners bound to IN6ADDR_LOOPBACK refuse connections using IPv4
+ // addresses.
+ TestParam{V6Loopback(), V6Any()},
+ TestParam{V6Loopback(), V6Loopback()}),
+ DescribeTestParam);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_loopback_blocking.cc b/test/syscalls/linux/socket_ip_loopback_blocking.cc
index d7fc9715b..fda252dd7 100644
--- a/test/syscalls/linux/socket_ip_loopback_blocking.cc
+++ b/test/syscalls/linux/socket_ip_loopback_blocking.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <netinet/tcp.h>
+
#include <vector>
#include "test/syscalls/linux/ip_socket_test_util.h"
@@ -22,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(
@@ -42,5 +44,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingIPSockets, BlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index 7e0deda05..53c076787 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -24,13 +24,20 @@
#include <sys/un.h>
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
-TEST_P(TCPSocketPairTest, TcpInfoSucceedes) {
+using ::testing::AnyOf;
+using ::testing::Eq;
+
+TEST_P(TCPSocketPairTest, TcpInfoSucceeds) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct tcp_info opt = {};
@@ -39,7 +46,7 @@ TEST_P(TCPSocketPairTest, TcpInfoSucceedes) {
SyscallSucceeds());
}
-TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) {
+TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct tcp_info opt = {};
@@ -48,7 +55,7 @@ TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) {
SyscallSucceeds());
}
-TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceedes) {
+TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct tcp_info opt = {};
@@ -243,6 +250,31 @@ TEST_P(TCPSocketPairTest, ShutdownRdAllowsReadOfReceivedDataBeforeEOF) {
SyscallSucceedsWithValue(0));
}
+// This test verifies that a shutdown(wr) by the server after sending
+// data allows the client to still read() the queued data and a client
+// close after sending response allows server to read the incoming
+// response.
+TEST_P(TCPSocketPairTest, ShutdownWrServerClientClose) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ char buf[10] = {};
+ ScopedThread t([&]() {
+ ASSERT_THAT(RetryEINTR(read)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(close(sockets->release_first_fd()),
+ SyscallSucceedsWithValue(0));
+ });
+ ASSERT_THAT(RetryEINTR(write)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(RetryEINTR(shutdown)(sockets->second_fd(), SHUT_WR),
+ SyscallSucceedsWithValue(0));
+ t.Join();
+
+ ASSERT_THAT(RetryEINTR(read)(sockets->second_fd(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
TEST_P(TCPSocketPairTest, ClosedReadNonBlockingSocket) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -495,6 +527,7 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlZero) {
// Copied from include/net/tcp.h.
constexpr int MAX_TCP_KEEPIDLE = 32767;
constexpr int MAX_TCP_KEEPINTVL = 32767;
+constexpr int MAX_TCP_KEEPCNT = 127;
TEST_P(TCPSocketPairTest, SetTCPKeepidleAboveMax) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -546,6 +579,78 @@ TEST_P(TCPSocketPairTest, SetTCPKeepintvlToMax) {
EXPECT_EQ(get, MAX_TCP_KEEPINTVL);
}
+TEST_P(TCPSocketPairTest, TCPKeepcountDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 9); // 9 keepalive probes.
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &kZero,
+ sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAboveMax = MAX_TCP_KEEPCNT + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &kAboveMax, sizeof(kAboveMax)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &MAX_TCP_KEEPCNT, sizeof(MAX_TCP_KEEPCNT)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, MAX_TCP_KEEPCNT);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToOne) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, keepaliveCount);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPKeepcountToNegative) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int keepaliveCount = -5;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_KEEPCNT,
+ &keepaliveCount, sizeof(keepaliveCount)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST_P(TCPSocketPairTest, SetOOBInline) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -696,5 +801,265 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) {
EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc)));
}
+// Linux and Netstack both default to a 60s TCP_LINGER2 timeout.
+constexpr int kDefaultTCPLingerTimeout = 60;
+// On Linux, the maximum linger2 timeout was changed from 60sec to 120sec.
+constexpr int kMaxTCPLingerTimeout = 120;
+constexpr int kOldMaxTCPLingerTimeout = 60;
+
+TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kDefaultTCPLingerTimeout);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero,
+ sizeof(kZero)),
+ SyscallSucceedsWithValue(0));
+
+ constexpr int kNegative = -1234;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2,
+ &kNegative, sizeof(kNegative)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveMax) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout
+ // on linux (defaults to 60 seconds on linux).
+ constexpr int kAboveDefault = kMaxTCPLingerTimeout + 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2,
+ &kAboveDefault, sizeof(kAboveDefault)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ if (IsRunningOnGvisor()) {
+ EXPECT_EQ(get, kMaxTCPLingerTimeout);
+ } else {
+ EXPECT_THAT(get,
+ AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout)));
+ }
+}
+
+TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout
+ // on linux (defaults to 60 seconds on linux).
+ constexpr int kTCPLingerTimeout = kDefaultTCPLingerTimeout - 1;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2,
+ &kTCPLingerTimeout, sizeof(kTCPLingerTimeout)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPLingerTimeout);
+}
+
+TEST_P(TCPSocketPairTest, TestTCPCloseWithData) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ScopedThread t([&]() {
+ // Close one end to trigger sending of a FIN.
+ ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_WR), SyscallSucceeds());
+ char buf[3];
+ ASSERT_THAT(read(sockets->second_fd(), buf, 3),
+ SyscallSucceedsWithValue(3));
+ absl::SleepFor(absl::Milliseconds(50));
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ });
+
+ absl::SleepFor(absl::Milliseconds(50));
+ // Send some data then close.
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->first_fd(), kStr, 3), SyscallSucceedsWithValue(3));
+ t.Join();
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+}
+
+TEST_P(TCPSocketPairTest, TCPUserTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0); // 0 ms (disabled).
+}
+
+TEST_P(TCPSocketPairTest, SetTCPUserTimeoutZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kZero = 0;
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0); // 0 ms (disabled).
+}
+
+TEST_P(TCPSocketPairTest, SetTCPUserTimeoutBelowZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kNeg = -10;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kNeg, sizeof(kNeg)),
+ SyscallFailsWithErrno(EINVAL));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0); // 0 ms (disabled).
+}
+
+TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ constexpr int kAbove = 10;
+ ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kAbove, sizeof(kAbove)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kAbove);
+}
+
+TEST_P(TCPSocketPairTest, SetTCPWindowClampBelowMinRcvBufConnectedSocket) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ // Discover minimum receive buf by setting a really low value
+ // for the receive buffer.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &kZero,
+ sizeof(kZero)),
+ SyscallSucceeds());
+
+ // Now retrieve the minimum value for SO_RCVBUF as the set above should
+ // have caused SO_RCVBUF for the socket to be set to the minimum.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVBUF, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int min_so_rcvbuf = get;
+
+ {
+ // Setting TCP_WINDOW_CLAMP to zero for a connected socket is not permitted.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &kZero, sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Non-zero clamp values below MIN_SO_RCVBUF/2 should result in the clamp
+ // being set to MIN_SO_RCVBUF/2.
+ int below_half_min_so_rcvbuf = min_so_rcvbuf / 2 - 1;
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &below_half_min_so_rcvbuf, sizeof(below_half_min_so_rcvbuf)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(min_so_rcvbuf / 2, get);
+ }
+}
+
+TEST_P(TCPSocketPairTest, IpMulticastTtlDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_GT(get, 0);
+}
+
+TEST_P(TCPSocketPairTest, IpMulticastLoopDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 1);
+}
+
+TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) {
+ DisableSave ds; // Too many syscalls.
+ constexpr int kThreadCount = 1000;
+ std::unique_ptr<ScopedThread> instances[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ instances[i] = absl::make_unique<ScopedThread>([&]() {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ScopedThread t([&]() {
+ // Close one end to trigger sending of a FIN.
+ struct pollfd poll_fd = {sockets->second_fd(), POLLIN | POLLHUP, 0};
+ // Wait up to 20 seconds for the data.
+ constexpr int kPollTimeoutMs = 20000;
+ ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs),
+ SyscallSucceedsWithValue(1));
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ });
+
+ // Send some data then close.
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(write(sockets->first_fd(), kStr, 3),
+ SyscallSucceedsWithValue(3));
+ absl::SleepFor(absl::Milliseconds(10));
+ ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ t.Join();
+ });
+ }
+ for (int i = 0; i < kThreadCount; i++) {
+ instances[i]->Join();
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
index 0dc274e2d..4e79d21f4 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic_loopback.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <netinet/tcp.h>
+
#include <vector>
#include "test/syscalls/linux/ip_socket_test_util.h"
@@ -22,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVecToVec<SocketPairKind>(
@@ -38,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P(
AllTCPSockets, TCPSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback.cc b/test/syscalls/linux/socket_ip_tcp_loopback.cc
index 831de53b8..9db3037bc 100644
--- a/test/syscalls/linux/socket_ip_tcp_loopback.cc
+++ b/test/syscalls/linux/socket_ip_tcp_loopback.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -34,5 +35,6 @@ INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, AllSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
index cd3ad97d0..f996b93d2 100644
--- a/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
+++ b/test/syscalls/linux/socket_ip_tcp_loopback_blocking.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <netinet/tcp.h>
+
#include <vector>
#include "test/syscalls/linux/ip_socket_test_util.h"
@@ -22,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVecToVec<SocketPairKind>(
@@ -38,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingTCPSockets, BlockingStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
index 1acdecc17..ffa377210 100644
--- a/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
+++ b/test/syscalls/linux/socket_ip_tcp_loopback_nonblock.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <netinet/tcp.h>
+
#include <vector>
#include "test/syscalls/linux/ip_socket_test_util.h"
@@ -22,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVecToVec<SocketPairKind>(
@@ -37,5 +39,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingTCPSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 2a4ed04a5..edb86aded 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -14,6 +14,7 @@
#include "test/syscalls/linux/socket_ip_udp_generic.h"
+#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
@@ -35,7 +36,7 @@ TEST_P(UDPSocketPairTest, MulticastTTLDefault) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -52,7 +53,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMin) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -69,7 +70,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLMax) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -91,7 +92,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLNegativeOne) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -126,7 +127,7 @@ TEST_P(UDPSocketPairTest, SetUDPMulticastTTLChar) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_TTL,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -147,7 +148,7 @@ TEST_P(UDPSocketPairTest, MulticastLoopDefault) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -163,7 +164,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoop) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -173,7 +174,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoop) {
&kSockOptOn, sizeof(kSockOptOn)),
SyscallSucceeds());
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -192,7 +193,7 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) {
int get = -1;
socklen_t get_len = sizeof(get);
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
@@ -202,12 +203,250 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) {
&kSockOptOnChar, sizeof(kSockOptOnChar)),
SyscallSucceeds());
- EXPECT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
&get, &get_len),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_len, sizeof(get));
EXPECT_EQ(get, kSockOptOn);
}
+TEST_P(UDPSocketPairTest, ReuseAddrDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, SetReuseAddr) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, ReusePortDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, SetReusePort) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOff, sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEADDR, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_REUSEPORT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+// Test getsockopt for a socket which is not set with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, IPPKTINFODefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test setsockopt and getsockopt for a socket with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ // Check getsockopt before IP_PKTINFO is set.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_len, sizeof(get));
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_len, sizeof(get));
+}
+
+// Holds TOS or TClass information for IPv4 or IPv6 respectively.
+struct RecvTosOption {
+ int level;
+ int option;
+};
+
+RecvTosOption GetRecvTosOption(int domain) {
+ TEST_CHECK(domain == AF_INET || domain == AF_INET6);
+ RecvTosOption opt;
+ switch (domain) {
+ case AF_INET:
+ opt.level = IPPROTO_IP;
+ opt.option = IP_RECVTOS;
+ break;
+ case AF_INET6:
+ opt.level = IPPROTO_IPV6;
+ opt.option = IPV6_RECVTCLASS;
+ break;
+ }
+ return opt;
+}
+
+// Ensure that Receiving TOS or TCLASS is off by default.
+TEST_P(UDPSocketPairTest, RecvTosDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as
+// expected.
+TEST_P(UDPSocketPairTest, SetRecvTos) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this
+// mirrors behavior in linux.
+TEST_P(UDPSocketPairTest, TOSRecvMismatch) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(AF_INET);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+}
+
+// Test that an IPv4 socket does not support the IPv6 TClass option.
+TEST_P(UDPSocketPairTest, TClassRecvMismatch) {
+ // This should only test AF_INET sockets for the mismatch behavior.
+ SKIP_IF(GetParam().domain != AF_INET);
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS,
+ &get, &get_len),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback.cc b/test/syscalls/linux/socket_ip_udp_loopback.cc
index 1df74a348..c7fa44884 100644
--- a/test/syscalls/linux/socket_ip_udp_loopback.cc
+++ b/test/syscalls/linux/socket_ip_udp_loopback.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -44,5 +45,6 @@ INSTANTIATE_TEST_SUITE_P(
AllUDPSockets, UDPSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
index 1e259efa7..d6925a8df 100644
--- a/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
+++ b/test/syscalls/linux/socket_ip_udp_loopback_blocking.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingUDPSockets, BlockingNonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
index 74cbd326d..d675eddc6 100644
--- a/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
+++ b/test/syscalls/linux/socket_ip_udp_loopback_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingUDPSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc
index b02872308..1c7b0cf90 100644
--- a/test/syscalls/linux/socket_ip_unbound.cc
+++ b/test/syscalls/linux/socket_ip_unbound.cc
@@ -40,7 +40,7 @@ TEST_P(IPUnboundSocketTest, TtlDefault) {
socklen_t get_sz = sizeof(get);
EXPECT_THAT(getsockopt(socket->get(), IPPROTO_IP, IP_TTL, &get, &get_sz),
SyscallSucceedsWithValue(0));
- EXPECT_EQ(get, 64);
+ EXPECT_TRUE(get == 64 || get == 127);
EXPECT_EQ(get_sz, sizeof(get));
}
@@ -129,6 +129,7 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTtl) {
struct TOSOption {
int level;
int option;
+ int cmsg_level;
};
constexpr int INET_ECN_MASK = 3;
@@ -139,10 +140,12 @@ static TOSOption GetTOSOption(int domain) {
case AF_INET:
opt.level = IPPROTO_IP;
opt.option = IP_TOS;
+ opt.cmsg_level = SOL_IP;
break;
case AF_INET6:
opt.level = IPPROTO_IPV6;
opt.option = IPV6_TCLASS;
+ opt.cmsg_level = SOL_IPV6;
break;
}
return opt;
@@ -154,7 +157,7 @@ TEST_P(IPUnboundSocketTest, TOSDefault) {
int get = -1;
socklen_t get_sz = sizeof(get);
constexpr int kDefaultTOS = 0;
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, kDefaultTOS);
@@ -170,7 +173,7 @@ TEST_P(IPUnboundSocketTest, SetTOS) {
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, set);
@@ -185,7 +188,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOS) {
SyscallSucceedsWithValue(0));
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, set);
@@ -207,7 +210,7 @@ TEST_P(IPUnboundSocketTest, InvalidLargeTOS) {
}
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, kDefaultTOS);
@@ -226,7 +229,7 @@ TEST_P(IPUnboundSocketTest, CheckSkipECN) {
}
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, expect);
@@ -246,7 +249,7 @@ TEST_P(IPUnboundSocketTest, ZeroTOSOptionSize) {
}
int get = -1;
socklen_t get_sz = 0;
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, 0);
EXPECT_EQ(get, -1);
@@ -273,7 +276,7 @@ TEST_P(IPUnboundSocketTest, SmallTOSOptionSize) {
}
uint get = -1;
socklen_t get_sz = i;
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, expect_sz);
// Account for partial copies by getsockopt, retrieve the lower
@@ -294,7 +297,7 @@ TEST_P(IPUnboundSocketTest, LargeTOSOptionSize) {
// We expect the system call handler to only copy atmost sizeof(int) bytes
// as asserted by the check below. Hence, we do not expect the copy to
// overflow in getsockopt.
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(int));
EXPECT_EQ(get, set);
@@ -322,7 +325,7 @@ TEST_P(IPUnboundSocketTest, NegativeTOS) {
}
int get = -1;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, expect);
@@ -335,25 +338,118 @@ TEST_P(IPUnboundSocketTest, InvalidNegativeTOS) {
TOSOption t = GetTOSOption(GetParam().domain);
int expect;
if (GetParam().domain == AF_INET) {
- EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
SyscallSucceedsWithValue(0));
expect = static_cast<uint8_t>(set);
if (GetParam().protocol == IPPROTO_TCP) {
expect &= ~INET_ECN_MASK;
}
} else {
- EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
+ ASSERT_THAT(setsockopt(socket->get(), t.level, t.option, &set, set_sz),
SyscallFailsWithErrno(EINVAL));
expect = 0;
}
int get = 0;
socklen_t get_sz = sizeof(get);
- EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
+ ASSERT_THAT(getsockopt(socket->get(), t.level, t.option, &get, &get_sz),
SyscallSucceedsWithValue(0));
EXPECT_EQ(get_sz, sizeof(get));
EXPECT_EQ(get, expect);
}
+TEST_P(IPUnboundSocketTest, NullTOS) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+ int set_sz = sizeof(int);
+ if (GetParam().domain == AF_INET) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
+ SyscallFailsWithErrno(EFAULT));
+ } else { // AF_INET6
+ // The AF_INET6 behavior is not yet compatible. gVisor will try to read
+ // optval from user memory at syscall handler, it needs substantial
+ // refactoring to implement this behavior just for IPv6.
+ if (IsRunningOnGvisor()) {
+ EXPECT_THAT(setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
+ SyscallFailsWithErrno(EFAULT));
+ } else {
+ // Linux's IPv6 stack treats nullptr optval as input of 0, so the call
+ // succeeds. (net/ipv6/ipv6_sockglue.c, do_ipv6_setsockopt())
+ //
+ // Linux's implementation would need fixing as passing a nullptr as optval
+ // and non-zero optlen may not be valid.
+ // TODO(b/158666797): Combine the gVisor and linux cases for IPv6.
+ // Some kernel versions return EFAULT, so we handle both.
+ EXPECT_THAT(
+ setsockopt(socket->get(), t.level, t.option, nullptr, set_sz),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallSucceedsWithValue(0)));
+ }
+ }
+ socklen_t get_sz = sizeof(int);
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, nullptr, &get_sz),
+ SyscallFailsWithErrno(EFAULT));
+ int get = -1;
+ EXPECT_THAT(getsockopt(socket->get(), t.level, t.option, &get, nullptr),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) {
+ SKIP_IF(GetParam().protocol == IPPROTO_TCP);
+
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ TOSOption t = GetTOSOption(GetParam().domain);
+
+ in_addr addr4;
+ in6_addr addr6;
+ ASSERT_THAT(inet_pton(AF_INET, "127.0.0.1", &addr4), ::testing::Eq(1));
+ ASSERT_THAT(inet_pton(AF_INET6, "fe80::", &addr6), ::testing::Eq(1));
+
+ cmsghdr cmsg = {};
+ cmsg.cmsg_len = sizeof(cmsg);
+ cmsg.cmsg_level = t.cmsg_level;
+ cmsg.cmsg_type = t.option;
+
+ msghdr msg = {};
+ msg.msg_control = &cmsg;
+ msg.msg_controllen = sizeof(cmsg);
+ if (GetParam().domain == AF_INET) {
+ msg.msg_name = &addr4;
+ msg.msg_namelen = sizeof(addr4);
+ } else {
+ msg.msg_name = &addr6;
+ msg.msg_namelen = sizeof(addr6);
+ }
+
+ EXPECT_THAT(sendmsg(socket->get(), &msg, 0), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(IPUnboundSocketTest, ReuseAddrDefault) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
+TEST_P(IPUnboundSocketTest, SetReuseAddr) {
+ auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ASSERT_THAT(setsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ int get = -1;
+ socklen_t get_sz = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(socket->get(), SOL_SOCKET, SO_REUSEADDR, &get, &get_sz),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_sz, sizeof(get));
+}
+
INSTANTIATE_TEST_SUITE_P(
IPUnboundSockets, IPUnboundSocketTest,
::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>(
diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc
index 3c3712b50..80f12b0a9 100644
--- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.cc
@@ -18,6 +18,7 @@
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
+
#include <cstdio>
#include <cstring>
diff --git a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
index 92f03e045..797c4174e 100644
--- a/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
+++ b/test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking_test.cc
@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h"
+
#include <vector>
#include "test/syscalls/linux/ip_socket_test_util.h"
-#include "test/syscalls/linux/socket_ipv4_tcp_unbound_external_networking.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketKind> GetSockets() {
return ApplyVec<SocketKind>(
@@ -31,5 +33,7 @@ std::vector<SocketKind> GetSockets() {
INSTANTIATE_TEST_SUITE_P(IPv4TCPUnboundSockets,
IPv4TCPUnboundExternalNetworkingSocketTest,
::testing::ValuesIn(GetSockets()));
+
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index b828b6844..bc005e2bb 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -15,12 +15,16 @@
#include "test/syscalls/linux/socket_ipv4_udp_unbound.h"
#include <arpa/inet.h>
+#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
+#include <sys/types.h>
#include <sys/un.h>
+
#include <cstdio>
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/test_util.h"
@@ -28,49 +32,29 @@
namespace gvisor {
namespace testing {
-constexpr char kMulticastAddress[] = "224.0.2.1";
-constexpr char kBroadcastAddress[] = "255.255.255.255";
-
-TestAddress V4Multicast() {
- TestAddress t("V4Multicast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- inet_addr(kMulticastAddress);
- return t;
-}
-
-TestAddress V4Broadcast() {
- TestAddress t("V4Broadcast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- inet_addr(kBroadcastAddress);
- return t;
-}
-
// Check that packets are not received without a group membership. Default send
// interface configured by bind.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the first FD to the loopback. This is an alternative to
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
EXPECT_THAT(
- bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address. If multicast worked like unicast,
// this would ensure that we get the packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -82,33 +66,33 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNoGroup) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
}
// Check that not setting a default send interface prevents multicast packets
// from being sent. Group membership interface configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddrNoDefaultSendIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the second FD to the v4 any address to ensure that we can receive any
// unicast packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -118,8 +102,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -128,27 +112,27 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddrNoDefaultSendIf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallFailsWithErrno(ENETUNREACH));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallFailsWithErrno(ENETUNREACH));
}
// Check that not setting a default send interface prevents multicast packets
// from being sent. Group membership interface configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNicNoDefaultSendIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the second FD to the v4 any address to ensure that we can receive any
// unicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -158,8 +142,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -168,35 +152,35 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNicNoDefaultSendIf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallFailsWithErrno(ENETUNREACH));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallFailsWithErrno(ENETUNREACH));
}
// Check that multicast works when the default send interface is configured by
// bind and the group membership is configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the first FD to the loopback. This is an alternative to
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
ASSERT_THAT(
- bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -206,8 +190,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -216,43 +200,42 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackAddr) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
// Check that multicast works when the default send interface is configured by
// bind and the group membership is configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the first FD to the loopback. This is an alternative to
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
ASSERT_THAT(
- bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -262,8 +245,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -272,17 +255,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -290,25 +271,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackNic) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in sendto, and the group
// membership is configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreq iface = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -318,8 +300,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -328,17 +310,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -346,25 +326,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddr) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in sendto, and the group
// membership is configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreqn iface = {};
iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -374,8 +355,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -384,17 +365,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -402,25 +381,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNic) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in connect, and the group
// membership is configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreq iface = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -430,8 +410,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -439,22 +419,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
ASSERT_THAT(
- RetryEINTR(connect)(sockets->first_fd(),
+ RetryEINTR(connect)(socket1->get(),
reinterpret_cast<sockaddr*>(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -462,25 +440,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrConnect) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in connect, and the group
// membership is configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreqn iface = {};
iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -490,8 +469,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -499,22 +478,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
ASSERT_THAT(
- RetryEINTR(connect)(sockets->first_fd(),
+ RetryEINTR(connect)(socket1->get(),
reinterpret_cast<sockaddr*>(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -522,25 +499,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicConnect) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in sendto, and the group
// membership is configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreq iface = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->first_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -550,8 +528,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -560,17 +538,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -578,25 +554,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelf) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in sendto, and the group
// membership is configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreqn iface = {};
iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->first_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -606,8 +583,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -616,17 +593,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -634,25 +609,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelf) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in connect, and the group
// membership is configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreq iface = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->first_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -662,8 +638,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -671,20 +647,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
EXPECT_THAT(
- RetryEINTR(connect)(sockets->first_fd(),
+ RetryEINTR(connect)(socket1->get(),
reinterpret_cast<sockaddr*>(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf),
+ EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
}
@@ -692,25 +667,26 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfConnect) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in connect, and the group
// membership is configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreqn iface = {};
iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->first_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -720,8 +696,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -729,20 +705,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) {
reinterpret_cast<sockaddr_in*>(&connect_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
ASSERT_THAT(
- RetryEINTR(connect)(sockets->first_fd(),
+ RetryEINTR(connect)(socket1->get(),
reinterpret_cast<sockaddr*>(&connect_addr.addr),
connect_addr.addr_len),
SyscallSucceeds());
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(send)(sockets->first_fd(), send_buf, sizeof(send_buf), 0),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(send)(socket1->get(), send_buf, sizeof(send_buf), 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf),
+ EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
}
@@ -750,29 +725,30 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfConnect) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in sendto, and the group
// membership is configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreq iface = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP,
&kSockOptOff, sizeof(kSockOptOff)),
SyscallSucceeds());
// Bind the first FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->first_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -782,8 +758,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -792,17 +768,15 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
@@ -810,29 +784,30 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfAddrSelfNoLoop) {
// Check that multicast works when the default send interface is configured by
// IP_MULTICAST_IF, the send address is specified in sendto, and the group
// membership is configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Set the default send interface.
ip_mreqn iface = {};
iface.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_LOOP,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_LOOP,
&kSockOptOff, sizeof(kSockOptOff)),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->first_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -842,8 +817,8 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -852,57 +827,57 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastLoopbackIfNicSelfNoLoop) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(
- RetryEINTR(recv)(sockets->first_fd(), recv_buf, sizeof(recv_buf), 0),
- SyscallSucceedsWithValue(sizeof(recv_buf)));
+ ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0),
+ SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
// Check that dropping a group membership that does not exist fails.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastInvalidDrop) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastInvalidDrop) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Unregister from a membership that we didn't have.
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallFailsWithErrno(EADDRNOTAVAIL));
}
// Check that dropping a group membership prevents multicast packets from being
// delivered. Default send address configured by bind and group membership
// interface configured by address.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the first FD to the loopback. This is an alternative to
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
EXPECT_THAT(
- bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -912,11 +887,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) {
ip_mreq group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -925,15 +900,14 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
}
@@ -941,26 +915,27 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropAddr) {
// Check that dropping a group membership prevents multicast packets from being
// delivered. Default send address configured by bind and group membership
// interface configured by NIC ID.
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind the first FD to the loopback. This is an alternative to
// IP_MULTICAST_IF for setting the default send interface.
auto sender_addr = V4Loopback();
EXPECT_THAT(
- bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
// Bind the second FD to the v4 any address to ensure that we can receive the
// multicast packet.
auto receiver_addr = V4Any();
- EXPECT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ EXPECT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -970,11 +945,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet.
@@ -983,50 +958,53 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastDropNic) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- EXPECT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&send_addr.addr),
- send_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we did not receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- EXPECT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfZero) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn iface = {};
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfInvalidNic) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn iface = {};
iface.imr_ifindex = -1;
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallFailsWithErrno(EADDRNOTAVAIL));
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfInvalidAddr) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfInvalidAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreq iface = {};
iface.imr_interface.s_addr = inet_addr("255.255.255");
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallFailsWithErrno(EADDRNOTAVAIL));
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetShort) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetShort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Create a valid full-sized request.
ip_mreqn iface = {};
@@ -1034,29 +1012,31 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetShort) {
// Send an optlen of 1 to check that optlen is enforced.
EXPECT_THAT(
- setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1),
+ setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface, 1),
SyscallFailsWithErrno(EINVAL));
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefault) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefault) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
in_addr get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
EXPECT_EQ(size, sizeof(get));
EXPECT_EQ(get.s_addr, 0);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefaultReqn) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfDefaultReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
// getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
@@ -1071,19 +1051,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfDefaultReqn) {
EXPECT_EQ(get.imr_ifindex, 0);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddrGetReqn) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddrGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
in_addr set = {};
set.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
sizeof(set)),
SyscallSucceeds());
ip_mreqn get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
// getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
@@ -1095,19 +1076,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddrGetReqn) {
EXPECT_EQ(get.imr_ifindex, 0);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddrGetReqn) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddrGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreq set = {};
set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
sizeof(set)),
SyscallSucceeds());
ip_mreqn get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
// getsockopt(IP_MULTICAST_IF) can only return an in_addr, so it treats the
@@ -1119,19 +1101,20 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddrGetReqn) {
EXPECT_EQ(get.imr_ifindex, 0);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNicGetReqn) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNicGetReqn) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn set = {};
set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
sizeof(set)),
SyscallSucceeds());
ip_mreqn get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
EXPECT_EQ(size, sizeof(in_addr));
EXPECT_EQ(get.imr_multiaddr.s_addr, 0);
@@ -1139,87 +1122,93 @@ TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNicGetReqn) {
EXPECT_EQ(get.imr_ifindex, 0);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetAddr) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
in_addr set = {};
set.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
sizeof(set)),
SyscallSucceeds());
in_addr get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
EXPECT_EQ(size, sizeof(get));
EXPECT_EQ(get.s_addr, set.s_addr);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetReqAddr) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetReqAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreq set = {};
set.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
sizeof(set)),
SyscallSucceeds());
in_addr get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
EXPECT_EQ(size, sizeof(get));
EXPECT_EQ(get.s_addr, set.imr_interface.s_addr);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, IpMulticastIfSetNic) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfSetNic) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn set = {};
set.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &set,
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &set,
sizeof(set)),
SyscallSucceeds());
in_addr get = {};
socklen_t size = sizeof(get);
ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
+ getsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &get, &size),
SyscallSucceeds());
EXPECT_EQ(size, sizeof(get));
EXPECT_EQ(get.s_addr, 0);
}
-TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupNoIf) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupNoIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallFailsWithErrno(ENODEV));
}
-TEST_P(IPv4UDPUnboundSocketPairTest, TestJoinGroupInvalidIf) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestJoinGroupInvalidIf) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn group = {};
group.imr_address.s_addr = inet_addr("255.255.255");
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallFailsWithErrno(ENODEV));
}
// Check that multiple memberships are not allowed on the same socket.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- auto fd = sockets->first_fd();
+TEST_P(IPv4UDPUnboundSocketTest, TestMultipleJoinsOnSingleSocket) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto fd = socket1->get();
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
@@ -1234,41 +1223,44 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestMultipleJoinsOnSingleSocket) {
}
// Check that two sockets can join the same multicast group at the same time.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestTwoSocketsJoinSameMulticastGroup) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestTwoSocketsJoinSameMulticastGroup) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Drop the membership twice on each socket, the second call for each socket
// should fail.
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
- EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallFailsWithErrno(EADDRNOTAVAIL));
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
- EXPECT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_DROP_MEMBERSHIP,
- &group, sizeof(group)),
+ EXPECT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_DROP_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallFailsWithErrno(EADDRNOTAVAIL));
}
// Check that two sockets can join the same multicast group at the same time,
// and both will receive data on it.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) {
+TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) {
std::unique_ptr<SocketPair> socket_pairs[2] = {
- ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()),
- ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())};
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())),
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))};
ip_mreq iface = {}, group = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
@@ -1338,11 +1330,12 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestMcastReceptionOnTwoSockets) {
// Check that on two sockets that joined a group and listen on ANY, dropping
// memberships one by one will continue to deliver packets to both sockets until
// both memberships have been dropped.
-TEST_P(IPv4UDPUnboundSocketPairTest,
- TestMcastReceptionWhenDroppingMemberships) {
+TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) {
std::unique_ptr<SocketPair> socket_pairs[2] = {
- ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()),
- ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair())};
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket())),
+ absl::make_unique<FDSocketPair>(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocket()))};
ip_mreq iface = {}, group = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
@@ -1437,18 +1430,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest,
// Check that a receiving socket can bind to the multicast address before
// joining the group and receive data once the group has been joined.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind second socket (receiver) to the multicast address.
auto receiver_addr = V4Multicast();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Update receiver_addr with the correct port number.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -1458,30 +1452,29 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) {
ip_mreqn group = {};
group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
- ASSERT_THAT(setsockopt(sockets->second_fd(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
- &group, sizeof(group)),
+ ASSERT_THAT(setsockopt(socket2->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, &group,
+ sizeof(group)),
SyscallSucceeds());
// Send a multicast packet on the first socket out the loopback interface.
ip_mreq iface = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
auto sendto_addr = V4Multicast();
reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -1489,18 +1482,19 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenJoinThenReceive) {
// Check that a receiving socket can bind to the multicast address and won't
// receive multicast data if it hasn't joined the group.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind second socket (receiver) to the multicast address.
auto receiver_addr = V4Multicast();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
// Update receiver_addr with the correct port number.
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -1509,40 +1503,40 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenNoJoinThenNoReceive) {
// Send a multicast packet on the first socket out the loopback interface.
ip_mreq iface = {};
iface.imr_interface.s_addr = htonl(INADDR_LOOPBACK);
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_MULTICAST_IF,
- &iface, sizeof(iface)),
+ ASSERT_THAT(setsockopt(socket1->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
+ sizeof(iface)),
SyscallSucceeds());
auto sendto_addr = V4Multicast();
reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we don't receive the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallFailsWithErrno(EAGAIN));
}
// Check that a socket can bind to a multicast address and still send out
// packets.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind second socket (receiver) to the ANY address.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -1551,11 +1545,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) {
// Bind the first socket (sender) to the multicast address.
auto sender_addr = V4Multicast();
ASSERT_THAT(
- bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
socklen_t sender_addr_len = sender_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&sender_addr.addr),
&sender_addr_len),
SyscallSucceeds());
@@ -1567,15 +1561,14 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -1583,46 +1576,46 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToMcastThenSend) {
// Check that a receiving socket can bind to the broadcast address and receive
// broadcast packets.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind second socket (receiver) to the broadcast address.
auto receiver_addr = V4Broadcast();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
// Send a broadcast packet on the first socket out the loopback interface.
- EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST,
- &kSockOptOn, sizeof(kSockOptOn)),
+ EXPECT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn,
+ sizeof(kSockOptOn)),
SyscallSucceedsWithValue(0));
// Note: Binding to the loopback interface makes the broadcast go out of it.
auto sender_bind_addr = V4Loopback();
- ASSERT_THAT(bind(sockets->first_fd(),
- reinterpret_cast<sockaddr*>(&sender_bind_addr.addr),
- sender_bind_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_bind_addr.addr),
+ sender_bind_addr.addr_len),
+ SyscallSucceeds());
auto sendto_addr = V4Broadcast();
reinterpret_cast<sockaddr_in*>(&sendto_addr.addr)->sin_port =
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the multicast packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
@@ -1630,17 +1623,18 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenReceive) {
// Check that a socket can bind to the broadcast address and still send out
// packets.
-TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Bind second socket (receiver) to the ANY address.
auto receiver_addr = V4Any();
- ASSERT_THAT(bind(sockets->second_fd(),
- reinterpret_cast<sockaddr*>(&receiver_addr.addr),
- receiver_addr.addr_len),
- SyscallSucceeds());
+ ASSERT_THAT(
+ bind(socket2->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
socklen_t receiver_addr_len = receiver_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->second_fd(),
+ ASSERT_THAT(getsockname(socket2->get(),
reinterpret_cast<sockaddr*>(&receiver_addr.addr),
&receiver_addr_len),
SyscallSucceeds());
@@ -1649,11 +1643,11 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) {
// Bind the first socket (sender) to the broadcast address.
auto sender_addr = V4Broadcast();
ASSERT_THAT(
- bind(sockets->first_fd(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ bind(socket1->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
sender_addr.addr_len),
SyscallSucceeds());
socklen_t sender_addr_len = sender_addr.addr_len;
- ASSERT_THAT(getsockname(sockets->first_fd(),
+ ASSERT_THAT(getsockname(socket1->get(),
reinterpret_cast<sockaddr*>(&sender_addr.addr),
&sender_addr_len),
SyscallSucceeds());
@@ -1665,19 +1659,898 @@ TEST_P(IPv4UDPUnboundSocketPairTest, TestBindToBcastThenSend) {
reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
char send_buf[200];
RandomizeBuffer(send_buf, sizeof(send_buf));
- ASSERT_THAT(
- RetryEINTR(sendto)(sockets->first_fd(), send_buf, sizeof(send_buf), 0,
- reinterpret_cast<sockaddr*>(&sendto_addr.addr),
- sendto_addr.addr_len),
- SyscallSucceedsWithValue(sizeof(send_buf)));
+ ASSERT_THAT(RetryEINTR(sendto)(socket1->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&sendto_addr.addr),
+ sendto_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
// Check that we received the packet.
char recv_buf[sizeof(send_buf)] = {};
- ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf),
+ ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf),
MSG_DONTWAIT),
SyscallSucceedsWithValue(sizeof(recv_buf)));
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
}
+// Check that SO_REUSEADDR always delivers to the most recently bound socket.
+//
+// FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly. Enable
+// random and co-op save (below) once that is fixed.
+TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) {
+ std::vector<std::unique_ptr<FileDescriptor>> sockets;
+ sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()));
+
+ ASSERT_THAT(setsockopt(sockets[0]->get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(sockets[0]->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(sockets[0]->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ constexpr int kMessageSize = 200;
+
+ // FIXME(gvisor.dev/issue/873): Endpoint order is not restored correctly.
+ const DisableSave ds;
+
+ for (int i = 0; i < 10; i++) {
+ // Add a new receiver.
+ sockets.emplace_back(ASSERT_NO_ERRNO_AND_VALUE(NewSocket()));
+ auto& last = sockets.back();
+ ASSERT_THAT(setsockopt(last->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(last->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Send a new message to the SO_REUSEADDR group. We use a new socket each
+ // time so that a new ephemeral port will be used each time. This ensures
+ // that we aren't doing REUSEPORT-like hash load blancing.
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ char send_buf[kMessageSize];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Verify that the most recent socket got the message. We don't expect any
+ // of the other sockets to have received it, but we will check that later.
+ char recv_buf[sizeof(send_buf)] = {};
+ EXPECT_THAT(
+ RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+ }
+
+ // Verify that no other messages were received.
+ for (auto& socket : sockets) {
+ char recv_buf[kMessageSize] = {};
+ EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrThenReusePort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReusePortThenReuseAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReusePort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConvertibleToReuseAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable1) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT.
+ socket2->reset();
+
+ // Bind socket3 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindReuseAddrReusePortConversionReversable2) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Close socket2 to revert to just socket1 with REUSEADDR and REUSEPORT.
+ socket2->reset();
+
+ // Bind socket3 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReusePort) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, also with REUSEADDR and
+ // REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEPORT.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, BindDoubleReuseAddrReusePortThenReuseAddr) {
+ auto socket1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket3 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind socket1 with REUSEADDR and REUSEPORT.
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket1->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(socket1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(socket1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind socket2 to the same address as socket1, also with REUSEADDR and
+ // REUSEPORT.
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(socket2->get(), SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(bind(socket2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind socket3 to the same address as socket1, only with REUSEADDR.
+ ASSERT_THAT(setsockopt(socket3->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(socket3->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+}
+
+// Check that REUSEPORT takes precedence over REUSEADDR.
+TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
+ auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(receiver1->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(receiver1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Bind receiver2 to the same address as socket1, also with REUSEADDR and
+ // REUSEPORT.
+ ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEADDR,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(receiver2->get(), SOL_SOCKET, SO_REUSEPORT,
+ &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(bind(receiver2->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+
+ constexpr int kMessageSize = 10;
+
+ for (int i = 0; i < 100; ++i) {
+ // Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket
+ // each time so that a new ephemerial port will be used each time. This
+ // ensures that we cycle through hashes.
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ char send_buf[kMessageSize] = {};
+ EXPECT_THAT(RetryEINTR(sendto)(sender->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ }
+
+ // Check that both receivers got messages. This checks that we are using load
+ // balancing (REUSEPORT) instead of the most recently bound socket
+ // (REUSEADDR).
+ char recv_buf[kMessageSize] = {};
+ EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(kMessageSize));
+ EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf),
+ MSG_DONTWAIT),
+ SyscallSucceedsWithValue(kMessageSize));
+}
+
+// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports.
+// We disable S/R because this test creates a large number of sockets.
+TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) {
+ auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ constexpr int kClients = 65536;
+ // Bind the first socket to the loopback and take note of the selected port.
+ auto addr = V4Loopback();
+ ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len),
+ SyscallSucceeds());
+ socklen_t addr_len = addr.addr_len;
+ ASSERT_THAT(getsockname(receiver1->get(),
+ reinterpret_cast<sockaddr*>(&addr.addr), &addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(addr_len, addr.addr_len);
+
+ // Disable cooperative S/R as we are making too many syscalls.
+ DisableSave ds;
+ std::vector<std::unique_ptr<FileDescriptor>> sockets;
+ for (int i = 0; i < kClients; i++) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr),
+ addr.addr_len);
+ if (ret == 0) {
+ sockets.push_back(std::move(s));
+ continue;
+ }
+ ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN));
+ break;
+ }
+}
+
+// Test that socket will receive packet info control message.
+TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF((IsRunningWithHostinet()));
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto sender_addr = V4Loopback();
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t sender_addr_len = sender_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ &sender_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
+
+ auto receiver_addr = V4Loopback();
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port;
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {};
+ size_t cmsg_data_len = sizeof(in_pktinfo);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Get loopback index.
+ ifreq ifr = {};
+ absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo");
+ ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ ASSERT_NE(ifr.ifr_ifindex, 0);
+
+ // Check the data
+ in_pktinfo received_pktinfo = {};
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex);
+ EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
+
+// Check that setting SO_RCVBUF below min is clamped to the minimum
+// receive buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufBelowMin) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &below_min,
+ sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_RCVBUF above max is clamped to the maximum
+// receive buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufAboveMax) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &above_max,
+ sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_RCVBUF min <= rcvBufSz <= max is honored.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBuf) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kRcvBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &quarter_sz,
+ sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_RCVBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+ ASSERT_EQ(quarter_sz, val);
+}
+
+// Check that setting SO_SNDBUF below min is clamped to the minimum
+// send buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufBelowMin) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value so let's use a value that when doubled will still
+ // be smaller than min.
+ int below_min = min / 2 - 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &below_min,
+ sizeof(below_min)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ ASSERT_EQ(min, val);
+}
+
+// Check that setting SO_SNDBUF above max is clamped to the maximum
+// send buffer size.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBufAboveMax) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ int max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+
+ int above_max = max + 1;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &above_max,
+ sizeof(above_max)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+ ASSERT_EQ(max, val);
+}
+
+// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored.
+TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ int max = 0;
+ int min = 0;
+ {
+ // Discover maxmimum buffer size by setting to a really large value.
+ constexpr int kSndBufSz = 0xffffffff;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ max = 0;
+ socklen_t max_len = sizeof(max);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &max, &max_len),
+ SyscallSucceeds());
+ }
+
+ {
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kSndBufSz = 0;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &kSndBufSz,
+ sizeof(kSndBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ int quarter_sz = min + (max - min) / 4;
+ ASSERT_THAT(setsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &quarter_sz,
+ sizeof(quarter_sz)),
+ SyscallSucceeds());
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s->get(), SOL_SOCKET, SO_SNDBUF, &val, &val_len),
+ SyscallSucceeds());
+
+ // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF.
+ if (!IsRunningOnGvisor()) {
+ quarter_sz *= 2;
+ }
+
+ ASSERT_EQ(quarter_sz, val);
+}
+
+TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) {
+ auto sender_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+
+ // Bind the first FD to the loopback. This is an alternative to
+ // IP_MULTICAST_IF for setting the default send interface.
+ auto sender_addr = V4Loopback();
+ ASSERT_THAT(
+ bind(sender_socket->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+
+ // Bind the second FD to the v4 any address to ensure that we can receive the
+ // multicast packet.
+ auto receiver_addr = V4Any();
+ ASSERT_THAT(bind(receiver_socket->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t receiver_addr_len = receiver_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver_socket->get(),
+ reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ &receiver_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len);
+
+ // Register to receive multicast packets.
+ ip_mreqn group = {};
+ group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress);
+ group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo"));
+ ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP,
+ &group, sizeof(group)),
+ SyscallSucceeds());
+
+ // Register to receive IP packet info.
+ const int one = 1;
+ ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_PKTINFO, &one,
+ sizeof(one)),
+ SyscallSucceeds());
+
+ // Send a multicast packet.
+ auto send_addr = V4Multicast();
+ reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port;
+ char send_buf[200];
+ RandomizeBuffer(send_buf, sizeof(send_buf));
+ ASSERT_THAT(
+ RetryEINTR(sendto)(sender_socket->get(), send_buf, sizeof(send_buf), 0,
+ reinterpret_cast<sockaddr*>(&send_addr.addr),
+ send_addr.addr_len),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+
+ // Check that we received the multicast packet.
+ msghdr recv_msg = {};
+ iovec recv_iov = {};
+ char recv_buf[sizeof(send_buf)];
+ char recv_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {};
+ size_t cmsg_data_len = sizeof(in_pktinfo);
+ recv_iov.iov_base = recv_buf;
+ recv_iov.iov_len = sizeof(recv_buf);
+ recv_msg.msg_iov = &recv_iov;
+ recv_msg.msg_iovlen = 1;
+ recv_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ recv_msg.msg_control = recv_cmsg_buf;
+ ASSERT_THAT(RetryEINTR(recvmsg)(receiver_socket->get(), &recv_msg, 0),
+ SyscallSucceedsWithValue(sizeof(send_buf)));
+ EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
+
+ // Check the IP_PKTINFO control message.
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP);
+ EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO);
+
+ // Get loopback index.
+ ifreq ifr = {};
+ absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo");
+ ASSERT_THAT(ioctl(receiver_socket->get(), SIOCGIFINDEX, &ifr),
+ SyscallSucceeds());
+ ASSERT_NE(ifr.ifr_ifindex, 0);
+
+ in_pktinfo received_pktinfo = {};
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex);
+ if (IsRunningOnGvisor()) {
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, group.imr_multiaddr.s_addr);
+ } else {
+ EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK));
+ }
+ EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, group.imr_multiaddr.s_addr);
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.h b/test/syscalls/linux/socket_ipv4_udp_unbound.h
index 8e07bfbbf..f64c57645 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.h
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.h
@@ -20,8 +20,8 @@
namespace gvisor {
namespace testing {
-// Test fixture for tests that apply to pairs of IPv4 UDP sockets.
-using IPv4UDPUnboundSocketPairTest = SocketPairTest;
+// Test fixture for tests that apply to IPv4 UDP sockets.
+using IPv4UDPUnboundSocketTest = SimpleSocketTest;
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
index 98ae414f3..b206137eb 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc
@@ -41,63 +41,39 @@ TestAddress V4EmptyAddress() {
return t;
}
-constexpr char kMulticastAddress[] = "224.0.2.1";
-
-TestAddress V4Multicast() {
- TestAddress t("V4Multicast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- inet_addr(kMulticastAddress);
- return t;
-}
-
-TestAddress V4Broadcast() {
- TestAddress t("V4Broadcast");
- t.addr.ss_family = AF_INET;
- t.addr_len = sizeof(sockaddr_in);
- reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
- htonl(INADDR_BROADCAST);
- return t;
-}
-
void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() {
- got_if_infos_ = false;
+ // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
+ // IPv4 address on eth0.
+ found_net_interfaces_ = false;
// Get interface list.
- std::vector<std::string> if_names;
ASSERT_NO_ERRNO(if_helper_.Load());
- if_names = if_helper_.InterfaceList(AF_INET);
+ std::vector<std::string> if_names = if_helper_.InterfaceList(AF_INET);
if (if_names.size() != 2) {
return;
}
// Figure out which interface is where.
- int lo = 0, eth = 1;
- if (if_names[lo] != "lo") {
- lo = 1;
- eth = 0;
- }
-
- if (if_names[lo] != "lo") {
- return;
- }
-
- lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[lo]));
- lo_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[lo]);
- if (lo_if_addr_ == nullptr) {
+ std::string lo = if_names[0];
+ std::string eth = if_names[1];
+ if (lo != "lo") std::swap(lo, eth);
+ if (lo != "lo") return;
+
+ lo_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(lo));
+ auto lo_if_addr = if_helper_.GetAddr(AF_INET, lo);
+ if (lo_if_addr == nullptr) {
return;
}
- lo_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(lo_if_addr_)->sin_addr;
+ lo_if_addr_ = *reinterpret_cast<const sockaddr_in*>(lo_if_addr);
- eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(if_names[eth]));
- eth_if_addr_ = if_helper_.GetAddr(AF_INET, if_names[eth]);
- if (eth_if_addr_ == nullptr) {
+ eth_if_idx_ = ASSERT_NO_ERRNO_AND_VALUE(if_helper_.GetIndex(eth));
+ auto eth_if_addr = if_helper_.GetAddr(AF_INET, eth);
+ if (eth_if_addr == nullptr) {
return;
}
- eth_if_sin_addr_ = reinterpret_cast<sockaddr_in*>(eth_if_addr_)->sin_addr;
+ eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr);
- got_if_infos_ = true;
+ found_net_interfaces_ = true;
}
// Verifies that a newly instantiated UDP socket does not have the
@@ -136,6 +112,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) {
// the destination port number.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastReceivedOnExpectedPort) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -211,9 +188,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// not a unicast address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastReceivedOnExpectedAddresses) {
- // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
- // IPv4 address on eth0.
- SKIP_IF(!got_if_infos_);
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -262,7 +237,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Bind the non-receiving socket to the unicast ethernet address.
auto norecv_addr = rcv1_addr;
reinterpret_cast<sockaddr_in*>(&norecv_addr.addr)->sin_addr =
- eth_if_sin_addr_;
+ eth_if_addr_.sin_addr;
ASSERT_THAT(bind(norcv->get(), reinterpret_cast<sockaddr*>(&norecv_addr.addr),
norecv_addr.addr_len),
SyscallSucceedsWithValue(0));
@@ -298,6 +273,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// (UDPBroadcastSendRecvOnSocketBoundToAny).
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastSendRecvOnSocketBoundToBroadcast) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Enable SO_BROADCAST.
@@ -339,6 +315,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// (UDPBroadcastSendRecvOnSocketBoundToBroadcast).
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
UDPBroadcastSendRecvOnSocketBoundToAny) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Enable SO_BROADCAST.
@@ -377,6 +354,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST
// disabled.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
// Broadcast a test message without having enabled SO_BROADCAST on the sending
@@ -427,6 +405,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// multicast on gVisor.
SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(!found_net_interfaces_);
+
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
@@ -461,6 +441,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Check that multicast packets will be delivered to the sending socket without
// setting an interface.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) {
+ SKIP_IF(!found_net_interfaces_);
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
@@ -504,6 +485,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) {
// set interface and IP_MULTICAST_LOOP disabled.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastSelfLoopOff) {
+ SKIP_IF(!found_net_interfaces_);
auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto bind_addr = V4Any();
@@ -554,6 +536,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) {
// multicast on gVisor.
SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(!found_net_interfaces_);
+
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -592,6 +576,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) {
// Check that multicast packets will be delivered to another socket without
// setting an interface.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -639,6 +624,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) {
// set interface and IP_MULTICAST_LOOP disabled on the sending socket.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastSenderNoLoop) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -690,6 +676,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// setting an interface and IP_MULTICAST_LOOP disabled on the receiving socket.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastReceiverNoLoop) {
+ SKIP_IF(!found_net_interfaces_);
+
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -742,6 +730,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// and both will receive data on it when bound to the ANY address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastToTwoBoundToAny) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
std::unique_ptr<FileDescriptor> receivers[2] = {
ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
@@ -808,6 +797,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// and both will receive data on it when bound to the multicast address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastToTwoBoundToMulticastAddress) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
std::unique_ptr<FileDescriptor> receivers[2] = {
ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
@@ -877,6 +867,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// multicast address, both will receive data.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
TestSendMulticastToTwoBoundToAnyAndMulticastAddress) {
+ SKIP_IF(!found_net_interfaces_);
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
std::unique_ptr<FileDescriptor> receivers[2] = {
ASSERT_NO_ERRNO_AND_VALUE(NewSocket()),
@@ -950,6 +941,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// is not a multicast address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
IpMulticastLoopbackFromAddr) {
+ SKIP_IF(!found_net_interfaces_);
+
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -1017,9 +1010,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// interface, a multicast packet sent out uses the latter as its source address.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
IpMulticastLoopbackIfNicAndAddr) {
- // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
- // IPv4 address on eth0.
- SKIP_IF(!got_if_infos_);
+ SKIP_IF(!found_net_interfaces_);
// Create receiver, bind to ANY and join the multicast group.
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -1048,7 +1039,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
ip_mreqn iface = {};
iface.imr_ifindex = lo_if_idx_;
- iface.imr_address = eth_if_sin_addr_;
+ iface.imr_address = eth_if_addr_.sin_addr;
ASSERT_THAT(setsockopt(sender->get(), IPPROTO_IP, IP_MULTICAST_IF, &iface,
sizeof(iface)),
SyscallSucceeds());
@@ -1078,16 +1069,14 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
SKIP_IF(IsRunningOnGvisor());
// Verify the received source address.
- EXPECT_EQ(eth_if_sin_addr_.s_addr, src_addr_in->sin_addr.s_addr);
+ EXPECT_EQ(eth_if_addr_.sin_addr.s_addr, src_addr_in->sin_addr.s_addr);
}
// Check that when we are bound to one interface we can set IP_MULTICAST_IF to
// another interface.
TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
IpMulticastLoopbackBindToOneIfSetMcastIfToAnother) {
- // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its
- // IPv4 address on eth0.
- SKIP_IF(!got_if_infos_);
+ SKIP_IF(!found_net_interfaces_);
// FIXME (b/137790511): When bound to one interface it is not possible to set
// IP_MULTICAST_IF to a different interface.
@@ -1095,7 +1084,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
// Create sender and bind to eth interface.
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
- ASSERT_THAT(bind(sender->get(), eth_if_addr_, sizeof(sockaddr_in)),
+ ASSERT_THAT(bind(sender->get(), reinterpret_cast<sockaddr*>(&eth_if_addr_),
+ sizeof(eth_if_addr_)),
SyscallSucceeds());
// Run through all possible combinations of index and address for
@@ -1105,9 +1095,9 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest,
struct in_addr imr_address;
} test_data[] = {
{lo_if_idx_, {}},
- {0, lo_if_sin_addr_},
- {lo_if_idx_, lo_if_sin_addr_},
- {lo_if_idx_, eth_if_sin_addr_},
+ {0, lo_if_addr_.sin_addr},
+ {lo_if_idx_, lo_if_addr_.sin_addr},
+ {lo_if_idx_, eth_if_addr_.sin_addr},
};
for (auto t : test_data) {
ip_mreqn iface = {};
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
index bec2e96ee..0e9e70e8e 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h
@@ -29,17 +29,15 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest {
IfAddrHelper if_helper_;
- // got_if_infos_ is set to false if SetUp() could not obtain all interface
- // infos that we need.
- bool got_if_infos_;
+ // found_net_interfaces_ is set to false if SetUp() could not obtain
+ // all interface infos that we need.
+ bool found_net_interfaces_;
// Interface infos.
int lo_if_idx_;
int eth_if_idx_;
- sockaddr* lo_if_addr_;
- sockaddr* eth_if_addr_;
- in_addr lo_if_sin_addr_;
- in_addr eth_if_sin_addr_;
+ sockaddr_in lo_if_addr_;
+ sockaddr_in eth_if_addr_;
};
} // namespace testing
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
index 9d4e1ab97..f6e64c157 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking_test.cc
@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h"
+
#include <vector>
#include "test/syscalls/linux/ip_socket_test_util.h"
-#include "test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketKind> GetSockets() {
return ApplyVec<SocketKind>(
@@ -31,5 +33,7 @@ std::vector<SocketKind> GetSockets() {
INSTANTIATE_TEST_SUITE_P(IPv4UDPUnboundSockets,
IPv4UDPUnboundExternalNetworkingSocketTest,
::testing::ValuesIn(GetSockets()));
+
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
index cb0105471..f121c044d 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback.cc
@@ -22,14 +22,11 @@
namespace gvisor {
namespace testing {
-std::vector<SocketPairKind> GetSocketPairs() {
- return ApplyVec<SocketPairKind>(
- IPv4UDPUnboundSocketPair,
- AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK}));
-}
-
-INSTANTIATE_TEST_SUITE_P(IPv4UDPSockets, IPv4UDPUnboundSocketPairTest,
- ::testing::ValuesIn(GetSocketPairs()));
+INSTANTIATE_TEST_SUITE_P(
+ IPv4UDPSockets, IPv4UDPUnboundSocketTest,
+ ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket,
+ AllBitwiseCombinations(List<int>{
+ 0, SOCK_NONBLOCK}))));
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc
index 765f8e0e4..5f8d7f981 100644
--- a/test/syscalls/linux/socket_netdevice.cc
+++ b/test/syscalls/linux/socket_netdevice.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <linux/ethtool.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <linux/sockios.h>
@@ -49,6 +50,7 @@ TEST(NetdeviceTest, Loopback) {
// Check that the loopback is zero hardware address.
ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds());
+ EXPECT_EQ(ifr.ifr_hwaddr.sa_family, ARPHRD_LOOPBACK);
EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0);
EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0);
EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0);
@@ -68,7 +70,8 @@ TEST(NetdeviceTest, Netmask) {
// Use a netlink socket to get the netmask, which we'll then compare to the
// netmask obtained via ioctl.
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
struct request {
@@ -90,7 +93,7 @@ TEST(NetdeviceTest, Netmask) {
int prefixlen = -1;
ASSERT_NO_ERRNO(NetlinkRequestResponse(
fd, &req, sizeof(req),
- [&](const struct nlmsghdr *hdr) {
+ [&](const struct nlmsghdr* hdr) {
EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE)));
EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI)
@@ -106,8 +109,8 @@ TEST(NetdeviceTest, Netmask) {
// RTM_NEWADDR contains at least the header and ifaddrmsg.
EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg));
- struct ifaddrmsg *ifaddrmsg =
- reinterpret_cast<struct ifaddrmsg *>(NLMSG_DATA(hdr));
+ struct ifaddrmsg* ifaddrmsg =
+ reinterpret_cast<struct ifaddrmsg*>(NLMSG_DATA(hdr));
if (ifaddrmsg->ifa_index == static_cast<uint32_t>(ifr.ifr_ifindex) &&
ifaddrmsg->ifa_family == AF_INET) {
prefixlen = ifaddrmsg->ifa_prefixlen;
@@ -126,8 +129,8 @@ TEST(NetdeviceTest, Netmask) {
snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
ASSERT_THAT(ioctl(sock.get(), SIOCGIFNETMASK, &ifr), SyscallSucceeds());
EXPECT_EQ(ifr.ifr_netmask.sa_family, AF_INET);
- struct sockaddr_in *sin =
- reinterpret_cast<struct sockaddr_in *>(&ifr.ifr_netmask);
+ struct sockaddr_in* sin =
+ reinterpret_cast<struct sockaddr_in*>(&ifr.ifr_netmask);
EXPECT_EQ(sin->sin_addr.s_addr, mask);
}
@@ -177,6 +180,27 @@ TEST(NetdeviceTest, InterfaceMTU) {
EXPECT_GT(ifr.ifr_mtu, 0);
}
+TEST(NetdeviceTest, EthtoolGetTSInfo) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
+
+ struct ethtool_ts_info tsi = {};
+ tsi.cmd = ETHTOOL_GET_TS_INFO; // Get NIC's Timestamping capabilities.
+
+ // Prepare the request.
+ struct ifreq ifr = {};
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ ifr.ifr_data = (void*)&tsi;
+
+ // Check that SIOCGIFMTU returns a nonzero MTU.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ return;
+ }
+ ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_netlink.cc b/test/syscalls/linux/socket_netlink.cc
new file mode 100644
index 000000000..4ec0fd4fa
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink.cc
@@ -0,0 +1,153 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/netlink.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Tests for all netlink socket protocols.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// NetlinkTest parameter is the protocol to test.
+using NetlinkTest = ::testing::TestWithParam<int>;
+
+// Netlink sockets must be SOCK_DGRAM or SOCK_RAW.
+TEST_P(NetlinkTest, Types) {
+ const int protocol = GetParam();
+
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+ EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, protocol),
+ SyscallFailsWithErrno(ESOCKTNOSUPPORT));
+
+ int fd;
+ EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, protocol), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+
+ EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, protocol), SyscallSucceeds());
+ EXPECT_THAT(close(fd), SyscallSucceeds());
+}
+
+TEST_P(NetlinkTest, AutomaticPort) {
+ const int protocol = GetParam();
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol));
+
+ struct sockaddr_nl addr = {};
+ addr.nl_family = AF_NETLINK;
+
+ EXPECT_THAT(
+ bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
+ SyscallSucceeds());
+
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, sizeof(addr));
+ // This is the only netlink socket in the process, so it should get the PID as
+ // the port id.
+ //
+ // N.B. Another process could theoretically have explicitly reserved our pid
+ // as a port ID, but that is very unlikely.
+ EXPECT_EQ(addr.nl_pid, getpid());
+}
+
+// Calling connect automatically binds to an automatic port.
+TEST_P(NetlinkTest, ConnectBinds) {
+ const int protocol = GetParam();
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol));
+
+ struct sockaddr_nl addr = {};
+ addr.nl_family = AF_NETLINK;
+
+ EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, sizeof(addr));
+
+ // Each test is running in a pid namespace, so another process can explicitly
+ // reserve our pid as a port ID. In this case, a negative portid value will be
+ // set.
+ if (static_cast<pid_t>(addr.nl_pid) > 0) {
+ EXPECT_EQ(addr.nl_pid, getpid());
+ }
+
+ memset(&addr, 0, sizeof(addr));
+ addr.nl_family = AF_NETLINK;
+
+ // Connecting again is allowed, but keeps the same port.
+ EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, sizeof(addr));
+ EXPECT_EQ(addr.nl_pid, getpid());
+}
+
+TEST_P(NetlinkTest, GetPeerName) {
+ const int protocol = GetParam();
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, protocol));
+
+ struct sockaddr_nl addr = {};
+ socklen_t addrlen = sizeof(addr);
+
+ EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, sizeof(addr));
+ EXPECT_EQ(addr.nl_family, AF_NETLINK);
+ // Peer is the kernel if we didn't connect elsewhere.
+ EXPECT_EQ(addr.nl_pid, 0);
+}
+
+INSTANTIATE_TEST_SUITE_P(ProtocolTest, NetlinkTest,
+ ::testing::Values(NETLINK_ROUTE,
+ NETLINK_KOBJECT_UEVENT));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index dd4a11655..b3fcf8e7c 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -14,6 +14,7 @@
#include <arpa/inet.h>
#include <ifaddrs.h>
+#include <linux/if.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <sys/socket.h>
@@ -25,8 +26,10 @@
#include "gtest/gtest.h"
#include "absl/strings/str_format.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
#include "test/syscalls/linux/socket_netlink_util.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/test_util.h"
@@ -38,115 +41,12 @@ namespace testing {
namespace {
+constexpr uint32_t kSeq = 12345;
+
using ::testing::AnyOf;
using ::testing::Eq;
-// Netlink sockets must be SOCK_DGRAM or SOCK_RAW.
-TEST(NetlinkRouteTest, Types) {
- EXPECT_THAT(socket(AF_NETLINK, SOCK_STREAM, NETLINK_ROUTE),
- SyscallFailsWithErrno(ESOCKTNOSUPPORT));
- EXPECT_THAT(socket(AF_NETLINK, SOCK_SEQPACKET, NETLINK_ROUTE),
- SyscallFailsWithErrno(ESOCKTNOSUPPORT));
- EXPECT_THAT(socket(AF_NETLINK, SOCK_RDM, NETLINK_ROUTE),
- SyscallFailsWithErrno(ESOCKTNOSUPPORT));
- EXPECT_THAT(socket(AF_NETLINK, SOCK_DCCP, NETLINK_ROUTE),
- SyscallFailsWithErrno(ESOCKTNOSUPPORT));
- EXPECT_THAT(socket(AF_NETLINK, SOCK_PACKET, NETLINK_ROUTE),
- SyscallFailsWithErrno(ESOCKTNOSUPPORT));
-
- int fd;
- EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE),
- SyscallSucceeds());
- EXPECT_THAT(close(fd), SyscallSucceeds());
-
- EXPECT_THAT(fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE),
- SyscallSucceeds());
- EXPECT_THAT(close(fd), SyscallSucceeds());
-}
-
-TEST(NetlinkRouteTest, AutomaticPort) {
- FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE));
-
- struct sockaddr_nl addr = {};
- addr.nl_family = AF_NETLINK;
-
- EXPECT_THAT(
- bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)),
- SyscallSucceeds());
-
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
- &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, sizeof(addr));
- // This is the only netlink socket in the process, so it should get the PID as
- // the port id.
- //
- // N.B. Another process could theoretically have explicitly reserved our pid
- // as a port ID, but that is very unlikely.
- EXPECT_EQ(addr.nl_pid, getpid());
-}
-
-// Calling connect automatically binds to an automatic port.
-TEST(NetlinkRouteTest, ConnectBinds) {
- FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE));
-
- struct sockaddr_nl addr = {};
- addr.nl_family = AF_NETLINK;
-
- EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
- sizeof(addr)),
- SyscallSucceeds());
-
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
- &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, sizeof(addr));
-
- // Each test is running in a pid namespace, so another process can explicitly
- // reserve our pid as a port ID. In this case, a negative portid value will be
- // set.
- if (static_cast<pid_t>(addr.nl_pid) > 0) {
- EXPECT_EQ(addr.nl_pid, getpid());
- }
-
- memset(&addr, 0, sizeof(addr));
- addr.nl_family = AF_NETLINK;
-
- // Connecting again is allowed, but keeps the same port.
- EXPECT_THAT(connect(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
- sizeof(addr)),
- SyscallSucceeds());
-
- addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
- &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, sizeof(addr));
- EXPECT_EQ(addr.nl_pid, getpid());
-}
-
-TEST(NetlinkRouteTest, GetPeerName) {
- FileDescriptor fd =
- ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE));
-
- struct sockaddr_nl addr = {};
- socklen_t addrlen = sizeof(addr);
-
- EXPECT_THAT(getpeername(fd.get(), reinterpret_cast<struct sockaddr*>(&addr),
- &addrlen),
- SyscallSucceeds());
-
- EXPECT_EQ(addrlen, sizeof(addr));
- EXPECT_EQ(addr.nl_family, AF_NETLINK);
- // Peer is the kernel if we didn't connect elsewhere.
- EXPECT_EQ(addr.nl_pid, 0);
-}
-
-// Parameters for GetSockOpt test. They are:
+// Parameters for SockOptTest. They are:
// 0: Socket option to query.
// 1: A predicate to run on the returned sockopt value. Should return true if
// the value is considered ok.
@@ -195,7 +95,8 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(SO_DOMAIN, IsEqual(AF_NETLINK),
absl::StrFormat("AF_NETLINK (%d)", AF_NETLINK)),
std::make_tuple(SO_PROTOCOL, IsEqual(NETLINK_ROUTE),
- absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE))));
+ absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)),
+ std::make_tuple(SO_PASSCRED, IsEqual(0), "0")));
// Validates the reponses to RTM_GETLINK + NLM_F_DUMP.
void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) {
@@ -218,55 +119,170 @@ void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) {
}
TEST(NetlinkRouteTest, GetLinkDump) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+ // Loopback is common among all tests, check that it's found.
+ bool loopbackFound = false;
+ ASSERT_NO_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ CheckGetLinkResponse(hdr, kSeq, port);
+ if (hdr->nlmsg_type != RTM_NEWLINK) {
+ return;
+ }
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ std::cout << "Found interface idx=" << msg->ifi_index
+ << ", type=" << std::hex << msg->ifi_type << std::endl;
+ if (msg->ifi_type == ARPHRD_LOOPBACK) {
+ loopbackFound = true;
+ EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0);
+ }
+ }));
+ EXPECT_TRUE(loopbackFound);
+}
+
+// CheckLinkMsg checks a netlink message against an expected link.
+void CheckLinkMsg(const struct nlmsghdr* hdr, const Link& link) {
+ ASSERT_THAT(hdr->nlmsg_type, Eq(RTM_NEWLINK));
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ EXPECT_EQ(msg->ifi_index, link.index);
+
+ const struct rtattr* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ EXPECT_NE(nullptr, rta) << "IFLA_IFNAME not found in message.";
+ if (rta != nullptr) {
+ std::string name(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ EXPECT_EQ(name, link.name);
+ }
+}
+
+TEST(NetlinkRouteTest, GetLinkByIndex) {
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
struct request {
struct nlmsghdr hdr;
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETLINK;
- req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = loopback_link.index;
- // Loopback is common among all tests, check that it's found.
- bool loopbackFound = false;
+ bool found = false;
ASSERT_NO_ERRNO(NetlinkRequestResponse(
fd, &req, sizeof(req),
[&](const struct nlmsghdr* hdr) {
- CheckGetLinkResponse(hdr, kSeq, port);
- if (hdr->nlmsg_type != RTM_NEWLINK) {
- return;
- }
- ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct ifinfomsg)));
- const struct ifinfomsg* msg =
- reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
- std::cout << "Found interface idx=" << msg->ifi_index
- << ", type=" << std::hex << msg->ifi_type;
- if (msg->ifi_type == ARPHRD_LOOPBACK) {
- loopbackFound = true;
- EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0);
- }
+ CheckLinkMsg(hdr, loopback_link);
+ found = true;
},
false));
- EXPECT_TRUE(loopbackFound);
+ EXPECT_TRUE(found) << "Netlink response does not contain any links.";
}
-TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+TEST(NetlinkRouteTest, GetLinkByName) {
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
struct request {
struct nlmsghdr hdr;
struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
};
- constexpr uint32_t kSeq = 12345;
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1);
+ strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ bool found = false;
+ ASSERT_NO_ERRNO(NetlinkRequestResponse(
+ fd, &req, sizeof(req),
+ [&](const struct nlmsghdr* hdr) {
+ CheckLinkMsg(hdr, loopback_link);
+ found = true;
+ },
+ false));
+ EXPECT_TRUE(found) << "Netlink response does not contain any links.";
+}
+
+TEST(NetlinkRouteTest, GetLinkByIndexNotFound) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = 1234590;
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, ::testing::_));
+}
+
+TEST(NetlinkRouteTest, GetLinkByNameNotFound) {
+ const std::string name = "nodevice?!";
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(name.size() + 1);
+ strncpy(req.ifname, name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, ::testing::_));
+}
+
+TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
@@ -277,30 +293,19 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
req.hdr.nlmsg_seq = kSeq;
req.ifm.ifi_family = AF_UNSPEC;
- ASSERT_NO_ERRNO(NetlinkRequestResponse(
- fd, &req, sizeof(req),
- [&](const struct nlmsghdr* hdr) {
- EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR));
- EXPECT_EQ(hdr->nlmsg_seq, kSeq);
- EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr));
-
- const struct nlmsgerr* msg =
- reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr));
- EXPECT_EQ(msg->error, -EOPNOTSUPP);
- },
- true));
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(EOPNOTSUPP, ::testing::_));
}
TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
struct request {
struct nlmsghdr hdr;
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETLINK;
@@ -331,15 +336,14 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
}
TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
struct request {
struct nlmsghdr hdr;
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETLINK;
@@ -372,7 +376,8 @@ TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) {
}
TEST(NetlinkRouteTest, ControlMessageIgnored) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
struct request {
@@ -381,8 +386,6 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) {
struct ifinfomsg ifm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
// This control message is ignored. We still receive a response for the
@@ -407,7 +410,8 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) {
}
TEST(NetlinkRouteTest, GetAddrDump) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
struct request {
@@ -415,8 +419,6 @@ TEST(NetlinkRouteTest, GetAddrDump) {
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
@@ -465,9 +467,59 @@ TEST(NetlinkRouteTest, LookupAll) {
ASSERT_GT(count, 0);
}
+TEST(NetlinkRouteTest, AddAddr) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifa;
+ struct rtattr rtattr;
+ struct in_addr addr;
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifa.ifa_family = AF_INET;
+ req.ifa.ifa_prefixlen = 24;
+ req.ifa.ifa_flags = 0;
+ req.ifa.ifa_scope = 0;
+ req.ifa.ifa_index = loopback_link.index;
+ req.rtattr.rta_type = IFA_LOCAL;
+ req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr));
+ inet_pton(AF_INET, "10.0.0.1", &req.addr);
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ // Create should succeed, as no such address in kernel.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK;
+ EXPECT_NO_ERRNO(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+
+ // Replace an existing address should succeed.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK;
+ req.hdr.nlmsg_seq++;
+ EXPECT_NO_ERRNO(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len));
+
+ // Create exclusive should fail, as we created the address above.
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK;
+ req.hdr.nlmsg_seq++;
+ EXPECT_THAT(
+ NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len),
+ PosixErrorIs(EEXIST, ::testing::_));
+}
+
// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request.
TEST(NetlinkRouteTest, GetRouteDump) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
struct request {
@@ -475,8 +527,6 @@ TEST(NetlinkRouteTest, GetRouteDump) {
struct rtmsg rtm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req = {};
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETROUTE;
@@ -527,7 +577,10 @@ TEST(NetlinkRouteTest, GetRouteDump) {
std::cout << std::endl;
- if (msg->rtm_table == RT_TABLE_MAIN) {
+ // If the test is running in a new network namespace, it will have only
+ // the local route table.
+ if (msg->rtm_table == RT_TABLE_MAIN ||
+ (!IsRunningOnGvisor() && msg->rtm_table == RT_TABLE_LOCAL)) {
routeFound = true;
dstFound = rtDstFound && dstFound;
}
@@ -539,19 +592,102 @@ TEST(NetlinkRouteTest, GetRouteDump) {
EXPECT_TRUE(dstFound);
}
+// GetRouteRequest tests a RTM_GETROUTE request with RTM_F_LOOKUP_TABLE flag.
+TEST(NetlinkRouteTest, GetRouteRequest) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+ uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get()));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtmsg rtm;
+ struct nlattr nla;
+ struct in_addr sin_addr;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETROUTE;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+
+ req.rtm.rtm_family = AF_INET;
+ req.rtm.rtm_dst_len = 32;
+ req.rtm.rtm_src_len = 0;
+ req.rtm.rtm_tos = 0;
+ req.rtm.rtm_table = RT_TABLE_UNSPEC;
+ req.rtm.rtm_protocol = RTPROT_UNSPEC;
+ req.rtm.rtm_scope = RT_SCOPE_UNIVERSE;
+ req.rtm.rtm_type = RTN_UNSPEC;
+ req.rtm.rtm_flags = RTM_F_LOOKUP_TABLE;
+
+ req.nla.nla_len = 8;
+ req.nla.nla_type = RTA_DST;
+ inet_aton("127.0.0.2", &req.sin_addr);
+
+ bool rtDstFound = false;
+ ASSERT_NO_ERRNO(NetlinkRequestResponseSingle(
+ fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) {
+ // Validate the reponse to RTM_GETROUTE request with RTM_F_LOOKUP_TABLE
+ // flag.
+ EXPECT_THAT(hdr->nlmsg_type, RTM_NEWROUTE);
+
+ EXPECT_TRUE(hdr->nlmsg_flags == 0) << std::hex << hdr->nlmsg_flags;
+
+ EXPECT_EQ(hdr->nlmsg_seq, kSeq);
+ EXPECT_EQ(hdr->nlmsg_pid, port);
+
+ // RTM_NEWROUTE contains at least the header and rtmsg.
+ ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct rtmsg)));
+ const struct rtmsg* msg =
+ reinterpret_cast<const struct rtmsg*>(NLMSG_DATA(hdr));
+
+ // NOTE: rtmsg fields are char fields.
+ std::cout << "Found route table=" << static_cast<int>(msg->rtm_table)
+ << ", protocol=" << static_cast<int>(msg->rtm_protocol)
+ << ", scope=" << static_cast<int>(msg->rtm_scope)
+ << ", type=" << static_cast<int>(msg->rtm_type);
+
+ EXPECT_EQ(msg->rtm_family, AF_INET);
+ EXPECT_EQ(msg->rtm_dst_len, 32);
+ EXPECT_TRUE((msg->rtm_flags & RTM_F_CLONED) == RTM_F_CLONED)
+ << std::hex << msg->rtm_flags;
+
+ int len = RTM_PAYLOAD(hdr);
+ std::cout << ", len=" << len;
+ for (struct rtattr* attr = RTM_RTA(msg); RTA_OK(attr, len);
+ attr = RTA_NEXT(attr, len)) {
+ if (attr->rta_type == RTA_DST) {
+ char address[INET_ADDRSTRLEN] = {};
+ inet_ntop(AF_INET, RTA_DATA(attr), address, sizeof(address));
+ std::cout << ", dst=" << address;
+ rtDstFound = true;
+ } else if (attr->rta_type == RTA_OIF) {
+ const char* oif = reinterpret_cast<const char*>(RTA_DATA(attr));
+ std::cout << ", oif=" << oif;
+ }
+ }
+
+ std::cout << std::endl;
+ }));
+ // Found RTA_DST for RTM_F_LOOKUP_TABLE.
+ EXPECT_TRUE(rtDstFound);
+}
+
// RecvmsgTrunc tests the recvmsg MSG_TRUNC flag with zero length output
// buffer. MSG_TRUNC with a zero length buffer should consume subsequent
// messages off the socket.
TEST(NetlinkRouteTest, RecvmsgTrunc) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
struct request {
struct nlmsghdr hdr;
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
@@ -619,15 +755,14 @@ TEST(NetlinkRouteTest, RecvmsgTrunc) {
// it, so a properly sized buffer can be allocated to store the message. This
// test tests that scenario.
TEST(NetlinkRouteTest, RecvmsgTruncPeek) {
- FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
struct request {
struct nlmsghdr hdr;
struct rtgenmsg rgm;
};
- constexpr uint32_t kSeq = 12345;
-
struct request req;
req.hdr.nlmsg_len = sizeof(req);
req.hdr.nlmsg_type = RTM_GETADDR;
@@ -692,6 +827,111 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) {
} while (type != NLMSG_DONE && type != NLMSG_ERROR);
}
+// No SCM_CREDENTIALS are received without SO_PASSCRED set.
+TEST(NetlinkRouteTest, NoPasscredNoCreds) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ char control[CMSG_SPACE(sizeof(struct ucred))] = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ // Note: This test assumes at least one message is returned by the
+ // RTM_GETADDR request.
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ // No control messages.
+ EXPECT_EQ(CMSG_FIRSTHDR(&msg), nullptr);
+}
+
+// SCM_CREDENTIALS are received with SO_PASSCRED set.
+TEST(NetlinkRouteTest, PasscredCreds) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ char control[CMSG_SPACE(sizeof(struct ucred))] = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ // Note: This test assumes at least one message is returned by the
+ // RTM_GETADDR request.
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ struct ucred creds;
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+
+ memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds));
+
+ // The peer is the kernel, which is "PID" 0.
+ EXPECT_EQ(creds.pid, 0);
+ // The kernel identifies as root. Also allow nobody in case this test is
+ // running in a userns without root mapped.
+ EXPECT_THAT(creds.uid, AnyOf(Eq(0), Eq(65534)));
+ EXPECT_THAT(creds.gid, AnyOf(Eq(0), Eq(65534)));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
new file mode 100644
index 000000000..bde1dbb4d
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -0,0 +1,162 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+
+#include <linux/if.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr uint32_t kSeq = 12345;
+
+} // namespace
+
+PosixError DumpLinks(
+ const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = seq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
+}
+
+PosixErrorOr<std::vector<Link>> DumpLinks() {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ std::vector<Link> links;
+ RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ if (hdr->nlmsg_type != RTM_NEWLINK ||
+ hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
+ return;
+ }
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ if (rta == nullptr) {
+ // Ignore links that do not have a name.
+ return;
+ }
+
+ links.emplace_back();
+ links.back().index = msg->ifi_index;
+ links.back().type = msg->ifi_type;
+ links.back().name =
+ std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ }));
+ return links;
+}
+
+PosixErrorOr<Link> LoopbackLink() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.type == ARPHRD_LOOPBACK) {
+ return link;
+ }
+ }
+ return PosixError(ENOENT, "loopback link not found");
+}
+
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifaddr;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr));
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifaddr.ifa_index = index;
+ req.ifaddr.ifa_family = family;
+ req.ifaddr.ifa_prefixlen = prefixlen;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFA_LOCAL;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char pad[NLMSG_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+ req.ifinfo.ifi_flags = flags;
+ req.ifinfo.ifi_change = change;
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFLA_ADDRESS;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
new file mode 100644
index 000000000..149c4a7f6
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include <vector>
+
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+
+struct Link {
+ int index;
+ int16_t type;
+ std::string name;
+};
+
+PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn);
+
+PosixErrorOr<std::vector<Link>> DumpLinks();
+
+// Returns the loopback link on the system. ENOENT if not found.
+PosixErrorOr<Link> LoopbackLink();
+
+// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
+// LinkChangeFlags changes interface flags. E.g. IFF_UP.
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change);
+
+// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface.
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
diff --git a/test/syscalls/linux/socket_netlink_uevent.cc b/test/syscalls/linux/socket_netlink_uevent.cc
new file mode 100644
index 000000000..da425bed4
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_uevent.cc
@@ -0,0 +1,83 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/filter.h>
+#include <linux/netlink.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+// Tests for NETLINK_KOBJECT_UEVENT sockets.
+//
+// gVisor never sends any messages on these sockets, so we don't test the events
+// themselves.
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// SO_PASSCRED can be enabled. Since no messages are sent in gVisor, we don't
+// actually test receiving credentials.
+TEST(NetlinkUeventTest, PassCred) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT));
+
+ EXPECT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+}
+
+// SO_DETACH_FILTER fails without a filter already installed.
+TEST(NetlinkUeventTest, DetachNoFilter) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT));
+
+ int opt;
+ EXPECT_THAT(
+ setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+// We can attach a BPF filter.
+TEST(NetlinkUeventTest, AttachFilter) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_KOBJECT_UEVENT));
+
+ // Minimal BPF program: a single ret.
+ struct sock_filter filter = {0x6, 0, 0, 0};
+ struct sock_fprog prog = {};
+ prog.len = 1;
+ prog.filter = &filter;
+
+ EXPECT_THAT(
+ setsockopt(fd.get(), SOL_SOCKET, SO_ATTACH_FILTER, &prog, sizeof(prog)),
+ SyscallSucceeds());
+
+ int opt;
+ EXPECT_THAT(
+ setsockopt(fd.get(), SOL_SOCKET, SO_DETACH_FILTER, &opt, sizeof(opt)),
+ SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc
index fcb8f8a88..952eecfe8 100644
--- a/test/syscalls/linux/socket_netlink_util.cc
+++ b/test/syscalls/linux/socket_netlink_util.cc
@@ -12,24 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <sys/socket.h>
+#include "test/syscalls/linux/socket_netlink_util.h"
#include <linux/if_arp.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
+#include <sys/socket.h>
#include <vector>
#include "absl/strings/str_cat.h"
-#include "test/syscalls/linux/socket_netlink_util.h"
#include "test/syscalls/linux/socket_test_util.h"
namespace gvisor {
namespace testing {
-PosixErrorOr<FileDescriptor> NetlinkBoundSocket() {
+PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) {
FileDescriptor fd;
- ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE));
+ ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol));
struct sockaddr_nl addr = {};
addr.nl_family = AF_NETLINK;
@@ -72,9 +72,10 @@ PosixError NetlinkRequestResponse(
iov.iov_base = buf.data();
iov.iov_len = buf.size();
- // Response is a series of NLM_F_MULTI messages, ending with a NLMSG_DONE
- // message.
+ // If NLM_F_MULTI is set, response is a series of messages that ends with a
+ // NLMSG_DONE message.
int type = -1;
+ int flags = 0;
do {
int len;
RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
@@ -90,6 +91,7 @@ PosixError NetlinkRequestResponse(
for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
fn(hdr);
+ flags = hdr->nlmsg_flags;
type = hdr->nlmsg_type;
// Done should include an integer payload for dump_done_errno.
// See net/netlink/af_netlink.c:netlink_dump
@@ -99,15 +101,87 @@ PosixError NetlinkRequestResponse(
EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int)));
}
}
- } while (type != NLMSG_DONE && type != NLMSG_ERROR);
+ } while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR);
if (expect_nlmsgerr) {
EXPECT_EQ(type, NLMSG_ERROR);
- } else {
+ } else if (flags & NLM_F_MULTI) {
EXPECT_EQ(type, NLMSG_DONE);
}
return NoError();
}
+PosixError NetlinkRequestResponseSingle(
+ const FileDescriptor& fd, void* request, size_t len,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct iovec iov = {};
+ iov.iov_base = request;
+ iov.iov_len = len;
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ // No destination required; it defaults to pid 0, the kernel.
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
+
+ constexpr size_t kBufferSize = 4096;
+ std::vector<char> buf(kBufferSize);
+ iov.iov_base = buf.data();
+ iov.iov_len = buf.size();
+
+ int ret;
+ RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
+
+ // We don't bother with the complexity of dealing with truncated messages.
+ // We must allocate a large enough buffer up front.
+ if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) {
+ return PosixError(
+ EIO,
+ absl::StrCat("Received truncated message with flags: ", msg.msg_flags));
+ }
+
+ for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
+ NLMSG_OK(hdr, ret); hdr = NLMSG_NEXT(hdr, ret)) {
+ fn(hdr);
+ }
+
+ return NoError();
+}
+
+PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
+ void* request, size_t len) {
+ // Dummy negative number for no error message received.
+ // We won't get a negative error number so there will be no confusion.
+ int err = -42;
+ RETURN_IF_ERRNO(NetlinkRequestResponse(
+ fd, request, len,
+ [&](const struct nlmsghdr* hdr) {
+ EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type);
+ EXPECT_EQ(hdr->nlmsg_seq, seq);
+ EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr));
+
+ const struct nlmsgerr* msg =
+ reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr));
+ err = -msg->error;
+ },
+ true));
+ return PosixError(err);
+}
+
+const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
+ const struct ifinfomsg* msg, int16_t attr) {
+ const int ifi_space = NLMSG_SPACE(sizeof(*msg));
+ int attrlen = hdr->nlmsg_len - ifi_space;
+ const struct rtattr* rta = reinterpret_cast<const struct rtattr*>(
+ reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space));
+ for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) {
+ if (rta->rta_type == attr) {
+ return rta;
+ }
+ }
+ return nullptr;
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h
index db8639a2f..e13ead406 100644
--- a/test/syscalls/linux/socket_netlink_util.h
+++ b/test/syscalls/linux/socket_netlink_util.h
@@ -15,6 +15,8 @@
#ifndef GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_
#define GVISOR_TEST_SYSCALLS_SOCKET_NETLINK_UTIL_H_
+#include <sys/socket.h>
+// socket.h has to be included before if_arp.h.
#include <linux/if_arp.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
@@ -25,18 +27,35 @@
namespace gvisor {
namespace testing {
-// Returns a bound NETLINK_ROUTE socket.
-PosixErrorOr<FileDescriptor> NetlinkBoundSocket();
+// Returns a bound netlink socket.
+PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol);
// Returns the port ID of the passed socket.
PosixErrorOr<uint32_t> NetlinkPortID(int fd);
-// Send the passed request and call fn will all response netlink messages.
+// Send the passed request and call fn on all response netlink messages.
+//
+// To be used on requests with NLM_F_MULTI reponses.
PosixError NetlinkRequestResponse(
const FileDescriptor& fd, void* request, size_t len,
const std::function<void(const struct nlmsghdr* hdr)>& fn,
bool expect_nlmsgerr);
+// Send the passed request and call fn on all response netlink messages.
+//
+// To be used on requests without NLM_F_MULTI reponses.
+PosixError NetlinkRequestResponseSingle(
+ const FileDescriptor& fd, void* request, size_t len,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn);
+
+// Send the passed request then expect and return an ack or error.
+PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
+ void* request, size_t len);
+
+// Find rtnetlink attribute in message.
+const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
+ const struct ifinfomsg* msg, int16_t attr);
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_stream.cc b/test/syscalls/linux/socket_non_stream.cc
index d91c5ed39..c61817f14 100644
--- a/test/syscalls/linux/socket_non_stream.cc
+++ b/test/syscalls/linux/socket_non_stream.cc
@@ -113,7 +113,7 @@ TEST_P(NonStreamSocketPairTest, RecvmsgMsghdrFlagMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were updated.
- EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
}
// Stream sockets allow data sent with multiple sends to be peeked at in a
@@ -193,7 +193,7 @@ TEST_P(NonStreamSocketPairTest, MsgTruncTruncationRecvmsgMsghdrFlagMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were updated.
- EXPECT_EQ(msg.msg_flags, MSG_TRUNC);
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
}
TEST_P(NonStreamSocketPairTest, MsgTruncSameSize) {
@@ -224,5 +224,114 @@ TEST_P(NonStreamSocketPairTest, MsgTruncNotFull) {
EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
}
+// This test tests reading from a socket with MSG_TRUNC and a zero length
+// receive buffer. The user should be able to get the message length.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags.
+ EXPECT_EQ(msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer. The user should be able to get the message length
+// without reading data off the socket.
+TEST_P(NonStreamSocketPairTest, RecvmsgMsgTruncMsgPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // The syscall succeeds returning the full size of the message on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is set on msghdr flags because the receive buffer is
+ // smaller than the message size.
+ EXPECT_EQ(peek_msg.msg_flags & MSG_TRUNC, MSG_TRUNC);
+
+ char received_data[sizeof(sent_data)] = {};
+
+ struct iovec received_iov;
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = sizeof(received_data);
+ struct msghdr received_msg = {};
+ received_msg.msg_flags = -1;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+
+ // Next we can read the actual data.
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &received_msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ EXPECT_EQ(0, memcmp(sent_data, received_data, sizeof(sent_data)));
+
+ // Check that MSG_TRUNC is not set on msghdr flags because we read the whole
+ // message.
+ EXPECT_EQ(received_msg.msg_flags & MSG_TRUNC, 0);
+}
+
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The user should be able to get an
+// EAGAIN or EWOULDBLOCK error response.
+TEST_P(NonStreamSocketPairTest, RecvmsgTruncPeekDontwaitZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't send any data on the socket.
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ // recvmsg fails with EAGAIN because no data is available on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK | MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc
index 62d87c1af..b052f6e61 100644
--- a/test/syscalls/linux/socket_non_stream_blocking.cc
+++ b/test/syscalls/linux/socket_non_stream_blocking.cc
@@ -25,6 +25,7 @@
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -44,5 +45,41 @@ TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
SyscallSucceedsWithValue(sizeof(sent_data)));
}
+// This test tests reading from a socket with MSG_TRUNC | MSG_PEEK and a zero
+// length receive buffer and MSG_DONTWAIT. The recvmsg call should block on
+// reading the data.
+TEST_P(BlockingNonStreamSocketPairTest,
+ RecvmsgTruncPeekDontwaitZeroLenBlocking) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ // NOTE: We don't initially send any data on the socket.
+ const int data_size = 10;
+ char sent_data[data_size];
+ RandomizeBuffer(sent_data, data_size);
+
+ // The receive buffer is of zero length.
+ char peek_data[0] = {};
+
+ struct iovec peek_iov;
+ peek_iov.iov_base = peek_data;
+ peek_iov.iov_len = sizeof(peek_data);
+ struct msghdr peek_msg = {};
+ peek_msg.msg_flags = -1;
+ peek_msg.msg_iov = &peek_iov;
+ peek_msg.msg_iovlen = 1;
+
+ ScopedThread t([&]() {
+ // The syscall succeeds returning the full size of the message on the
+ // socket. This should block until there is data on the socket.
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &peek_msg,
+ MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(data_size));
+ });
+
+ absl::SleepFor(absl::Seconds(1));
+ ASSERT_THAT(RetryEINTR(send)(sockets->first_fd(), sent_data, data_size, 0),
+ SyscallSucceedsWithValue(data_size));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_stream.cc b/test/syscalls/linux/socket_stream.cc
index 346443f96..6522b2e01 100644
--- a/test/syscalls/linux/socket_stream.cc
+++ b/test/syscalls/linux/socket_stream.cc
@@ -104,7 +104,60 @@ TEST_P(StreamSocketPairTest, RecvmsgMsghdrFlagsNoMsgTrunc) {
EXPECT_EQ(0, memcmp(received_data, sent_data, sizeof(received_data)));
// Check that msghdr flags were cleared (MSG_TRUNC was not set).
- EXPECT_EQ(msg.msg_flags, 0);
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
+}
+
+TEST_P(StreamSocketPairTest, RecvmsgTruncPeekZeroLen) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[10];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+ ASSERT_THAT(
+ RetryEINTR(send)(sockets->first_fd(), sent_data, sizeof(sent_data), 0),
+ SyscallSucceedsWithValue(sizeof(sent_data)));
+
+ char received_data[0] = {};
+
+ struct iovec iov;
+ iov.iov_base = received_data;
+ iov.iov_len = sizeof(received_data);
+ struct msghdr msg = {};
+ msg.msg_flags = -1;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(
+ RetryEINTR(recvmsg)(sockets->second_fd(), &msg, MSG_TRUNC | MSG_PEEK),
+ SyscallSucceedsWithValue(0));
+
+ // Check that msghdr flags were cleared (MSG_TRUNC was not set).
+ ASSERT_EQ(msg.msg_flags & MSG_TRUNC, 0);
}
TEST_P(StreamSocketPairTest, MsgTrunc) {
diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc
index e9cc082bf..538ee2268 100644
--- a/test/syscalls/linux/socket_stream_blocking.cc
+++ b/test/syscalls/linux/socket_stream_blocking.cc
@@ -32,38 +32,38 @@ namespace gvisor {
namespace testing {
TEST_P(BlockingStreamSocketPairTest, BlockPartialWriteClosed) {
- // FIXME(b/35921550): gVisor doesn't support SO_SNDBUF on UDS, nor does it
- // enforce any limit; it will write arbitrary amounts of data without
- // blocking.
- SKIP_IF(IsRunningOnGvisor());
-
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
-
- int buffer_size;
- socklen_t length = sizeof(buffer_size);
- ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF,
- &buffer_size, &length),
- SyscallSucceeds());
-
- int wfd = sockets->first_fd();
- ScopedThread t([wfd, buffer_size]() {
- std::vector<char> buf(2 * buffer_size);
- // Write more than fits in the buffer. Blocks then returns partial write
- // when the other end is closed. The next call returns EPIPE.
- //
- // N.B. writes occur in chunks, so we may see less than buffer_size from
- // the first call.
- ASSERT_THAT(write(wfd, buf.data(), buf.size()),
- SyscallSucceedsWithValue(::testing::Gt(0)));
- ASSERT_THAT(write(wfd, buf.data(), buf.size()),
- ::testing::AnyOf(SyscallFailsWithErrno(EPIPE),
- SyscallFailsWithErrno(ECONNRESET)));
- });
-
- // Leave time for write to become blocked.
- absl::SleepFor(absl::Seconds(1));
-
- ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ // FIXME(b/35921550): gVisor doesn't support SO_SNDBUF on UDS, nor does it
+ // enforce any limit; it will write arbitrary amounts of data without
+ // blocking.
+ SKIP_IF(IsRunningOnGvisor());
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int buffer_size;
+ socklen_t length = sizeof(buffer_size);
+ ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDBUF,
+ &buffer_size, &length),
+ SyscallSucceeds());
+
+ int wfd = sockets->first_fd();
+ ScopedThread t([wfd, buffer_size]() {
+ std::vector<char> buf(2 * buffer_size);
+ // Write more than fits in the buffer. Blocks then returns partial write
+ // when the other end is closed. The next call returns EPIPE.
+ //
+ // N.B. writes occur in chunks, so we may see less than buffer_size from
+ // the first call.
+ ASSERT_THAT(write(wfd, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(::testing::Gt(0)));
+ ASSERT_THAT(write(wfd, buf.data(), buf.size()),
+ ::testing::AnyOf(SyscallFailsWithErrno(EPIPE),
+ SyscallFailsWithErrno(ECONNRESET)));
+ });
+
+ // Leave time for write to become blocked.
+ absl::SleepFor(absl::Seconds(1));
+
+ ASSERT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
}
// Random save may interrupt the call to sendmsg() in SendLargeSendMsg(),
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index eff7d577e..53b678e94 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -18,10 +18,13 @@
#include <poll.h>
#include <sys/socket.h>
+#include <memory>
+
#include "gtest/gtest.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
+#include "absl/types/optional.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
@@ -109,7 +112,10 @@ Creator<SocketPair> AcceptBindSocketPairCreator(bool abstract, int domain,
MaybeSave(); // Unlinked path.
}
- return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
+ // accepted is before connected to destruct connected before accepted.
+ // Destructors for nonstatic member objects are called in the reverse order
+ // in which they appear in the class declaration.
+ return absl::make_unique<AddrFDSocketPair>(accepted, connected, bind_addr,
extra_addr);
};
}
@@ -311,11 +317,16 @@ PosixErrorOr<T> BindIP(int fd, bool dual_stack) {
}
template <typename T>
-PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair(
- int bound, int connected, int type, bool dual_stack) {
- ASSIGN_OR_RETURN_ERRNO(T bind_addr, BindIP<T>(bound, dual_stack));
- RETURN_ERROR_IF_SYSCALL_FAIL(listen(bound, /* backlog = */ 5));
+PosixErrorOr<T> TCPBindAndListen(int fd, bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(T addr, BindIP<T>(fd, dual_stack));
+ RETURN_ERROR_IF_SYSCALL_FAIL(listen(fd, /* backlog = */ 5));
+ return addr;
+}
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>>
+CreateTCPConnectAcceptSocketPair(int bound, int connected, int type,
+ bool dual_stack, T bind_addr) {
int connect_result = 0;
RETURN_ERROR_IF_SYSCALL_FAIL(
(connect_result = RetryEINTR(connect)(
@@ -353,19 +364,25 @@ PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair(
}
MaybeSave(); // Successful accept.
- // FIXME(b/110484944)
- if (connect_result == -1) {
- absl::SleepFor(absl::Seconds(1));
- }
+ T extra_addr = {};
+ LocalhostAddr(&extra_addr, dual_stack);
+ return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
+ extra_addr);
+}
+
+template <typename T>
+PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateTCPAcceptBindSocketPair(
+ int bound, int connected, int type, bool dual_stack) {
+ ASSIGN_OR_RETURN_ERRNO(T bind_addr, TCPBindAndListen<T>(bound, dual_stack));
+
+ auto result = CreateTCPConnectAcceptSocketPair(bound, connected, type,
+ dual_stack, bind_addr);
// Cleanup no longer needed resources.
RETURN_ERROR_IF_SYSCALL_FAIL(close(bound));
MaybeSave(); // Successful close.
- T extra_addr = {};
- LocalhostAddr(&extra_addr, dual_stack);
- return absl::make_unique<AddrFDSocketPair>(connected, accepted, bind_addr,
- extra_addr);
+ return result;
}
Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
@@ -389,6 +406,63 @@ Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
};
}
+Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator(
+ int domain, int type, int protocol, bool dual_stack) {
+ // These are lazily initialized below, on the first call to the returned
+ // lambda. These values are private to each returned lambda, but shared across
+ // invocations of a specific lambda.
+ //
+ // The sharing allows pairs created with the same parameters to share a
+ // listener. This prevents future connects from failing if the connecting
+ // socket selects a port which had previously been used by a listening socket
+ // that still has some connections in TIME-WAIT.
+ //
+ // The lazy initialization is to avoid creating sockets during parameter
+ // enumeration. This is important because parameters are enumerated during the
+ // build process where networking may not be available.
+ auto listener = std::make_shared<absl::optional<int>>(absl::optional<int>());
+ auto addr4 = std::make_shared<absl::optional<sockaddr_in>>(
+ absl::optional<sockaddr_in>());
+ auto addr6 = std::make_shared<absl::optional<sockaddr_in6>>(
+ absl::optional<sockaddr_in6>());
+
+ return [=]() -> PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> {
+ int connected;
+ RETURN_ERROR_IF_SYSCALL_FAIL(connected = socket(domain, type, protocol));
+ MaybeSave(); // Successful socket creation.
+
+ // Share the listener across invocations.
+ if (!listener->has_value()) {
+ int fd = socket(domain, type, protocol);
+ if (fd < 0) {
+ return PosixError(errno, absl::StrCat("socket(", domain, ", ", type,
+ ", ", protocol, ")"));
+ }
+ listener->emplace(fd);
+ MaybeSave(); // Successful socket creation.
+ }
+
+ // Bind the listener once, but create a new connect/accept pair each
+ // time.
+ if (domain == AF_INET) {
+ if (!addr4->has_value()) {
+ addr4->emplace(
+ TCPBindAndListen<sockaddr_in>(listener->value(), dual_stack)
+ .ValueOrDie());
+ }
+ return CreateTCPConnectAcceptSocketPair(listener->value(), connected,
+ type, dual_stack, addr4->value());
+ }
+ if (!addr6->has_value()) {
+ addr6->emplace(
+ TCPBindAndListen<sockaddr_in6>(listener->value(), dual_stack)
+ .ValueOrDie());
+ }
+ return CreateTCPConnectAcceptSocketPair(listener->value(), connected, type,
+ dual_stack, addr6->value());
+ };
+}
+
template <typename T>
PosixErrorOr<std::unique_ptr<AddrFDSocketPair>> CreateUDPBoundSocketPair(
int sock1, int sock2, int type, bool dual_stack) {
@@ -518,8 +592,8 @@ size_t CalculateUnixSockAddrLen(const char* sun_path) {
if (sun_path[0] == 0) {
return sizeof(sockaddr_un);
}
- // Filesystem addresses use the address length plus the 2 byte sun_family and
- // null terminator.
+ // Filesystem addresses use the address length plus the 2 byte sun_family
+ // and null terminator.
return strlen(sun_path) + 3;
}
@@ -726,6 +800,24 @@ TestAddress V4MappedLoopback() {
return t;
}
+TestAddress V4Multicast() {
+ TestAddress t("V4Multicast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ inet_addr(kMulticastAddress);
+ return t;
+}
+
+TestAddress V4Broadcast() {
+ TestAddress t("V4Broadcast");
+ t.addr.ss_family = AF_INET;
+ t.addr_len = sizeof(sockaddr_in);
+ reinterpret_cast<sockaddr_in*>(&t.addr)->sin_addr.s_addr =
+ htonl(INADDR_BROADCAST);
+ return t;
+}
+
TestAddress V6Any() {
TestAddress t("V6Any");
t.addr.ss_family = AF_INET6;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index be38907c2..734b48b96 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -114,6 +114,9 @@ class FDSocketPair : public SocketPair {
public:
FDSocketPair(int first_fd, int second_fd)
: first_(first_fd), second_(second_fd) {}
+ FDSocketPair(std::unique_ptr<FileDescriptor> first_fd,
+ std::unique_ptr<FileDescriptor> second_fd)
+ : first_(first_fd->release()), second_(second_fd->release()) {}
int first_fd() const override { return first_.get(); }
int second_fd() const override { return second_.get(); }
@@ -270,6 +273,12 @@ Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
int protocol,
bool dual_stack);
+// TCPAcceptBindPersistentListenerSocketPairCreator is like
+// TCPAcceptBindSocketPairCreator, except it uses the same listening socket to
+// create all SocketPairs.
+Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator(
+ int domain, int type, int protocol, bool dual_stack);
+
// UDPBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by invoking the bind() and connect() syscalls on UDP
// sockets.
@@ -475,10 +484,15 @@ struct TestAddress {
: description(std::move(description)), addr(), addr_len() {}
};
+constexpr char kMulticastAddress[] = "224.0.2.1";
+constexpr char kBroadcastAddress[] = "255.255.255.255";
+
TestAddress V4Any();
+TestAddress V4Broadcast();
TestAddress V4Loopback();
TestAddress V4MappedAny();
TestAddress V4MappedLoopback();
+TestAddress V4Multicast();
TestAddress V6Any();
TestAddress V6Loopback();
diff --git a/test/syscalls/linux/socket_unix.cc b/test/syscalls/linux/socket_unix.cc
index 8a28202a8..591cab3fd 100644
--- a/test/syscalls/linux/socket_unix.cc
+++ b/test/syscalls/linux/socket_unix.cc
@@ -65,6 +65,21 @@ TEST_P(UnixSocketPairTest, BindToBadName) {
SyscallFailsWithErrno(ENOENT));
}
+TEST_P(UnixSocketPairTest, BindToBadFamily) {
+ auto pair =
+ ASSERT_NO_ERRNO_AND_VALUE(UnixDomainSocketPair(SOCK_SEQPACKET).Create());
+
+ constexpr char kBadName[] = "/some/path/that/does/not/exist";
+ sockaddr_un sockaddr;
+ sockaddr.sun_family = AF_INET;
+ memcpy(sockaddr.sun_path, kBadName, sizeof(kBadName));
+
+ EXPECT_THAT(
+ bind(pair->first_fd(), reinterpret_cast<struct sockaddr*>(&sockaddr),
+ sizeof(sockaddr)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST_P(UnixSocketPairTest, RecvmmsgTimeoutAfterRecv) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[10];
@@ -241,8 +256,9 @@ TEST_P(UnixSocketPairTest, ShutdownWrite) {
}
TEST_P(UnixSocketPairTest, SocketReopenFromProcfs) {
- // TODO(b/122310852): We should be returning ENXIO and NOT EIO.
- SKIP_IF(IsRunningOnGvisor());
+ // TODO(gvisor.dev/issue/1624): In VFS1, we return EIO instead of ENXIO (see
+ // b/122310852). Remove this skip once VFS1 is deleted.
+ SKIP_IF(IsRunningWithVFS1());
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
// Opening a socket pair via /proc/self/fd/X is a ENXIO.
diff --git a/test/syscalls/linux/socket_unix_abstract_nonblock.cc b/test/syscalls/linux/socket_unix_abstract_nonblock.cc
index be31ab2a7..8bef76b67 100644
--- a/test/syscalls/linux/socket_unix_abstract_nonblock.cc
+++ b/test/syscalls/linux/socket_unix_abstract_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingAbstractUnixSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_blocking_local.cc b/test/syscalls/linux/socket_unix_blocking_local.cc
index 1994139e6..77cb8c6d6 100644
--- a/test/syscalls/linux/socket_unix_blocking_local.cc
+++ b/test/syscalls/linux/socket_unix_blocking_local.cc
@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "test/syscalls/linux/socket_blocking.h"
-
#include <vector>
+#include "test/syscalls/linux/socket_blocking.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(
@@ -40,5 +40,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingUnixDomainSockets, BlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_cmsg.cc b/test/syscalls/linux/socket_unix_cmsg.cc
index 1159c5229..a16899493 100644
--- a/test/syscalls/linux/socket_unix_cmsg.cc
+++ b/test/syscalls/linux/socket_unix_cmsg.cc
@@ -149,6 +149,35 @@ TEST_P(UnixSocketPairCmsgTest, BadFDPass) {
SyscallFailsWithErrno(EBADF));
}
+TEST_P(UnixSocketPairCmsgTest, ShortCmsg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char sent_data[20];
+ RandomizeBuffer(sent_data, sizeof(sent_data));
+
+ int sent_fd = -1;
+
+ struct msghdr msg = {};
+ char control[CMSG_SPACE(sizeof(sent_fd))];
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_len = 1;
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ memcpy(CMSG_DATA(cmsg), &sent_fd, sizeof(sent_fd));
+
+ struct iovec iov;
+ iov.iov_base = sent_data;
+ iov.iov_len = sizeof(sent_data);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
// BasicFDPassNoSpace starts off by sending a single FD just like BasicFDPass.
// The difference is that when calling recvmsg, no space for FDs is provided,
// only space for the cmsg header.
diff --git a/test/syscalls/linux/socket_unix_dgram.cc b/test/syscalls/linux/socket_unix_dgram.cc
index 3245cf7c9..af0df4fb4 100644
--- a/test/syscalls/linux/socket_unix_dgram.cc
+++ b/test/syscalls/linux/socket_unix_dgram.cc
@@ -16,6 +16,7 @@
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
diff --git a/test/syscalls/linux/socket_unix_dgram_local.cc b/test/syscalls/linux/socket_unix_dgram_local.cc
index 9134fcdf7..31d2d5216 100644
--- a/test/syscalls/linux/socket_unix_dgram_local.cc
+++ b/test/syscalls/linux/socket_unix_dgram_local.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(VecCat<SocketPairKind>(
@@ -52,5 +53,6 @@ INSTANTIATE_TEST_SUITE_P(
DgramUnixSockets, NonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc
index cd4fba25c..2db8b68d3 100644
--- a/test/syscalls/linux/socket_unix_dgram_non_blocking.cc
+++ b/test/syscalls/linux/socket_unix_dgram_non_blocking.cc
@@ -14,6 +14,7 @@
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
diff --git a/test/syscalls/linux/socket_unix_domain.cc b/test/syscalls/linux/socket_unix_domain.cc
index fa3efc7f8..f7dff8b4d 100644
--- a/test/syscalls/linux/socket_unix_domain.cc
+++ b/test/syscalls/linux/socket_unix_domain.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, AllSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
index 8ba7af971..6700b4d90 100644
--- a/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
+++ b/test/syscalls/linux/socket_unix_filesystem_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingFilesystemUnixSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_non_stream.cc b/test/syscalls/linux/socket_unix_non_stream.cc
index 276a94eb8..884319e1d 100644
--- a/test/syscalls/linux/socket_unix_non_stream.cc
+++ b/test/syscalls/linux/socket_unix_non_stream.cc
@@ -109,7 +109,7 @@ PosixErrorOr<std::vector<Mapping>> CreateFragmentedRegion(const int size,
}
// A contiguous iov that is heavily fragmented in FileMem can still be sent
-// successfully.
+// successfully. See b/115833655.
TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -165,7 +165,7 @@ TEST_P(UnixNonStreamSocketPairTest, FragmentedSendMsg) {
}
// A contiguous iov that is heavily fragmented in FileMem can still be received
-// into successfully.
+// into successfully. Regression test for b/115833655.
TEST_P(UnixNonStreamSocketPairTest, FragmentedRecvMsg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
diff --git a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
index da762cd83..fddcdf1c5 100644
--- a/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
+++ b/test/syscalls/linux/socket_unix_non_stream_blocking_local.cc
@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "test/syscalls/linux/socket_non_stream_blocking.h"
-
#include <vector>
+#include "test/syscalls/linux/socket_non_stream_blocking.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(
@@ -37,5 +37,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingNonStreamUnixSockets, BlockingNonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc
index 411fb4518..85999db04 100644
--- a/test/syscalls/linux/socket_unix_pair.cc
+++ b/test/syscalls/linux/socket_unix_pair.cc
@@ -22,6 +22,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(ApplyVec<SocketPairKind>(
@@ -38,5 +39,6 @@ INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, UnixSocketPairCmsgTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_pair_nonblock.cc b/test/syscalls/linux/socket_unix_pair_nonblock.cc
index 3135d325f..281410a9a 100644
--- a/test/syscalls/linux/socket_unix_pair_nonblock.cc
+++ b/test/syscalls/linux/socket_unix_pair_nonblock.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return ApplyVec<SocketPairKind>(
@@ -33,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingUnixSockets, NonBlockingSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_seqpacket.cc b/test/syscalls/linux/socket_unix_seqpacket.cc
index 60fa9e38a..6d03df4d9 100644
--- a/test/syscalls/linux/socket_unix_seqpacket.cc
+++ b/test/syscalls/linux/socket_unix_seqpacket.cc
@@ -16,6 +16,7 @@
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
@@ -42,6 +43,24 @@ TEST_P(SeqpacketUnixSocketPairTest, ReadOneSideClosed) {
SyscallSucceedsWithValue(0));
}
+TEST_P(SeqpacketUnixSocketPairTest, Sendto) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "\0nonexistent";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr,
+ sizeof(addr)),
+ SyscallSucceedsWithValue(3));
+
+ char data[10] = {};
+ ASSERT_THAT(read(sockets->first_fd(), data, sizeof(data)),
+ SyscallSucceedsWithValue(3));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/socket_unix_seqpacket_local.cc b/test/syscalls/linux/socket_unix_seqpacket_local.cc
index dff75a532..69a5f150d 100644
--- a/test/syscalls/linux/socket_unix_seqpacket_local.cc
+++ b/test/syscalls/linux/socket_unix_seqpacket_local.cc
@@ -23,6 +23,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(VecCat<SocketPairKind>(
@@ -52,5 +53,6 @@ INSTANTIATE_TEST_SUITE_P(
SeqpacketUnixSockets, UnixNonStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc
index 563467365..99e77b89e 100644
--- a/test/syscalls/linux/socket_unix_stream.cc
+++ b/test/syscalls/linux/socket_unix_stream.cc
@@ -89,6 +89,20 @@ TEST_P(StreamUnixSocketPairTest, ReadOneSideClosedWithUnreadData) {
SyscallFailsWithErrno(ECONNRESET));
}
+TEST_P(StreamUnixSocketPairTest, Sendto) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct sockaddr_un addr = {};
+ addr.sun_family = AF_UNIX;
+ constexpr char kPath[] = "\0nonexistent";
+ memcpy(addr.sun_path, kPath, sizeof(kPath));
+
+ constexpr char kStr[] = "abc";
+ ASSERT_THAT(sendto(sockets->second_fd(), kStr, 3, 0, (struct sockaddr*)&addr,
+ sizeof(addr)),
+ SyscallFailsWithErrno(EISCONN));
+}
+
INSTANTIATE_TEST_SUITE_P(
AllUnixDomainSockets, StreamUnixSocketPairTest,
::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>(
diff --git a/test/syscalls/linux/socket_unix_stream_blocking_local.cc b/test/syscalls/linux/socket_unix_stream_blocking_local.cc
index fa0a9d367..8429bd429 100644
--- a/test/syscalls/linux/socket_unix_stream_blocking_local.cc
+++ b/test/syscalls/linux/socket_unix_stream_blocking_local.cc
@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "test/syscalls/linux/socket_stream_blocking.h"
-
#include <vector>
+#include "test/syscalls/linux/socket_stream_blocking.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -35,5 +35,6 @@ INSTANTIATE_TEST_SUITE_P(
BlockingStreamUnixSockets, BlockingStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_local.cc b/test/syscalls/linux/socket_unix_stream_local.cc
index 65eef1a81..a7e3449a9 100644
--- a/test/syscalls/linux/socket_unix_stream_local.cc
+++ b/test/syscalls/linux/socket_unix_stream_local.cc
@@ -21,6 +21,7 @@
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return VecCat<SocketPairKind>(
@@ -42,5 +43,6 @@ INSTANTIATE_TEST_SUITE_P(
StreamUnixSockets, StreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
index ec777c59f..4b763c8e2 100644
--- a/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
+++ b/test/syscalls/linux/socket_unix_stream_nonblock_local.cc
@@ -11,16 +11,16 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "test/syscalls/linux/socket_stream_nonblock.h"
-
#include <vector>
+#include "test/syscalls/linux/socket_stream_nonblock.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include "test/util/test_util.h"
namespace gvisor {
namespace testing {
+namespace {
std::vector<SocketPairKind> GetSocketPairs() {
return {
@@ -34,5 +34,6 @@ INSTANTIATE_TEST_SUITE_P(
NonBlockingStreamUnixSockets, NonBlockingStreamSocketPairTest,
::testing::ValuesIn(IncludeReversals(GetSocketPairs())));
+} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_unix_unbound_abstract.cc b/test/syscalls/linux/socket_unix_unbound_abstract.cc
index 7f5816ace..8b1762000 100644
--- a/test/syscalls/linux/socket_unix_unbound_abstract.cc
+++ b/test/syscalls/linux/socket_unix_unbound_abstract.cc
@@ -14,6 +14,7 @@
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc
index b14f24086..cab912152 100644
--- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc
+++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc
@@ -14,6 +14,7 @@
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
diff --git a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc
index 50ffa1d04..cb99030f5 100644
--- a/test/syscalls/linux/socket_unix_unbound_seqpacket.cc
+++ b/test/syscalls/linux/socket_unix_unbound_seqpacket.cc
@@ -14,6 +14,7 @@
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
diff --git a/test/syscalls/linux/socket_unix_unbound_stream.cc b/test/syscalls/linux/socket_unix_unbound_stream.cc
index 344918c34..f185dded3 100644
--- a/test/syscalls/linux/socket_unix_unbound_stream.cc
+++ b/test/syscalls/linux/socket_unix_unbound_stream.cc
@@ -14,6 +14,7 @@
#include <stdio.h>
#include <sys/un.h>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index 85232cb1f..08fc4b1b7 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <fcntl.h>
+#include <linux/unistd.h>
#include <sys/eventfd.h>
#include <sys/resource.h>
#include <sys/sendfile.h>
@@ -60,6 +61,62 @@ TEST(SpliceTest, TwoRegularFiles) {
SyscallFailsWithErrno(EINVAL));
}
+int memfd_create(const std::string& name, unsigned int flags) {
+ return syscall(__NR_memfd_create, name.c_str(), flags);
+}
+
+TEST(SpliceTest, NegativeOffset) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the output file as write only.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("negative", 0), SyscallSucceeds());
+ const FileDescriptor out_fd(fd);
+
+ loff_t out_offset = 0xffffffffffffffffull;
+ constexpr int kSize = 2;
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Write offset + size overflows int64.
+//
+// This is a regression test for b/148041624.
+TEST(SpliceTest, WriteOverflow) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill the pipe.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the output file.
+ int fd;
+ EXPECT_THAT(fd = memfd_create("overflow", 0), SyscallSucceeds());
+ const FileDescriptor out_fd(fd);
+
+ // out_offset + kSize overflows INT64_MAX.
+ loff_t out_offset = 0x7ffffffffffffffeull;
+ constexpr int kSize = 3;
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kSize, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
TEST(SpliceTest, SamePipe) {
// Create a new pipe.
int fds[2];
@@ -373,6 +430,55 @@ TEST(SpliceTest, TwoPipes) {
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
}
+TEST(SpliceTest, TwoPipesCircular) {
+ // This test deadlocks the sentry on VFS1 because VFS1 splice ordering is
+ // based on fs.File.UniqueID, which does not prevent circular ordering between
+ // e.g. inode-level locks taken by fs.FileOperations.
+ SKIP_IF(IsRunningWithVFS1());
+
+ // Create two pipes.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor first_rfd(fds[0]);
+ const FileDescriptor first_wfd(fds[1]);
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor second_rfd(fds[0]);
+ const FileDescriptor second_wfd(fds[1]);
+
+ // On Linux, each pipe is normally limited to
+ // include/linux/pipe_fs_i.h:PIPE_DEF_BUFFERS buffers worth of data.
+ constexpr size_t PIPE_DEF_BUFFERS = 16;
+
+ // Write some data to each pipe. Below we splice 1 byte at a time between
+ // pipes, which very quickly causes each byte to be stored in a separate
+ // buffer, so we must ensure that the total amount of data in the system is <=
+ // PIPE_DEF_BUFFERS bytes.
+ std::vector<char> buf(PIPE_DEF_BUFFERS / 2);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(first_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(write(second_wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(buf.size()));
+
+ // Have another thread splice from the second pipe to the first, while we
+ // splice from the first to the second. The test passes if this does not
+ // deadlock.
+ const int kIterations = 1000;
+ DisableSave ds;
+ ScopedThread t([&]() {
+ for (int i = 0; i < kIterations; i++) {
+ ASSERT_THAT(
+ splice(second_rfd.get(), nullptr, first_wfd.get(), nullptr, 1, 0),
+ SyscallSucceedsWithValue(1));
+ }
+ });
+ for (int i = 0; i < kIterations; i++) {
+ ASSERT_THAT(
+ splice(first_rfd.get(), nullptr, second_wfd.get(), nullptr, 1, 0),
+ SyscallSucceedsWithValue(1));
+ }
+}
+
TEST(SpliceTest, Blocking) {
// Create two new pipes.
int first[2], second[2];
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index 30de2f8ff..2503960f3 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -34,6 +34,13 @@
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+#ifndef AT_STATX_FORCE_SYNC
+#define AT_STATX_FORCE_SYNC 0x2000
+#endif
+#ifndef AT_STATX_DONT_SYNC
+#define AT_STATX_DONT_SYNC 0x4000
+#endif
+
namespace gvisor {
namespace testing {
@@ -557,6 +564,8 @@ TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) {
#ifndef SYS_statx
#if defined(__x86_64__)
#define SYS_statx 332
+#elif defined(__aarch64__)
+#define SYS_statx 291
#else
#error "Unknown architecture"
#endif
@@ -599,13 +608,13 @@ struct kernel_statx {
uint64_t __spare2[14];
};
-int statx(int dirfd, const char *pathname, int flags, unsigned int mask,
- struct kernel_statx *statxbuf) {
+int statx(int dirfd, const char* pathname, int flags, unsigned int mask,
+ struct kernel_statx* statxbuf) {
return syscall(SYS_statx, dirfd, pathname, flags, mask, statxbuf);
}
TEST_F(StatTest, StatxAbsPath) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
struct kernel_statx stx;
@@ -615,7 +624,7 @@ TEST_F(StatTest, StatxAbsPath) {
}
TEST_F(StatTest, StatxRelPathDirFD) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
struct kernel_statx stx;
@@ -629,7 +638,7 @@ TEST_F(StatTest, StatxRelPathDirFD) {
}
TEST_F(StatTest, StatxRelPathCwd) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds());
@@ -641,7 +650,7 @@ TEST_F(StatTest, StatxRelPathCwd) {
}
TEST_F(StatTest, StatxEmptyPath) {
- SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, 0) < 0 &&
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
const auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDONLY));
@@ -651,6 +660,60 @@ TEST_F(StatTest, StatxEmptyPath) {
EXPECT_TRUE(S_ISREG(stx.stx_mode));
}
+TEST_F(StatTest, StatxDoesNotRejectExtraneousMaskBits) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set all mask bits except for STATX__RESERVED.
+ uint mask = 0xffffffff & ~0x80000000;
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, mask, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxRejectsReservedMaskBit) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ // Set STATX__RESERVED in the mask.
+ EXPECT_THAT(statx(-1, test_file_name_.c_str(), 0, 0x80000000, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(StatTest, StatxSymlink) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ std::string parent_dir = "/tmp";
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(parent_dir, test_file_name_));
+ std::string p = link.path();
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), AT_SYMLINK_NOFOLLOW, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISLNK(stx.stx_mode));
+ EXPECT_THAT(statx(AT_FDCWD, p.c_str(), 0, STATX_ALL, &stx),
+ SyscallSucceeds());
+ EXPECT_TRUE(S_ISREG(stx.stx_mode));
+}
+
+TEST_F(StatTest, StatxInvalidFlags) {
+ SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
+ errno == ENOSYS);
+
+ struct kernel_statx stx;
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(), 12345, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Sync flags are mutually exclusive.
+ EXPECT_THAT(statx(AT_FDCWD, test_file_name_.c_str(),
+ AT_STATX_FORCE_SYNC | AT_STATX_DONT_SYNC, 0, &stx),
+ SyscallFailsWithErrno(EINVAL));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc
index 7e73325bf..4afed6d08 100644
--- a/test/syscalls/linux/sticky.cc
+++ b/test/syscalls/linux/sticky.cc
@@ -40,10 +40,17 @@ namespace {
TEST(StickyTest, StickyBitPermDenied) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
- auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
- std::string path = JoinPath(dir.path(), "NewDir");
- ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
+
+ // After changing credentials below, we need to use an open fd to make
+ // modifications in the parent dir, because there is no guarantee that we will
+ // still have the ability to open it.
+ const FileDescriptor parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY));
+ ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds());
+ ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds());
+ ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -61,17 +68,31 @@ TEST(StickyTest, StickyBitPermDenied) {
syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
SyscallSucceeds());
- EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(parent_fd.get(), "file", 0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0),
+ SyscallFailsWithErrno(EPERM));
});
}
TEST(StickyTest, StickyBitSameUID) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
- auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
- std::string path = JoinPath(dir.path(), "NewDir");
- ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
+
+ // After changing credentials below, we need to use an open fd to make
+ // modifications in the parent dir, because there is no guarantee that we will
+ // still have the ability to open it.
+ const FileDescriptor parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY));
+ ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds());
+ ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds());
+ ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -87,17 +108,29 @@ TEST(StickyTest, StickyBitSameUID) {
SyscallSucceeds());
// We still have the same EUID.
- EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
+ EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds());
});
}
TEST(StickyTest, StickyBitCapFOWNER) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
- auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
- EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
- std::string path = JoinPath(dir.path(), "NewDir");
- ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds());
+ const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds());
+
+ // After changing credentials below, we need to use an open fd to make
+ // modifications in the parent dir, because there is no guarantee that we will
+ // still have the ability to open it.
+ const FileDescriptor parent_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY));
+ ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds());
+ ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds());
+ ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds());
// Drop privileges and change IDs only in child thread, or else this parent
// thread won't be able to open some log files after the test ends.
@@ -114,7 +147,12 @@ TEST(StickyTest, StickyBitCapFOWNER) {
SyscallSucceeds());
EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true));
- EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
+ EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR),
+ SyscallSucceeds());
+ EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds());
});
}
} // namespace
diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc
index b249ff91f..a17ff62e9 100644
--- a/test/syscalls/linux/symlink.cc
+++ b/test/syscalls/linux/symlink.cc
@@ -20,6 +20,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "absl/time/clock.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
@@ -38,7 +39,7 @@ mode_t FilePermission(const std::string& path) {
}
// Test that name collisions are checked on the new link path, not the source
-// path.
+// path. Regression test for b/31782115.
TEST(SymlinkTest, CanCreateSymlinkWithCachedSourceDirent) {
const std::string srcname = NewTempAbsPath();
const std::string newname = NewTempAbsPath();
@@ -272,6 +273,30 @@ TEST(SymlinkTest, ChmodSymlink) {
EXPECT_EQ(FilePermission(newpath), 0777);
}
+// Test that following a symlink updates the atime on the symlink.
+TEST(SymlinkTest, FollowUpdatesATime) {
+ const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const std::string link = NewTempAbsPath();
+ EXPECT_THAT(symlink(file.path().c_str(), link.c_str()), SyscallSucceeds());
+
+ // Lstat the symlink.
+ struct stat st_before_follow;
+ ASSERT_THAT(lstat(link.c_str(), &st_before_follow), SyscallSucceeds());
+
+ // Let the clock advance.
+ absl::SleepFor(absl::Seconds(1));
+
+ // Open the file via the symlink.
+ int fd;
+ ASSERT_THAT(fd = open(link.c_str(), O_RDWR, 0666), SyscallSucceeds());
+ FileDescriptor fd_closer(fd);
+
+ // Lstat the symlink again, and check that atime is updated.
+ struct stat st_after_follow;
+ ASSERT_THAT(lstat(link.c_str(), &st_after_follow), SyscallSucceeds());
+ EXPECT_LT(st_before_follow.st_atime, st_after_follow.st_atime);
+}
+
class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {};
// Test that creating an existing symlink with creat will create the target.
diff --git a/test/syscalls/linux/sync.cc b/test/syscalls/linux/sync.cc
index fe479390d..8aa2525a9 100644
--- a/test/syscalls/linux/sync.cc
+++ b/test/syscalls/linux/sync.cc
@@ -14,10 +14,9 @@
#include <fcntl.h>
#include <stdio.h>
-#include <unistd.h>
-
#include <sys/syscall.h>
#include <unistd.h>
+
#include <string>
#include "gtest/gtest.h"
diff --git a/test/syscalls/linux/sysret.cc b/test/syscalls/linux/sysret.cc
index 819fa655a..19ffbd85b 100644
--- a/test/syscalls/linux/sysret.cc
+++ b/test/syscalls/linux/sysret.cc
@@ -14,6 +14,8 @@
// Tests to verify that the behavior of linux and gvisor matches when
// 'sysret' returns to bad (aka non-canonical) %rip or %rsp.
+
+#include <linux/elf.h>
#include <sys/ptrace.h>
#include <sys/user.h>
@@ -32,6 +34,7 @@ constexpr uint64_t kNonCanonicalRsp = 0xFFFF000000000000;
class SysretTest : public ::testing::Test {
protected:
struct user_regs_struct regs_;
+ struct iovec iov;
pid_t child_;
void SetUp() override {
@@ -48,10 +51,15 @@ class SysretTest : public ::testing::Test {
// Parent.
int status;
+ memset(&iov, 0, sizeof(iov));
ASSERT_THAT(pid, SyscallSucceeds()); // Might still be < 0.
ASSERT_THAT(waitpid(pid, &status, 0), SyscallSucceedsWithValue(pid));
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
- ASSERT_THAT(ptrace(PTRACE_GETREGS, pid, 0, &regs_), SyscallSucceeds());
+
+ iov.iov_base = &regs_;
+ iov.iov_len = sizeof(regs_);
+ ASSERT_THAT(ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
child_ = pid;
}
@@ -61,13 +69,27 @@ class SysretTest : public ::testing::Test {
}
void SetRip(uint64_t newrip) {
+#if defined(__x86_64__)
regs_.rip = newrip;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, &regs_), SyscallSucceeds());
+#elif defined(__aarch64__)
+ regs_.pc = newrip;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
}
void SetRsp(uint64_t newrsp) {
+#if defined(__x86_64__)
regs_.rsp = newrsp;
- ASSERT_THAT(ptrace(PTRACE_SETREGS, child_, 0, &regs_), SyscallSucceeds());
+#elif defined(__aarch64__)
+ regs_.sp = newrsp;
+#else
+#error "Unknown architecture"
+#endif
+ ASSERT_THAT(ptrace(PTRACE_SETREGSET, child_, NT_PRSTATUS, &iov),
+ SyscallSucceeds());
}
// Wait waits for the child pid and returns the exit status.
@@ -104,8 +126,15 @@ TEST_F(SysretTest, BadRsp) {
SetRsp(kNonCanonicalRsp);
Detach();
int status = Wait();
+#if defined(__x86_64__)
EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGBUS)
<< "status = " << status;
+#elif defined(__aarch64__)
+ EXPECT_TRUE(WIFSIGNALED(status) && WTERMSIG(status) == SIGSEGV)
+ << "status = " << status;
+#else
+#error "Unknown architecture"
+#endif
}
} // namespace
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index bfa031bce..a6325a761 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -13,6 +13,9 @@
// limitations under the License.
#include <fcntl.h>
+#ifndef __fuchsia__
+#include <linux/filter.h>
+#endif // __fuchsia__
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
@@ -130,6 +133,33 @@ void TcpSocketTest::TearDown() {
}
}
+TEST_P(TcpSocketTest, ConnectOnEstablishedConnection) {
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ ASSERT_THAT(
+ connect(s_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EISCONN));
+ ASSERT_THAT(
+ connect(t_, reinterpret_cast<const struct sockaddr*>(&addr), addrlen),
+ SyscallFailsWithErrno(EISCONN));
+}
+
+TEST_P(TcpSocketTest, ShutdownWriteInTimeWait) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(s_, SHUT_RDWR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter TIME_WAIT.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(TcpSocketTest, ShutdownWriteInFinWait1) {
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+ absl::SleepFor(absl::Seconds(1)); // Wait to enter FIN-WAIT2.
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceeds());
+}
+
TEST_P(TcpSocketTest, DataCoalesced) {
char buf[10];
@@ -231,7 +261,8 @@ TEST_P(TcpSocketTest, ZeroWriteAllowed) {
}
// Test that a non-blocking write with a buffer that is larger than the send
-// buffer size will not actually write the whole thing at once.
+// buffer size will not actually write the whole thing at once. Regression test
+// for b/64438887.
TEST_P(TcpSocketTest, NonblockingLargeWrite) {
// Set the FD to O_NONBLOCK.
int opts;
@@ -394,8 +425,15 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) {
sizeof(tcp_nodelay_flag)),
SyscallSucceeds());
+ // Set a 256KB send/receive buffer.
+ int buf_sz = 1 << 18;
+ EXPECT_THAT(setsockopt(t_, SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+
// Create a large buffer that will be used for sending.
- std::vector<char> buf(10 * sendbuf_size_);
+ std::vector<char> buf(1 << 16);
// Write until we receive an error.
while (RetryEINTR(send)(s_, buf.data(), buf.size(), 0) != -1) {
@@ -405,6 +443,11 @@ TEST_P(TcpSocketTest, PollWithFullBufferBlocks) {
}
// The last error should have been EWOULDBLOCK.
ASSERT_EQ(errno, EWOULDBLOCK);
+
+ // Now polling on the FD with a timeout should return 0 corresponding to no
+ // FDs ready.
+ struct pollfd poll_fd = {s_, POLLOUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10), SyscallSucceedsWithValue(0));
}
TEST_P(TcpSocketTest, MsgTrunc) {
@@ -677,6 +720,30 @@ TEST_P(TcpSocketTest, TcpSCMPriority) {
ASSERT_EQ(cmsg, nullptr);
}
+TEST_P(TcpSocketTest, TimeWaitPollHUP) {
+ shutdown(s_, SHUT_RDWR);
+ ScopedThread t([&]() {
+ constexpr int kTimeout = 10000;
+ constexpr int16_t want_events = POLLHUP;
+ struct pollfd pfd = {
+ .fd = s_,
+ .events = want_events,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ });
+ shutdown(t_, SHUT_RDWR);
+ t.Join();
+ // At this point s_ should be in TIME-WAIT and polling for POLLHUP should
+ // return with 1 FD.
+ constexpr int kTimeout = 10000;
+ constexpr int16_t want_events = POLLHUP;
+ struct pollfd pfd = {
+ .fd = s_,
+ .events = want_events,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+}
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
@@ -789,6 +856,20 @@ TEST_P(TcpSocketTest, FullBuffer) {
t_ = -1;
}
+TEST_P(TcpSocketTest, PollAfterShutdown) {
+ ScopedThread client_thread([this]() {
+ EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {s_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+ });
+
+ EXPECT_THAT(shutdown(t_, SHUT_WR), SyscallSucceedsWithValue(0));
+ struct pollfd poll_fd = {t_, POLLIN | POLLERR | POLLHUP, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+}
+
TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
// Initialize address to the loopback one.
sockaddr_storage addr =
@@ -942,6 +1023,78 @@ TEST_P(SimpleTcpSocketTest, BlockingConnectRefused) {
EXPECT_THAT(close(s.release()), SyscallSucceeds());
}
+// Test that connecting to a non-listening port and thus receiving a RST is
+// handled appropriately by the socket - the port that the socket was bound to
+// is released and the expected error is returned.
+TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) {
+ // Create a socket that is known to not be listening. As is it bound but not
+ // listening, when another socket connects to the port, it will refuse..
+ FileDescriptor bound_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ sockaddr_storage bound_addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t bound_addrlen = sizeof(bound_addr);
+
+ ASSERT_THAT(
+ bind(bound_s.get(), reinterpret_cast<struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallSucceeds());
+
+ // Get the addresses the socket is bound to because the port is chosen by the
+ // stack.
+ ASSERT_THAT(getsockname(bound_s.get(),
+ reinterpret_cast<struct sockaddr*>(&bound_addr),
+ &bound_addrlen),
+ SyscallSucceeds());
+
+ // Create, initialize, and bind the socket that is used to test connecting to
+ // the non-listening port.
+ FileDescriptor client_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // Initialize client address to the loopback one.
+ sockaddr_storage client_addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t client_addrlen = sizeof(client_addr);
+
+ ASSERT_THAT(
+ bind(client_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr),
+ client_addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(getsockname(client_s.get(),
+ reinterpret_cast<struct sockaddr*>(&client_addr),
+ &client_addrlen),
+ SyscallSucceeds());
+
+ // Now the test: connect to the bound but not listening socket with the
+ // client socket. The bound socket should return a RST and cause the client
+ // socket to return an error and clean itself up immediately.
+ // The error being ECONNREFUSED diverges with RFC 793, page 37, but does what
+ // Linux does.
+ ASSERT_THAT(connect(client_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
+
+ FileDescriptor new_s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Test binding to the address from the client socket. This should be okay
+ // if it was dropped correctly.
+ ASSERT_THAT(
+ bind(new_s.get(), reinterpret_cast<struct sockaddr*>(&client_addr),
+ client_addrlen),
+ SyscallSucceeds());
+
+ // Attempt #2, with the new socket and reused addr our connect should fail in
+ // the same way as before, not with an EADDRINUSE.
+ ASSERT_THAT(connect(client_s.get(),
+ reinterpret_cast<const struct sockaddr*>(&bound_addr),
+ bound_addrlen),
+ SyscallFailsWithErrno(ECONNREFUSED));
+}
+
// Test that we get an ECONNREFUSED with a nonblocking socket.
TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) {
FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE(
@@ -1150,6 +1303,346 @@ TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) {
}
}
+TEST_P(SimpleTcpSocketTest, SetTCPUserTimeout) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ {
+ constexpr int kTCPUserTimeout = -1;
+ EXPECT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kTCPUserTimeout, sizeof(kTCPUserTimeout)),
+ SyscallFailsWithErrno(EINVAL));
+ }
+
+ // kTCPUserTimeout is in milliseconds.
+ constexpr int kTCPUserTimeout = 100;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT,
+ &kTCPUserTimeout, sizeof(kTCPUserTimeout)),
+ SyscallSucceedsWithValue(0));
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPUserTimeout);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptNeg) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // -ve TCP_DEFER_ACCEPT is same as setting it to zero.
+ constexpr int kNeg = -1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &kNeg, sizeof(kNeg)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, GetTCPDeferAcceptDefault) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, 0);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPDeferAcceptGreaterThanZero) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // kTCPDeferAccept is in seconds.
+ // NOTE: linux translates seconds to # of retries and back from
+ // #of retries to seconds. Which means only certain values
+ // translate back exactly. That's why we use 3 here, a value of
+ // 5 will result in us getting back 7 instead of 5 in the
+ // getsockopt.
+ constexpr int kTCPDeferAccept = 3;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT,
+ &kTCPDeferAccept, sizeof(kTCPDeferAccept)),
+ SyscallSucceeds());
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_DEFER_ACCEPT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPDeferAccept);
+}
+
+TEST_P(SimpleTcpSocketTest, RecvOnClosedSocket) {
+ auto s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ char buf[1];
+ EXPECT_THAT(recv(s.get(), buf, 0, 0), SyscallFailsWithErrno(ENOTCONN));
+ EXPECT_THAT(recv(s.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(SimpleTcpSocketTest, TCPConnectSoRcvBufRace) {
+ auto s = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(GetParam(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ sockaddr_storage addr =
+ ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
+ socklen_t addrlen = sizeof(addr);
+
+ RetryEINTR(connect)(s.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ addrlen);
+ int buf_sz = 1 << 18;
+ EXPECT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &buf_sz, sizeof(buf_sz)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPSynCntLessThanOne) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int default_syn_cnt = get;
+
+ {
+ // TCP_SYNCNT less than 1 should be rejected with an EINVAL.
+ constexpr int kZero = 0;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kZero, sizeof(kZero)),
+ SyscallFailsWithErrno(EINVAL));
+
+ // TCP_SYNCNT less than 1 should be rejected with an EINVAL.
+ constexpr int kNeg = -1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kNeg, sizeof(kNeg)),
+ SyscallFailsWithErrno(EINVAL));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(default_syn_cnt, get);
+ }
+}
+
+TEST_P(SimpleTcpSocketTest, GetTCPSynCntDefault) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ constexpr int kDefaultSynCnt = 6;
+
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kDefaultSynCnt);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPSynCntGreaterThanOne) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ constexpr int kTCPSynCnt = 20;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt,
+ sizeof(kTCPSynCnt)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kTCPSynCnt);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPSynCntAboveMax) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int default_syn_cnt = get;
+ {
+ constexpr int kTCPSynCnt = 256;
+ ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &kTCPSynCnt,
+ sizeof(kTCPSynCnt)),
+ SyscallFailsWithErrno(EINVAL));
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_SYNCNT, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, default_syn_cnt);
+ }
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPWindowClampBelowMinRcvBuf) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Discover minimum receive buf by setting a really low value
+ // for the receive buffer.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ // Now retrieve the minimum value for SO_RCVBUF as the set above should
+ // have caused SO_RCVBUF for the socket to be set to the minimum.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int min_so_rcvbuf = get;
+
+ {
+ // TCP_WINDOW_CLAMP less than min_so_rcvbuf/2 should be set to
+ // min_so_rcvbuf/2.
+ int below_half_min_rcvbuf = min_so_rcvbuf / 2 - 1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &below_half_min_rcvbuf, sizeof(below_half_min_rcvbuf)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(min_so_rcvbuf / 2, get);
+ }
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPWindowClampZeroClosedSocket) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ constexpr int kZero = 0;
+ ASSERT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len),
+ SyscallSucceeds());
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kZero);
+}
+
+TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ // Discover minimum receive buf by setting a really low value
+ // for the receive buffer.
+ constexpr int kZero = 0;
+ EXPECT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &kZero, sizeof(kZero)),
+ SyscallSucceeds());
+
+ // Now retrieve the minimum value for SO_RCVBUF as the set above should
+ // have caused SO_RCVBUF for the socket to be set to the minimum.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_RCVBUF, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ int min_so_rcvbuf = get;
+
+ {
+ int above_half_min_rcv_buf = min_so_rcvbuf / 2 + 1;
+ EXPECT_THAT(
+ setsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP,
+ &above_half_min_rcv_buf, sizeof(above_half_min_rcv_buf)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(s.get(), IPPROTO_TCP, TCP_WINDOW_CLAMP, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(above_half_min_rcv_buf, get);
+ }
+}
+
+#ifndef __fuchsia__
+
+// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+// gVisor currently silently ignores attaching a filter.
+TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ // Program generated using sudo tcpdump -i lo tcp and port 1234 -dd
+ struct sock_filter code[] = {
+ {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd},
+ {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000006},
+ {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2},
+ {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2},
+ {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017},
+ {0x15, 0, 8, 0x00000006}, {0x28, 0, 0, 0x00000014},
+ {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e},
+ {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2},
+ {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2},
+ {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000},
+ };
+ struct sock_fprog bpf = {
+ .len = ABSL_ARRAYSIZE(code),
+ .filter = code,
+ };
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)),
+ SyscallSucceeds());
+
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+}
+
+TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ SKIP_IF(IsRunningOnGvisor());
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) {
+ FileDescriptor s =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
+
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
+
INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest,
::testing::Values(AF_INET, AF_INET6));
diff --git a/test/syscalls/linux/time.cc b/test/syscalls/linux/time.cc
index c7eead17e..e75bba669 100644
--- a/test/syscalls/linux/time.cc
+++ b/test/syscalls/linux/time.cc
@@ -26,6 +26,7 @@ namespace {
constexpr long kFudgeSeconds = 5;
+#if defined(__x86_64__) || defined(__i386__)
// Mimics the time(2) wrapper from glibc prior to 2.15.
time_t vsyscall_time(time_t* t) {
constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
@@ -62,6 +63,7 @@ TEST(TimeTest, VsyscallTime_InvalidAddressSIGSEGV) {
::testing::KilledBySignal(SIGSEGV), "");
}
+// Mimics the gettimeofday(2) wrapper from the Go runtime <= 1.2.
int vsyscall_gettimeofday(struct timeval* tv, struct timezone* tz) {
constexpr uint64_t kVsyscallGettimeofdayEntry = 0xffffffffff600000;
return reinterpret_cast<int (*)(struct timeval*, struct timezone*)>(
@@ -97,6 +99,7 @@ TEST(TimeTest, VsyscallGettimeofday_InvalidAddressSIGSEGV) {
reinterpret_cast<struct timezone*>(0x1)),
::testing::KilledBySignal(SIGSEGV), "");
}
+#endif
} // namespace
diff --git a/test/syscalls/linux/timerfd.cc b/test/syscalls/linux/timerfd.cc
index 86ed87b7c..c4f8fdd7a 100644
--- a/test/syscalls/linux/timerfd.cc
+++ b/test/syscalls/linux/timerfd.cc
@@ -204,16 +204,33 @@ TEST_P(TimerfdTest, SetAbsoluteTime) {
EXPECT_EQ(1, val);
}
-TEST_P(TimerfdTest, IllegalReadWrite) {
+TEST_P(TimerfdTest, IllegalSeek) {
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ if (!IsRunningWithVFS1()) {
+ EXPECT_THAT(lseek(tfd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE));
+ }
+}
+
+TEST_P(TimerfdTest, IllegalPread) {
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ int val;
+ EXPECT_THAT(pread(tfd.get(), &val, sizeof(val), 0),
+ SyscallFailsWithErrno(ESPIPE));
+}
+
+TEST_P(TimerfdTest, IllegalPwrite) {
+ auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0));
+ EXPECT_THAT(pwrite(tfd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE));
+ if (!IsRunningWithVFS1()) {
+ }
+}
+
+TEST_P(TimerfdTest, IllegalWrite) {
auto const tfd =
ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK));
uint64_t val = 0;
- EXPECT_THAT(PreadFd(tfd.get(), &val, sizeof(val), 0),
- SyscallFailsWithErrno(ESPIPE));
- EXPECT_THAT(WriteFd(tfd.get(), &val, sizeof(val)),
+ EXPECT_THAT(write(tfd.get(), &val, sizeof(val)),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(PwriteFd(tfd.get(), &val, sizeof(val), 0),
- SyscallFailsWithErrno(ESPIPE));
}
std::string PrintClockId(::testing::TestParamInfo<int> info) {
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index 3db18d7ac..4b3c44527 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -297,9 +297,13 @@ class IntervalTimer {
PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid,
const struct sigevent& sev) {
int timerid;
- if (syscall(SYS_timer_create, clockid, &sev, &timerid) < 0) {
+ int ret = syscall(SYS_timer_create, clockid, &sev, &timerid);
+ if (ret < 0) {
return PosixError(errno, "timer_create");
}
+ if (ret > 0) {
+ return PosixError(EINVAL, "timer_create should never return positive");
+ }
MaybeSave();
return IntervalTimer(timerid);
}
@@ -317,6 +321,18 @@ TEST(IntervalTimerTest, IsInitiallyStopped) {
EXPECT_EQ(0, its.it_value.tv_nsec);
}
+// Kernel can create multiple timers without issue.
+//
+// Regression test for gvisor.dev/issue/1738.
+TEST(IntervalTimerTest, MultipleTimers) {
+ struct sigevent sev = {};
+ sev.sigev_notify = SIGEV_NONE;
+ const auto timer1 =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+ const auto timer2 =
+ ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev));
+}
+
TEST(IntervalTimerTest, SingleShotSilent) {
struct sigevent sev = {};
sev.sigev_notify = SIGEV_NONE;
@@ -642,5 +658,5 @@ int main(int argc, char** argv) {
}
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/tkill.cc b/test/syscalls/linux/tkill.cc
index bae377c69..8d8ebbb24 100644
--- a/test/syscalls/linux/tkill.cc
+++ b/test/syscalls/linux/tkill.cc
@@ -54,7 +54,7 @@ void SigHandler(int sig, siginfo_t* info, void* context) {
TEST_CHECK(info->si_code == SI_TKILL);
}
-// Test with a real signal.
+// Test with a real signal. Regression test for b/24790092.
TEST(TkillTest, ValidTIDAndRealSignal) {
struct sigaction sa;
sa.sa_sigaction = SigHandler;
diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc
index e5cc5d97c..c988c6380 100644
--- a/test/syscalls/linux/truncate.cc
+++ b/test/syscalls/linux/truncate.cc
@@ -19,6 +19,7 @@
#include <sys/vfs.h>
#include <time.h>
#include <unistd.h>
+
#include <iostream>
#include <string>
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
new file mode 100644
index 000000000..97d554e72
--- /dev/null
+++ b/test/syscalls/linux/tuntap.cc
@@ -0,0 +1,422 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_ether.h>
+#include <linux/if_tun.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr int kIPLen = 4;
+
+constexpr const char kDevNetTun[] = "/dev/net/tun";
+constexpr const char kTapName[] = "tap0";
+
+constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
+constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB};
+
+PosixErrorOr<std::set<std::string>> DumpLinkNames() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ std::set<std::string> names;
+ for (const auto& link : links) {
+ names.emplace(link.name);
+ }
+ return names;
+}
+
+PosixErrorOr<Link> GetLinkByName(const std::string& name) {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.name == name) {
+ return link;
+ }
+ }
+ return PosixError(ENOENT, "interface not found");
+}
+
+struct pihdr {
+ uint16_t pi_flags;
+ uint16_t pi_protocol;
+} __attribute__((packed));
+
+struct ping_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct iphdr ip;
+ struct icmphdr icmp;
+ char payload[64];
+} __attribute__((packed));
+
+ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ ping_pkt pkt = {};
+
+ pkt.pi.pi_protocol = htons(ETH_P_IP);
+
+ memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest));
+ memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source));
+ pkt.eth.h_proto = htons(ETH_P_IP);
+
+ pkt.ip.ihl = 5;
+ pkt.ip.version = 4;
+ pkt.ip.tos = 0;
+ pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) +
+ sizeof(pkt.payload));
+ pkt.ip.id = 1;
+ pkt.ip.frag_off = 1 << 6; // Do not fragment
+ pkt.ip.ttl = 64;
+ pkt.ip.protocol = IPPROTO_ICMP;
+ inet_pton(AF_INET, dstip, &pkt.ip.daddr);
+ inet_pton(AF_INET, srcip, &pkt.ip.saddr);
+ pkt.ip.check = IPChecksum(pkt.ip);
+
+ pkt.icmp.type = ICMP_ECHO;
+ pkt.icmp.code = 0;
+ pkt.icmp.checksum = 0;
+ pkt.icmp.un.echo.sequence = 1;
+ pkt.icmp.un.echo.id = 1;
+
+ strncpy(pkt.payload, "abcd", sizeof(pkt.payload));
+ pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload));
+
+ return pkt;
+}
+
+struct arp_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct arphdr arp;
+ uint8_t arp_sha[ETH_ALEN];
+ uint8_t arp_spa[kIPLen];
+ uint8_t arp_tha[ETH_ALEN];
+ uint8_t arp_tpa[kIPLen];
+} __attribute__((packed));
+
+std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ std::string buffer;
+ buffer.resize(sizeof(arp_pkt));
+
+ arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]);
+ {
+ pkt->pi.pi_protocol = htons(ETH_P_ARP);
+
+ memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest));
+ memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source));
+ pkt->eth.h_proto = htons(ETH_P_ARP);
+
+ pkt->arp.ar_hrd = htons(ARPHRD_ETHER);
+ pkt->arp.ar_pro = htons(ETH_P_IP);
+ pkt->arp.ar_hln = ETH_ALEN;
+ pkt->arp.ar_pln = kIPLen;
+ pkt->arp.ar_op = htons(ARPOP_REPLY);
+
+ memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha));
+ inet_pton(AF_INET, srcip, pkt->arp_spa);
+ memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha));
+ inet_pton(AF_INET, dstip, pkt->arp_tpa);
+ }
+ return buffer;
+}
+
+} // namespace
+
+TEST(TuntapStaticTest, NetTunExists) {
+ struct stat statbuf;
+ ASSERT_THAT(stat(kDevNetTun, &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
+class TuntapTest : public ::testing::Test {
+ protected:
+ void TearDown() override {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) {
+ // Bring back capability if we had dropped it in test case.
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true));
+ }
+ }
+};
+
+TEST_F(TuntapTest, CreateInterfaceNoCap) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ strncpy(ifr.ifr_name, kTapName, IFNAMSIZ);
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(TuntapTest, CreateFixedNameInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_expect = ifr_set;
+ // See __tun_chr_ioctl() in net/drivers/tun.c.
+ ifr_expect.ifr_flags |= IFF_NOFILTER;
+
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(kTapName)));
+ EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0));
+}
+
+TEST_F(TuntapTest, CreateInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ // Empty ifr.ifr_name. Let kernel assign.
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ std::string ifname = ifr_get.ifr_name;
+ EXPECT_THAT(ifname, ::testing::StartsWith("tap"));
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(ifname)));
+}
+
+TEST_F(TuntapTest, InvalidReadWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ char buf[128] = {};
+ EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+}
+
+TEST_F(TuntapTest, WriteToDownDevice) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ // Device created should be down by default.
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ char buf[128] = {};
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO));
+}
+
+PosixErrorOr<FileDescriptor> OpenAndAttachTap(
+ const std::string& dev_name, const std::string& dev_ipv4_addr) {
+ // Interface creation.
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, dev_name.c_str(), IFNAMSIZ);
+ if (ioctl(fd.get(), TUNSETIFF, &ifr_set) < 0) {
+ return PosixError(errno);
+ }
+
+ ASSIGN_OR_RETURN_ERRNO(auto link, GetLinkByName(dev_name));
+
+ // Interface setup.
+ struct in_addr addr;
+ inet_pton(AF_INET, dev_ipv4_addr.c_str(), &addr);
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(link.index, AF_INET, /*prefixlen=*/24, &addr,
+ sizeof(addr)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME(b/110961832): gVisor doesn't support setting MAC address on
+ // interfaces yet.
+ RETURN_IF_ERRNO(LinkSetMacAddr(link.index, kMacA, sizeof(kMacA)));
+
+ // FIXME(b/110961832): gVisor always creates enabled/up'd interfaces.
+ RETURN_IF_ERRNO(LinkChangeFlags(link.index, IFF_UP, IFF_UP));
+ }
+
+ return fd;
+}
+
+// This test sets up a TAP device and pings kernel by sending ICMP echo request.
+//
+// It works as the following:
+// * Open /dev/net/tun, and create kTapName interface.
+// * Use rtnetlink to do initial setup of the interface:
+// * Assign IP address 10.0.0.1/24 to kernel.
+// * MAC address: kMacA
+// * Bring up the interface.
+// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel.
+// * Loop to receive packets from TAP device/fd:
+// * If packet is an ICMP echo reply, it stops and passes the test.
+// * If packet is an ARP request, it responds with canned reply and resends
+// the
+// ICMP request packet.
+TEST_F(TuntapTest, PingKernel) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+ ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+ std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+
+ // Send ping, this would trigger an ARP request on Linux.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+
+ // Receive loop to process inbound packets.
+ struct inpkt {
+ union {
+ pihdr pi;
+ ping_pkt ping;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ // Process ARP packet.
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ // Respond with canned ARP reply.
+ EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()),
+ SyscallSucceedsWithValue(arp_rep.size()));
+ // First ping request might have been dropped due to mac address not in
+ // ARP cache. Send it again.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+ }
+
+ // Process ping response packet.
+ if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol &&
+ r.ping.ip.protocol == ping_req.ip.protocol &&
+ !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) &&
+ !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) &&
+ r.ping.icmp.type == 0 && r.ping.icmp.code == 0) {
+ // Ends and passes the test.
+ break;
+ }
+ }
+}
+
+TEST_F(TuntapTest, SendUdpTriggersArpResolution) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+
+ // Send a UDP packet to remote.
+ int sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_IP);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_in remote = {};
+ remote.sin_family = AF_INET;
+ remote.sin_port = htons(42);
+ inet_pton(AF_INET, "10.0.0.2", &remote.sin_addr);
+ int ret = sendto(sock, "hello", 5, 0, reinterpret_cast<sockaddr*>(&remote),
+ sizeof(remote));
+ ASSERT_THAT(ret, ::testing::AnyOf(SyscallSucceeds(),
+ SyscallFailsWithErrno(EHOSTDOWN)));
+
+ struct inpkt {
+ union {
+ pihdr pi;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ break;
+ }
+ }
+}
+
+// Write hang bug found by syskaller: b/155928773
+// https://syzkaller.appspot.com/bug?id=065b893bd8d1d04a4e0a1d53c578537cde1efe99
+TEST_F(TuntapTest, WriteHangBug155928773) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(OpenAndAttachTap(kTapName, "10.0.0.1"));
+
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ ASSERT_THAT(sock, SyscallSucceeds());
+
+ struct sockaddr_in remote = {};
+ remote.sin_family = AF_INET;
+ remote.sin_port = htons(42);
+ inet_pton(AF_INET, "10.0.0.1", &remote.sin_addr);
+ // Return values do not matter in this test.
+ connect(sock, reinterpret_cast<struct sockaddr*>(&remote), sizeof(remote));
+ write(sock, "hello", 5);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/tuntap_hostinet.cc b/test/syscalls/linux/tuntap_hostinet.cc
new file mode 100644
index 000000000..1513fb9d5
--- /dev/null
+++ b/test/syscalls/linux/tuntap_hostinet.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(TuntapHostInetTest, NoNetTun) {
+ SKIP_IF(!IsRunningOnGvisor());
+ SKIP_IF(!IsRunningWithHostinet());
+
+ struct stat statbuf;
+ ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallFailsWithErrno(ENOENT));
+}
+
+} // namespace
+} // namespace testing
+
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc
index 111dbacdf..7a8ac30a4 100644
--- a/test/syscalls/linux/udp_socket.cc
+++ b/test/syscalls/linux/udp_socket.cc
@@ -12,1332 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <arpa/inet.h>
-#include <fcntl.h>
-#include <linux/errqueue.h>
-#include <netinet/in.h>
-#include <sys/ioctl.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-
-#include "gtest/gtest.h"
-#include "absl/base/macros.h"
-#include "absl/time/clock.h"
-#include "absl/time/time.h"
-#include "test/syscalls/linux/socket_test_util.h"
-#include "test/syscalls/linux/unix_domain_socket_test_util.h"
-#include "test/util/test_util.h"
-#include "test/util/thread_util.h"
+#include "test/syscalls/linux/udp_socket_test_cases.h"
namespace gvisor {
namespace testing {
namespace {
-// The initial port to be be used on gvisor.
-constexpr int TestPort = 40000;
-
-// Fixture for tests parameterized by the address family to use (AF_INET and
-// AF_INET6) when creating sockets.
-class UdpSocketTest : public ::testing::TestWithParam<AddressFamily> {
- protected:
- // Creates two sockets that will be used by test cases.
- void SetUp() override;
-
- // Closes the sockets created by SetUp().
- void TearDown() override {
- EXPECT_THAT(close(s_), SyscallSucceeds());
- EXPECT_THAT(close(t_), SyscallSucceeds());
-
- for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) {
- ASSERT_NO_ERRNO(FreeAvailablePort(ports_[i]));
- }
- }
-
- // First UDP socket.
- int s_;
-
- // Second UDP socket.
- int t_;
-
- // The length of the socket address.
- socklen_t addrlen_;
-
- // Initialized address pointing to loopback and port TestPort+i.
- struct sockaddr* addr_[3];
-
- // Initialize "any" address.
- struct sockaddr* anyaddr_;
-
- // Used ports.
- int ports_[3];
-
- private:
- // Storage for the loopback addresses.
- struct sockaddr_storage addr_storage_[3];
-
- // Storage for the "any" address.
- struct sockaddr_storage anyaddr_storage_;
-};
-
-// Gets a pointer to the port component of the given address.
-uint16_t* Port(struct sockaddr_storage* addr) {
- switch (addr->ss_family) {
- case AF_INET: {
- auto sin = reinterpret_cast<struct sockaddr_in*>(addr);
- return &sin->sin_port;
- }
- case AF_INET6: {
- auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr);
- return &sin6->sin6_port;
- }
- }
-
- return nullptr;
-}
-
-void UdpSocketTest::SetUp() {
- int type;
- if (GetParam() == AddressFamily::kIpv4) {
- type = AF_INET;
- auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_);
- addrlen_ = sizeof(*sin);
- sin->sin_addr.s_addr = htonl(INADDR_ANY);
- } else {
- type = AF_INET6;
- auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_);
- addrlen_ = sizeof(*sin6);
- if (GetParam() == AddressFamily::kIpv6) {
- sin6->sin6_addr = IN6ADDR_ANY_INIT;
- } else {
- TestAddress const& v4_mapped_any = V4MappedAny();
- sin6->sin6_addr =
- reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
- ->sin6_addr;
- }
- }
- ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
-
- ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
-
- memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_));
- anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_);
- anyaddr_->sa_family = type;
-
- if (gvisor::testing::IsRunningOnGvisor()) {
- for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) {
- ports_[i] = TestPort + i;
- }
- } else {
- // When not under gvisor, use utility function to pick port. Assert that
- // all ports are different.
- std::string error;
- for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) {
- // Find an unused port, we specify port 0 to allow the kernel to provide
- // the port.
- bool unique = true;
- do {
- ports_[i] = ASSERT_NO_ERRNO_AND_VALUE(PortAvailable(
- 0, AddressFamily::kDualStack, SocketType::kUdp, false));
- ASSERT_GT(ports_[i], 0);
- for (size_t j = 0; j < i; ++j) {
- if (ports_[j] == ports_[i]) {
- unique = false;
- break;
- }
- }
- } while (!unique);
- }
- }
-
- // Initialize the sockaddrs.
- for (size_t i = 0; i < ABSL_ARRAYSIZE(addr_); ++i) {
- memset(&addr_storage_[i], 0, sizeof(addr_storage_[i]));
-
- addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]);
- addr_[i]->sa_family = type;
-
- switch (type) {
- case AF_INET: {
- auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]);
- sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
- sin->sin_port = htons(ports_[i]);
- break;
- }
- case AF_INET6: {
- auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr_[i]);
- sin6->sin6_addr = in6addr_loopback;
- sin6->sin6_port = htons(ports_[i]);
- break;
- }
- }
- }
-}
-
-TEST_P(UdpSocketTest, Creation) {
- int type = AF_INET6;
- if (GetParam() == AddressFamily::kIpv4) {
- type = AF_INET;
- }
-
- int s_;
-
- ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds());
- EXPECT_THAT(close(s_), SyscallSucceeds());
-
- ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds());
- EXPECT_THAT(close(s_), SyscallSucceeds());
-
- ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails());
-}
-
-TEST_P(UdpSocketTest, Getsockname) {
- // Check that we're not bound.
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(memcmp(&addr, anyaddr_, addrlen_), 0);
-
- // Bind, then check that we get the right address.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0);
-}
-
-TEST_P(UdpSocketTest, Getpeername) {
- // Check that we're not connected.
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
-
- // Connect, then check that we get the right address.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- addrlen = sizeof(addr);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0);
-}
-
-TEST_P(UdpSocketTest, SendNotConnected) {
- // Do send & write, they must fail.
- char buf[512];
- EXPECT_THAT(send(s_, buf, sizeof(buf), 0),
- SyscallFailsWithErrno(EDESTADDRREQ));
-
- EXPECT_THAT(write(s_, buf, sizeof(buf)), SyscallFailsWithErrno(EDESTADDRREQ));
-
- // Use sendto.
- ASSERT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Check that we're bound now.
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_NE(*Port(&addr), 0);
-}
-
-TEST_P(UdpSocketTest, ConnectBinds) {
- // Connect the socket.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Check that we're bound now.
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_NE(*Port(&addr), 0);
-}
-
-TEST_P(UdpSocketTest, ReceiveNotBound) {
- char buf[512];
- EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-}
-
-TEST_P(UdpSocketTest, Bind) {
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Try to bind again.
- EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL));
-
- // Check that we're still bound to the original address.
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0);
-}
-
-TEST_P(UdpSocketTest, BindInUse) {
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Try to bind again.
- EXPECT_THAT(bind(t_, addr_[0], addrlen_), SyscallFailsWithErrno(EADDRINUSE));
-}
-
-TEST_P(UdpSocketTest, ReceiveAfterConnect) {
- // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Get the address s_ was bound to during connect.
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
-
- // Send from t_ to s_.
- char buf[512];
- RandomizeBuffer(buf, sizeof(buf));
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0,
- reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Receive the data.
- char received[sizeof(buf)];
- EXPECT_THAT(recv(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(sizeof(received)));
- EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
-}
-
-TEST_P(UdpSocketTest, ReceiveAfterDisconnect) {
- // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Get the address s_ was bound to during connect.
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
-
- for (int i = 0; i < 2; i++) {
- // Send from t_ to s_.
- char buf[512];
- RandomizeBuffer(buf, sizeof(buf));
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0,
- reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Receive the data.
- char received[sizeof(buf)];
- EXPECT_THAT(recv(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(sizeof(received)));
- EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
-
- // Disconnect s_.
- struct sockaddr addr = {};
- addr.sa_family = AF_UNSPEC;
- ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds());
- // Connect s_ loopback:TestPort.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
- }
-}
-
-TEST_P(UdpSocketTest, Connect) {
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Check that we're connected to the right peer.
- struct sockaddr_storage peer;
- socklen_t peerlen = sizeof(peer);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallSucceeds());
- EXPECT_EQ(peerlen, addrlen_);
- EXPECT_EQ(memcmp(&peer, addr_[0], addrlen_), 0);
-
- // Try to bind after connect.
- EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL));
-
- // Try to connect again.
- EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds());
-
- // Check that peer name changed.
- peerlen = sizeof(peer);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallSucceeds());
- EXPECT_EQ(peerlen, addrlen_);
- EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
-}
-
-void ConnectAny(AddressFamily family, int sockfd, uint16_t port) {
- struct sockaddr_storage addr = {};
-
- // Precondition check.
- {
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
-
- if (family == AddressFamily::kIpv4) {
- auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
- EXPECT_EQ(addrlen, sizeof(*addr_out));
- EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
- } else {
- auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
- EXPECT_EQ(addrlen, sizeof(*addr_out));
- struct in6_addr any = IN6ADDR_ANY_INIT;
- EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0);
- }
-
- {
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
- }
-
- struct sockaddr_storage baddr = {};
- if (family == AddressFamily::kIpv4) {
- auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
- addrlen = sizeof(*addr_in);
- addr_in->sin_family = AF_INET;
- addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
- addr_in->sin_port = port;
- } else {
- auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
- addrlen = sizeof(*addr_in);
- addr_in->sin6_family = AF_INET6;
- addr_in->sin6_port = port;
- if (family == AddressFamily::kIpv6) {
- addr_in->sin6_addr = IN6ADDR_ANY_INIT;
- } else {
- TestAddress const& v4_mapped_any = V4MappedAny();
- addr_in->sin6_addr =
- reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
- ->sin6_addr;
- }
- }
-
- // TODO(b/138658473): gVisor doesn't allow connecting to the zero port.
- if (port == 0) {
- SKIP_IF(IsRunningOnGvisor());
- }
-
- ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen),
- SyscallSucceeds());
- }
-
- // Postcondition check.
- {
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(
- getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
-
- if (family == AddressFamily::kIpv4) {
- auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
- EXPECT_EQ(addrlen, sizeof(*addr_out));
- EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK));
- } else {
- auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
- EXPECT_EQ(addrlen, sizeof(*addr_out));
- struct in6_addr loopback;
- if (family == AddressFamily::kIpv6) {
- loopback = IN6ADDR_LOOPBACK_INIT;
- } else {
- TestAddress const& v4_mapped_loopback = V4MappedLoopback();
- loopback = reinterpret_cast<const struct sockaddr_in6*>(
- &v4_mapped_loopback.addr)
- ->sin6_addr;
- }
-
- EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
- }
-
- addrlen = sizeof(addr);
- if (port == 0) {
- EXPECT_THAT(
- getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
- } else {
- EXPECT_THAT(
- getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- }
- }
-}
-
-TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); }
-
-TEST_P(UdpSocketTest, ConnectAnyWithPort) {
- auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
- ConnectAny(GetParam(), s_, port);
-}
-
-void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) {
- struct sockaddr_storage addr = {};
-
- socklen_t addrlen = sizeof(addr);
- struct sockaddr_storage baddr = {};
- if (family == AddressFamily::kIpv4) {
- auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
- addrlen = sizeof(*addr_in);
- addr_in->sin_family = AF_INET;
- addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
- addr_in->sin_port = port;
- } else {
- auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
- addrlen = sizeof(*addr_in);
- addr_in->sin6_family = AF_INET6;
- addr_in->sin6_port = port;
- if (family == AddressFamily::kIpv6) {
- addr_in->sin6_addr = IN6ADDR_ANY_INIT;
- } else {
- TestAddress const& v4_mapped_any = V4MappedAny();
- addr_in->sin6_addr =
- reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr)
- ->sin6_addr;
- }
- }
-
- // TODO(b/138658473): gVisor doesn't allow connecting to the zero port.
- if (port == 0) {
- SKIP_IF(IsRunningOnGvisor());
- }
-
- ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen),
- SyscallSucceeds());
- // Now the socket is bound to the loopback address.
-
- // Disconnect
- addrlen = sizeof(addr);
- addr.ss_family = AF_UNSPEC;
- ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen),
- SyscallSucceeds());
-
- // Check that after disconnect the socket is bound to the ANY address.
- EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- if (family == AddressFamily::kIpv4) {
- auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr);
- EXPECT_EQ(addrlen, sizeof(*addr_out));
- EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
- } else {
- auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr);
- EXPECT_EQ(addrlen, sizeof(*addr_out));
- struct in6_addr loopback = IN6ADDR_ANY_INIT;
-
- EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
- }
-}
-
-TEST_P(UdpSocketTest, DisconnectAfterConnectAny) {
- DisconnectAfterConnectAny(GetParam(), s_, 0);
-}
-
-TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) {
- auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
- DisconnectAfterConnectAny(GetParam(), s_, port);
-}
-
-TEST_P(UdpSocketTest, DisconnectAfterBind) {
- ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
- // Connect the socket.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- struct sockaddr_storage addr = {};
- addr.ss_family = AF_UNSPEC;
- EXPECT_THAT(
- connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
- SyscallSucceeds());
-
- // Check that we're still bound.
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
-
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0);
-
- addrlen = sizeof(addr);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
-}
-
-TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
- struct sockaddr_storage baddr = {};
- socklen_t addrlen;
- auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1]));
- if (GetParam() == AddressFamily::kIpv4) {
- auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr);
- addr_in->sin_family = AF_INET;
- addr_in->sin_port = port;
- addr_in->sin_addr.s_addr = htonl(INADDR_ANY);
- } else {
- auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr);
- addr_in->sin6_family = AF_INET6;
- addr_in->sin6_port = port;
- addr_in->sin6_scope_id = 0;
- addr_in->sin6_addr = IN6ADDR_ANY_INIT;
- }
- ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_),
- SyscallSucceeds());
- // Connect the socket.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- struct sockaddr_storage addr = {};
- addr.ss_family = AF_UNSPEC;
- EXPECT_THAT(
- connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
- SyscallSucceeds());
-
- // Check that we're still bound.
- addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
-
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0);
-
- addrlen = sizeof(addr);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallFailsWithErrno(ENOTCONN));
-}
-
-TEST_P(UdpSocketTest, Disconnect) {
- for (int i = 0; i < 2; i++) {
- // Try to connect again.
- EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds());
-
- // Check that we're connected to the right peer.
- struct sockaddr_storage peer;
- socklen_t peerlen = sizeof(peer);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallSucceeds());
- EXPECT_EQ(peerlen, addrlen_);
- EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0);
-
- // Try to disconnect.
- struct sockaddr_storage addr = {};
- addr.ss_family = AF_UNSPEC;
- EXPECT_THAT(
- connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)),
- SyscallSucceeds());
-
- peerlen = sizeof(peer);
- EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen),
- SyscallFailsWithErrno(ENOTCONN));
-
- // Check that we're still bound.
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceeds());
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(*Port(&addr), 0);
- }
-}
-
-TEST_P(UdpSocketTest, ConnectBadAddress) {
- struct sockaddr addr = {};
- addr.sa_family = addr_[0]->sa_family;
- ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)),
- SyscallFailsWithErrno(EINVAL));
-}
-
-TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Send to a different destination than we're connected to.
- char buf[512];
- EXPECT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[1], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-}
-
-TEST_P(UdpSocketTest, ZerolengthWriteAllowed) {
- // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Bind t_ to loopback:TestPort+1.
- ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds());
-
- char buf[3];
- // Send zero length packet from s_ to t_.
- ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0));
- // Receive the packet.
- char received[3];
- EXPECT_THAT(read(t_, received, sizeof(received)),
- SyscallSucceedsWithValue(0));
-}
-
-TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) {
- // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Bind t_ to loopback:TestPort+1.
- ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Set t_ to non-blocking.
- int opts = 0;
- ASSERT_THAT(opts = fcntl(t_, F_GETFL), SyscallSucceeds());
- ASSERT_THAT(fcntl(t_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds());
-
- char buf[3];
- // Send zero length packet from s_ to t_.
- ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0));
- // Receive the packet.
- char received[3];
- EXPECT_THAT(read(t_, received, sizeof(received)),
- SyscallSucceedsWithValue(0));
- EXPECT_THAT(read(t_, received, sizeof(received)),
- SyscallFailsWithErrno(EAGAIN));
-}
-
-TEST_P(UdpSocketTest, SendAndReceiveNotConnected) {
- // Bind s_ to loopback.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Send some data to s_.
- char buf[512];
- RandomizeBuffer(buf, sizeof(buf));
-
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Receive the data.
- char received[sizeof(buf)];
- EXPECT_THAT(recv(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(sizeof(received)));
- EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
-}
-
-TEST_P(UdpSocketTest, SendAndReceiveConnected) {
- // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Bind t_ to loopback:TestPort+1.
- ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Send some data from t_ to s_.
- char buf[512];
- RandomizeBuffer(buf, sizeof(buf));
-
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Receive the data.
- char received[sizeof(buf)];
- EXPECT_THAT(recv(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(sizeof(received)));
- EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
-}
-
-TEST_P(UdpSocketTest, ReceiveFromNotConnected) {
- // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Bind t_ to loopback:TestPort+2.
- ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds());
-
- // Send some data from t_ to s_.
- char buf[512];
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Check that the data isn't_ received because it was sent from a different
- // address than we're connected.
- EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-}
-
-TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Bind t_ to loopback:TestPort+2.
- ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds());
-
- // Send some data from t_ to s_.
- char buf[512];
- RandomizeBuffer(buf, sizeof(buf));
-
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Connect to loopback:TestPort+1.
- ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Receive the data. It works because it was sent before the connect.
- char received[sizeof(buf)];
- EXPECT_THAT(recv(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(sizeof(received)));
- EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
-
- // Send again. This time it should not be received.
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-}
-
-TEST_P(UdpSocketTest, ReceiveFrom) {
- // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Bind t_ to loopback:TestPort+1.
- ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Send some data from t_ to s_.
- char buf[512];
- RandomizeBuffer(buf, sizeof(buf));
-
- ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // Receive the data and sender address.
- char received[sizeof(buf)];
- struct sockaddr_storage addr;
- socklen_t addrlen = sizeof(addr);
- EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0,
- reinterpret_cast<sockaddr*>(&addr), &addrlen),
- SyscallSucceedsWithValue(sizeof(received)));
- EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
- EXPECT_EQ(addrlen, addrlen_);
- EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0);
-}
-
-TEST_P(UdpSocketTest, Listen) {
- ASSERT_THAT(listen(s_, SOMAXCONN), SyscallFailsWithErrno(EOPNOTSUPP));
-}
-
-TEST_P(UdpSocketTest, Accept) {
- ASSERT_THAT(accept(s_, nullptr, nullptr), SyscallFailsWithErrno(EOPNOTSUPP));
-}
-
-// This test validates that a read shutdown with pending data allows the read
-// to proceed with the data before returning EAGAIN.
-TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) {
- char received[512];
-
- // Bind t_ to loopback:TestPort+2.
- ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds());
-
- // Connect the socket, then try to shutdown again.
- ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds());
-
- // Verify that we get EWOULDBLOCK when there is nothing to read.
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- const char* buf = "abc";
- EXPECT_THAT(write(t_, buf, 3), SyscallSucceedsWithValue(3));
-
- int opts = 0;
- ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds());
- ASSERT_THAT(fcntl(s_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds());
- ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds());
- ASSERT_NE(opts & O_NONBLOCK, 0);
-
- EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
-
- // We should get the data even though read has been shutdown.
- EXPECT_THAT(recv(s_, received, 2, 0), SyscallSucceedsWithValue(2));
-
- // Because we read less than the entire packet length, since it's a packet
- // based socket any subsequent reads should return EWOULDBLOCK.
- EXPECT_THAT(recv(s_, received, 1, 0), SyscallFailsWithErrno(EWOULDBLOCK));
-}
-
-// This test is validating that even after a socket is shutdown if it's
-// reconnected it will reset the shutdown state.
-TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) {
- char received[512];
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
-
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- // Connect the socket, then try to shutdown again.
- ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds());
-
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-}
-
-TEST_P(UdpSocketTest, ReadShutdown) {
- char received[512];
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
-
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- // Connect the socket, then try to shutdown again.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
-
- EXPECT_THAT(recv(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(0));
-}
-
-TEST_P(UdpSocketTest, ReadShutdownDifferentThread) {
- char received[512];
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- // Connect the socket, then shutdown from another thread.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- ScopedThread t([&] {
- absl::SleepFor(absl::Milliseconds(200));
- EXPECT_THAT(shutdown(this->s_, SHUT_RD), SyscallSucceeds());
- });
- EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(0));
- t.Join();
-
- EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(0));
-}
-
-TEST_P(UdpSocketTest, WriteShutdown) {
- EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
- EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds());
-}
-
-TEST_P(UdpSocketTest, SynchronousReceive) {
- // Bind s_ to loopback.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Send some data to s_ from another thread.
- char buf[512];
- RandomizeBuffer(buf, sizeof(buf));
-
- // Receive the data prior to actually starting the other thread.
- char received[512];
- EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), MSG_DONTWAIT),
- SyscallFailsWithErrno(EWOULDBLOCK));
-
- // Start the thread.
- ScopedThread t([&] {
- absl::SleepFor(absl::Milliseconds(200));
- ASSERT_THAT(
- sendto(this->t_, buf, sizeof(buf), 0, this->addr_[0], this->addrlen_),
- SyscallSucceedsWithValue(sizeof(buf)));
- });
-
- EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0),
- SyscallSucceedsWithValue(512));
- EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
-}
-
-TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) {
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Send 3 packets from t_ to s_.
- constexpr int psize = 100;
- char buf[3 * psize];
- RandomizeBuffer(buf, sizeof(buf));
-
- for (int i = 0; i < 3; ++i) {
- ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(psize));
- }
-
- // Receive the data as 3 separate packets.
- char received[6 * psize];
- for (int i = 0; i < 3; ++i) {
- EXPECT_THAT(recv(s_, received + i * psize, 3 * psize, 0),
- SyscallSucceedsWithValue(psize));
- }
- EXPECT_EQ(memcmp(buf, received, 3 * psize), 0);
-}
-
-TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) {
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Direct writes from t_ to s_.
- ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Send 2 packets from t_ to s_, where each packet's data consists of 2
- // discontiguous iovecs.
- constexpr size_t kPieceSize = 100;
- char buf[4 * kPieceSize];
- RandomizeBuffer(buf, sizeof(buf));
-
- for (int i = 0; i < 2; i++) {
- struct iovec iov[2];
- for (int j = 0; j < 2; j++) {
- iov[j].iov_base = reinterpret_cast<void*>(
- reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize);
- iov[j].iov_len = kPieceSize;
- }
- ASSERT_THAT(writev(t_, iov, 2), SyscallSucceedsWithValue(2 * kPieceSize));
- }
-
- // Receive the data as 2 separate packets.
- char received[6 * kPieceSize];
- for (int i = 0; i < 2; i++) {
- struct iovec iov[3];
- for (int j = 0; j < 3; j++) {
- iov[j].iov_base = reinterpret_cast<void*>(
- reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize);
- iov[j].iov_len = kPieceSize;
- }
- ASSERT_THAT(readv(s_, iov, 3), SyscallSucceedsWithValue(2 * kPieceSize));
- }
- EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0);
-}
-
-TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) {
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Send 2 packets from t_ to s_, where each packet's data consists of 2
- // discontiguous iovecs.
- constexpr size_t kPieceSize = 100;
- char buf[4 * kPieceSize];
- RandomizeBuffer(buf, sizeof(buf));
-
- for (int i = 0; i < 2; i++) {
- struct iovec iov[2];
- for (int j = 0; j < 2; j++) {
- iov[j].iov_base = reinterpret_cast<void*>(
- reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize);
- iov[j].iov_len = kPieceSize;
- }
- struct msghdr msg = {};
- msg.msg_name = addr_[0];
- msg.msg_namelen = addrlen_;
- msg.msg_iov = iov;
- msg.msg_iovlen = 2;
- ASSERT_THAT(sendmsg(t_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize));
- }
-
- // Receive the data as 2 separate packets.
- char received[6 * kPieceSize];
- for (int i = 0; i < 2; i++) {
- struct iovec iov[3];
- for (int j = 0; j < 3; j++) {
- iov[j].iov_base = reinterpret_cast<void*>(
- reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize);
- iov[j].iov_len = kPieceSize;
- }
- struct msghdr msg = {};
- msg.msg_iov = iov;
- msg.msg_iovlen = 3;
- ASSERT_THAT(recvmsg(s_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize));
- }
- EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0);
-}
-
-TEST_P(UdpSocketTest, FIONREADShutdown) {
- int n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- // A UDP socket must be connected before it can be shutdown.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-}
-
-TEST_P(UdpSocketTest, FIONREADWriteShutdown) {
- int n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // A UDP socket must be connected before it can be shutdown.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- const char str[] = "abc";
- ASSERT_THAT(send(s_, str, sizeof(str), 0),
- SyscallSucceedsWithValue(sizeof(str)));
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, sizeof(str));
-
- EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, sizeof(str));
-}
-
-TEST_P(UdpSocketTest, FIONREAD) {
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Check that the bound socket with an empty buffer reports an empty first
- // packet.
- int n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- // Send 3 packets from t_ to s_.
- constexpr int psize = 100;
- char buf[3 * psize];
- RandomizeBuffer(buf, sizeof(buf));
-
- for (int i = 0; i < 3; ++i) {
- ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(psize));
-
- // Check that regardless of how many packets are in the queue, the size
- // reported is that of a single packet.
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, psize);
- }
-}
-
-TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) {
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // Check that the bound socket with an empty buffer reports an empty first
- // packet.
- int n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- // Send 3 packets from t_ to s_.
- constexpr int psize = 100;
- char buf[3 * psize];
- RandomizeBuffer(buf, sizeof(buf));
-
- for (int i = 0; i < 3; ++i) {
- ASSERT_THAT(sendto(t_, buf + i * psize, 0, 0, addr_[0], addrlen_),
- SyscallSucceedsWithValue(0));
-
- // Check that regardless of how many packets are in the queue, the size
- // reported is that of a single packet.
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
- }
-}
-
-TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) {
- int n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- // Bind s_ to loopback:TestPort.
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- // A UDP socket must be connected before it can be shutdown.
- ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds());
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- const char str[] = "abc";
- ASSERT_THAT(send(s_, str, 0, 0), SyscallSucceedsWithValue(0));
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-
- EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds());
-
- n = -1;
- EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0));
- EXPECT_EQ(n, 0);
-}
-
-TEST_P(UdpSocketTest, ErrorQueue) {
- char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))];
- msghdr msg;
- memset(&msg, 0, sizeof(msg));
- iovec iov;
- memset(&iov, 0, sizeof(iov));
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
- msg.msg_control = cmsgbuf;
- msg.msg_controllen = sizeof(cmsgbuf);
-
- // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
- EXPECT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_ERRQUEUE),
- SyscallFailsWithErrno(EAGAIN));
-}
-
-TEST_P(UdpSocketTest, SoTimestampOffByDefault) {
- int v = -1;
- socklen_t optlen = sizeof(v);
- ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen),
- SyscallSucceeds());
- ASSERT_EQ(v, kSockOptOff);
- ASSERT_EQ(optlen, sizeof(v));
-}
-
-TEST_P(UdpSocketTest, SoTimestamp) {
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
-
- int v = 1;
- ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)),
- SyscallSucceeds());
-
- char buf[3];
- // Send zero length packet from t_ to s_.
- ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0));
-
- char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))];
- msghdr msg;
- memset(&msg, 0, sizeof(msg));
- iovec iov;
- memset(&iov, 0, sizeof(iov));
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
- msg.msg_control = cmsgbuf;
- msg.msg_controllen = sizeof(cmsgbuf);
-
- ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0));
-
- struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
- ASSERT_NE(cmsg, nullptr);
- ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
- ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP);
- ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval)));
-
- struct timeval tv = {};
- memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval));
-
- ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
-
- // There should be nothing to get via ioctl.
- ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT));
-}
-
-TEST_P(UdpSocketTest, WriteShutdownNotConnected) {
- EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
-}
-
-TEST_P(UdpSocketTest, TimestampIoctl) {
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
-
- char buf[3];
- // Send packet from t_ to s_.
- ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)),
- SyscallSucceedsWithValue(sizeof(buf)));
-
- // There should be no control messages.
- char recv_buf[sizeof(buf)];
- ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf)));
-
- // A nonzero timeval should be available via ioctl.
- struct timeval tv = {};
- ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds());
- ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
-}
-
-TEST_P(UdpSocketTest, TimetstampIoctlNothingRead) {
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
-
- struct timeval tv = {};
- ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT));
-}
-
-// Test that the timestamp accessed via SIOCGSTAMP is still accessible after
-// SO_TIMESTAMP is enabled and used to retrieve a timestamp.
-TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
- ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
- ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
-
- char buf[3];
- // Send packet from t_ to s_.
- ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)),
- SyscallSucceedsWithValue(sizeof(buf)));
- ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0));
-
- // There should be no control messages.
- char recv_buf[sizeof(buf)];
- ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf)));
-
- // A nonzero timeval should be available via ioctl.
- struct timeval tv = {};
- ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds());
- ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
-
- // Enable SO_TIMESTAMP and send a message.
- int v = 1;
- EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)),
- SyscallSucceeds());
- ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0));
-
- // There should be a message for SO_TIMESTAMP.
- char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))];
- msghdr msg = {};
- iovec iov = {};
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
- msg.msg_control = cmsgbuf;
- msg.msg_controllen = sizeof(cmsgbuf);
- ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0));
- struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
- cmsg = CMSG_FIRSTHDR(&msg);
- ASSERT_NE(cmsg, nullptr);
-
- // The ioctl should return the exact same values as before.
- struct timeval tv2 = {};
- ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv2), SyscallSucceeds());
- ASSERT_EQ(tv.tv_sec, tv2.tv_sec);
- ASSERT_EQ(tv.tv_usec, tv2.tv_usec);
-}
-
INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest,
::testing::Values(AddressFamily::kIpv4,
AddressFamily::kIpv6,
diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc
new file mode 100644
index 000000000..54a0594f7
--- /dev/null
+++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc
@@ -0,0 +1,57 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef __fuchsia__
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <linux/errqueue.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/udp_socket_test_cases.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+TEST_P(UdpSocketTest, ErrorQueue) {
+ char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))];
+ msghdr msg;
+ memset(&msg, 0, sizeof(msg));
+ iovec iov;
+ memset(&iov, 0, sizeof(iov));
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = sizeof(cmsgbuf);
+
+ // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
+ EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // __fuchsia__
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
new file mode 100644
index 000000000..60c48ed6e
--- /dev/null
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -0,0 +1,1781 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/udp_socket_test_cases.h"
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#ifndef __fuchsia__
+#include <linux/filter.h>
+#endif // __fuchsia__
+#include <netinet/in.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "absl/strings/str_format.h"
+#ifndef SIOCGSTAMP
+#include <linux/sockios.h>
+#endif
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/syscalls/linux/unix_domain_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// Gets a pointer to the port component of the given address.
+uint16_t* Port(struct sockaddr_storage* addr) {
+ switch (addr->ss_family) {
+ case AF_INET: {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(addr);
+ return &sin->sin_port;
+ }
+ case AF_INET6: {
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr);
+ return &sin6->sin6_port;
+ }
+ }
+
+ return nullptr;
+}
+
+// Sets addr port to "port".
+void SetPort(struct sockaddr_storage* addr, uint16_t port) {
+ switch (addr->ss_family) {
+ case AF_INET: {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(addr);
+ sin->sin_port = port;
+ break;
+ }
+ case AF_INET6: {
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr);
+ sin6->sin6_port = port;
+ break;
+ }
+ }
+}
+
+void UdpSocketTest::SetUp() {
+ addrlen_ = GetAddrLength();
+
+ bind_ =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_));
+ bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+
+ sock_ =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+}
+
+int UdpSocketTest::GetFamily() {
+ if (GetParam() == AddressFamily::kIpv4) {
+ return AF_INET;
+ }
+ return AF_INET6;
+}
+
+PosixError UdpSocketTest::BindLoopback() {
+ bind_addr_storage_ = InetLoopbackAddr();
+ struct sockaddr* bind_addr_ =
+ reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+ return BindSocket(bind_.get(), bind_addr_);
+}
+
+PosixError UdpSocketTest::BindAny() {
+ bind_addr_storage_ = InetAnyAddr();
+ struct sockaddr* bind_addr_ =
+ reinterpret_cast<struct sockaddr*>(&bind_addr_storage_);
+ return BindSocket(bind_.get(), bind_addr_);
+}
+
+PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) {
+ socklen_t len = sizeof(bind_addr_storage_);
+
+ // Bind, then check that we get the right address.
+ RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len));
+
+ if (addrlen_ != len) {
+ return PosixError(
+ EINVAL,
+ absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_));
+ }
+ return PosixError(0);
+}
+
+socklen_t UdpSocketTest::GetAddrLength() {
+ struct sockaddr_storage addr;
+ if (GetFamily() == AF_INET) {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ return sizeof(*sin);
+ }
+
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ return sizeof(*sin6);
+}
+
+sockaddr_storage UdpSocketTest::InetAnyAddr() {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily();
+
+ if (GetFamily() == AF_INET) {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_addr.s_addr = htonl(INADDR_ANY);
+ sin->sin_port = htons(0);
+ return addr;
+ }
+
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ sin6->sin6_addr = IN6ADDR_ANY_INIT;
+ sin6->sin6_port = htons(0);
+ return addr;
+}
+
+sockaddr_storage UdpSocketTest::InetLoopbackAddr() {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily();
+
+ if (GetFamily() == AF_INET) {
+ auto sin = reinterpret_cast<struct sockaddr_in*>(&addr);
+ sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ sin->sin_port = htons(0);
+ return addr;
+ }
+ auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr);
+ sin6->sin6_addr = in6addr_loopback;
+ sin6->sin6_port = htons(0);
+ return addr;
+}
+
+void UdpSocketTest::Disconnect(int sockfd) {
+ sockaddr_storage addr_storage = InetAnyAddr();
+ sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ socklen_t addrlen = sizeof(addr_storage);
+
+ addr->sa_family = AF_UNSPEC;
+ ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds());
+
+ // Check that after disconnect the socket is bound to the ANY address.
+ EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds());
+ if (GetParam() == AddressFamily::kIpv4) {
+ auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY));
+ } else {
+ auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ struct in6_addr loopback = IN6ADDR_ANY_INIT;
+
+ EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
+ }
+}
+
+TEST_P(UdpSocketTest, Creation) {
+ FileDescriptor sock =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ EXPECT_THAT(close(sock.release()), SyscallSucceeds());
+
+ sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0));
+ EXPECT_THAT(close(sock.release()), SyscallSucceeds());
+
+ ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails());
+}
+
+TEST_P(UdpSocketTest, Getsockname) {
+ // Check that we're not bound.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ 0);
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ EXPECT_THAT(
+ getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, Getpeername) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Check that we're not connected.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+
+ // Connect, then check that we get the right address.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, SendNotConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Do send & write, they must fail.
+ char buf[512];
+ EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0),
+ SyscallFailsWithErrno(EDESTADDRREQ));
+
+ EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)),
+ SyscallFailsWithErrno(EDESTADDRREQ));
+
+ // Use sendto.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Check that we're bound now.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr), 0);
+}
+
+TEST_P(UdpSocketTest, ConnectBinds) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect the socket.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Check that we're bound now.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_NE(*Port(&addr), 0);
+}
+
+TEST_P(UdpSocketTest, ReceiveNotBound) {
+ char buf[512];
+ EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, Bind) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Try to bind again.
+ EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Check that we're still bound to the original address.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, BindInUse) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Try to bind again.
+ EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_),
+ SyscallFailsWithErrno(EADDRINUSE));
+}
+
+TEST_P(UdpSocketTest, ReceiveAfterConnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Send from sock_ to bind_
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, ReceiveAfterDisconnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ for (int i = 0; i < 2; i++) {
+ // Connet sock_ to bound address.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+
+ // Send from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+
+ // Disconnect sock_.
+ struct sockaddr unspec = {};
+ unspec.sa_family = AF_UNSPEC;
+ ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)),
+ SyscallSucceeds());
+ }
+}
+
+TEST_P(UdpSocketTest, Connect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Check that we're connected to the right peer.
+ struct sockaddr_storage peer;
+ socklen_t peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallSucceeds());
+ EXPECT_EQ(peerlen, addrlen_);
+ EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0);
+
+ // Try to bind after connect.
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_THAT(
+ bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ SyscallFailsWithErrno(EINVAL));
+
+ struct sockaddr_storage bind2_storage = InetLoopbackAddr();
+ struct sockaddr* bind2_addr =
+ reinterpret_cast<struct sockaddr*>(&bind2_storage);
+ FileDescriptor bind2 =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP));
+ ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr));
+
+ // Try to connect again.
+ EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds());
+
+ // Check that peer name changed.
+ peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallSucceeds());
+ EXPECT_EQ(peerlen, addrlen_);
+ EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, ConnectAnyZero) {
+ // TODO(138658473): Enable when we can connect to port 0 with gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, ConnectAnyWithPort) {
+ ASSERT_NO_ERRNO(BindAny());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectAny) {
+ // TODO(138658473): Enable when we can connect to port 0 with gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+ struct sockaddr_storage any = InetAnyAddr();
+ EXPECT_THAT(
+ connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_),
+ SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+
+ Disconnect(sock_.get());
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) {
+ ASSERT_NO_ERRNO(BindAny());
+ EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr));
+
+ Disconnect(sock_.get());
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterBind) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Bind to the next port above bind_.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr));
+
+ // Connect the socket.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct sockaddr_storage unspec = {};
+ unspec.ss_family = AF_UNSPEC;
+ EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec),
+ sizeof(unspec.ss_family)),
+ SyscallSucceeds());
+
+ // Check that we're still bound.
+ socklen_t addrlen = sizeof(unspec);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0);
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) {
+ ASSERT_NO_ERRNO(BindAny());
+
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ socklen_t addrlen = sizeof(addr);
+
+ // Connect the socket.
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds());
+
+ // If the socket is bound to ANY and connected to a loopback address,
+ // getsockname() has to return the loopback address.
+ if (GetParam() == AddressFamily::kIpv4) {
+ auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr);
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK));
+ } else {
+ auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr);
+ struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT;
+ EXPECT_EQ(addrlen, sizeof(*addr_out));
+ EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0);
+ }
+}
+
+TEST_P(UdpSocketTest, DisconnectAfterBindToAny) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ struct sockaddr_storage any_storage = InetAnyAddr();
+ struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage);
+ SetPort(&any_storage, *Port(&bind_addr_storage_) + 1);
+
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), any));
+
+ // Connect the socket.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ Disconnect(sock_.get());
+
+ // Check that we're still bound.
+ struct sockaddr_storage addr;
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(memcmp(&addr, any, addrlen), 0);
+
+ addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, Disconnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ struct sockaddr_storage any_storage = InetAnyAddr();
+ struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage);
+ SetPort(&any_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_NO_ERRNO(BindSocket(sock_.get(), any));
+
+ for (int i = 0; i < 2; i++) {
+ // Try to connect again.
+ EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Check that we're connected to the right peer.
+ struct sockaddr_storage peer;
+ socklen_t peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallSucceeds());
+ EXPECT_EQ(peerlen, addrlen_);
+ EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0);
+
+ // Try to disconnect.
+ struct sockaddr_storage addr = {};
+ addr.ss_family = AF_UNSPEC;
+ EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr),
+ sizeof(addr.ss_family)),
+ SyscallSucceeds());
+
+ peerlen = sizeof(peer);
+ EXPECT_THAT(
+ getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen),
+ SyscallFailsWithErrno(ENOTCONN));
+
+ // Check that we're still bound.
+ socklen_t addrlen = sizeof(addr);
+ EXPECT_THAT(
+ getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+ EXPECT_EQ(addrlen, addrlen_);
+ EXPECT_EQ(*Port(&addr), *Port(&any_storage));
+ }
+}
+
+TEST_P(UdpSocketTest, ConnectBadAddress) {
+ struct sockaddr addr = {};
+ addr.sa_family = GetFamily();
+ ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ struct sockaddr_storage addr_storage = InetAnyAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Send to a different destination than we're connected to.
+ char buf[512];
+ EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+}
+
+TEST_P(UdpSocketTest, ZerolengthWriteAllowed) {
+ // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ // Connect to loopback:bind_addr_+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:bind_addr_+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ char buf[3];
+ // Send zero length packet from bind_ to sock_.
+ ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {sock_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000),
+ SyscallSucceedsWithValue(1));
+
+ // Receive the packet.
+ char received[3];
+ EXPECT_THAT(read(sock_.get(), received, sizeof(received)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) {
+ // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:bind_addr_port+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Set sock to non-blocking.
+ int opts = 0;
+ ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds());
+ ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK),
+ SyscallSucceeds());
+
+ char buf[3];
+ // Send zero length packet from bind_ to sock_.
+ ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {sock_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // Receive the packet.
+ char received[3];
+ EXPECT_THAT(read(sock_.get(), received, sizeof(received)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(read(sock_.get(), received, sizeof(received)),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+TEST_P(UdpSocketTest, SendAndReceiveNotConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send some data to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, SendAndReceiveConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:TestPort+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, ReceiveFromNotConnected) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:bind_addr_port+2.
+ struct sockaddr_storage addr2_storage = InetLoopbackAddr();
+ struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage);
+ SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2);
+ ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Check that the data isn't received because it was sent from a different
+ // address than we're connected.
+ EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, ReceiveBeforeConnect) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Bind sock to loopback:bind_addr_port+2.
+ struct sockaddr_storage addr2_storage = InetLoopbackAddr();
+ struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage);
+ SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2);
+ ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Connect to loopback:TestPort+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Receive the data. It works because it was sent before the connect.
+ char received[sizeof(buf)];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+
+ // Send again. This time it should not be received.
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, ReceiveFrom) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind sock to loopback:TestPort+1.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Send some data from sock to bind_.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ // Receive the data and sender address.
+ char received[sizeof(buf)];
+ struct sockaddr_storage addr2;
+ socklen_t addr2len = sizeof(addr2);
+ EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0,
+ reinterpret_cast<sockaddr*>(&addr2), &addr2len),
+ SyscallSucceedsWithValue(sizeof(received)));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+ EXPECT_EQ(addr2len, addrlen_);
+ EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0);
+}
+
+TEST_P(UdpSocketTest, Listen) {
+ ASSERT_THAT(listen(sock_.get(), SOMAXCONN),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+TEST_P(UdpSocketTest, Accept) {
+ ASSERT_THAT(accept(sock_.get(), nullptr, nullptr),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+// This test validates that a read shutdown with pending data allows the read
+// to proceed with the data before returning EAGAIN.
+TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ // Bind to loopback:bind_addr_port+1 and connect to bind_addr_.
+ ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Verify that we get EWOULDBLOCK when there is nothing to read.
+ char received[512];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ const char* buf = "abc";
+ EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3));
+
+ int opts = 0;
+ ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds());
+ ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK),
+ SyscallSucceeds());
+ ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds());
+ ASSERT_NE(opts & O_NONBLOCK, 0);
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds());
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // We should get the data even though read has been shutdown.
+ EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2));
+
+ // Because we read less than the entire packet length, since it's a packet
+ // based socket any subsequent reads should return EWOULDBLOCK.
+ EXPECT_THAT(recv(bind_.get(), received, 1, 0),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+// This test is validating that even after a socket is shutdown if it's
+// reconnected it will reset the shutdown state.
+TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) {
+ char received[512];
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Connect the socket, then try to shutdown again.
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Connect to loopback:bind_addr_port+1.
+ struct sockaddr_storage addr_storage = InetLoopbackAddr();
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
+ SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1);
+ ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST_P(UdpSocketTest, ReadShutdown) {
+ // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without
+ // MSG_DONTWAIT blocks indefinitely.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ char received[512];
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Connect the socket, then try to shutdown again.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds());
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UdpSocketTest, ReadShutdownDifferentThread) {
+ // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without
+ // MSG_DONTWAIT blocks indefinitely.
+ SKIP_IF(IsRunningWithHostinet());
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ char received[512];
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Connect the socket, then shutdown from another thread.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Milliseconds(200));
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds());
+ });
+ EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(0));
+ t.Join();
+
+ EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_P(UdpSocketTest, WriteShutdown) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds());
+}
+
+TEST_P(UdpSocketTest, SynchronousReceive) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send some data to bind_ from another thread.
+ char buf[512];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ // Receive the data prior to actually starting the other thread.
+ char received[512];
+ EXPECT_THAT(
+ RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Start the thread.
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Milliseconds(200));
+ ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_,
+ this->addrlen_),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ });
+
+ EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0),
+ SyscallSucceedsWithValue(512));
+ EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0);
+}
+
+TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send 3 packets from sock to bind_.
+ constexpr int psize = 100;
+ char buf[3 * psize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_THAT(
+ sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(psize));
+ }
+
+ // Receive the data as 3 separate packets.
+ char received[6 * psize];
+ for (int i = 0; i < 3; ++i) {
+ EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0),
+ SyscallSucceedsWithValue(psize));
+ }
+ EXPECT_EQ(memcmp(buf, received, 3 * psize), 0);
+}
+
+TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Direct writes from sock to bind_.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Send 2 packets from sock to bind_, where each packet's data consists of
+ // 2 discontiguous iovecs.
+ constexpr size_t kPieceSize = 100;
+ char buf[4 * kPieceSize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[2];
+ for (int j = 0; j < 2; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ ASSERT_THAT(writev(sock_.get(), iov, 2),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+
+ // Receive the data as 2 separate packets.
+ char received[6 * kPieceSize];
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[3];
+ for (int j = 0; j < 3; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ ASSERT_THAT(readv(bind_.get(), iov, 3),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+ EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0);
+}
+
+TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Send 2 packets from sock to bind_, where each packet's data consists of
+ // 2 discontiguous iovecs.
+ constexpr size_t kPieceSize = 100;
+ char buf[4 * kPieceSize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[2];
+ for (int j = 0; j < 2; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ struct msghdr msg = {};
+ msg.msg_name = bind_addr_;
+ msg.msg_namelen = addrlen_;
+ msg.msg_iov = iov;
+ msg.msg_iovlen = 2;
+ ASSERT_THAT(sendmsg(sock_.get(), &msg, 0),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+
+ // Receive the data as 2 separate packets.
+ char received[6 * kPieceSize];
+ for (int i = 0; i < 2; i++) {
+ struct iovec iov[3];
+ for (int j = 0; j < 3; j++) {
+ iov[j].iov_base = reinterpret_cast<void*>(
+ reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize);
+ iov[j].iov_len = kPieceSize;
+ }
+ struct msghdr msg = {};
+ msg.msg_iov = iov;
+ msg.msg_iovlen = 3;
+ ASSERT_THAT(recvmsg(bind_.get(), &msg, 0),
+ SyscallSucceedsWithValue(2 * kPieceSize));
+ }
+ EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0);
+}
+
+TEST_P(UdpSocketTest, FIONREADShutdown) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ int n = -1;
+ EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ // A UDP socket must be connected before it can be shutdown.
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+}
+
+TEST_P(UdpSocketTest, FIONREADWriteShutdown) {
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // A UDP socket must be connected before it can be shutdown.
+ ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ const char str[] = "abc";
+ ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0),
+ SyscallSucceedsWithValue(sizeof(str)));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, sizeof(str));
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, sizeof(str));
+}
+
+// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the
+// corresponding macro and become `0x541B`.
+TEST_P(UdpSocketTest, Fionread) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Check that the bound socket with an empty buffer reports an empty first
+ // packet.
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ // Send 3 packets from sock to bind_.
+ constexpr int psize = 100;
+ char buf[3 * psize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_THAT(
+ sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(psize));
+
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // Check that regardless of how many packets are in the queue, the size
+ // reported is that of a single packet.
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, psize);
+ }
+}
+
+TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) {
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // Check that the bound socket with an empty buffer reports an empty first
+ // packet.
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ // Send 3 packets from sock to bind_.
+ constexpr int psize = 100;
+ char buf[3 * psize];
+ RandomizeBuffer(buf, sizeof(buf));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ for (int i = 0; i < 3; ++i) {
+ ASSERT_THAT(
+ sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(0));
+
+ // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet
+ // socket does not cause a poll event to be triggered.
+ if (!IsRunningWithHostinet()) {
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+ }
+
+ // Check that regardless of how many packets are in the queue, the size
+ // reported is that of a single packet.
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+ }
+}
+
+TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) {
+ int n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ // A UDP socket must be connected before it can be shutdown.
+ ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ const char str[] = "abc";
+ ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0));
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds());
+
+ n = -1;
+ EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0));
+ EXPECT_EQ(n, 0);
+}
+
+TEST_P(UdpSocketTest, SoNoCheckOffByDefault) {
+ // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = -1;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
+TEST_P(UdpSocketTest, SoNoCheck) {
+ // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = kSockOptOn;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen),
+ SyscallSucceeds());
+ v = -1;
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOn);
+ ASSERT_EQ(optlen, sizeof(v));
+
+ v = kSockOptOff;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen),
+ SyscallSucceeds());
+ v = -1;
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
+TEST_P(UdpSocketTest, SoTimestampOffByDefault) {
+ // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by
+ // hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ int v = -1;
+ socklen_t optlen = sizeof(v);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen),
+ SyscallSucceeds());
+ ASSERT_EQ(v, kSockOptOff);
+ ASSERT_EQ(optlen, sizeof(v));
+}
+
+TEST_P(UdpSocketTest, SoTimestamp) {
+ // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not
+ // supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ int v = 1;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)),
+ SyscallSucceeds());
+
+ char buf[3];
+ // Send zero length packet from sock to bind_.
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0),
+ SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))];
+ msghdr msg;
+ memset(&msg, 0, sizeof(msg));
+ iovec iov;
+ memset(&iov, 0, sizeof(iov));
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = sizeof(cmsgbuf);
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0),
+ SyscallSucceedsWithValue(0));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval)));
+
+ struct timeval tv = {};
+ memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval));
+
+ ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
+
+ // There should be nothing to get via ioctl.
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(UdpSocketTest, WriteShutdownNotConnected) {
+ EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
+}
+
+TEST_P(UdpSocketTest, TimestampIoctl) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ char buf[3];
+ // Send packet from sock to bind_.
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // There should be no control messages.
+ char recv_buf[sizeof(buf)];
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf)));
+
+ // A nonzero timeval should be available via ioctl.
+ struct timeval tv = {};
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds());
+ ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
+}
+
+TEST_P(UdpSocketTest, TimestampIoctlNothingRead) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ struct timeval tv = {};
+ ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+// Test that the timestamp accessed via SIOCGSTAMP is still accessible after
+// SO_TIMESTAMP is enabled and used to retrieve a timestamp.
+TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
+ // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not
+ // supported by hostinet.
+ SKIP_IF(IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ char buf[3];
+ // Send packet from sock to bind_.
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)),
+ SyscallSucceedsWithValue(sizeof(buf)));
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0),
+ SyscallSucceedsWithValue(0));
+
+ struct pollfd pfd = {bind_.get(), POLLIN, 0};
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // There should be no control messages.
+ char recv_buf[sizeof(buf)];
+ ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf)));
+
+ // A nonzero timeval should be available via ioctl.
+ struct timeval tv = {};
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds());
+ ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0);
+
+ // Enable SO_TIMESTAMP and send a message.
+ int v = 1;
+ EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)),
+ SyscallSucceeds());
+ ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000),
+ SyscallSucceedsWithValue(1));
+
+ // There should be a message for SO_TIMESTAMP.
+ char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))];
+ msghdr msg = {};
+ iovec iov = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cmsgbuf;
+ msg.msg_controllen = sizeof(cmsgbuf);
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0),
+ SyscallSucceedsWithValue(0));
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+
+ // The ioctl should return the exact same values as before.
+ struct timeval tv2 = {};
+ ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds());
+ ASSERT_EQ(tv.tv_sec, tv2.tv_sec);
+ ASSERT_EQ(tv.tv_usec, tv2.tv_usec);
+}
+
+// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on
+// outgoing packets, and that a receiving socket with IP_RECVTOS or
+// IPV6_RECVTCLASS will create the corresponding control message.
+TEST_P(UdpSocketTest, SetAndReceiveTOS) {
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ int recv_level = SOL_IP;
+ int recv_type = IP_RECVTOS;
+ if (GetParam() != AddressFamily::kIpv4) {
+ recv_level = SOL_IPV6;
+ recv_type = IPV6_RECVTCLASS;
+ }
+ ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Set socket TOS.
+ int sent_level = recv_level;
+ int sent_type = IP_TOS;
+ if (sent_level == SOL_IPV6) {
+ sent_type = IPV6_TCLASS;
+ }
+ int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value.
+ ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos,
+ sizeof(sent_tos)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ struct msghdr sent_msg = {};
+ struct iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = &sent_data[0];
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ // Receive message.
+ struct msghdr received_msg = {};
+ struct iovec received_iov = {};
+ char received_data[kDataLength];
+ received_iov.iov_base = &received_data[0];
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ size_t cmsg_data_len = sizeof(int8_t);
+ if (sent_type == IPV6_TCLASS) {
+ cmsg_data_len = sizeof(int);
+ }
+ std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len));
+ received_msg.msg_control = &received_cmsgbuf[0];
+ received_msg.msg_controllen = received_cmsgbuf.size();
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, sent_level);
+ EXPECT_EQ(cmsg->cmsg_type, sent_type);
+ int8_t received_tos = 0;
+ memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos));
+ EXPECT_EQ(received_tos, sent_tos);
+}
+
+// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the
+// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or
+// IPV6_RECVTCLASS will create the corresponding control message.
+TEST_P(UdpSocketTest, SendAndReceiveTOS) {
+ // TODO(b/146661005): Setting TOS via cmsg not supported for netstack.
+ SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
+
+ ASSERT_NO_ERRNO(BindLoopback());
+ ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ int recv_level = SOL_IP;
+ int recv_type = IP_RECVTOS;
+ if (GetParam() != AddressFamily::kIpv4) {
+ recv_level = SOL_IPV6;
+ recv_type = IPV6_RECVTCLASS;
+ }
+ int recv_opt = kSockOptOn;
+ ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt,
+ sizeof(recv_opt)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ int sent_level = recv_level;
+ int sent_type = IP_TOS;
+ int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value.
+
+ struct msghdr sent_msg = {};
+ struct iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = &sent_data[0];
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ size_t cmsg_data_len = sizeof(int8_t);
+ if (sent_level == SOL_IPV6) {
+ sent_type = IPV6_TCLASS;
+ cmsg_data_len = sizeof(int);
+ }
+ std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len));
+ sent_msg.msg_control = &sent_cmsgbuf[0];
+ sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+
+ // Manually add control message.
+ struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg);
+ sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len);
+ sent_cmsg->cmsg_level = sent_level;
+ sent_cmsg->cmsg_type = sent_type;
+ *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ // Receive message.
+ struct msghdr received_msg = {};
+ struct iovec received_iov = {};
+ char received_data[kDataLength];
+ received_iov.iov_base = &received_data[0];
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len));
+ received_msg.msg_control = &received_cmsgbuf[0];
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, sent_level);
+ EXPECT_EQ(cmsg->cmsg_type, sent_type);
+ int8_t received_tos = 0;
+ memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos));
+ EXPECT_EQ(received_tos, sent_tos);
+}
+
+TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) {
+ // Discover minimum buffer size by setting it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ int min = 0;
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+
+ // Bind bind_ to loopback.
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ {
+ // Send data of size min and verify that it's received.
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(received.size()));
+ }
+
+ {
+ // Send data of size min + 1 and verify that its received. Both linux and
+ // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer
+ // is currently empty.
+ std::vector<char> buf(min + 1);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(received.size()));
+ }
+}
+
+// Test that receive buffer limits are enforced.
+TEST_P(UdpSocketTest, RecvBufLimits) {
+ // Bind s_ to loopback.
+ ASSERT_NO_ERRNO(BindLoopback());
+
+ int min = 0;
+ {
+ // Discover minimum buffer size by trying to set it to zero.
+ constexpr int kRcvBufSz = 0;
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz,
+ sizeof(kRcvBufSz)),
+ SyscallSucceeds());
+
+ socklen_t min_len = sizeof(min);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len),
+ SyscallSucceeds());
+ }
+
+ // Now set the limit to min * 4.
+ int new_rcv_buf_sz = min * 4;
+ if (!IsRunningOnGvisor() || IsRunningWithHostinet()) {
+ // Linux doubles the value specified so just set to min * 2.
+ new_rcv_buf_sz = min * 2;
+ }
+
+ ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz,
+ sizeof(new_rcv_buf_sz)),
+ SyscallSucceeds());
+ int rcv_buf_sz = 0;
+ {
+ socklen_t rcv_buf_len = sizeof(rcv_buf_sz);
+ ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz,
+ &rcv_buf_len),
+ SyscallSucceeds());
+ }
+
+ {
+ std::vector<char> buf(min);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ int sent = 4;
+ if (IsRunningOnGvisor() && !IsRunningWithHostinet()) {
+ // Linux seems to drop the 4th packet even though technically it should
+ // fit in the receive buffer.
+ ASSERT_THAT(
+ sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_),
+ SyscallSucceedsWithValue(buf.size()));
+ sent++;
+ }
+
+ for (int i = 0; i < sent - 1; i++) {
+ // Receive the data.
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallSucceedsWithValue(received.size()));
+ EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0);
+ }
+
+ // The last receive should fail with EAGAIN as the last packet should have
+ // been dropped due to lack of space in the receive buffer.
+ std::vector<char> received(buf.size());
+ EXPECT_THAT(
+ recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+ }
+}
+
+#ifndef __fuchsia__
+
+// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+// gVisor currently silently ignores attaching a filter.
+TEST_P(UdpSocketTest, SetSocketDetachFilter) {
+ // Program generated using sudo tcpdump -i lo udp and port 1234 -dd
+ struct sock_filter code[] = {
+ {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd},
+ {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011},
+ {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2},
+ {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2},
+ {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017},
+ {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014},
+ {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e},
+ {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2},
+ {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2},
+ {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000},
+ };
+ struct sock_fprog bpf = {
+ .len = ABSL_ARRAYSIZE(code),
+ .filter = code,
+ };
+ ASSERT_THAT(
+ setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)),
+ SyscallSucceeds());
+
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallSucceeds());
+}
+
+TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) {
+ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER.
+ SKIP_IF(IsRunningOnGvisor());
+ constexpr int val = 0;
+ ASSERT_THAT(
+ setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)),
+ SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_P(UdpSocketTest, GetSocketDetachFilter) {
+ int val = 0;
+ socklen_t val_len = sizeof(val);
+ ASSERT_THAT(
+ getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len),
+ SyscallFailsWithErrno(ENOPROTOOPT));
+}
+
+#endif // __fuchsia__
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h
new file mode 100644
index 000000000..f7e25c805
--- /dev/null
+++ b/test/syscalls/linux/udp_socket_test_cases.h
@@ -0,0 +1,82 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
+#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
+
+#include <sys/socket.h>
+
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// The initial port to be be used on gvisor.
+constexpr int TestPort = 40000;
+
+// Fixture for tests parameterized by the address family to use (AF_INET and
+// AF_INET6) when creating sockets.
+class UdpSocketTest
+ : public ::testing::TestWithParam<gvisor::testing::AddressFamily> {
+ protected:
+ // Creates two sockets that will be used by test cases.
+ void SetUp() override;
+
+ // Binds the socket bind_ to the loopback and updates bind_addr_.
+ PosixError BindLoopback();
+
+ // Binds the socket bind_ to Any and updates bind_addr_.
+ PosixError BindAny();
+
+ // Binds given socket to address addr and updates.
+ PosixError BindSocket(int socket, struct sockaddr* addr);
+
+ // Return initialized Any address to port 0.
+ struct sockaddr_storage InetAnyAddr();
+
+ // Return initialized Loopback address to port 0.
+ struct sockaddr_storage InetLoopbackAddr();
+
+ // Disconnects socket sockfd.
+ void Disconnect(int sockfd);
+
+ // Get family for the test.
+ int GetFamily();
+
+ // Socket used by Bind methods
+ FileDescriptor bind_;
+
+ // Second socket used for tests.
+ FileDescriptor sock_;
+
+ // Address for bind_ socket.
+ struct sockaddr* bind_addr_;
+
+ // Initialized to the length based on GetFamily().
+ socklen_t addrlen_;
+
+ // Storage for bind_addr_.
+ struct sockaddr_storage bind_addr_storage_;
+
+ private:
+ // Helper to initialize addrlen_ for the test case.
+ socklen_t GetAddrLength();
+};
+} // namespace testing
+} // namespace gvisor
+
+#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
index 6218fbce1..64d6d0b8f 100644
--- a/test/syscalls/linux/uidgid.cc
+++ b/test/syscalls/linux/uidgid.cc
@@ -14,6 +14,7 @@
#include <errno.h>
#include <grp.h>
+#include <sys/resource.h>
#include <sys/types.h>
#include <unistd.h>
@@ -249,6 +250,26 @@ TEST(UidGidRootTest, Setgroups) {
SyscallFailsWithErrno(EFAULT));
}
+TEST(UidGidRootTest, Setuid_prlimit) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
+
+ // Do seteuid in a separate thread so that after finishing this test, the
+ // process can still open files the test harness created before starting this
+ // test. Otherwise, the files are created by root (UID before the test), but
+ // cannot be opened by the `uid` set below after the test.
+ ScopedThread([&] {
+ // Use syscall instead of glibc setuid wrapper because we want this seteuid
+ // call to only apply to this task. POSIX threads, however, require that all
+ // threads have the same UIDs, so using the seteuid wrapper sets all
+ // threads' UID.
+ EXPECT_THAT(syscall(SYS_setreuid, -1, 65534), SyscallSucceeds());
+
+ // Despite the UID change, we should be able to get our own limits.
+ struct rlimit rl = {};
+ EXPECT_THAT(prlimit(0, RLIMIT_NOFILE, NULL, &rl), SyscallSucceeds());
+ });
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/unix_domain_socket_test_util.cc b/test/syscalls/linux/unix_domain_socket_test_util.cc
index 7fb9eed8d..b05ab2900 100644
--- a/test/syscalls/linux/unix_domain_socket_test_util.cc
+++ b/test/syscalls/linux/unix_domain_socket_test_util.cc
@@ -15,6 +15,7 @@
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
#include <sys/un.h>
+
#include <vector>
#include "gtest/gtest.h"
diff --git a/test/syscalls/linux/unix_domain_socket_test_util.h b/test/syscalls/linux/unix_domain_socket_test_util.h
index 5eca0b7f0..b8073db17 100644
--- a/test/syscalls/linux/unix_domain_socket_test_util.h
+++ b/test/syscalls/linux/unix_domain_socket_test_util.h
@@ -16,6 +16,7 @@
#define GVISOR_TEST_SYSCALLS_UNIX_DOMAIN_SOCKET_TEST_UTIL_H_
#include <string>
+
#include "test/syscalls/linux/socket_test_util.h"
namespace gvisor {
diff --git a/test/syscalls/linux/utimes.cc b/test/syscalls/linux/utimes.cc
index 80716859a..e647d2896 100644
--- a/test/syscalls/linux/utimes.cc
+++ b/test/syscalls/linux/utimes.cc
@@ -20,6 +20,7 @@
#include <time.h>
#include <unistd.h>
#include <utime.h>
+
#include <string>
#include "absl/time/time.h"
@@ -33,17 +34,10 @@ namespace testing {
namespace {
-// TODO(b/36516566): utimes(nullptr) does not pick the "now" time in the
-// application's time domain, so when asserting that times are within a window,
-// we expand the window to allow for differences between the time domains.
-constexpr absl::Duration kClockSlack = absl::Milliseconds(100);
-
// TimeBoxed runs fn, setting before and after to (coarse realtime) times
// guaranteed* to come before and after fn started and completed, respectively.
//
// fn may be called more than once if the clock is adjusted.
-//
-// * See the comment on kClockSlack. gVisor breaks this guarantee.
void TimeBoxed(absl::Time* before, absl::Time* after,
std::function<void()> const& fn) {
do {
@@ -54,12 +48,15 @@ void TimeBoxed(absl::Time* before, absl::Time* after,
// filesystems set it to 1, so we don't do any truncation.
struct timespec ts;
EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds());
- *before = absl::TimeFromTimespec(ts);
+ // FIXME(b/132819225): gVisor filesystem timestamps inconsistently use the
+ // internal or host clock, which may diverge slightly. Allow some slack on
+ // times to account for the difference.
+ *before = absl::TimeFromTimespec(ts) - absl::Seconds(1);
fn();
EXPECT_THAT(clock_gettime(CLOCK_REALTIME_COARSE, &ts), SyscallSucceeds());
- *after = absl::TimeFromTimespec(ts);
+ *after = absl::TimeFromTimespec(ts) + absl::Seconds(1);
if (*after < *before) {
// Clock jumped backwards; retry.
@@ -68,23 +65,17 @@ void TimeBoxed(absl::Time* before, absl::Time* after,
// which could lead to test failures, but that is very unlikely to happen.
continue;
}
-
- if (IsRunningOnGvisor()) {
- // See comment on kClockSlack.
- *before -= kClockSlack;
- *after += kClockSlack;
- }
} while (*after < *before);
}
void TestUtimesOnPath(std::string const& path) {
struct stat statbuf;
- struct timeval times[2] = {{1, 0}, {2, 0}};
+ struct timeval times[2] = {{10, 0}, {20, 0}};
EXPECT_THAT(utimes(path.c_str(), times), SyscallSucceeds());
EXPECT_THAT(stat(path.c_str(), &statbuf), SyscallSucceeds());
- EXPECT_EQ(1, statbuf.st_atime);
- EXPECT_EQ(2, statbuf.st_mtime);
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
absl::Time before;
absl::Time after;
@@ -115,18 +106,18 @@ TEST(UtimesTest, OnDir) {
TEST(UtimesTest, MissingPath) {
auto path = NewTempAbsPath();
- struct timeval times[2] = {{1, 0}, {2, 0}};
+ struct timeval times[2] = {{10, 0}, {20, 0}};
EXPECT_THAT(utimes(path.c_str(), times), SyscallFailsWithErrno(ENOENT));
}
void TestFutimesat(int dirFd, std::string const& path) {
struct stat statbuf;
- struct timeval times[2] = {{1, 0}, {2, 0}};
+ struct timeval times[2] = {{10, 0}, {20, 0}};
EXPECT_THAT(futimesat(dirFd, path.c_str(), times), SyscallSucceeds());
EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds());
- EXPECT_EQ(1, statbuf.st_atime);
- EXPECT_EQ(2, statbuf.st_mtime);
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
absl::Time before;
absl::Time after;
@@ -162,12 +153,12 @@ TEST(FutimesatTest, OnRelPath) {
TEST(FutimesatTest, InvalidNsec) {
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
struct timeval times[4][2] = {{
- {0, 1}, // Valid
+ {0, 1}, // Valid
{1, static_cast<int64_t>(1e7)} // Invalid
},
{
{1, static_cast<int64_t>(1e7)}, // Invalid
- {0, 1} // Valid
+ {0, 1} // Valid
},
{
{0, 1}, // Valid
@@ -187,11 +178,11 @@ TEST(FutimesatTest, InvalidNsec) {
void TestUtimensat(int dirFd, std::string const& path) {
struct stat statbuf;
- const struct timespec times[2] = {{1, 0}, {2, 0}};
+ const struct timespec times[2] = {{10, 0}, {20, 0}};
EXPECT_THAT(utimensat(dirFd, path.c_str(), times, 0), SyscallSucceeds());
EXPECT_THAT(fstatat(dirFd, path.c_str(), &statbuf, 0), SyscallSucceeds());
- EXPECT_EQ(1, statbuf.st_atime);
- EXPECT_EQ(2, statbuf.st_mtime);
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
// Test setting with UTIME_NOW and UTIME_OMIT.
struct stat statbuf2;
@@ -234,10 +225,7 @@ void TestUtimensat(int dirFd, std::string const& path) {
EXPECT_GE(mtime3, before);
EXPECT_LE(mtime3, after);
- if (!IsRunningOnGvisor()) {
- // FIXME(b/36516566): Gofers set atime and mtime to different "now" times.
- EXPECT_EQ(atime3, mtime3);
- }
+ EXPECT_EQ(atime3, mtime3);
}
TEST(UtimensatTest, OnAbsPath) {
@@ -287,14 +275,15 @@ TEST(UtimeTest, ZeroAtimeandMtime) {
TEST(UtimensatTest, InvalidNsec) {
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- struct timespec times[2][2] = {{
- {0, UTIME_OMIT}, // Valid
- {2, static_cast<int64_t>(1e10)} // Invalid
- },
- {
- {2, static_cast<int64_t>(1e10)}, // Invalid
- {0, UTIME_OMIT} // Valid
- }};
+ struct timespec times[2][2] = {
+ {
+ {0, UTIME_OMIT}, // Valid
+ {2, static_cast<int64_t>(1e10)} // Invalid
+ },
+ {
+ {2, static_cast<int64_t>(1e10)}, // Invalid
+ {0, UTIME_OMIT} // Valid
+ }};
for (unsigned int i = 0; i < sizeof(times) / sizeof(times[0]); i++) {
std::cout << "test:" << i << "\n";
@@ -315,13 +304,13 @@ TEST(Utimensat, NullPath) {
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDWR));
struct stat statbuf;
- const struct timespec times[2] = {{1, 0}, {2, 0}};
+ const struct timespec times[2] = {{10, 0}, {20, 0}};
// Call syscall directly.
EXPECT_THAT(syscall(SYS_utimensat, fd.get(), NULL, times, 0),
SyscallSucceeds());
EXPECT_THAT(fstatat(0, f.path().c_str(), &statbuf, 0), SyscallSucceeds());
- EXPECT_EQ(1, statbuf.st_atime);
- EXPECT_EQ(2, statbuf.st_mtime);
+ EXPECT_EQ(10, statbuf.st_atime);
+ EXPECT_EQ(20, statbuf.st_mtime);
}
} // namespace
diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc
index 40c0014b9..ce1899f45 100644
--- a/test/syscalls/linux/vdso_clock_gettime.cc
+++ b/test/syscalls/linux/vdso_clock_gettime.cc
@@ -17,6 +17,7 @@
#include <syscall.h>
#include <time.h>
#include <unistd.h>
+
#include <map>
#include <string>
#include <utility>
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
index 0aaba482d..19d05998e 100644
--- a/test/syscalls/linux/vfork.cc
+++ b/test/syscalls/linux/vfork.cc
@@ -191,5 +191,5 @@ int main(int argc, char** argv) {
return gvisor::testing::RunChild();
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/vsyscall.cc b/test/syscalls/linux/vsyscall.cc
index 2c2303358..ae4377108 100644
--- a/test/syscalls/linux/vsyscall.cc
+++ b/test/syscalls/linux/vsyscall.cc
@@ -24,6 +24,7 @@ namespace testing {
namespace {
+#if defined(__x86_64__) || defined(__i386__)
time_t vsyscall_time(time_t* t) {
constexpr uint64_t kVsyscallTimeEntry = 0xffffffffff600400;
return reinterpret_cast<time_t (*)(time_t*)>(kVsyscallTimeEntry)(t);
@@ -37,6 +38,7 @@ TEST(VsyscallTest, VsyscallAlwaysAvailableOnGvisor) {
time_t t;
EXPECT_THAT(vsyscall_time(&t), SyscallSucceeds());
}
+#endif
} // namespace
diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc
index 9b219cfd6..39b5b2f56 100644
--- a/test/syscalls/linux/write.cc
+++ b/test/syscalls/linux/write.cc
@@ -31,14 +31,8 @@ namespace gvisor {
namespace testing {
namespace {
-// This test is currently very rudimentary.
-//
-// TODO(edahlgren):
-// * bad buffer states (EFAULT).
-// * bad fds (wrong permission, wrong type of file, EBADF).
-// * check offset is incremented.
-// * check for EOF.
-// * writing to pipes, symlinks, special files.
+
+// TODO(gvisor.dev/issue/2370): This test is currently very rudimentary.
class WriteTest : public ::testing::Test {
public:
ssize_t WriteBytes(int fd, int bytes) {
diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc
new file mode 100644
index 000000000..cbcf08451
--- /dev/null
+++ b/test/syscalls/linux/xattr.cc
@@ -0,0 +1,610 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <sys/types.h>
+#include <sys/xattr.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_set.h"
+#include "test/syscalls/linux/file_base.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+class XattrTest : public FileTest {};
+
+TEST_F(XattrTest, XattrNonexistentFile) {
+ const char* path = "/does/not/exist";
+ const char* name = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(listxattr(path, nullptr, 0), SyscallFailsWithErrno(ENOENT));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENOENT));
+}
+
+TEST_F(XattrTest, XattrNullName) {
+ const char* path = test_file_name_.c_str();
+
+ EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(getxattr(path, nullptr, nullptr, 0),
+ SyscallFailsWithErrno(EFAULT));
+ EXPECT_THAT(removexattr(path, nullptr), SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(XattrTest, XattrEmptyName) {
+ const char* path = test_file_name_.c_str();
+
+ EXPECT_THAT(setxattr(path, "", nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(getxattr(path, "", nullptr, 0), SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(removexattr(path, ""), SyscallFailsWithErrno(ERANGE));
+}
+
+TEST_F(XattrTest, XattrLargeName) {
+ const char* path = test_file_name_.c_str();
+ std::string name = "user.";
+ name += std::string(XATTR_NAME_MAX - name.length(), 'a');
+
+ if (!IsRunningOnGvisor()) {
+ // In gVisor, access to xattrs is controlled with an explicit list of
+ // allowed names. This name isn't going to be configured to allow access, so
+ // don't test it.
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
+ SyscallSucceedsWithValue(0));
+ }
+
+ name += "a";
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
+ SyscallFailsWithErrno(ERANGE));
+ EXPECT_THAT(removexattr(path, name.c_str()), SyscallFailsWithErrno(ERANGE));
+}
+
+TEST_F(XattrTest, XattrInvalidPrefix) {
+ const char* path = test_file_name_.c_str();
+ std::string name(XATTR_NAME_MAX, 'a');
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ EXPECT_THAT(getxattr(path, name.c_str(), nullptr, 0),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+ EXPECT_THAT(removexattr(path, name.c_str()),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
+// Do not allow save/restore cycles after making the test file read-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrReadOnly_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ DisableSave ds;
+ ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IRUSR));
+
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0),
+ SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EACCES));
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+}
+
+// Do not allow save/restore cycles after making the test file write-only, as
+// the restore will fail to open it with r/w permissions.
+TEST_F(XattrTest, XattrWriteOnly_NoRandomSave) {
+ // Drop capabilities that allow us to override file and directory permissions.
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false));
+ ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false));
+
+ DisableSave ds;
+ ASSERT_NO_ERRNO(testing::Chmod(test_file_name_, S_IWUSR));
+
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(EACCES));
+
+ // listxattr will succeed even without read permissions.
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(path, name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrTrustedWithNonadmin) {
+ // TODO(b/148380782): Support setxattr and getxattr with "trusted" prefix.
+ SKIP_IF(IsRunningOnGvisor());
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ const char* path = test_file_name_.c_str();
+ const char name[] = "trusted.abc";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, XattrOnDirectory) {
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(dir.path().c_str(), name, nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(getxattr(dir.path().c_str(), name, nullptr, 0),
+ SyscallSucceedsWithValue(0));
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(dir.path().c_str(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(dir.path().c_str(), name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrOnSymlink) {
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(link.path().c_str(), name, nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(getxattr(link.path().c_str(), name, nullptr, 0),
+ SyscallSucceedsWithValue(0));
+
+ char list[sizeof(name)];
+ EXPECT_THAT(listxattr(link.path().c_str(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(removexattr(link.path().c_str(), name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrOnInvalidFileTypes) {
+ const char name[] = "user.test";
+
+ char char_device[] = "/dev/zero";
+ EXPECT_THAT(setxattr(char_device, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(getxattr(char_device, name, nullptr, 0),
+ SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(listxattr(char_device, nullptr, 0), SyscallSucceedsWithValue(0));
+
+ // Use tmpfs, where creation of named pipes is supported.
+ const std::string fifo = NewTempAbsPathInDir("/dev/shm");
+ const char* path = fifo.c_str();
+ EXPECT_THAT(mknod(path, S_IFIFO | S_IRUSR | S_IWUSR, 0), SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(listxattr(path, nullptr, 0), SyscallSucceedsWithValue(0));
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(XattrTest, SetxattrSizeSmallerThanValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ size_t size = 1;
+ EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ std::vector<char> expected_buf = {'a', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, SetxattrZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds());
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, XATTR_SIZE_MAX),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(buf, '-');
+}
+
+TEST_F(XattrTest, SetxattrSizeTooLarge) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+
+ // Note that each particular fs implementation may stipulate a lower size
+ // limit, in which case we actually may fail (e.g. error with ENOSPC) for
+ // some sizes under XATTR_SIZE_MAX.
+ size_t size = XATTR_SIZE_MAX + 1;
+ std::vector<char> val(size);
+ EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
+ SyscallFailsWithErrno(E2BIG));
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0),
+ SyscallFailsWithErrno(EFAULT));
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, SetxattrNullValueAndZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val(XATTR_SIZE_MAX + 1);
+ std::fill(val.begin(), val.end(), 'a');
+ size_t size = 1;
+ EXPECT_THAT(setxattr(path, name, val.data(), size, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ std::vector<char> expected_buf = {'a', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), size),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, SetxattrReplaceWithSmaller) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ std::vector<char> expected_buf = {'a', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, SetxattrReplaceWithLarger) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ EXPECT_THAT(setxattr(path, name, val.data(), 1, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, val.data(), 2, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf = {'-', '-'};
+ EXPECT_THAT(getxattr(path, name, buf.data(), 2), SyscallSucceedsWithValue(2));
+ EXPECT_EQ(buf, val);
+}
+
+TEST_F(XattrTest, SetxattrCreateFlag) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE),
+ SyscallFailsWithErrno(EEXIST));
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, SetxattrReplaceFlag) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE),
+ SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE),
+ SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, SetxattrInvalidFlags) {
+ const char* path = test_file_name_.c_str();
+ int invalid_flags = 0xff;
+ EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(XattrTest, Getxattr) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ int val = 1234;
+ size_t size = sizeof(val);
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ int buf = 0;
+ EXPECT_THAT(getxattr(path, name, &buf, size), SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+}
+
+TEST_F(XattrTest, GetxattrSizeSmallerThanValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ std::vector<char> val = {'a', 'a'};
+ size_t size = val.size();
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, 1), SyscallFailsWithErrno(ERANGE));
+ EXPECT_EQ(buf, '-');
+}
+
+TEST_F(XattrTest, GetxattrSizeLargerThanValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, 1, /*flags=*/0), SyscallSucceeds());
+
+ std::vector<char> buf(XATTR_SIZE_MAX);
+ std::fill(buf.begin(), buf.end(), '-');
+ std::vector<char> expected_buf = buf;
+ expected_buf[0] = 'a';
+ EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(1));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, GetxattrZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+
+ char buf = '-';
+ EXPECT_THAT(getxattr(path, name, &buf, 0),
+ SyscallSucceedsWithValue(sizeof(val)));
+ EXPECT_EQ(buf, '-');
+}
+
+TEST_F(XattrTest, GetxattrSizeTooLarge) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ EXPECT_THAT(setxattr(path, name, &val, sizeof(val), /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> buf(XATTR_SIZE_MAX + 1);
+ std::fill(buf.begin(), buf.end(), '-');
+ std::vector<char> expected_buf = buf;
+ expected_buf[0] = 'a';
+ EXPECT_THAT(getxattr(path, name, buf.data(), buf.size()),
+ SyscallSucceedsWithValue(sizeof(val)));
+ EXPECT_EQ(buf, expected_buf);
+}
+
+TEST_F(XattrTest, GetxattrNullValue) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(getxattr(path, name, nullptr, size),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(XattrTest, GetxattrNullValueAndZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ char val = 'a';
+ size_t size = sizeof(val);
+ // Set value with zero size.
+ EXPECT_THAT(setxattr(path, name, &val, 0, /*flags=*/0), SyscallSucceeds());
+ // Get value with nonzero size.
+ EXPECT_THAT(getxattr(path, name, nullptr, size), SyscallSucceedsWithValue(0));
+
+ // Set value with nonzero size.
+ EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds());
+ // Get value with zero size.
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size));
+}
+
+TEST_F(XattrTest, GetxattrNonexistentName) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, Listxattr) {
+ const char* path = test_file_name_.c_str();
+ const std::string name = "user.test";
+ const std::string name2 = "user.test2";
+ const std::string name3 = "user.test3";
+ EXPECT_THAT(setxattr(path, name.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name2.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+ EXPECT_THAT(setxattr(path, name3.c_str(), nullptr, 0, /*flags=*/0),
+ SyscallSucceeds());
+
+ std::vector<char> list(name.size() + 1 + name2.size() + 1 + name3.size() + 1);
+ char* buf = list.data();
+ EXPECT_THAT(listxattr(path, buf, XATTR_SIZE_MAX),
+ SyscallSucceedsWithValue(list.size()));
+
+ absl::flat_hash_set<std::string> got = {};
+ for (char* p = buf; p < buf + list.size(); p += strlen(p) + 1) {
+ got.insert(std::string{p});
+ }
+
+ absl::flat_hash_set<std::string> expected = {name, name2, name3};
+ EXPECT_EQ(got, expected);
+}
+
+TEST_F(XattrTest, ListxattrNoXattrs) {
+ const char* path = test_file_name_.c_str();
+
+ std::vector<char> list, expected;
+ EXPECT_THAT(listxattr(path, list.data(), sizeof(list)),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(list, expected);
+
+ // Listxattr should succeed if there are no attributes, even if the buffer
+ // passed in is a nullptr.
+ EXPECT_THAT(listxattr(path, nullptr, sizeof(list)),
+ SyscallSucceedsWithValue(0));
+}
+
+TEST_F(XattrTest, ListxattrNullBuffer) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ EXPECT_THAT(listxattr(path, nullptr, sizeof(name)),
+ SyscallFailsWithErrno(EFAULT));
+}
+
+TEST_F(XattrTest, ListxattrSizeTooSmall) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+
+ char list[sizeof(name) - 1];
+ EXPECT_THAT(listxattr(path, list, sizeof(list)),
+ SyscallFailsWithErrno(ERANGE));
+}
+
+TEST_F(XattrTest, ListxattrZeroSize) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(listxattr(path, nullptr, 0),
+ SyscallSucceedsWithValue(sizeof(name)));
+}
+
+TEST_F(XattrTest, RemoveXattr) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds());
+ EXPECT_THAT(removexattr(path, name), SyscallSucceeds());
+ EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, RemoveXattrNonexistentName) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(ENODATA));
+}
+
+TEST_F(XattrTest, LXattrOnSymlink) {
+ const char name[] = "user.test";
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateSymlinkTo(dir.path(), test_file_name_));
+
+ EXPECT_THAT(lsetxattr(link.path().c_str(), name, nullptr, 0, 0),
+ SyscallFailsWithErrno(EPERM));
+ EXPECT_THAT(lgetxattr(link.path().c_str(), name, nullptr, 0),
+ SyscallFailsWithErrno(ENODATA));
+ EXPECT_THAT(llistxattr(link.path().c_str(), nullptr, 0),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(lremovexattr(link.path().c_str(), name),
+ SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(XattrTest, LXattrOnNonsymlink) {
+ const char* path = test_file_name_.c_str();
+ const char name[] = "user.test";
+ int val = 1234;
+ size_t size = sizeof(val);
+ EXPECT_THAT(lsetxattr(path, name, &val, size, /*flags=*/0),
+ SyscallSucceeds());
+
+ int buf = 0;
+ EXPECT_THAT(lgetxattr(path, name, &buf, size),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(llistxattr(path, list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(lremovexattr(path, name), SyscallSucceeds());
+}
+
+TEST_F(XattrTest, XattrWithFD) {
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_.c_str(), 0));
+ const char name[] = "user.test";
+ int val = 1234;
+ size_t size = sizeof(val);
+ EXPECT_THAT(fsetxattr(fd.get(), name, &val, size, /*flags=*/0),
+ SyscallSucceeds());
+
+ int buf = 0;
+ EXPECT_THAT(fgetxattr(fd.get(), name, &buf, size),
+ SyscallSucceedsWithValue(size));
+ EXPECT_EQ(buf, val);
+
+ char list[sizeof(name)];
+ EXPECT_THAT(flistxattr(fd.get(), list, sizeof(list)),
+ SyscallSucceedsWithValue(sizeof(name)));
+ EXPECT_STREQ(list, name);
+
+ EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/syscall_test_runner.sh b/test/syscalls/syscall_test_runner.sh
deleted file mode 100755
index 864bb2de4..000000000
--- a/test/syscalls/syscall_test_runner.sh
+++ /dev/null
@@ -1,34 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# syscall_test_runner.sh is a simple wrapper around the go syscall test runner.
-# It exists so that we can build the syscall test runner once, and use it for
-# all syscall tests, rather than build it for each test run.
-
-set -euf -x -o pipefail
-
-echo -- "$@"
-
-if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then
- mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}"
- chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}"
-fi
-
-# Get location of syscall_test_runner binary.
-readonly runner=$(find "${TEST_SRCDIR}" -name syscall_test_runner)
-
-# Pass the arguments of this script directly to the runner.
-exec "${runner}" "$@"
diff --git a/test/uds/BUILD b/test/uds/BUILD
index a3843e699..51e2c7ce8 100644
--- a/test/uds/BUILD
+++ b/test/uds/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(
default_visibility = ["//:sandbox"],
@@ -9,7 +9,6 @@ go_library(
name = "uds",
testonly = 1,
srcs = ["uds.go"],
- importpath = "gvisor.dev/gvisor/test/uds",
deps = [
"//pkg/log",
"//pkg/unet",
diff --git a/test/util/BUILD b/test/util/BUILD
index 5d2a9cc2c..2a17c33ee 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -1,5 +1,4 @@
-load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
-load("//test/syscalls:build_defs.bzl", "select_for_linux")
+load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system")
package(
default_visibility = ["//:sandbox"],
@@ -42,7 +41,7 @@ cc_library(
":save_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -56,7 +55,7 @@ cc_library(
":posix_error",
":test_util",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -68,7 +67,7 @@ cc_test(
":proc_util",
":test_main",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -88,7 +87,7 @@ cc_library(
":file_descriptor",
":posix_error",
"@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -102,7 +101,7 @@ cc_test(
":temp_path",
":test_main",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -135,19 +134,20 @@ cc_library(
":cleanup",
":posix_error",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
cc_library(
name = "save_util",
testonly = 1,
- srcs = ["save_util.cc"] +
- select_for_linux(
- ["save_util_linux.cc"],
- ["save_util_other.cc"],
- ),
+ srcs = [
+ "save_util.cc",
+ "save_util_linux.cc",
+ "save_util_other.cc",
+ ],
hdrs = ["save_util.h"],
+ defines = select_system(),
)
cc_library(
@@ -166,6 +166,14 @@ cc_library(
)
cc_library(
+ name = "platform_util",
+ testonly = 1,
+ srcs = ["platform_util.cc"],
+ hdrs = ["platform_util.h"],
+ deps = [":test_util"],
+)
+
+cc_library(
name = "posix_error",
testonly = 1,
srcs = ["posix_error.cc"],
@@ -175,7 +183,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -186,7 +194,7 @@ cc_test(
deps = [
":posix_error",
":test_main",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -210,7 +218,7 @@ cc_library(
":cleanup",
":posix_error",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -225,27 +233,34 @@ cc_library(
":test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
cc_library(
name = "test_util",
testonly = 1,
- srcs = ["test_util.cc"],
+ srcs = [
+ "test_util.cc",
+ "test_util_impl.cc",
+ "test_util_runfiles.cc",
+ ],
hdrs = ["test_util.h"],
+ defines = select_system(),
deps = [
":fs_util",
":logging",
":posix_error",
":save_util",
+ "@bazel_tools//tools/cpp/runfiles",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
+ gbenchmark,
],
)
@@ -277,7 +292,7 @@ cc_library(
":posix_error",
":test_util",
"@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -288,7 +303,7 @@ cc_test(
deps = [
":test_main",
":test_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -308,7 +323,7 @@ cc_library(
":file_descriptor",
":posix_error",
":save_util",
- "@com_google_googletest//:gtest",
+ gtest,
],
)
@@ -335,3 +350,9 @@ cc_library(
":save_util",
],
)
+
+cc_library(
+ name = "temp_umask",
+ testonly = 1,
+ hdrs = ["temp_umask.h"],
+)
diff --git a/test/util/capability_util.cc b/test/util/capability_util.cc
index 5d733887b..a1b994c45 100644
--- a/test/util/capability_util.cc
+++ b/test/util/capability_util.cc
@@ -36,10 +36,10 @@ PosixErrorOr<bool> CanCreateUserNamespace() {
ASSIGN_OR_RETURN_ERRNO(
auto child_stack,
MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- int const child_pid =
- clone(+[](void*) { return 0; },
- reinterpret_cast<void*>(child_stack.addr() + kPageSize),
- CLONE_NEWUSER | SIGCHLD, /* arg = */ nullptr);
+ int const child_pid = clone(
+ +[](void*) { return 0; },
+ reinterpret_cast<void*>(child_stack.addr() + kPageSize),
+ CLONE_NEWUSER | SIGCHLD, /* arg = */ nullptr);
if (child_pid > 0) {
int status;
int const ret = waitpid(child_pid, &status, /* options = */ 0);
@@ -63,13 +63,13 @@ PosixErrorOr<bool> CanCreateUserNamespace() {
// is in a chroot environment (i.e., the caller's root directory does
// not match the root directory of the mount namespace in which it
// resides)."
- std::cerr << "clone(CLONE_NEWUSER) failed with EPERM";
+ std::cerr << "clone(CLONE_NEWUSER) failed with EPERM" << std::endl;
return false;
} else if (errno == EUSERS) {
// "(since Linux 3.11) CLONE_NEWUSER was specified in flags, and the call
// would cause the limit on the number of nested user namespaces to be
// exceeded. See user_namespaces(7)."
- std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS";
+ std::cerr << "clone(CLONE_NEWUSER) failed with EUSERS" << std::endl;
return false;
} else {
// Unexpected error code; indicate an actual error.
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index 88b1e7911..5418948fe 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -105,6 +105,15 @@ PosixErrorOr<struct stat> Stat(absl::string_view path) {
return stat_buf;
}
+PosixErrorOr<struct stat> Lstat(absl::string_view path) {
+ struct stat stat_buf;
+ int res = lstat(std::string(path).c_str(), &stat_buf);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("lstat ", path));
+ }
+ return stat_buf;
+}
+
PosixErrorOr<struct stat> Fstat(int fd) {
struct stat stat_buf;
int res = fstat(fd, &stat_buf);
@@ -116,18 +125,18 @@ PosixErrorOr<struct stat> Fstat(int fd) {
PosixErrorOr<bool> Exists(absl::string_view path) {
struct stat stat_buf;
- int res = stat(std::string(path).c_str(), &stat_buf);
+ int res = lstat(std::string(path).c_str(), &stat_buf);
if (res < 0) {
if (errno == ENOENT) {
return false;
}
- return PosixError(errno, absl::StrCat("stat ", path));
+ return PosixError(errno, absl::StrCat("lstat ", path));
}
return true;
}
PosixErrorOr<bool> IsDirectory(absl::string_view path) {
- ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Stat(path));
+ ASSIGN_OR_RETURN_ERRNO(struct stat stat_buf, Lstat(path));
if (S_ISDIR(stat_buf.st_mode)) {
return true;
}
@@ -443,7 +452,7 @@ PosixErrorOr<std::string> MakeAbsolute(absl::string_view filename,
std::string CleanPath(const absl::string_view unclean_path) {
std::string path = std::string(unclean_path);
- const char *src = path.c_str();
+ const char* src = path.c_str();
std::string::iterator dst = path.begin();
// Check for absolute path and determine initial backtrack limit.
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index ee1b341d7..8cdac23a1 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -26,6 +26,17 @@
namespace gvisor {
namespace testing {
+
+// O_LARGEFILE as defined by Linux. glibc tries to be clever by setting it to 0
+// because "it isn't needed", even though Linux can return it via F_GETFL.
+#if defined(__x86_64__)
+constexpr int kOLargeFile = 00100000;
+#elif defined(__aarch64__)
+constexpr int kOLargeFile = 00400000;
+#else
+#error "Unknown architecture"
+#endif
+
// Returns a status or the current working directory.
PosixErrorOr<std::string> GetCWD();
@@ -33,9 +44,14 @@ PosixErrorOr<std::string> GetCWD();
// can't be determined.
PosixErrorOr<bool> Exists(absl::string_view path);
-// Returns a stat structure for the given path or an error.
+// Returns a stat structure for the given path or an error. If the path
+// represents a symlink, it will be traversed.
PosixErrorOr<struct stat> Stat(absl::string_view path);
+// Returns a stat structure for the given path or an error. If the path
+// represents a symlink, it will not be traversed.
+PosixErrorOr<struct stat> Lstat(absl::string_view path);
+
// Returns a stat struct for the given fd.
PosixErrorOr<struct stat> Fstat(int fd);
diff --git a/test/util/fs_util_test.cc b/test/util/fs_util_test.cc
index 2a200320a..657b6a46e 100644
--- a/test/util/fs_util_test.cc
+++ b/test/util/fs_util_test.cc
@@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "test/util/fs_util.h"
+
#include <errno.h>
+
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include "test/util/fs_util.h"
#include "test/util/posix_error.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
diff --git a/test/util/mount_util.h b/test/util/mount_util.h
index 38ec6c8a1..09e2281eb 100644
--- a/test/util/mount_util.h
+++ b/test/util/mount_util.h
@@ -17,6 +17,7 @@
#include <errno.h>
#include <sys/mount.h>
+
#include <functional>
#include <string>
@@ -30,10 +31,10 @@ namespace testing {
// Mount mounts the filesystem, and unmounts when the returned reference is
// destroyed.
-inline PosixErrorOr<Cleanup> Mount(const std::string &source,
- const std::string &target,
- const std::string &fstype, uint64_t mountflags,
- const std::string &data,
+inline PosixErrorOr<Cleanup> Mount(const std::string& source,
+ const std::string& target,
+ const std::string& fstype,
+ uint64_t mountflags, const std::string& data,
uint64_t umountflags) {
if (mount(source.c_str(), target.c_str(), fstype.c_str(), mountflags,
data.c_str()) == -1) {
diff --git a/test/util/multiprocess_util.h b/test/util/multiprocess_util.h
index 61526b4e7..2f3bf4a6f 100644
--- a/test/util/multiprocess_util.h
+++ b/test/util/multiprocess_util.h
@@ -99,11 +99,13 @@ inline PosixErrorOr<Cleanup> ForkAndExec(const std::string& filename,
const ExecveArray& argv,
const ExecveArray& envv, pid_t* child,
int* execve_errno) {
- return ForkAndExec(filename, argv, envv, [] {}, child, execve_errno);
+ return ForkAndExec(
+ filename, argv, envv, [] {}, child, execve_errno);
}
// Equivalent to ForkAndExec, except using dirfd and flags with execveat.
-PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd, const std::string& pathname,
+PosixErrorOr<Cleanup> ForkAndExecveat(int32_t dirfd,
+ const std::string& pathname,
const ExecveArray& argv,
const ExecveArray& envv, int flags,
const std::function<void()>& fn,
diff --git a/test/util/platform_util.cc b/test/util/platform_util.cc
new file mode 100644
index 000000000..c9200d381
--- /dev/null
+++ b/test/util/platform_util.cc
@@ -0,0 +1,48 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/platform_util.h"
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PlatformSupport PlatformSupport32Bit() {
+ if (GvisorPlatform() == Platform::kPtrace ||
+ GvisorPlatform() == Platform::kKVM) {
+ return PlatformSupport::NotSupported;
+ } else {
+ return PlatformSupport::Allowed;
+ }
+}
+
+PlatformSupport PlatformSupportAlignmentCheck() {
+ return PlatformSupport::Allowed;
+}
+
+PlatformSupport PlatformSupportMultiProcess() {
+ return PlatformSupport::Allowed;
+}
+
+PlatformSupport PlatformSupportInt3() {
+ if (GvisorPlatform() == Platform::kKVM) {
+ return PlatformSupport::NotSupported;
+ } else {
+ return PlatformSupport::Allowed;
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/platform_util.h b/test/util/platform_util.h
new file mode 100644
index 000000000..28cc92371
--- /dev/null
+++ b/test/util/platform_util.h
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_PLATFORM_UTIL_H_
+#define GVISOR_TEST_UTIL_PLATFORM_UTIL_H_
+
+namespace gvisor {
+namespace testing {
+
+// PlatformSupport is a generic enumeration of classes of support.
+//
+// It is up to the individual functions and callers to agree on the precise
+// definition for each case. The document here generally refers to 32-bit
+// as an example. Many cases will use only NotSupported and Allowed.
+enum class PlatformSupport {
+ // The feature is not supported on the current platform.
+ //
+ // In the case of 32-bit, this means that calls will generally be interpreted
+ // as 64-bit calls, and there is no support for 32-bit binaries, long calls,
+ // etc. This usually means that the underlying implementation just pretends
+ // that 32-bit doesn't exist.
+ NotSupported,
+
+ // Calls will be ignored by the kernel with a fixed error.
+ Ignored,
+
+ // Calls will result in a SIGSEGV or similar fault.
+ Segfault,
+
+ // The feature is supported as expected.
+ //
+ // In the case of 32-bit, this means that the system call or far call will be
+ // handled properly.
+ Allowed,
+};
+
+PlatformSupport PlatformSupport32Bit();
+PlatformSupport PlatformSupportAlignmentCheck();
+PlatformSupport PlatformSupportMultiProcess();
+PlatformSupport PlatformSupportInt3();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_PLATFORM_UTL_H_
diff --git a/test/util/posix_error_test.cc b/test/util/posix_error_test.cc
index d67270842..bf9465abb 100644
--- a/test/util/posix_error_test.cc
+++ b/test/util/posix_error_test.cc
@@ -15,6 +15,7 @@
#include "test/util/posix_error.h"
#include <errno.h>
+
#include "gmock/gmock.h"
#include "gtest/gtest.h"
diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc
index c0fd9a095..c01f916aa 100644
--- a/test/util/pty_util.cc
+++ b/test/util/pty_util.cc
@@ -24,6 +24,14 @@ namespace gvisor {
namespace testing {
PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) {
+ PosixErrorOr<int> n = SlaveID(master);
+ if (!n.ok()) {
+ return PosixErrorOr<FileDescriptor>(n.error());
+ }
+ return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK);
+}
+
+PosixErrorOr<int> SlaveID(const FileDescriptor& master) {
// Get pty index.
int n;
int ret = ioctl(master.get(), TIOCGPTN, &n);
@@ -38,7 +46,7 @@ PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) {
return PosixError(errno, "ioctl(TIOSPTLCK) failed");
}
- return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK);
+ return n;
}
} // namespace testing
diff --git a/test/util/pty_util.h b/test/util/pty_util.h
index 367b14f15..0722da379 100644
--- a/test/util/pty_util.h
+++ b/test/util/pty_util.h
@@ -24,6 +24,9 @@ namespace testing {
// Opens the slave end of the passed master as R/W and nonblocking.
PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master);
+// Get the number of the slave end of the master.
+PosixErrorOr<int> SlaveID(const FileDescriptor& master);
+
} // namespace testing
} // namespace gvisor
diff --git a/test/util/rlimit_util.cc b/test/util/rlimit_util.cc
index 684253f78..d7bfc1606 100644
--- a/test/util/rlimit_util.cc
+++ b/test/util/rlimit_util.cc
@@ -15,6 +15,7 @@
#include "test/util/rlimit_util.h"
#include <sys/resource.h>
+
#include <cerrno>
#include "test/util/cleanup.h"
diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc
index 7a0f14342..d0aea8e6a 100644
--- a/test/util/save_util_linux.cc
+++ b/test/util/save_util_linux.cc
@@ -12,22 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#ifdef __linux__
+
#include <errno.h>
#include <sys/syscall.h>
#include <unistd.h>
#include "test/util/save_util.h"
+#if defined(__x86_64__) || defined(__i386__)
+#define SYS_TRIGGER_SAVE SYS_create_module
+#elif defined(__aarch64__)
+#define SYS_TRIGGER_SAVE SYS_finit_module
+#else
+#error "Unknown architecture"
+#endif
+
namespace gvisor {
namespace testing {
void MaybeSave() {
if (internal::ShouldSave()) {
int orig_errno = errno;
- syscall(SYS_create_module, nullptr, 0);
+ // We use it to trigger saving the sentry state
+ // when this syscall is called.
+ // Notice: this needs to be a valid syscall
+ // that is not used in any of the syscall tests.
+ syscall(SYS_TRIGGER_SAVE, nullptr, 0);
errno = orig_errno;
}
}
} // namespace testing
} // namespace gvisor
+
+#endif
diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc
index 1aca663b7..931af2c29 100644
--- a/test/util/save_util_other.cc
+++ b/test/util/save_util_other.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#ifndef __linux__
+
namespace gvisor {
namespace testing {
@@ -21,3 +23,5 @@ void MaybeSave() {
} // namespace testing
} // namespace gvisor
+
+#endif
diff --git a/test/util/signal_util.cc b/test/util/signal_util.cc
index 26738864f..5ee95ee80 100644
--- a/test/util/signal_util.cc
+++ b/test/util/signal_util.cc
@@ -15,6 +15,7 @@
#include "test/util/signal_util.h"
#include <signal.h>
+
#include <ostream>
#include "gtest/gtest.h"
diff --git a/test/util/signal_util.h b/test/util/signal_util.h
index 7fd2af015..e7b66aa51 100644
--- a/test/util/signal_util.h
+++ b/test/util/signal_util.h
@@ -18,6 +18,7 @@
#include <signal.h>
#include <sys/syscall.h>
#include <unistd.h>
+
#include <ostream>
#include "gmock/gmock.h"
@@ -84,6 +85,20 @@ inline void FixupFault(ucontext_t* ctx) {
// The encoding is 0x48 0xab 0x00.
ctx->uc_mcontext.gregs[REG_RIP] += 3;
}
+#elif __aarch64__
+inline void Fault() {
+ // Zero and dereference x0.
+ asm("mov xzr, x0\r\n"
+ "str xzr, [x0]\r\n"
+ :
+ :
+ : "x0");
+}
+
+inline void FixupFault(ucontext_t* ctx) {
+ // Skip the bad instruction above.
+ ctx->uc_mcontext.pc += 4;
+}
#endif
} // namespace testing
diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc
index 35aacb172..e1bdee7fd 100644
--- a/test/util/temp_path.cc
+++ b/test/util/temp_path.cc
@@ -56,7 +56,7 @@ void TryDeleteRecursively(std::string const& path) {
if (undeleted_dirs || undeleted_files || !status.ok()) {
std::cerr << path << ": failed to delete " << undeleted_dirs
<< " directories and " << undeleted_files
- << " files: " << status;
+ << " files: " << status << std::endl;
}
}
}
@@ -77,6 +77,7 @@ std::string NewTempAbsPath() {
std::string NewTempRelPath() { return NextTempBasename(); }
std::string GetAbsoluteTestTmpdir() {
+ // Note that TEST_TMPDIR is guaranteed to be set.
char* env_tmpdir = getenv("TEST_TMPDIR");
std::string tmp_dir =
env_tmpdir != nullptr ? std::string(env_tmpdir) : "/tmp";
diff --git a/test/util/temp_path.h b/test/util/temp_path.h
index 92d669503..9e5ac11f4 100644
--- a/test/util/temp_path.h
+++ b/test/util/temp_path.h
@@ -16,6 +16,7 @@
#define GVISOR_TEST_UTIL_TEMP_PATH_H_
#include <sys/stat.h>
+
#include <string>
#include <utility>
diff --git a/test/syscalls/linux/temp_umask.h b/test/util/temp_umask.h
index 81a25440c..e7de84a54 100644
--- a/test/syscalls/linux/temp_umask.h
+++ b/test/util/temp_umask.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
-#define GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
+#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_
+#define GVISOR_TEST_UTIL_TEMP_UMASK_H_
#include <sys/stat.h>
#include <sys/types.h>
@@ -36,4 +36,4 @@ class TempUmask {
} // namespace testing
} // namespace gvisor
-#endif // GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_
+#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_
diff --git a/test/util/test_main.cc b/test/util/test_main.cc
index 5c7ee0064..1f389e58f 100644
--- a/test/util/test_main.cc
+++ b/test/util/test_main.cc
@@ -16,5 +16,5 @@
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/util/test_util.cc b/test/util/test_util.cc
index ba0dcf7d0..d0c1d6426 100644
--- a/test/util/test_util.cc
+++ b/test/util/test_util.cc
@@ -40,24 +40,38 @@
namespace gvisor {
namespace testing {
-#define TEST_ON_GVISOR "TEST_ON_GVISOR"
+constexpr char kGvisorNetwork[] = "GVISOR_NETWORK";
+constexpr char kGvisorVfs[] = "GVISOR_VFS";
+constexpr char kFuseEnabled[] = "FUSE_ENABLED";
bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; }
-Platform GvisorPlatform() {
+const std::string GvisorPlatform() {
// Set by runner.go.
- char* env = getenv(TEST_ON_GVISOR);
+ const char* env = getenv(kTestOnGvisor);
if (!env) {
return Platform::kNative;
}
- if (strcmp(env, "ptrace") == 0) {
- return Platform::kPtrace;
- }
- if (strcmp(env, "kvm") == 0) {
- return Platform::kKVM;
+ return std::string(env);
+}
+
+bool IsRunningWithHostinet() {
+ const char* env = getenv(kGvisorNetwork);
+ return env && strcmp(env, "host") == 0;
+}
+
+bool IsRunningWithVFS1() {
+ const char* env = getenv(kGvisorVfs);
+ if (env == nullptr) {
+ // If not set, it's running on Linux.
+ return false;
}
- std::cerr << "unknown platform " << env;
- abort();
+ return strcmp(env, "VFS1") == 0;
+}
+
+bool IsFUSEEnabled() {
+ const char* env = getenv(kFuseEnabled);
+ return env && strcmp(env, "TRUE") == 0;
}
// Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations
@@ -70,7 +84,6 @@ Platform GvisorPlatform() {
"xchg %%rdi, %%rbx\n" \
: "=a"(a), "=D"(b), "=c"(c), "=d"(d) \
: "a"(a_inp), "2"(c_inp))
-#endif // defined(__x86_64__)
CPUVendor GetCPUVendor() {
uint32_t eax, ebx, ecx, edx;
@@ -87,6 +100,7 @@ CPUVendor GetCPUVendor() {
}
return CPUVendor::kUnknownVendor;
}
+#endif // defined(__x86_64__)
bool operator==(const KernelVersion& first, const KernelVersion& second) {
return first.major == second.major && first.minor == second.minor &&
@@ -116,9 +130,6 @@ PosixErrorOr<KernelVersion> GetKernelVersion() {
return ParseKernelVersion(buf.release);
}
-void SetupGvisorDeathTest() {
-}
-
std::string CPUSetToString(const cpu_set_t& set, size_t cpus) {
std::string str = "cpuset[";
for (unsigned int n = 0; n < cpus; n++) {
@@ -224,15 +235,5 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) {
return abs_diff <= static_cast<uint64_t>(tolerance * target);
}
-void TestInit(int* argc, char*** argv) {
- ::testing::InitGoogleTest(argc, *argv);
- ::absl::ParseCommandLine(*argc, *argv);
-
- // Always mask SIGPIPE as it's common and tests aren't expected to handle it.
- struct sigaction sa = {};
- sa.sa_handler = SIG_IGN;
- TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0);
-}
-
} // namespace testing
} // namespace gvisor
diff --git a/test/util/test_util.h b/test/util/test_util.h
index b9d2dc2ba..373c54f32 100644
--- a/test/util/test_util.h
+++ b/test/util/test_util.h
@@ -26,16 +26,13 @@
// IsRunningOnGvisor returns true if the test is known to be running on gVisor.
// GvisorPlatform can be used to get more detail:
//
-// switch (GvisorPlatform()) {
-// case Platform::kNative:
-// case Platform::kGvisor:
-// EXPECT_THAT(mmap(...), SyscallSucceeds());
-// break;
-// case Platform::kPtrace:
-// EXPECT_THAT(mmap(...), SyscallFailsWithErrno(ENOSYS));
-// break;
+// if (GvisorPlatform() == Platform::kPtrace) {
+// ...
// }
//
+// SetupGvisorDeathTest ensures that signal handling does not interfere with
+/// tests that rely on fatal signals.
+//
// Matchers
// ========
//
@@ -198,6 +195,8 @@
namespace gvisor {
namespace testing {
+constexpr char kTestOnGvisor[] = "TEST_ON_GVISOR";
+
// TestInit must be called prior to RUN_ALL_TESTS.
//
// This parses all arguments and adjusts argc and argv appropriately.
@@ -213,15 +212,24 @@ void TestInit(int* argc, char*** argv);
if (expr) GTEST_SKIP() << #expr; \
} while (0)
-enum class Platform {
- kNative,
- kKVM,
- kPtrace,
-};
+// Platform contains platform names.
+namespace Platform {
+constexpr char kNative[] = "native";
+constexpr char kPtrace[] = "ptrace";
+constexpr char kKVM[] = "kvm";
+constexpr char kFuchsia[] = "fuchsia";
+} // namespace Platform
+
bool IsRunningOnGvisor();
-Platform GvisorPlatform();
+const std::string GvisorPlatform();
+bool IsRunningWithHostinet();
+// TODO(gvisor.dev/issue/1624): Delete once VFS1 is gone.
+bool IsRunningWithVFS1();
+bool IsFUSEEnabled();
+#ifdef __linux__
void SetupGvisorDeathTest();
+#endif
struct KernelVersion {
int major;
@@ -560,6 +568,25 @@ ssize_t ApplyFileIoSyscall(F const& f, size_t const count) {
} // namespace internal
+inline PosixErrorOr<std::string> ReadAllFd(int fd) {
+ std::string all;
+ all.reserve(128 * 1024); // arbitrary.
+
+ std::vector<char> buffer(16 * 1024);
+ for (;;) {
+ auto const bytes = RetryEINTR(read)(fd, buffer.data(), buffer.size());
+ if (bytes < 0) {
+ return PosixError(errno, "file read");
+ }
+ if (bytes == 0) {
+ return std::move(all);
+ }
+ if (bytes > 0) {
+ all.append(buffer.data(), bytes);
+ }
+ }
+}
+
inline ssize_t ReadFd(int fd, void* buf, size_t count) {
return internal::ApplyFileIoSyscall(
[&](size_t completed) {
@@ -762,7 +789,14 @@ MATCHER_P2(EquivalentWithin, target, tolerance,
return Equivalent(arg, target, tolerance);
}
+// Returns the absolute path to the a data dependency. 'path' is the runfile
+// location relative to workspace root.
+#ifdef __linux__
+std::string RunfilePath(std::string path);
+#endif
+
void TestInit(int* argc, char*** argv);
+int RunAllTests(void);
} // namespace testing
} // namespace gvisor
diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc
new file mode 100644
index 000000000..7e1ad9e66
--- /dev/null
+++ b/test/util/test_util_impl.cc
@@ -0,0 +1,52 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
+#include "absl/flags/parse.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+extern bool FLAGS_benchmark_list_tests;
+extern std::string FLAGS_benchmark_filter;
+
+namespace gvisor {
+namespace testing {
+
+void SetupGvisorDeathTest() {}
+
+void TestInit(int* argc, char*** argv) {
+ ::testing::InitGoogleTest(argc, *argv);
+ benchmark::Initialize(argc, *argv);
+ ::absl::ParseCommandLine(*argc, *argv);
+
+ // Always mask SIGPIPE as it's common and tests aren't expected to handle it.
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0);
+}
+
+int RunAllTests() {
+ if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") {
+ benchmark::RunSpecifiedBenchmarks();
+ return 0;
+ } else {
+ return RUN_ALL_TESTS();
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc
new file mode 100644
index 000000000..694d21692
--- /dev/null
+++ b/test/util/test_util_runfiles.cc
@@ -0,0 +1,50 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef __fuchsia__
+
+#include <iostream>
+#include <string>
+
+#include "test/util/fs_util.h"
+#include "test/util/test_util.h"
+#include "tools/cpp/runfiles/runfiles.h"
+
+namespace gvisor {
+namespace testing {
+
+std::string RunfilePath(std::string path) {
+ static const bazel::tools::cpp::runfiles::Runfiles* const runfiles = [] {
+ std::string error;
+ auto* runfiles =
+ bazel::tools::cpp::runfiles::Runfiles::CreateForTest(&error);
+ if (runfiles == nullptr) {
+ std::cerr << "Unable to find runfiles: " << error << std::endl;
+ }
+ return runfiles;
+ }();
+
+ if (!runfiles) {
+ // Can't find runfiles? This probably won't work, but __main__/path is our
+ // best guess.
+ return JoinPath("__main__", path);
+ }
+
+ return runfiles->Rlocation(JoinPath("__main__", path));
+}
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // __fuchsia__
diff --git a/test/util/test_util_test.cc b/test/util/test_util_test.cc
index b7300d9e5..f42100374 100644
--- a/test/util/test_util_test.cc
+++ b/test/util/test_util_test.cc
@@ -15,6 +15,7 @@
#include "test/util/test_util.h"
#include <errno.h>
+
#include <vector>
#include "gmock/gmock.h"
diff --git a/third_party/gvsync/downgradable_rwmutex_1_12_unsafe.go b/third_party/gvsync/downgradable_rwmutex_1_12_unsafe.go
deleted file mode 100644
index 855b2a2b1..000000000
--- a/third_party/gvsync/downgradable_rwmutex_1_12_unsafe.go
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Copyright 2019 The gVisor Authors.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.12
-// +build !go1.13
-
-// TODO(b/133868570): Delete once Go 1.12 is no longer supported.
-
-package gvsync
-
-import _ "unsafe"
-
-//go:linkname runtimeSemrelease112 sync.runtime_Semrelease
-func runtimeSemrelease112(s *uint32, handoff bool)
-
-func runtimeSemrelease(s *uint32, handoff bool, skipframes int) {
- // 'skipframes' is only available starting from 1.13.
- runtimeSemrelease112(s, handoff)
-}
diff --git a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go b/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go
deleted file mode 100644
index 8baec5458..000000000
--- a/third_party/gvsync/downgradable_rwmutex_1_13_unsafe.go
+++ /dev/null
@@ -1,16 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Copyright 2019 The gVisor Authors.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build go1.13
-// +build !go1.14
-
-// Check go:linkname function signatures when updating Go version.
-
-package gvsync
-
-import _ "unsafe"
-
-//go:linkname runtimeSemrelease sync.runtime_Semrelease
-func runtimeSemrelease(s *uint32, handoff bool, skipframes int)
diff --git a/tools/BUILD b/tools/BUILD
new file mode 100644
index 000000000..da83877b1
--- /dev/null
+++ b/tools/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "bzl_library")
+
+package(licenses = ["notice"])
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/bazel.mk b/tools/bazel.mk
new file mode 100644
index 000000000..3e27af7d1
--- /dev/null
+++ b/tools/bazel.mk
@@ -0,0 +1,181 @@
+#!/usr/bin/make -f
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# See base Makefile.
+SHELL=/bin/bash -o pipefail
+BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
+ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
+ xargs -n 1 basename 2>/dev/null)
+
+# Bazel container configuration (see below).
+USER ?= gvisor
+HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8)
+BUILDER_BASE := gvisor.dev/images/default
+BUILDER_IMAGE := gvisor.dev/images/builder
+BUILDER_NAME ?= gvisor-builder-$(HASH)
+DOCKER_NAME ?= gvisor-bazel-$(HASH)
+DOCKER_PRIVILEGED ?= --privileged
+BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/)
+GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/)
+DOCKER_SOCKET := /var/run/docker.sock
+
+# Bazel flags.
+BAZEL := bazel $(STARTUP_OPTIONS)
+OPTIONS += --color=no --curses=no
+
+# Basic options.
+UID := $(shell id -u ${USER})
+GID := $(shell id -g ${USER})
+USERADD_OPTIONS :=
+FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS)
+FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID)
+FULL_DOCKER_RUN_OPTIONS += --entrypoint ""
+FULL_DOCKER_RUN_OPTIONS += --init
+FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)"
+FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)"
+FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp"
+FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID)
+FULL_DOCKER_EXEC_OPTIONS += --interactive
+ifeq (true,$(shell [[ -t 0 ]] && echo true))
+FULL_DOCKER_EXEC_OPTIONS += --tty
+endif
+
+# Add docker passthrough options.
+ifneq ($(DOCKER_PRIVILEGED),)
+FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
+FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED)
+FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED)
+DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET))
+ifneq ($(GID),$(DOCKER_GROUP))
+USERADD_OPTIONS += --groups $(DOCKER_GROUP)
+GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) &&
+FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP)
+endif
+endif
+
+# Add KVM passthrough options.
+ifneq (,$(wildcard /dev/kvm))
+FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm
+KVM_GROUP := $(shell stat -c '%g' /dev/kvm)
+ifneq ($(GID),$(KVM_GROUP))
+USERADD_OPTIONS += --groups $(KVM_GROUP)
+GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) &&
+FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP)
+endif
+endif
+
+# Load the appropriate config.
+ifneq (,$(BAZEL_CONFIG))
+OPTIONS += --config=$(BAZEL_CONFIG)
+endif
+
+# NOTE: we pass -l to useradd below because otherwise you can hit a bug
+# best described here:
+# https://github.com/moby/moby/issues/5419#issuecomment-193876183
+# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out
+# out of disk space.
+bazel-image: load-default
+ @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi
+ docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \
+ $(BUILDER_BASE) \
+ sh -c "groupadd --gid $(GID) --non-unique $(USER) && \
+ $(GROUPADD_DOCKER) \
+ useradd -l --uid $(UID) --non-unique --no-create-home \
+ --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \
+ if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi"
+ docker commit $(BUILDER_NAME) $(BUILDER_IMAGE)
+ @docker rm -f $(BUILDER_NAME)
+.PHONY: bazel-image
+
+##
+## Bazel helpers.
+##
+## This file supports targets that wrap bazel in a running Docker
+## container to simplify development. Some options are available to
+## control the behavior of this container:
+## USER - The in-container user.
+## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests).
+## DOCKER_NAME - The container name (default: gvisor-bazel-HASH).
+## BAZEL_CACHE - The bazel cache directory (default: detected).
+## GCLOUD_CONFIG - The gcloud config directory (detect: detected).
+## DOCKER_SOCKET - The Docker socket (default: detected).
+##
+bazel-server-start: bazel-image ## Starts the bazel server.
+ @mkdir -p $(BAZEL_CACHE)
+ @mkdir -p $(GCLOUD_CONFIG)
+ @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi
+ # This command runs a bazel server, and the container sticks around
+ # until the bazel server exits. This should ensure that it does not
+ # exit in the middle of running a build, but also it won't stick around
+ # forever. The build commands wrap around an appropriate exec into the
+ # container in order to perform work via the bazel client.
+ docker run -d --rm --name $(DOCKER_NAME) \
+ -v "$(CURDIR):$(CURDIR)" \
+ --workdir "$(CURDIR)" \
+ $(FULL_DOCKER_RUN_OPTIONS) \
+ $(BUILDER_IMAGE) \
+ sh -c "tail -f --pid=\$$($(BAZEL) info server_pid)"
+.PHONY: bazel-server-start
+
+bazel-shutdown: ## Shuts down a running bazel server.
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \
+ rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]]
+.PHONY: bazel-shutdown
+
+bazel-alias: ## Emits an alias that can be used within the shell.
+ @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'"
+.PHONY: bazel-alias
+
+bazel-server: ## Ensures that the server exists. Used as an internal target.
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true || $(MAKE) bazel-server-start
+.PHONY: bazel-server
+
+build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) "$(TARGETS)"'
+
+build_paths = $(build_cmd) 2>&1 \
+ | tee /proc/self/fd/2 \
+ | grep -E "^ bazel-bin/" \
+ | tr -d '\r' \
+ | awk '{$$1=$$1};1' \
+ | xargs -n 1 -I {} sh -c "$(1)"
+
+build: bazel-server
+ @$(call build_cmd)
+.PHONY: build
+
+copy: bazel-server
+ifeq (,$(DESTINATION))
+ $(error Destination not provided.)
+endif
+ @$(call build_paths,cp -fa {} $(DESTINATION))
+
+run: bazel-server
+ @$(call build_paths,{} $(ARGS))
+.PHONY: run
+
+sudo: bazel-server
+ @$(call build_paths,sudo -E {} $(ARGS))
+.PHONY: sudo
+
+test: OPTIONS += --test_output=errors --keep_going --verbose_failures=true
+test: bazel-server
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS)
+.PHONY: test
+
+query:
+ @$(MAKE) bazel-server >&2 # If we need to start, ensure stdout is not polluted.
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) query $(OPTIONS) "$(TARGETS)" 2>/dev/null'
+.PHONY: query
diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD
new file mode 100644
index 000000000..8d4356119
--- /dev/null
+++ b/tools/bazeldefs/BUILD
@@ -0,0 +1,106 @@
+load("//tools:defs.bzl", "bzl_library", "rbe_platform", "rbe_toolchain")
+
+package(licenses = ["notice"])
+
+# In bazel, no special support is required for loopback networking. This is
+# just a dummy data target that does not change the test environment.
+genrule(
+ name = "loopback",
+ outs = ["loopback.txt"],
+ cmd = "touch $@",
+ visibility = ["//:sandbox"],
+)
+
+# We need to define a bazel platform and toolchain to specify dockerPrivileged
+# and dockerRunAsRoot options, they are required to run tests on the RBE
+# cluster in Kokoro.
+rbe_platform(
+ name = "rbe_ubuntu1604",
+ constraint_values = [
+ "@bazel_tools//platforms:x86_64",
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//tools/cpp:clang",
+ "@bazel_toolchains//constraints:xenial",
+ "@bazel_toolchains//constraints/sanitizers:support_msan",
+ ],
+ remote_execution_properties = """
+ properties: {
+ name: "container-image"
+ value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50"
+ }
+ properties: {
+ name: "dockerAddCapabilities"
+ value: "SYS_ADMIN"
+ }
+ properties: {
+ name: "dockerPrivileged"
+ value: "true"
+ }
+ """,
+)
+
+rbe_toolchain(
+ name = "cc-toolchain-clang-x86_64-default",
+ exec_compatible_with = [],
+ tags = [
+ "manual",
+ ],
+ target_compatible_with = [],
+ toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
+# Updated versions of the above, compatible with bazel3.
+rbe_platform(
+ name = "rbe_ubuntu1604_bazel3",
+ constraint_values = [
+ "@bazel_tools//platforms:x86_64",
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//tools/cpp:clang",
+ "@bazel_toolchains_bazel3//constraints:xenial",
+ "@bazel_toolchains_bazel3//constraints/sanitizers:support_msan",
+ ],
+ remote_execution_properties = """
+ properties: {
+ name: "container-image"
+ value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272"
+ }
+ properties: {
+ name: "dockerAddCapabilities"
+ value: "SYS_ADMIN"
+ }
+ properties: {
+ name: "dockerPrivileged"
+ value: "true"
+ }
+ """,
+)
+
+rbe_toolchain(
+ name = "cc-toolchain-clang-x86_64-default_bazel3",
+ exec_compatible_with = [],
+ tags = [
+ "manual",
+ ],
+ target_compatible_with = [],
+ toolchain = "@bazel_toolchains_bazel3//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
+bzl_library(
+ name = "platforms_bzl",
+ srcs = ["platforms.bzl"],
+ visibility = ["//visibility:private"],
+)
+
+bzl_library(
+ name = "tags_bzl",
+ srcs = ["tags.bzl"],
+ visibility = ["//visibility:private"],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
new file mode 100644
index 000000000..db7f379b8
--- /dev/null
+++ b/tools/bazeldefs/defs.bzl
@@ -0,0 +1,181 @@
+"""Bazel implementations of standard rules."""
+
+load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle")
+load("@bazel_skylib//rules:build_test.bzl", _build_test = "build_test")
+load("@bazel_skylib//:bzl_library.bzl", _bzl_library = "bzl_library")
+load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier")
+load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test")
+load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library")
+load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test")
+load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library")
+
+build_test = _build_test
+bzl_library = _bzl_library
+cc_library = _cc_library
+cc_flags_supplier = _cc_flags_supplier
+cc_proto_library = _cc_proto_library
+cc_test = _cc_test
+cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
+gazelle = _gazelle
+go_embed_data = _go_embed_data
+go_path = _go_path
+gtest = "@com_google_googletest//:gtest"
+grpcpp = "@com_github_grpc_grpc//:grpc++"
+gbenchmark = "@com_google_benchmark//:benchmark"
+loopback = "//tools/bazeldefs:loopback"
+pkg_deb = _pkg_deb
+pkg_tar = _pkg_tar
+py_binary = native.py_binary
+rbe_platform = native.platform
+rbe_toolchain = native.toolchain
+vdso_linker_option = "-fuse-ld=gold "
+
+def short_path(path):
+ return path
+
+def proto_library(name, has_services = None, **kwargs):
+ native.proto_library(
+ name = name,
+ **kwargs
+ )
+
+def cc_grpc_library(name, **kwargs):
+ _cc_grpc_library(name = name, grpc_only = True, **kwargs)
+
+def _go_proto_or_grpc_library(go_library_func, name, **kwargs):
+ deps = [
+ dep.replace("_proto", "_go_proto")
+ for dep in (kwargs.pop("deps", []) or [])
+ ]
+ go_library_func(
+ name = name + "_go_proto",
+ importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto",
+ proto = ":" + name + "_proto",
+ deps = deps,
+ **kwargs
+ )
+
+def go_proto_library(name, **kwargs):
+ _go_proto_or_grpc_library(_go_proto_library, name, **kwargs)
+
+def go_grpc_and_proto_libraries(name, **kwargs):
+ _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs)
+
+def cc_binary(name, static = False, **kwargs):
+ """Run cc_binary.
+
+ Args:
+ name: name of the target.
+ static: make a static binary if True
+ **kwargs: the rest of the args.
+ """
+ if static:
+ # How to statically link a c++ program that uses threads, like for gRPC:
+ # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html
+ if "linkopts" not in kwargs:
+ kwargs["linkopts"] = []
+ kwargs["linkopts"] += [
+ "-static",
+ "-lstdc++",
+ "-Wl,--whole-archive",
+ "-lpthread",
+ "-Wl,--no-whole-archive",
+ ]
+ _cc_binary(
+ name = name,
+ **kwargs
+ )
+
+def go_binary(name, static = False, pure = False, **kwargs):
+ """Build a go binary.
+
+ Args:
+ name: name of the target.
+ static: build a static binary.
+ pure: build without cgo.
+ **kwargs: rest of the arguments are passed to _go_binary.
+ """
+ if static:
+ kwargs["static"] = "on"
+ if pure:
+ kwargs["pure"] = "on"
+ _go_binary(
+ name = name,
+ **kwargs
+ )
+
+def go_importpath(target):
+ """Returns the importpath for the target."""
+ return target[GoLibrary].importpath
+
+def go_library(name, **kwargs):
+ _go_library(
+ name = name,
+ importpath = "gvisor.dev/gvisor/" + native.package_name(),
+ **kwargs
+ )
+
+def go_test(name, pure = False, library = None, **kwargs):
+ """Build a go test.
+
+ Args:
+ name: name of the output binary.
+ pure: should it be built without cgo.
+ library: the library to embed.
+ **kwargs: rest of the arguments to pass to _go_test.
+ """
+ if pure:
+ kwargs["pure"] = "on"
+ if library:
+ kwargs["embed"] = [library]
+ _go_test(
+ name = name,
+ **kwargs
+ )
+
+def go_rule(rule, implementation, **kwargs):
+ """Wraps a rule definition with Go attributes.
+
+ Args:
+ rule: rule function (typically rule or aspect).
+ implementation: implementation function.
+ **kwargs: other arguments to pass to rule.
+
+ Returns:
+ The result of invoking the rule.
+ """
+ attrs = kwargs.pop("attrs", [])
+ attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data")
+ attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib")
+ toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"]
+ return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs)
+
+def go_context(ctx):
+ go_ctx = _go_context(ctx)
+ return struct(
+ go = go_ctx.go,
+ env = go_ctx.env,
+ runfiles = depset([go_ctx.go] + go_ctx.sdk.tools + go_ctx.stdlib.libs),
+ goos = go_ctx.sdk.goos,
+ goarch = go_ctx.sdk.goarch,
+ tags = go_ctx.tags,
+ )
+
+def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs):
+ values = {
+ "@bazel_tools//src/conditions:linux_x86_64": amd64,
+ "@bazel_tools//src/conditions:linux_aarch64": arm64,
+ }
+ if default:
+ values["//conditions:default"] = default
+ return select(values, **kwargs)
+
+def select_system(linux = ["__linux__"], **kwargs):
+ return linux # Only Linux supported.
+
+def default_installer():
+ return None
+
+def default_net_util():
+ return [] # Nothing needed.
diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl
new file mode 100644
index 000000000..165b22311
--- /dev/null
+++ b/tools/bazeldefs/platforms.bzl
@@ -0,0 +1,9 @@
+"""List of platforms."""
+
+# Platform to associated tags.
+platforms = {
+ "ptrace": [],
+ "kvm": [],
+}
+
+default_platform = "ptrace"
diff --git a/tools/bazeldefs/tags.bzl b/tools/bazeldefs/tags.bzl
new file mode 100644
index 000000000..f5d7a7b21
--- /dev/null
+++ b/tools/bazeldefs/tags.bzl
@@ -0,0 +1,56 @@
+"""List of special Go suffixes."""
+
+def explode(tagset, suffixes):
+ """explode combines tagset and suffixes in all ways.
+
+ Args:
+ tagset: Original suffixes.
+ suffixes: Suffixes to combine before and after.
+
+ Returns:
+ The set of possible combinations.
+ """
+ result = [t for t in tagset]
+ result += [s for s in suffixes]
+ for t in tagset:
+ result += [t + s for s in suffixes]
+ result += [s + t for s in suffixes]
+ return result
+
+archs = [
+ "_386",
+ "_aarch64",
+ "_amd64",
+ "_arm",
+ "_arm64",
+ "_mips",
+ "_mips64",
+ "_mips64le",
+ "_mipsle",
+ "_ppc64",
+ "_ppc64le",
+ "_riscv64",
+ "_s390x",
+ "_sparc64",
+ "_x86",
+]
+
+oses = [
+ "_linux",
+]
+
+generic = [
+ "_impl",
+ "_race",
+ "_norace",
+ "_unsafe",
+ "_opts",
+]
+
+# State explosion? Sure. This is approximately:
+# len(archs) * (1 + 2 * len(oses) * (1 + 2 * len(generic))
+#
+# This evaluates to 495 at the time of writing. So it's a lot of different
+# combinations, but not so much that it will cause issues. We can probably add
+# quite a few more variants before this becomes a genuine problem.
+go_suffixes = explode(explode(archs, oses), generic)
diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD
new file mode 100644
index 000000000..5748fb390
--- /dev/null
+++ b/tools/bigquery/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "bigquery",
+ testonly = 1,
+ srcs = ["bigquery.go"],
+ deps = ["@com_google_cloud_go_bigquery//:go_default_library"],
+)
diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go
new file mode 100644
index 000000000..56f0dc5c9
--- /dev/null
+++ b/tools/bigquery/bigquery.go
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bigquery defines a BigQuery schema for benchmarks.
+//
+// This package contains a schema for BigQuery and methods for publishing
+// benchmark data into tables.
+package bigquery
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "time"
+
+ bq "cloud.google.com/go/bigquery"
+)
+
+// Benchmark is the top level structure of recorded benchmark data. BigQuery
+// will infer the schema from this.
+type Benchmark struct {
+ Name string `bq:"name"`
+ Timestamp time.Time `bq:"timestamp"`
+ Official bool `bq:"official"`
+ Metric []*Metric `bq:"metric"`
+ Metadata *Metadata `bq:"metadata"`
+}
+
+// Metric holds the actual metric data and unit information for this benchmark.
+type Metric struct {
+ Name string `bq:"name"`
+ Unit string `bq:"unit"`
+ Sample float64 `bq:"sample"`
+}
+
+// Metadata about this benchmark.
+type Metadata struct {
+ CL string `bq:"changelist"`
+ IterationID string `bq:"iteration_id"`
+ PendingCL string `bq:"pending_cl"`
+ Workflow string `bq:"workflow"`
+ Platform string `bq:"platform"`
+ Gofer string `bq:"gofer"`
+}
+
+// InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated.
+func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error {
+ client, err := bq.NewClient(ctx, projectID)
+ if err != nil {
+ return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err)
+ }
+ defer client.Close()
+
+ dataset := client.Dataset(datasetID)
+ if err := dataset.Create(ctx, nil); err != nil && !checkDuplicateError(err) {
+ return fmt.Errorf("failed to create dataset: %s: %v", datasetID, err)
+ }
+
+ table := dataset.Table(tableID)
+ schema, err := bq.InferSchema(Benchmark{})
+ if err != nil {
+ return fmt.Errorf("failed to infer schema: %v", err)
+ }
+
+ if err := table.Create(ctx, &bq.TableMetadata{Schema: schema}); err != nil && !checkDuplicateError(err) {
+ return fmt.Errorf("failed to create table: %s: %v", tableID, err)
+ }
+ return nil
+}
+
+// AddMetric adds a metric to an existing Benchmark.
+func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) {
+ m := &Metric{
+ Name: metricName,
+ Unit: unit,
+ Sample: sample,
+ }
+ bm.Metric = append(bm.Metric, m)
+}
+
+// NewBenchmark initializes a new benchmark.
+func NewBenchmark(name string, official bool) *Benchmark {
+ return &Benchmark{
+ Name: name,
+ Timestamp: time.Now().UTC(),
+ Official: official,
+ Metric: make([]*Metric, 0),
+ }
+}
+
+// SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table.
+func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error {
+ client, err := bq.NewClient(ctx, projectID)
+ if err != nil {
+ return fmt.Errorf("Failed to initialize client on project: %s: %v", projectID, err)
+ }
+ defer client.Close()
+
+ uploader := client.Dataset(datasetID).Table(tableID).Uploader()
+ if err = uploader.Put(ctx, benchmarks); err != nil {
+ return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err)
+ }
+
+ return nil
+}
+
+// BigQuery will error "409" for duplicate tables and datasets.
+func checkDuplicateError(err error) bool {
+ return strings.Contains(err.Error(), "googleapi: Error 409: Already Exists")
+}
diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD
new file mode 100644
index 000000000..b8c3ddf44
--- /dev/null
+++ b/tools/checkescape/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "checkescape",
+ srcs = ["checkescape.go"],
+ nogo = False,
+ visibility = ["//tools/nogo:__subpackages__"],
+ deps = [
+ "//tools/nogo/data",
+ "@org_golang_x_tools//go/analysis:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/buildssa:go_tool_library",
+ "@org_golang_x_tools//go/ssa:go_tool_library",
+ ],
+)
diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go
new file mode 100644
index 000000000..f8def4823
--- /dev/null
+++ b/tools/checkescape/checkescape.go
@@ -0,0 +1,726 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package checkescape allows recursive escape analysis for hot paths.
+//
+// The analysis tracks multiple types of escapes, in two categories. First,
+// 'hard' escapes are explicit allocations. Second, 'soft' escapes are
+// interface dispatches or dynamic function dispatches; these don't necessarily
+// escape but they *may* escape. The analysis is capable of making assertions
+// recursively: soft escapes cannot be analyzed in this way, and therefore
+// count as escapes for recursive purposes.
+//
+// The different types of escapes are as follows, with the category in
+// parentheses:
+//
+// heap: A direct allocation is made on the heap (hard).
+// builtin: A call is made to a built-in allocation function (hard).
+// stack: A stack split as part of a function preamble (soft).
+// interface: A call is made via an interface whicy *may* escape (soft).
+// dynamic: A dynamic function is dispatched which *may* escape (soft).
+//
+// To the use the package, annotate a function-level comment with either the
+// line "// +checkescape" or "// +checkescape:OPTION[,OPTION]". In the second
+// case, the OPTION field is either a type above, or one of:
+//
+// local: Escape analysis is limited to local hard escapes only.
+// all: All the escapes are included.
+// hard: All hard escapes are included.
+//
+// If the "// +checkescape" annotation is provided, this is equivalent to
+// provided the local and hard options.
+//
+// Some examples of this syntax are:
+//
+// +checkescape:all - Analyzes for all escapes in this function and all calls.
+// +checkescape:local - Analyzes only for default local hard escapes.
+// +checkescape:heap - Only analyzes for heap escapes.
+// +checkescape:interface,dynamic - Only checks for dynamic calls and interface calls.
+// +checkescape - Does the same as +checkescape:local,hard.
+//
+// Note that all of the above can be inverted by using +mustescape. The
+// +checkescape keyword will ensure failure if the class of escape occurs,
+// whereas +mustescape will fail if the given class of escape does not occur.
+//
+// Local exemptions can be made by a comment of the form "// escapes: reason."
+// This must appear on the line of the escape and will also apply to callers of
+// the function as well (for non-local escape analysis).
+package checkescape
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+ "io"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/buildssa"
+ "golang.org/x/tools/go/ssa"
+ "gvisor.dev/gvisor/tools/nogo/data"
+)
+
+const (
+ // magic is the magic annotation.
+ magic = "// +checkescape"
+
+ // magicParams is the magic annotation with specific parameters.
+ magicParams = magic + ":"
+
+ // testMagic is the test magic annotation (parameters required).
+ testMagic = "// +mustescape:"
+
+ // exempt is the exemption annotation.
+ exempt = "// escapes"
+)
+
+// escapingBuiltins are builtins known to escape.
+//
+// These are lowered at an earlier stage of compilation to explicit function
+// calls, but are not available for recursive analysis.
+var escapingBuiltins = []string{
+ "append",
+ "makemap",
+ "newobject",
+ "mallocgc",
+}
+
+// Analyzer defines the entrypoint.
+var Analyzer = &analysis.Analyzer{
+ Name: "checkescape",
+ Doc: "surfaces recursive escape analysis results",
+ Run: run,
+ Requires: []*analysis.Analyzer{buildssa.Analyzer},
+ FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)},
+}
+
+// packageEscapeFacts is the set of all functions in a package, and whether or
+// not they recursively pass escape analysis.
+//
+// All the type names for receivers are encoded in the full key. The key
+// represents the fully qualified package and type name used at link time.
+type packageEscapeFacts struct {
+ Funcs map[string][]Escape
+}
+
+// AFact implements analysis.Fact.AFact.
+func (*packageEscapeFacts) AFact() {}
+
+// CallSite is a single call site.
+//
+// These can be chained.
+type CallSite struct {
+ LocalPos token.Pos
+ Resolved LinePosition
+}
+
+// Escape is a single escape instance.
+type Escape struct {
+ Reason EscapeReason
+ Detail string
+ Chain []CallSite
+}
+
+// LinePosition is a low-resolution token.Position.
+//
+// This is used to match against possible exemptions placed in the source.
+type LinePosition struct {
+ Filename string
+ Line int
+}
+
+// String implements fmt.Stringer.String.
+func (e *LinePosition) String() string {
+ return fmt.Sprintf("%s:%d", e.Filename, e.Line)
+}
+
+// String implements fmt.Stringer.String.
+//
+// Note that this string will contain new lines.
+func (e *Escape) String() string {
+ var b bytes.Buffer
+ fmt.Fprintf(&b, "%s", e.Reason.String())
+ for i, cs := range e.Chain {
+ if i == len(e.Chain)-1 {
+ fmt.Fprintf(&b, "\n @ %s → %s", cs.Resolved.String(), e.Detail)
+ } else {
+ fmt.Fprintf(&b, "\n + %s", cs.Resolved.String())
+ }
+ }
+ return b.String()
+}
+
+// EscapeReason is an escape reason.
+//
+// This is a simple enum.
+type EscapeReason int
+
+const (
+ interfaceInvoke EscapeReason = iota
+ unknownPackage
+ allocation
+ builtin
+ dynamicCall
+ stackSplit
+ reasonCount // Count for below.
+)
+
+// String returns the string for the EscapeReason.
+//
+// Note that this also implicitly defines the reverse string -> EscapeReason
+// mapping, which is the word before the colon (computed below).
+func (e EscapeReason) String() string {
+ switch e {
+ case interfaceInvoke:
+ return "interface: function invocation via interface"
+ case unknownPackage:
+ return "unknown: no package information available"
+ case allocation:
+ return "heap: call to runtime heap allocation"
+ case builtin:
+ return "builtin: call to runtime builtin"
+ case dynamicCall:
+ return "dynamic: call via dynamic function"
+ case stackSplit:
+ return "stack: stack split on function entry"
+ default:
+ panic(fmt.Sprintf("unknown reason: %d", e))
+ }
+}
+
+var hardReasons = []EscapeReason{
+ allocation,
+ builtin,
+}
+
+var softReasons = []EscapeReason{
+ interfaceInvoke,
+ unknownPackage,
+ dynamicCall,
+ stackSplit,
+}
+
+var allReasons = append(hardReasons, softReasons...)
+
+var escapeTypes = func() map[string]EscapeReason {
+ result := make(map[string]EscapeReason)
+ for _, r := range allReasons {
+ parts := strings.Split(r.String(), ":")
+ result[parts[0]] = r // Key before ':'.
+ }
+ return result
+}()
+
+// EscapeCount counts escapes.
+//
+// It is used to avoid accumulating too many escapes for the same reason, for
+// the same function. We limit each class to 3 instances (arbitrarily).
+type EscapeCount struct {
+ byReason [reasonCount]uint32
+}
+
+// maxRecordsPerReason is the number of explicit records.
+//
+// See EscapeCount (and usage), and Record implementation.
+const maxRecordsPerReason = 5
+
+// Record records the reason or returns false if it should not be added.
+func (ec *EscapeCount) Record(reason EscapeReason) bool {
+ ec.byReason[reason]++
+ if ec.byReason[reason] > maxRecordsPerReason {
+ return false
+ }
+ return true
+}
+
+// loadObjdump reads the objdump output.
+//
+// This records if there is a call any function for every source line. It is
+// used only to remove false positives for escape analysis. The call will be
+// elided if escape analysis is able to put the object on the heap exclusively.
+func loadObjdump() (map[LinePosition]string, error) {
+ f, err := os.Open(data.Objdump)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ // Build the map.
+ m := make(map[LinePosition]string)
+ r := bufio.NewReader(f)
+ var (
+ lastField string
+ lastPos LinePosition
+ )
+ for {
+ line, err := r.ReadString('\n')
+ if err != nil && err != io.EOF {
+ return nil, err
+ }
+
+ // We recognize lines corresponding to actual code (not the
+ // symbol name or other metadata) and annotate them if they
+ // correspond to an explicit CALL instruction. We assume that
+ // the lack of a CALL for a given line is evidence that escape
+ // analysis has eliminated an allocation.
+ //
+ // Lines look like this (including the first space):
+ // gohacks_unsafe.go:33 0xa39 488b442408 MOVQ 0x8(SP), AX
+ if len(line) > 0 && line[0] == ' ' {
+ fields := strings.Fields(line)
+ if !strings.Contains(fields[3], "CALL") {
+ continue
+ }
+
+ // Ignore strings containing duffzero, which is just
+ // used by stack allocations for types that are large
+ // enough to warrant Duff's device.
+ if strings.Contains(line, "runtime.duffzero") {
+ continue
+ }
+
+ // Ignore the racefuncenter call, which is used for
+ // race builds. This does not escape.
+ if strings.Contains(line, "runtime.racefuncenter") {
+ continue
+ }
+
+ // Calculate the filename and line. Note that per the
+ // example above, the filename is not a fully qualified
+ // base, just the basename (what we require).
+ if fields[0] != lastField {
+ parts := strings.SplitN(fields[0], ":", 2)
+ lineNum, err := strconv.ParseInt(parts[1], 10, 64)
+ if err != nil {
+ return nil, err
+ }
+ lastPos = LinePosition{
+ Filename: parts[0],
+ Line: int(lineNum),
+ }
+ lastField = fields[0]
+ }
+ if _, ok := m[lastPos]; ok {
+ continue // Already marked.
+ }
+
+ // Save the actual call for the detail.
+ m[lastPos] = strings.Join(fields[3:], " ")
+ }
+ if err == io.EOF {
+ break
+ }
+ }
+
+ return m, nil
+}
+
+// poser is a type that implements Pos.
+type poser interface {
+ Pos() token.Pos
+}
+
+// run performs the analysis.
+func run(pass *analysis.Pass) (interface{}, error) {
+ calls, err := loadObjdump()
+ if err != nil {
+ return nil, err
+ }
+ pef := packageEscapeFacts{
+ Funcs: make(map[string][]Escape),
+ }
+ linePosition := func(inst, parent poser) LinePosition {
+ p := pass.Fset.Position(inst.Pos())
+ if (p.Filename == "" || p.Line == 0) && parent != nil {
+ p = pass.Fset.Position(parent.Pos())
+ }
+ return LinePosition{
+ Filename: filepath.Base(p.Filename),
+ Line: p.Line,
+ }
+ }
+ hasCall := func(inst poser) (string, bool) {
+ p := linePosition(inst, nil)
+ s, ok := calls[p]
+ return s, ok
+ }
+ callSite := func(inst ssa.Instruction) CallSite {
+ return CallSite{
+ LocalPos: inst.Pos(),
+ Resolved: linePosition(inst, inst.Parent()),
+ }
+ }
+ escapes := func(reason EscapeReason, detail string, inst ssa.Instruction, ec *EscapeCount) []Escape {
+ if !ec.Record(reason) {
+ return nil // Skip.
+ }
+ es := Escape{
+ Reason: reason,
+ Detail: detail,
+ Chain: []CallSite{callSite(inst)},
+ }
+ return []Escape{es}
+ }
+ resolve := func(sub []Escape, inst ssa.Instruction, ec *EscapeCount) (es []Escape) {
+ for _, e := range sub {
+ if !ec.Record(e.Reason) {
+ continue // Skip.
+ }
+ es = append(es, Escape{
+ Reason: e.Reason,
+ Detail: e.Detail,
+ Chain: append([]CallSite{callSite(inst)}, e.Chain...),
+ })
+ }
+ return es
+ }
+ state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA)
+
+ var loadFunc func(*ssa.Function) []Escape // Used below.
+
+ analyzeInstruction := func(inst ssa.Instruction, ec *EscapeCount) []Escape {
+ switch x := inst.(type) {
+ case *ssa.Call:
+ if x.Call.IsInvoke() {
+ // This is an interface dispatch. There is no
+ // way to know if this is actually escaping or
+ // not, since we don't know the underlying
+ // type.
+ call, _ := hasCall(inst)
+ return escapes(interfaceInvoke, call, inst, ec)
+ }
+ switch x := x.Call.Value.(type) {
+ case *ssa.Function:
+ if x.Pkg == nil {
+ // Can't resolve the package.
+ return escapes(unknownPackage, "no package", inst, ec)
+ }
+
+ // Atomic functions are instrinics. We can
+ // assume that they don't escape.
+ if x.Pkg.Pkg.Name() == "atomic" {
+ return nil
+ }
+
+ // Is this a local function? If yes, call the
+ // function to load the local function. The
+ // local escapes are the escapes found in the
+ // local function.
+ if x.Pkg.Pkg == pass.Pkg {
+ return resolve(loadFunc(x), inst, ec)
+ }
+
+ // Recursively collect information from
+ // the other analyzers.
+ var imp packageEscapeFacts
+ if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) {
+ // Unable to import the dependency; we must
+ // declare these as escaping.
+ return escapes(unknownPackage, "no analysis", inst, ec)
+ }
+
+ // The escapes of this instruction are the
+ // escapes of the called function directly.
+ return resolve(imp.Funcs[x.RelString(x.Pkg.Pkg)], inst, ec)
+ case *ssa.Builtin:
+ // Ignore elided escapes.
+ if _, has := hasCall(inst); !has {
+ return nil
+ }
+
+ // Check if the builtin is escaping.
+ for _, name := range escapingBuiltins {
+ if x.Name() == name {
+ return escapes(builtin, name, inst, ec)
+ }
+ }
+ default:
+ // All dynamic calls are counted as soft
+ // escapes. They are similar to interface
+ // dispatches. We cannot actually look up what
+ // this refers to using static analysis alone.
+ call, _ := hasCall(inst)
+ return escapes(dynamicCall, call, inst, ec)
+ }
+ case *ssa.Alloc:
+ // Ignore non-heap allocations.
+ if !x.Heap {
+ return nil
+ }
+
+ // Ignore elided escapes.
+ call, has := hasCall(inst)
+ if !has {
+ return nil
+ }
+
+ // This is a real heap allocation.
+ return escapes(allocation, call, inst, ec)
+ case *ssa.MakeMap:
+ return escapes(builtin, "makemap", inst, ec)
+ case *ssa.MakeSlice:
+ return escapes(builtin, "makeslice", inst, ec)
+ case *ssa.MakeClosure:
+ return escapes(builtin, "makeclosure", inst, ec)
+ case *ssa.MakeChan:
+ return escapes(builtin, "makechan", inst, ec)
+ }
+ return nil // No escapes.
+ }
+
+ var analyzeBasicBlock func(*ssa.BasicBlock, *EscapeCount) []Escape // Recursive.
+ analyzeBasicBlock = func(block *ssa.BasicBlock, ec *EscapeCount) (rval []Escape) {
+ for _, inst := range block.Instrs {
+ rval = append(rval, analyzeInstruction(inst, ec)...)
+ }
+ return rval // N.B. may be empty.
+ }
+
+ loadFunc = func(fn *ssa.Function) []Escape {
+ // Is this already available?
+ name := fn.RelString(pass.Pkg)
+ if es, ok := pef.Funcs[name]; ok {
+ return es
+ }
+
+ // In the case of a true cycle, we assume that the current
+ // function itself has no escapes until the rest of the
+ // analysis is complete. This will trip the above in the case
+ // of a cycle of any kind.
+ pef.Funcs[name] = nil
+
+ // Perform the basic analysis.
+ var (
+ es []Escape
+ ec EscapeCount
+ )
+ if fn.Recover != nil {
+ es = append(es, analyzeBasicBlock(fn.Recover, &ec)...)
+ }
+ for _, block := range fn.Blocks {
+ es = append(es, analyzeBasicBlock(block, &ec)...)
+ }
+
+ // Check for a stack split.
+ if call, has := hasCall(fn); has {
+ es = append(es, Escape{
+ Reason: stackSplit,
+ Detail: call,
+ Chain: []CallSite{CallSite{
+ LocalPos: fn.Pos(),
+ Resolved: linePosition(fn, fn.Parent()),
+ }},
+ })
+ }
+
+ // Save the result and return.
+ pef.Funcs[name] = es
+ return es
+ }
+
+ // Complete all local functions.
+ for _, fn := range state.SrcFuncs {
+ loadFunc(fn)
+ }
+
+ // Build the exception list.
+ exemptions := make(map[LinePosition]string)
+ for _, f := range pass.Files {
+ for _, cg := range f.Comments {
+ for _, c := range cg.List {
+ p := pass.Fset.Position(c.Slash)
+ if strings.HasPrefix(strings.ToLower(c.Text), exempt) {
+ exemptions[LinePosition{
+ Filename: filepath.Base(p.Filename),
+ Line: p.Line,
+ }] = c.Text[len(exempt):]
+ }
+ }
+ }
+ }
+
+ // Delete everything matching the excemtions.
+ //
+ // This has the implication that exceptions are applied recursively,
+ // since this now modified set is what will be saved.
+ for name, escapes := range pef.Funcs {
+ var newEscapes []Escape
+ for _, escape := range escapes {
+ isExempt := false
+ for line, _ := range exemptions {
+ // Note that an exemption applies if it is
+ // marked as an exemption anywhere in the call
+ // chain. It need not be marked as escapes in
+ // the function itself, nor in the top-level
+ // caller.
+ for _, callSite := range escape.Chain {
+ if callSite.Resolved == line {
+ isExempt = true
+ break
+ }
+ }
+ if isExempt {
+ break
+ }
+ }
+ if !isExempt {
+ // Record this escape; not an exception.
+ newEscapes = append(newEscapes, escape)
+ }
+ }
+ pef.Funcs[name] = newEscapes // Update.
+ }
+
+ // Export all findings for future packages.
+ pass.ExportPackageFact(&pef)
+
+ // Scan all functions for violations.
+ for _, f := range pass.Files {
+ // Scan all declarations.
+ for _, decl := range f.Decls {
+ fdecl, ok := decl.(*ast.FuncDecl)
+ // Function declaration?
+ if !ok {
+ continue
+ }
+ // Is there a comment?
+ if fdecl.Doc == nil {
+ continue
+ }
+ var (
+ reasons []EscapeReason
+ found bool
+ local bool
+ testReasons = make(map[EscapeReason]bool) // reason -> local?
+ )
+ // Does the comment contain a +checkescape line?
+ for _, c := range fdecl.Doc.List {
+ if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) {
+ continue
+ }
+ if c.Text == magic {
+ // Default: hard reasons, local only.
+ reasons = hardReasons
+ local = true
+ } else if strings.HasPrefix(c.Text, magicParams) {
+ // Extract specific reasons.
+ types := strings.Split(c.Text[len(magicParams):], ",")
+ found = true // For below.
+ for i := 0; i < len(types); i++ {
+ if types[i] == "local" {
+ // Limit search to local escapes.
+ local = true
+ } else if types[i] == "all" {
+ // Append all reasons.
+ reasons = append(reasons, allReasons...)
+ } else if types[i] == "hard" {
+ // Append all hard reasons.
+ reasons = append(reasons, hardReasons...)
+ } else {
+ r, ok := escapeTypes[types[i]]
+ if !ok {
+ // This is not a valid escape reason.
+ pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i])
+ continue
+ }
+ reasons = append(reasons, r)
+ }
+ }
+ } else if strings.HasPrefix(c.Text, testMagic) {
+ types := strings.Split(c.Text[len(testMagic):], ",")
+ local := false
+ for i := 0; i < len(types); i++ {
+ if types[i] == "local" {
+ local = true
+ } else {
+ r, ok := escapeTypes[types[i]]
+ if !ok {
+ // This is not a valid escape reason.
+ pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i])
+ continue
+ }
+ if v, ok := testReasons[r]; ok && v {
+ // Already registered as local.
+ continue
+ }
+ testReasons[r] = local
+ }
+ }
+ }
+ }
+ if len(reasons) == 0 && found {
+ // A magic annotation was provided, but no reasons.
+ pass.Reportf(fdecl.Pos(), "no reasons provided")
+ continue
+ }
+
+ // Scan for matches.
+ fn := pass.TypesInfo.Defs[fdecl.Name].(*types.Func)
+ name := state.Pkg.Prog.FuncValue(fn).RelString(pass.Pkg)
+ es, ok := pef.Funcs[name]
+ if !ok {
+ pass.Reportf(fdecl.Pos(), "internal error: function %s not found.", name)
+ continue
+ }
+ for _, e := range es {
+ for _, r := range reasons {
+ // Is does meet our local requirement?
+ if local && len(e.Chain) > 1 {
+ continue
+ }
+ // Does this match the reason? Emit
+ // with a full stack trace that
+ // explains why this violates our
+ // constraints.
+ if e.Reason == r {
+ pass.Reportf(e.Chain[0].LocalPos, "%s", e.String())
+ }
+ }
+ }
+
+ // Scan for test (required) matches.
+ testReasonsFound := make(map[EscapeReason]bool)
+ for _, e := range es {
+ // Is this local?
+ local, ok := testReasons[e.Reason]
+ wantLocal := len(e.Chain) == 1
+ testReasonsFound[e.Reason] = wantLocal
+ if !ok {
+ continue
+ }
+ if local == wantLocal {
+ delete(testReasons, e.Reason)
+ }
+ }
+ for reason, local := range testReasons {
+ // We didn't find the escapes we wanted.
+ pass.Reportf(fdecl.Pos(), fmt.Sprintf("testescapes not found: reason=%s, local=%t", reason, local))
+ }
+ if len(testReasons) > 0 {
+ // Dump all reasons found to help in debugging.
+ for _, e := range es {
+ pass.Reportf(e.Chain[0].LocalPos, "escape found: %s", e.String())
+ }
+ }
+ }
+ }
+
+ return nil, nil
+}
diff --git a/tools/checkescape/test1/BUILD b/tools/checkescape/test1/BUILD
new file mode 100644
index 000000000..783403247
--- /dev/null
+++ b/tools/checkescape/test1/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "test1",
+ srcs = ["test1.go"],
+ visibility = ["//tools/checkescape/test2:__pkg__"],
+)
diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go
new file mode 100644
index 000000000..68d3f72cc
--- /dev/null
+++ b/tools/checkescape/test1/test1.go
@@ -0,0 +1,195 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test1 is a test package.
+package test1
+
+import (
+ "fmt"
+ "reflect"
+)
+
+// Interface is a generic interface.
+type Interface interface {
+ Foo()
+}
+
+// Type is a concrete implementation of Interface.
+type Type struct {
+ A uint64
+ B uint64
+}
+
+// Foo implements Interface.Foo.
+//go:nosplit
+func (t Type) Foo() {
+ fmt.Printf("%v", t) // Never executed.
+}
+
+// +checkescape:all,hard
+//go:nosplit
+func InterfaceFunction(i Interface) {
+ // Do nothing; exported for tests.
+}
+
+// +checkesacape:all,hard
+//go:nosplit
+func TypeFunction(t *Type) {
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinMap(x int) map[string]bool {
+ return make(map[string]bool)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinMapRec(x int) map[string]bool {
+ return BuiltinMap(x)
+}
+
+// +temustescapestescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinClosure(x int) func() {
+ return func() {
+ fmt.Printf("%v", x)
+ }
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinClosureRec(x int) func() {
+ return BuiltinClosure(x)
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinMakeSlice(x int) []byte {
+ return make([]byte, x)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinMakeSliceRec(x int) []byte {
+ return BuiltinMakeSlice(x)
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinAppend(x []byte) []byte {
+ return append(x, 0)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinAppendRec() []byte {
+ return BuiltinAppend(nil)
+}
+
+// +mustescape:local,builtin
+//go:noinline
+//go:nosplit
+func BuiltinChan() chan int {
+ return make(chan int)
+}
+
+// +mustescape:builtin
+//go:noinline
+//go:nosplit
+func builtinChanRec() chan int {
+ return BuiltinChan()
+}
+
+// +mustescape:local,heap
+//go:noinline
+//go:nosplit
+func Heap() *Type {
+ var t Type
+ return &t
+}
+
+// +mustescape:heap
+//go:noinline
+//go:nosplit
+func heapRec() *Type {
+ return Heap()
+}
+
+// +mustescape:local,interface
+//go:noinline
+//go:nosplit
+func Dispatch(i Interface) {
+ i.Foo()
+}
+
+// +mustescape:interface
+//go:noinline
+//go:nosplit
+func dispatchRec(i Interface) {
+ Dispatch(i)
+}
+
+// +mustescape:local,dynamic
+//go:noinline
+//go:nosplit
+func Dynamic(f func()) {
+ f()
+}
+
+// +mustescape:dynamic
+//go:noinline
+//go:nosplit
+func dynamicRec(f func()) {
+ Dynamic(f)
+}
+
+// +mustescape:local,unknown
+//go:noinline
+//go:nosplit
+func Unknown() {
+ _ = reflect.TypeOf((*Type)(nil)) // Does not actually escape.
+}
+
+// +mustescape:unknown
+//go:noinline
+//go:nosplit
+func unknownRec() {
+ Unknown()
+}
+
+//go:noinline
+//go:nosplit
+func internalFunc() {
+}
+
+// +mustescape:local,stack
+//go:noinline
+func Split() {
+ internalFunc()
+}
+
+// +mustescape:stack
+//go:noinline
+func splitRec() {
+ Split()
+}
diff --git a/tools/checkescape/test2/BUILD b/tools/checkescape/test2/BUILD
new file mode 100644
index 000000000..5a11e4b43
--- /dev/null
+++ b/tools/checkescape/test2/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "test2",
+ srcs = ["test2.go"],
+ deps = ["//tools/checkescape/test1"],
+)
diff --git a/tools/checkescape/test2/test2.go b/tools/checkescape/test2/test2.go
new file mode 100644
index 000000000..7fce3e3be
--- /dev/null
+++ b/tools/checkescape/test2/test2.go
@@ -0,0 +1,94 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test2 is a test package that imports test1.
+package test2
+
+import (
+ "gvisor.dev/gvisor/tools/checkescape/test1"
+)
+
+// +checkescape:all
+//go:nosplit
+func interfaceFunctionCrossPkg() {
+ var i test1.Interface
+ test1.InterfaceFunction(i)
+}
+
+// +checkesacape:all
+//go:nosplit
+func typeFunctionCrossPkg() {
+ var t test1.Type
+ test1.TypeFunction(&t)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinMapCrossPkg(x int) map[string]bool {
+ return test1.BuiltinMap(x)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinClosureCrossPkg(x int) func() {
+ return test1.BuiltinClosure(x)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinMakeSliceCrossPkg(x int) []byte {
+ return test1.BuiltinMakeSlice(x)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinAppendCrossPkg() []byte {
+ return test1.BuiltinAppend(nil)
+}
+
+// +mustescape:builtin
+//go:noinline
+func builtinChanCrossPkg() chan int {
+ return test1.BuiltinChan()
+}
+
+// +mustescape:heap
+//go:noinline
+func heapCrossPkg() *test1.Type {
+ return test1.Heap()
+}
+
+// +mustescape:interface
+//go:noinline
+func dispatchCrossPkg(i test1.Interface) {
+ test1.Dispatch(i)
+}
+
+// +mustescape:dynamic
+//go:noinline
+func dynamicCrossPkg(f func()) {
+ test1.Dynamic(f)
+}
+
+// +mustescape:unknown
+//go:noinline
+func unknownCrossPkg() {
+ test1.Unknown()
+}
+
+// +mustescape:stack
+//go:noinline
+func splitCrosssPkt() {
+ test1.Split()
+}
diff --git a/tools/checkunsafe/BUILD b/tools/checkunsafe/BUILD
index d85c56131..0c264151b 100644
--- a/tools/checkunsafe/BUILD
+++ b/tools/checkunsafe/BUILD
@@ -1,12 +1,12 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_tool_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
-go_tool_library(
+go_library(
name = "checkunsafe",
srcs = ["check_unsafe.go"],
- importpath = "checkunsafe",
- visibility = ["//visibility:public"],
+ nogo = False,
+ visibility = ["//tools/nogo:__subpackages__"],
deps = [
"@org_golang_x_tools//go/analysis:go_tool_library",
],
diff --git a/tools/defs.bzl b/tools/defs.bzl
new file mode 100644
index 000000000..e71a26cf4
--- /dev/null
+++ b/tools/defs.bzl
@@ -0,0 +1,253 @@
+"""Wrappers for common build rules.
+
+These wrappers apply common BUILD configurations (e.g., proto_library
+automagically creating cc_ and go_ proto targets) and act as a single point of
+change for Google-internal and bazel-compatible rules.
+"""
+
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
+load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option")
+load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
+load("//tools/bazeldefs:tags.bzl", "go_suffixes")
+load("//tools/nogo:defs.bzl", "nogo_test")
+
+# Delegate directly.
+build_test = _build_test
+bzl_library = _bzl_library
+cc_binary = _cc_binary
+cc_flags_supplier = _cc_flags_supplier
+cc_grpc_library = _cc_grpc_library
+cc_library = _cc_library
+cc_test = _cc_test
+cc_toolchain = _cc_toolchain
+default_installer = _default_installer
+default_net_util = _default_net_util
+gbenchmark = _gbenchmark
+gazelle = _gazelle
+go_embed_data = _go_embed_data
+go_path = _go_path
+go_test = _go_test
+gtest = _gtest
+grpcpp = _grpcpp
+loopback = _loopback
+pkg_deb = _pkg_deb
+pkg_tar = _pkg_tar
+py_binary = _py_binary
+select_arch = _select_arch
+select_system = _select_system
+short_path = _short_path
+rbe_platform = _rbe_platform
+rbe_toolchain = _rbe_toolchain
+vdso_linker_option = _vdso_linker_option
+
+# Platform options.
+default_platform = _default_platform
+platforms = _platforms
+
+def go_binary(name, **kwargs):
+ """Wraps the standard go_binary.
+
+ Args:
+ name: the rule name.
+ **kwargs: standard go_binary arguments.
+ """
+ _go_binary(
+ name = name,
+ **kwargs
+ )
+
+def calculate_sets(srcs):
+ """Calculates special Go sets for templates.
+
+ Args:
+ srcs: the full set of Go sources.
+
+ Returns:
+ A dictionary of the form:
+
+ "": [src1.go, src2.go]
+ "suffix": [src3suffix.go, src4suffix.go]
+
+ Note that suffix will typically start with '_'.
+ """
+ result = dict()
+ for file in srcs:
+ if not file.endswith(".go"):
+ continue
+ target = ""
+ for suffix in go_suffixes:
+ if file.endswith(suffix + ".go"):
+ target = suffix
+ if not target in result:
+ result[target] = [file]
+ else:
+ result[target].append(file)
+ return result
+
+def go_imports(name, src, out):
+ """Simplify a single Go source file by eliminating unused imports."""
+ native.genrule(
+ name = name,
+ srcs = [src],
+ outs = [out],
+ tools = ["@org_golang_x_tools//cmd/goimports:goimports"],
+ cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"),
+ )
+
+def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, nogo = True, **kwargs):
+ """Wraps the standard go_library and does stateification and marshalling.
+
+ The recommended way is to use this rule with mostly identical configuration as the native
+ go_library rule.
+
+ These definitions provide additional flags (stateify, marshal) that can be used
+ with the generators to automatically supplement the library code.
+
+ load("//tools:defs.bzl", "go_library")
+
+ go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+ )
+
+ Args:
+ name: the rule name.
+ srcs: the library sources.
+ deps: the library dependencies.
+ imports: imports required for stateify.
+ stateify: whether statify is enabled (default: true).
+ marshal: whether marshal is enabled (default: false).
+ marshal_debug: whether the gomarshal tools emits debugging output (default: false).
+ **kwargs: standard go_library arguments.
+ """
+ all_srcs = srcs
+ all_deps = deps
+ dirname, _, _ = native.package_name().rpartition("/")
+ full_pkg = dirname + "/" + name
+ if stateify:
+ # Only do stateification for non-state packages without manual autogen.
+ # First, we need to segregate the input files via the special suffixes,
+ # and calculate the final output set.
+ state_sets = calculate_sets(srcs)
+ for (suffix, src_subset) in state_sets.items():
+ go_stateify(
+ name = name + suffix + "_state_autogen_with_imports",
+ srcs = src_subset,
+ imports = imports,
+ package = full_pkg,
+ out = name + suffix + "_state_autogen_with_imports.go",
+ )
+ go_imports(
+ name = name + suffix + "_state_autogen",
+ src = name + suffix + "_state_autogen_with_imports.go",
+ out = name + suffix + "_state_autogen.go",
+ )
+ all_srcs = all_srcs + [
+ name + suffix + "_state_autogen.go"
+ for suffix in state_sets.keys()
+ ]
+ if "//pkg/state" not in all_deps:
+ all_deps = all_deps + ["//pkg/state"]
+
+ if marshal:
+ # See above.
+ marshal_sets = calculate_sets(srcs)
+ for (suffix, src_subset) in marshal_sets.items():
+ go_marshal(
+ name = name + suffix + "_abi_autogen",
+ srcs = src_subset,
+ debug = select({
+ "//tools/go_marshal:marshal_config_verbose": True,
+ "//conditions:default": marshal_debug,
+ }),
+ imports = imports,
+ package = name,
+ )
+ extra_deps = [
+ dep
+ for dep in marshal_deps
+ if not dep in all_deps
+ ]
+ all_deps = all_deps + extra_deps
+ all_srcs = all_srcs + [
+ name + suffix + "_abi_autogen_unsafe.go"
+ for suffix in marshal_sets.keys()
+ ]
+
+ _go_library(
+ name = name,
+ srcs = all_srcs,
+ deps = all_deps,
+ **kwargs
+ )
+ if nogo:
+ nogo_test(
+ name = name + "_nogo",
+ deps = [":" + name],
+ )
+
+ if marshal:
+ # Ignore importpath for go_test.
+ kwargs.pop("importpath", None)
+
+ # See above.
+ marshal_sets = calculate_sets(srcs)
+ for (suffix, _) in marshal_sets.items():
+ _go_test(
+ name = name + suffix + "_abi_autogen_test",
+ srcs = [name + suffix + "_abi_autogen_test.go"],
+ library = ":" + name,
+ deps = marshal_test_deps,
+ **kwargs
+ )
+
+def proto_library(name, srcs, deps = None, has_services = 0, **kwargs):
+ """Wraps the standard proto_library.
+
+ Given a proto_library named "foo", this produces up to five different
+ targets:
+ - foo_proto: proto_library rule.
+ - foo_go_proto: go_proto_library rule.
+ - foo_cc_proto: cc_proto_library rule.
+ - foo_go_grpc_proto: go_grpc_library rule.
+ - foo_cc_grpc_proto: cc_grpc_library rule.
+
+ Args:
+ name: the name to which _proto, _go_proto, etc, will be appended.
+ srcs: the proto sources.
+ deps: for the proto library and the go_proto_library.
+ has_services: 1 to build gRPC code, otherwise 0.
+ **kwargs: standard proto_library arguments.
+ """
+ _proto_library(
+ name = name + "_proto",
+ srcs = srcs,
+ deps = deps,
+ has_services = has_services,
+ **kwargs
+ )
+ if has_services:
+ _go_grpc_and_proto_libraries(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
+ else:
+ _go_proto_library(
+ name = name,
+ deps = deps,
+ **kwargs
+ )
+ _cc_proto_library(
+ name = name + "_cc_proto",
+ deps = [":" + name + "_proto"],
+ **kwargs
+ )
+ if has_services:
+ _cc_grpc_library(
+ name = name + "_cc_grpc_proto",
+ srcs = [":" + name + "_proto"],
+ deps = [":" + name + "_cc_proto"],
+ **kwargs
+ )
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index ddb9b6e7b..e5c060024 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -14,12 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -eo pipefail
+set -xeo pipefail
# Discovery the package name from the go.mod file.
-declare -r gomod="$(pwd)/go.mod"
-declare -r module=$(cat "${gomod}" | grep -E "^module" | cut -d' ' -f2)
-declare -r gosum="$(pwd)/go.sum"
+declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2)
+declare -r origpwd=$(pwd)
+declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE")
# Check that gopath has been built.
declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}"
@@ -40,9 +40,15 @@ trap finish EXIT
# Record the current working commit.
declare -r head=$(git describe --always)
-# We expect to have an existing go branch that we will use as the basis for
-# this commit. That branch may be empty, but it must exist.
-declare -r go_branch=$(git show-ref --hash origin/go)
+# We expect to have an existing go branch that we will use as the basis for this
+# commit. That branch may be empty, but it must exist. We search for this branch
+# using the local branch, the "origin" branch, and other remotes, in order.
+git fetch --all
+declare -r go_branch=$( \
+ git show-ref --hash refs/heads/go || \
+ git show-ref --hash refs/remotes/origin/go || \
+ git show-ref --hash go | head -n 1 \
+)
# Clone the current repository to the temporary directory, and check out the
# current go_branch directory. We move to the new repository for convenience.
@@ -65,15 +71,42 @@ git checkout -b go "${go_branch}"
git merge --no-commit --strategy ours ${head} || \
git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
-# Sync the entire gopath_dir and go.mod.
-rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" .
-cp "${gomod}" .
-cp "${gosum}" .
+# Normalize the permissions on the old branch. Note that they should be
+# normalized if constructed by this tool, but we do so before the rsync.
+find . -type f -exec chmod 0644 {} \;
+find . -type d -exec chmod 0755 {} \;
+
+# Sync the entire gopath_dir.
+rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" .
+
+# Add additional files.
+for file in "${othersrc[@]}"; do
+ cp "${origpwd}"/"${file}" .
+done
+
+# Construct a new README.md.
+cat > README.md <<EOF
+# gVisor
+
+This branch is a synthetic branch, containing only Go sources, that is
+compatible with standard Go tools. See the master branch for authoritative
+sources and tests.
+EOF
# There are a few solitary files that can get left behind due to the way bazel
# constructs the gopath target. Note that we don't find all Go files here
# because they may correspond to unused templates, etc.
-cp "${repo_orig}"/runsc/*.go runsc/
+declare -ar binaries=( "runsc" "shim/v1" "shim/v2" )
+for target in "${binaries[@]}"; do
+ mkdir -p "${target}"
+ cp "${repo_orig}/${target}"/*.go "${target}/"
+done
+
+# Normalize all permissions. The way bazel constructs the :gopath tree may leave
+# some strange permissions on files. We don't have anything in this tree that
+# should be execution, only the Go source files, README.md, and ${othersrc}.
+find . -type f -exec chmod 0644 {} \;
+find . -type d -exec chmod 0755 {} \;
# Update the current working set and commit.
git add . && git commit -m "Merge ${head} (automated)"
diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD
index 39318b877..807c08ead 100644
--- a/tools/go_generics/BUILD
+++ b/tools/go_generics/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "bzl_library", "go_binary")
package(licenses = ["notice"])
@@ -9,30 +9,12 @@ go_binary(
"imports.go",
"remove.go",
],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
deps = ["//tools/go_generics/globals"],
)
-genrule(
- name = "go_generics_tests",
- srcs = glob(["generics_tests/**"]) + [":go_generics"],
- outs = ["go_generics_tests.tgz"],
- cmd = "tar -czvhf $@ $(SRCS)",
-)
-
-genrule(
- name = "go_generics_test_bundle",
- srcs = [
- ":go_generics_tests.tgz",
- ":go_generics_unittest.sh",
- ],
- outs = ["go_generics_test.sh"],
- cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@",
- executable = True,
-)
-
-sh_test(
- name = "go_generics_test",
- size = "small",
- srcs = ["go_generics_test.sh"],
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
)
diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl
index c5be52ecd..33329cf28 100644
--- a/tools/go_generics/defs.bzl
+++ b/tools/go_generics/defs.bzl
@@ -1,11 +1,24 @@
+"""Generics support via go_generics."""
+
+TemplateInfo = provider(
+ fields = {
+ "types": "required types",
+ "opt_types": "optional types",
+ "consts": "required consts",
+ "opt_consts": "optional consts",
+ "deps": "package dependencies",
+ "file": "merged template",
+ },
+)
+
def _go_template_impl(ctx):
- input = ctx.files.srcs
+ srcs = ctx.files.srcs
output = ctx.outputs.out
- args = ["-o=%s" % output.path] + [f.path for f in input]
+ args = ["-o=%s" % output.path] + [f.path for f in srcs]
ctx.actions.run(
- inputs = input,
+ inputs = srcs,
outputs = [output],
mnemonic = "GoGenericsTemplate",
progress_message = "Building Go template %s" % ctx.label,
@@ -13,14 +26,14 @@ def _go_template_impl(ctx):
executable = ctx.executable._tool,
)
- return struct(
+ return [TemplateInfo(
types = ctx.attr.types,
opt_types = ctx.attr.opt_types,
consts = ctx.attr.consts,
opt_consts = ctx.attr.opt_consts,
deps = ctx.attr.deps,
file = output,
- )
+ )]
"""
Generates a Go template from a set of Go files.
@@ -43,7 +56,7 @@ go_template = rule(
implementation = _go_template_impl,
attrs = {
"srcs": attr.label_list(mandatory = True, allow_files = True),
- "deps": attr.label_list(allow_files = True),
+ "deps": attr.label_list(allow_files = True, cfg = "target"),
"types": attr.string_list(),
"opt_types": attr.string_list(),
"consts": attr.string_list(),
@@ -55,8 +68,14 @@ go_template = rule(
},
)
+TemplateInstanceInfo = provider(
+ fields = {
+ "srcs": "source files",
+ },
+)
+
def _go_template_instance_impl(ctx):
- template = ctx.attr.template
+ template = ctx.attr.template[TemplateInfo]
output = ctx.outputs.out
# Check that all required types are defined.
@@ -81,20 +100,21 @@ def _go_template_instance_impl(ctx):
# Build the argument list.
args = ["-i=%s" % template.file.path, "-o=%s" % output.path]
- args += ["-p=%s" % ctx.attr.package]
+ if ctx.attr.package:
+ args.append("-p=%s" % ctx.attr.package)
if len(ctx.attr.prefix) > 0:
- args += ["-prefix=%s" % ctx.attr.prefix]
+ args.append("-prefix=%s" % ctx.attr.prefix)
if len(ctx.attr.suffix) > 0:
- args += ["-suffix=%s" % ctx.attr.suffix]
+ args.append("-suffix=%s" % ctx.attr.suffix)
args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()]
args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()]
args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()]
if ctx.attr.anon:
- args += ["-anon"]
+ args.append("-anon")
ctx.actions.run(
inputs = [template.file],
@@ -105,10 +125,9 @@ def _go_template_instance_impl(ctx):
executable = ctx.executable._tool,
)
- # TODO: How can we get the dependencies out?
- return struct(
- files = depset([output]),
- )
+ return [TemplateInstanceInfo(
+ srcs = [output],
+ )]
"""
Instantiates a Go template by replacing all generic types with concrete ones.
@@ -126,14 +145,14 @@ Args:
go_template_instance = rule(
implementation = _go_template_instance_impl,
attrs = {
- "template": attr.label(mandatory = True, providers = ["types"]),
+ "template": attr.label(mandatory = True),
"prefix": attr.string(),
"suffix": attr.string(),
"types": attr.string_dict(),
"consts": attr.string_dict(),
"imports": attr.string_dict(),
"anon": attr.bool(mandatory = False, default = False),
- "package": attr.string(mandatory = True),
+ "package": attr.string(mandatory = False),
"out": attr.output(mandatory = True),
"_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")),
},
diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go
index e9cc2c753..0860ca9db 100644
--- a/tools/go_generics/generics.go
+++ b/tools/go_generics/generics.go
@@ -223,7 +223,9 @@ func main() {
} else {
switch kind {
case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction:
- ident.Name = *prefix + ident.Name + *suffix
+ if ident.Name != "_" {
+ ident.Name = *prefix + ident.Name + *suffix
+ }
case globals.KindTag:
// Modify the state tag appropriately.
if m := stateTagRegexp.FindStringSubmatch(ident.Name); m != nil {
diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt
deleted file mode 100644
index c9d0e09bf..000000000
--- a/tools/go_generics/generics_tests/all_stmts/opts.txt
+++ /dev/null
@@ -1 +0,0 @@
--t=T=Q
diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt
deleted file mode 100644
index c9d0e09bf..000000000
--- a/tools/go_generics/generics_tests/all_types/opts.txt
+++ /dev/null
@@ -1 +0,0 @@
--t=T=Q
diff --git a/tools/go_generics/generics_tests/anon/opts.txt b/tools/go_generics/generics_tests/anon/opts.txt
deleted file mode 100644
index a5e9d26de..000000000
--- a/tools/go_generics/generics_tests/anon/opts.txt
+++ /dev/null
@@ -1 +0,0 @@
--t=T=Q -suffix=New -anon
diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt
deleted file mode 100644
index 4fb59dce8..000000000
--- a/tools/go_generics/generics_tests/consts/opts.txt
+++ /dev/null
@@ -1 +0,0 @@
--c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC"
diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt
deleted file mode 100644
index 87324be79..000000000
--- a/tools/go_generics/generics_tests/imports/opts.txt
+++ /dev/null
@@ -1 +0,0 @@
--t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath
diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt
deleted file mode 100644
index 9c8ecaada..000000000
--- a/tools/go_generics/generics_tests/remove_typedef/opts.txt
+++ /dev/null
@@ -1 +0,0 @@
--t=T=U
diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt
deleted file mode 100644
index 7832ef66f..000000000
--- a/tools/go_generics/generics_tests/simple/opts.txt
+++ /dev/null
@@ -1 +0,0 @@
--t=T=Q -suffix=New
diff --git a/tools/go_generics/globals/BUILD b/tools/go_generics/globals/BUILD
index 74853c7d2..38caa3ce7 100644
--- a/tools/go_generics/globals/BUILD
+++ b/tools/go_generics/globals/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
package(licenses = ["notice"])
@@ -8,6 +8,6 @@ go_library(
"globals_visitor.go",
"scope.go",
],
- importpath = "gvisor.dev/gvisor/tools/go_generics/globals",
+ stateify = False,
visibility = ["//tools/go_generics:__pkg__"],
)
diff --git a/tools/go_generics/globals/scope.go b/tools/go_generics/globals/scope.go
index 96c965ea2..eec93534b 100644
--- a/tools/go_generics/globals/scope.go
+++ b/tools/go_generics/globals/scope.go
@@ -72,6 +72,10 @@ func (s *scope) deepLookup(n string) *symbol {
}
func (s *scope) add(name string, kind SymKind, pos token.Pos) {
+ if s.syms[name] != nil {
+ return
+ }
+
s.syms[name] = &symbol{
kind: kind,
pos: pos,
diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh
deleted file mode 100755
index 44b22db91..000000000
--- a/tools/go_generics/go_generics_unittest.sh
+++ /dev/null
@@ -1,70 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Bash "safe-mode": Treat command failures as fatal (even those that occur in
-# pipes), and treat unset variables as errors.
-set -eu -o pipefail
-
-# This file will be generated as a self-extracting shell script in order to
-# eliminate the need for any runtime dependencies. The tarball at the end will
-# include the go_generics binary, as well as a subdirectory named
-# generics_tests. See the BUILD file for more information.
-declare -r temp=$(mktemp -d)
-function cleanup() {
- rm -rf "${temp}"
-}
-# trap cleanup EXIT
-
-# Print message in "$1" then exit with status 1.
-function die () {
- echo "$1" 1>&2
- exit 1
-}
-
-# This prints the line number of __BUNDLE__ below, that should be the last line
-# of this script. After that point, the concatenated archive will be the
-# contents.
-declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0`
-tail -n+"${tgz}" $0 | tar -xzv -C "${temp}"
-
-# The target for the test.
-declare -r binary="$(find ${temp} -type f -a -name go_generics)"
-declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*"
-
-# Go through all test cases.
-for f in ${input_dirs}; do
- base=$(basename "${f}")
-
- # Run go_generics on the input file.
- opts=$(head -n 1 ${f}/opts.txt)
- out="${f}/output/generated.go"
- expected="${f}/output/output.go"
- ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\""
-
- # Compare the outputs.
- diff ${expected} ${out}
- if [ $? -ne 0 ]; then
- echo "Expected:"
- cat ${expected}
- echo "Actual:"
- cat ${out}
- die "Actual output is different from expected for test \"${base}\""
- fi
-done
-
-echo "PASS"
-exit 0
-__BUNDLE__
diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD
index 02b09120e..2fd5a200d 100644
--- a/tools/go_generics/go_merge/BUILD
+++ b/tools/go_generics/go_merge/BUILD
@@ -1,9 +1,9 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "go_binary")
package(licenses = ["notice"])
go_binary(
name = "go_merge",
srcs = ["main.go"],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
)
diff --git a/tools/go_generics/rules_tests/BUILD b/tools/go_generics/rules_tests/BUILD
index 9d26a88b7..8a329dfc6 100644
--- a/tools/go_generics/rules_tests/BUILD
+++ b/tools/go_generics/rules_tests/BUILD
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
diff --git a/tools/go_generics/tests/BUILD b/tools/go_generics/tests/BUILD
new file mode 100644
index 000000000..7547a6b53
--- /dev/null
+++ b/tools/go_generics/tests/BUILD
@@ -0,0 +1,7 @@
+load("//tools:defs.bzl", "bzl_library")
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/go_generics/tests/all_stmts/BUILD b/tools/go_generics/tests/all_stmts/BUILD
new file mode 100644
index 000000000..a4a7c775a
--- /dev/null
+++ b/tools/go_generics/tests/all_stmts/BUILD
@@ -0,0 +1,16 @@
+load("//tools/go_generics/tests:defs.bzl", "go_generics_test")
+
+go_generics_test(
+ name = "all_stmts",
+ inputs = ["input.go"],
+ output = "output.go",
+ types = {
+ "T": "Q",
+ },
+)
+
+# @unused
+glaze_ignore = [
+ "input.go",
+ "output.go",
+]
diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/tests/all_stmts/input.go
index 4791d1ff1..4791d1ff1 100644
--- a/tools/go_generics/generics_tests/all_stmts/input.go
+++ b/tools/go_generics/tests/all_stmts/input.go
diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/tests/all_stmts/output.go
index a53d84535..a53d84535 100644
--- a/tools/go_generics/generics_tests/all_stmts/output/output.go
+++ b/tools/go_generics/tests/all_stmts/output.go
diff --git a/tools/go_generics/tests/all_types/BUILD b/tools/go_generics/tests/all_types/BUILD
new file mode 100644
index 000000000..60b1fd314
--- /dev/null
+++ b/tools/go_generics/tests/all_types/BUILD
@@ -0,0 +1,16 @@
+load("//tools/go_generics/tests:defs.bzl", "go_generics_test")
+
+go_generics_test(
+ name = "all_types",
+ inputs = ["input.go"],
+ output = "output.go",
+ types = {
+ "T": "Q",
+ },
+)
+
+# @unused
+glaze_ignore = [
+ "input.go",
+ "output.go",
+]
diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/tests/all_types/input.go
index 3575d02ec..6f85bbb69 100644
--- a/tools/go_generics/generics_tests/all_types/input.go
+++ b/tools/go_generics/tests/all_types/input.go
@@ -14,7 +14,9 @@
package tests
-import "./lib"
+import (
+ "./lib"
+)
type T int
diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/tests/all_types/lib/lib.go
index 988786496..988786496 100644
--- a/tools/go_generics/generics_tests/all_types/lib/lib.go
+++ b/tools/go_generics/tests/all_types/lib/lib.go
diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/tests/all_types/output.go
index 41fd147a1..c0bbebfe7 100644
--- a/tools/go_generics/generics_tests/all_types/output/output.go
+++ b/tools/go_generics/tests/all_types/output.go
@@ -14,7 +14,9 @@
package main
-import "./lib"
+import (
+ "./lib"
+)
type newType struct {
a Q
diff --git a/tools/go_generics/tests/anon/BUILD b/tools/go_generics/tests/anon/BUILD
new file mode 100644
index 000000000..ef24f4b25
--- /dev/null
+++ b/tools/go_generics/tests/anon/BUILD
@@ -0,0 +1,18 @@
+load("//tools/go_generics/tests:defs.bzl", "go_generics_test")
+
+go_generics_test(
+ name = "anon",
+ anon = True,
+ inputs = ["input.go"],
+ output = "output.go",
+ suffix = "New",
+ types = {
+ "T": "Q",
+ },
+)
+
+# @unused
+glaze_ignore = [
+ "input.go",
+ "output.go",
+]
diff --git a/tools/go_generics/generics_tests/anon/input.go b/tools/go_generics/tests/anon/input.go
index 44086d522..44086d522 100644
--- a/tools/go_generics/generics_tests/anon/input.go
+++ b/tools/go_generics/tests/anon/input.go
diff --git a/tools/go_generics/generics_tests/anon/output/output.go b/tools/go_generics/tests/anon/output.go
index 160cddf79..7fa791853 100644
--- a/tools/go_generics/generics_tests/anon/output/output.go
+++ b/tools/go_generics/tests/anon/output.go
@@ -35,8 +35,8 @@ func (f FooNew) GetBar(name string) Q {
func foobarNew() {
a := BazNew{}
- a.Q = 0 // should not be renamed, this is a limitation
+ a.Q = 0
b := otherpkg.UnrelatedType{}
- b.Q = 0 // should not be renamed, this is a limitation
+ b.Q = 0
}
diff --git a/tools/go_generics/tests/consts/BUILD b/tools/go_generics/tests/consts/BUILD
new file mode 100644
index 000000000..fd7caccad
--- /dev/null
+++ b/tools/go_generics/tests/consts/BUILD
@@ -0,0 +1,23 @@
+load("//tools/go_generics/tests:defs.bzl", "go_generics_test")
+
+go_generics_test(
+ name = "consts",
+ consts = {
+ "c1": "20",
+ "z": "600",
+ "v": "3.3",
+ "s": "\"def\"",
+ "A": "20",
+ "C": "100",
+ "S": "\"def\"",
+ "T": "\"ABC\"",
+ },
+ inputs = ["input.go"],
+ output = "output.go",
+)
+
+# @unused
+glaze_ignore = [
+ "input.go",
+ "output.go",
+]
diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/tests/consts/input.go
index 04b95fcc6..04b95fcc6 100644
--- a/tools/go_generics/generics_tests/consts/input.go
+++ b/tools/go_generics/tests/consts/input.go
diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/tests/consts/output.go
index 18d316cc9..18d316cc9 100644
--- a/tools/go_generics/generics_tests/consts/output/output.go
+++ b/tools/go_generics/tests/consts/output.go
diff --git a/tools/go_generics/tests/defs.bzl b/tools/go_generics/tests/defs.bzl
new file mode 100644
index 000000000..6277c3947
--- /dev/null
+++ b/tools/go_generics/tests/defs.bzl
@@ -0,0 +1,67 @@
+"""Generics tests."""
+
+load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
+
+def _go_generics_test_impl(ctx):
+ runner = ctx.actions.declare_file(ctx.label.name)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "exec diff --ignore-blank-lines --ignore-matching-lines=^[[:space:]]*// %s %s" % (
+ ctx.files.template_output[0].short_path,
+ ctx.files.expected_output[0].short_path,
+ ),
+ "",
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+ return [DefaultInfo(
+ executable = runner,
+ runfiles = ctx.runfiles(
+ files = ctx.files.template_output + ctx.files.expected_output,
+ collect_default = True,
+ collect_data = True,
+ ),
+ )]
+
+_go_generics_test = rule(
+ implementation = _go_generics_test_impl,
+ attrs = {
+ "template_output": attr.label(mandatory = True, allow_single_file = True),
+ "expected_output": attr.label(mandatory = True, allow_single_file = True),
+ },
+ test = True,
+)
+
+def go_generics_test(name, inputs, output, types = None, consts = None, **kwargs):
+ """Instantiates a generics test.
+
+ Args:
+ name: the name of the test.
+ inputs: all the input files.
+ output: the output files.
+ types: the template types (dictionary).
+ consts: the template consts (dictionary).
+ **kwargs: additional arguments for the template_instance.
+ """
+ if types == None:
+ types = dict()
+ if consts == None:
+ consts = dict()
+ go_template(
+ name = name + "_template",
+ srcs = inputs,
+ types = types.keys(),
+ consts = consts.keys(),
+ )
+ go_template_instance(
+ name = name + "_output",
+ template = ":" + name + "_template",
+ out = name + "_output.go",
+ types = types,
+ consts = consts,
+ **kwargs
+ )
+ _go_generics_test(
+ name = name + "_test",
+ template_output = name + "_output.go",
+ expected_output = output,
+ )
diff --git a/tools/go_generics/tests/imports/BUILD b/tools/go_generics/tests/imports/BUILD
new file mode 100644
index 000000000..a86223d41
--- /dev/null
+++ b/tools/go_generics/tests/imports/BUILD
@@ -0,0 +1,24 @@
+load("//tools/go_generics/tests:defs.bzl", "go_generics_test")
+
+go_generics_test(
+ name = "imports",
+ consts = {
+ "n": "math.Uint32",
+ "m": "math.Uint64",
+ },
+ imports = {
+ "sync": "sync",
+ "math": "mymathpath",
+ },
+ inputs = ["input.go"],
+ output = "output.go",
+ types = {
+ "T": "sync.Mutex",
+ },
+)
+
+# @unused
+glaze_ignore = [
+ "input.go",
+ "output.go",
+]
diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/tests/imports/input.go
index 0f032c2a1..0f032c2a1 100644
--- a/tools/go_generics/generics_tests/imports/input.go
+++ b/tools/go_generics/tests/imports/input.go
diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/tests/imports/output.go
index 2488ca58c..2488ca58c 100644
--- a/tools/go_generics/generics_tests/imports/output/output.go
+++ b/tools/go_generics/tests/imports/output.go
diff --git a/tools/go_generics/tests/remove_typedef/BUILD b/tools/go_generics/tests/remove_typedef/BUILD
new file mode 100644
index 000000000..46457cec6
--- /dev/null
+++ b/tools/go_generics/tests/remove_typedef/BUILD
@@ -0,0 +1,16 @@
+load("//tools/go_generics/tests:defs.bzl", "go_generics_test")
+
+go_generics_test(
+ name = "remove_typedef",
+ inputs = ["input.go"],
+ output = "output.go",
+ types = {
+ "T": "U",
+ },
+)
+
+# @unused
+glaze_ignore = [
+ "input.go",
+ "output.go",
+]
diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/tests/remove_typedef/input.go
index cf632bae7..cf632bae7 100644
--- a/tools/go_generics/generics_tests/remove_typedef/input.go
+++ b/tools/go_generics/tests/remove_typedef/input.go
diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/tests/remove_typedef/output.go
index d44fd8e1c..d44fd8e1c 100644
--- a/tools/go_generics/generics_tests/remove_typedef/output/output.go
+++ b/tools/go_generics/tests/remove_typedef/output.go
diff --git a/tools/go_generics/tests/simple/BUILD b/tools/go_generics/tests/simple/BUILD
new file mode 100644
index 000000000..4b9265ea4
--- /dev/null
+++ b/tools/go_generics/tests/simple/BUILD
@@ -0,0 +1,17 @@
+load("//tools/go_generics/tests:defs.bzl", "go_generics_test")
+
+go_generics_test(
+ name = "simple",
+ inputs = ["input.go"],
+ output = "output.go",
+ suffix = "New",
+ types = {
+ "T": "Q",
+ },
+)
+
+# @unused
+glaze_ignore = [
+ "input.go",
+ "output.go",
+]
diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/tests/simple/input.go
index 2a917f16c..2a917f16c 100644
--- a/tools/go_generics/generics_tests/simple/input.go
+++ b/tools/go_generics/tests/simple/input.go
diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/tests/simple/output.go
index 6bfa0b25b..6bfa0b25b 100644
--- a/tools/go_generics/generics_tests/simple/output/output.go
+++ b/tools/go_generics/tests/simple/output.go
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
index c862b277c..f79defea7 100644
--- a/tools/go_marshal/BUILD
+++ b/tools/go_marshal/BUILD
@@ -1,6 +1,6 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "bzl_library", "go_binary")
-package(licenses = ["notice"])
+licenses(["notice"])
go_binary(
name = "go_marshal",
@@ -12,3 +12,14 @@ go_binary(
"//tools/go_marshal/gomarshal",
],
)
+
+config_setting(
+ name = "marshal_config_verbose",
+ values = {"define": "gomarshal=verbose"},
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
index 481575bd3..68d759083 100644
--- a/tools/go_marshal/README.md
+++ b/tools/go_marshal/README.md
@@ -9,30 +9,16 @@ automatically generating code to marshal go data structures to memory.
`binary.Marshal` by moving the go runtime reflection necessary to marshal a
struct to compile-time.
-`go_marshal` automatically generates implementations for `abi.Marshallable` and
-`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall
-implementations) can directly invoke `safemem.Reader.ReadToBlocks` and
-`safemem.Writer.WriteFromBlocks`. Data structures that require custom
-serialization will have manual implementations for these interfaces.
+`go_marshal` automatically generates implementations for `marshal.Marshallable`
+and `safemem.{Reader,Writer}`. Data structures that require custom serialization
+will have manual implementations for these interfaces.
Data structures can be flagged for code generation by adding a struct-level
comment `// +marshal`.
# Usage
-See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`.
-
-The recommended way to generate a go library with marshalling is to use the
-`go_library` with mostly identical configuration as the native go_library rule.
-
-```
-load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
-
-go_library(
- name = "foo",
- srcs = ["foo.go"],
-)
-```
+See `defs.bzl`: a new rule is provided, `go_marshal`.
Under the hood, the `go_marshal` rule is used to generate a file that will
appear in a Go target; the output file should appear explicitly in a srcs list.
@@ -54,11 +40,7 @@ go_library(
"foo.go",
"foo_abi.go",
],
- deps = [
- "<PKGPATH>/gvisor/pkg/abi",
- "<PKGPATH>/gvisor/pkg/sentry/safemem/safemem",
- "<PKGPATH>/gvisor/pkg/sentry/usermem/usermem",
- ],
+ ...
)
```
@@ -69,22 +51,6 @@ These tests use reflection to verify properties of the ABI struct, and should be
considered part of the generated interfaces (but are too expensive to execute at
runtime). Ensure these tests run at some point.
-```
-$ cat BUILD
-load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
-
-go_library(
- name = "foo",
- srcs = ["foo.go"],
-)
-$ blaze build :foo
-$ blaze query ...
-<path-to-dir>:foo_abi_autogen
-<path-to-dir>:foo_abi_autogen_test
-$ blaze test :foo_abi_autogen_test
-<test-output>
-```
-
# Restrictions
Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is
@@ -131,22 +97,6 @@ for embedded structs that are not aligned.
Because of this, it's generally best to avoid using `marshal:"unaligned"` and
insert explicit padding fields instead.
-## Debugging go_marshal
-
-To enable debugging output from the go marshal tool, pass the `-debug` flag to
-the tool. When using the build rules from above, add a `debug = True` field to
-the build rule like this:
-
-```
-load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
-
-go_library(
- name = "foo",
- srcs = ["foo.go"],
- debug = True,
-)
-```
-
## Modifying the `go_marshal` Tool
The following are some guidelines for modifying the `go_marshal` tool:
diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD
index c859ced77..c2a4d45c4 100644
--- a/tools/go_marshal/analysis/BUILD
+++ b/tools/go_marshal/analysis/BUILD
@@ -1,12 +1,11 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
-package(licenses = ["notice"])
+licenses(["notice"])
go_library(
name = "analysis",
testonly = 1,
srcs = ["analysis_unsafe.go"],
- importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis",
visibility = [
"//:sandbox",
],
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
index 9a9a4f298..cd55cf5cb 100644
--- a/tools/go_marshal/analysis/analysis_unsafe.go
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -161,6 +161,10 @@ func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
implicitPad := int(typ.Size()) - nextXOff
f := typ.Field(typ.NumField() - 1) // Final field
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ // Final field explicitly marked unaligned.
+ break
+ }
t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
index c32eb559f..323e33882 100644
--- a/tools/go_marshal/defs.bzl
+++ b/tools/go_marshal/defs.bzl
@@ -1,57 +1,14 @@
-"""Marshal is a tool for generating marshalling interfaces for Go types.
-
-The recommended way is to use the go_library rule defined below with mostly
-identical configuration as the native go_library rule.
-
-load("//tools/go_marshal:defs.bzl", "go_library")
-
-go_library(
- name = "foo",
- srcs = ["foo.go"],
-)
-
-Under the hood, the go_marshal rule is used to generate a file that will
-appear in a Go target; the output file should appear explicitly in a srcs list.
-For example (the above is still the preferred way):
-
-load("//tools/go_marshal:defs.bzl", "go_marshal")
-
-go_marshal(
- name = "foo_abi",
- srcs = ["foo.go"],
- out = "foo_abi.go",
- package = "foo",
-)
-
-go_library(
- name = "foo",
- srcs = [
- "foo.go",
- "foo_abi.go",
- ],
- deps = [
- "//tools/go_marshal:marshal",
- "//pkg/sentry/platform/safecopy",
- "//pkg/sentry/usermem",
- ],
-)
-"""
-
-load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+"""Marshal is a tool for generating marshalling interfaces for Go types."""
def _go_marshal_impl(ctx):
"""Execute the go_marshal tool."""
output = ctx.outputs.lib
output_test = ctx.outputs.test
- (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD")
-
- decl = "/".join(["gvisor.dev/gvisor", build_dir])
# Run the marshal command.
args = ["-output=%s" % output.path]
args += ["-pkg=%s" % ctx.attr.package]
args += ["-output_test=%s" % output_test.path]
- args += ["-declarationPkg=%s" % decl]
if ctx.attr.debug:
args += ["-debug"]
@@ -83,7 +40,6 @@ go_marshal = rule(
implementation = _go_marshal_impl,
attrs = {
"srcs": attr.label_list(mandatory = True, allow_files = True),
- "libname": attr.string(mandatory = True),
"imports": attr.string_list(mandatory = False),
"package": attr.string(mandatory = True),
"debug": attr.bool(doc = "enable debugging output from the go_marshal tool"),
@@ -95,58 +51,15 @@ go_marshal = rule(
},
)
-def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs):
- """wraps the standard go_library and does mashalling interface generation.
-
- Args:
- name: Same as native go_library.
- srcs: Same as native go_library.
- deps: Same as native go_library.
- imports: Extra import paths to pass to the go_marshal tool.
- debug: Enables debugging output from the go_marshal tool.
- **kwargs: Remaining args to pass to the native go_library rule unmodified.
- """
- go_marshal(
- name = name + "_abi_autogen",
- libname = name,
- srcs = [src for src in srcs if src.endswith(".go")],
- debug = debug,
- imports = imports,
- package = name,
- )
-
- extra_deps = [
- "//tools/go_marshal/marshal",
- "//pkg/sentry/platform/safecopy",
- "//pkg/sentry/usermem",
- ]
-
- all_srcs = srcs + [name + "_abi_autogen_unsafe.go"]
- all_deps = deps + [] # + extra_deps
-
- for extra in extra_deps:
- if extra not in deps:
- all_deps.append(extra)
-
- _go_library(
- name = name,
- srcs = all_srcs,
- deps = all_deps,
- **kwargs
- )
-
- # Don't pass importpath arg to go_test.
- kwargs.pop("importpath", "")
-
- _go_test(
- name = name + "_abi_autogen_test",
- srcs = [name + "_abi_autogen_test.go"],
- # Generated test has a fixed set of dependencies since we generate these
- # tests. They should only depend on the library generated above, and the
- # Marshallable interface.
- deps = [
- ":" + name,
- "//tools/go_marshal/analysis",
- ],
- **kwargs
- )
+# marshal_deps are the dependencies requied by generated code.
+marshal_deps = [
+ "//pkg/gohacks",
+ "//pkg/safecopy",
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+]
+
+# marshal_test_deps are required by test targets.
+marshal_test_deps = [
+ "//tools/go_marshal/analysis",
+]
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index a0eae6492..44cb33ae4 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -1,17 +1,21 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
-package(licenses = ["notice"])
+licenses(["notice"])
go_library(
name = "gomarshal",
srcs = [
"generator.go",
"generator_interfaces.go",
+ "generator_interfaces_array_newtype.go",
+ "generator_interfaces_primitive_newtype.go",
+ "generator_interfaces_struct.go",
"generator_tests.go",
"util.go",
],
- importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal",
+ stateify = False,
visibility = [
"//:sandbox",
],
+ deps = ["//tools/tags"],
)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 641ccd938..19bcd4e6a 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -23,17 +23,14 @@ import (
"go/token"
"os"
"sort"
-)
+ "strings"
-const (
- marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
- usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem"
- safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+ "gvisor.dev/gvisor/tools/tags"
)
-// List of identifiers we use in generated code, that may conflict a
-// similarly-named source identifier. Avoid problems by refusing the generate
-// code when we see these.
+// List of identifiers we use in generated code that may conflict with a
+// similarly-named source identifier. Abort gracefully when we see these to
+// avoid potentially confusing compilation failures in generated code.
//
// This only applies to import aliases at the moment. All other identifiers
// are qualified by a receiver argument, since they're struct fields.
@@ -41,10 +38,21 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner",
+ "length", "limit", "ptr", "size", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
+// Constructed fromt badIdents in init().
+var badIdentsMap map[string]struct{}
+
+func init() {
+ badIdentsMap = make(map[string]struct{})
+ for _, ident := range badIdents {
+ badIdentsMap[ident] = struct{}{}
+ }
+}
+
// Generator drives code generation for a single invocation of the go_marshal
// utility.
//
@@ -62,15 +70,12 @@ type Generator struct {
outputTest *os.File
// Package name for the generated file.
pkg string
- // Go import path for package we're processing. This package should directly
- // declare the type we're generating code for.
- declaration string
// Set of extra packages to import in the generated file.
imports *importTable
}
// NewGenerator creates a new code Generator.
-func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) {
+func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*Generator, error) {
f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
@@ -80,25 +85,29 @@ func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports
return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
}
g := Generator{
- inputs: srcs,
- output: f,
- outputTest: fTest,
- pkg: pkg,
- declaration: declaration,
- imports: newImportTable(),
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ pkg: pkg,
+ imports: newImportTable(),
}
for _, i := range imports {
// All imports on the extra imports list are unconditionally marked as
- // used, so they're always added to the generated code.
+ // used, so that they're always added to the generated code.
g.imports.add(i).markUsed()
}
- g.imports.add(marshalImport).markUsed()
- // The follow imports may or may not be used by the generated
- // code, depending what's required for the target types. Don't
- // mark these imports as used by default.
- g.imports.add(usermemImport)
- g.imports.add(safecopyImport)
+
+ // The following imports may or may not be used by the generated code,
+ // depending on what's required for the target types. Don't mark these as
+ // used by default.
+ g.imports.add("io")
+ g.imports.add("reflect")
+ g.imports.add("runtime")
g.imports.add("unsafe")
+ g.imports.add("gvisor.dev/gvisor/pkg/gohacks")
+ g.imports.add("gvisor.dev/gvisor/pkg/safecopy")
+ g.imports.add("gvisor.dev/gvisor/pkg/usermem")
+ g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal")
return &g, nil
}
@@ -108,6 +117,14 @@ func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports
func (g *Generator) writeHeader() error {
var b sourceBuffer
b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+
+ // Emit build tags.
+ if t := tags.Aggregate(g.inputs); len(t) > 0 {
+ b.emit(strings.Join(t.Lines(), "\n"))
+ b.emit("\n\n")
+ }
+
+ // Package header.
b.emit("package %s\n\n", g.pkg)
if err := b.write(g.output); err != nil {
return err
@@ -172,10 +189,73 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
return files, fsets, nil
}
-// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// sliceAPI carries information about the '+marshal slice' directive.
+type sliceAPI struct {
+ // Comment node in the AST containing the +marshal tag.
+ comment *ast.Comment
+ // Identifier fragment to use when naming generated functions for the slice
+ // API.
+ ident string
+ // Whether the generated functions should reference the newtype name, or the
+ // inner type name. Only meaningful on newtype declarations on primitives.
+ inner bool
+}
+
+// marshallableType carries information about a type marked with the '+marshal'
+// directive.
+type marshallableType struct {
+ spec *ast.TypeSpec
+ slice *sliceAPI
+}
+
+func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType {
+ mt := marshallableType{
+ spec: spec,
+ slice: nil,
+ }
+
+ var unhandledTags []string
+
+ for _, tag := range strings.Fields(strings.TrimPrefix(tagLine.Text, "// +marshal")) {
+ if strings.HasPrefix(tag, "slice:") {
+ tokens := strings.Split(tag, ":")
+ if len(tokens) < 2 || len(tokens) > 3 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive has invalid 'slice' clause. Expecting format 'slice:<IDENTIFIER>[:inner]', got '%v'", tag))
+ }
+ if len(tokens[1]) == 0 {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has empty identifier argument. Expecting '+marshal slice:identifier'")
+ }
+
+ sa := &sliceAPI{
+ comment: tagLine,
+ ident: tokens[1],
+ }
+ mt.slice = sa
+
+ if len(tokens) == 3 {
+ if tokens[2] != "inner" {
+ abortAt(fset.Position(tagLine.Slash), "+marshal slice directive has an invalid argument. Expecting '+marshal slice:<IDENTIFIER>[:inner]'")
+ }
+ sa.inner = true
+ }
+
+ continue
+ }
+
+ unhandledTags = append(unhandledTags, tag)
+ }
+
+ if len(unhandledTags) > 0 {
+ abortAt(fset.Position(tagLine.Slash), fmt.Sprintf("+marshal directive contained the following unknown clauses: %v", strings.Join(unhandledTags, " ")))
+ }
+
+ return mt
+}
+
+// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
- var types []*ast.TypeSpec
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType {
+ var types []marshallableType
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
// Type declaration?
@@ -190,9 +270,11 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
}
// Does the comment contain a "+marshal" line?
marked := false
+ var tagLine *ast.Comment
for _, c := range gdecl.Doc.List {
- if c.Text == "// +marshal" {
+ if strings.HasPrefix(c.Text, "// +marshal") {
marked = true
+ tagLine = c
break
}
}
@@ -201,14 +283,23 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
continue
}
for _, spec := range gdecl.Specs {
- // We already confirmed we're in a type declaration earlier.
+ // We already confirmed we're in a type declaration earlier, so this
+ // cast will succeed.
t := spec.(*ast.TypeSpec)
- if _, ok := t.Type.(*ast.StructType); ok {
- debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
- types = append(types, t)
- continue
+ switch t.Type.(type) {
+ case *ast.StructType:
+ debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
+ case *ast.Ident: // Newtype on primitive.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
+ case *ast.ArrayType: // Newtype on array.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on array %s.\n", t.Name.Name)
+ default:
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
- debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ types = append(types, newMarshallableType(f, tagLine, t))
+
}
}
return types
@@ -222,11 +313,6 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
// identifiers in the generated code don't conflict with any imported package
// names.
func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
- badImportNames := make(map[string]bool)
- for _, i := range badIdents {
- badImportNames[i] = true
- }
-
is := make(map[string]importStmt)
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
@@ -240,10 +326,10 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
// Make sure we have an import that doesn't use any local names that
// would conflict with identifiers in the generated code.
- if len(i.name) == 1 {
+ if len(i.name) == 1 && i.name != "_" {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
}
- if badImportNames[i.name] {
+ if _, ok := badIdentsMap[i.name]; ok {
abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
}
}
@@ -252,20 +338,40 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
-func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- // We're guaranteed to have only struct type specs by now. See
- // Generator.collectMarshallabeTypes.
- i := newInterfaceGenerator(t, fset)
- i.validate()
- i.emitMarshallable()
+func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator {
+ i := newInterfaceGenerator(t.spec, fset)
+ switch ty := t.spec.Type.(type) {
+ case *ast.StructType:
+ i.validateStruct(t.spec, ty)
+ i.emitMarshallableForStruct(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForStruct(ty, t.slice)
+ }
+ case *ast.Ident:
+ i.validatePrimitiveNewtype(ty)
+ i.emitMarshallableForPrimitiveNewtype(ty)
+ if t.slice != nil {
+ i.emitMarshallableSliceForPrimitiveNewtype(ty, t.slice)
+ }
+ case *ast.ArrayType:
+ i.validateArrayNewtype(t.spec.Name, ty)
+ // After validate, we can safely call arrayLen.
+ i.emitMarshallableForArrayNewtype(t.spec.Name, ty, ty.Elt.(*ast.Ident))
+ if t.slice != nil {
+ abortAt(fset.Position(t.slice.comment.Slash), fmt.Sprintf("Array type marked as '+marshal slice:...', but this is not supported. Perhaps fold one of the dimensions?"))
+ }
+ default:
+ // This should've been filtered out by collectMarshallabeTypes.
+ panic(fmt.Sprintf("Unexpected type %+v", ty))
+ }
return i
}
// generateOneTestSuite generates a test suite for the automatically generated
// implementations type t.
-func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
- i := newTestGenerator(t, g.declaration)
- i.emitTests()
+func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator {
+ i := newTestGenerator(t.spec)
+ i.emitTests(t.slice)
return i
}
@@ -304,35 +410,24 @@ func (g *Generator) Run() error {
for i, a := range asts {
// Collect type declarations marked for code generation and generate
// Marshallable interfaces.
- for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
impl := g.generateOne(t, fsets[i])
// Collect Marshallable types referenced by the generated code.
- for ref, _ := range impl.ms {
+ for ref := range impl.ms {
ms[ref] = struct{}{}
}
impls = append(impls, impl)
// Collect imports referenced by the generated code and add them to
// the list of imports we need to copy to the generated code.
- for name, _ := range impl.is {
+ for name := range impl.is {
if !g.imports.markUsed(name) {
- panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
}
}
ts = append(ts, g.generateOneTestSuite(t))
}
}
- // Tool was invoked with input files with no data structures marked for code
- // generation. This is probably not what the user intended.
- if len(impls) == 0 {
- var buf bytes.Buffer
- fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
- for _, i := range g.inputs {
- fmt.Fprintf(&buf, " %s\n", i)
- }
- abort(buf.String())
- }
-
// Write output file header. These include things like package name and
// import statements.
if err := g.writeHeader(); err != nil {
@@ -359,11 +454,12 @@ func (g *Generator) Run() error {
// source file.
func (g *Generator) writeTests(ts []*testGenerator) error {
var b sourceBuffer
- b.emit("package %s_test\n\n", g.pkg)
+ b.emit("package %s\n\n", g.pkg)
if err := b.write(g.outputTest); err != nil {
return err
}
+ // Collect and write test import statements.
imports := newImportTable()
for _, t := range ts {
imports.merge(t.imports)
@@ -373,6 +469,27 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
return err
}
+ // Write test functions.
+
+ // If we didn't generate any Marshallable implementations, we can't just
+ // emit an empty test file, since that causes the build to fail with "no
+ // tests/benchmarks/examples found". Unfortunately we can't signal bazel to
+ // omit the entire package since the outputs are already defined before
+ // go-marshal is called. If we'd otherwise emit an empty test suite, emit an
+ // empty example instead.
+ if len(ts) == 0 {
+ b.reset()
+ b.emit("func Example() {\n")
+ b.inIndent(func() {
+ b.emit("// This example is intentionally empty to ensure this file contains at least\n")
+ b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
+ b.emit("// is marked marshallable, but emitting a test file with no entities results\n")
+ b.emit("// in a build failure.\n")
+ })
+ b.emit("}\n")
+ return b.write(g.outputTest)
+ }
+
for _, t := range ts {
if err := t.write(g.outputTest); err != nil {
return err
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index a712c14dc..e3c3dac63 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string {
// newinterfaceGenerator creates a new interface generator.
func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &interfaceGenerator{
t: t,
r: receiverName(t),
@@ -77,25 +74,12 @@ func (g *interfaceGenerator) recordUsedMarshallable(m string) {
func (g *interfaceGenerator) recordUsedImport(i string) {
g.is[i] = struct{}{}
-
}
func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
g.as[fieldName] = struct{}{}
}
-func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
-func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
- return fmt.Sprintf("%s.%s", g.r, n.Name)
-}
-
// abortAt aborts the go_marshal tool with the given error message, with a
// reference position to the input source. Same as abortAt, but uses g to
// resolve p to position.
@@ -103,67 +87,6 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-// validate ensures the type we're working with can be marshalled. These checks
-// are done ahead of time and in one place so we can make assumptions later.
-func (g *interfaceGenerator) validate() {
- g.forEachField(func(f *ast.Field) {
- if len(f.Names) == 0 {
- g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
- }
- })
-
- g.forEachField(func(f *ast.Field) {
- fieldDispatcher{
- primitive: func(_, t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we
- // provide suggestions for some disallowed types and reject
- // them, then attempt to marshal any remaining types by
- // invoking the marshal.Marshallable interface on them. If
- // these types don't actually implement
- // marshal.Marshallable, compilation of the generated code
- // will fail with an appropriate error message.
- return
- case "int":
- g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
- },
- selector: func(_, _, _ *ast.Ident) {
- // No validation to perform on selector fields. However this
- // callback must still be provided.
- },
- array: func(n, _ *ast.Ident, len int) {
- a := f.Type.(*ast.ArrayType)
- if a.Len == nil {
- g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
- }
-
- if _, ok := a.Len.(*ast.BasicLit); !ok {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
- }
-
- if _, ok := a.Elt.(*ast.Ident); !ok {
- g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
- }
-
- if len <= 0 {
- g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
- }
- },
- unhandled: func(_ *ast.Ident) {
- g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
- },
- }.dispatch(f)
- })
-}
-
// scalarSize returns the size of type identified by t. If t isn't a primitive
// type, the size isn't known at code generation time, and must be resolved via
// the marshal.Marshallable interface.
@@ -190,7 +113,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+// marshalScalar writes a single scalar to a byte slice.
+func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -213,43 +137,26 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string)
}
}
-func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+// unmarshalScalar reads a single scalar from a byte slice.
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) {
switch typ {
- case "int8":
- g.emit("%s = int8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
- case "uint8":
- g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
g.shift(bufVar, 1)
-
- case "int16":
- g.recordUsedImport("usermem")
- g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
- g.shift(bufVar, 2)
- case "uint16":
+ case "int8", "uint8":
+ g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
g.shift(bufVar, 2)
-
- case "int32":
- g.recordUsedImport("usermem")
- g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
- g.shift(bufVar, 4)
- case "uint32":
+ case "int32", "uint32":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
g.shift(bufVar, 4)
-
- case "int64":
- g.recordUsedImport("usermem")
- g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
- g.shift(bufVar, 8)
- case "uint64":
+ case "int64", "uint64":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
g.shift(bufVar, 8)
default:
g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
@@ -258,250 +165,112 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string
}
}
-// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
-// packed. Returns "", false if g.t has no fields that may be potentially
-// packed, otherwise returns <clause>, true, where <clause> is an expression
-// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
-func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
- if len(g.as) == 0 {
- return "", false
- }
-
- cs := make([]string, 0, len(g.as))
- for accessor, _ := range g.as {
- cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
- }
- return strings.Join(cs, " && "), true
+// emitCastToByteSlice unsafely casts an arbitrary type's underlying memory to a
+// byte slice, bypassing escape analysis. The caller is responsible for ensuring
+// srcPtr lives until they're done with dstVar, the runtime does not consider
+// dstVar dependent on srcPtr due to the escape analysis bypass.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "hdr", and cannot be used
+// in a context where it is already bound.
+func (g *interfaceGenerator) emitCastToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.recordUsedImport("gohacks")
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(%s)))\n", srcPtr)
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
}
-func (g *interfaceGenerator) emitMarshallable() {
- // Is g.t a packed struct without consideing field types?
- thisPacked := true
- g.forEachField(func(f *ast.Field) {
- if f.Tag != nil {
- if f.Tag.Value == "`marshal:\"unaligned\"`" {
- if thisPacked {
- debugfAt(g.f.Position(g.t.Pos()),
- fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
- thisPacked = false
- }
- }
- }
- })
-
- g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
- g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
- g.inIndent(func() {
- primitiveSize := 0
- var dynamicSizeTerms []string
-
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n)))
- }
- },
- selector: func(n, tX, tSel *ast.Ident) {
- tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
- g.recordUsedImport(tX.Name)
- g.recordUsedMarshallable(tName)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
- },
- array: func(n, t *ast.Ident, len int) {
- if len < 1 {
- // Zero-length arrays should've been rejected by validate().
- panic("unreachable")
- }
- if size, dynamic := g.scalarSize(t); !dynamic {
- primitiveSize += size * len
- } else {
- g.recordUsedMarshallable(t.Name)
- dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
- }
- },
- }.dispatch)
- g.emit("return %d", primitiveSize)
- if len(dynamicSizeTerms) > 0 {
- g.incIndent()
- }
- {
- for _, d := range dynamicSizeTerms {
- g.emitNoIndent(" +\n")
- g.emit(d)
- }
- }
- if len(dynamicSizeTerms) > 0 {
- g.decIndent()
- }
- })
- g.emit("\n}\n\n")
-
- g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
- g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
- }
- return
- }
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("dst", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
-
- g.emit("for i := 0; i < %d; i++ {\n", size)
- g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
-
- g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
- g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- g.forEachField(fieldDispatcher{
- primitive: func(n, t *ast.Ident) {
- if n.Name == "_" {
- g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can reference here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
- g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
- }
- return
- }
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
- },
- selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
- },
- array: func(n, t *ast.Ident, size int) {
- if n.Name == "_" {
- g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
- if len, dynamic := g.scalarSize(t); !dynamic {
- g.shift("src", len*size)
- } else {
- // We can't use shiftDynamic here because we don't have
- // an instance of the dynamic type we can referece here
- // (since the version in this struct is anonymous). Use
- // a typed nil pointer to call SizeBytes() instead.
- g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
- }
- return
- }
+// emitCastToByteSlice unsafely casts a slice with elements of an abitrary type
+// to a byte slice. As part of the cast, the byte slice is made to look
+// independent of the src slice by bypassing escape analysis. This means the
+// byte slice can be used without causing the source to escape. The caller is
+// responsible for ensuring srcPtr lives until they're done with dstVar, as the
+// runtime no longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifiers "ptr", "val" and "hdr",
+// and cannot be used in a context where these identifiers are already bound.
+func (g *interfaceGenerator) emitCastSliceToByteSlice(srcPtr, dstVar, lenExpr string) {
+ g.emitNoEscapeSliceDataPointer(srcPtr, "val")
+
+ g.emit("// Construct a slice backed by dst's underlying memory.\n")
+ g.emit("var %s []byte\n", dstVar)
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&%s))\n", dstVar)
+ g.emit("hdr.Data = uintptr(val)\n")
+ g.emit("hdr.Len = %s\n", lenExpr)
+ g.emit("hdr.Cap = %s\n\n", lenExpr)
+}
- g.emit("for i := 0; i < %d; i++ {\n", size)
- g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
- })
- g.emit("}\n")
- },
- }.dispatch)
- })
- g.emit("}\n\n")
+// emitNoEscapeSliceDataPointer unsafely casts a slice's data pointer to an
+// unsafe.Pointer, bypassing escape analysis. The caller is responsible for
+// ensuring srcPtr lives until they're done with dstVar, as the runtime no
+// longer considers dstVar dependent on srcPtr and is free to GC it.
+//
+// srcPtr must be a pointer.
+//
+// This function uses internally uses the identifier "ptr" cannot be used in a
+// context where this identifier is already bound.
+func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) {
+ g.recordUsedImport("gohacks")
+ g.emit("ptr := unsafe.Pointer(%s)\n", srcPtr)
+ g.emit("%s := gohacks.Noescape(unsafe.Pointer((*reflect.SliceHeader)(ptr).Data))\n\n", dstVar)
+}
- g.emit("// Packed implements marshal.Marshallable.Packed.\n")
- g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
- g.inIndent(func() {
- expr, fieldsMaybePacked := g.areFieldsPackedExpression()
- switch {
- case !thisPacked:
- g.emit("return false\n")
- case fieldsMaybePacked:
- g.emit("return %s\n", expr)
- default:
- g.emit("return true\n")
+func (g *interfaceGenerator) emitKeepAlive(ptrVar string) {
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar)
+ g.emit("// must live until the use above.\n")
+ g.emit("runtime.KeepAlive(%s)\n", ptrVar)
+}
- }
- })
- g.emit("}\n\n")
+func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) {
+ switch x := e.X.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, x)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", x.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", x.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
- g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
- g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
- g.emit("%s.MarshalBytes(dst)\n", g.r)
- }
- })
- g.emit("}\n\n")
+ fmt.Fprintf(b, "%s", e.Op)
- g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
- g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
- g.inIndent(func() {
- if thisPacked {
- g.recordUsedImport("safecopy")
- g.recordUsedImport("unsafe")
- if cond, ok := g.areFieldsPackedExpression(); ok {
- g.emit("if %s {\n", cond)
- g.inIndent(func() {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- })
- g.emit("} else {\n")
- g.inIndent(func() {
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- })
- g.emit("}\n")
- } else {
- g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
- }
- } else {
- g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
- g.emit("%s.UnmarshalBytes(src)\n", g.r)
- }
- })
- g.emit("}\n\n")
+ switch y := e.Y.(type) {
+ case *ast.BinaryExpr:
+ // Recursively expand sub-expression.
+ g.expandBinaryExpr(b, y)
+ case *ast.Ident:
+ fmt.Fprintf(b, "%s", y.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(b, "%s", y.Value)
+ default:
+ g.abortAt(e.Pos(), "Cannot convert binary expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+}
+// arrayLenExpr returns a string containing a valid golang expression
+// representing the length of array a. The returned expression should be treated
+// as a single value, and will be already parenthesized as required.
+func (g *interfaceGenerator) arrayLenExpr(a *ast.ArrayType) string {
+ var b strings.Builder
+
+ switch l := a.Len.(type) {
+ case *ast.Ident:
+ fmt.Fprintf(&b, "%s", l.Name)
+ case *ast.BasicLit:
+ fmt.Fprintf(&b, "%s", l.Value)
+ case *ast.BinaryExpr:
+ g.expandBinaryExpr(&b, l)
+ return fmt.Sprintf("(%s)", b.String())
+ default:
+ g.abortAt(l.Pos(), "Cannot convert this array len expression to output code. Go-marshal currently only handles simple expressions of literals, constants and basic identifiers")
+ }
+ return b.String()
}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
new file mode 100644
index 000000000..72ef03a22
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go
@@ -0,0 +1,146 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on arrays.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+func (g *interfaceGenerator) validateArrayNewtype(n *ast.Ident, a *ast.ArrayType) {
+ if a.Len == nil {
+ g.abortAt(a.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+}
+
+func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *ast.ArrayType, elt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ lenExpr := g.arrayLenExpr(a)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(elt); !dynamic {
+ g.emit("return %d * %s\n", size, lenExpr)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes() * %s\n", n.Name, lenExpr)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src")
+ })
+ g.emit("}\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Array newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
new file mode 100644
index 000000000..39f654ea8
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go
@@ -0,0 +1,289 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// newtypes on primitives.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+)
+
+// marshalPrimitiveScalar writes a single primitive variable to a byte
+// slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// MarshalBytes, so we don't recursively call %s.MarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("// Explicilty cast to the underlying type before dispatching to\n")
+ g.emit("// UnmarshalBytes, so we don't recursively call %s.UnmarshalBytes\n", accessor)
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Ident, slice *sliceAPI) {
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+
+ eltType := g.typeName()
+ if slice.inner {
+ eltType = nt.Name
+ }
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType)
+ g.emit("//go:nosplit\n")
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType)
+ g.emit("//go:nosplit\n")
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType)
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
new file mode 100644
index 000000000..4b9cea08a
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go
@@ -0,0 +1,622 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains the bits of the code generator specific to marshalling
+// structs.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "strings"
+)
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct(ts *ast.TypeSpec, st *ast.StructType) {
+ forEachStructField(st, func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ forEachStructField(st, func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ g.validatePrimitiveNewtype(t)
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, _ *ast.Ident) {
+ g.validateArrayNewtype(n, a)
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+func (g *interfaceGenerator) isStructPacked(st *ast.StructType) bool {
+ packed := true
+ forEachStructField(st, func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if packed {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ packed = false
+ }
+ }
+ }
+ })
+ return packed
+}
+
+func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) {
+ thisPacked := g.isStructPacked(st)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", t.Name))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%d*%s", size, lenExpr))
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%s", t.Name, lenExpr))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", tX.Name, tSel.Name)
+ g.emit("dst = dst[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%s] ~= [%s]%s{0}\n", t.Name, lenExpr, lenExpr, t.Name)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("dst = dst[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ forEachStructField(st, fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We don't have an instance of the dynamic type we can
+ // reference here (since the version in this struct is
+ // anonymous). Use a typed nil pointer to call
+ // SizeBytes() instead.
+ g.shiftDynamic("src", fmt.Sprintf("(*%s)(nil)", t.Name))
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: %s ~= src[:sizeof(%s.%s)]\n", g.fieldAccessor(n), tX.Name, tSel.Name)
+ g.emit("src = src[(*%s.%s)(nil).SizeBytes():]\n", tX.Name, tSel.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s.%s)(nil)", tX.Name, tSel.Name))
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident) {
+ lenExpr := g.arrayLenExpr(a)
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%s]%s(%s), src[:sizeof(%s)*%s])\n", lenExpr, t.Name, g.fieldAccessor(n), t.Name, lenExpr)
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ g.emit("src = src[%d*(%s):]\n", size, lenExpr)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*(%s):]\n", t.Name, lenExpr)
+ }
+ return
+ }
+
+ g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("//go:nosplit\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(fallback)
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(fallback)
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("//go:nosplit\n")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r)
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n")
+ g.emit("// partially unmarshalled struct.\n")
+ g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r)
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.recordUsedImport("io")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r)
+ g.emit("%s.MarshalBytes(buf)\n", g.r)
+ g.emit("length, err := w.Write(buf)\n")
+ g.emit("return int64(length), err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !%s {\n", cond)
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r))
+
+ g.emit("length, err := w.Write(buf)\n")
+ g.emitKeepAlive(g.r)
+ g.emit("return int64(length), err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
+
+func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, slice *sliceAPI) {
+ thisPacked := g.isStructPacked(st)
+
+ if slice.inner {
+ abortAt(g.f.Position(slice.comment.Slash), fmt.Sprintf("The ':inner' argument to '+marshal slice:%s:inner' is only applicable to newtypes on primitives. Remove it from this struct declaration.", slice.ident))
+ }
+
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("usermem")
+
+ g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n\n")
+
+ g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n")
+ g.emit("limit := length/size\n")
+ g.emit("for idx := 0; idx < limit; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Handle any final partial object. buf is guaranteed to be long enough for the\n")
+ g.emit("// final element, but may not contain valid data for the entire range. This may\n")
+ g.emit("// result in unmarshalling zero values for some parts of the object.\n")
+ g.emit("if length%size != 0 {\n")
+ g.inIndent(func() {
+ g.emit("idx := limit\n")
+ g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("return length, err\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast deserialization.
+ g.emitCastSliceToByteSlice("&dst", "buf", "size * count")
+
+ g.emit("length, err := task.CopyInBytes(addr, buf)\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName())
+ g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("buf := task.CopyScratchBuffer(size * count)\n")
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return task.CopyOutBytes(addr, buf)\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ // Fast serialization.
+ g.emitCastSliceToByteSlice("&src", "buf", "size * count")
+
+ g.emit("length, err := task.CopyOutBytes(addr, buf)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe%s is like %s.MarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func MarshalUnsafe%s(src []%s, dst []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(src)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !src[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&src", "val")
+
+ g.emit("length, err := safecopy.CopyIn(dst[:(size*count)], val)\n")
+ g.emitKeepAlive("src")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe%s is like %s.UnmarshalUnsafe, but for a []%s.\n", slice.ident, g.typeName(), g.typeName())
+ g.emit("func UnmarshalUnsafe%s(dst []%s, src []byte) (int, error) {\n", slice.ident, g.typeName())
+ g.inIndent(func() {
+ g.emit("count := len(dst)\n")
+ g.emit("if count == 0 {\n")
+ g.inIndent(func() {
+ g.emit("return 0, nil\n")
+ })
+ g.emit("}\n")
+ g.emit("size := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ fallback := func() {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("for idx := 0; idx < count; idx++ {\n")
+ g.inIndent(func() {
+ g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n")
+ })
+ g.emit("}\n")
+ g.emit("return size * count, nil\n")
+ }
+ if thisPacked {
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("unsafe")
+ if _, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if !dst[0].Packed() {\n")
+ g.inIndent(fallback)
+ g.emit("}\n\n")
+ }
+ g.emitNoEscapeSliceDataPointer("&dst", "val")
+
+ g.emit("length, err := safecopy.CopyOut(val, src[:(size*count)])\n")
+ g.emitKeepAlive("dst")
+ g.emit("return length, err\n")
+ } else {
+ fallback()
+ }
+ })
+ g.emit("}\n\n")
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index df25cb5b2..631295373 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -22,12 +22,19 @@ import (
)
var standardImports = []string{
+ "bytes",
"fmt",
"reflect",
"testing",
+
"gvisor.dev/gvisor/tools/go_marshal/analysis",
}
+var sliceAPIImports = []string{
+ "encoding/binary",
+ "gvisor.dev/gvisor/pkg/usermem",
+}
+
type testGenerator struct {
sourceBuffer
@@ -46,10 +53,7 @@ type testGenerator struct {
decl *importStmt
}
-func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
+func newTestGenerator(t *ast.TypeSpec) *testGenerator {
g := &testGenerator{
t: t,
r: receiverName(t),
@@ -59,22 +63,17 @@ func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
for _, i := range standardImports {
g.imports.add(i).markUsed()
}
- g.decl = g.imports.add(declaration)
- g.decl.markUsed()
+ // These imports are used if a type requests the slice API. Don't
+ // mark them as used by default.
+ for _, i := range sliceAPIImports {
+ g.imports.add(i)
+ }
return g
}
func (g *testGenerator) typeName() string {
- return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name)
-}
-
-func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
+ return g.t.Name.Name
}
func (g *testGenerator) testFuncName(base string) string {
@@ -89,10 +88,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) {
func (g *testGenerator) emitTestNonZeroSize() {
g.inTestFunction("TestSizeNonZero", func() {
- g.emit("x := &%s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("if x.SizeBytes() == 0 {\n")
g.inIndent(func() {
- g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
})
g.emit("}\n")
})
@@ -100,7 +99,7 @@ func (g *testGenerator) emitTestNonZeroSize() {
func (g *testGenerator) emitTestSuspectAlignment() {
g.inTestFunction("TestSuspectAlignment", func() {
- g.emit("x := %s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
})
}
@@ -118,35 +117,115 @@ func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
g.emit("y.UnmarshalBytes(buf)\n")
g.emit("if !reflect.DeepEqual(x, y) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
})
g.emit("}\n")
g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
})
g.emit("}\n\n")
g.emit("z.UnmarshalUnsafe(buf)\n")
g.emit("if !reflect.DeepEqual(x, z) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalBytes/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, z))\n")
})
g.emit("}\n")
g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
g.inIndent(func() {
- g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n")
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalSlicePreservesData(slice *sliceAPI) {
+ for _, name := range []string{"binary", "usermem"} {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated test for '%s' referenced a non-existent import with local name '%s'", g.typeName(), name))
+ }
+ }
+
+ g.inTestFunction("TestSafeMarshalUnmarshalSlicePreservesData", func() {
+ g.emit("var x, y, yUnsafe [8]%s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+ g.emit("size := (*%s)(nil).SizeBytes() * len(x)\n", g.typeName())
+ g.emit("buf := bytes.NewBuffer(make([]byte, size))\n")
+ g.emit("buf.Reset()\n")
+ g.emit("if err := binary.Write(buf, usermem.ByteOrder, x[:]); err != nil {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"binary.Write failed: %v\", err))\n")
+ })
+ g.emit("}\n")
+ g.emit("bufUnsafe := make([]byte, size)\n")
+ g.emit("MarshalUnsafe%s(x[:], bufUnsafe)\n\n", slice.ident)
+
+ g.emit("UnmarshalUnsafe%s(y[:], buf.Bytes())\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across binary.Write/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("UnmarshalUnsafe%s(yUnsafe[:], bufUnsafe)\n", slice.ident)
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafeSlice/UnmarshalUnsafeSlice cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+ })
+}
+
+func (g *testGenerator) emitTestWriteToUnmarshalPreservesData() {
+ g.inTestFunction("TestWriteToUnmarshalPreservesData", func() {
+ g.emit("var x, y, yUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("var buf bytes.Buffer\n\n")
+
+ g.emit("x.WriteTo(&buf)\n")
+ g.emit("y.UnmarshalBytes(buf.Bytes())\n\n")
+ g.emit("yUnsafe.UnmarshalUnsafe(buf.Bytes())\n\n")
+
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalBytes cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across WriteTo/UnmarshalUnsafe cycle:\\nBefore: %+v\\nAfter: %+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
+ g.inTestFunction("TestSizeBytesOnTypedNilPtr", func() {
+ g.emit("var x %s\n", g.typeName())
+ g.emit("sizeFromConcrete := x.SizeBytes()\n")
+ g.emit("sizeFromTypedNilPtr := (*%s)(nil).SizeBytes()\n\n", g.typeName())
+
+ g.emit("if sizeFromTypedNilPtr != sizeFromConcrete {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatalf(\"SizeBytes() on typed nil pointer (%v) doesn't match size returned by a concrete object (%v).\\n\", sizeFromTypedNilPtr, sizeFromConcrete)\n")
})
g.emit("}\n")
})
}
-func (g *testGenerator) emitTests() {
+func (g *testGenerator) emitTests(slice *sliceAPI) {
g.emitTestNonZeroSize()
g.emitTestSuspectAlignment()
g.emitTestMarshalUnmarshalPreservesData()
+ g.emitTestWriteToUnmarshalPreservesData()
+ g.emitTestSizeBytesOnTypedNilPtr()
+
+ if slice != nil {
+ g.emitTestMarshalUnmarshalSlicePreservesData(slice)
+ }
}
func (g *testGenerator) write(out io.Writer) error {
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
index 967537abf..d94314302 100644
--- a/tools/go_marshal/gomarshal/util.go
+++ b/tools/go_marshal/gomarshal/util.go
@@ -25,7 +25,6 @@ import (
"path"
"reflect"
"sort"
- "strconv"
"strings"
)
@@ -64,12 +63,18 @@ func kindString(e ast.Expr) string {
}
}
+func forEachStructField(st *ast.StructType, fn func(f *ast.Field)) {
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
// fieldDispatcher is a collection of callbacks for handling different types of
// fields in a struct declaration.
type fieldDispatcher struct {
primitive func(n, t *ast.Ident)
selector func(n, tX, tSel *ast.Ident)
- array func(n, t *ast.Ident, size int)
+ array func(n *ast.Ident, a *ast.ArrayType, t *ast.Ident)
unhandled func(n *ast.Ident)
}
@@ -96,22 +101,12 @@ func (fd fieldDispatcher) dispatch(f *ast.Field) {
case *ast.SelectorExpr:
fd.selector(name, v.X.(*ast.Ident), v.Sel)
case *ast.ArrayType:
- len := 0
- if v.Len != nil {
- // Non-literal array length is handled by generatorInterfaces.validate().
- if lenLit, ok := v.Len.(*ast.BasicLit); ok {
- var err error
- len, err = strconv.Atoi(lenLit.Value)
- if err != nil {
- panic(err)
- }
- }
- }
switch t := v.Elt.(type) {
case *ast.Ident:
- fd.array(name, t, len)
+ fd.array(name, v, t)
default:
- fd.array(name, nil, len)
+ // Should be handled with a better error message during validate.
+ panic(fmt.Sprintf("Array element type is of unsupported kind. Expected *ast.Ident, got %v", t))
}
default:
fd.unhandled(name)
@@ -219,6 +214,11 @@ type sourceBuffer struct {
b bytes.Buffer
}
+func (b *sourceBuffer) reset() {
+ b.indent = 0
+ b.b.Reset()
+}
+
func (b *sourceBuffer) incIndent() {
b.indent++
}
@@ -265,6 +265,11 @@ type importStmt struct {
aliased bool
// Indicates whether this import was referenced by generated code.
used bool
+ // AST node and file set representing the import statement, if any. These
+ // are only non-nil if the import statement originates from an input source
+ // file.
+ spec *ast.ImportSpec
+ fset *token.FileSet
}
func newImport(p string) *importStmt {
@@ -290,14 +295,27 @@ func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
name: name,
path: p,
aliased: spec.Name != nil,
+ spec: spec,
+ fset: f,
}
}
+// String implements fmt.Stringer.String. This generates a string for the import
+// statement appropriate for writing directly to generated code.
func (i *importStmt) String() string {
if i.aliased {
- return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ return fmt.Sprintf("%s %q", i.name, i.path)
}
- return fmt.Sprintf("\"%s\"", i.path)
+ return fmt.Sprintf("%q", i.path)
+}
+
+// debugString returns a debug string representing an import statement. This
+// representation is not valid golang code and is used for debugging output.
+func (i *importStmt) debugString() string {
+ if i.spec != nil && i.fset != nil {
+ return fmt.Sprintf("%s: %s", i.fset.Position(i.spec.Path.Pos()), i)
+ }
+ return fmt.Sprintf("(go-marshal import): %s", i)
}
func (i *importStmt) markUsed() {
@@ -305,58 +323,111 @@ func (i *importStmt) markUsed() {
}
func (i *importStmt) equivalent(other *importStmt) bool {
- return i == other
+ return i.name == other.name && i.path == other.path && i.aliased == other.aliased
}
// importTable represents a collection of importStmts.
+//
+// An importTable may contain multiple import statements referencing the same
+// local name. All import statements aliasing to the same local name are
+// technically ambiguous, as if such an import name is used in the generated
+// code, it's not clear which import statement it refers to. We ignore any
+// potential collisions until actually writing the import table to the generated
+// source file. See importTable.write.
+//
+// Given the following import statements across all the files comprising a
+// package marshalled:
+//
+// "sync"
+// "pkg/sync"
+// "pkg/sentry/kernel"
+// ktime "pkg/sentry/kernel/time"
+//
+// An importTable representing them would look like this:
+//
+// importTable {
+// is: map[string][]*importStmt {
+// "sync": []*importStmt{
+// importStmt{name:"sync", path:"sync", aliased:false}
+// importStmt{name:"sync", path:"pkg/sync", aliased:false}
+// },
+// "kernel": []*importStmt{importStmt{
+// name: "kernel",
+// path: "pkg/sentry/kernel",
+// aliased: false
+// }},
+// "ktime": []*importStmt{importStmt{
+// name: "ktime",
+// path: "pkg/sentry/kernel/time",
+// aliased: true,
+// }},
+// }
+// }
+//
+// Note that the local name "sync" is assigned to two different import
+// statements. This is possible if the import statements are from different
+// source files in the same package.
+//
+// Since go-marshal generates a single output file per package regardless of the
+// number of input files, if "sync" is referenced by any generated code, it's
+// unclear which import statement "sync" refers to. While it's theoretically
+// possible to resolve this by assigning a unique local alias to each instance
+// of the sync package, go-marshal currently aborts when it encounters such an
+// ambiguity.
+//
+// TODO(b/151478251): importTable considers the final component of an import
+// path to be the package name, but this is only a convention. The actual
+// package name is determined by the package statement in the source files for
+// the package.
type importTable struct {
// Map of imports and whether they should be copied to the output.
- is map[string]*importStmt
+ is map[string][]*importStmt
}
func newImportTable() *importTable {
return &importTable{
- is: make(map[string]*importStmt),
+ is: make(map[string][]*importStmt),
}
}
-// Merges import statements from other into i. Collisions in import statements
-// result in a panic.
+// Merges import statements from other into i.
func (i *importTable) merge(other *importTable) {
- for name, im := range other.is {
- if dup, ok := i.is[name]; ok && dup.equivalent(im) {
- panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
- }
-
- i.is[name] = im
+ for name, ims := range other.is {
+ i.is[name] = append(i.is[name], ims...)
}
}
+func (i *importTable) addStmt(s *importStmt) *importStmt {
+ i.is[s.name] = append(i.is[s.name], s)
+ return s
+}
+
func (i *importTable) add(s string) *importStmt {
n := newImport(s)
- i.is[n.name] = n
- return n
+ return i.addStmt(n)
}
func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
- n := newImportFromSpec(spec, f)
- i.is[n.name] = n
- return n
+ return i.addStmt(newImportFromSpec(spec, f))
}
// Marks the import named n as used. If no such import is in the table, returns
// false.
func (i *importTable) markUsed(n string) bool {
- if n, ok := i.is[n]; ok {
- n.markUsed()
+ if ns, ok := i.is[n]; ok {
+ for _, n := range ns {
+ n.markUsed()
+ }
return true
}
return false
}
func (i *importTable) clear() {
- for _, i := range i.is {
- i.used = false
+ for _, is := range i.is {
+ for _, i := range is {
+ i.used = false
+ }
}
}
@@ -367,9 +438,42 @@ func (i *importTable) write(out io.Writer) error {
}
imports := make([]string, 0, len(i.is))
- for _, i := range i.is {
- if i.used {
- imports = append(imports, i.String())
+ for name, is := range i.is {
+ var lastUsed *importStmt
+ var ambiguous bool
+
+ for _, i := range is {
+ if i.used {
+ if lastUsed != nil {
+ if !i.equivalent(lastUsed) {
+ ambiguous = true
+ }
+ }
+ lastUsed = i
+ }
+ }
+
+ if ambiguous {
+ // We have two or more import statements across the different source
+ // files that share a local name, and at least one of these imports
+ // are used by the generated code. This ambiguity can't be resolved
+ // by go-marshal and requires the user intervention. Dump a list of
+ // the colliding import statements and let the user modify the input
+ // files as appropriate.
+ var b strings.Builder
+ fmt.Fprintf(&b, "The imported name %q is used by one of the types marked for marshalling, and which import statement the code refers to is ambiguous. Perhaps give the imports unique local names?\n\n", name)
+ fmt.Fprintf(&b, "The following %d import statements are ambiguous for the local name %q:\n", len(is), name)
+ // Note: len(is) is guaranteed to be 1 or greater or ambiguous can't
+ // be true. Therefore the slicing below is safe.
+ for _, i := range is[:len(is)-1] {
+ fmt.Fprintf(&b, " %v\n", i.debugString())
+ }
+ fmt.Fprintf(&b, " %v", is[len(is)-1].debugString())
+ panic(b.String())
+ }
+
+ if lastUsed != nil {
+ imports = append(imports, lastUsed.String())
}
}
sort.Strings(imports)
diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go
index 3d12eb93c..f74be5c29 100644
--- a/tools/go_marshal/main.go
+++ b/tools/go_marshal/main.go
@@ -1,4 +1,4 @@
-// Copyright 2019 Google LLC
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -31,11 +31,10 @@ import (
)
var (
- pkg = flag.String("pkg", "", "output package")
- output = flag.String("output", "", "output file")
- outputTest = flag.String("output_test", "", "output file for tests")
- imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
- declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on")
+ pkg = flag.String("pkg", "", "output package")
+ output = flag.String("output", "", "output file")
+ outputTest = flag.String("output_test", "", "output file for tests")
+ imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
)
func main() {
@@ -62,7 +61,7 @@ func main() {
// as an import.
extraImports = strings.Split(*imports, ",")
}
- g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports)
+ g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, extraImports)
if err != nil {
panic(err)
}
diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD
index 47dda97a1..4aec98218 100644
--- a/tools/go_marshal/marshal/BUILD
+++ b/tools/go_marshal/marshal/BUILD
@@ -1,14 +1,17 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools:defs.bzl", "go_library")
-package(licenses = ["notice"])
+licenses(["notice"])
go_library(
name = "marshal",
srcs = [
"marshal.go",
+ "marshal_impl_util.go",
],
- importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal",
visibility = [
"//:sandbox",
],
+ deps = [
+ "//pkg/usermem",
+ ],
)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
index a313a27ed..85b196f08 100644
--- a/tools/go_marshal/marshal/marshal.go
+++ b/tools/go_marshal/marshal/marshal.go
@@ -20,18 +20,50 @@
// tools/go_marshal. See the go_marshal README for details.
package marshal
-// Marshallable represents a type that can be marshalled to and from memory.
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Task provides a subset of kernel.Task, used in marshalling. We don't import
+// the kernel package directly to avoid circular dependency.
+type Task interface {
+ // CopyScratchBuffer provides a task goroutine-local scratch buffer. See
+ // kernel.CopyScratchBuffer.
+ CopyScratchBuffer(size int) []byte
+
+ // CopyOutBytes writes the contents of b to the task's memory. See
+ // kernel.CopyOutBytes.
+ CopyOutBytes(addr usermem.Addr, b []byte) (int, error)
+
+ // CopyInBytes reads the contents of the task's memory to b. See
+ // kernel.CopyInBytes.
+ CopyInBytes(addr usermem.Addr, b []byte) (int, error)
+}
+
+// Marshallable represents operations on a type that can be marshalled to and
+// from memory.
+//
+// go-marshal automatically generates implementations for this interface for
+// types marked as '+marshal'.
type Marshallable interface {
+ io.WriterTo
+
// SizeBytes is the size of the memory representation of a type in
// marshalled form.
+ //
+ // SizeBytes must handle a nil receiver. Practically, this means SizeBytes
+ // cannot deference any fields on the object implementing it (but will
+ // likely make use of the type of these fields).
SizeBytes() int
- // MarshalBytes serializes a copy of a type to dst. dst must be at least
- // SizeBytes() long.
+ // MarshalBytes serializes a copy of a type to dst.
+ // Precondition: dst must be at least SizeBytes() in length.
MarshalBytes(dst []byte)
- // UnmarshalBytes deserializes a type from src. src must be at least
- // SizeBytes() long.
+ // UnmarshalBytes deserializes a type from src.
+ // Precondition: src must be at least SizeBytes() in length.
UnmarshalBytes(src []byte)
// Packed returns true if the marshalled size of the type is the same as the
@@ -39,6 +71,12 @@ type Marshallable interface {
// starting at unaligned addresses (should always be true by default for ABI
// structs, verified by automatically generated tests when using
// go_marshal), and has no fields marked `marshal:"unaligned"`.
+ //
+ // Packed must return the same result for all possible values of the type
+ // implementing it. Violating this constraint implies the type doesn't have
+ // a static memory layout, and will lead to memory corruption.
+ // Go-marshal-generated code reuses the result of Packed for multiple values
+ // of the same type.
Packed() bool
// MarshalUnsafe serializes a type by bulk copying its in-memory
@@ -46,15 +84,100 @@ type Marshallable interface {
// has no implicit padding, see Marshallable.Packed. When Packed would
// return false, MarshalUnsafe should fall back to the safer but slower
// MarshalBytes.
+ // Precondition: dst must be at least SizeBytes() in length.
MarshalUnsafe(dst []byte)
- // UnmarshalUnsafe deserializes a type directly to the underlying memory
- // allocated for the object by the runtime.
+ // UnmarshalUnsafe deserializes a type by directly copying to the underlying
+ // memory allocated for the object by the runtime.
//
// This allows much faster unmarshalling of types which have no implicit
// padding, see Marshallable.Packed. When Packed would return false,
// UnmarshalUnsafe should fall back to the safer but slower unmarshal
- // mechanism implemented in UnmarshalBytes (usually by calling
- // UnmarshalBytes directly).
+ // mechanism implemented in UnmarshalBytes.
+ // Precondition: src must be at least SizeBytes() in length.
UnmarshalUnsafe(src []byte)
+
+ // CopyIn deserializes a Marshallable type from a task's memory. This may
+ // only be called from a task goroutine. This is more efficient than calling
+ // UnmarshalUnsafe on Marshallable.Packed types, as the type being
+ // marshalled does not escape. The implementation should avoid creating
+ // extra copies in memory by directly deserializing to the object's
+ // underlying memory.
+ //
+ // If the copy-in from the task memory is only partially successful, CopyIn
+ // should still attempt to deserialize as much data as possible. See comment
+ // for UnmarshalBytes.
+ CopyIn(task Task, addr usermem.Addr) (int, error)
+
+ // CopyOut serializes a Marshallable type to a task's memory. This may only
+ // be called from a task goroutine. This is more efficient than calling
+ // MarshalUnsafe on Marshallable.Packed types, as the type being serialized
+ // does not escape. The implementation should avoid creating extra copies in
+ // memory by directly serializing from the object's underlying memory.
+ //
+ // The copy-out to the task memory may be partially successful, in which
+ // case CopyOut returns how much data was serialized. See comment for
+ // MarshalBytes for implications.
+ CopyOut(task Task, addr usermem.Addr) (int, error)
+
+ // CopyOutN is like CopyOut, but explicitly requests a partial
+ // copy-out. Note that this may yield unexpected results for non-packed
+ // types and the caller may only want to allow this for packed types. See
+ // comment on MarshalBytes.
+ //
+ // The limit must be less than or equal to SizeBytes().
+ CopyOutN(task Task, addr usermem.Addr, limit int) (int, error)
}
+
+// go-marshal generates additional functions for a type based on additional
+// clauses to the +marshal directive. They are documented below.
+//
+// Slice API
+// =========
+//
+// Adding a "slice" clause to the +marshal directive for structs or newtypes on
+// primitives like this:
+//
+// // +marshal slice:FooSlice
+// type Foo struct { ... }
+//
+// Generates four additional functions for marshalling slices of Foos like this:
+//
+// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It
+// // might be more efficient that repeatedly calling Foo.MarshalUnsafe
+// // over a []Foo in a loop if the type is Packed.
+// // Preconditions: dst must be at least len(src)*Foo.SizeBytes() in length.
+// func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... }
+//
+// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It
+// // might be more efficient that repeatedly calling Foo.UnmarshalUnsafe
+// // over a []Foo in a loop if the type is Packed.
+// // Preconditions: src must be at least len(dst)*Foo.SizeBytes() in length.
+// func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... }
+//
+// // CopyFooSliceIn copies in a slice of Foo objects from the task's memory.
+// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... }
+//
+// // CopyFooSliceIn copies out a slice of Foo objects to the task's memory.
+// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... }
+//
+// The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where
+// %s is the first argument to the slice clause. This directive is not supported
+// for newtypes on arrays.
+//
+// The slice clause also takes an optional second argument, which must be the
+// value "inner":
+//
+// // +marshal slice:Int32Slice:inner
+// type Int32 int32
+//
+// This is only valid on newtypes on primitives, and causes the generated
+// functions to accept slices of the inner type instead:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... }
+//
+// Without "inner", they would instead be:
+//
+// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... }
+//
+// This may help avoid a cast depending on how the generated functions are used.
diff --git a/tools/go_marshal/marshal/marshal_impl_util.go b/tools/go_marshal/marshal/marshal_impl_util.go
new file mode 100644
index 000000000..89c7d3575
--- /dev/null
+++ b/tools/go_marshal/marshal/marshal_impl_util.go
@@ -0,0 +1,78 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package marshal
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// StubMarshallable implements the Marshallable interface.
+// StubMarshallable is a convenient embeddable type for satisfying the
+// marshallable interface, but provides no actual implementation. It is
+// useful when the marshallable interface needs to be implemented manually,
+// but the caller doesn't require the full marshallable interface.
+type StubMarshallable struct{}
+
+// WriteTo implements Marshallable.WriteTo.
+func (StubMarshallable) WriteTo(w io.Writer) (n int64, err error) {
+ panic("Please implement your own WriteTo function")
+}
+
+// SizeBytes implements Marshallable.SizeBytes.
+func (StubMarshallable) SizeBytes() int {
+ panic("Please implement your own SizeBytes function")
+}
+
+// MarshalBytes implements Marshallable.MarshalBytes.
+func (StubMarshallable) MarshalBytes(dst []byte) {
+ panic("Please implement your own MarshalBytes function")
+}
+
+// UnmarshalBytes implements Marshallable.UnmarshalBytes.
+func (StubMarshallable) UnmarshalBytes(src []byte) {
+ panic("Please implement your own UnMarshalBytes function")
+}
+
+// Packed implements Marshallable.Packed.
+func (StubMarshallable) Packed() bool {
+ panic("Please implement your own Packed function")
+}
+
+// MarshalUnsafe implements Marshallable.MarshalUnsafe.
+func (StubMarshallable) MarshalUnsafe(dst []byte) {
+ panic("Please implement your own MarshalUnsafe function")
+}
+
+// UnmarshalUnsafe implements Marshallable.UnmarshalUnsafe.
+func (StubMarshallable) UnmarshalUnsafe(src []byte) {
+ panic("Please implement your own UnmarshalUnsafe function")
+}
+
+// CopyIn implements Marshallable.CopyIn.
+func (StubMarshallable) CopyIn(task Task, addr usermem.Addr) (int, error) {
+ panic("Please implement your own CopyIn function")
+}
+
+// CopyOut implements Marshallable.CopyOut.
+func (StubMarshallable) CopyOut(task Task, addr usermem.Addr) (int, error) {
+ panic("Please implement your own CopyOut function")
+}
+
+// CopyOutN implements Marshallable.CopyOutN.
+func (StubMarshallable) CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) {
+ panic("Please implement your own CopyOutN function")
+}
diff --git a/tools/go_marshal/primitive/BUILD b/tools/go_marshal/primitive/BUILD
new file mode 100644
index 000000000..cc08ba63a
--- /dev/null
+++ b/tools/go_marshal/primitive/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "primitive",
+ srcs = [
+ "primitive.go",
+ ],
+ marshal = True,
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ ],
+)
diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go
new file mode 100644
index 000000000..d93edda8b
--- /dev/null
+++ b/tools/go_marshal/primitive/primitive.go
@@ -0,0 +1,247 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package primitive defines marshal.Marshallable implementations for primitive
+// types.
+package primitive
+
+import (
+ "io"
+
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// Int8 is a marshal.Marshallable implementation for int8.
+//
+// +marshal slice:Int8Slice:inner
+type Int8 int8
+
+// Uint8 is a marshal.Marshallable implementation for uint8.
+//
+// +marshal slice:Uint8Slice:inner
+type Uint8 uint8
+
+// Int16 is a marshal.Marshallable implementation for int16.
+//
+// +marshal slice:Int16Slice:inner
+type Int16 int16
+
+// Uint16 is a marshal.Marshallable implementation for uint16.
+//
+// +marshal slice:Uint16Slice:inner
+type Uint16 uint16
+
+// Int32 is a marshal.Marshallable implementation for int32.
+//
+// +marshal slice:Int32Slice:inner
+type Int32 int32
+
+// Uint32 is a marshal.Marshallable implementation for uint32.
+//
+// +marshal slice:Uint32Slice:inner
+type Uint32 uint32
+
+// Int64 is a marshal.Marshallable implementation for int64.
+//
+// +marshal slice:Int64Slice:inner
+type Int64 int64
+
+// Uint64 is a marshal.Marshallable implementation for uint64.
+//
+// +marshal slice:Uint64Slice:inner
+type Uint64 uint64
+
+// ByteSlice is a marshal.Marshallable implementation for []byte.
+// This is a convenience wrapper around a dynamically sized type, and can't be
+// embedded in other marshallable types because it breaks assumptions made by
+// go-marshal internals. It violates the "no dynamically-sized types"
+// constraint of the go-marshal library.
+type ByteSlice []byte
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (b *ByteSlice) SizeBytes() int {
+ return len(*b)
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (b *ByteSlice) MarshalBytes(dst []byte) {
+ copy(dst, *b)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (b *ByteSlice) UnmarshalBytes(src []byte) {
+ copy(*b, src)
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (b *ByteSlice) Packed() bool {
+ return false
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (b *ByteSlice) MarshalUnsafe(dst []byte) {
+ b.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (b *ByteSlice) UnmarshalUnsafe(src []byte) {
+ b.UnmarshalBytes(src)
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (b *ByteSlice) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ return task.CopyInBytes(addr, *b)
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (b *ByteSlice) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ return task.CopyOutBytes(addr, *b)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (b *ByteSlice) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ return task.CopyOutBytes(addr, (*b)[:limit])
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) {
+ n, err := w.Write(*b)
+ return int64(n), err
+}
+
+var _ marshal.Marshallable = (*ByteSlice)(nil)
+
+// Below, we define some convenience functions for marshalling primitive types
+// using the newtypes above, without requiring superfluous casts.
+
+// 16-bit integers
+
+// CopyInt16In is a convenient wrapper for copying in an int16 from the task's
+// memory.
+func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) {
+ var buf Int16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int16(buf)
+ return n, nil
+}
+
+// CopyInt16Out is a convenient wrapper for copying out an int16 to the task's
+// memory.
+func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) {
+ srcP := Int16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint16In is a convenient wrapper for copying in a uint16 from the task's
+// memory.
+func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) {
+ var buf Uint16
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint16(buf)
+ return n, nil
+}
+
+// CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's
+// memory.
+func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) {
+ srcP := Uint16(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 32-bit integers
+
+// CopyInt32In is a convenient wrapper for copying in an int32 from the task's
+// memory.
+func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) {
+ var buf Int32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int32(buf)
+ return n, nil
+}
+
+// CopyInt32Out is a convenient wrapper for copying out an int32 to the task's
+// memory.
+func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) {
+ srcP := Int32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint32In is a convenient wrapper for copying in a uint32 from the task's
+// memory.
+func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) {
+ var buf Uint32
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint32(buf)
+ return n, nil
+}
+
+// CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's
+// memory.
+func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) {
+ srcP := Uint32(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// 64-bit integers
+
+// CopyInt64In is a convenient wrapper for copying in an int64 from the task's
+// memory.
+func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) {
+ var buf Int64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = int64(buf)
+ return n, nil
+}
+
+// CopyInt64Out is a convenient wrapper for copying out an int64 to the task's
+// memory.
+func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) {
+ srcP := Int64(src)
+ return srcP.CopyOut(task, addr)
+}
+
+// CopyUint64In is a convenient wrapper for copying in a uint64 from the task's
+// memory.
+func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) {
+ var buf Uint64
+ n, err := buf.CopyIn(task, addr)
+ if err != nil {
+ return n, err
+ }
+ *dst = uint64(buf)
+ return n, nil
+}
+
+// CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's
+// memory.
+func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) {
+ srcP := Uint64(src)
+ return srcP.CopyOut(task, addr)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
index fa82f8e9b..3d989823a 100644
--- a/tools/go_marshal/test/BUILD
+++ b/tools/go_marshal/test/BUILD
@@ -1,8 +1,6 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools:defs.bzl", "go_library", "go_test")
-package(licenses = ["notice"])
-
-load("//tools/go_marshal:defs.bzl", "go_library")
+licenses(["notice"])
package_group(
name = "gomarshal_test",
@@ -17,7 +15,7 @@ go_test(
deps = [
":test",
"//pkg/binary",
- "//pkg/sentry/usermem",
+ "//pkg/usermem",
"//tools/go_marshal/analysis",
],
)
@@ -26,6 +24,21 @@ go_library(
name = "test",
testonly = 1,
srcs = ["test.go"],
- importpath = "gvisor.dev/gvisor/tools/go_marshal/test",
+ marshal = True,
+ visibility = ["//tools/go_marshal/test:__subpackages__"],
deps = ["//tools/go_marshal/test/external"],
)
+
+go_test(
+ name = "marshal_test",
+ size = "small",
+ srcs = ["marshal_test.go"],
+ deps = [
+ ":test",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//tools/go_marshal/analysis",
+ "//tools/go_marshal/marshal",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
index e70db06d8..224d308c7 100644
--- a/tools/go_marshal/test/benchmark_test.go
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -22,9 +22,9 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/tools/go_marshal/analysis"
- test "gvisor.dev/gvisor/tools/go_marshal/test"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
)
// Marshalling using the standard encoding/binary package.
@@ -176,3 +176,45 @@ func BenchmarkGoMarshalUnsafe(b *testing.B) {
panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
}
}
+
+func BenchmarkBinarySlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+func BenchmarkGoMarshalUnsafeSlice(b *testing.B) {
+ var s1, s2 [64]test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, (*test.Stat)(nil).SizeBytes()*len(s1))
+ test.MarshalUnsafeStatSlice(s1[:], buf)
+ test.UnmarshalUnsafeStatSlice(s2[:], buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD
new file mode 100644
index 000000000..f74e6ffae
--- /dev/null
+++ b/tools/go_marshal/test/escape/BUILD
@@ -0,0 +1,14 @@
+load("//tools:defs.bzl", "go_library")
+
+licenses(["notice"])
+
+go_library(
+ name = "escape",
+ testonly = 1,
+ srcs = ["escape.go"],
+ deps = [
+ "//pkg/usermem",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/test",
+ ],
+)
diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go
new file mode 100644
index 000000000..6a46ddbf8
--- /dev/null
+++ b/tools/go_marshal/test/escape/escape.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package escape
+
+import (
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// dummyTask implements marshal.Task.
+type dummyTask struct {
+}
+
+func (*dummyTask) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) {
+ return len(b), nil
+}
+
+func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) {
+ return len(b), nil
+}
+
+func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) {
+ buf := t.CopyScratchBuffer(marshallable.SizeBytes())
+ marshallable.MarshalBytes(buf)
+ t.CopyOutBytes(addr, buf)
+}
+
+func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) {
+ buf := t.CopyScratchBuffer(marshallable.SizeBytes())
+ marshallable.MarshalUnsafe(buf)
+ t.CopyOutBytes(addr, buf)
+}
+
+// +checkescape:all
+//go:nosplit
+func doCopyIn(t *dummyTask) {
+ var stat test.Stat
+ stat.CopyIn(t, usermem.Addr(0xf000ba12))
+}
+
+// +checkescape:all
+//go:nosplit
+func doCopyOut(t *dummyTask) {
+ var stat test.Stat
+ stat.CopyOut(t, usermem.Addr(0xf000ba12))
+}
+
+// +mustescape:builtin
+// +mustescape:stack
+func doMarshalBytesDirect(t *dummyTask) {
+ var stat test.Stat
+ buf := t.CopyScratchBuffer(stat.SizeBytes())
+ stat.MarshalBytes(buf)
+ t.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+}
+
+// +mustescape:builtin
+// +mustescape:stack
+func doMarshalUnsafeDirect(t *dummyTask) {
+ var stat test.Stat
+ buf := t.CopyScratchBuffer(stat.SizeBytes())
+ stat.MarshalUnsafe(buf)
+ t.CopyOutBytes(usermem.Addr(0xf000ba12), buf)
+}
+
+// +mustescape:local,heap
+// +mustescape:stack
+func doMarshalBytesViaMarshallable(t *dummyTask) {
+ var stat test.Stat
+ t.MarshalBytes(usermem.Addr(0xf000ba12), &stat)
+}
+
+// +mustescape:local,heap
+// +mustescape:stack
+func doMarshalUnsafeViaMarshallable(t *dummyTask) {
+ var stat test.Stat
+ t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat)
+}
diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD
index 8fb43179b..0cf6da603 100644
--- a/tools/go_marshal/test/external/BUILD
+++ b/tools/go_marshal/test/external/BUILD
@@ -1,11 +1,11 @@
-package(licenses = ["notice"])
+load("//tools:defs.bzl", "go_library")
-load("//tools/go_marshal:defs.bzl", "go_library")
+licenses(["notice"])
go_library(
name = "external",
testonly = 1,
srcs = ["external.go"],
- importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external",
+ marshal = True,
visibility = ["//tools/go_marshal/test:gomarshal_test"],
)
diff --git a/tools/go_marshal/test/external/external.go b/tools/go_marshal/test/external/external.go
index 4be3722f3..26fe8e0c8 100644
--- a/tools/go_marshal/test/external/external.go
+++ b/tools/go_marshal/test/external/external.go
@@ -21,3 +21,11 @@ package external
type External struct {
j int64
}
+
+// NotPacked is an unaligned Marshallable type for use in testing.
+//
+// +marshal
+type NotPacked struct {
+ a int32
+ b byte `marshal:"unaligned"`
+}
diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go
new file mode 100644
index 000000000..16829ee45
--- /dev/null
+++ b/tools/go_marshal/test/marshal_test.go
@@ -0,0 +1,515 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal_test contains manual tests for the marshal interface. These
+// are intended to test behaviour not covered by the automatically generated
+// tests.
+package marshal_test
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "reflect"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+var simulatedErr error = syserror.EFAULT
+
+// mockTask implements marshal.Task.
+type mockTask struct {
+ taskMem usermem.BytesIO
+}
+
+// populate fills the task memory with the contents of val.
+func (t *mockTask) populate(val interface{}) {
+ var buf bytes.Buffer
+ // Use binary.Write so we aren't testing go-marshal against its own
+ // potentially buggy implementation.
+ if err := binary.Write(&buf, usermem.ByteOrder, val); err != nil {
+ panic(err)
+ }
+ t.taskMem.Bytes = buf.Bytes()
+}
+
+func (t *mockTask) setLimit(n int) {
+ if len(t.taskMem.Bytes) < n {
+ grown := make([]byte, n)
+ copy(grown, t.taskMem.Bytes)
+ t.taskMem.Bytes = grown
+ return
+ }
+ t.taskMem.Bytes = t.taskMem.Bytes[:n]
+}
+
+// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer.
+func (t *mockTask) CopyScratchBuffer(size int) []byte {
+ return make([]byte, size)
+}
+
+// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation
+// completely ignores the target address and stores a copy of b in its
+// internally buffer, overriding any previous contents.
+func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{})
+}
+
+// CopyInBytes implements marshal.Task.CopyInBytes. The implementation
+// completely ignores the source address and always fills b from the begining of
+// its internal buffer.
+func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) {
+ return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{})
+}
+
+// unsafeMemory returns the underlying memory for m. The returned slice is only
+// valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+func unsafeMemory(m marshal.Marshallable) []byte {
+ if !m.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ // reflect.ValueOf(m)
+ // .Elem() // Unwrap interface to inner concrete object
+ // .Addr() // Pointer value to object
+ // .Pointer() // Actual address from the pointer value
+ ptr := reflect.ValueOf(m).Elem().Addr().Pointer()
+
+ size := m.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = ptr
+ hdr.Len = size
+ hdr.Cap = size
+
+ return mem
+}
+
+// unsafeMemorySlice returns the underlying memory for m. The returned slice is
+// only valid for the lifetime for m. The garbage collector isn't aware that the
+// returned slice is related to m, the caller must ensure m lives long enough.
+//
+// Precondition: m must be a slice.
+func unsafeMemorySlice(m interface{}, elt marshal.Marshallable) []byte {
+ kind := reflect.TypeOf(m).Kind()
+ if kind != reflect.Slice {
+ panic("unsafeMemorySlice called on non-slice")
+ }
+
+ if !elt.Packed() {
+ // We can't return a slice pointing to the underlying memory
+ // since the layout isn't packed. Allocate a temporary buffer
+ // and marshal instead.
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, usermem.ByteOrder, m); err != nil {
+ panic(err)
+ }
+ return buf.Bytes()
+ }
+
+ v := reflect.ValueOf(m)
+ length := v.Len() * elt.SizeBytes()
+
+ var mem []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&mem))
+ hdr.Data = v.Pointer() // This is a pointer to the first elem for slices.
+ hdr.Len = length
+ hdr.Cap = length
+
+ return mem
+}
+
+func isZeroes(buf []byte) bool {
+ for _, b := range buf {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// compareMemory compares the first n bytes of two chuncks of memory represented
+// by expected and actual.
+func compareMemory(t *testing.T, expected, actual []byte, n int) {
+ t.Logf("Expected (%d): %v (%d) + (%d) %v\n", len(expected), expected[:n], n, len(expected)-n, expected[n:])
+ t.Logf("Actual (%d): %v (%d) + (%d) %v\n", len(actual), actual[:n], n, len(actual)-n, actual[n:])
+
+ if diff := cmp.Diff(expected[:n], actual[:n]); diff != "" {
+ t.Errorf("Memory buffers don't match:\n--- expected only\n+++ actual only\n%v", diff)
+ }
+}
+
+// limitedCopyIn populates task memory with src, then unmarshals task memory to
+// dst. The task signals an error at limit bytes during copy-in, which should
+// result in a truncated unmarshalling.
+func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) {
+ var task mockTask
+ task.populate(src)
+ task.setLimit(limit)
+
+ n, err := dst.CopyIn(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := unsafeMemory(dst)
+ defer runtime.KeepAlive(dst)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if dst.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", dst.SizeBytes()-n, actualMem)
+ }
+}
+
+// limitedCopyOut marshals src to task memory. The task signals an error at
+// limit bytes during copy-out, which should result in a truncated marshalling.
+func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOut(&task, usermem.Addr(0))
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if err != simulatedErr {
+ t.Errorf("CopyOut returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// copyOutN marshals src to task memory, requesting the marshalling to be
+// limited to limit bytes.
+func copyOutN(t *testing.T, src marshal.Marshallable, limit int) {
+ var task mockTask
+ task.setLimit(limit)
+
+ n, err := src.CopyOutN(&task, usermem.Addr(0), limit)
+ if err != nil {
+ t.Errorf("CopyOut returned unexpected error: %v", err)
+ }
+ if n != limit {
+ t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+
+ expectedMem := unsafeMemory(src)
+ defer runtime.KeepAlive(src)
+ actualMem := task.taskMem.Bytes
+
+ t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:])
+ t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:])
+
+ compareMemory(t, expectedMem, actualMem, n)
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling succeeds when the
+// underyling copy in/out operations partially succeed.
+func TestLimitedMarshalling(t *testing.T) {
+ types := []reflect.Type{
+ // Packed types.
+ reflect.TypeOf((*test.Type2)(nil)),
+ reflect.TypeOf((*test.Type3)(nil)),
+ reflect.TypeOf((*test.Timespec)(nil)),
+ reflect.TypeOf((*test.Stat)(nil)),
+ reflect.TypeOf((*test.InetAddr)(nil)),
+ reflect.TypeOf((*test.SignalSet)(nil)),
+ reflect.TypeOf((*test.SignalSetAlias)(nil)),
+ // Non-packed types.
+ reflect.TypeOf((*test.Type1)(nil)),
+ reflect.TypeOf((*test.Type4)(nil)),
+ reflect.TypeOf((*test.Type5)(nil)),
+ reflect.TypeOf((*test.Type6)(nil)),
+ reflect.TypeOf((*test.Type7)(nil)),
+ reflect.TypeOf((*test.Type8)(nil)),
+ }
+
+ for _, tyPtr := range types {
+ // Remove one level of pointer-indirection from the type. We get this
+ // back when we pass the type to reflect.New.
+ ty := tyPtr.Elem()
+
+ // Partial copy-in.
+ t.Run(fmt.Sprintf("PartialCopyIn_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ actual := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyIn(t, expected, actual, expected.SizeBytes()/2)
+ })
+
+ // Partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOut_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ limitedCopyOut(t, expected, expected.SizeBytes()/2)
+ })
+
+ // Explicitly request partial copy-out.
+ t.Run(fmt.Sprintf("PartialCopyOutN_%v", ty), func(t *testing.T) {
+ expected := reflect.New(ty).Interface().(marshal.Marshallable)
+ analysis.RandomizeValue(expected)
+
+ copyOutN(t, expected, expected.SizeBytes()/2)
+ })
+ }
+}
+
+// TestLimitedMarshalling verifies marshalling/unmarshalling of slices of
+// marshallable types succeed when the underyling copy in/out operations
+// partially succeed.
+func TestLimitedSliceMarshalling(t *testing.T) {
+ types := []struct {
+ arrayPtrType reflect.Type
+ copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error)
+ copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error)
+ unsafeMemory func(arrPtr interface{}) []byte
+ }{
+ // Packed types.
+ {
+ reflect.TypeOf((*[20]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Stat)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Stat)[:]
+ return test.CopyStatSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Stat)[:]
+ return test.CopyStatSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Stat)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[5]test.SignalSetAlias)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[5]test.SignalSetAlias)[:]
+ return test.CopySignalSetAliasSliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[5]test.SignalSetAlias)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ // Non-packed types.
+ {
+ reflect.TypeOf((*[20]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[20]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[20]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[20]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[1]test.Type1)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[1]test.Type1)[:]
+ return test.CopyType1SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[1]test.Type1)[:]
+ return test.CopyType1SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[1]test.Type1)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ {
+ reflect.TypeOf((*[7]test.Type8)(nil)),
+ func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) {
+ slice := dst.(*[7]test.Type8)[:]
+ return test.CopyType8SliceIn(task, addr, slice)
+ },
+ func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) {
+ slice := src.(*[7]test.Type8)[:]
+ return test.CopyType8SliceOut(task, addr, slice)
+ },
+ func(a interface{}) []byte {
+ slice := a.(*[7]test.Type8)[:]
+ return unsafeMemorySlice(slice, &slice[0])
+ },
+ },
+ }
+
+ for _, tt := range types {
+ // The body of this loop is generic over the type tt.arrayPtrType, with
+ // the help of reflection. To aid in readability, comments below show
+ // the equivalent go code assuming
+ // tt.arrayPtrType = typeof(*[20]test.Stat).
+
+ // Equivalent:
+ // var x *[20]test.Stat
+ // arrayTy := reflect.TypeOf(*x)
+ arrayTy := tt.arrayPtrType.Elem()
+
+ // Partial copy-in of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceIn_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+ actual := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceIn(&task, usermem.Addr(0), actual)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := tt.unsafeMemory(actual)
+ defer runtime.KeepAlive(actual)
+
+ compareMemory(t, expectedMem, actualMem, n)
+
+ // The last n bytes should be zero for actual, since actual was
+ // zero-initialized, and CopyIn shouldn't have touched those bytes. However
+ // we can only guarantee we didn't touch anything in the last n bytes if the
+ // layout is packed.
+ if elem.Packed() && !isZeroes(actualMem[n:]) {
+ t.Errorf("Expected the last %d bytes of copied in object to be zeroes, got %v\n", (elem.SizeBytes()*length)-n, actualMem)
+ }
+ })
+
+ // Partial copy-out of slices.
+ t.Run(fmt.Sprintf("PartialCopySliceOut_%v", arrayTy), func(t *testing.T) {
+ // Equivalent:
+ // var x [20]test.Stat
+ // length := len(x)
+ length := arrayTy.Len()
+ if length < 1 {
+ panic("Test type can't be zero-length array")
+ }
+ // Equivalent:
+ // elem := new(test.Stat).(marshal.Marshallable)
+ elem := reflect.New(arrayTy.Elem()).Interface().(marshal.Marshallable)
+
+ // Equivalent:
+ // var expected, actual interface{}
+ // expected = new([20]test.Stat)
+ // actual = new([20]test.Stat)
+ expected := reflect.New(arrayTy).Interface()
+
+ analysis.RandomizeValue(expected)
+
+ limit := (length * elem.SizeBytes()) / 2
+ // Also make sure the limit is partially inside one of the elements.
+ limit += elem.SizeBytes() / 2
+ analysis.RandomizeValue(expected)
+
+ var task mockTask
+ task.populate(expected)
+ task.setLimit(limit)
+
+ n, err := tt.copySliceOut(&task, usermem.Addr(0), expected)
+ if n != limit {
+ t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n)
+ }
+ if n < length*elem.SizeBytes() && err != simulatedErr {
+ t.Errorf("CopyIn returned unexpected error, expected %v, got %v", simulatedErr, err)
+ }
+
+ expectedMem := tt.unsafeMemory(expected)
+ defer runtime.KeepAlive(expected)
+ actualMem := task.taskMem.Bytes
+
+ compareMemory(t, expectedMem, actualMem, n)
+ })
+ }
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 8de02d707..f75ca1b7f 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -23,7 +23,7 @@ import (
// Type1 is a test data type.
//
-// +marshal
+// +marshal slice:Type1Slice
type Type1 struct {
a Type2
x, y int64 // Multiple field names.
@@ -75,6 +75,34 @@ type Type5 struct {
m int64
}
+// Type6 is a test data type ends mid-word.
+//
+// +marshal
+type Type6 struct {
+ a int64
+ b int64
+ // If c isn't marked unaligned, analysis fails (as it should, since
+ // the unsafe API corrupts Type7).
+ c byte `marshal:"unaligned"`
+}
+
+// Type7 is a test data type that contains a child struct that ends
+// mid-word.
+// +marshal
+type Type7 struct {
+ x Type6
+ y int64
+}
+
+// Type8 is a test data type which contains an external non-packed field.
+//
+// +marshal slice:Type8Slice
+type Type8 struct {
+ a int64
+ np ex.NotPacked
+ b int64
+}
+
// Timespec represents struct timespec in <time.h>.
//
// +marshal
@@ -85,7 +113,7 @@ type Timespec struct {
// Stat represents struct stat.
//
-// +marshal
+// +marshal slice:StatSlice
type Stat struct {
Dev uint64
Ino uint64
@@ -103,3 +131,46 @@ type Stat struct {
CTime Timespec
_ [3]int64
}
+
+// InetAddr is an example marshallable newtype on an array.
+//
+// +marshal
+type InetAddr [4]byte
+
+// SignalSet is an example marshallable newtype on a primitive.
+//
+// +marshal slice:SignalSetSlice:inner
+type SignalSet uint64
+
+// SignalSetAlias is an example newtype on another marshallable type.
+//
+// +marshal slice:SignalSetAliasSlice
+type SignalSetAlias SignalSet
+
+const sizeA = 64
+const sizeB = 8
+
+// TestArray is a test data structure on an array with a constant length.
+//
+// +marshal
+type TestArray [sizeA]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// constants for the array length.
+//
+// +marshal
+type TestArray2 [sizeA * sizeB]int32
+
+// TestArray2 is a newtype on an array with a simple arithmetic expression of
+// mixed constants and literals for the array length.
+//
+// +marshal
+type TestArray3 [sizeA*sizeB + 12]int32
+
+// Type9 is a test data type containing an array with a non-literal length.
+//
+// +marshal
+type Type9 struct {
+ x int64
+ y [sizeA]int32
+}
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD
index bb53f8ae9..913558b4e 100644
--- a/tools/go_stateify/BUILD
+++ b/tools/go_stateify/BUILD
@@ -1,9 +1,16 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools:defs.bzl", "bzl_library", "go_binary")
package(licenses = ["notice"])
go_binary(
name = "stateify",
srcs = ["main.go"],
- visibility = ["//visibility:public"],
+ visibility = ["//:sandbox"],
+ deps = ["//tools/tags"],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
)
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
index 3ce36c1c8..6a5e666f0 100644
--- a/tools/go_stateify/defs.bzl
+++ b/tools/go_stateify/defs.bzl
@@ -1,41 +1,4 @@
-"""Stateify is a tool for generating state wrappers for Go types.
-
-The recommended way is to use the go_library rule defined below with mostly
-identical configuration as the native go_library rule.
-
-load("//tools/go_stateify:defs.bzl", "go_library")
-
-go_library(
- name = "foo",
- srcs = ["foo.go"],
-)
-
-Under the hood, the go_stateify rule is used to generate a file that will
-appear in a Go target; the output file should appear explicitly in a srcs list.
-For example (the above is still the preferred way):
-
-load("//tools/go_stateify:defs.bzl", "go_stateify")
-
-go_stateify(
- name = "foo_state",
- srcs = ["foo.go"],
- out = "foo_state.go",
- package = "foo",
-)
-
-go_library(
- name = "foo",
- srcs = [
- "foo.go",
- "foo_state.go",
- ],
- deps = [
- "//pkg/state",
- ],
-)
-"""
-
-load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library")
+"""Stateify is a tool for generating state wrappers for Go types."""
def _go_stateify_impl(ctx):
"""Implementation for the stateify tool."""
@@ -43,12 +6,12 @@ def _go_stateify_impl(ctx):
# Run the stateify command.
args = ["-output=%s" % output.path]
- args += ["-pkg=%s" % ctx.attr.package]
+ args.append("-fullpkg=%s" % ctx.attr.package)
if ctx.attr._statepkg:
- args += ["-statepkg=%s" % ctx.attr._statepkg]
+ args.append("-statepkg=%s" % ctx.attr._statepkg)
if ctx.attr.imports:
- args += ["-imports=%s" % ",".join(ctx.attr.imports)]
- args += ["--"]
+ args.append("-imports=%s" % ",".join(ctx.attr.imports))
+ args.append("--")
for src in ctx.attr.srcs:
args += [f.path for f in src.files.to_list()]
ctx.actions.run(
@@ -80,14 +43,11 @@ for statified types.
mandatory = False,
),
"package": attr.string(
- doc = "The package name for the input sources.",
+ doc = "The fully qualified package name for the input sources.",
mandatory = True,
),
"out": attr.output(
- doc = """
-The name of the generated file output. This must not conflict with any other
-files and must be added to the srcs of the relevant go_library.
-""",
+ doc = "Name of the generator output file.",
mandatory = True,
),
"_tool": attr.label(
@@ -98,39 +58,3 @@ files and must be added to the srcs of the relevant go_library.
"_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"),
},
)
-
-def go_library(name, srcs, deps = [], imports = [], **kwargs):
- """Standard go_library wrapped which generates state source files.
-
- Args:
- name: the name of the go_library rule.
- srcs: sources of the go_library. Each will be processed for stateify
- annotations.
- deps: dependencies for the go_library.
- imports: an optional list of extra non-aliased, Go-style absolute import
- paths required for stateified types.
- **kwargs: passed to go_library.
- """
- if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs:
- # Only do stateification for non-state packages without manual autogen.
- go_stateify(
- name = name + "_state_autogen",
- srcs = [src for src in srcs if src.endswith(".go")],
- imports = imports,
- package = name,
- out = name + "_state_autogen.go",
- )
- all_srcs = srcs + [name + "_state_autogen.go"]
- if "//pkg/state" not in deps:
- all_deps = deps + ["//pkg/state"]
- else:
- all_deps = deps
- else:
- all_deps = deps
- all_srcs = srcs
- _go_library(
- name = name,
- srcs = all_srcs,
- deps = all_deps,
- **kwargs
- )
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index db7a7107b..4f6ed208a 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -23,13 +23,16 @@ import (
"go/parser"
"go/token"
"os"
+ "path/filepath"
"reflect"
"strings"
"sync"
+
+ "gvisor.dev/gvisor/tools/tags"
)
var (
- pkg = flag.String("pkg", "", "output package")
+ fullPkg = flag.String("fullpkg", "", "fully qualified output package")
imports = flag.String("imports", "", "extra imports for the output file")
output = flag.String("output", "", "output file")
statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
@@ -100,7 +103,7 @@ type scanFunctions struct {
// skipped if nil.
//
// Fields tagged nosave are skipped.
-func scanFields(ss *ast.StructType, fn scanFunctions) {
+func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) {
if ss.Fields.List == nil {
// No fields.
return
@@ -124,7 +127,16 @@ func scanFields(ss *ast.StructType, fn scanFunctions) {
continue
}
- switch tag := extractStateTag(field.Tag); tag {
+ // Is this a anonymous struct? If yes, then continue the
+ // recursion with the given prefix. We don't pay attention to
+ // any tags on the top-level struct field.
+ tag := extractStateTag(field.Tag)
+ if anon, ok := field.Type.(*ast.StructType); ok && tag == "" {
+ scanFields(anon, name+".", fn)
+ continue
+ }
+
+ switch tag {
case "zerovalue":
if fn.zerovalue != nil {
fn.zerovalue(name)
@@ -168,7 +180,7 @@ func main() {
flag.Usage()
os.Exit(1)
}
- if *pkg == "" {
+ if *fullPkg == "" {
fmt.Fprintf(os.Stderr, "Error: package required.")
os.Exit(1)
}
@@ -198,33 +210,25 @@ func main() {
// initCalls is dumped at the end.
var initCalls []string
- // Declare our emission closures.
+ // Common closures.
emitRegister := func(name string) {
- initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name))
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name))
}
emitZeroCheck := func(name string) {
- fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name)
- }
- emitLoadValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName)
- }
- emitLoad := func(name string) {
- fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name)
- }
- emitLoadWait := func(name string) {
- fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name)
- }
- emitSaveValue := func(name, typName string) {
- fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
- fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name)
+ fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name)
}
- emitSave := func(name string) {
- fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name)
+
+ // Automated warning.
+ fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
+
+ // Emit build tags.
+ if t := tags.Aggregate(flag.Args()); len(t) > 0 {
+ fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n"))
}
// Emit the package name.
- fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
- fmt.Fprintf(outputFile, "package %s\n\n", *pkg)
+ _, pkg := filepath.Split(*fullPkg)
+ fmt.Fprintf(outputFile, "package %s\n\n", pkg)
// Emit the imports lazily.
var once sync.Once
@@ -256,6 +260,7 @@ func main() {
fmt.Fprintf(os.Stderr, "Input %q can't be parsed: %v\n", filename, err)
os.Exit(1)
}
+
files = append(files, f)
}
@@ -317,87 +322,140 @@ func main() {
continue
}
- // Only generate code for types marked
- // "// +stateify savable" in one of the proceeding
- // comment lines.
+ // Only generate code for types marked "// +stateify
+ // savable" in one of the proceeding comment lines. If
+ // the line is marked "// +stateify type" then only
+ // generate type information and register the type.
if d.Doc == nil {
continue
}
- savable := false
+ var (
+ generateTypeInfo = false
+ generateSaverLoader = false
+ )
for _, l := range d.Doc.List {
if l.Text == "// +stateify savable" {
- savable = true
+ generateTypeInfo = true
+ generateSaverLoader = true
break
}
+ if l.Text == "// +stateify type" {
+ generateTypeInfo = true
+ }
}
- if !savable {
+ if !generateTypeInfo && !generateSaverLoader {
continue
}
for _, gs := range d.Specs {
ts := gs.(*ast.TypeSpec)
- switch ts.Type.(type) {
- case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr:
- // Don't register.
- break
+ switch x := ts.Type.(type) {
case *ast.StructType:
maybeEmitImports()
- ss := ts.Type.(*ast.StructType)
+ // Record the slot for each field.
+ fieldCount := 0
+ fields := make(map[string]int)
+ emitField := func(name string) {
+ fmt.Fprintf(outputFile, " \"%s\",\n", name)
+ fields[name] = fieldCount
+ fieldCount++
+ }
+ emitFieldValue := func(name string, _ string) {
+ emitField(name)
+ }
+ emitLoadValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName)
+ }
+ emitLoad := func(name string) {
+ fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name)
+ }
+ emitLoadWait := func(name string) {
+ fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name)
+ }
+ emitSaveValue := func(name, typName string) {
+ fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name))
+ fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name)
+ }
+ emitSave := func(name string) {
+ fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name)
+ }
+
+ // Generate the type name method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
+ fmt.Fprintf(outputFile, "}\n\n")
+
+ // Generate the fields method.
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return []string{\n")
+ scanFields(x, "", scanFunctions{
+ normal: emitField,
+ wait: emitField,
+ value: emitFieldValue,
+ })
+ fmt.Fprintf(outputFile, " }\n")
+ fmt.Fprintf(outputFile, "}\n\n")
- // Define beforeSave if a definition was not found. This
- // prevents the code from compiling if a custom beforeSave
- // was defined in a file not provided to this binary and
- // prevents inherited methods from being called multiple times
- // by overriding them.
- if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok {
- fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name)
+ // Define beforeSave if a definition was not found. This prevents
+ // the code from compiling if a custom beforeSave was defined in a
+ // file not provided to this binary and prevents inherited methods
+ // from being called multiple times by overriding them.
+ if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name)
}
// Generate the save method.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " x.beforeSave()\n")
- scanFields(ss, scanFunctions{zerovalue: emitZeroCheck})
- scanFields(ss, scanFunctions{value: emitSaveValue})
- scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave})
- fmt.Fprintf(outputFile, "}\n\n")
+ //
+ // N.B. For historical reasons, we perform the value saves first,
+ // and perform the value loads last. There should be no dependency
+ // on this specific behavior, but the ability to specify slots
+ // allows a manual implementation to be order-dependent.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix)
+ fmt.Fprintf(outputFile, " x.beforeSave()\n")
+ scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck})
+ scanFields(x, "", scanFunctions{value: emitSaveValue})
+ scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave})
+ fmt.Fprintf(outputFile, "}\n\n")
+ }
- // Define afterLoad if a definition was not found. We do this
- // for the same reason that we do it for beforeSave.
+ // Define afterLoad if a definition was not found. We do this for
+ // the same reason that we do it for beforeSave.
_, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}]
- if !hasAfterLoad {
- fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name)
+ if !hasAfterLoad && generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name)
}
// Generate the load method.
//
- // Note that the manual loads always follow the
- // automated loads.
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait})
- scanFields(ss, scanFunctions{value: emitLoadValue})
- if hasAfterLoad {
- // The call to afterLoad is made conditionally, because when
- // AfterLoad is called, the object encodes a dependency on
- // referred objects (i.e. fields). This means that afterLoad
- // will not be called until the other afterLoads are called.
- fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ // N.B. See the comment above for the save method.
+ if generateSaverLoader {
+ fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix)
+ scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait})
+ scanFields(x, "", scanFunctions{value: emitLoadValue})
+ if hasAfterLoad {
+ // The call to afterLoad is made conditionally, because when
+ // AfterLoad is called, the object encodes a dependency on
+ // referred objects (i.e. fields). This means that afterLoad
+ // will not be called until the other afterLoads are called.
+ fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n")
+ }
+ fmt.Fprintf(outputFile, "}\n\n")
}
- fmt.Fprintf(outputFile, "}\n\n")
// Add to our registration.
emitRegister(ts.Name.Name)
+
case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType:
maybeEmitImports()
- _, val := resolveTypeName(ts.Name.Name, ts.Type)
-
- // Dispatch directly.
- fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val)
+ // Generate the info methods.
+ fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name)
fmt.Fprintf(outputFile, "}\n\n")
- fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix)
- fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val)
+ fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name)
+ fmt.Fprintf(outputFile, " return nil\n")
fmt.Fprintf(outputFile, "}\n\n")
// See above.
diff --git a/tools/installers/BUILD b/tools/installers/BUILD
new file mode 100644
index 000000000..13d3cc5e0
--- /dev/null
+++ b/tools/installers/BUILD
@@ -0,0 +1,41 @@
+# Installers for use by the tools/vm_test rules.
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+sh_binary(
+ name = "head",
+ srcs = ["head.sh"],
+ data = [
+ "//runsc",
+ ],
+)
+
+sh_binary(
+ name = "images",
+ srcs = ["images.sh"],
+ data = [
+ "//images",
+ ],
+)
+
+sh_binary(
+ name = "master",
+ srcs = ["master.sh"],
+)
+
+sh_binary(
+ name = "containerd",
+ srcs = ["containerd.sh"],
+)
+
+sh_binary(
+ name = "shim",
+ srcs = ["shim.sh"],
+ data = [
+ "//shim/v1:gvisor-containerd-shim",
+ "//shim/v2:containerd-shim-runsc-v1",
+ ],
+)
diff --git a/tools/installers/containerd.sh b/tools/installers/containerd.sh
new file mode 100755
index 000000000..6b7bb261c
--- /dev/null
+++ b/tools/installers/containerd.sh
@@ -0,0 +1,114 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+declare -r CONTAINERD_VERSION=${CONTAINERD_VERSION:-1.3.0}
+declare -r CONTAINERD_MAJOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $1; }')"
+declare -r CONTAINERD_MINOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $2; }')"
+
+# Default to an older version for crictl for containerd <= 1.2.
+if [[ "${CONTAINERD_MAJOR}" -eq 1 ]] && [[ "${CONTAINERD_MINOR}" -le 2 ]]; then
+ declare -r CRITOOLS_VERSION=${CRITOOLS_VERSION:-1.13.0}
+else
+ declare -r CRITOOLS_VERSION=${CRITOOLS_VERSION:-1.18.0}
+fi
+
+# Helper for Go packages below.
+install_helper() {
+ PACKAGE="${1}"
+ TAG="${2}"
+
+ # Clone the repository.
+ mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \
+ git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}"
+
+ # Checkout and build the repository.
+ (cd "${GOPATH}"/src/"${PACKAGE}" && \
+ git checkout "${TAG}" && \
+ make && \
+ make install)
+}
+
+# Install dependencies for the crictl tests.
+while true; do
+ if (apt-get update && apt-get install -y \
+ btrfs-tools \
+ libseccomp-dev); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Install containerd & cri-tools.
+declare -rx GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
+install_helper github.com/containerd/containerd "v${CONTAINERD_VERSION}" "${GOPATH}"
+install_helper github.com/kubernetes-sigs/cri-tools "v${CRITOOLS_VERSION}" "${GOPATH}"
+
+# Configure containerd-shim.
+#
+# Note that for versions <= 1.1 the legacy shim must be installed in /usr/bin,
+# which should align with the installer script in head.sh (or master.sh).
+if [[ "${CONTAINERD_MAJOR}" -le 1 ]] && [[ "${CONTAINERD_MINOR}" -lt 2 ]]; then
+ declare -r shim_config_path=/etc/containerd/gvisor-containerd-shim.toml
+ mkdir -p $(dirname ${shim_config_path})
+ cat > ${shim_config_path} <<-EOF
+ runc_shim = "/usr/bin/containerd-shim"
+
+[runsc_config]
+ debug = "true"
+ debug-log = "/tmp/runsc-logs/"
+ strace = "true"
+ file-access = "shared"
+EOF
+fi
+
+# Configure CNI.
+(cd "${GOPATH}" && src/github.com/containerd/containerd/script/setup/install-cni)
+cat <<EOF | sudo tee /etc/cni/net.d/10-bridge.conf
+{
+ "cniVersion": "0.3.1",
+ "name": "bridge",
+ "type": "bridge",
+ "bridge": "cnio0",
+ "isGateway": true,
+ "ipMasq": true,
+ "ipam": {
+ "type": "host-local",
+ "ranges": [
+ [{"subnet": "10.200.0.0/24"}]
+ ],
+ "routes": [{"dst": "0.0.0.0/0"}]
+ }
+}
+EOF
+cat <<EOF | sudo tee /etc/cni/net.d/99-loopback.conf
+{
+ "cniVersion": "0.3.1",
+ "type": "loopback"
+}
+EOF
+
+# Configure crictl.
+cat <<EOF | sudo tee /etc/crictl.yaml
+runtime-endpoint: unix:///run/containerd/containerd.sock
+EOF
+
+# Cleanup.
+rm -rf "${GOPATH}"
diff --git a/tools/installers/head.sh b/tools/installers/head.sh
new file mode 100755
index 000000000..a613fcb5b
--- /dev/null
+++ b/tools/installers/head.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Install our runtime.
+runfiles=.
+if [[ -d "$0.runfiles" ]]; then
+ runfiles="$0.runfiles"
+fi
+$(find -L "${runfiles}" -executable -type f -name runsc) install
+
+# Restart docker.
+if service docker status 2>/dev/null; then
+ service docker restart
+fi
diff --git a/tools/installers/images.sh b/tools/installers/images.sh
new file mode 100755
index 000000000..52e750f57
--- /dev/null
+++ b/tools/installers/images.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeuo pipefail
+
+# Find the images directory.
+for images in $(find . -type d -name images); do
+ if [[ -f "${images}"/Makefile ]]; then
+ make -C "${images}" load-all-images
+ fi
+done
diff --git a/tools/installers/master.sh b/tools/installers/master.sh
new file mode 100755
index 000000000..2c6001c6c
--- /dev/null
+++ b/tools/installers/master.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Install runsc from the master branch.
+set -e
+
+curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
+add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
+
+while true; do
+ if (apt-get update && apt-get install -y runsc); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+runsc install
+service docker restart
diff --git a/tools/installers/shim.sh b/tools/installers/shim.sh
new file mode 100755
index 000000000..8153ce283
--- /dev/null
+++ b/tools/installers/shim.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Install all the shims.
+#
+# Note that containerd looks at the current executable directory
+# in order to find the shim binary. So we need to check in order
+# of preference. The local containerd installer will install to
+# /usr/local, so we use that first.
+if [[ -x /usr/local/bin/containerd ]]; then
+ containerd_install_dir=/usr/local/bin
+else
+ containerd_install_dir=/usr/bin
+fi
+runfiles=.
+if [[ -d "$0.runfiles" ]]; then
+ runfiles="$0.runfiles"
+fi
+find -L "${runfiles}" -executable -type f -name containerd-shim-runsc-v1 -exec cp -L {} "${containerd_install_dir}" \;
+find -L "${runfiles}" -executable -type f -name gvisor-containerd-shim -exec cp -L {} "${containerd_install_dir}" \;
diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD
new file mode 100644
index 000000000..4ef1a3124
--- /dev/null
+++ b/tools/issue_reviver/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "issue_reviver",
+ srcs = ["main.go"],
+ deps = [
+ "//tools/issue_reviver/github",
+ "//tools/issue_reviver/reviver",
+ ],
+)
diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD
new file mode 100644
index 000000000..0eabc2835
--- /dev/null
+++ b/tools/issue_reviver/github/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "github",
+ srcs = ["github.go"],
+ nogo = False,
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+ deps = [
+ "//tools/issue_reviver/reviver",
+ "@com_github_google_go_github_v28//github:go_default_library",
+ "@org_golang_x_oauth2//:go_default_library",
+ ],
+)
+
+go_test(
+ name = "github_test",
+ size = "small",
+ srcs = ["github_test.go"],
+ library = ":github",
+)
diff --git a/tools/issue_reviver/github/github.go b/tools/issue_reviver/github/github.go
new file mode 100644
index 000000000..8ffd7e606
--- /dev/null
+++ b/tools/issue_reviver/github/github.go
@@ -0,0 +1,176 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package github implements reviver.Bugger interface on top of Github issues.
+package github
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/google/go-github/github"
+ "golang.org/x/oauth2"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+// Bugger implements reviver.Bugger interface for github issues.
+type Bugger struct {
+ owner string
+ repo string
+ dryRun bool
+
+ client *github.Client
+ issues map[int]*github.Issue
+}
+
+// NewBugger creates a new Bugger.
+func NewBugger(token, owner, repo string, dryRun bool) (*Bugger, error) {
+ b := &Bugger{
+ owner: owner,
+ repo: repo,
+ dryRun: dryRun,
+ issues: map[int]*github.Issue{},
+ }
+ if err := b.load(token); err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+func (b *Bugger) load(token string) error {
+ ctx := context.Background()
+ if len(token) == 0 {
+ fmt.Print("No OAUTH token provided, using unauthenticated account.\n")
+ b.client = github.NewClient(nil)
+ } else {
+ ts := oauth2.StaticTokenSource(
+ &oauth2.Token{AccessToken: token},
+ )
+ tc := oauth2.NewClient(ctx, ts)
+ b.client = github.NewClient(tc)
+ }
+
+ err := processAllPages(func(listOpts github.ListOptions) (*github.Response, error) {
+ opts := &github.IssueListByRepoOptions{State: "open", ListOptions: listOpts}
+ tmps, resp, err := b.client.Issues.ListByRepo(ctx, b.owner, b.repo, opts)
+ if err != nil {
+ return resp, err
+ }
+ for _, issue := range tmps {
+ b.issues[issue.GetNumber()] = issue
+ }
+ return resp, nil
+ })
+ if err != nil {
+ return err
+ }
+
+ fmt.Printf("Loaded %d issues from github.com/%s/%s\n", len(b.issues), b.owner, b.repo)
+ return nil
+}
+
+// Activate implements reviver.Bugger.
+func (b *Bugger) Activate(todo *reviver.Todo) (bool, error) {
+ id, err := parseIssueNo(todo.Issue)
+ if err != nil {
+ return true, err
+ }
+ if id <= 0 {
+ return false, nil
+ }
+
+ // Check against active issues cache.
+ if _, ok := b.issues[id]; ok {
+ fmt.Printf("%q is active: OK\n", todo.Issue)
+ return true, nil
+ }
+
+ fmt.Printf("%q is not active: reopening issue %d\n", todo.Issue, id)
+
+ // Format comment with TODO locations and search link.
+ comment := strings.Builder{}
+ fmt.Fprintln(&comment, "There are TODOs still referencing this issue:")
+ for _, l := range todo.Locations {
+ fmt.Fprintf(&comment,
+ "1. [%s:%d](https://github.com/%s/%s/blob/HEAD/%s#%d): %s\n",
+ l.File, l.Line, b.owner, b.repo, l.File, l.Line, l.Comment)
+ }
+ fmt.Fprintf(&comment,
+ "\n\nSearch [TODO](https://github.com/%s/%s/search?q=%%22%s%%22)", b.owner, b.repo, todo.Issue)
+
+ if b.dryRun {
+ fmt.Printf("[dry-run: skipping change to issue %d]\n%s\n=======================\n", id, comment.String())
+ return true, nil
+ }
+
+ ctx := context.Background()
+ req := &github.IssueRequest{State: github.String("open")}
+ _, _, err = b.client.Issues.Edit(ctx, b.owner, b.repo, id, req)
+ if err != nil {
+ return true, fmt.Errorf("failed to reactivate issue %d: %v", id, err)
+ }
+
+ cmt := &github.IssueComment{
+ Body: github.String(comment.String()),
+ Reactions: &github.Reactions{Confused: github.Int(1)},
+ }
+ if _, _, err := b.client.Issues.CreateComment(ctx, b.owner, b.repo, id, cmt); err != nil {
+ return true, fmt.Errorf("failed to add comment to issue %d: %v", id, err)
+ }
+
+ return true, nil
+}
+
+// parseIssueNo parses the issue number out of the issue url.
+func parseIssueNo(url string) (int, error) {
+ const prefix = "gvisor.dev/issue/"
+
+ // First check if I can handle the TODO.
+ idStr := strings.TrimPrefix(url, prefix)
+ if len(url) == len(idStr) {
+ return 0, nil
+ }
+
+ id, err := strconv.ParseInt(strings.TrimRight(idStr, "/"), 10, 64)
+ if err != nil {
+ return 0, err
+ }
+ return int(id), nil
+}
+
+func processAllPages(fn func(github.ListOptions) (*github.Response, error)) error {
+ opts := github.ListOptions{PerPage: 1000}
+ for {
+ resp, err := fn(opts)
+ if err != nil {
+ if rateErr, ok := err.(*github.RateLimitError); ok {
+ duration := rateErr.Rate.Reset.Sub(time.Now())
+ if duration > 5*time.Minute {
+ return fmt.Errorf("Rate limited for too long: %v", duration)
+ }
+ fmt.Printf("Rate limited, sleeping for: %v\n", duration)
+ time.Sleep(duration)
+ continue
+ }
+ return err
+ }
+ if resp.NextPage == 0 {
+ return nil
+ }
+ opts.Page = resp.NextPage
+ }
+}
diff --git a/tools/issue_reviver/github/github_test.go b/tools/issue_reviver/github/github_test.go
new file mode 100644
index 000000000..a78b230ef
--- /dev/null
+++ b/tools/issue_reviver/github/github_test.go
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package github
+
+import (
+ "testing"
+)
+
+func TestParseIssueNo(t *testing.T) {
+ testCases := []struct {
+ issue string
+ expectErr bool
+ expected int
+ }{
+ {
+ issue: "gvisor.dev/issue/123",
+ expected: 123,
+ },
+ {
+ issue: "gvisor.dev/issue/123/",
+ expected: 123,
+ },
+ {
+ issue: "not a url",
+ expected: 0,
+ },
+ {
+ issue: "gvisor.dev/issue//",
+ expectErr: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.issue, func(t *testing.T) {
+ id, err := parseIssueNo(tc.issue)
+ if err != nil && !tc.expectErr {
+ t.Errorf("got error: %v", err)
+ } else if tc.expected != id {
+ t.Errorf("got: %v, want: %v", id, tc.expected)
+ }
+ })
+ }
+}
diff --git a/tools/issue_reviver/main.go b/tools/issue_reviver/main.go
new file mode 100644
index 000000000..47c796b8a
--- /dev/null
+++ b/tools/issue_reviver/main.go
@@ -0,0 +1,100 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package main is the entry point for issue_reviver.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/issue_reviver/github"
+ "gvisor.dev/gvisor/tools/issue_reviver/reviver"
+)
+
+var (
+ owner string
+ repo string
+ tokenFile string
+ path string
+ dryRun bool
+)
+
+// Keep the options simple for now. Supports only a single path and repo.
+func init() {
+ flag.StringVar(&owner, "owner", "", "Github project org/owner to look for issues")
+ flag.StringVar(&repo, "repo", "", "Github repo to look for issues")
+ flag.StringVar(&tokenFile, "oauth-token-file", "", "Path to file containing the OAUTH token to be used as credential to github")
+ flag.StringVar(&path, "path", ".", "Path to scan for TODOs")
+ flag.BoolVar(&dryRun, "dry-run", false, "If set to true, no changes are made to issues")
+}
+
+func main() {
+ // Set defaults from the environment.
+ repository := os.Getenv("GITHUB_REPOSITORY")
+ if parts := strings.SplitN(repository, "/", 2); len(parts) == 2 {
+ owner = parts[0]
+ repo = parts[1]
+ }
+
+ // Parse flags.
+ flag.Parse()
+
+ // Check for mandatory parameters.
+ if len(owner) == 0 {
+ fmt.Println("missing --owner option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(repo) == 0 {
+ fmt.Println("missing --repo option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+ if len(path) == 0 {
+ fmt.Println("missing --path option.")
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ // The access token may be passed as a file so it doesn't show up in
+ // command line arguments. It also may be provided through the
+ // environment to faciliate use through GitHub's CI system.
+ token := os.Getenv("GITHUB_TOKEN")
+ if len(tokenFile) != 0 {
+ bytes, err := ioutil.ReadFile(tokenFile)
+ if err != nil {
+ fmt.Println(err.Error())
+ os.Exit(1)
+ }
+ token = string(bytes)
+ }
+
+ bugger, err := github.NewBugger(token, owner, repo, dryRun)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "Error getting github issues:", err)
+ os.Exit(1)
+ }
+ rev := reviver.New([]string{path}, []reviver.Bugger{bugger})
+ if errs := rev.Run(); len(errs) > 0 {
+ fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs))
+ for _, err := range errs {
+ fmt.Fprintf(os.Stderr, "\t%v\n", err)
+ }
+ os.Exit(1)
+ }
+}
diff --git a/tools/issue_reviver/reviver/BUILD b/tools/issue_reviver/reviver/BUILD
new file mode 100644
index 000000000..d262932bd
--- /dev/null
+++ b/tools/issue_reviver/reviver/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "reviver",
+ srcs = ["reviver.go"],
+ visibility = [
+ "//tools/issue_reviver:__subpackages__",
+ ],
+)
+
+go_test(
+ name = "reviver_test",
+ size = "small",
+ srcs = ["reviver_test.go"],
+ library = ":reviver",
+)
diff --git a/tools/issue_reviver/reviver/reviver.go b/tools/issue_reviver/reviver/reviver.go
new file mode 100644
index 000000000..2af7f0d59
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver.go
@@ -0,0 +1,192 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package reviver scans the code looking for TODOs and pass them to registered
+// Buggers to ensure TODOs point to active issues.
+package reviver
+
+import (
+ "bufio"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "regexp"
+ "sync"
+)
+
+// regexTodo matches a TODO or FIXME comment.
+var regexTodo = regexp.MustCompile(`(\/\/|#)\s*(TODO|FIXME)\(([a-zA-Z0-9.\/]+)\):\s*(.+)`)
+
+// Bugger interface is called for every TODO found in the code. If it can handle
+// the TODO, it must return true. If it returns false, the next Bugger is
+// called. If no Bugger handles the TODO, it's dropped on the floor.
+type Bugger interface {
+ Activate(todo *Todo) (bool, error)
+}
+
+// Location saves the location where the TODO was found.
+type Location struct {
+ Comment string
+ File string
+ Line uint
+}
+
+// Todo represents a unique TODO. There can be several TODOs pointing to the
+// same issue in the code. They are all grouped together.
+type Todo struct {
+ Issue string
+ Locations []Location
+}
+
+// Reviver scans the given paths for TODOs and calls Buggers to handle them.
+type Reviver struct {
+ paths []string
+ buggers []Bugger
+
+ mu sync.Mutex
+ todos map[string]*Todo
+ errs []error
+}
+
+// New create a new Reviver.
+func New(paths []string, buggers []Bugger) *Reviver {
+ return &Reviver{
+ paths: paths,
+ buggers: buggers,
+ todos: map[string]*Todo{},
+ }
+}
+
+// Run runs. It returns all errors found during processing, it doesn't stop
+// on errors.
+func (r *Reviver) Run() []error {
+ // Process each directory in parallel.
+ wg := sync.WaitGroup{}
+ for _, path := range r.paths {
+ wg.Add(1)
+ go func(path string) {
+ defer wg.Done()
+ r.processPath(path, &wg)
+ }(path)
+ }
+
+ wg.Wait()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ fmt.Printf("Processing %d TODOs (%d errors)...\n", len(r.todos), len(r.errs))
+ dropped := 0
+ for _, todo := range r.todos {
+ ok, err := r.processTodo(todo)
+ if err != nil {
+ r.errs = append(r.errs, err)
+ }
+ if !ok {
+ dropped++
+ }
+ }
+ fmt.Printf("Processed %d TODOs, %d were skipped (%d errors)\n", len(r.todos)-dropped, dropped, len(r.errs))
+
+ return r.errs
+}
+
+func (r *Reviver) processPath(path string, wg *sync.WaitGroup) {
+ fmt.Printf("Processing dir %q\n", path)
+ fis, err := ioutil.ReadDir(path)
+ if err != nil {
+ r.addErr(fmt.Errorf("error processing dir %q: %v", path, err))
+ return
+ }
+
+ for _, fi := range fis {
+ childPath := filepath.Join(path, fi.Name())
+ switch {
+ case fi.Mode().IsDir():
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ r.processPath(childPath, wg)
+ }()
+
+ case fi.Mode().IsRegular():
+ file, err := os.Open(childPath)
+ if err != nil {
+ r.addErr(err)
+ continue
+ }
+
+ scanner := bufio.NewScanner(file)
+ lineno := uint(0)
+ for scanner.Scan() {
+ lineno++
+ line := scanner.Text()
+ if todo := r.processLine(line, childPath, lineno); todo != nil {
+ r.addTodo(todo)
+ }
+ }
+ }
+ }
+}
+
+func (r *Reviver) processLine(line, path string, lineno uint) *Todo {
+ matches := regexTodo.FindStringSubmatch(line)
+ if matches == nil {
+ return nil
+ }
+ if len(matches) != 5 {
+ panic(fmt.Sprintf("regex returned wrong matches for %q: %v", line, matches))
+ }
+ return &Todo{
+ Issue: matches[3],
+ Locations: []Location{
+ {
+ File: path,
+ Line: lineno,
+ Comment: matches[4],
+ },
+ },
+ }
+}
+
+func (r *Reviver) addTodo(newTodo *Todo) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if todo := r.todos[newTodo.Issue]; todo == nil {
+ r.todos[newTodo.Issue] = newTodo
+ } else {
+ todo.Locations = append(todo.Locations, newTodo.Locations...)
+ }
+}
+
+func (r *Reviver) addErr(err error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.errs = append(r.errs, err)
+}
+
+func (r *Reviver) processTodo(todo *Todo) (bool, error) {
+ for _, bugger := range r.buggers {
+ ok, err := bugger.Activate(todo)
+ if err != nil {
+ return false, err
+ }
+ if ok {
+ return true, nil
+ }
+ }
+ return false, nil
+}
diff --git a/tools/issue_reviver/reviver/reviver_test.go b/tools/issue_reviver/reviver/reviver_test.go
new file mode 100644
index 000000000..a9fb1f9f1
--- /dev/null
+++ b/tools/issue_reviver/reviver/reviver_test.go
@@ -0,0 +1,88 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package reviver
+
+import (
+ "testing"
+)
+
+func TestProcessLine(t *testing.T) {
+ for _, tc := range []struct {
+ line string
+ want *Todo
+ }{
+ {
+ line: "// TODO(foobar.com/issue/123): comment, bla. blabla.",
+ want: &Todo{
+ Issue: "foobar.com/issue/123",
+ Locations: []Location{
+ {Comment: "comment, bla. blabla."},
+ },
+ },
+ },
+ {
+ line: "// FIXME(b/123): internal bug",
+ want: &Todo{
+ Issue: "b/123",
+ Locations: []Location{
+ {Comment: "internal bug"},
+ },
+ },
+ },
+ {
+ line: "TODO(issue): not todo",
+ },
+ {
+ line: "FIXME(issue): not todo",
+ },
+ {
+ line: "// TODO (issue): not todo",
+ },
+ {
+ line: "// TODO(issue) not todo",
+ },
+ {
+ line: "// todo(issue): not todo",
+ },
+ {
+ line: "// TODO(issue):",
+ },
+ } {
+ t.Logf("Testing: %s", tc.line)
+ r := Reviver{}
+ got := r.processLine(tc.line, "test", 0)
+ if got == nil {
+ if tc.want != nil {
+ t.Errorf("failed to process line, want: %+v", tc.want)
+ }
+ } else {
+ if tc.want == nil {
+ t.Errorf("expected error, got: %+v", got)
+ continue
+ }
+ if got.Issue != tc.want.Issue {
+ t.Errorf("wrong issue, got: %v, want: %v", got.Issue, tc.want.Issue)
+ }
+ if len(got.Locations) != len(tc.want.Locations) {
+ t.Errorf("wrong number of locations, got: %v, want: %v, locations: %+v", len(got.Locations), len(tc.want.Locations), got.Locations)
+ }
+ for i, wantLoc := range tc.want.Locations {
+ if got.Locations[i].Comment != wantLoc.Comment {
+ t.Errorf("wrong comment, got: %v, want: %v", got.Locations[i].Comment, wantLoc.Comment)
+ }
+ }
+ }
+ }
+}
diff --git a/tools/make_apt.sh b/tools/make_apt.sh
new file mode 100755
index 000000000..3fb1066e5
--- /dev/null
+++ b/tools/make_apt.sh
@@ -0,0 +1,139 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [[ "$#" -le 3 ]]; then
+ echo "usage: $0 <private-key> <suite> <root> <packages...>"
+ exit 1
+fi
+declare -r private_key=$(readlink -e "$1"); shift
+declare -r suite="$1"; shift
+declare -r root="$1"; shift
+
+# Ensure that we have the correct packages installed.
+function apt_install() {
+ while true; do
+ sudo apt-get update &&
+ sudo apt-get install -y "$@" &&
+ true
+ result="${?}"
+ case $result in
+ 0)
+ break
+ ;;
+ 100)
+ # 100 is the error code that apt-get returns.
+ ;;
+ *)
+ exit $result
+ ;;
+ esac
+ done
+}
+dpkg-sig --help >/dev/null 2>&1 || apt_install dpkg-sig
+apt-ftparchive --help >/dev/null 2>&1 || apt_install apt-utils
+xz --help >/dev/null 2>&1 || apt_install xz-utils
+
+# Verbose from this point.
+set -xeo pipefail
+
+# Create a directory for the release.
+declare -r release="${root}/dists/${suite}"
+mkdir -p "${release}"
+
+# Create a temporary keyring, and ensure it is cleaned up.
+declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg)
+cleanup() {
+ rm -f "${keyring}"
+}
+trap cleanup EXIT
+
+# We attempt the import twice because the first one will fail if the public key
+# is not found. This isn't actually a failure for us, because we don't require
+# the public (this may be stored separately). The second import will succeed
+# because, in reality, the first import succeeded and it's a no-op.
+gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" || \
+ gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}"
+
+# Copy the packages into the root.
+for pkg in "$@"; do
+ ext=${pkg##*.}
+ name=$(basename "${pkg}" ".${ext}")
+ arch=${name##*_}
+ if [[ "${name}" == "${arch}" ]]; then
+ continue # Not a regular package.
+ fi
+ if [[ "${pkg}" =~ ^.*\.deb$ ]]; then
+ # Extract from the debian file.
+ version=$(dpkg --info "${pkg}" | grep -E 'Version:' | cut -d':' -f2)
+ elif [[ "${pkg}" =~ ^.*\.changes$ ]]; then
+ # Extract from the changes file.
+ version=$(grep -E 'Version:' "${pkg}" | cut -d':' -f2)
+ else
+ # Unsupported file type.
+ echo "Unknown file type: ${pkg}"
+ exit 1
+ fi
+
+ # The package may already exist, in which case we leave it alone.
+ version=${version// /} # Trim whitespace.
+ destdir="${root}/pool/${version}/binary-${arch}"
+ target="${destdir}/${name}.${ext}"
+ if [[ -f "${target}" ]]; then
+ continue
+ fi
+
+ # Copy & sign the package.
+ mkdir -p "${destdir}"
+ cp -a "${pkg}" "${target}"
+ chmod 0644 "${target}"
+ if [[ "${ext}" == "deb" ]]; then
+ dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${target}"
+ fi
+done
+
+# Build the package list.
+declare arches=()
+for dir in "${root}"/pool/*/binary-*; do
+ name=$(basename "${dir}")
+ arch=${name##binary-}
+ arches+=("${arch}")
+ repo_packages="${release}"/main/"${name}"
+ mkdir -p "${repo_packages}"
+ (cd "${root}" && apt-ftparchive --arch "${arch}" packages pool > "${repo_packages}"/Packages)
+ (cd "${repo_packages}" && cat Packages | gzip > Packages.gz)
+ (cd "${repo_packages}" && cat Packages | xz > Packages.xz)
+done
+
+# Build the release list.
+cat > "${release}"/apt.conf <<EOF
+APT {
+ FTPArchive {
+ Release {
+ Architectures "${arches[@]}";
+ Suite "${suite}";
+ Components "main";
+ };
+ };
+};
+EOF
+(cd "${release}" && apt-ftparchive -c=apt.conf release . > Release)
+rm "${release}"/apt.conf
+
+# Sign the release.
+declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512")
+(cd "${release}" && rm -f Release.gpg InRelease)
+(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release)
+(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release)
diff --git a/tools/make_release.sh b/tools/make_release.sh
new file mode 100755
index 000000000..9137dd9bb
--- /dev/null
+++ b/tools/make_release.sh
@@ -0,0 +1,81 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if [[ "$#" -le 2 ]]; then
+ echo "usage: $0 <private-key> <root> <binaries & packages...>"
+ echo "The environment variable NIGHTLY may be set to control"
+ echo "whether the nightly packages are produced or not."
+ exit 1
+fi
+
+set -xeo pipefail
+declare -r private_key="$1"; shift
+declare -r root="$1"; shift
+declare -a binaries
+declare -a pkgs
+
+# Collect binaries & packages.
+for arg in "$@"; do
+ if [[ "${arg}" == *.deb ]] || [[ "${arg}" == *.changes ]]; then
+ pkgs+=("${arg}")
+ else
+ binaries+=("${arg}")
+ fi
+done
+
+# install_raw installs raw artifacts.
+install_raw() {
+ mkdir -p "${root}/$1"
+ for binary in "${binaries[@]}"; do
+ # Copy the raw file & generate a sha512sum.
+ name=$(basename "${binary}")
+ cp -f "${binary}" "${root}/$1"
+ (cd "${root}/$1" && sha512sum "${name}" > "${name}.sha512")
+ done
+}
+
+# install_apt installs an apt repository.
+install_apt() {
+ tools/make_apt.sh "${private_key}" "$1" "${root}" "${pkgs[@]}"
+}
+
+# If nightly, install only nightly artifacts.
+if [[ "${NIGHTLY:-false}" == "true" ]]; then
+ # The "latest" directory and current date.
+ stamp="$(date -Idate)"
+ install_raw "nightly/latest"
+ install_raw "nightly/${stamp}"
+ install_apt "nightly"
+else
+ # Is it a tagged release? Build that.
+ tags="$(git tag --points-at HEAD 2>/dev/null || true)"
+ if ! [[ -z "${tags}" ]]; then
+ # Note that a given commit can match any number of tags. We have to iterate
+ # through all possible tags and produce associated artifacts.
+ for tag in ${tags}; do
+ name=$(echo "${tag}" | cut -d'-' -f2)
+ base=$(echo "${name}" | cut -d'.' -f1)
+ install_raw "release/${name}"
+ install_raw "release/latest"
+ install_apt "release"
+ install_apt "${base}"
+ done
+ else
+ # Otherwise, assume it is a raw master commit.
+ install_raw "master/latest"
+ install_apt "master"
+ fi
+fi
diff --git a/tools/make_repository.sh b/tools/make_repository.sh
deleted file mode 100755
index 071f72b74..000000000
--- a/tools/make_repository.sh
+++ /dev/null
@@ -1,79 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Parse arguments. We require more than two arguments, which are the private
-# keyring, the e-mail associated with the signer, and the list of packages.
-if [ "$#" -le 3 ]; then
- echo "usage: $0 <private-key> <signer-email> <component> <packages...>"
- exit 1
-fi
-declare -r private_key=$(readlink -e "$1")
-declare -r signer="$2"
-declare -r component="$3"
-shift; shift; shift
-
-# Verbose from this point.
-set -xeo pipefail
-
-# Create a temporary working directory. We don't remove this, as we ultimately
-# print this result and allow the caller to copy wherever they would like.
-declare -r tmpdir=$(mktemp -d /tmp/repoXXXXXX)
-
-# Create a temporary keyring, and ensure it is cleaned up.
-declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg)
-cleanup() {
- rm -f "${keyring}"
-}
-trap cleanup EXIT
-gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2
-
-# Copy the packages, and ensure permissions are correct.
-for pkg in "$@"; do
- name=$(basename "${pkg}" .deb)
- name=$(basename "${name}" .changes)
- arch=${name##*_}
- if [[ "${name}" == "${arch}" ]]; then
- continue # Not a regular package.
- fi
- mkdir -p "${tmpdir}"/"${component}"/binary-"${arch}"
- cp -a "${pkg}" "${tmpdir}"/"${component}"/binary-"${arch}"
-done
-find "${tmpdir}" -type f -exec chmod 0644 {} \;
-
-# Ensure there are no symlinks hanging around; these may be remnants of the
-# build process. They may be useful for other things, but we are going to build
-# an index of the actual packages here.
-find "${tmpdir}" -type l -exec rm -f {} \;
-
-# Sign all packages.
-for file in "${tmpdir}"/"${component}"/binary-*/*.deb; do
- dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2
-done
-
-# Build the package list.
-for dir in "${tmpdir}"/"${component}"/binary-*; do
- (cd "${dir}" && apt-ftparchive packages . | gzip > Packages.gz)
-done
-
-# Build the release list.
-(cd "${tmpdir}" && apt-ftparchive release . > Release)
-
-# Sign the release.
-(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release >&2)
-(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release >&2)
-
-# Show the results.
-echo "${tmpdir}"
diff --git a/tools/nogo.js b/tools/nogo.js
deleted file mode 100644
index fc0a4d1f0..000000000
--- a/tools/nogo.js
+++ /dev/null
@@ -1,7 +0,0 @@
-{
- "checkunsafe": {
- "exclude_files": {
- "/external/": "not subject to constraint"
- }
- }
-}
diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD
new file mode 100644
index 000000000..e1bfb9a2c
--- /dev/null
+++ b/tools/nogo/BUILD
@@ -0,0 +1,55 @@
+load("//tools:defs.bzl", "bzl_library", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "nogo",
+ srcs = [
+ "build.go",
+ "config.go",
+ "matchers.go",
+ "nogo.go",
+ "register.go",
+ ],
+ nogo = False,
+ visibility = ["//:sandbox"],
+ deps = [
+ "//tools/checkescape",
+ "//tools/checkunsafe",
+ "//tools/nogo/data",
+ "@org_golang_x_tools//go/analysis:go_tool_library",
+ "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/assign:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/atomic:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/bools:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/buildtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/cgocall:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/composite:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/copylock:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/errorsas:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/httpresponse:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/loopclosure:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/lostcancel:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilfunc:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/nilness:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/printf:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/shadow:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/shift:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/stdmethods:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/stringintconv:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/structtag:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/tests:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unmarshal:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unreachable:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unsafeptr:go_tool_library",
+ "@org_golang_x_tools//go/analysis/passes/unusedresult:go_tool_library",
+ "@org_golang_x_tools//go/gcexportdata:go_tool_library",
+ ],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/nogo/README.md b/tools/nogo/README.md
new file mode 100644
index 000000000..6e4db18de
--- /dev/null
+++ b/tools/nogo/README.md
@@ -0,0 +1,31 @@
+# Extended "nogo" analysis
+
+This package provides a build aspect that perform nogo analysis. This will be
+automatically injected to all relevant libraries when using the default
+`go_binary` and `go_library` rules.
+
+It exists for several reasons.
+
+* The default `nogo` provided by bazel is insufficient with respect to the
+ possibility of binary analysis. This package allows us to analyze the
+ generated binary in addition to using the standard analyzers.
+
+* The configuration provided in this package is much richer than the standard
+ `nogo` JSON blob. Specifically, it allows us to exclude specific structures
+ from the composite rules (such as the Ranges that are common with the set
+ types).
+
+* The bazel version of `nogo` is run directly against the `go_library` and
+ `go_binary` targets, meaning that any change to the configuration requires a
+ rebuild from scratch (for some reason included all C++ source files in the
+ process). Using an aspect is more efficient in this regard.
+
+* The checks supported by this package are exported as tests, which makes it
+ easier to reason about and plumb into the build system.
+
+* For uninteresting reasons, it is impossible to integrate the default `nogo`
+ analyzer provided by bazel with internal Google tooling. To provide a
+ consistent experience, this package allows those systems to be unified.
+
+To use this package, import `nogo_test` from `defs.bzl` and add a single
+dependency which is a `go_binary` or `go_library` rule.
diff --git a/tools/nogo/build.go b/tools/nogo/build.go
new file mode 100644
index 000000000..433d13738
--- /dev/null
+++ b/tools/nogo/build.go
@@ -0,0 +1,40 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "fmt"
+ "io"
+ "os"
+)
+
+var (
+ // internalPrefix is the internal path prefix. Note that this is not
+ // special, as paths should be passed relative to the repository root
+ // and should not have any special prefix applied.
+ internalPrefix = fmt.Sprintf("^")
+
+ // externalPrefix is external workspace packages.
+ externalPrefix = "^external/"
+)
+
+// findStdPkg needs to find the bundled standard library packages.
+func (i *importer) findStdPkg(path string) (io.ReadCloser, error) {
+ if path == "C" {
+ // Cgo builds cannot be analyzed. Skip.
+ return nil, ErrSkip
+ }
+ return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", i.GOOS, i.GOARCH, path))
+}
diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD
new file mode 100644
index 000000000..e2d76cd5c
--- /dev/null
+++ b/tools/nogo/check/BUILD
@@ -0,0 +1,12 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+# Note that the check binary must be public, since an aspect may be applied
+# across lots of different rules in different repositories.
+go_binary(
+ name = "check",
+ srcs = ["main.go"],
+ visibility = ["//visibility:public"],
+ deps = ["//tools/nogo"],
+)
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
new file mode 100644
index 000000000..3828edf3a
--- /dev/null
+++ b/tools/nogo/check/main.go
@@ -0,0 +1,24 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary check is the nogo entrypoint.
+package main
+
+import (
+ "gvisor.dev/gvisor/tools/nogo"
+)
+
+func main() {
+ nogo.Main()
+}
diff --git a/tools/nogo/config.go b/tools/nogo/config.go
new file mode 100644
index 000000000..6958fca69
--- /dev/null
+++ b/tools/nogo/config.go
@@ -0,0 +1,116 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/passes/asmdecl"
+ "golang.org/x/tools/go/analysis/passes/assign"
+ "golang.org/x/tools/go/analysis/passes/atomic"
+ "golang.org/x/tools/go/analysis/passes/bools"
+ "golang.org/x/tools/go/analysis/passes/buildtag"
+ "golang.org/x/tools/go/analysis/passes/cgocall"
+ "golang.org/x/tools/go/analysis/passes/composite"
+ "golang.org/x/tools/go/analysis/passes/copylock"
+ "golang.org/x/tools/go/analysis/passes/errorsas"
+ "golang.org/x/tools/go/analysis/passes/httpresponse"
+ "golang.org/x/tools/go/analysis/passes/loopclosure"
+ "golang.org/x/tools/go/analysis/passes/lostcancel"
+ "golang.org/x/tools/go/analysis/passes/nilfunc"
+ "golang.org/x/tools/go/analysis/passes/nilness"
+ "golang.org/x/tools/go/analysis/passes/printf"
+ "golang.org/x/tools/go/analysis/passes/shadow"
+ "golang.org/x/tools/go/analysis/passes/shift"
+ "golang.org/x/tools/go/analysis/passes/stdmethods"
+ "golang.org/x/tools/go/analysis/passes/stringintconv"
+ "golang.org/x/tools/go/analysis/passes/structtag"
+ "golang.org/x/tools/go/analysis/passes/tests"
+ "golang.org/x/tools/go/analysis/passes/unmarshal"
+ "golang.org/x/tools/go/analysis/passes/unreachable"
+ "golang.org/x/tools/go/analysis/passes/unsafeptr"
+ "golang.org/x/tools/go/analysis/passes/unusedresult"
+
+ "gvisor.dev/gvisor/tools/checkescape"
+ "gvisor.dev/gvisor/tools/checkunsafe"
+)
+
+var analyzerConfig = map[*analysis.Analyzer]matcher{
+ // Standard analyzers.
+ asmdecl.Analyzer: alwaysMatches(),
+ assign.Analyzer: externalExcluded(
+ ".*gazelle/walk/walk.go", // False positive.
+ ),
+ atomic.Analyzer: alwaysMatches(),
+ bools.Analyzer: alwaysMatches(),
+ buildtag.Analyzer: alwaysMatches(),
+ cgocall.Analyzer: alwaysMatches(),
+ composite.Analyzer: and(
+ disableMatches(), // Disabled for now.
+ resultExcluded{
+ "Object_",
+ "Range{",
+ },
+ ),
+ copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos).
+ errorsas.Analyzer: alwaysMatches(),
+ httpresponse.Analyzer: alwaysMatches(),
+ loopclosure.Analyzer: alwaysMatches(),
+ lostcancel.Analyzer: internalMatches(), // Common external issues.
+ nilfunc.Analyzer: alwaysMatches(),
+ nilness.Analyzer: and(
+ internalMatches(), // Common "tautological checks".
+ internalExcluded(
+ "pkg/sentry/platform/kvm/kvm_test.go", // Intentional.
+ "tools/bigquery/bigquery.go", // False positive.
+ ),
+ ),
+ printf.Analyzer: alwaysMatches(),
+ shift.Analyzer: alwaysMatches(),
+ stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write").
+ stringintconv.Analyzer: and(
+ internalExcluded(),
+ externalExcluded(
+ ".*protobuf/.*.go", // Bad conversions.
+ ".*flate/huffman_bit_writer.go", // Bad conversion.
+ ),
+ ),
+ shadow.Analyzer: disableMatches(), // Disabled for now.
+ structtag.Analyzer: internalMatches(), // External not subject to rules.
+ tests.Analyzer: alwaysMatches(),
+ unmarshal.Analyzer: alwaysMatches(),
+ unreachable.Analyzer: internalMatches(),
+ unsafeptr.Analyzer: and(
+ internalMatches(),
+ internalExcluded(
+ ".*_test.go", // Exclude tests.
+ "pkg/flipcall/.*_unsafe.go", // Special case.
+ "pkg/gohacks/gohacks_unsafe.go", // Special case.
+ "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case.
+ "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case.
+ "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case.
+ "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case.
+ "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case.
+ "pkg/sentry/vfs/mount_unsafe.go", // Special case.
+ "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case.
+ "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case.
+ "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case.
+ ),
+ ),
+ unusedresult.Analyzer: alwaysMatches(),
+
+ // Internal analyzers: external packages not subject.
+ checkescape.Analyzer: internalMatches(),
+ checkunsafe.Analyzer: internalMatches(),
+}
diff --git a/tools/nogo/data/BUILD b/tools/nogo/data/BUILD
new file mode 100644
index 000000000..b7564cc44
--- /dev/null
+++ b/tools/nogo/data/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "data",
+ srcs = ["data.go"],
+ nogo = False,
+ visibility = ["//tools:__subpackages__"],
+)
diff --git a/tools/nogo/data/data.go b/tools/nogo/data/data.go
new file mode 100644
index 000000000..eb84d0d27
--- /dev/null
+++ b/tools/nogo/data/data.go
@@ -0,0 +1,21 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package data contains shared data for nogo analysis.
+//
+// This is used to break a dependency cycle.
+package data
+
+// Objdump is the dumped binary under analysis.
+var Objdump string
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
new file mode 100644
index 000000000..d399079c5
--- /dev/null
+++ b/tools/nogo/defs.bzl
@@ -0,0 +1,176 @@
+"""Nogo rules."""
+
+load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule")
+
+# NogoInfo is the serialized set of package facts for a nogo analysis.
+#
+# Each go_library rule will generate a corresponding nogo rule, which will run
+# with the source files as input. Note however, that the individual nogo rules
+# are simply stubs that enter into the shadow dependency tree (the "aspect").
+NogoInfo = provider(
+ fields = {
+ "facts": "serialized package facts",
+ "importpath": "package import path",
+ "binaries": "package binary files",
+ },
+)
+
+def _nogo_aspect_impl(target, ctx):
+ # If this is a nogo rule itself (and not the shadow of a go_library or
+ # go_binary rule created by such a rule), then we simply return nothing.
+ # All work is done in the shadow properties for go rules. For a proto
+ # library, we simply skip the analysis portion but still need to return a
+ # valid NogoInfo to reference the generated binary.
+ if ctx.rule.kind == "go_library":
+ srcs = ctx.rule.files.srcs
+ elif ctx.rule.kind == "go_proto_library" or ctx.rule.kind == "go_wrap_cc":
+ srcs = []
+ else:
+ return [NogoInfo()]
+
+ go_ctx = go_context(ctx)
+
+ # Construct the Go environment from the go_ctx.env dictionary.
+ env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()])
+
+ # Start with all target files and srcs as input.
+ inputs = target.files.to_list() + srcs
+
+ # Generate a shell script that dumps the binary. Annoyingly, this seems
+ # necessary as the context in which a run_shell command runs does not seem
+ # to cleanly allow us redirect stdout to the actual output file. Perhaps
+ # I'm missing something here, but the intermediate script does work.
+ binaries = target.files.to_list()
+ disasm_file = ctx.actions.declare_file(target.label.name + ".out")
+ dumper = ctx.actions.declare_file("%s-dumper" % ctx.label.name)
+ ctx.actions.write(dumper, "\n".join([
+ "#!/bin/bash",
+ "%s %s tool objdump %s > %s\n" % (
+ env_prefix,
+ go_ctx.go.path,
+ [f.path for f in binaries if f.path.endswith(".a")][0],
+ disasm_file.path,
+ ),
+ ]), is_executable = True)
+ ctx.actions.run(
+ inputs = binaries,
+ outputs = [disasm_file],
+ tools = go_ctx.runfiles,
+ mnemonic = "GoObjdump",
+ progress_message = "Objdump %s" % target.label,
+ executable = dumper,
+ )
+ inputs.append(disasm_file)
+
+ # Extract the importpath for this package.
+ importpath = go_importpath(target)
+
+ # The nogo tool requires a configfile serialized in JSON format to do its
+ # work. This must line up with the nogo.Config fields.
+ facts = ctx.actions.declare_file(target.label.name + ".facts")
+ config = struct(
+ ImportPath = importpath,
+ GoFiles = [src.path for src in srcs if src.path.endswith(".go")],
+ NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")],
+ # Google's internal build system needs a bit more help to find std.
+ StdZip = go_ctx.std_zip.short_path if hasattr(go_ctx, "std_zip") else "",
+ GOOS = go_ctx.goos,
+ GOARCH = go_ctx.goarch,
+ Tags = go_ctx.tags,
+ FactMap = {}, # Constructed below.
+ ImportMap = {}, # Constructed below.
+ FactOutput = facts.path,
+ Objdump = disasm_file.path,
+ )
+
+ # Collect all info from shadow dependencies.
+ for dep in ctx.rule.attr.deps:
+ # There will be no file attribute set for all transitive dependencies
+ # that are not go_library or go_binary rules, such as a proto rules.
+ # This is handled by the ctx.rule.kind check above.
+ info = dep[NogoInfo]
+ if not hasattr(info, "facts"):
+ continue
+
+ # Configure where to find the binary & fact files. Note that this will
+ # use .x and .a regardless of whether this is a go_binary rule, since
+ # these dependencies must be go_library rules.
+ x_files = [f.path for f in info.binaries if f.path.endswith(".x")]
+ if not len(x_files):
+ x_files = [f.path for f in info.binaries if f.path.endswith(".a")]
+ config.ImportMap[info.importpath] = x_files[0]
+ config.FactMap[info.importpath] = info.facts.path
+
+ # Ensure the above are available as inputs.
+ inputs.append(info.facts)
+ inputs += info.binaries
+
+ # Write the configuration and run the tool.
+ config_file = ctx.actions.declare_file(target.label.name + ".cfg")
+ ctx.actions.write(config_file, config.to_json())
+ inputs.append(config_file)
+
+ # Run the nogo tool itself.
+ ctx.actions.run(
+ inputs = inputs,
+ outputs = [facts],
+ tools = go_ctx.runfiles,
+ executable = ctx.files._nogo[0],
+ mnemonic = "GoStaticAnalysis",
+ progress_message = "Analyzing %s" % target.label,
+ arguments = ["-config=%s" % config_file.path],
+ )
+
+ # Return the package facts as output.
+ return [NogoInfo(
+ facts = facts,
+ importpath = importpath,
+ binaries = binaries,
+ )]
+
+nogo_aspect = go_rule(
+ aspect,
+ implementation = _nogo_aspect_impl,
+ attr_aspects = ["deps"],
+ attrs = {
+ "_nogo": attr.label(
+ default = "//tools/nogo/check:check",
+ allow_single_file = True,
+ ),
+ },
+)
+
+def _nogo_test_impl(ctx):
+ """Check nogo findings."""
+
+ # Build a runner that checks for the existence of the facts file. Note that
+ # the actual build will fail in the case of a broken analysis. We things
+ # this way so that any test applied is effectively pushed down to all
+ # upstream dependencies through the aspect.
+ inputs = []
+ runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
+ runner_content = ["#!/bin/bash"]
+ for dep in ctx.attr.deps:
+ info = dep[NogoInfo]
+ inputs.append(info.facts)
+
+ # Draw a sweet unicode checkmark with the package name (in green).
+ runner_content.append("echo -e \"\\033[0;32m\\xE2\\x9C\\x94\\033[0;31m\\033[0m %s\"" % info.importpath)
+ runner_content.append("exit 0\n")
+ ctx.actions.write(runner, "\n".join(runner_content), is_executable = True)
+ return [DefaultInfo(
+ runfiles = ctx.runfiles(files = inputs),
+ executable = runner,
+ )]
+
+_nogo_test = rule(
+ implementation = _nogo_test_impl,
+ attrs = {
+ "deps": attr.label_list(aspects = [nogo_aspect]),
+ },
+ test = True,
+)
+
+def nogo_test(**kwargs):
+ tags = kwargs.pop("tags", []) + ["nogo"]
+ _nogo_test(tags = tags, **kwargs)
diff --git a/tools/nogo/io_bazel_rules_go-visibility.patch b/tools/nogo/io_bazel_rules_go-visibility.patch
new file mode 100644
index 000000000..6b64b2e85
--- /dev/null
+++ b/tools/nogo/io_bazel_rules_go-visibility.patch
@@ -0,0 +1,25 @@
+diff --git a/third_party/org_golang_x_tools-extras.patch b/third_party/org_golang_x_tools-extras.patch
+index 133fbccc..5f0d9a47 100644
+--- a/third_party/org_golang_x_tools-extras.patch
++++ b/third_party/org_golang_x_tools-extras.patch
+@@ -32,7 +32,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/
+
+ go_library(
+ name = "go_default_library",
+-@@ -14,6 +14,23 @@
++@@ -14,6 +14,20 @@
+ ],
+ )
+
+@@ -43,10 +43,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/
+ + "imports.go",
+ + ],
+ + importpath = "golang.org/x/tools/go/analysis/internal/facts",
+-+ visibility = [
+-+ "//go/analysis:__subpackages__",
+-+ "@io_bazel_rules_go//go/tools/builders:__pkg__",
+-+ ],
+++ visibility = ["//visibility:public"],
+ + deps = [
+ + "//go/analysis:go_tool_library",
+ + "//go/types/objectpath:go_tool_library",
diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go
new file mode 100644
index 000000000..57a250501
--- /dev/null
+++ b/tools/nogo/matchers.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "go/token"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+)
+
+type matcher interface {
+ ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool
+}
+
+// pathRegexps filters explicit paths.
+type pathRegexps struct {
+ expr []*regexp.Regexp
+
+ // include, if true, indicates that paths matching any regexp in expr
+ // match.
+ //
+ // If false, paths matching no regexps in expr match.
+ include bool
+}
+
+// buildRegexps builds a list of regular expressions.
+//
+// This will panic on error.
+func buildRegexps(prefix string, args ...string) []*regexp.Regexp {
+ result := make([]*regexp.Regexp, 0, len(args))
+ for _, arg := range args {
+ result = append(result, regexp.MustCompile(filepath.Join(prefix, arg)))
+ }
+ return result
+}
+
+// ShouldReport implements matcher.ShouldReport.
+func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
+ fullPos := fs.Position(d.Pos).String()
+ for _, path := range p.expr {
+ if path.MatchString(fullPos) {
+ return p.include
+ }
+ }
+ return !p.include
+}
+
+// internalExcluded excludes specific internal paths.
+func internalExcluded(paths ...string) *pathRegexps {
+ return &pathRegexps{
+ expr: buildRegexps(internalPrefix, paths...),
+ include: false,
+ }
+}
+
+// excludedExcluded excludes specific external paths.
+func externalExcluded(paths ...string) *pathRegexps {
+ return &pathRegexps{
+ expr: buildRegexps(externalPrefix, paths...),
+ include: false,
+ }
+}
+
+// internalMatches returns a path matcher for internal packages.
+func internalMatches() *pathRegexps {
+ return &pathRegexps{
+ expr: buildRegexps(internalPrefix, ".*"),
+ include: true,
+ }
+}
+
+// resultExcluded excludes explicit message contents.
+type resultExcluded []string
+
+// ShouldReport implements matcher.ShouldReport.
+func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool {
+ for _, str := range r {
+ if strings.Contains(d.Message, str) {
+ return false
+ }
+ }
+ return true // Not excluded.
+}
+
+// andMatcher is a composite matcher.
+type andMatcher struct {
+ first matcher
+ second matcher
+}
+
+// ShouldReport implements matcher.ShouldReport.
+func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool {
+ return a.first.ShouldReport(d, fs) && a.second.ShouldReport(d, fs)
+}
+
+// and is a syntactic convension for andMatcher.
+func and(first matcher, second matcher) *andMatcher {
+ return &andMatcher{
+ first: first,
+ second: second,
+ }
+}
+
+// anyMatcher matches everything.
+type anyMatcher struct{}
+
+// ShouldReport implements matcher.ShouldReport.
+func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
+ return true
+}
+
+// alwaysMatches returns an anyMatcher instance.
+func alwaysMatches() anyMatcher {
+ return anyMatcher{}
+}
+
+// neverMatcher will never match.
+type neverMatcher struct{}
+
+// ShouldReport implements matcher.ShouldReport.
+func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool {
+ return false
+}
+
+// disableMatches returns a neverMatcher instance.
+func disableMatches() neverMatcher {
+ return neverMatcher{}
+}
diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go
new file mode 100644
index 000000000..ea1e97076
--- /dev/null
+++ b/tools/nogo/nogo.go
@@ -0,0 +1,326 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package nogo implements binary analysis similar to bazel's nogo,
+// or the unitchecker package. It exists in order to provide additional
+// facilities for analysis, namely plumbing through the output from
+// dumping the generated binary (to analyze actual produced code).
+package nogo
+
+import (
+ "encoding/json"
+ "errors"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/build"
+ "go/parser"
+ "go/token"
+ "go/types"
+ "io"
+ "io/ioutil"
+ "log"
+ "os"
+ "path/filepath"
+ "reflect"
+
+ "golang.org/x/tools/go/analysis"
+ "golang.org/x/tools/go/analysis/internal/facts"
+ "golang.org/x/tools/go/gcexportdata"
+ "gvisor.dev/gvisor/tools/nogo/data"
+)
+
+// pkgConfig is serialized as the configuration.
+//
+// This contains everything required for the analysis.
+type pkgConfig struct {
+ ImportPath string
+ GoFiles []string
+ NonGoFiles []string
+ Tags []string
+ GOOS string
+ GOARCH string
+ ImportMap map[string]string
+ FactMap map[string]string
+ FactOutput string
+ Objdump string
+ StdZip string
+}
+
+// loadFacts finds and loads facts per FactMap.
+func (c *pkgConfig) loadFacts(path string) ([]byte, error) {
+ realPath, ok := c.FactMap[path]
+ if !ok {
+ return nil, nil // No facts available.
+ }
+
+ // Read the files file.
+ data, err := ioutil.ReadFile(realPath)
+ if err != nil {
+ return nil, err
+ }
+ return data, nil
+}
+
+// shouldInclude indicates whether the file should be included.
+//
+// NOTE: This does only basic parsing of tags.
+func (c *pkgConfig) shouldInclude(path string) (bool, error) {
+ ctx := build.Default
+ ctx.GOOS = c.GOOS
+ ctx.GOARCH = c.GOARCH
+ ctx.BuildTags = c.Tags
+ return ctx.MatchFile(filepath.Dir(path), filepath.Base(path))
+}
+
+// importer is an implementation of go/types.Importer.
+//
+// This wraps a configuration, which provides the map of package names to
+// files, and the facts. Note that this importer implementation will always
+// pass when a given package is not available.
+type importer struct {
+ pkgConfig
+ fset *token.FileSet
+ cache map[string]*types.Package
+ lastErr error
+}
+
+// Import implements types.Importer.Import.
+func (i *importer) Import(path string) (*types.Package, error) {
+ if path == "unsafe" {
+ // Special case: go/types has pre-defined type information for
+ // unsafe. We ensure that this package is correct, in case any
+ // analyzers are specifically looking for this.
+ return types.Unsafe, nil
+ }
+ realPath, ok := i.ImportMap[path]
+ var (
+ rc io.ReadCloser
+ err error
+ )
+ if !ok {
+ // Not found in the import path. Attempt to find the package
+ // via the standard library.
+ rc, err = i.findStdPkg(path)
+ } else {
+ // Open the file.
+ rc, err = os.Open(realPath)
+ }
+ if err != nil {
+ i.lastErr = err
+ return nil, err
+ }
+ defer rc.Close()
+
+ // Load all exported data.
+ r, err := gcexportdata.NewReader(rc)
+ if err != nil {
+ return nil, err
+ }
+
+ return gcexportdata.Read(r, i.fset, i.cache, path)
+}
+
+// ErrSkip indicates the package should be skipped.
+var ErrSkip = errors.New("skipped")
+
+// checkPackage runs all analyzers.
+//
+// The implementation was adapted from [1], which was in turn adpated from [2].
+// This returns a list of matching analysis issues, or an error if the analysis
+// could not be completed.
+//
+// [1] bazelbuid/rules_go/tools/builders/nogo_main.go
+// [2] golang.org/x/tools/go/checker/internal/checker
+func checkPackage(config pkgConfig) ([]string, error) {
+ imp := &importer{
+ pkgConfig: config,
+ fset: token.NewFileSet(),
+ cache: make(map[string]*types.Package),
+ }
+
+ // Load all source files.
+ var syntax []*ast.File
+ for _, file := range config.GoFiles {
+ include, err := config.shouldInclude(file)
+ if err != nil {
+ return nil, fmt.Errorf("error evaluating file %q: %v", file, err)
+ }
+ if !include {
+ continue
+ }
+ s, err := parser.ParseFile(imp.fset, file, nil, parser.ParseComments)
+ if err != nil {
+ return nil, fmt.Errorf("error parsing file %q: %v", file, err)
+ }
+ syntax = append(syntax, s)
+ }
+
+ // Check type information.
+ typesSizes := types.SizesFor("gc", config.GOARCH)
+ typeConfig := types.Config{Importer: imp}
+ typesInfo := &types.Info{
+ Types: make(map[ast.Expr]types.TypeAndValue),
+ Uses: make(map[*ast.Ident]types.Object),
+ Defs: make(map[*ast.Ident]types.Object),
+ Implicits: make(map[ast.Node]types.Object),
+ Scopes: make(map[ast.Node]*types.Scope),
+ Selections: make(map[*ast.SelectorExpr]*types.Selection),
+ }
+ types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo)
+ if err != nil && imp.lastErr != ErrSkip {
+ return nil, fmt.Errorf("error checking types: %w", err)
+ }
+
+ // Load all package facts.
+ facts, err := facts.Decode(types, config.loadFacts)
+ if err != nil {
+ return nil, fmt.Errorf("error decoding facts: %w", err)
+ }
+
+ // Set the binary global for use.
+ data.Objdump = config.Objdump
+
+ // Register fact types and establish dependencies between analyzers.
+ // The visit closure will execute recursively, and populate results
+ // will all required analysis results.
+ diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic)
+ results := make(map[*analysis.Analyzer]interface{})
+ var visit func(*analysis.Analyzer) error // For recursion.
+ visit = func(a *analysis.Analyzer) error {
+ if _, ok := results[a]; ok {
+ return nil
+ }
+
+ // Run recursively for all dependencies.
+ for _, req := range a.Requires {
+ if err := visit(req); err != nil {
+ return err
+ }
+ }
+
+ // Prepare the matcher.
+ m := analyzerConfig[a]
+ report := func(d analysis.Diagnostic) {
+ if m.ShouldReport(d, imp.fset) {
+ diagnostics[a] = append(diagnostics[a], d)
+ }
+ }
+
+ // Run the analysis.
+ factFilter := make(map[reflect.Type]bool)
+ for _, f := range a.FactTypes {
+ factFilter[reflect.TypeOf(f)] = true
+ }
+ p := &analysis.Pass{
+ Analyzer: a,
+ Fset: imp.fset,
+ Files: syntax,
+ Pkg: types,
+ TypesInfo: typesInfo,
+ ResultOf: results, // All results.
+ Report: report,
+ ImportPackageFact: facts.ImportPackageFact,
+ ExportPackageFact: facts.ExportPackageFact,
+ ImportObjectFact: facts.ImportObjectFact,
+ ExportObjectFact: facts.ExportObjectFact,
+ AllPackageFacts: func() []analysis.PackageFact { return facts.AllPackageFacts(factFilter) },
+ AllObjectFacts: func() []analysis.ObjectFact { return facts.AllObjectFacts(factFilter) },
+ TypesSizes: typesSizes,
+ }
+ result, err := a.Run(p)
+ if err != nil {
+ return fmt.Errorf("error running analysis %s: %v", a, err)
+ }
+
+ // Sanity check & save the result.
+ if got, want := reflect.TypeOf(result), a.ResultType; got != want {
+ return fmt.Errorf("error: analyzer %s returned a result of type %v, but declared ResultType %v", a, got, want)
+ }
+ results[a] = result
+ return nil // Success.
+ }
+
+ // Visit all analysis recursively.
+ for a, _ := range analyzerConfig {
+ if imp.lastErr == ErrSkip {
+ continue // No local analysis.
+ }
+ if err := visit(a); err != nil {
+ return nil, err // Already has context.
+ }
+ }
+
+ // Write the output file.
+ if config.FactOutput != "" {
+ factData := facts.Encode()
+ if err := ioutil.WriteFile(config.FactOutput, factData, 0644); err != nil {
+ return nil, fmt.Errorf("error: unable to open facts output %q: %v", config.FactOutput, err)
+ }
+ }
+
+ // Convert all diagnostics to strings.
+ findings := make([]string, 0, len(diagnostics))
+ for a, ds := range diagnostics {
+ for _, d := range ds {
+ // Include the anlyzer name for debugability and configuration.
+ findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message))
+ }
+ }
+
+ // Return all findings.
+ return findings, nil
+}
+
+var (
+ configFile = flag.String("config", "", "configuration file (in JSON format)")
+)
+
+// Main is the entrypoint; it should be called directly from main.
+//
+// N.B. This package registers it's own flags.
+func Main() {
+ // Parse all flags.
+ flag.Parse()
+
+ // Load the configuration.
+ f, err := os.Open(*configFile)
+ if err != nil {
+ log.Fatalf("unable to open configuration %q: %v", *configFile, err)
+ }
+ defer f.Close()
+ config := new(pkgConfig)
+ dec := json.NewDecoder(f)
+ dec.DisallowUnknownFields()
+ if err := dec.Decode(config); err != nil {
+ log.Fatalf("unable to decode configuration: %v", err)
+ }
+
+ // Process the package.
+ findings, err := checkPackage(*config)
+ if err != nil {
+ log.Fatalf("error checking package: %v", err)
+ }
+
+ // No findings?
+ if len(findings) == 0 {
+ os.Exit(0)
+ }
+
+ // Print findings and exit with non-zero code.
+ for _, finding := range findings {
+ fmt.Fprintf(os.Stdout, "%s\n", finding)
+ }
+ os.Exit(1)
+}
diff --git a/tools/nogo/register.go b/tools/nogo/register.go
new file mode 100644
index 000000000..62b499661
--- /dev/null
+++ b/tools/nogo/register.go
@@ -0,0 +1,64 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package nogo
+
+import (
+ "encoding/gob"
+ "log"
+
+ "golang.org/x/tools/go/analysis"
+)
+
+// analyzers returns all configured analyzers.
+func analyzers() (all []*analysis.Analyzer) {
+ for a, _ := range analyzerConfig {
+ all = append(all, a)
+ }
+ return all
+}
+
+func init() {
+ // Validate basic configuration.
+ if err := analysis.Validate(analyzers()); err != nil {
+ log.Fatalf("unable to validate analyzer: %v", err)
+ }
+
+ // Register all fact types.
+ //
+ // N.B. This needs to be done recursively, because there may be
+ // analyzers in the Requires list that do not appear explicitly above.
+ registered := make(map[*analysis.Analyzer]struct{})
+ var register func(*analysis.Analyzer)
+ register = func(a *analysis.Analyzer) {
+ if _, ok := registered[a]; ok {
+ return
+ }
+
+ // Regsiter dependencies.
+ for _, da := range a.Requires {
+ register(da)
+ }
+
+ // Register local facts.
+ for _, f := range a.FactTypes {
+ gob.Register(f)
+ }
+
+ registered[a] = struct{}{} // Done.
+ }
+ for _, a := range analyzers() {
+ register(a)
+ }
+}
diff --git a/tools/tag_release.sh b/tools/tag_release.sh
index 9d5a60583..b0bab74b4 100755
--- a/tools/tag_release.sh
+++ b/tools/tag_release.sh
@@ -18,16 +18,28 @@
# validate a provided release name, create a tag and push it. It must be
# run manually when a release is created.
-set -xeu
+set -xeuo pipefail
# Check arguments.
-if [ "$#" -ne 2 ]; then
- echo "usage: $0 <commit|revid> <release.rc>"
+if [[ "$#" -ne 3 ]]; then
+ echo "usage: $0 <commit|revid> <release.rc> <message-file>"
exit 1
fi
declare -r target_commit="$1"
declare -r release="$2"
+declare -r message_file="$3"
+
+if [[ -z "${target_commit}" ]]; then
+ echo "error: <commit|revid> is empty."
+fi
+if [[ -z "${release}" ]]; then
+ echo "error: <release.rc> is empty."
+fi
+if ! [[ -r "${message_file}" ]]; then
+ echo "error: message file '${message_file}' is not readable."
+ exit 1
+fi
closest_commit() {
while read line; do
@@ -62,7 +74,9 @@ if ! [[ "${release}" =~ ^20[0-9]{6}\.[0-9]+$ ]]; then
exit 1
fi
-# Tag the given commit (annotated, to record the committer).
+# Tag the given commit (annotated, to record the committer). Note that the tag
+# here is applied as a force, in case the tag already exists and is the same.
+# The push will fail in this case (because it is not forced).
declare -r tag="release-${release}"
-(git tag -a "${tag}" "${commit}" && git push origin tag "${tag}") || \
- (git tag -d "${tag}" && false)
+git tag -f -F "${message_file}" -a "${tag}" "${commit}" && \
+ git push origin tag "${tag}"
diff --git a/tools/tags/BUILD b/tools/tags/BUILD
new file mode 100644
index 000000000..1c02e2c89
--- /dev/null
+++ b/tools/tags/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "tags",
+ srcs = ["tags.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//tools:__subpackages__"],
+)
diff --git a/tools/tags/tags.go b/tools/tags/tags.go
new file mode 100644
index 000000000..f35904e0a
--- /dev/null
+++ b/tools/tags/tags.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tags is a utility for parsing build tags.
+package tags
+
+import (
+ "fmt"
+ "io/ioutil"
+ "strings"
+)
+
+// OrSet is a set of tags on a single line.
+//
+// Note that tags may include ",", and we don't distinguish this case in the
+// logic below. Ideally, this constraints can be split into separate top-level
+// build tags in order to resolve any issues.
+type OrSet []string
+
+// Line returns the line for this or.
+func (or OrSet) Line() string {
+ return fmt.Sprintf("// +build %s", strings.Join([]string(or), " "))
+}
+
+// AndSet is the set of all OrSets.
+type AndSet []OrSet
+
+// Lines returns the lines to be printed.
+func (and AndSet) Lines() (ls []string) {
+ for _, or := range and {
+ ls = append(ls, or.Line())
+ }
+ return
+}
+
+// Join joins this AndSet with another.
+func (and AndSet) Join(other AndSet) AndSet {
+ return append(and, other...)
+}
+
+// Tags returns the unique set of +build tags.
+//
+// Derived form the runtime's canBuild.
+func Tags(file string) (tags AndSet) {
+ data, err := ioutil.ReadFile(file)
+ if err != nil {
+ return nil
+ }
+ // Check file contents for // +build lines.
+ for _, p := range strings.Split(string(data), "\n") {
+ p = strings.TrimSpace(p)
+ if p == "" {
+ continue
+ }
+ if !strings.HasPrefix(p, "//") {
+ break
+ }
+ if !strings.Contains(p, "+build") {
+ continue
+ }
+ fields := strings.Fields(p[2:])
+ if len(fields) < 1 || fields[0] != "+build" {
+ continue
+ }
+ tags = append(tags, OrSet(fields[1:]))
+ }
+ return tags
+}
+
+// Aggregate aggregates all tags from a set of files.
+//
+// Note that these may be in conflict, in which case the build will fail.
+func Aggregate(files []string) (tags AndSet) {
+ for _, file := range files {
+ tags = tags.Join(Tags(file))
+ }
+ return tags
+}
diff --git a/tools/vm/BUILD b/tools/vm/BUILD
new file mode 100644
index 000000000..d95ca6c63
--- /dev/null
+++ b/tools/vm/BUILD
@@ -0,0 +1,63 @@
+load("//tools:defs.bzl", "bzl_library", "cc_binary", "gtest")
+load("//tools/vm:defs.bzl", "vm_image", "vm_test")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+sh_binary(
+ name = "zone",
+ srcs = ["zone.sh"],
+)
+
+sh_binary(
+ name = "builder",
+ srcs = ["build.sh"],
+)
+
+sh_binary(
+ name = "executer",
+ srcs = ["execute.sh"],
+)
+
+cc_binary(
+ name = "test",
+ testonly = 1,
+ srcs = ["test.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+vm_image(
+ name = "ubuntu1604",
+ family = "ubuntu-1604-lts",
+ project = "ubuntu-os-cloud",
+ scripts = [
+ "//tools/vm/ubuntu1604",
+ ],
+)
+
+vm_image(
+ name = "ubuntu1804",
+ family = "ubuntu-1804-lts",
+ project = "ubuntu-os-cloud",
+ scripts = [
+ "//tools/vm/ubuntu1804",
+ ],
+)
+
+vm_test(
+ name = "vm_test",
+ shard_count = 2,
+ targets = [":test"],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/tools/vm/README.md b/tools/vm/README.md
new file mode 100644
index 000000000..1e9859e66
--- /dev/null
+++ b/tools/vm/README.md
@@ -0,0 +1,48 @@
+# VM Images & Tests
+
+All commands in this directory require the `gcloud` project to be set.
+
+For example: `gcloud config set project gvisor-kokoro-testing`.
+
+Images can be generated by using the `vm_image` rule. This rule will generate a
+binary target that builds an image in an idempotent way, and can be referenced
+from other rules.
+
+For example:
+
+```
+vm_image(
+ name = "ubuntu",
+ project = "ubuntu-1604-lts",
+ family = "ubuntu-os-cloud",
+ scripts = [
+ "script.sh",
+ "other.sh",
+ ],
+)
+```
+
+These images can be built manually by executing the target. The output on
+`stdout` will be the image id (in the current project).
+
+For example:
+
+```
+$ bazel build :ubuntu
+```
+
+Images are always named per the hash of all the hermetic input scripts. This
+allows images to be memoized quickly and easily.
+
+The `vm_test` rule can be used to execute a command remotely. This is still
+under development however, and will likely change over time.
+
+For example:
+
+```
+vm_test(
+ name = "mycommand",
+ image = ":ubuntu",
+ targets = [":test"],
+)
+```
diff --git a/tools/image_build.sh b/tools/vm/build.sh
index 9b20a740d..752b2b77b 100755
--- a/tools/image_build.sh
+++ b/tools/vm/build.sh
@@ -18,81 +18,100 @@
# virtualization enabled, and 2) has been completely set up with the
# image_setup.sh script. This script should be idempotent, as we memoize the
# setup script with a hash and check for that name.
-#
-# The GCP project name should be defined via a gcloud config.
-set -xeo pipefail
+set -eou pipefail
# Parameters.
-declare -r ZONE=${ZONE:-us-central1-f}
declare -r USERNAME=${USERNAME:-test}
declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
+declare -r ZONE=${ZONE:-us-central1-f}
# Random names.
declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
-# Hashes inputs.
-declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@")
-declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
-declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH}
+# Hash inputs in order to memoize the produced image.
+declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
+declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH}
# Does the image already exist? Skip the build.
-declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
+declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
if ! [[ -z "${existing}" ]]; then
echo "${existing}"
exit 0
fi
-# Set the zone for all actions.
-gcloud config set compute/zone "${ZONE}"
+# Standard arguments (applies only on script execution).
+declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
+
+# gcloud has path errors; is this a result of being a genrule?
+export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin}
# Start a unique instance. Note that this instance will have a unique persistent
# disk as it's boot disk with the same name as the instance.
-gcloud compute instances create \
+(set -x; gcloud compute instances create \
--quiet \
--image-project "${IMAGE_PROJECT}" \
--image-family "${IMAGE_FAMILY}" \
--boot-disk-size "200GB" \
- "${INSTANCE_NAME}"
+ --zone "${ZONE}" \
+ "${INSTANCE_NAME}" >/dev/null)
function cleanup {
- gcloud compute instances delete --quiet "${INSTANCE_NAME}"
+ (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}")
}
trap cleanup EXIT
-# Wait for the instance to become available.
-declare attempts=0
-while [[ "${attempts}" -lt 30 ]]; do
- attempts=$((${attempts}+1))
- if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then
- break
+# Wait for the instance to become available (up to 5 minutes).
+echo -n "Waiting for ${INSTANCE_NAME}" >&2
+declare timeout=300
+declare success=0
+declare internal=""
+declare -r start=$(date +%s)
+declare -r end=$((${start}+${timeout}))
+while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
+ echo -n "." >&2
+ if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
+ success=$((${success}+1))
+ elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
+ success=$((${success}+1))
+ internal="--internal-ip"
fi
done
-if [[ "${attempts}" -ge 30 ]]; then
- echo "too many attempts: failed"
+
+if [[ "${success}" -eq "0" ]]; then
+ echo "connect timed out after ${timeout} seconds." >&2
exit 1
+else
+ echo "done." >&2
fi
# Run the install scripts provided.
for arg; do
- gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}"
+ (set -x; gcloud compute ssh ${internal} \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ sudo bash - <"${arg}" >/dev/null)
done
# Stop the instance; required before creating an image.
-gcloud compute instances stop --quiet "${INSTANCE_NAME}"
+(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null)
# Create a snapshot of the instance disk.
-gcloud compute disks snapshot \
+(set -x; gcloud compute disks snapshot \
--quiet \
- --zone="${ZONE}" \
+ --zone "${ZONE}" \
--snapshot-names="${SNAPSHOT_NAME}" \
- "${INSTANCE_NAME}"
+ "${INSTANCE_NAME}" >/dev/null)
# Create the disk image.
-gcloud compute images create \
+(set -x; gcloud compute images create \
--quiet \
--source-snapshot="${SNAPSHOT_NAME}" \
--licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
- "${IMAGE_NAME}"
+ "${IMAGE_NAME}" >/dev/null)
+
+# Finish up.
+echo "${IMAGE_NAME}"
diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl
new file mode 100644
index 000000000..9af5ad3b4
--- /dev/null
+++ b/tools/vm/defs.bzl
@@ -0,0 +1,202 @@
+"""Image configuration. See README.md."""
+
+load("//tools:defs.bzl", "default_installer")
+
+# vm_image_builder is a rule that will construct a shell script that actually
+# generates a given VM image. Note that this does not _run_ the shell script
+# (although it can be run manually). It will be run manually during generation
+# of the vm_image target itself. This level of indirection is used so that the
+# build system itself only runs the builder once when multiple targets depend
+# on it, avoiding a set of races and conflicts.
+def _vm_image_builder_impl(ctx):
+ # Generate a binary that actually builds the image.
+ builder = ctx.actions.declare_file(ctx.label.name)
+ script_paths = []
+ for script in ctx.files.scripts:
+ script_paths.append(script.short_path)
+ builder_content = "\n".join([
+ "#!/bin/bash",
+ "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
+ "export USERNAME=%s" % ctx.attr.username,
+ "export IMAGE_PROJECT=%s" % ctx.attr.project,
+ "export IMAGE_FAMILY=%s" % ctx.attr.family,
+ "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)),
+ "",
+ ])
+ ctx.actions.write(builder, builder_content, is_executable = True)
+
+ # Note that the scripts should only be files, and should not include any
+ # indirect transitive dependencies. The build script wouldn't work.
+ return [DefaultInfo(
+ executable = builder,
+ runfiles = ctx.runfiles(
+ files = ctx.files.scripts + ctx.files._builder + ctx.files.zone,
+ ),
+ )]
+
+vm_image_builder = rule(
+ attrs = {
+ "_builder": attr.label(
+ executable = True,
+ default = "//tools/vm:builder",
+ cfg = "host",
+ ),
+ "username": attr.string(default = "$(whoami)"),
+ "zone": attr.label(
+ executable = True,
+ default = "//tools/vm:zone",
+ cfg = "host",
+ ),
+ "family": attr.string(mandatory = True),
+ "project": attr.string(mandatory = True),
+ "scripts": attr.label_list(allow_files = True),
+ },
+ executable = True,
+ implementation = _vm_image_builder_impl,
+)
+
+# See vm_image_builder above.
+def _vm_image_impl(ctx):
+ # Run the builder to generate our output.
+ echo = ctx.actions.declare_file(ctx.label.name)
+ resolved_inputs, argv, runfiles_manifests = ctx.resolve_command(
+ command = "\n".join([
+ "set -e",
+ "image=$(%s)" % ctx.files.builder[0].path,
+ "echo -ne \"#!/bin/bash\\necho ${image}\\n\" > %s" % echo.path,
+ "chmod 0755 %s" % echo.path,
+ ]),
+ tools = [ctx.attr.builder],
+ )
+ ctx.actions.run_shell(
+ tools = resolved_inputs,
+ outputs = [echo],
+ progress_message = "Building image...",
+ execution_requirements = {"local": "true"},
+ command = argv,
+ input_manifests = runfiles_manifests,
+ )
+
+ # Return just the echo command. All of the builder runfiles have been
+ # resolved and consumed in the generation of the trivial echo script.
+ return [DefaultInfo(executable = echo)]
+
+_vm_image_test = rule(
+ attrs = {
+ "builder": attr.label(
+ executable = True,
+ cfg = "host",
+ ),
+ },
+ test = True,
+ implementation = _vm_image_impl,
+)
+
+def vm_image(name, **kwargs):
+ vm_image_builder(
+ name = name + "_builder",
+ **kwargs
+ )
+ _vm_image_test(
+ name = name,
+ builder = ":" + name + "_builder",
+ tags = [
+ "local",
+ "manual",
+ ],
+ )
+
+def _vm_test_impl(ctx):
+ runner = ctx.actions.declare_file("%s-executer" % ctx.label.name)
+
+ # Note that the remote execution case must actually generate an
+ # intermediate target in order to collect all the relevant runfiles so that
+ # they can be copied over for remote execution.
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "export ZONE=$(%s)" % ctx.files.zone[0].short_path,
+ "export USERNAME=%s" % ctx.attr.username,
+ "export IMAGE=$(%s)" % ctx.files.image[0].short_path,
+ "export SUDO=%s" % "true" if ctx.attr.sudo else "false",
+ "%s %s" % (
+ ctx.executable.executer.short_path,
+ " ".join([
+ target.files_to_run.executable.short_path
+ for target in ctx.attr.targets
+ ]),
+ ),
+ "",
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return with all transitive files.
+ runfiles = ctx.runfiles(
+ transitive_files = depset(transitive = [
+ depset(target.data_runfiles.files)
+ for target in ctx.attr.targets
+ if hasattr(target, "data_runfiles")
+ ]),
+ files = ctx.files.executer + ctx.files.zone + ctx.files.image +
+ ctx.files.targets,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_vm_test = rule(
+ attrs = {
+ "image": attr.label(
+ executable = True,
+ default = "//tools/vm:ubuntu1804",
+ cfg = "host",
+ ),
+ "executer": attr.label(
+ executable = True,
+ default = "//tools/vm:executer",
+ cfg = "host",
+ ),
+ "username": attr.string(default = "$(whoami)"),
+ "zone": attr.label(
+ executable = True,
+ default = "//tools/vm:zone",
+ cfg = "host",
+ ),
+ "sudo": attr.bool(default = True),
+ "machine": attr.string(default = "n1-standard-1"),
+ "targets": attr.label_list(
+ mandatory = True,
+ allow_empty = False,
+ cfg = "target",
+ ),
+ },
+ test = True,
+ implementation = _vm_test_impl,
+)
+
+def vm_test(
+ installers = None,
+ **kwargs):
+ """Runs the given targets as a remote test.
+
+ Args:
+ installer: Script to run before all targets.
+ **kwargs: All test arguments. Should include targets and image.
+ """
+ targets = kwargs.pop("targets", [])
+ if installers == None:
+ installers = [
+ "//tools/installers:head",
+ "//tools/installers:images",
+ ]
+ targets = installers + targets
+ if default_installer():
+ targets = [default_installer()] + targets
+ _vm_test(
+ tags = [
+ "local",
+ "manual",
+ ],
+ targets = targets,
+ local = 1,
+ **kwargs
+ )
diff --git a/tools/vm/execute.sh b/tools/vm/execute.sh
new file mode 100755
index 000000000..1f1f3ce01
--- /dev/null
+++ b/tools/vm/execute.sh
@@ -0,0 +1,160 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Required input.
+if ! [[ -v IMAGE ]]; then
+ echo "no image provided: set IMAGE."
+ exit 1
+fi
+
+# Parameters.
+declare -r USERNAME=${USERNAME:-test}
+declare -r KEYNAME=$(mktemp --tmpdir -u key-XXXXXX)
+declare -r SSHKEYS=$(mktemp --tmpdir -u sshkeys-XXXXXX)
+declare -r INSTANCE_NAME=$(mktemp -u test-XXXXXX | tr A-Z a-z)
+declare -r MACHINE=${MACHINE:-n1-standard-1}
+declare -r ZONE=${ZONE:-us-central1-f}
+declare -r SUDO=${SUDO:-false}
+
+# Standard arguments (applies only on script execution).
+declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--")
+
+# This script is executed as a test rule, which will reset the value of HOME.
+# Unfortunately, it is needed to load the gconfig credentials. We will reset
+# HOME when we actually execute in the remote environment, defined below.
+export HOME=$(eval echo ~$(whoami))
+
+# Generate unique keys for this test.
+[[ -f "${KEYNAME}" ]] || ssh-keygen -t rsa -N "" -f "${KEYNAME}" -C "${USERNAME}"
+cat > "${SSHKEYS}" <<EOF
+${USERNAME}:$(cat ${KEYNAME}.pub)
+EOF
+
+# Start a unique instance. This means that we first generate a unique set of ssh
+# keys to ensure that only we have access to this instance. Note that we must
+# constrain ourselves to Haswell or greater in order to have nested
+# virtualization available.
+gcloud compute instances create \
+ --min-cpu-platform "Intel Haswell" \
+ --preemptible \
+ --no-scopes \
+ --metadata block-project-ssh-keys=TRUE \
+ --metadata-from-file ssh-keys="${SSHKEYS}" \
+ --machine-type "${MACHINE}" \
+ --image "${IMAGE}" \
+ --zone "${ZONE}" \
+ "${INSTANCE_NAME}"
+function cleanup {
+ gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}"
+}
+trap cleanup EXIT
+
+# Wait for the instance to become available (up to 5 minutes).
+declare timeout=300
+declare success=0
+declare -r start=$(date +%s)
+declare -r end=$((${start}+${timeout}))
+while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do
+ if gcloud compute ssh --ssh-key-file="${KEYNAME}" --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then
+ success=$((${success}+1))
+ fi
+done
+if [[ "${success}" -eq "0" ]]; then
+ echo "connect timed out after ${timeout} seconds."
+ exit 1
+fi
+
+# Copy the local directory over.
+tar czf - --dereference --exclude=.git . |
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ tar xzf -
+
+# Execute the command remotely.
+for cmd; do
+ # Setup relevant environment.
+ #
+ # N.B. This is not a complete test environment, but is complete enough to
+ # provide rudimentary sharding and test output support.
+ declare -a PREFIX=( "env" )
+ if [[ -v TEST_SHARD_INDEX ]]; then
+ PREFIX+=( "TEST_SHARD_INDEX=${TEST_SHARD_INDEX}" )
+ fi
+ if [[ -v TEST_SHARD_STATUS_FILE ]]; then
+ SHARD_STATUS_FILE=$(mktemp -u test-shard-status-XXXXXX)
+ PREFIX+=( "TEST_SHARD_STATUS_FILE=/tmp/${SHARD_STATUS_FILE}" )
+ fi
+ if [[ -v TEST_TOTAL_SHARDS ]]; then
+ PREFIX+=( "TEST_TOTAL_SHARDS=${TEST_TOTAL_SHARDS}" )
+ fi
+ if [[ -v TEST_TMPDIR ]]; then
+ REMOTE_TMPDIR=$(mktemp -u test-XXXXXX)
+ PREFIX+=( "TEST_TMPDIR=/tmp/${REMOTE_TMPDIR}" )
+ # Create remotely.
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ mkdir -p "/tmp/${REMOTE_TMPDIR}"
+ fi
+ if [[ -v XML_OUTPUT_FILE ]]; then
+ TEST_XML_OUTPUT=$(mktemp -u xml-output-XXXXXX)
+ PREFIX+=( "XML_OUTPUT_FILE=/tmp/${TEST_XML_OUTPUT}" )
+ fi
+ if [[ "${SUDO}" == "true" ]]; then
+ PREFIX+=( "sudo" "-E" )
+ fi
+
+ # Execute the command.
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ "${PREFIX[@]}" "${cmd}"
+
+ # Collect relevant results.
+ if [[ -v TEST_SHARD_STATUS_FILE ]]; then
+ gcloud compute scp \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${SHARD_STATUS_FILE}" \
+ "${TEST_SHARD_STATUS_FILE}" 2>/dev/null || true # Allowed to fail.
+ fi
+ if [[ -v XML_OUTPUT_FILE ]]; then
+ gcloud compute scp \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${TEST_XML_OUTPUT}" \
+ "${XML_OUTPUT_FILE}" 2>/dev/null || true # Allowed to fail.
+ fi
+
+ # Clean up the temporary directory.
+ if [[ -v TEST_TMPDIR ]]; then
+ gcloud compute ssh \
+ --ssh-key-file="${KEYNAME}" \
+ --zone "${ZONE}" \
+ "${USERNAME}"@"${INSTANCE_NAME}" -- \
+ "${SSH_ARGS[@]}" \
+ rm -rf "/tmp/${REMOTE_TMPDIR}"
+ fi
+done
diff --git a/tools/vm/test.cc b/tools/vm/test.cc
new file mode 100644
index 000000000..c0ceacda1
--- /dev/null
+++ b/tools/vm/test.cc
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gtest/gtest.h"
+
+namespace {
+
+TEST(Image, Sanity0) {
+ // Do nothing (in shard 0).
+}
+
+TEST(Image, Sanity1) {
+ // Do nothing (in shard 1).
+}
+
+} // namespace
diff --git a/kokoro/ubuntu1604/10_core.sh b/tools/vm/ubuntu1604/10_core.sh
index e87a6eee8..629f7cf7a 100755
--- a/kokoro/ubuntu1604/10_core.sh
+++ b/tools/vm/ubuntu1604/10_core.sh
@@ -17,14 +17,27 @@
set -xeo pipefail
# Install all essential build tools.
-apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config
+while true; do
+ if (apt-get update && apt-get install -y \
+ make \
+ git-core \
+ build-essential \
+ linux-headers-$(uname -r) \
+ pkg-config); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install a recent go toolchain.
if ! [[ -d /usr/local/go ]]; then
- wget https://dl.google.com/go/go1.12.linux-amd64.tar.gz
- tar -xvf go1.12.linux-amd64.tar.gz
+ wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz
+ tar -xvf go1.13.5.linux-amd64.tar.gz
mv go /usr/local
fi
# Link the Go binary from /usr/bin; replacing anything there.
-(cd /usr/bin && rm -f go && sudo ln -fs /usr/local/go/bin/go go)
+(cd /usr/bin && rm -f go && ln -fs /usr/local/go/bin/go go)
diff --git a/tools/vm/ubuntu1604/15_gcloud.sh b/tools/vm/ubuntu1604/15_gcloud.sh
new file mode 100755
index 000000000..bc2e5eccc
--- /dev/null
+++ b/tools/vm/ubuntu1604/15_gcloud.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Install all essential build tools.
+while true; do
+ if (apt-get update && apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ gnupg); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Add gcloud repositories.
+echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \
+ tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
+
+# Add the appropriate key.
+curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \
+ apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
+
+# Install the gcloud SDK.
+while true; do
+ if (apt-get update && apt-get install -y google-cloud-sdk); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
diff --git a/kokoro/ubuntu1604/20_bazel.sh b/tools/vm/ubuntu1604/20_bazel.sh
index b9a894024..bb7afa676 100755
--- a/kokoro/ubuntu1604/20_bazel.sh
+++ b/tools/vm/ubuntu1604/20_bazel.sh
@@ -16,10 +16,20 @@
set -xeo pipefail
-declare -r BAZEL_VERSION=0.29.1
+declare -r BAZEL_VERSION=2.0.0
# Install bazel dependencies.
-apt-get update && apt-get install -y openjdk-8-jdk-headless unzip
+while true; do
+ if (apt-get update && apt-get install -y \
+ openjdk-8-jdk-headless \
+ unzip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Use the release installer.
curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/kokoro/ubuntu1604/25_docker.sh b/tools/vm/ubuntu1604/30_docker.sh
index 1d3defcd3..d393133e4 100755
--- a/kokoro/ubuntu1604/25_docker.sh
+++ b/tools/vm/ubuntu1604/30_docker.sh
@@ -15,12 +15,20 @@
# limitations under the License.
# Add dependencies.
-apt-get update && apt-get -y install \
- apt-transport-https \
- ca-certificates \
- curl \
- gnupg-agent \
- software-properties-common
+while true; do
+ if (apt-get update && apt-get install -y \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
# Install the key.
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
@@ -32,4 +40,25 @@ add-apt-repository \
stable"
# Install docker.
-apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io
+while true; do
+ if (apt-get update && apt-get install -y \
+ docker-ce \
+ docker-ce-cli \
+ containerd.io); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# Enable experimental features, for cross-building aarch64 images.
+# Enable Docker IPv6.
+cat > /etc/docker/daemon.json <<EOF
+{
+ "experimental": true,
+ "fixed-cidr-v6": "2001:db8:1::/64",
+ "ipv6": true
+}
+EOF
diff --git a/kokoro/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh
index 64772d74d..d3b96c9ad 100755
--- a/kokoro/ubuntu1604/40_kokoro.sh
+++ b/tools/vm/ubuntu1604/40_kokoro.sh
@@ -23,16 +23,34 @@ declare -r ssh_public_keys=(
)
# Install dependencies.
-apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm
-
-# We need a kbuilder user.
-if useradd -c "kbuilder user" -m -s /bin/bash kbuilder; then
- # User was added successfully; we add the relevant SSH keys here.
- mkdir -p ~kbuilder/.ssh
- (IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys
- chmod 0600 ~kbuilder/.ssh/authorized_keys
- chown -R kbuilder ~kbuilder/.ssh
-fi
+while true; do
+ if (apt-get update && apt-get install -y \
+ rsync \
+ coreutils \
+ python-psutil \
+ qemu-kvm \
+ python-pip \
+ python3-pip \
+ zip); then
+ break
+ fi
+ result=$?
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+
+# junitparser is used to merge junit xml files.
+pip install --no-cache-dir junitparser
+
+# We need a kbuilder user, which may already exist.
+useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true
+
+# We need to provision appropriate keys.
+mkdir -p ~kbuilder/.ssh
+(IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys
+chmod 0600 ~kbuilder/.ssh/authorized_keys
+chown -R kbuilder ~kbuilder/.ssh
# Give passwordless sudo access.
cat > /etc/sudoers.d/kokoro <<EOF
diff --git a/tools/vm/ubuntu1604/BUILD b/tools/vm/ubuntu1604/BUILD
new file mode 100644
index 000000000..ab1df0c4c
--- /dev/null
+++ b/tools/vm/ubuntu1604/BUILD
@@ -0,0 +1,7 @@
+package(licenses = ["notice"])
+
+filegroup(
+ name = "ubuntu1604",
+ srcs = glob(["*.sh"]),
+ visibility = ["//:sandbox"],
+)
diff --git a/tools/vm/ubuntu1804/BUILD b/tools/vm/ubuntu1804/BUILD
new file mode 100644
index 000000000..0c8856dde
--- /dev/null
+++ b/tools/vm/ubuntu1804/BUILD
@@ -0,0 +1,7 @@
+package(licenses = ["notice"])
+
+alias(
+ name = "ubuntu1804",
+ actual = "//tools/vm/ubuntu1604",
+ visibility = ["//:sandbox"],
+)
diff --git a/tools/vm/zone.sh b/tools/vm/zone.sh
new file mode 100755
index 000000000..79569fb19
--- /dev/null
+++ b/tools/vm/zone.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+exec gcloud config get-value compute/zone
diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh
index fb09ff331..a22c8c9f2 100755
--- a/tools/workspace_status.sh
+++ b/tools/workspace_status.sh
@@ -15,4 +15,4 @@
# limitations under the License.
# The STABLE_ prefix will trigger a re-link if it changes.
-echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty)
+echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty || echo 0.0.0)
diff --git a/vdso/BUILD b/vdso/BUILD
index 7ceed349e..c70bb8218 100644
--- a/vdso/BUILD
+++ b/vdso/BUILD
@@ -3,20 +3,10 @@
# normal system VDSO (time, gettimeofday, clock_gettimeofday) but which uses
# timekeeping parameters managed by the sandbox kernel.
-load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", "cc_flags_supplier")
+load("//tools:defs.bzl", "cc_flags_supplier", "cc_toolchain", "select_arch", "vdso_linker_option")
package(licenses = ["notice"])
-config_setting(
- name = "x86_64",
- constraint_values = ["@bazel_tools//platforms:x86_64"],
-)
-
-config_setting(
- name = "aarch64",
- constraint_values = ["@bazel_tools//platforms:aarch64"],
-)
-
genrule(
name = "vdso",
srcs = [
@@ -39,14 +29,15 @@ genrule(
"-O2 " +
"-std=c++11 " +
"-fPIC " +
+ "-fno-sanitize=all " +
# Some toolchains enable stack protector by default. Disable it, the
# VDSO has no hooks to handle failures.
"-fno-stack-protector " +
- "-fuse-ld=gold " +
- select({
- ":x86_64": "-m64 ",
- "//conditions:default": "",
- }) +
+ vdso_linker_option +
+ select_arch(
+ amd64 = "-m64 ",
+ arm64 = "",
+ ) +
"-shared " +
"-nostdlib " +
"-Wl,-soname=linux-vdso.so.1 " +
@@ -55,12 +46,10 @@ genrule(
"-Wl,-Bsymbolic " +
"-Wl,-z,max-page-size=4096 " +
"-Wl,-z,common-page-size=4096 " +
- select(
- {
- ":x86_64": "-Wl,-T$(location vdso_amd64.lds) ",
- ":aarch64": "-Wl,-T$(location vdso_arm64.lds) ",
- },
- no_match_error = "Unsupported architecture",
+ select_arch(
+ amd64 = "-Wl,-T$(location vdso_amd64.lds) ",
+ arm64 = "-Wl,-T$(location vdso_arm64.lds) ",
+ no_match_error = "unsupported architecture",
) +
"-o $(location vdso.so) " +
"$(location vdso.cc) " +
@@ -68,14 +57,14 @@ genrule(
"&& $(location :check_vdso) " +
"--check-data " +
"--vdso $(location vdso.so) ",
+ exec_tools = [
+ ":check_vdso",
+ ],
features = ["-pie"],
toolchains = [
- "@bazel_tools//tools/cpp:current_cc_toolchain",
+ cc_toolchain,
":no_pie_cc_flags",
],
- tools = [
- ":check_vdso",
- ],
visibility = ["//:sandbox"],
)
@@ -87,6 +76,6 @@ cc_flags_supplier(
py_binary(
name = "check_vdso",
srcs = ["check_vdso.py"],
- python_version = "PY2",
+ python_version = "PY3",
visibility = ["//:sandbox"],
)
diff --git a/vdso/syscalls.h b/vdso/syscalls.h
index f5865bb72..0c6a922a0 100644
--- a/vdso/syscalls.h
+++ b/vdso/syscalls.h
@@ -26,6 +26,9 @@
#include <stddef.h>
#include <sys/types.h>
+#define __stringify_1(x...) #x
+#define __stringify(x...) __stringify_1(x)
+
namespace vdso {
#if __x86_64__
@@ -51,22 +54,15 @@ static inline int sys_getcpu(unsigned* cpu, unsigned* node,
return num;
}
-#elif __aarch64__
-
-static inline int sys_rt_sigreturn(void) {
- int num = __NR_rt_sigreturn;
-
- asm volatile(
- "mov x8, %0\n"
- "svc #0 \n"
- : "+r"(num)
- :
- :);
- return num;
+static inline void sys_rt_sigreturn(void) {
+ asm volatile("movl $" __stringify(__NR_rt_sigreturn)", %eax \n"
+ "syscall \n");
}
-static inline int sys_clock_gettime(clockid_t _clkid, struct timespec *_ts) {
- register struct timespec *ts asm("x1") = _ts;
+#elif __aarch64__
+
+static inline int sys_clock_gettime(clockid_t _clkid, struct timespec* _ts) {
+ register struct timespec* ts asm("x1") = _ts;
register clockid_t clkid asm("x0") = _clkid;
register long ret asm("x0");
register long nr asm("x8") = __NR_clock_gettime;
@@ -78,8 +74,8 @@ static inline int sys_clock_gettime(clockid_t _clkid, struct timespec *_ts) {
return ret;
}
-static inline int sys_clock_getres(clockid_t _clkid, struct timespec *_ts) {
- register struct timespec *ts asm("x1") = _ts;
+static inline int sys_clock_getres(clockid_t _clkid, struct timespec* _ts) {
+ register struct timespec* ts asm("x1") = _ts;
register clockid_t clkid asm("x0") = _clkid;
register long ret asm("x0");
register long nr asm("x8") = __NR_clock_getres;
@@ -91,6 +87,11 @@ static inline int sys_clock_getres(clockid_t _clkid, struct timespec *_ts) {
return ret;
}
+static inline void sys_rt_sigreturn(void) {
+ asm volatile("mov x8, #" __stringify(__NR_rt_sigreturn)" \n"
+ "svc #0 \n");
+}
+
#else
#error "unsupported architecture"
#endif
diff --git a/vdso/vdso.cc b/vdso/vdso.cc
index 8bb80a7a4..3b6653b5d 100644
--- a/vdso/vdso.cc
+++ b/vdso/vdso.cc
@@ -69,6 +69,12 @@ int __common_gettimeofday(struct timeval* tv, struct timezone* tz) {
}
} // namespace
+// __kernel_rt_sigreturn() implements rt_sigreturn()
+extern "C" void __kernel_rt_sigreturn(unsigned long unused) {
+ // No optimizations yet, just make the real system call.
+ sys_rt_sigreturn();
+}
+
#if __x86_64__
// __vdso_clock_gettime() implements clock_gettime()
@@ -126,6 +132,10 @@ extern "C" int __kernel_clock_getres(clockid_t clock, struct timespec* res) {
case CLOCK_REALTIME:
case CLOCK_MONOTONIC:
case CLOCK_BOOTTIME: {
+ if (res == nullptr) {
+ return 0;
+ }
+
res->tv_sec = 0;
res->tv_nsec = 1;
break;
@@ -139,12 +149,6 @@ extern "C" int __kernel_clock_getres(clockid_t clock, struct timespec* res) {
return ret;
}
-// __kernel_rt_sigreturn() implements gettimeofday()
-extern "C" int __kernel_rt_sigreturn(unsigned long unused) {
- // No optimizations yet, just make the real system call.
- return sys_rt_sigreturn();
-}
-
#else
#error "unsupported architecture"
#endif
diff --git a/vdso/vdso_amd64.lds b/vdso/vdso_amd64.lds
index e2615ae9e..d114290da 100644
--- a/vdso/vdso_amd64.lds
+++ b/vdso/vdso_amd64.lds
@@ -95,6 +95,7 @@ VERSION {
__vdso_getcpu;
time;
__vdso_time;
+ __kernel_rt_sigreturn;
local: *;
};
diff --git a/website/BUILD b/website/BUILD
new file mode 100644
index 000000000..7b61d13c8
--- /dev/null
+++ b/website/BUILD
@@ -0,0 +1,188 @@
+load("//tools:defs.bzl", "bzl_library", "pkg_tar")
+load("//website:defs.bzl", "doc", "docs")
+
+package(licenses = ["notice"])
+
+# website is the full container image. Note that this actually just collects
+# other dependendcies and runs Docker locally to import and tag the image.
+sh_binary(
+ name = "website",
+ srcs = ["import.sh"],
+ data = [":files"],
+ tags = [
+ "local",
+ "manual",
+ ],
+)
+
+# files is the full file system of the generated container.
+#
+# It must collect the all tarballs (produced by the rules below), and run it
+# through the Dockerfile to generate the site. Note that this checks all links,
+# and therefore requires all static content to be present as well.
+#
+# Note that this rule violates most aspects of hermetic builds. However, this
+# works much more reliably than depending on the container_image rules from
+# bazel itself, which are convoluted and seem to have a hard time even finding
+# the toolchain.
+genrule(
+ name = "files",
+ srcs = [
+ ":config",
+ ":css",
+ ":docs",
+ ":static",
+ ":syscallmd",
+ "//website/blog:posts",
+ "//website/cmd/server",
+ ],
+ outs = ["files.tgz"],
+ cmd = "set -x; " +
+ "T=$$(mktemp -d); " +
+ "mkdir -p $$T/input && " +
+ "mkdir -p $$T/output/_site && " +
+ "tar -xf $(location :config) -C $$T/input && " +
+ "tar -xf $(location :css) -C $$T/input && " +
+ "tar -xf $(location :docs) -C $$T/input && " +
+ "tar -xf $(location :syscallmd) -C $$T/input && " +
+ "tar -xf $(location //website/blog:posts) -C $$T/input && " +
+ "find $$T/input -type f -exec chmod u+rw {} \\; && " +
+ "docker run -i --user $$(id -u):$$(id -g) " +
+ "-v $$(readlink -m $$T/input):/input " +
+ "-v $$(readlink -m $$T/output/_site):/output " +
+ "gvisor.dev/images/jekyll && " +
+ "tar -xf $(location :static) -C $$T/output/_site && " +
+ "docker run -i --user $$(id -u):$$(id -g) " +
+ "-v $$(readlink -m $$T/output/_site):/output " +
+ "gvisor.dev/images/jekyll " +
+ "ruby /checks.rb " +
+ "/output && " +
+ "cp $(location //website/cmd/server) $$T/output/server && " +
+ "tar -zcf $@ -C $$T/output . && " +
+ "rm -rf $$T",
+ tags = [
+ "local",
+ "manual",
+ "nosandbox",
+ ],
+)
+
+# static are the purely static parts of the site. These are effectively copied
+# in after jekyll generates all the dynamic content.
+pkg_tar(
+ name = "static",
+ srcs = [
+ "archive.key",
+ ] + glob([
+ "performance/**",
+ ]),
+ strip_prefix = "./",
+)
+
+# main.scss requires front-matter to be processed.
+genrule(
+ name = "css",
+ srcs = glob([
+ "css/**",
+ ]),
+ outs = [
+ "css.tar",
+ ],
+ cmd = "T=$$(mktemp -d); " +
+ "mkdir -p $$T/css && " +
+ "for file in $(SRCS); do " +
+ "echo -en '---\\n---\\n' > $$T/css/$$(basename $$file) && " +
+ "cat $$file >> $$T/css/$$(basename $$file); " +
+ "done && " +
+ "tar -C $$T -czf $@ . && " +
+ "rm -rf $$T",
+)
+
+# config is "mostly" static content. These are parts of the site that are
+# present when jekyll runs, but are not dynamically generated.
+pkg_tar(
+ name = "config",
+ srcs = [
+ ":css",
+ "_config.yml",
+ "//website/blog:index.html",
+ ] + glob([
+ "assets/**",
+ "_includes/**",
+ "_layouts/**",
+ "_plugins/**",
+ "_sass/**",
+ ]),
+ strip_prefix = "./",
+)
+
+# index is the index file.
+doc(
+ name = "index",
+ src = "index.md",
+ layout = "base",
+ permalink = "/",
+)
+
+# docs is the dynamic content of the site.
+docs(
+ name = "docs",
+ deps = [
+ ":index",
+ "//:code_of_conduct",
+ "//:contributing",
+ "//:governance",
+ "//:security",
+ "//g3doc:community",
+ "//g3doc:index",
+ "//g3doc:roadmap",
+ "//g3doc:style",
+ "//g3doc/architecture_guide:performance",
+ "//g3doc/architecture_guide:platforms",
+ "//g3doc/architecture_guide:resources",
+ "//g3doc/architecture_guide:security",
+ "//g3doc/user_guide:FAQ",
+ "//g3doc/user_guide:checkpoint_restore",
+ "//g3doc/user_guide:compatibility",
+ "//g3doc/user_guide:debugging",
+ "//g3doc/user_guide:filesystem",
+ "//g3doc/user_guide:install",
+ "//g3doc/user_guide:networking",
+ "//g3doc/user_guide:platforms",
+ "//g3doc/user_guide/containerd:configuration",
+ "//g3doc/user_guide/containerd:containerd_11",
+ "//g3doc/user_guide/containerd:quick_start",
+ "//g3doc/user_guide/quick_start:docker",
+ "//g3doc/user_guide/quick_start:kubernetes",
+ "//g3doc/user_guide/quick_start:oci",
+ "//g3doc/user_guide/tutorials:cni",
+ "//g3doc/user_guide/tutorials:docker",
+ "//g3doc/user_guide/tutorials:kubernetes",
+ ],
+)
+
+# Generate JSON for system call tables
+genrule(
+ name = "syscalljson",
+ outs = ["syscalls.json"],
+ cmd = "$(location //runsc) -- help syscalls -format json -filename $@",
+ tools = ["//runsc"],
+)
+
+# Generate markdown from the json dump.
+genrule(
+ name = "syscallmd",
+ srcs = [":syscalljson"],
+ outs = ["syscallsmd"],
+ cmd = "T=$$(mktemp -d) && " +
+ "$(location //website/cmd/syscalldocs) -in $< -out $$T && " +
+ "tar -C $$T -czf $@ . && " +
+ "rm -rf $$T",
+ tools = ["//website/cmd/syscalldocs"],
+)
+
+bzl_library(
+ name = "defs_bzl",
+ srcs = ["defs.bzl"],
+ visibility = ["//visibility:private"],
+)
diff --git a/website/_config.yml b/website/_config.yml
new file mode 100644
index 000000000..b08602970
--- /dev/null
+++ b/website/_config.yml
@@ -0,0 +1,36 @@
+destination: _site
+markdown: kramdown
+kramdown:
+ syntax_highlighter: rouge
+ toc_levels: "2,3"
+highlighter: rouge
+paginate: 5
+paginate_path: "/blog/page:num/"
+plugins:
+ - jekyll-paginate
+ - jekyll-autoprefixer
+ - jekyll-inline-svg
+ - jekyll-relative-links
+ - jekyll-feed
+ - jekyll-sitemap
+site_url: https://gvisor.dev
+feed:
+ path: blog/index.xml
+svg:
+ optimize: true
+defaults:
+ - scope:
+ path: ""
+ values:
+ layout: default
+analytics: "UA-150193582-1"
+authors:
+ jsprad:
+ name: Jeremiah Spradlin
+ email: jsprad@google.com
+ zkoopmans:
+ name: Zach Koopmans
+ email: zkoopmans@google.com
+ igudger:
+ name: Ian Gudger
+ email: igudger@google.com
diff --git a/website/_includes/byline.html b/website/_includes/byline.html
new file mode 100644
index 000000000..d8ae22cb0
--- /dev/null
+++ b/website/_includes/byline.html
@@ -0,0 +1,18 @@
+By
+{% assign last_pos=include.authors.size | minus: 1 %}
+{% assign and_pos=include.authors.size | minus: 2 %}
+{% for i in (0..last_pos) %}
+ {% assign author_id=include.authors[i] %}
+ {% assign author=site.authors[author_id] %}
+ {% if author %}
+ <a href="mailto:{{ author.email }}">{{ author.name }}</a>
+ {% else %}
+ {{ author_id }}
+ {% endif %}
+ {% if i == and_pos %}
+ and
+ {% elsif i < and_pos %}
+ ,
+ {% endif %}
+{% endfor %}
+on <span class="text-muted">{{ include.date | date_to_long_string }}</span>
diff --git a/website/_includes/footer-links.html b/website/_includes/footer-links.html
new file mode 100644
index 000000000..2036dbaa9
--- /dev/null
+++ b/website/_includes/footer-links.html
@@ -0,0 +1,43 @@
+<div class="container">
+ <div class="row">
+ <div class="col-sm-3 col-md-2">
+ <p>About</p>
+ <ul class="list-unstyled">
+ <li><a href="/roadmap/">Roadmap</a></li>
+ <li><a href="/contributing/">Contributing</a></li>
+ <li><a href="/security/">Security</a></li>
+ <li><a href="/community/governance/">Governance</a></li>
+ <li><a href="https://policies.google.com/privacy">Privacy Policy</a></li>
+ </ul>
+ </div>
+ <div class="col-sm-3 col-md-2">
+ <p>Support</p>
+ <ul class="list-unstyled">
+ <li><a href="https://github.com/google/gvisor/issues">Issues</a></li>
+ <li><a href="/docs">Documentation</a></li>
+ <li><a href="/docs/user_guide/faq">FAQ</a></li>
+ </ul>
+ </div>
+ <div class="col-sm-3 col-md-2">
+ <p>Connect</p>
+ <ul class="list-unstyled">
+ <li><a href="https://github.com/google/gvisor">GitHub</a></li>
+ <li><a href="https://groups.google.com/forum/#!forum/gvisor-users">User Mailing List</a></li>
+ <li><a href="https://groups.google.com/forum/#!forum/gvisor-dev">Developer Mailing List</a></li>
+ <li><a href="https://gitter.im/gvisor/community">Gitter Chat</a></li>
+ <li><a href="/blog">Blog</a></li>
+ </ul>
+ </div>
+ <div class="col-sm-3 col-md-3"></div>
+ <div class="hidden-xs hidden-sm col-md-3">
+ <a href="https://cloud.google.com/run">
+ <img style="float: right;" src="/assets/logos/powered-gvisor.png" alt="Powered by gVisor"/>
+ </a>
+ </div>
+ </div>
+ <div class="row">
+ <div class="col-lg-12">
+ <p>&copy; {{ 'now' | date: "%Y" }} The gVisor Authors</p>
+ </div>
+ </div>
+</div>
diff --git a/website/_includes/footer.html b/website/_includes/footer.html
new file mode 100644
index 000000000..c1a373329
--- /dev/null
+++ b/website/_includes/footer.html
@@ -0,0 +1,72 @@
+<footer class="footer">
+ {% include footer-links.html %}
+</footer>
+
+<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.3.1/jquery.min.js" integrity="sha256-FgpCb/KJQlLNfOu91ta32o/NMZxltwRo8QtmkMRdAu8=" crossorigin="anonymous"></script>
+<script src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.10.1/js/all.min.js" integrity="sha256-Z1Nvg/+y2+vRFhFgFij7Lv0r77yG3hOvWz2wI0SfTa0=" crossorigin="anonymous"></script>
+<script src="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha256-U5ZEeKfGNOja007MMD3YBI0A3OSZOQbeG6z2f2Y0hu8=" crossorigin="anonymous"></script>
+<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.13.0/d3.min.js" integrity="sha256-hYXbQJK4qdJiAeDVjjQ9G0D6A0xLnDQ4eJI9dkm7Fpk=" crossorigin="anonymous"></script>
+
+{% if site.analytics %}
+<script>
+var doNotTrack = false;
+if (!doNotTrack) {
+ window.ga=window.ga||function(){(ga.q=ga.q||[]).push(arguments)};ga.l=+new Date;
+ ga('create', '{{ site.analytics }}', 'auto');
+ ga('send', 'pageview');
+}
+</script>
+<script async src='https://www.google-analytics.com/analytics.js'></script>
+{% endif %}
+
+<script>
+ var shiftWindow = function() {
+ if (location.hash.length !== 0) {
+ window.scrollBy(0, -50);
+ }
+ };
+ window.addEventListener("hashchange", shiftWindow);
+
+ var highlightCurrentSidebarNav = function() {
+ var href = location.pathname;
+ var item = $('#sidebar-nav [href$="' + href + '"]');
+ if (item) {
+ var li = item.parent();
+ li.addClass("active");
+
+ if (li.parent() && li.parent().is("ul")) {
+ do {
+ var ul = li.parent();
+ if (ul.hasClass("collapse")) {
+ ul.collapse("show");
+ }
+ li = ul.parent();
+ } while (li && li.is("li"));
+ }
+ }
+ };
+
+ $(document).ready(function() {
+ // Scroll to anchor of location hash, adjusted for fixed navbar.
+ window.setTimeout(function() {
+ shiftWindow();
+ }, 1);
+
+ // Flip the caret when submenu toggles are clicked.
+ $(".sidebar-submenu").on("show.bs.collapse", function() {
+ var toggle = $('[href$="#' + $(this).attr('id') + '"]');
+ if (toggle) {
+ toggle.addClass("dropup");
+ }
+ });
+ $(".sidebar-submenu").on("hide.bs.collapse", function() {
+ var toggle = $('[href$="#' + $(this).attr('id') + '"]');
+ if (toggle) {
+ toggle.removeClass("dropup");
+ }
+ });
+
+ // Highlight the current page on the sidebar nav.
+ highlightCurrentSidebarNav();
+ });
+</script>
diff --git a/website/_includes/graph.html b/website/_includes/graph.html
new file mode 100644
index 000000000..ba4cf9840
--- /dev/null
+++ b/website/_includes/graph.html
@@ -0,0 +1,205 @@
+{::nomarkdown}
+{% assign fn = include.id | remove: " " | remove: "-" | downcase %}
+<figure><a href="{{ include.url }}"><svg id="{{ include.id }}" width=500 height=200 onload="render_{{ fn }}()"><title>{{ include.title }}</title></svg></a></figure>
+<script>
+function render_{{ fn }}() {
+d3.csv("{{ include.url }}", function(d, i, columns) {
+ return d; // Transformed below.
+}, function(error, data) {
+ if (error) throw(error);
+
+ // Create a new data that pivots on runtime.
+ //
+ // To start, we have:
+ // runtime, ..., result
+ // runc, ..., 1
+ // runsc, ..., 2
+ //
+ // In the end we want:
+ // ..., runsc, runc
+ // ..., 1, 2
+
+ // Filter by metric, if required.
+ if ("{{ include.metric }}" != "") {
+ orig_columns = data.columns;
+ data = data.filter(d => d.metric == "{{ include.metric }}");
+ data.columns = orig_columns;
+ }
+
+ // Filter by method, if required.
+ if ("{{ include.method }}" != "") {
+ orig_columns = data.columns;
+ data = data.filter(d => d.method == "{{ include.method }}");
+ data.columns = orig_columns.filter(key => key != "method");
+ }
+
+ // Enumerate runtimes.
+ var runtimes = Array.from(new Set(data.map(d => d.runtime)));
+ var metrics = Array.from(new Set(data.map(d => d.metric)));
+ if (metrics.length < 1) {
+ console.log(data);
+ throw("need at least one metric");
+ } else if (metrics.length == 1) {
+ metric = metrics[0];
+ data.columns = data.columns.filter(key => key != "metric");
+ } else {
+ metric = ""; // Used for grouping.
+ }
+
+ var isSubset = function(a, sup) {
+ var ap = Object.getOwnPropertyNames(a);
+ for (var i = 0; i < ap.length; i++) {
+ if (a[ap[i]] !== sup[ap[i]]) {
+ return false;
+ }
+ }
+ return true;
+ };
+
+ // Execute a pivot to include runtimes as attributes.
+ var new_data = data.map(function(data_item) {
+ // Generate a prototype data item.
+ var proto_item = Object.assign({}, data_item);
+ delete proto_item.runtime;
+ delete proto_item.result;
+ var next_item = Object.assign({}, proto_item);
+
+ // Find all matching runtime items.
+ data.forEach(function(d) {
+ if (isSubset(proto_item, d)) {
+ // Add the result result.
+ next_item[d.runtime] = d.result;
+ }
+ });
+ return next_item;
+ });
+
+ // Remove any duplication.
+ new_data = Array.from(new Set(new_data));
+ new_data.columns = data.columns;
+ new_data.columns = new_data.columns.filter(key => key != "runtime" && key != "result");
+ new_data.columns = new_data.columns.concat(runtimes);
+ data = new_data;
+
+ // Slice based on the first key.
+ if (data.columns.length != runtimes.length) {
+ x0_key = new_data.columns[0];
+ var x1_domain = data.columns.slice(1);
+ } else {
+ x0_key = "runtime";
+ var x1_domain = runtimes;
+ }
+
+ // Determine varaible margins.
+ var x0_domain = data.map(d => d[x0_key]);
+ var margin_bottom_pad = 0;
+ if (x0_domain.length > 8) {
+ margin_bottom_pad = 50;
+ }
+
+ // Use log scale if required.
+ var y_min = 0;
+ if ({{ include.log | default: "false" }}) {
+ // Need to cap lower end of the domain at 1.
+ y_min = 1;
+ }
+
+ if ({{ include.y_min | default: "false" }}) {
+ y_min = "{{ include.y_min }}";
+ }
+
+ var svg = d3.select("#{{ include.id }}"),
+ margin = {top: 20, right: 20, bottom: 30 + margin_bottom_pad, left: 50},
+ width = +svg.attr("width") - margin.left - margin.right,
+ height = +svg.attr("height") - margin.top - margin.bottom,
+ g = svg.append("g").attr("transform", "translate(" + margin.left + "," + margin.top + ")");
+
+ var x0 = d3.scaleBand()
+ .rangeRound([margin.left / 2, width - (4 * margin.right)])
+ .paddingInner(0.1);
+
+ var x1 = d3.scaleBand()
+ .padding(0.05);
+
+ var y = d3.scaleLinear()
+ .rangeRound([height, 0]);
+ if ({{ include.log | default: "false" }}) {
+ y = d3.scaleLog()
+ .rangeRound([height, 0]);
+ }
+
+ var z = d3.scaleOrdinal()
+ .range(["#262362", "#FBB03B", "#286FD7", "#6b486b"]);
+
+ // Set all domains.
+ x0.domain(x0_domain);
+ x1.domain(x1_domain).rangeRound([0, x0.bandwidth()]);
+ y.domain([y_min, d3.max(data, d => d3.max(x1_domain, key => parseFloat(d[key])))]).nice();
+
+ // The data.
+ g.append("g")
+ .selectAll("g")
+ .data(data)
+ .enter().append("g")
+ .attr("transform", function(d) { return "translate(" + x0(d[x0_key]) + ",0)"; })
+ .selectAll("rect")
+ .data(d => x1_domain.map(key => ({key, value: d[key]})))
+ .enter().append("rect")
+ .attr("x", d => x1(d.key))
+ .attr("y", d => y(d.value))
+ .attr("width", x1.bandwidth())
+ .attr("height", d => y(y_min) - y(d.value))
+ .attr("fill", d => z(d.key));
+
+ // X0 ticks and labels.
+ var x0_axis = g.append("g")
+ .attr("class", "axis")
+ .attr("transform", "translate(0," + height + ")")
+ .call(d3.axisBottom(x0));
+ if (x0_domain.length > 8) {
+ x0_axis.selectAll("text")
+ .style("text-anchor", "end")
+ .attr("dx", "-.8em")
+ .attr("dy", ".15em")
+ .attr("transform", "rotate(-65)");
+ }
+
+ // Y ticks and top-label.
+ if (metric == "default") {
+ metric = ""; // Don't display.
+ }
+ g.append("g")
+ .attr("class", "axis")
+ .call(d3.axisLeft(y).ticks(null, "s"))
+ .append("text")
+ .attr("x", -30.0)
+ .attr("y", y(y.ticks().pop()) - 10.0)
+ .attr("dy", "0.32em")
+ .attr("fill", "#000")
+ .attr("font-weight", "bold")
+ .attr("text-anchor", "start")
+ .text(metric);
+
+ // The legend.
+ var legend = g.append("g")
+ .attr("font-family", "sans-serif")
+ .attr("font-size", 10)
+ .attr("text-anchor", "end")
+ .selectAll("g")
+ .data(x1_domain.slice().reverse())
+ .enter().append("g")
+ .attr("transform", function(d, i) { return "translate(0," + i * 20 + ")"; });
+ legend.append("rect")
+ .attr("x", width - 19)
+ .attr("width", 19)
+ .attr("height", 19)
+ .attr("fill", z);
+ legend.append("text")
+ .attr("x", width - 24)
+ .attr("y", 9.5)
+ .attr("dy", "0.32em")
+ .text(function(d) { return d; });
+});
+}
+</script>
+{:/}
diff --git a/website/_includes/header-links.html b/website/_includes/header-links.html
new file mode 100644
index 000000000..4232fdaa5
--- /dev/null
+++ b/website/_includes/header-links.html
@@ -0,0 +1,19 @@
+<nav class="navbar navbar-expand-sm navbar-inverse navbar-fixed-top">
+ <div class="container">
+ <div class="navbar-brand">
+ <a href="/">
+ <img src="/assets/logos/logo_solo_on_dark.svg" height="25" class="d-inline-block align-top" style="margin-right: 10px;" alt="logo" />
+ gVisor
+ </a>
+ </div>
+
+ <div class="collapse navbar-collapse">
+ <ul class="nav navbar-nav navbar-right">
+ <li><a href="/docs">Documentation</a></li>
+ <li><a href="/blog">Blog</a></li>
+ <li><a href="/community/">Community</a></li>
+ <li><a href="https://github.com/google/gvisor">GitHub</a></li>
+ </ul>
+ </div>
+ </div>
+</nav>
diff --git a/website/_includes/header.html b/website/_includes/header.html
new file mode 100644
index 000000000..c80310069
--- /dev/null
+++ b/website/_includes/header.html
@@ -0,0 +1,30 @@
+ <head>
+ <meta charset="utf-8">
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <meta name="viewport" content="width=device-width, initial-scale=1">
+ {% if page.title %}
+ <title>{{ page.title }} - gVisor</title>
+ {% else %}
+ <title>gVisor</title>
+ {% endif %}
+ <link rel="canonical" href="{{ page.url | replace:'index.html','' | prepend: site_root }}">
+
+ <!-- Dependencies. -->
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/3.3.7/css/bootstrap.min.css" integrity="sha256-916EbMg70RQy9LHiGkXzG8hSg9EdNy97GazNG/aiY1w=" crossorigin="anonymous" />
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.10.1/css/all.min.css" integrity="sha256-fdcFNFiBMrNfWL6OcAGQz6jDgNTRxnrLEd4vJYFWScE=" crossorigin="anonymous" />
+
+ <!-- Our own style sheet. -->
+ <link rel="stylesheet" type="text/css" href="/css/main.css">
+ <link rel="icon" type="image/png" href="/assets/favicons/favicon-32x32.png" sizes="32x32">
+ <link rel="icon" type="image/png" href="/assets/favicons/favicon-16x16.png" sizes="16x16">
+
+ {% if page.title %}
+ <meta name="og:title" content="{{ page.title }}">
+ {% else %}
+ <meta name="og:title" content="gVisor">
+ {% endif %}
+ {% if page.description %}
+ <meta name="og:description" content="{{ page.description }}">
+ {% endif %}
+ <meta name="og:image" content="{{ site.site_url }}/assets/logos/logo_solo_on_white_bordered.svg">
+ </head>
diff --git a/website/_includes/paginator.html b/website/_includes/paginator.html
new file mode 100644
index 000000000..b4ff4c3b1
--- /dev/null
+++ b/website/_includes/paginator.html
@@ -0,0 +1,10 @@
+<nav aria-label="...">
+ <ul class="pager">
+ {% if paginator.previous_page %}
+ <li class="previous"><a href="{{ paginator.previous_page_path }}"><span aria-hidden="true">&larr;</span> Newer</a></li>
+ {% endif %}
+ {% if paginator.next_page %}
+ <li class="next"><a href="{{ paginator.next_page_path }}">Older <span aria-hidden="true">&rarr;</span></a></li>
+ {% endif %}
+ </ul>
+</nav>
diff --git a/website/_includes/required_linux.html b/website/_includes/required_linux.html
new file mode 100644
index 000000000..e9d1b7548
--- /dev/null
+++ b/website/_includes/required_linux.html
@@ -0,0 +1,2 @@
+> Note: gVisor supports only x86\_64 and requires Linux 4.14.77+
+> ([older Linux](/docs/user_guide/networking/#gso)).
diff --git a/website/_layouts/base.html b/website/_layouts/base.html
new file mode 100644
index 000000000..b30bee0dc
--- /dev/null
+++ b/website/_layouts/base.html
@@ -0,0 +1,9 @@
+<!DOCTYPE html>
+<html lang="en" itemscope itemtype="https://schema.org/WebPage">
+ {% include header.html %}
+ <body>
+ {% include header-links.html %}
+ {{ content }}
+ {% include footer.html %}
+ </body>
+</html>
diff --git a/website/_layouts/blog.html b/website/_layouts/blog.html
new file mode 100644
index 000000000..6c371ab50
--- /dev/null
+++ b/website/_layouts/blog.html
@@ -0,0 +1,17 @@
+---
+layout: base
+---
+
+<div class="container">
+ <div class="row">
+ <div class="col-lg-2"></div>
+ <div class="col-lg-8">
+ <h1>{{ page.title }}</h1>
+ {% if page.feed %}
+ <a class="btn-inverse" href="/blog/index.xml">Feed&nbsp;<i class="fas fa-rss ml-2"></i></a>
+ {% endif %}
+ {{ content }}
+ </div>
+ <div class="col-lg-2"></div>
+ </div>
+</div>
diff --git a/website/_layouts/default.html b/website/_layouts/default.html
new file mode 100644
index 000000000..e5523e3fc
--- /dev/null
+++ b/website/_layouts/default.html
@@ -0,0 +1,14 @@
+---
+layout: base
+---
+{% if page.title %}
+<div class="container">
+ <div class="page-header">
+ <h1>{{ page.title }}</h1>
+ </div>
+</div>
+{% endif %}
+
+<div class="container">
+ {{ content }}
+</div>
diff --git a/website/_layouts/docs.html b/website/_layouts/docs.html
new file mode 100644
index 000000000..0422f9fb0
--- /dev/null
+++ b/website/_layouts/docs.html
@@ -0,0 +1,54 @@
+---
+layout: base
+categories:
+ - Project
+ - User Guide
+ - Architecture Guide
+ - Compatibility
+---
+
+<div class="container">
+ <div class="row">
+ <div class="col-md-3">
+ <nav class="sidebar" id="sidebar-nav">
+ {% for category in layout.categories %}
+ <h3>{{ category }}</h3>
+ <ul class="sidebar-nav">
+ {% assign subcats = site.pages | where: 'layout', 'docs' | where: 'category', category | group_by: 'subcategory' | sort: 'name', 'first' %}
+ {% for subcategory in subcats %}
+ {% assign sorted_pages = subcategory.items | sort: 'weight', 'last' %}
+ {% if subcategory.name != "" %}
+ {% assign ac = "aria-controls" %}
+ {% assign cid = category | remove: " " | downcase %}
+ {% assign sid = subcategory.name | remove: " " | downcase %}
+ <li>
+ <a class="sidebar-nav-heading" data-toggle="collapse" href="#{{ cid }}-{{ sid }}" aria-expanded="false" {{ ac }}="{{ cid }}-{{ sid }}">{{ subcategory.name }}<span class="caret"></span></a>
+ <ul class="collapse sidebar-nav sidebar-submenu" id="{{ cid }}-{{ sid }}">
+ {% endif %}
+ {% for p in sorted_pages %}
+ <li><a href="{{ p.url }}">{{ p.title }}</a></li>
+ {% endfor %}
+ {% if subcategory.name != "" %}
+ </li>
+ </ul>
+ {% endif %}
+ {% endfor %}
+ </ul>
+ {% endfor %}
+ </nav>
+ </div>
+
+ <div class="col-md-9">
+ <h1>{{ page.title }}</h1>
+ {% if page.editpath %}
+ <p>
+ <a href="https://github.com/google/gvisor/edit/master/{{page.editpath}}" target="_blank" rel="noopener"><i class="fa fa-edit fa-fw"></i> Edit this page</a>
+ <a href="https://github.com/google/gvisor/issues/new?title={{page.title | url_encode}}" target="_blank" rel="noopener"><i class="fab fa-github fa-fw"></i> Create issue</a>
+ </p>
+ {% endif %}
+ <div class="docs-content">
+ {{ content }}
+ </div>
+ </div>
+ </div>
+</div>
diff --git a/website/_layouts/post.html b/website/_layouts/post.html
new file mode 100644
index 000000000..640bee5af
--- /dev/null
+++ b/website/_layouts/post.html
@@ -0,0 +1,10 @@
+---
+layout: blog
+---
+
+<div class="blog-meta">
+ {% include byline.html authors=page.authors date=page.date %}
+</div>
+<div class="blog-content">
+ {{ content }}
+</div>
diff --git a/website/_plugins/svg_mime_type.rb b/website/_plugins/svg_mime_type.rb
new file mode 100644
index 000000000..ad6bb6480
--- /dev/null
+++ b/website/_plugins/svg_mime_type.rb
@@ -0,0 +1,3 @@
+require 'webrick'
+include WEBrick
+WEBrick::HTTPUtils::DefaultMimeTypes.store 'svg', 'image/svg+xml'
diff --git a/website/_sass/footer.scss b/website/_sass/footer.scss
new file mode 100644
index 000000000..ec2ba5e20
--- /dev/null
+++ b/website/_sass/footer.scss
@@ -0,0 +1,15 @@
+.footer {
+ margin-top: 40px;
+ background-color: #222;
+ color: #fff;
+ padding: 20px;
+
+ a {
+ color: $inverse-link-color;
+
+ &:hover,
+ &:focus {
+ color: $inverse-link-hover-color;
+ }
+ }
+}
diff --git a/website/_sass/front.scss b/website/_sass/front.scss
new file mode 100644
index 000000000..0e4208f3c
--- /dev/null
+++ b/website/_sass/front.scss
@@ -0,0 +1,17 @@
+.jumbotron {
+ background-image: url(/assets/images/background.jpg);
+ background-position: center;
+ background-repeat: no-repeat;
+ background-size: cover;
+ background-blend-mode: darken;
+ background-color: rgba(0, 0, 0, 0.3);
+
+ p {
+ color: #fff;
+ margin-top: 0;
+ margin-bottom: 0;
+ font-weight: 300;
+ font-size: 24px;
+ line-height: 30px;
+ }
+}
diff --git a/website/_sass/navbar.scss b/website/_sass/navbar.scss
new file mode 100644
index 000000000..65bc573ac
--- /dev/null
+++ b/website/_sass/navbar.scss
@@ -0,0 +1,26 @@
+.navbar-inverse {
+ background-color: $primary;
+ border-bottom: 1px solid $primary;
+
+ .navbar-brand > a {
+ color: #fff;
+
+ &:focus,
+ &:hover {
+ color: #fff;
+ }
+ }
+
+ .navbar-nav > li > a {
+ color: $inverse-link-color;
+
+ &:focus,
+ &:hover {
+ color: $inverse-link-hover-color;
+ }
+ }
+
+ .navbar-nav .nav-icon {
+ font-size: 18px;
+ }
+}
diff --git a/website/_sass/sidebar.scss b/website/_sass/sidebar.scss
new file mode 100644
index 000000000..f4ca05df9
--- /dev/null
+++ b/website/_sass/sidebar.scss
@@ -0,0 +1,61 @@
+$sidebar-border-color: #fff;
+$sidebar-hover-border-color: #66bb6a;
+
+.sidebar {
+ margin-top: 40px;
+
+ ul.sidebar-nav {
+ list-style-type: none;
+ padding: 0;
+ transition: height 0.01s;
+
+ li {
+ &.sidebar-nav-heading {
+ padding: 10px 0;
+ margin: 0;
+ display: block;
+ font-size: 16px;
+ font-weight: 300;
+ }
+
+ a {
+ padding: 4px 0;
+ display: block;
+ border-right: 2px solid $sidebar-border-color;
+
+ &:focus {
+ text-decoration: none;
+ }
+
+ .caret {
+ float: right;
+ margin-top: 8px;
+ margin-right: 10px;
+ }
+ }
+
+ &.active {
+ a {
+ border-left: 2px solid $sidebar-hover-border-color;
+ padding-left: 6px;
+ }
+ }
+ }
+
+ ul.sidebar-nav {
+ padding-left: 10px;
+ }
+ }
+}
+
+@media (min-width: 992px) {
+ .sidebar-toggle {
+ display: none;
+ }
+
+ .sidebar {
+ &.collapse {
+ display: block;
+ }
+ }
+}
diff --git a/website/_sass/style.scss b/website/_sass/style.scss
new file mode 100644
index 000000000..4deb945d4
--- /dev/null
+++ b/website/_sass/style.scss
@@ -0,0 +1,154 @@
+$primary: #262362;
+$secondary: #fff;
+$link-color: #286fd7;
+$inverse-link-color: #fff;
+
+$link-hover-color: darken($link-color, 10%);
+$inverse-link-hover-color: darken($inverse-link-color, 10%);
+
+$text-color: #444;
+
+$body-font-family: 'Roboto', 'Helvetica Neue', Helvetica, Arial, sans-serif;
+$code-font-family: 'Source Code Pro', monospace;
+
+html {
+ position: relative;
+ min-height: 100%;
+}
+
+body {
+ color: $text-color;
+ font-family: $body-font-family;
+ padding-top: 40px;
+}
+
+a {
+ color: $link-color;
+
+ &:hover,
+ &:focus {
+ color: $link-hover-color;
+ text-decoration: none;
+ }
+
+ code {
+ color: $link-color;
+ }
+}
+
+h1,
+h2,
+h3,
+h4,
+h5,
+h6 {
+ color: $text-color;
+ font-weight: 400;
+}
+
+h1 code,
+h2 code,
+h3 code,
+h4 code,
+h5 code,
+h6 code {
+ color: $text-color;
+ background: transparent;
+}
+
+h1 {
+ font-size: 30px;
+ margin-top: 40px;
+ margin-bottom: 40px;
+}
+
+h2 {
+ font-size: 24px;
+ margin-top: 30px;
+ margin-bottom: 30px;
+
+ code {
+ font-size: 24px;
+ }
+}
+
+h3 {
+ font-size: 20px;
+ margin-top: 24px;
+ margin-bottom: 24px;
+
+ code {
+ font-size: 20px;
+ }
+}
+
+h4 {
+ font-size: 18px;
+ margin-top: 20px;
+ margin-bottom: 20px;
+
+ code {
+ font-size: 18px;
+ }
+}
+
+p,
+li {
+ font-size: 14px;
+ line-height: 22px;
+}
+
+code {
+ font-family: $code-font-family;
+ font-size: 13px;
+}
+
+.btn {
+ color: $text-color;
+ background-color: $inverse-link-color;
+}
+
+.btn-inverse {
+ color: $text-color;
+ background-color: #fff;
+}
+
+.well {
+ box-shadow: none;
+}
+
+table {
+ width: 100%;
+}
+
+table td,
+table th {
+ border: 1px solid #ddd;
+ padding: 8px;
+}
+
+table tr:nth-child(even) {
+ background-color: #eee;
+}
+
+table th {
+ padding-top: 12px;
+ padding-bottom: 12px;
+ background-color: $primary;
+ color: $secondary;
+}
+
+.blog-meta {
+ margin-top: 10px;
+ margin-bottom: 20px;
+}
+
+.docs-content * img {
+ display: block;
+ margin: 20px auto;
+}
+
+.blog-content * img {
+ display: block;
+ margin: 20px auto;
+}
diff --git a/website/archive.key b/website/archive.key
new file mode 100644
index 000000000..1a91698bf
--- /dev/null
+++ b/website/archive.key
@@ -0,0 +1,29 @@
+-----BEGIN PGP PUBLIC KEY BLOCK-----
+
+mQINBF0meAYBEACcBYPOSBiKtid+qTQlbgKGPxUYt0cNZiQqWXylhYUT4PuNlNx5
+s+sBLFvNTpdTrXMmZ8NkekyjD1HardWvebvJT4u+Ho/9jUr4rP71cNwNtocz/w8G
+DsUXSLgH8SDkq6xw0L+5eGc78BBg9cOeBeFBm3UPgxTBXS9Zevoi2w1lzSxkXvjx
+cGzltzMZfPXERljgLzp9AAfhg/2ouqVQm37fY+P/NDzFMJ1XHPIIp9KJl/prBVud
+jJJteFZ5sgL6MwjBQq2kw+q2Jb8Zfjl0BeXDgGMN5M5lGhX2wTfiMbfo7KWyzRnB
+RpSP3BxlLqYeQUuLG5Yx8z3oA3uBkuKaFOKvXtiScxmGM/+Ri2YM3m66imwDhtmP
+AKwTPI3Re4gWWOffglMVSv2sUAY32XZ74yXjY1VhK3bN3WFUPGrgQx4X7GP0A1Te
+lzqkT3VSMXieImTASosK5L5Q8rryvgCeI9tQLn9EpYFCtU3LXvVgTreGNEEjMOnL
+dR7yOU+Fs775stn6ucqmdYarx7CvKUrNAhgEeHMonLe1cjYScF7NfLO1GIrQKJR2
+DE0f+uJZ52inOkO8ufh3WVQJSYszuS3HCY7w5oj1aP38k/y9zZdZvVvwAWZaiqBQ
+iwjVs6Kub76VVZZhRDf4iYs8k1Zh64nXdfQt250d8U5yMPF3wIJ+c1yhxwARAQAB
+tCpUaGUgZ1Zpc29yIEF1dGhvcnMgPGd2aXNvci1ib3RAZ29vZ2xlLmNvbT6JAlQE
+EwEKAD4WIQRvHfheOnHCSRjnJ9VvxtVU4yvZQwUCXSZ4BgIbAwUJA8JnAAULCQgH
+AgYVCgkICwIEFgIDAQIeAQIXgAAKCRBvxtVU4yvZQ5WFD/9VZXMW5I2rKV+2gTHT
+CsW74kZVi1VFdAVYiUJZXw2jJNtcg3xdgBcscYPyecyka/6TS2q7q2fOGAzCZkcR
+e3lLzkGAngMlZ7PdHAE0PDMNFaeMZW0dxNH68vn7AiA1y2XwENnxVec7iXQH6aX5
+xUNg2OCiv5f6DJItHc/Q4SvFUi8QK7TT/GYE1RJXVJlLqfO6y4V8SeqfM+FHpHZM
+gzrwdTgsNiEm4lMjWcgb2Ib4i2JUVAjIRPfcpysiV5E7c3SPXyu4bOovKKlbhiJ1
+Q1M9M0zHik34Kjf4YNO1EW936j7Msd181CJt5Bl9XvlhPb8gey/ygpIvcicLx6M5
+lRJTy4z1TtkmtZ7E8EbJZWoPTaHlA6hoMtGeE35j3vMZN1qZYaYt26eFOxxhh7PA
+J0h1lS7T2O8u1c2JKhKvajtdmbqbJgI8FRhVsMoVBnqDK5aE9MOAso36OibfweEL
+8iV2z8JnBpWtbbUEaWro4knPtbLJbQFvXVietm3cFsbGg+DMIwI6x6HcU91IEFYI
+Sv4orK7xgLuM+f6dxo/Wel3ht18dg3x3krBLALTYBidRfnQYYR3sTfLquB8b5WaY
+o829L2Bop9GBygdLevkHHN5It6q8CVpn0H5HEJMNaDOX1LcPbf0CKwkkAVCBd9YZ
+eAX38ds9LliK7XPXdC4c+zEkGA==
+=x8TG
+-----END PGP PUBLIC KEY BLOCK-----
diff --git a/website/assets/favicons/apple-touch-icon-180x180.png b/website/assets/favicons/apple-touch-icon-180x180.png
new file mode 100644
index 000000000..bf4b6ce9b
--- /dev/null
+++ b/website/assets/favicons/apple-touch-icon-180x180.png
Binary files differ
diff --git a/website/assets/favicons/favicon-16x16.png b/website/assets/favicons/favicon-16x16.png
new file mode 100644
index 000000000..083264206
--- /dev/null
+++ b/website/assets/favicons/favicon-16x16.png
Binary files differ
diff --git a/website/assets/favicons/favicon-32x32.png b/website/assets/favicons/favicon-32x32.png
new file mode 100644
index 000000000..b8e4caff1
--- /dev/null
+++ b/website/assets/favicons/favicon-32x32.png
Binary files differ
diff --git a/website/assets/favicons/favicon.ico b/website/assets/favicons/favicon.ico
new file mode 100644
index 000000000..9238b79d9
--- /dev/null
+++ b/website/assets/favicons/favicon.ico
Binary files differ
diff --git a/website/assets/favicons/pwa-192x192.png b/website/assets/favicons/pwa-192x192.png
new file mode 100644
index 000000000..5d2fab785
--- /dev/null
+++ b/website/assets/favicons/pwa-192x192.png
Binary files differ
diff --git a/website/assets/favicons/pwa-512x512.png b/website/assets/favicons/pwa-512x512.png
new file mode 100644
index 000000000..23824439e
--- /dev/null
+++ b/website/assets/favicons/pwa-512x512.png
Binary files differ
diff --git a/website/assets/favicons/tile150x150.png b/website/assets/favicons/tile150x150.png
new file mode 100644
index 000000000..f76fcffae
--- /dev/null
+++ b/website/assets/favicons/tile150x150.png
Binary files differ
diff --git a/website/assets/favicons/tile310x150.png b/website/assets/favicons/tile310x150.png
new file mode 100644
index 000000000..4f87e4c12
--- /dev/null
+++ b/website/assets/favicons/tile310x150.png
Binary files differ
diff --git a/website/assets/favicons/tile310x310.png b/website/assets/favicons/tile310x310.png
new file mode 100644
index 000000000..a2926d0bd
--- /dev/null
+++ b/website/assets/favicons/tile310x310.png
Binary files differ
diff --git a/website/assets/favicons/tile70x70.png b/website/assets/favicons/tile70x70.png
new file mode 100644
index 000000000..96cc69fc4
--- /dev/null
+++ b/website/assets/favicons/tile70x70.png
Binary files differ
diff --git a/website/assets/images/2019-11-18-security-basics-figure1.png b/website/assets/images/2019-11-18-security-basics-figure1.png
new file mode 100644
index 000000000..2a8134a7a
--- /dev/null
+++ b/website/assets/images/2019-11-18-security-basics-figure1.png
Binary files differ
diff --git a/website/assets/images/2019-11-18-security-basics-figure2.png b/website/assets/images/2019-11-18-security-basics-figure2.png
new file mode 100644
index 000000000..f8b416e1d
--- /dev/null
+++ b/website/assets/images/2019-11-18-security-basics-figure2.png
Binary files differ
diff --git a/website/assets/images/2019-11-18-security-basics-figure3.png b/website/assets/images/2019-11-18-security-basics-figure3.png
new file mode 100644
index 000000000..833e3e2b5
--- /dev/null
+++ b/website/assets/images/2019-11-18-security-basics-figure3.png
Binary files differ
diff --git a/website/assets/images/2020-04-02-networking-security-figure1.png b/website/assets/images/2020-04-02-networking-security-figure1.png
new file mode 100644
index 000000000..b49cb0242
--- /dev/null
+++ b/website/assets/images/2020-04-02-networking-security-figure1.png
Binary files differ
diff --git a/website/assets/images/background.jpg b/website/assets/images/background.jpg
new file mode 100644
index 000000000..81f8e332b
--- /dev/null
+++ b/website/assets/images/background.jpg
Binary files differ
diff --git a/website/assets/logos/Makefile b/website/assets/logos/Makefile
new file mode 100644
index 000000000..49289ecc1
--- /dev/null
+++ b/website/assets/logos/Makefile
@@ -0,0 +1,13 @@
+#!/usr/bin/make -f
+
+srcs := $(wildcard *.svg)
+dsts := $(patsubst %.svg,%.png,$(srcs))
+
+all: $(dsts)
+.PHONY: all
+
+%.png %-16.png %-128.png %-1024.png: %.svg
+ @inkscape -z -e $*.png $<
+ @inkscape -z -w 16 -e $*-16.png $<
+ @inkscape -z -w 128 -e $*-128.png $<
+ @inkscape -z -w 1024 -e $*-1024.png $<
diff --git a/website/assets/logos/README.md b/website/assets/logos/README.md
new file mode 100644
index 000000000..2964982dd
--- /dev/null
+++ b/website/assets/logos/README.md
@@ -0,0 +1,10 @@
+# Logos
+
+This directory contains logo assets.
+
+The colors used are:
+
+* Background (blue): #262262
+* Highlight (yellow): #FBB03B
+
+Use `make` to generate sized PNGs from SVGs.
diff --git a/website/assets/logos/logo_solo_monochrome.png b/website/assets/logos/logo_solo_monochrome.png
new file mode 100644
index 000000000..e09c5ad5e
--- /dev/null
+++ b/website/assets/logos/logo_solo_monochrome.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_monochrome.svg b/website/assets/logos/logo_solo_monochrome.svg
new file mode 100644
index 000000000..73126fd8f
--- /dev/null
+++ b/website/assets/logos/logo_solo_monochrome.svg
@@ -0,0 +1,73 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="175.35599"
+ height="193.20036"
+ viewBox="0 0 175.35599 193.20036"
+ sodipodi:docname="logo_solo_monochrome.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.43085925"
+ inkscape:cx="374.99057"
+ inkscape:cy="88.483321"
+ inkscape:window-x="0"
+ inkscape:window-y="9"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g48"
+ transform="translate(548.2423,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.051 2.18,-28.09 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.247,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.34,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.543,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.524,-5.837 0,0 24.645,2.263 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.073 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_dark-1024.png b/website/assets/logos/logo_solo_on_dark-1024.png
new file mode 100644
index 000000000..6df428c65
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark-128.png b/website/assets/logos/logo_solo_on_dark-128.png
new file mode 100644
index 000000000..78a85475f
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark-128.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark-16.png b/website/assets/logos/logo_solo_on_dark-16.png
new file mode 100644
index 000000000..4f1e91c02
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark-16.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark.png b/website/assets/logos/logo_solo_on_dark.png
new file mode 100644
index 000000000..da20756f7
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark.svg b/website/assets/logos/logo_solo_on_dark.svg
new file mode 100644
index 000000000..ae8d9e879
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark.svg
@@ -0,0 +1,73 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="175.35599"
+ height="193.19984"
+ viewBox="0 0 175.35599 193.19985"
+ sodipodi:docname="logo_solo_on_dark.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title /></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="1278"
+ inkscape:window-height="699"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.8617185"
+ inkscape:cx="257.20407"
+ inkscape:cy="172.193"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.96254)"><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_dark_full-1024.png b/website/assets/logos/logo_solo_on_dark_full-1024.png
new file mode 100644
index 000000000..8d597dd3d
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full-128.png b/website/assets/logos/logo_solo_on_dark_full-128.png
new file mode 100644
index 000000000..fe6dd5dea
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full-128.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full-16.png b/website/assets/logos/logo_solo_on_dark_full-16.png
new file mode 100644
index 000000000..f9aa7dfdd
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full-16.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full.png b/website/assets/logos/logo_solo_on_dark_full.png
new file mode 100644
index 000000000..611b0565e
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_dark_full.svg b/website/assets/logos/logo_solo_on_dark_full.svg
new file mode 100644
index 000000000..6440835b1
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_dark_full.svg
@@ -0,0 +1,79 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="246.17325"
+ height="246.17325"
+ viewBox="0 0 246.17325 246.17326"
+ sodipodi:docname="logo_solo_on_dark_full.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title /></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="1278"
+ inkscape:window-height="699"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.8617185"
+ inkscape:cx="291.21659"
+ inkscape:cy="198.28704"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="1"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-580.43785,665.8419)"><circle
+ id="path83"
+ cx="527.64337"
+ cy="-407.06647"
+ r="92.314972"
+ transform="scale(1,-1)"
+ style="fill:#262262;fill-opacity:1;stroke-width:0.48076925" /><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_white.png b/website/assets/logos/logo_solo_on_white.png
new file mode 100644
index 000000000..ca539cdff
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white.svg b/website/assets/logos/logo_solo_on_white.svg
new file mode 100644
index 000000000..d794ad8e7
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white.svg
@@ -0,0 +1,73 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="175.35599"
+ height="193.20036"
+ viewBox="0 0 175.35599 193.20037"
+ sodipodi:docname="logo_solo_on_white.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.43085925"
+ inkscape:cx="370.53985"
+ inkscape:cy="50.91009"
+ inkscape:window-x="0"
+ inkscape:window-y="9"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g46"
+ transform="translate(548.2428,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.742,-0.023 8.417,0.007 2.221,0.025 4.442,0.102 6.663,0.175 4.79,0.154 4.818,0.165 5.881,-4.582 C 42.434,52.544 41.469,38.505 39.11,24.477 38.859,22.985 38.409,21.521 38.246,20.023 37.05,9.081 30.089,3.645 20.164,1.097 13.786,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.248,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.341,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.542,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.525,-5.837 0,0 24.644,2.263 34.463,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.651,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.948,4.193 3.173,0.073 5.98,0.037 7.457,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_solo_on_white_bordered-1024.png b/website/assets/logos/logo_solo_on_white_bordered-1024.png
new file mode 100644
index 000000000..62bb88d50
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered-128.png b/website/assets/logos/logo_solo_on_white_bordered-128.png
new file mode 100644
index 000000000..a8988766c
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered-128.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered-16.png b/website/assets/logos/logo_solo_on_white_bordered-16.png
new file mode 100644
index 000000000..a545c49cf
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered-16.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered.png b/website/assets/logos/logo_solo_on_white_bordered.png
new file mode 100644
index 000000000..cc99b7c51
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered.png
Binary files differ
diff --git a/website/assets/logos/logo_solo_on_white_bordered.svg b/website/assets/logos/logo_solo_on_white_bordered.svg
new file mode 100644
index 000000000..2e26f144a
--- /dev/null
+++ b/website/assets/logos/logo_solo_on_white_bordered.svg
@@ -0,0 +1,82 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="190.7361"
+ height="207.92123"
+ viewBox="0 0 190.7361 207.92124"
+ sodipodi:docname="logo_solo_on_white_bordered.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.609327"
+ inkscape:cx="298.55736"
+ inkscape:cy="108.65533"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g14" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-612.10927,647.00852)"><g
+ id="g12"><g
+ id="g14"
+ clip-path="url(#clipPath18)"><g
+ id="g46"
+ transform="translate(594.4321,453.1439)"><path
+ d="m 0,0 c 0,0 -0.204,0.139 -0.45,0.277 -0.906,0.547 -1.856,1 -2.744,1.321 -1.685,0.679 -3.62,1.262 -5.761,1.738 l -2.063,2.533 c -1.793,2.283 -6.09,7.358 -13.132,13.334 -8.142,6.209 -24.212,15.045 -47.595,12.442 -4.578,-0.444 -9.077,-1.318 -13.368,-2.597 l -0.232,-0.068 c -12.978,-3.918 -24.155,-11.512 -33.24,-22.601 -0.6,-0.754 -1.179,-1.504 -1.783,-2.307 l -0.134,-0.191 -0.062,-0.074 c -1.194,-1.596 -2.36,-3.258 -3.485,-4.969 l -0.125,-0.198 c -3.559,-5.915 -6.126,-12.72 -6.85,-14.73 -11.284,-34.556 2.735,-61.502 5.669,-66.567 8.482,-14.4 23.945,-32.461 50.005,-38.975 1.42,-0.354 2.872,-0.676 4.356,-0.96 0.016,-0.004 0.036,-0.009 0.053,-0.013 6.537,-1.118 16.647,-1.928 29.969,-0.317 3.21,0.621 8.236,2.535 8.646,8.445 2.209,0.842 10.261,3.812 10.261,3.812 8.572,3.874 18.586,11.106 24.334,24.546 1.21,2.83 0.277,6.128 -2.171,7.994 l -0.202,0.152 c 0.639,1.557 1.125,3.209 1.488,4.93 l 0.019,0.009 c 0,0 0.063,0.325 0.164,0.854 0.013,0.073 0.028,0.144 0.041,0.215 0.402,2.114 1.294,6.91 1.719,10.035 0.02,0.149 0.029,0.268 0.033,0.371 2.136,14.686 2.099,26.608 -0.156,37.847 1.443,0.114 2.672,1.095 3.106,2.477 l 0.655,2.086 C 9.188,-12.066 6.243,-4.003 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(551.9267,364.2689)"><path
+ d="m 0,0 c 16.398,20.796 22.346,43.748 18.045,69.9 3.02,0 5.654,-0.022 8.288,0.007 2.187,0.025 4.374,0.101 6.56,0.172 4.716,0.152 4.743,0.163 5.79,-4.512 C 41.779,51.734 40.83,37.911 38.507,24.1 38.26,22.631 37.817,21.189 37.656,19.715 36.479,8.941 29.625,3.588 19.853,1.08 13.574,-0.531 7.209,-0.816 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(548.4282,396.6898)"><path
+ d="M 0,0 C 0,2.553 -3.404,4.255 -3.404,4.255 -0.851,5.105 0,8.51 0,8.51 0,8.51 0.851,5.105 3.404,4.255 3.404,4.255 0,2.553 0,0 m -16.835,6.354 c 0,6.638 -8.851,11.063 -8.851,11.063 6.638,2.213 8.851,11.063 8.851,11.063 0,0 2.212,-8.85 8.85,-11.063 0,0 -8.85,-4.425 -8.85,-11.063"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g><g
+ id="g58"
+ transform="translate(489.744,429.7865)"><path
+ d="m 0,0 c 0,0 -2.18,3.308 -9.585,2.026 0,0 5.894,28.296 35.737,38.209 C 26.152,40.235 -2.777,23.901 0,0 m 0.574,-32.696 -0.003,-0.003 c -1.277,-1.994 -2.778,-3.523 -4.46,-4.544 -1.492,-0.919 -3.13,-1.403 -4.67,-1.403 -0.519,0 -1.028,0.055 -1.516,0.167 l -2.174,0.5 2.15,0.593 c 1.605,0.443 3.016,1.325 4.194,2.623 0.997,1.09 1.826,2.447 2.536,4.149 1.158,2.775 1.757,6.066 1.783,9.781 -0.048,3.661 -0.673,6.945 -1.857,9.75 -0.693,1.629 -1.56,3.009 -2.58,4.099 -1.208,1.29 -2.63,2.143 -4.228,2.536 l -2.154,0.529 2.145,0.564 c 0.473,0.125 0.983,0.193 1.516,0.205 l 0.031,0.002 1.662,-0.152 c 1.029,-0.203 2.071,-0.605 3.014,-1.167 1.727,-1.015 3.25,-2.527 4.526,-4.493 2.122,-3.323 3.264,-7.421 3.305,-11.857 C 3.767,-25.296 2.653,-29.4 0.574,-32.696 M 100.951,17.69 c 0,0 -0.074,0.05 -0.223,0.136 -0.532,0.32 -1.095,0.593 -1.688,0.801 -2.997,1.223 -9.729,3.139 -21.568,2.583 -0.029,0 -0.055,0 -0.084,-0.001 C 52.1,20.798 29.559,10.754 29.559,10.754 c 0,0 1.417,1.733 3.4,3.636 0,0 10e-4,0 10e-4,10e-4 1.036,0.958 2.319,2.044 3.853,3.176 0.044,0.031 0.086,0.062 0.122,0.091 8.573,6.286 24.949,13.945 50.829,9.394 -0.615,0.698 -1.254,1.341 -1.887,1.998 l 0.057,-0.009 c 0,0 -5.858,6.886 -16.555,12.616 -0.613,0.331 -1.252,0.659 -1.91,0.985 -0.038,0.018 -0.073,0.038 -0.112,0.057 -0.067,0.033 -0.128,0.057 -0.195,0.089 -8.007,3.888 -19.263,7.05 -33.564,5.458 -3.873,-0.374 -8.015,-1.116 -12.272,-2.399 -0.007,-0.002 -0.013,-0.003 -0.02,-0.005 L 21.304,45.84 c -10,-3.018 -20.632,-9.029 -29.888,-20.328 -0.563,-0.707 -1.105,-1.409 -1.632,-2.109 -0.064,-0.096 -0.138,-0.198 -0.222,-0.301 -1.164,-1.557 -2.236,-3.091 -3.229,-4.602 -3.216,-5.343 -5.544,-11.486 -6.214,-13.337 -10.417,-31.901 2.546,-56.661 5.065,-61.011 8.793,-14.923 24.186,-31.856 49.988,-36.751 0.195,-0.046 0.377,-0.101 0.573,-0.145 1.697,-0.361 4.79,-0.914 8.771,-1.178 1.5,-0.067 3.04,-0.093 4.626,-0.065 1.525,-0.01 2.953,0.016 4.268,0.063 0.39,0.027 0.729,0.04 1.029,0.044 5.023,0.233 8.144,0.762 8.144,0.762 -26.133,1.279 -39.232,13.203 -44.987,20.816 -1.304,1.622 -2.422,3.367 -3.324,5.234 -0.356,0.699 -0.515,1.097 -0.515,1.097 8.328,-7.069 19.98,-13.155 19.98,-13.155 10.449,-4.917 21.402,-7.337 33.008,-5.747 0,0 24.264,2.227 33.932,24.702 -0.417,0.317 -0.361,0.275 -0.777,0.592 -0.642,-0.518 -1.274,-1.007 -1.898,-1.474 -0.975,-0.64 -1.933,-1.336 -2.89,-2.038 -5.184,-3.431 -9.414,-5.048 -11.934,-5.788 -19.06,-4.811 -36.698,-1.232 -52.114,12.263 -7.14,6.251 -11.572,14.131 -11.759,23.917 -0.122,6.365 -0.188,12.734 -0.135,19.101 0.084,10.023 7.135,17.645 17.126,18.986 20.245,2.716 40.598,3.652 60.993,4.129 3.125,0.072 5.888,0.036 7.342,-3.306 0.025,-0.057 0.358,0.021 0.543,0.035 1.379,4.393 -0.607,9.127 -4.223,11.444"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path60"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_monochrome.png b/website/assets/logos/logo_with_text_monochrome.png
new file mode 100644
index 000000000..17442f55d
--- /dev/null
+++ b/website/assets/logos/logo_with_text_monochrome.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_monochrome.svg b/website/assets/logos/logo_with_text_monochrome.svg
new file mode 100644
index 000000000..4648e06c0
--- /dev/null
+++ b/website/assets/logos/logo_with_text_monochrome.svg
@@ -0,0 +1,116 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.97211"
+ height="193.20036"
+ viewBox="0 0 607.97212 193.20036"
+ sodipodi:docname="logo_with_text_monochrome.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.21542963"
+ inkscape:cx="296.21626"
+ inkscape:cy="101.98009"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g14"><g
+ id="g16"
+ clip-path="url(#clipPath20)"><g
+ id="g22"
+ transform="translate(668.8995,400.2876)"><path
+ d="m 0,0 c -0.698,-5.234 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.407 -14.132,15.178 0.698,8.374 5.408,12.719 14.132,13.033 C -4.362,17.776 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.413 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.167 -6.542,-18.842 0,-8.026 2.006,-14.395 6.019,-19.105 4.012,-4.71 10.031,-7.066 18.057,-7.066 5.164,0 9.7,1.57 13.608,4.711 v -2.617 c 0,-3.141 -0.986,-6.019 -2.957,-8.636 -1.972,-2.617 -5.749,-3.925 -11.331,-3.925 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.279 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.372 V 29.31 L 0,30.88 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path24"
+ inkscape:connector-curvature="0" /></g><g
+ id="g26"
+ transform="translate(720.3033,399.9331)"><path
+ d="M 0,0 -19.986,51.176 H -37.513 L -8.457,-21.105 H 8.891 L 37.875,51.176 H 20.348 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path30"
+ inkscape:connector-curvature="0" /><g
+ id="g32"
+ transform="translate(769.7443,451.1089)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.707 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.234 1.395,-1.396 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.698 5.757,2.094 1.395,1.395 2.094,3.14 2.094,5.234 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.707 2.512,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path34"
+ inkscape:connector-curvature="0" /></g><g
+ id="g36"
+ transform="translate(816.9264,406.8296)"><path
+ d="m 0,0 c -3.315,2.477 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.643 -1.187,1.064 -1.187,2.251 0,3.559 1.186,1.309 3.541,1.962 7.065,1.962 3.56,0 6.525,-1.412 8.898,-4.239 l 8.165,9.212 c -2.374,2.338 -4.92,4.1 -7.642,5.286 -2.721,1.186 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.423 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.898 1.771,-9.91 5.313,-12.038 3.541,-2.129 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.645 0.819,-1.307 0.645,-2.615 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.093 -11.043,4.921 l -8.113,-9.161 c 2.373,-2.373 5.208,-4.265 8.505,-5.678 3.298,-1.413 7.424,-2.12 12.379,-2.12 5.199,0 9.752,1.36 13.66,4.083 3.908,2.721 5.862,6.558 5.862,11.514 C 4.972,-6.787 3.315,-2.479 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path38"
+ inkscape:connector-curvature="0" /></g><g
+ id="g40"
+ transform="translate(862.2008,394.9751)"><path
+ d="m 0,0 c -2.478,-2.67 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.334 -9.682,4.004 -2.478,2.669 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.997 2.477,2.686 5.686,4.03 9.63,4.03 4.012,0 7.257,-1.344 9.735,-4.03 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.669 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.548 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.573 2.546,-13.915 7.641,-19.025 5.095,-5.113 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.547 19.051,7.642 5.095,5.093 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.094 -11.445,7.642 -19.051,7.642"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path42"
+ inkscape:connector-curvature="0" /></g><g
+ id="g44"
+ transform="translate(911.9489,431.1675)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.92 -13.66,-5.758 v 5.81 L -28.315,-1.519 V -52.34 h 14.655 v 32.19 c 2.372,4.92 5.338,7.501 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path46"
+ inkscape:connector-curvature="0" /></g><g
+ id="g48"
+ transform="translate(548.2423,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.051 2.18,-28.09 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.247,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.34,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.543,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.524,-5.837 0,0 24.645,2.263 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.073 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_dark-1024.png b/website/assets/logos/logo_with_text_on_dark-1024.png
new file mode 100644
index 000000000..a02a9014b
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark-128.png b/website/assets/logos/logo_with_text_on_dark-128.png
new file mode 100644
index 000000000..efae725b8
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark-128.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark-16.png b/website/assets/logos/logo_with_text_on_dark-16.png
new file mode 100644
index 000000000..a6069f98f
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark-16.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark.png b/website/assets/logos/logo_with_text_on_dark.png
new file mode 100644
index 000000000..24de18c11
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark.svg b/website/assets/logos/logo_with_text_on_dark.svg
new file mode 100644
index 000000000..52d8e52da
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark.svg
@@ -0,0 +1,116 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.97211"
+ height="193.19984"
+ viewBox="0 0 607.97212 193.19985"
+ sodipodi:docname="logo_with_text_on_dark.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.21542963"
+ inkscape:cx="296.21626"
+ inkscape:cy="101.97992"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.96254)"><g
+ id="g14"><g
+ id="g16"
+ clip-path="url(#clipPath20)"><g
+ id="g22"
+ transform="translate(668.8995,400.2874)"><path
+ d="m 0,0 c -0.698,-5.233 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.408 -14.132,15.178 0.698,8.375 5.408,12.719 14.132,13.033 C -4.362,17.777 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.414 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.166 -6.542,-18.842 0,-8.026 2.006,-14.394 6.019,-19.105 4.012,-4.71 10.031,-7.065 18.057,-7.065 5.164,0 9.7,1.569 13.608,4.71 v -2.616 c 0,-3.141 -0.986,-6.02 -2.957,-8.636 -1.972,-2.618 -5.749,-3.926 -11.331,-3.926 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.28 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.373 V 29.31 L 0,30.88 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path24"
+ inkscape:connector-curvature="0" /></g><g
+ id="g26"
+ transform="translate(720.3033,399.9339)"><path
+ d="m 0,0 -19.986,51.175 h -17.527 l 29.056,-72.28 H 8.891 l 28.984,72.28 H 20.348 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path30"
+ inkscape:connector-curvature="0" /><g
+ id="g32"
+ transform="translate(769.7443,451.1087)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.706 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.233 1.395,-1.397 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.697 5.757,2.094 1.395,1.394 2.094,3.139 2.094,5.233 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.706 2.512,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path34"
+ inkscape:connector-curvature="0" /></g><g
+ id="g36"
+ transform="translate(816.9264,406.8294)"><path
+ d="m 0,0 c -3.315,2.478 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.644 -1.187,1.063 -1.187,2.25 0,3.558 1.186,1.309 3.541,1.963 7.065,1.963 3.56,0 6.525,-1.413 8.898,-4.239 l 8.165,9.211 c -2.374,2.338 -4.92,4.101 -7.642,5.286 -2.721,1.187 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.422 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.897 1.771,-9.91 5.313,-12.038 3.541,-2.128 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.644 0.819,-1.308 0.645,-2.616 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.092 -11.043,4.92 l -8.113,-9.16 c 2.373,-2.374 5.208,-4.266 8.505,-5.678 3.298,-1.414 7.424,-2.121 12.379,-2.121 5.199,0 9.752,1.361 13.66,4.083 3.908,2.722 5.862,6.559 5.862,11.515 C 4.972,-6.787 3.315,-2.478 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path38"
+ inkscape:connector-curvature="0" /></g><g
+ id="g40"
+ transform="translate(862.2008,394.9749)"><path
+ d="m 0,0 c -2.478,-2.669 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.335 -9.682,4.004 -2.478,2.67 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.998 2.477,2.685 5.686,4.029 9.63,4.029 4.012,0 7.257,-1.344 9.735,-4.029 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.67 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.547 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.572 2.546,-13.914 7.641,-19.025 5.095,-5.112 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.548 19.051,7.642 5.095,5.094 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.095 -11.445,7.642 -19.051,7.642"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path42"
+ inkscape:connector-curvature="0" /></g><g
+ id="g44"
+ transform="translate(911.9489,431.1673)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.919 -13.66,-5.757 v 5.81 l -14.655,-1.571 v -50.821 h 14.655 v 32.189 c 2.372,4.92 5.338,7.502 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path46"
+ inkscape:connector-curvature="0" /></g><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_dark_full-1024.png b/website/assets/logos/logo_with_text_on_dark_full-1024.png
new file mode 100644
index 000000000..eb2e63981
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full-1024.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full-128.png b/website/assets/logos/logo_with_text_on_dark_full-128.png
new file mode 100644
index 000000000..4ed21e5cb
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full-128.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full-16.png b/website/assets/logos/logo_with_text_on_dark_full-16.png
new file mode 100644
index 000000000..d3968da5e
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full-16.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full.png b/website/assets/logos/logo_with_text_on_dark_full.png
new file mode 100644
index 000000000..21feea356
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_dark_full.svg b/website/assets/logos/logo_with_text_on_dark_full.svg
new file mode 100644
index 000000000..017e72414
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_dark_full.svg
@@ -0,0 +1,120 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="786.19244"
+ height="334.21716"
+ viewBox="0 0 786.19246 334.21718"
+ sodipodi:docname="logo_with_text_on_dark_full.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath20"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path18"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="567"
+ inkscape:window-height="462"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.3046635"
+ inkscape:cx="541.19762"
+ inkscape:cy="67.525134"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-524.53324,714.8519)"><path
+ d="M 393.39994,285.47606 H 983.04431 V 536.13894 H 393.39994 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none;stroke-width:0.36289462"
+ id="path12"
+ inkscape:connector-curvature="0" /><g
+ id="g14"><g
+ id="g16"
+ clip-path="url(#clipPath20)"><g
+ id="g22"
+ transform="translate(668.8995,400.2874)"><path
+ d="m 0,0 c -0.698,-5.233 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.408 -14.132,15.178 0.698,8.375 5.408,12.719 14.132,13.033 C -4.362,17.777 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.414 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.166 -6.542,-18.842 0,-8.026 2.006,-14.394 6.019,-19.105 4.012,-4.71 10.031,-7.065 18.057,-7.065 5.164,0 9.7,1.569 13.608,4.71 v -2.616 c 0,-3.141 -0.986,-6.02 -2.957,-8.636 -1.972,-2.618 -5.749,-3.926 -11.331,-3.926 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.28 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.373 V 29.31 L 0,30.88 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path24"
+ inkscape:connector-curvature="0" /></g><g
+ id="g26"
+ transform="translate(720.3033,399.9339)"><path
+ d="m 0,0 -19.986,51.175 h -17.527 l 29.056,-72.28 H 8.891 l 28.984,72.28 H 20.348 Z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path30"
+ inkscape:connector-curvature="0" /><g
+ id="g32"
+ transform="translate(769.7443,451.1087)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.706 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.233 1.395,-1.397 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.697 5.757,2.094 1.395,1.394 2.094,3.139 2.094,5.233 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.706 2.512,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path34"
+ inkscape:connector-curvature="0" /></g><g
+ id="g36"
+ transform="translate(816.9264,406.8294)"><path
+ d="m 0,0 c -3.315,2.478 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.644 -1.187,1.063 -1.187,2.25 0,3.558 1.186,1.309 3.541,1.963 7.065,1.963 3.56,0 6.525,-1.413 8.898,-4.239 l 8.165,9.211 c -2.374,2.338 -4.92,4.101 -7.642,5.286 -2.721,1.187 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.422 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.897 1.771,-9.91 5.313,-12.038 3.541,-2.128 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.644 0.819,-1.308 0.645,-2.616 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.092 -11.043,4.92 l -8.113,-9.16 c 2.373,-2.374 5.208,-4.266 8.505,-5.678 3.298,-1.414 7.424,-2.121 12.379,-2.121 5.199,0 9.752,1.361 13.66,4.083 3.908,2.722 5.862,6.559 5.862,11.515 C 4.972,-6.787 3.315,-2.478 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path38"
+ inkscape:connector-curvature="0" /></g><g
+ id="g40"
+ transform="translate(862.2008,394.9749)"><path
+ d="m 0,0 c -2.478,-2.669 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.335 -9.682,4.004 -2.478,2.67 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.998 2.477,2.685 5.686,4.029 9.63,4.029 4.012,0 7.257,-1.344 9.735,-4.029 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.67 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.547 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.572 2.546,-13.914 7.641,-19.025 5.095,-5.112 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.548 19.051,7.642 5.095,5.094 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.095 -11.445,7.642 -19.051,7.642"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path42"
+ inkscape:connector-curvature="0" /></g><g
+ id="g44"
+ transform="translate(911.9489,431.1673)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.919 -13.66,-5.757 v 5.81 l -14.655,-1.571 v -50.821 h 14.655 v 32.189 c 2.372,4.92 5.338,7.502 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path46"
+ inkscape:connector-curvature="0" /></g><g
+ id="g48"
+ transform="translate(548.2423,363.2484)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.434 18.328,70.995 3.068,0 5.743,-0.023 8.417,0.007 2.222,0.025 4.443,0.102 6.664,0.175 4.79,0.154 4.818,0.165 5.88,-4.582 3.145,-14.05 2.18,-28.089 -0.179,-42.118 -0.25,-1.492 -0.7,-2.956 -0.864,-4.454 C 37.05,9.081 30.089,3.645 20.165,1.097 13.787,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path50"
+ inkscape:connector-curvature="0" /></g><g
+ id="g52"
+ transform="translate(544.6891,396.1771)"><path
+ d="M 0,0 C 0,2.593 -3.457,4.321 -3.457,4.321 -0.864,5.186 0,8.644 0,8.644 0,8.644 0.865,5.186 3.458,4.321 3.458,4.321 0,2.593 0,0 m -17.099,6.453 c 0,6.742 -8.989,11.237 -8.989,11.237 6.742,2.247 8.989,11.237 8.989,11.237 0,0 2.247,-8.99 8.99,-11.237 0,0 -8.99,-4.495 -8.99,-11.237"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path54"
+ inkscape:connector-curvature="0" /></g><g
+ id="g56"
+ transform="translate(485.0861,429.7923)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.578 -4.53,-4.615 -1.515,-0.934 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.056 -1.539,0.169 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.214 1.175,2.819 1.784,6.161 1.81,9.935 -0.049,3.719 -0.683,7.054 -1.886,9.902 -0.703,1.655 -1.585,3.057 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.127 0.998,0.196 1.539,0.209 l 0.031,10e-4 1.688,-0.154 c 1.045,-0.206 2.104,-0.615 3.061,-1.184 1.755,-1.032 3.302,-2.568 4.598,-4.565 2.155,-3.374 3.315,-7.537 3.357,-12.043 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.051 -0.226,0.137 -0.541,0.326 -1.113,0.602 -1.715,0.814 -3.044,1.241 -9.881,3.187 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.923 30.022,10.923 c 0,0 1.439,1.76 3.453,3.691 10e-4,10e-4 10e-4,10e-4 0.002,0.002 1.052,0.973 2.355,2.076 3.912,3.226 0.046,0.032 0.088,0.063 0.124,0.094 8.708,6.383 25.34,14.162 51.625,9.54 -0.989,1.124 -2.002,2.193 -3.036,3.215 -1.112,0.884 -2.231,1.694 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.868 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 c -0.571,-0.718 -1.122,-1.431 -1.657,-2.14 -0.065,-0.1 -0.141,-0.202 -0.226,-0.307 -1.182,-1.582 -2.271,-3.141 -3.279,-4.674 -3.266,-5.427 -5.631,-11.666 -6.311,-13.546 -10.58,-32.401 2.586,-57.549 5.144,-61.967 8.93,-15.157 24.565,-32.355 50.771,-37.327 0.197,-0.046 0.382,-0.101 0.582,-0.146 1.723,-0.367 4.864,-0.929 8.908,-1.197 1.524,-0.069 3.088,-0.094 4.699,-0.066 1.548,-0.01 2.999,0.017 4.335,0.064 0.396,0.027 0.74,0.041 1.044,0.044 5.102,0.237 8.272,0.774 8.272,0.774 -26.543,1.3 -39.847,13.41 -45.691,21.142 -1.325,1.648 -2.46,3.421 -3.377,5.316 -0.361,0.711 -0.523,1.115 -0.523,1.115 8.459,-7.18 20.294,-13.361 20.294,-13.361 10.611,-4.993 21.737,-7.452 33.524,-5.838 0,0 24.645,2.264 34.464,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.652,-0.526 -1.294,-1.022 -1.926,-1.496 -0.991,-0.65 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.127 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.353 -11.944,24.291 -0.124,6.466 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.947,4.193 3.174,0.074 5.981,0.037 7.458,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.462 -0.617,9.27 -4.29,11.624"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path58"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_white.png b/website/assets/logos/logo_with_text_on_white.png
new file mode 100644
index 000000000..bf420a057
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_white.svg b/website/assets/logos/logo_with_text_on_white.svg
new file mode 100644
index 000000000..4275efe83
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white.svg
@@ -0,0 +1,116 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.97211"
+ height="193.20036"
+ viewBox="0 0 607.97212 193.20037"
+ sodipodi:docname="logo_with_text_on_white.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.43085925"
+ inkscape:cx="325.455"
+ inkscape:cy="50.910097"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-614.45037,638.9628)"><g
+ id="g12"><g
+ id="g14"
+ clip-path="url(#clipPath18)"><g
+ id="g20"
+ transform="translate(668.8995,400.2876)"><path
+ d="m 0,0 c -0.698,-5.234 -4.362,-8.375 -10.991,-9.421 -9.072,0.349 -13.783,5.407 -14.132,15.178 0.698,8.374 5.408,12.719 14.132,13.033 C -4.362,17.776 -0.698,14.341 0,8.479 Z m 0,26.117 c -2.442,2.826 -6.629,4.413 -12.561,4.763 -8.026,0 -14.219,-2.443 -18.581,-7.327 -4.361,-4.886 -6.542,-11.167 -6.542,-18.842 0,-8.026 2.006,-14.395 6.019,-19.105 4.012,-4.71 10.031,-7.066 18.057,-7.066 5.164,0 9.7,1.57 13.608,4.711 v -2.617 c 0,-3.141 -0.986,-6.019 -2.957,-8.636 -1.972,-2.617 -5.749,-3.925 -11.331,-3.925 -5.479,0.349 -10.137,1.919 -13.975,4.71 l -6.804,-7.85 c 5.582,-5.582 13.433,-8.48 23.553,-8.689 l 3.14,0.314 c 6.629,0.279 12.124,2.879 16.487,7.798 4.361,4.92 6.542,11.044 6.542,18.372 V 29.31 L 0,30.88 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path22"
+ inkscape:connector-curvature="0" /></g><g
+ id="g24"
+ transform="translate(720.3033,399.9331)"><path
+ d="M 0,0 -19.986,51.176 H -37.513 L -8.457,-21.105 H 8.891 L 37.875,51.176 H 20.348 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path26"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.522,378.828 h 14.655 v 52.392 h -14.655 z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /><g
+ id="g30"
+ transform="translate(769.7443,451.1089)"><path
+ d="m 0,0 c -2.373,0 -4.257,-0.707 -5.652,-2.12 -1.397,-1.413 -2.094,-3.166 -2.094,-5.26 0,-2.094 0.697,-3.839 2.094,-5.234 1.395,-1.396 3.315,-2.094 5.757,-2.094 2.442,0 4.361,0.698 5.757,2.094 1.395,1.395 2.094,3.14 2.094,5.234 0,2.094 -0.699,3.847 -2.094,5.26 C 4.466,-0.707 2.512,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path32"
+ inkscape:connector-curvature="0" /></g><g
+ id="g34"
+ transform="translate(816.9264,406.8296)"><path
+ d="m 0,0 c -3.315,2.477 -7.572,3.96 -12.771,4.449 -3.524,0.697 -5.879,1.578 -7.065,2.643 -1.187,1.064 -1.187,2.251 0,3.559 1.186,1.309 3.541,1.962 7.065,1.962 3.56,0 6.525,-1.412 8.898,-4.239 l 8.165,9.212 c -2.374,2.338 -4.92,4.1 -7.642,5.286 -2.721,1.186 -6.542,1.78 -11.462,1.78 -4.257,0 -8.4,-1.423 -12.43,-4.265 -4.03,-2.845 -6.046,-6.971 -6.046,-12.379 0,-5.898 1.771,-9.91 5.313,-12.038 3.541,-2.129 7.44,-3.437 11.697,-3.925 4.712,-0.454 7.476,-1.335 8.297,-2.645 0.819,-1.307 0.645,-2.615 -0.523,-3.924 -1.171,-1.309 -3.761,-1.963 -7.774,-1.963 -4.99,0.453 -8.67,2.093 -11.043,4.921 l -8.113,-9.161 c 2.373,-2.373 5.208,-4.265 8.505,-5.678 3.298,-1.413 7.424,-2.12 12.379,-2.12 5.199,0 9.752,1.36 13.66,4.083 3.908,2.721 5.862,6.558 5.862,11.514 C 4.972,-6.787 3.315,-2.479 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path36"
+ inkscape:connector-curvature="0" /></g><g
+ id="g38"
+ transform="translate(862.2008,394.9751)"><path
+ d="m 0,0 c -2.478,-2.67 -5.705,-4.004 -9.683,-4.004 -3.977,0 -7.205,1.334 -9.682,4.004 -2.478,2.669 -3.717,5.992 -3.717,9.97 0,3.978 1.239,7.31 3.717,9.997 2.477,2.686 5.686,4.03 9.63,4.03 4.012,0 7.257,-1.344 9.735,-4.03 C 2.477,17.28 3.716,13.948 3.716,9.97 3.716,5.992 2.477,2.669 0,0 m -9.63,36.716 c -7.608,0 -13.957,-2.548 -19.052,-7.642 -5.095,-5.095 -7.641,-11.445 -7.641,-19.051 0,-7.573 2.546,-13.915 7.641,-19.025 5.095,-5.113 11.444,-7.669 19.052,-7.669 7.606,0 13.956,2.547 19.051,7.642 5.095,5.093 7.641,11.445 7.641,19.052 0,7.606 -2.546,13.956 -7.641,19.051 -5.095,5.094 -11.445,7.642 -19.051,7.642"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path40"
+ inkscape:connector-curvature="0" /></g><g
+ id="g42"
+ transform="translate(911.9489,431.1675)"><path
+ d="m 0,0 c -5.374,0 -9.927,-1.92 -13.66,-5.758 v 5.81 L -28.315,-1.519 V -52.34 h 14.655 v 32.19 c 2.372,4.92 5.338,7.501 8.898,7.746 3.175,-0.105 5.617,-0.925 7.327,-2.46 L 4.868,-0.419 C 3.332,-0.14 1.709,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path44"
+ inkscape:connector-curvature="0" /></g><g
+ id="g46"
+ transform="translate(548.2428,363.2485)"><path
+ d="m 0,0 c 16.655,21.121 22.696,44.433 18.328,70.995 3.068,0 5.742,-0.023 8.417,0.007 2.221,0.025 4.442,0.102 6.663,0.175 4.79,0.154 4.818,0.165 5.881,-4.582 C 42.434,52.544 41.469,38.505 39.11,24.477 38.859,22.985 38.409,21.521 38.246,20.023 37.05,9.081 30.089,3.645 20.164,1.097 13.786,-0.54 7.323,-0.829 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(544.6891,396.1763)"><path
+ d="M 0,0 C 0,2.594 -3.457,4.322 -3.457,4.322 -0.864,5.187 0,8.644 0,8.644 0,8.644 0.865,5.187 3.458,4.322 3.458,4.322 0,2.594 0,0 m -17.099,6.454 c 0,6.742 -8.989,11.236 -8.989,11.236 6.742,2.248 8.989,11.238 8.989,11.238 0,0 2.248,-8.99 8.99,-11.238 0,0 -8.99,-4.494 -8.99,-11.236"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(485.0861,429.7925)"><path
+ d="m 0,0 c 0,0 -2.214,3.359 -9.736,2.059 0,0 5.987,28.738 36.298,38.806 C 26.562,40.865 -2.82,24.275 0,0 M 0.583,-33.208 0.58,-33.211 c -1.297,-2.026 -2.821,-3.579 -4.53,-4.616 -1.515,-0.933 -3.178,-1.425 -4.743,-1.425 -0.528,0 -1.044,0.057 -1.539,0.17 l -2.209,0.507 2.184,0.603 c 1.63,0.451 3.063,1.347 4.259,2.664 1.014,1.108 1.856,2.485 2.577,4.213 1.175,2.82 1.784,6.162 1.81,9.936 -0.049,3.718 -0.683,7.054 -1.886,9.902 -0.703,1.654 -1.585,3.056 -2.621,4.163 -1.227,1.311 -2.671,2.178 -4.294,2.576 l -2.187,0.538 2.179,0.572 c 0.48,0.126 0.998,0.196 1.539,0.209 h 0.031 l 1.688,-0.153 c 1.045,-0.206 2.104,-0.616 3.061,-1.185 1.755,-1.031 3.302,-2.567 4.598,-4.565 2.155,-3.374 3.315,-7.536 3.357,-12.042 -0.028,-4.548 -1.159,-8.717 -3.271,-12.064 m 101.949,51.176 c 0,0 -0.075,0.05 -0.226,0.136 -0.541,0.327 -1.113,0.603 -1.715,0.815 -3.044,1.241 -9.881,3.186 -21.906,2.623 -0.029,0 -0.056,0 -0.085,-0.001 C 52.916,21.123 30.022,10.922 30.022,10.922 c 0,0 1.439,1.761 3.453,3.692 10e-4,0 10e-4,10e-4 0.002,10e-4 1.052,0.974 2.355,2.076 3.912,3.227 0.046,0.031 0.088,0.063 0.124,0.093 8.708,6.384 25.341,14.163 51.625,9.541 -0.989,1.124 -2.002,2.192 -3.036,3.215 -1.112,0.883 -2.231,1.693 -3.354,2.456 0.02,-0.012 0.039,-0.023 0.059,-0.036 0,0 -17.016,19.415 -48.683,15.891 C 30.19,48.622 25.983,47.867 21.66,46.564 21.653,46.563 21.646,46.562 21.64,46.56 L 21.638,46.558 C 11.48,43.492 0.683,37.387 -8.719,25.911 -9.29,25.193 -9.841,24.479 -10.376,23.77 c -0.065,-0.099 -0.141,-0.202 -0.226,-0.307 -1.182,-1.581 -2.271,-3.14 -3.279,-4.674 -3.266,-5.427 -5.631,-11.665 -6.311,-13.545 -10.58,-32.401 2.586,-57.55 5.144,-61.967 8.93,-15.158 24.565,-32.355 50.771,-37.327 0.197,-0.047 0.382,-0.101 0.582,-0.147 1.723,-0.367 4.864,-0.929 8.908,-1.196 1.524,-0.069 3.088,-0.094 4.699,-0.067 1.548,-0.009 2.999,0.017 4.335,0.064 0.396,0.028 0.74,0.041 1.044,0.044 5.102,0.238 8.272,0.775 8.272,0.775 -26.542,1.299 -39.847,13.409 -45.691,21.142 -1.325,1.648 -2.46,3.42 -3.377,5.316 -0.361,0.71 -0.523,1.115 -0.523,1.115 8.459,-7.181 20.294,-13.362 20.294,-13.362 10.611,-4.993 21.737,-7.451 33.525,-5.837 0,0 24.644,2.263 34.463,25.09 -0.423,0.322 -0.366,0.278 -0.79,0.6 -0.651,-0.526 -1.294,-1.023 -1.926,-1.496 -0.991,-0.651 -1.964,-1.357 -2.937,-2.07 -5.265,-3.485 -9.561,-5.128 -12.12,-5.879 -19.359,-4.887 -37.273,-1.252 -52.93,12.455 -7.253,6.349 -11.754,14.352 -11.944,24.291 -0.124,6.465 -0.19,12.935 -0.136,19.4 0.085,10.181 7.246,17.921 17.394,19.284 20.561,2.759 41.234,3.71 61.948,4.193 3.173,0.073 5.98,0.037 7.457,-3.356 0.025,-0.058 0.363,0.02 0.552,0.035 1.4,4.461 -0.617,9.27 -4.29,11.624"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/logo_with_text_on_white_bordered.png b/website/assets/logos/logo_with_text_on_white_bordered.png
new file mode 100644
index 000000000..bd1a1e4b7
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white_bordered.png
Binary files differ
diff --git a/website/assets/logos/logo_with_text_on_white_bordered.svg b/website/assets/logos/logo_with_text_on_white_bordered.svg
new file mode 100644
index 000000000..08125629d
--- /dev/null
+++ b/website/assets/logos/logo_with_text_on_white_bordered.svg
@@ -0,0 +1,122 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!-- Created with Inkscape (http://www.inkscape.org/) -->
+
+<svg
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
+ xmlns:cc="http://creativecommons.org/ns#"
+ xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
+ xmlns:svg="http://www.w3.org/2000/svg"
+ xmlns="http://www.w3.org/2000/svg"
+ xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
+ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
+ version="1.1"
+ id="svg2"
+ xml:space="preserve"
+ width="607.64587"
+ height="207.92123"
+ viewBox="0 0 607.64589 207.92124"
+ sodipodi:docname="logo_with_text_on_white_bordered.svg"
+ inkscape:version="0.92.3 (2405546, 2018-03-11)"><metadata
+ id="metadata8"><rdf:RDF><cc:Work
+ rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type
+ rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs
+ id="defs6"><clipPath
+ clipPathUnits="userSpaceOnUse"
+ id="clipPath18"><path
+ d="M 0,821.614 H 1366 V 0 H 0 Z"
+ id="path16"
+ inkscape:connector-curvature="0" /></clipPath></defs><sodipodi:namedview
+ pagecolor="#ffffff"
+ bordercolor="#666666"
+ borderopacity="1"
+ objecttolerance="10"
+ gridtolerance="10"
+ guidetolerance="10"
+ inkscape:pageopacity="0"
+ inkscape:pageshadow="2"
+ inkscape:window-width="640"
+ inkscape:window-height="480"
+ id="namedview4"
+ showgrid="false"
+ fit-margin-top="0"
+ fit-margin-left="0"
+ fit-margin-right="0"
+ fit-margin-bottom="0"
+ inkscape:zoom="0.21542963"
+ inkscape:cx="298.55736"
+ inkscape:cy="108.65533"
+ inkscape:window-x="0"
+ inkscape:window-y="0"
+ inkscape:window-maximized="0"
+ inkscape:current-layer="g10" /><g
+ id="g10"
+ inkscape:groupmode="layer"
+ inkscape:label="gvisor_final-logo_20190313"
+ transform="matrix(1.3333333,0,0,-1.3333333,-612.10927,647.00852)"><g
+ id="g12"><g
+ id="g14"
+ clip-path="url(#clipPath18)"><g
+ id="g20"
+ transform="translate(670.7226,400.7367)"><path
+ d="m 0,0 c -0.687,-5.153 -4.294,-8.246 -10.821,-9.275 -8.933,0.342 -13.571,5.324 -13.914,14.943 0.687,8.245 5.325,12.522 13.914,12.832 C -4.294,17.503 -0.687,14.12 0,8.348 Z m 0,25.715 c -2.404,2.782 -6.527,4.345 -12.367,4.688 -7.902,0 -14,-2.404 -18.294,-7.214 -4.295,-4.81 -6.442,-10.994 -6.442,-18.55 0,-7.903 1.976,-14.172 5.927,-18.811 3.949,-4.638 9.876,-6.956 17.778,-6.956 5.085,0 9.55,1.545 13.398,4.638 v -2.577 c 0,-3.092 -0.971,-5.926 -2.911,-8.502 -1.942,-2.577 -5.66,-3.866 -11.156,-3.866 -5.395,0.344 -9.981,1.89 -13.76,4.638 l -6.699,-7.729 c 5.496,-5.496 13.226,-8.349 23.19,-8.555 l 3.091,0.309 c 6.527,0.275 11.938,2.834 16.233,7.678 4.294,4.844 6.441,10.873 6.441,18.088 V 28.857 L 0,30.403 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path22"
+ inkscape:connector-curvature="0" /></g><g
+ id="g24"
+ transform="translate(721.3339,400.388)"><path
+ d="M 0,0 -19.678,50.387 H -36.935 L -8.326,-20.779 H 8.754 L 37.291,50.387 H 20.034 Z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path26"
+ inkscape:connector-curvature="0" /></g><path
+ d="m 762.901,379.609 h 14.429 v 51.583 h -14.429 z"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path28"
+ inkscape:connector-curvature="0" /><g
+ id="g30"
+ transform="translate(770.0126,450.7747)"><path
+ d="m 0,0 c -2.337,0 -4.191,-0.696 -5.565,-2.088 -1.375,-1.391 -2.062,-3.117 -2.062,-5.179 0,-2.061 0.687,-3.779 2.062,-5.153 1.374,-1.374 3.263,-2.061 5.669,-2.061 2.404,0 4.293,0.687 5.667,2.061 1.375,1.374 2.062,3.092 2.062,5.153 0,2.062 -0.687,3.788 -2.062,5.179 C 4.397,-0.696 2.474,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path32"
+ inkscape:connector-curvature="0" /></g><g
+ id="g34"
+ transform="translate(816.4667,407.1781)"><path
+ d="m 0,0 c -3.264,2.438 -7.455,3.898 -12.573,4.381 -3.471,0.686 -5.789,1.554 -6.957,2.601 -1.168,1.048 -1.168,2.216 0,3.504 1.168,1.289 3.486,1.933 6.957,1.933 3.504,0 6.424,-1.391 8.76,-4.174 l 8.039,9.069 c -2.336,2.302 -4.844,4.038 -7.524,5.206 -2.68,1.167 -6.441,1.751 -11.285,1.751 -4.192,0 -8.271,-1.4 -12.239,-4.199 -3.968,-2.801 -5.951,-6.863 -5.951,-12.187 0,-5.807 1.742,-9.758 5.229,-11.854 3.486,-2.094 7.326,-3.383 11.518,-3.863 4.638,-0.447 7.36,-1.315 8.168,-2.604 0.807,-1.288 0.635,-2.576 -0.515,-3.864 -1.152,-1.289 -3.703,-1.932 -7.653,-1.932 -4.913,0.446 -8.537,2.06 -10.873,4.843 l -7.988,-9.017 c 2.336,-2.337 5.127,-4.201 8.374,-5.592 3.247,-1.392 7.309,-2.087 12.187,-2.087 5.119,0 9.602,1.339 13.45,4.02 3.848,2.679 5.772,6.458 5.772,11.336 C 4.896,-6.683 3.264,-2.44 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path36"
+ inkscape:connector-curvature="0" /></g><g
+ id="g38"
+ transform="translate(861.0429,395.5062)"><path
+ d="m 0,0 c -2.439,-2.628 -5.616,-3.942 -9.533,-3.942 -3.916,0 -7.095,1.314 -9.533,3.942 -2.44,2.628 -3.66,5.9 -3.66,9.816 0,3.917 1.22,7.198 3.66,9.843 2.438,2.646 5.598,3.968 9.482,3.968 3.949,0 7.145,-1.322 9.584,-3.968 C 2.439,17.014 3.659,13.733 3.659,9.816 3.659,5.9 2.439,2.628 0,0 m -9.481,36.149 c -7.491,0 -13.743,-2.507 -18.758,-7.523 -5.017,-5.017 -7.524,-11.269 -7.524,-18.757 0,-7.456 2.507,-13.7 7.524,-18.732 5.015,-5.034 11.267,-7.55 18.758,-7.55 7.489,0 13.741,2.508 18.757,7.523 5.016,5.016 7.524,11.269 7.524,18.759 0,7.488 -2.508,13.74 -7.524,18.757 -5.016,5.016 -11.268,7.523 -18.757,7.523"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path40"
+ inkscape:connector-curvature="0" /></g><g
+ id="g42"
+ transform="translate(910.0244,431.14)"><path
+ d="m 0,0 c -5.291,0 -9.773,-1.89 -13.449,-5.668 v 5.72 l -14.43,-1.546 v -50.037 h 14.43 v 31.691 c 2.335,4.845 5.255,7.387 8.76,7.627 3.126,-0.102 5.531,-0.91 7.214,-2.422 L 4.792,-0.412 C 3.28,-0.138 1.683,0 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path44"
+ inkscape:connector-curvature="0" /></g><g
+ id="g46"
+ transform="translate(594.4321,453.1439)"><path
+ d="m 0,0 c 0,0 -0.204,0.139 -0.45,0.277 -0.906,0.547 -1.856,1 -2.744,1.321 -1.685,0.679 -3.62,1.262 -5.761,1.738 l -2.063,2.533 c -1.793,2.283 -6.09,7.358 -13.132,13.334 -8.142,6.209 -24.212,15.045 -47.595,12.442 -4.578,-0.444 -9.077,-1.318 -13.368,-2.597 l -0.232,-0.068 c -12.978,-3.918 -24.155,-11.512 -33.24,-22.601 -0.6,-0.754 -1.179,-1.504 -1.783,-2.307 l -0.134,-0.191 -0.062,-0.074 c -1.194,-1.596 -2.36,-3.258 -3.485,-4.969 l -0.125,-0.198 c -3.559,-5.915 -6.126,-12.72 -6.85,-14.73 -11.284,-34.556 2.735,-61.502 5.669,-66.567 8.482,-14.4 23.945,-32.461 50.005,-38.975 1.42,-0.354 2.872,-0.676 4.356,-0.96 0.016,-0.004 0.036,-0.009 0.053,-0.013 6.537,-1.118 16.647,-1.928 29.969,-0.317 3.21,0.621 8.236,2.535 8.646,8.445 2.209,0.842 10.261,3.812 10.261,3.812 8.572,3.874 18.586,11.106 24.334,24.546 1.21,2.83 0.277,6.128 -2.171,7.994 l -0.202,0.152 c 0.639,1.557 1.125,3.209 1.488,4.93 l 0.019,0.009 c 0,0 0.063,0.325 0.164,0.854 0.013,0.073 0.028,0.144 0.041,0.215 0.402,2.114 1.294,6.91 1.719,10.035 0.02,0.149 0.029,0.268 0.033,0.371 2.136,14.686 2.099,26.608 -0.156,37.847 1.443,0.114 2.672,1.095 3.106,2.477 l 0.655,2.086 C 9.188,-12.066 6.243,-4.003 0,0"
+ style="fill:#262262;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path48"
+ inkscape:connector-curvature="0" /></g><g
+ id="g50"
+ transform="translate(551.9267,364.2689)"><path
+ d="m 0,0 c 16.398,20.796 22.346,43.748 18.045,69.9 3.02,0 5.654,-0.022 8.288,0.007 2.187,0.025 4.374,0.101 6.56,0.172 4.716,0.152 4.743,0.163 5.79,-4.512 C 41.779,51.734 40.83,37.911 38.507,24.1 38.26,22.631 37.817,21.189 37.656,19.715 36.479,8.941 29.625,3.588 19.853,1.08 13.574,-0.531 7.209,-0.816 0,0"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path52"
+ inkscape:connector-curvature="0" /></g><g
+ id="g54"
+ transform="translate(548.4282,396.6898)"><path
+ d="M 0,0 C 0,2.553 -3.404,4.255 -3.404,4.255 -0.851,5.105 0,8.51 0,8.51 0,8.51 0.851,5.105 3.404,4.255 3.404,4.255 0,2.553 0,0 m -16.835,6.354 c 0,6.638 -8.851,11.063 -8.851,11.063 6.638,2.213 8.851,11.063 8.851,11.063 0,0 2.212,-8.85 8.85,-11.063 0,0 -8.85,-4.425 -8.85,-11.063"
+ style="fill:#fbb03b;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path56"
+ inkscape:connector-curvature="0" /></g><g
+ id="g58"
+ transform="translate(489.744,429.7865)"><path
+ d="m 0,0 c 0,0 -2.18,3.308 -9.585,2.026 0,0 5.894,28.296 35.737,38.209 C 26.152,40.235 -2.777,23.901 0,0 m 0.574,-32.696 -0.003,-0.003 c -1.277,-1.994 -2.778,-3.523 -4.46,-4.544 -1.492,-0.919 -3.13,-1.403 -4.67,-1.403 -0.519,0 -1.028,0.055 -1.516,0.167 l -2.174,0.5 2.15,0.593 c 1.605,0.443 3.016,1.325 4.194,2.623 0.997,1.09 1.826,2.447 2.536,4.149 1.158,2.775 1.757,6.066 1.783,9.781 -0.048,3.661 -0.673,6.945 -1.857,9.75 -0.693,1.629 -1.56,3.009 -2.58,4.099 -1.208,1.29 -2.63,2.143 -4.228,2.536 l -2.154,0.529 2.145,0.564 c 0.473,0.125 0.983,0.193 1.516,0.205 l 0.031,0.002 1.662,-0.152 c 1.029,-0.203 2.071,-0.605 3.014,-1.167 1.727,-1.015 3.25,-2.527 4.526,-4.493 2.122,-3.323 3.264,-7.421 3.305,-11.857 C 3.767,-25.296 2.653,-29.4 0.574,-32.696 M 100.951,17.69 c 0,0 -0.074,0.05 -0.223,0.136 -0.532,0.32 -1.095,0.593 -1.688,0.801 -2.997,1.223 -9.729,3.139 -21.568,2.583 -0.029,0 -0.055,0 -0.084,-0.001 C 52.1,20.798 29.559,10.754 29.559,10.754 c 0,0 1.417,1.733 3.4,3.636 0,0 10e-4,0 10e-4,10e-4 1.036,0.958 2.319,2.044 3.853,3.176 0.044,0.031 0.086,0.062 0.122,0.091 8.573,6.286 24.949,13.945 50.829,9.394 -0.615,0.698 -1.254,1.341 -1.887,1.998 l 0.057,-0.009 c 0,0 -5.858,6.886 -16.555,12.616 -0.613,0.331 -1.252,0.659 -1.91,0.985 -0.038,0.018 -0.073,0.038 -0.112,0.057 -0.067,0.033 -0.128,0.057 -0.195,0.089 -8.007,3.888 -19.263,7.05 -33.564,5.458 -3.873,-0.374 -8.015,-1.116 -12.272,-2.399 -0.007,-0.002 -0.013,-0.003 -0.02,-0.005 L 21.304,45.84 c -10,-3.018 -20.632,-9.029 -29.888,-20.328 -0.563,-0.707 -1.105,-1.409 -1.632,-2.109 -0.064,-0.096 -0.138,-0.198 -0.222,-0.301 -1.164,-1.557 -2.236,-3.091 -3.229,-4.602 -3.216,-5.343 -5.544,-11.486 -6.214,-13.337 -10.417,-31.901 2.546,-56.661 5.065,-61.011 8.793,-14.923 24.186,-31.856 49.988,-36.751 0.195,-0.046 0.377,-0.101 0.573,-0.145 1.697,-0.361 4.79,-0.914 8.771,-1.178 1.5,-0.067 3.04,-0.093 4.626,-0.065 1.525,-0.01 2.953,0.016 4.268,0.063 0.39,0.027 0.729,0.04 1.029,0.044 5.023,0.233 8.144,0.762 8.144,0.762 -26.133,1.279 -39.232,13.203 -44.987,20.816 -1.304,1.622 -2.422,3.367 -3.324,5.234 -0.356,0.699 -0.515,1.097 -0.515,1.097 8.328,-7.069 19.98,-13.155 19.98,-13.155 10.449,-4.917 21.402,-7.337 33.008,-5.747 0,0 24.264,2.227 33.932,24.702 -0.417,0.317 -0.361,0.275 -0.777,0.592 -0.642,-0.518 -1.274,-1.007 -1.898,-1.474 -0.975,-0.64 -1.933,-1.336 -2.89,-2.038 -5.184,-3.431 -9.414,-5.048 -11.934,-5.788 -19.06,-4.811 -36.698,-1.232 -52.114,12.263 -7.14,6.251 -11.572,14.131 -11.759,23.917 -0.122,6.365 -0.188,12.734 -0.135,19.101 0.084,10.023 7.135,17.645 17.126,18.986 20.245,2.716 40.598,3.652 60.993,4.129 3.125,0.072 5.888,0.036 7.342,-3.306 0.025,-0.057 0.358,0.021 0.543,0.035 1.379,4.393 -0.607,9.127 -4.223,11.444"
+ style="fill:#ffffff;fill-opacity:1;fill-rule:nonzero;stroke:none"
+ id="path60"
+ inkscape:connector-curvature="0" /></g></g></g></g></svg> \ No newline at end of file
diff --git a/website/assets/logos/powered-gvisor.png b/website/assets/logos/powered-gvisor.png
new file mode 100644
index 000000000..e00c74a33
--- /dev/null
+++ b/website/assets/logos/powered-gvisor.png
Binary files differ
diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md
new file mode 100644
index 000000000..76bbabc13
--- /dev/null
+++ b/website/blog/2019-11-18-security-basics.md
@@ -0,0 +1,306 @@
+# gVisor Security Basics - Part 1
+
+This blog is a space for engineers and community members to share perspectives
+and deep dives on technology and design within the gVisor project. Though our
+logo suggests we're in the business of space exploration (or perhaps fighting
+sea monsters), we're actually in the business of sandboxing Linux containers.
+When we created gVisor, we had three specific goals in mind; _container-native
+security_, _resource efficiency_, and _platform portability_. To put it simply,
+gVisor provides _efficient defense-in-depth for containers anywhere_.
+
+This post addresses gVisor's _container-native security_, specifically how
+gVisor provides strong isolation between an application and the host OS. Future
+posts will address _resource efficiency_ (how gVisor preserves container
+benefits like fast starts, smaller snapshots, and less memory overhead than VMs)
+and _platform portability_ (run gVisor wherever Linux OCI containers run).
+Delivering on each of these goals requires careful security considerations and a
+robust design.
+
+## What does "sandbox" mean?
+
+gVisor allows the execution of untrusted containers, preventing them from
+adversely affecting the host. This means that the untrusted container is
+prevented from attacking or spying on either the host kernel or any other peer
+userspace processes on the host.
+
+For example, if you are a cloud container hosting service, running containers
+from different customers on the same virtual machine means that compromises
+expose customer data. Properly configured, gVisor can provide sufficient
+isolation to allow different customers to run containers on the same host. There
+are many aspects to the proper configuration, including limiting file and
+network access, which we will discuss in future posts.
+
+## The cost of compromise
+
+gVisor was designed around the premise that any security boundary could
+potentially be compromised with enough time and resources. We tried to optimize
+for a solution that was as costly and time-consuming for an attacker as
+possible, at every layer.
+
+Consequently, gVisor was built through a combination of intentional design
+principles and specific technology choices that work together to provide the
+security isolation needed for running hostile containers on a host. We'll dig
+into it in the next section!
+
+# Design Principles
+
+gVisor was designed with some common
+[secure design](https://en.wikipedia.org/wiki/Secure_by_design) principles in
+mind: Defense-in-Depth, Principle of Least-Privilege, Attack Surface Reduction
+and Secure-by-Default[^1].
+
+In general, Design Principles outline good engineering practices, but in the
+case of security, they also can be thought of as a set of tactics. In a
+real-life castle, there is no single defensive feature. Rather, there are many
+in combination: redundant walls, scattered draw bridges, small bottle-neck
+entrances, moats, etc.
+
+A simplified version of the design is below
+([more detailed version](/docs/))[^2]:
+
+![Figure 1](/assets/images/2019-11-18-security-basics-figure1.png "Simplified design of gVisor.")
+
+In order to discuss design principles, the following components are important to
+know:
+
+* runsc - binary that packages the Sentry, platform, and Gofer(s) that run
+ containers. runsc is the drop-in binary for running gVisor in Docker and
+ Kubernetes.
+* Untrusted Application - container running in the sandbox. Untrusted
+ application/container are used interchangeably in this article.
+* Platform Syscall Switcher - intercepts syscalls from the application and
+ passes them to the Sentry with no further handling.
+* Sentry - The "application kernel" in userspace that serves the untrusted
+ application. Each application instance has its own Sentry. The Sentry
+ handles syscalls, routes I/O to gofers, and manages memory and CPU, all in
+ userspace. The Sentry is allowed to make limited, filtered syscalls to the
+ host OS.
+* Gofer - a process that specifically handles different types of I/O for the
+ Sentry (usually disk I/O). Gofers are also allowed to make filtered syscalls
+ to the Host OS.
+* Host OS - the actual OS on which gVisor containers are running, always some
+ flavor of Linux (sorry, Windows/MacOS users).
+
+It is important to emphasize what is being protected from the untrusted
+application in this diagram: the host OS and other userspace applications.
+
+In this post, we are only discussing security-related features of gVisor, and
+you might ask, "What about performance, compatibility and stability?" We will
+cover these considerations in future posts.
+
+## Defense-in-Depth
+
+For gVisor, Defense-in-Depth means each component of the software stack trusts
+the other components as little as possible.
+
+It may seem strange that we would want our own software components to distrust
+each other. But by limiting the trust between small, discrete components, each
+component is forced to defend itself against potentially malicious input. And
+when you stack these components on top of each other, you can ensure that
+multiple security barriers must be overcome by an attacker.
+
+And this leads us to how Defense-in-Depth is applied to gVisor: no single
+vulnerability should compromise the host.
+
+In the "Attacker's Advantage / Defender's Dilemma," the defender must succeed
+all the time while the attacker only needs to succeed once. Defense in Depth
+inverts this principle: once the attacker successfully compromises any given
+software component, they are immediately faced with needing to compromise a
+subsequent, distinct layer in order to move laterally or acquire more privilege.
+
+For example, the untrusted container is isolated from the Sentry. The Sentry is
+isolated from host I/O operations by serving those requests in separate
+processes called Gofers. And both the untrusted container and its associated
+Gofers are isolated from the host process that is running the sandbox.
+
+An additional benefit is that this generally leads to more robust and stable
+software, forcing interfaces to be strictly defined and tested to ensure all
+inputs are properly parsed and bounds checked.
+
+## Least-Privilege
+
+The principle of Least-Privilege implies that each software component has only
+the permissions it needs to function, and no more.
+
+Least-Privilege is applied throughout gVisor. Each component and more
+importantly, each interface between the components, is designed so that only the
+minimum level of permission is required for it to perform its function.
+Specifically, the closer you are to the untrusted application, the less
+privilege you have.
+
+![Figure 2](/assets/images/2019-11-18-security-basics-figure2.png "runsc components and their privileges.")
+
+This is evident in how runsc (the drop in gVisor binary for Docker/Kubernetes)
+constructs the sandbox. The Sentry has the least privilege possible (it can't
+even open a file!). Gofers are only allowed file access, so even if it were
+compromised, the host network would be unavailable. Only the runsc binary itself
+has full access to the host OS, and even runsc's access to the host OS is often
+limited through capabilities / chroot / namespacing.
+
+Designing a system with Defense-in-Depth and Least-Privilege in mind encourages
+small, separate, single-purpose components, each with very restricted
+privileges.
+
+## Attack Surface Reduction
+
+There are no bugs in unwritten code. In other words, gVisor supports a feature
+if and only if it is needed to run host Linux containers.
+
+### Host Application/Sentry Interface:
+
+There are a lot of things gVisor does not need to do. For example, it does not
+need to support arbitrary device drivers, nor does it need to support video
+playback. By not implementing what will not be used, we avoid introducing
+potential bugs in our code.
+
+That is not to say gVisor has limited functionality! Quite the opposite, we
+analyzed what is actually needed to run Linux containers and today the Sentry
+supports 237 syscalls[^3]<sup>,</sup>[^4], along with the range of critical
+/proc and /dev files. However, gVisor does not support every syscall in the
+Linux kernel. There are about 350 syscalls[^5] within the 5.3.11 version of the
+Linux kernel, many of which do not apply to Linux containers that typically host
+cloud-like workloads. For example, we don't support old versions of epoll
+(epoll_ctl_old, epoll_wait_old), because they are deprecated in Linux and no
+supported workloads use them.
+
+Furthermore, any exploited vulnerabilities in the implemented syscalls (or
+Sentry code in general) only apply to gaining control of the Sentry. More on
+this in a later post.
+
+### Sentry/Host OS Interface:
+
+The Sentry's interactions with the Host OS are restricted in many ways. For
+instance, no syscall is "passed-through" from the untrusted application to the
+host OS. All syscalls are intercepted and interpreted. In the case where the
+Sentry needs to call the Host OS, we severely limit the syscalls that the Sentry
+itself is allowed to make to the host kernel[^6].
+
+For example, there are many file-system based attacks, where manipulation of
+files or their paths, can lead to compromise of the host[^7]. As a result, the
+Sentry does not allow any syscall that creates or opens a file descriptor. All
+file descriptors must be donated to the sandbox. By disallowing open or creation
+of file descriptors, we eliminate entire categories of these file-based attacks.
+
+This does not affect functionality though. For example, during startup, runsc
+will donate FDs the Sentry that allow for mapping STDIN/STDOUT/STDERR to the
+sandboxed application. Also the Gofer may donate an FD to the Sentry, allowing
+for direct access to some files. And most files will be remotely accessed
+through the Gofers, in which case no FDs are donated to the Sentry.
+
+The Sentry itself is only allowed access to specific
+[whitelisted syscalls](https://github.com/google/gvisor/blob/master/runsc/boot/config.go).
+Without networking, the Sentry needs 53 host syscalls in order to function, and
+with networking, it uses an additional 15[^8]. By limiting the whitelist to only
+these needed syscalls, we radically reduce the amount of host OS attack surface.
+If any attempts are made to call something outside the whitelist, it is
+immediately blocked and the sandbox is killed by the Host OS.
+
+### Sentry/Gofer Interface:
+
+The Sentry communicates with the Gofer through a local unix domain socket (UDS)
+via a version of the 9P protocol[^9]. The UDS file descriptor is passed to the
+sandbox during initialization and all communication between the Sentry and Gofer
+happens via 9P. We will go more into how Gofers work in future posts.
+
+### End Result
+
+So, of the 350 syscalls in the Linux kernel, the Sentry needs to implement only
+237 of them to support containers. At most, the Sentry only needs to call 68 of
+the host Linux syscalls. In other words, with gVisor, applications get the vast
+majority (and growing) functionality of Linux containers for only 68 possible
+syscalls to the Host OS. 350 syscalls to 68 is attack surface reduction.
+
+![Figure 3](/assets/images/2019-11-18-security-basics-figure3.png "Reduction of Attack Surface of the Syscall Table. Note that the Senty's Syscall Emulation Layer keeps the Containerized Process from ever calling the Host OS.")
+
+## Secure-by-default
+
+The default choice for a user should be safe. If users need to run a less secure
+configuration of the sandbox for the sake of performance or application
+compatibility, they must make the choice explicitly.
+
+An example of this might be a networking application that is performance
+sensitive. Instead of using the safer, Go-based Netstack in the Sentry, the
+untrusted container can instead use the host Linux networking stack directly.
+However, this means the untrusted container will be directly interacting with
+the host, without the safety benefits of the sandbox. It also means that an
+attack could directly compromise the host through his path.
+
+These less secure configurations are **not** the default. In fact, the user must
+take action to change the configuration and run in a less secure mode.
+Additionally, these actions make it very obvious that a less secure
+configuration is being used.
+
+This can be as simple as forcing a default runtime flag option to the secure
+option. gVisor does this by always using its internal netstack by default.
+However, for certain performance sensitive applications, we allow the usage of
+the host OS networking stack, but it requires the user to actively set a
+flag[^10].
+
+# Technology Choices
+
+Technology choices for gVisor mainly involve things that will give us a security
+boundary.
+
+At a higher level, boundaries in software might be describing a great many
+things. It may be discussing the boundaries between threads, boundaries between
+processes, boundaries between CPU privilege levels, and more.
+
+Security boundaries are interfaces that are designed and built so that entire
+classes of bugs/vulnerabilities are eliminated.
+
+For example, the Sentry and Gofers are implemented using Go. Go was chosen for a
+number of the features it provided. Go is a fast, statically-typed, compiled
+language that has efficient multi-threading support, garbage collection and a
+constrained set of "unsafe" operations.
+
+Using these features enabled safe array and pointer handling. This means entire
+classes of vulnerabilities were eliminated, such as buffer overflows and
+use-after-free.
+
+Another example is our use of very strict syscall switching to ensure that the
+Sentry is always the first software component that parses and interprets the
+calls being made by the untrusted container. Here is an instance where different
+platforms use different solutions, but all of them share this common trait,
+whether it is through the use of ptrace "a la PTRACE_ATTACH"[^11] or kvm's
+ring0[^12].
+
+Finally, one of the most restrictive choices was to use seccomp, to restrict the
+Sentry from being able to open or create a file descriptor on the host. All file
+I/O is required to go through Gofers. Preventing the opening or creation of file
+descriptions eliminates whole categories of bugs around file permissions
+[like this one](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-4557)[^13].
+
+# To be continued - Part 2
+
+In part 2 of this blog post, we will explore gVisor from an attacker's point of
+view. We will use it as an opportunity to examine the specific strengths and
+weaknesses of each gVisor component.
+
+We will also use it to introduce Google's Vulnerability Reward Program[^14], and
+other ways the community can contribute to help make gVisor safe, fast and
+stable.
+
+## Notes
+
+[^1]: [https://en.wikipedia.org/wiki/Secure_by_design](https://en.wikipedia.org/wiki/Secure_by_design)
+[^2]: [https://gvisor.dev/docs/architecture_guide](https://gvisor.dev/docs/architecture_guide/)
+[^3]: [https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/linux64_amd64.go](https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/syscalls.go)
+
+<!-- mdformat off(mdformat formats this into multiple lines) -->
+[^4]: Internally that is, it doesn't call to the Host OS to implement them, in fact that is explicitly disallowed, more on that in the future.
+<!-- mdformat on -->
+
+[^5]: [https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345](https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345)
+[^6]: [https://github.com/google/gvisor/tree/master/runsc/boot/filter](https://github.com/google/gvisor/tree/master/runsc/boot/filter)
+[^7]: [https://en.wikipedia.org/wiki/Dirty_COW](https://en.wikipedia.org/wiki/Dirty_COW)
+[^8]: [https://github.com/google/gvisor/blob/master/runsc/boot/config.go](https://github.com/google/gvisor/blob/master/runsc/boot/config.go)
+
+<!-- mdformat off(mdformat breaks this url by escaping the parenthesis) -->
+[^9]: [https://en.wikipedia.org/wiki/9P_(protocol)](https://en.wikipedia.org/wiki/9P_(protocol))
+<!-- mdformat on -->
+
+[^10]: [https://gvisor.dev/docs/user_guide/networking/#network-passthrough](https://gvisor.dev/docs/user_guide/networking/#network-passthrough)
+[^11]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390)
+[^12]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182)
+[^13]: [https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-4557](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-4557)
+[^14]: [https://www.google.com/about/appsecurity/reward-program/index.html](https://www.google.com/about/appsecurity/reward-program/index.html)
diff --git a/website/blog/2020-04-02-networking-security.md b/website/blog/2020-04-02-networking-security.md
new file mode 100644
index 000000000..f3ce02d11
--- /dev/null
+++ b/website/blog/2020-04-02-networking-security.md
@@ -0,0 +1,183 @@
+# gVisor Networking Security
+
+In our
+[first blog post](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/),
+we covered some secure design principles and how they guided the architecture of
+gVisor as a whole. In this post, we will cover how these principles guided the
+networking architecture of gVisor, and the tradeoffs involved. In particular, we
+will cover how these principles culminated in two networking modes, how they
+work, and the properties of each.
+
+## gVisor's security architecture in the context of networking
+
+Linux networking is complicated. The TCP protocol is over 40 years old, and has
+been repeatedly extended over the years to keep up with the rapid pace of
+network infrastructure improvements, all while maintaining compatibility. On top
+of that, Linux networking has a fairly large API surface. Linux supports
+[over 150 options](https://github.com/google/gvisor/blob/960f6a975b7e44c0efe8fd38c66b02017c4fe137/pkg/sentry/strace/socket.go#L476-L644)
+for the most common socket types alone. In fact, the net subsystem is one of the
+largest and fastest growing in Linux at approximately 1.1 million lines of code.
+For comparison, that is several times the size of the entire gVisor codebase.
+
+At the same time, networking is increasingly important. The cloud era is
+arguably about making everything a network service, and in order to make that
+work, the interconnect performance is critical. Adding networking support to
+gVisor was difficult, not just due to the inherent complexity, but also because
+it has the potential to significantly weaken gVisor's security model.
+
+As outlined in the previous blog post, gVisor's
+[secure design principles](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#design-principles)
+are:
+
+1. Defense in Depth: each component of the software stack trusts each other
+ component as little as possible.
+1. Least Privilege: each software component has only the permissions it needs
+ to function, and no more.
+1. Attack Surface Reduction: limit the surface area of the host exposed to the
+ sandbox.
+1. Secure by Default: the default choice for a user should be safe.
+
+gVisor manifests these principles as a multi-layered system. An application
+running in the sandbox interacts with the Sentry, a userspace kernel, which
+mediates all interactions with the Host OS and beyond. The Sentry is written in
+pure Go with minimal unsafe code, making it less vulnerable to buffer overflows
+and related memory bugs that can lead to a variety of compromises including code
+injection. It emulates Linux using only a minimal and audited set of Host OS
+syscalls that limit the Host OS's attack surface exposed to the Sentry itself.
+The syscall restrictions are enforced by running the Sentry with seccomp
+filters, which enforce that the Sentry can only use the expected set of
+syscalls. The Sentry runs as an unprivileged user and in namespaces, which,
+along with the seccomp filters, ensure that the Sentry is run with the Least
+Privilege required.
+
+gVisor's multi-layered design provides Defense in Depth. The Sentry, which does
+not trust the application because it may attack the Sentry and try to bypass it,
+is the first layer. The sandbox that the Sentry runs in is the second layer. If
+the Sentry were compromised, the attacker would still be in a highly restrictive
+sandbox which they must also break out of in order to compromise the Host OS.
+
+To enable networking functionality while preserving gVisor's security
+properties, we implemented a
+[userspace network stack](https://github.com/google/gvisor/tree/master/pkg/tcpip)
+in the Sentry, which we creatively named Netstack. Netstack is also written in
+Go, not only to avoid unsafe code in the network stack itself, but also to avoid
+a complicated and unsafe Foreign Function Interface. Having its own integrated
+network stack allows the Sentry to implement networking operations using up to
+three Host OS syscalls to read and write packets. These syscalls allow a very
+minimal set of operations which are already allowed (either through the same or
+a similar syscall). Moreover, because packets typically come from off-host (e.g.
+the internet), the Host OS's packet processing code has received a lot of
+scrutiny, hopefully resulting in a high degree of hardening.
+
+![Figure 1](/assets/images/2020-04-02-networking-security-figure1.png "Network and gVisor.")
+
+## Writing a network stack
+
+Netstack was written from scratch specifically for gVisor. Because Netstack was
+designed and implemented to be modular, flexible and self-contained, there are
+now several more projects using Netstack in creative and exciting ways. As we
+discussed, a custom network stack has enabled a variety of security-related
+goals which would not have been possible any other way. This came at a cost
+though. Network stacks are complex and writing a new one comes with many
+challenges, mostly related to application compatibility and performance.
+
+Compatibility issues typically come in two forms: missing features, and features
+with behavior that differs from Linux (usually due to bugs). Both of these are
+inevitable in an implementation of a complex system spanning many quickly
+evolving and ambiguous standards. However, we have invested heavily in this
+area, and the vast majority of applications have no issues using Netstack. For
+example,
+[we now support setting 34 different socket options](https://github.com/google/gvisor/blob/815df2959a76e4a19f5882e40402b9bbca9e70be/pkg/sentry/socket/netstack/netstack.go#L830-L1764)
+versus
+[only 7 in our initial git commit](https://github.com/google/gvisor/blob/d02b74a5dcfed4bfc8f2f8e545bca4d2afabb296/pkg/sentry/socket/epsocket/epsocket.go#L445-L702).
+We are continuing to make good progress in this area.
+
+Performance issues typically come from TCP behavior and packet processing speed.
+To improve our TCP behavior, we are working on implementing the full set of TCP
+RFCs. There are many RFCs which are significant to performance (e.g.
+[RACK](https://tools.ietf.org/id/draft-ietf-tcpm-rack-03.html) and
+[BBR](https://tools.ietf.org/html/draft-cardwell-iccrg-bbr-congestion-control-00))
+that we have yet to implement. This mostly affects TCP performance with
+non-ideal network conditions (e.g. cross continent connections). Faster packet
+processing mostly improves TCP performance when network conditions are very good
+(e.g. within a datacenter). Our primary strategy here is to reduce interactions
+with the Go runtime, specifically the garbage collector (GC) and scheduler. We
+are currently optimizing buffer management to reduce the amount of garbage,
+which will lower the GC cost. To reduce scheduler interactions, we are
+re-architecting the TCP implementation to use fewer goroutines. Performance
+today is good enough for most applications and we are making steady
+improvements. For example, since May of 2019, we have improved the Netstack
+runsc
+[iperf3 download benchmark](https://github.com/google/gvisor/tree/master/test/benchmarks/network)
+score by roughly 15% and upload score by around 10,000X. Current numbers are
+about 17 Gbps download and about 8 Gbps upload versus about 42 Gbps and 43 Gbps
+for native (Linux) respectively.
+
+## An alternative
+
+We also offer an alternative network mode: passthrough. This name can be
+misleading as syscalls are never passed through from the app to the Host OS.
+Instead, the passthrough mode implements networking in gVisor using the Host
+OS's network stack. (This mode is called
+[hostinet](https://github.com/google/gvisor/tree/master/pkg/sentry/socket/hostinet)
+in the codebase.) Passthrough mode can improve performance for some use cases as
+the Host OS's network stack has had an enormous number of person-years poured
+into making it highly performant. However, there is a rather large downside to
+using passthrough mode: it weakens gVisor's security model by increasing the
+Host OS's Attack Surface. This is because using the Host OS's network stack
+requires the Sentry to use the Host OS's
+[Berkeley socket interface](https://en.wikipedia.org/wiki/Berkeley_sockets). The
+Berkeley socket interface is a much larger API surface than the packet interface
+that our network stack uses. When passthrough mode is in use, the Sentry is
+allowed to use
+[15 additional syscalls](https://github.com/google/gvisor/blob/b1576e533223e98ebe4bd1b82b04e3dcda8c4bf1/runsc/boot/filter/config.go#L312-L517).
+Further, this set of syscalls includes some that allow the Sentry to create file
+descriptors, something that
+[we don't normally allow](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#sentry-host-os-interface)
+as it opens up classes of file-based attacks.
+
+There are some networking features that we can't implement on top of syscalls
+that we feel are safe (most notably those behind
+[ioctl](http://man7.org/linux/man-pages/man2/ioctl.2.html)) and therefore are
+not supported. Because of this, we actually support fewer networking features in
+passthrough mode than we do in Netstack, reducing application compatibility.
+That's right: using our networking stack provides better overall application
+compatibility than using our passthrough mode.
+
+That said, gVisor with passthrough networking still provides a high level of
+isolation. Applications cannot specify host syscall arguments directly, and the
+sentry's seccomp policy restricts its syscall use significantly more than a
+general purpose seccomp policy.
+
+## Secure by Default
+
+The goal of the Secure by Default principle is to make it easy to securely
+sandbox containers. Of course, disabling network access entirely is the most
+secure option, but that is not practical for most applications. To make gVisor
+Secure by Default, we have made Netstack the default networking mode in gVisor
+as we believe that it provides significantly better isolation. For this reason
+we strongly caution users from changing the default unless Netstack flat out
+won't work for them. The passthrough mode option is still provided, but we want
+users to make an informed decision when selecting it.
+
+Another way in which gVisor makes it easy to securely sandbox containers is by
+allowing applications to run unmodified, with no special configuration needed.
+In order to do this, gVisor needs to support all of the features and syscalls
+that applications use. Neither seccomp nor gVisor's passthrough mode can do this
+as applications commonly use syscalls which are too dangerous to be included in
+a secure policy. Even if this dream isn't fully realized today, gVisor's
+architecture with Netstack makes this possible.
+
+## Give Netstack a Try
+
+If you haven't already, try running a workload in gVisor with Netstack. You can
+find instructions on how to get started in our
+[Quick Start](/docs/user_guide/quick_start/docker/). We want to hear about both
+your successes and any issues you encounter. We welcome your contributions,
+whether that be verbal feedback or code contributions, via our
+[Gitter channel](https://gitter.im/gvisor/community),
+[email list](https://groups.google.com/forum/#!forum/gvisor-users),
+[issue tracker](https://gvisor.dev/issue/new), and
+[Github repository](https://github.com/google/gvisor). Feel free to express
+interest in an [open issue](https://gvisor.dev/issue/), or reach out if you
+aren't sure where to start.
diff --git a/website/blog/BUILD b/website/blog/BUILD
new file mode 100644
index 000000000..01c1f5a6e
--- /dev/null
+++ b/website/blog/BUILD
@@ -0,0 +1,37 @@
+load("//website:defs.bzl", "doc", "docs")
+
+package(
+ default_visibility = ["//website:__pkg__"],
+ licenses = ["notice"],
+)
+
+exports_files(["index.html"])
+
+doc(
+ name = "security_basics",
+ src = "2019-11-18-security-basics.md",
+ authors = [
+ "jsprad",
+ "zkoopmans",
+ ],
+ layout = "post",
+ permalink = "/blog/2019/11/18/gvisor-security-basics-part-1/",
+)
+
+doc(
+ name = "networking_security",
+ src = "2020-04-02-networking-security.md",
+ authors = [
+ "igudger",
+ ],
+ layout = "post",
+ permalink = "/blog/2020/04/02/gvisor-networking-security/",
+)
+
+docs(
+ name = "posts",
+ deps = [
+ ":" + rule
+ for rule in existing_rules()
+ ],
+)
diff --git a/website/blog/index.html b/website/blog/index.html
new file mode 100644
index 000000000..5c67c95fc
--- /dev/null
+++ b/website/blog/index.html
@@ -0,0 +1,22 @@
+---
+title: Blog
+layout: blog
+feed: true
+pagination:
+ enabled: true
+---
+
+{% for post in paginator.posts %}
+<div>
+ <h2><a href="{{ post.url }}">{{ post.title }}</a></h2>
+ <div class="blog-meta">
+ {% include byline.html authors=post.authors date=post.date %}
+ </div>
+ <p>{{ post.excerpt | strip_html }}</p>
+ <p><a href="{{ post.url }}">Full Post &raquo;</a></p>
+</div>
+{% endfor %}
+
+{% if paginator.total_pages > 1 %}
+{% include paginator.html %}
+{% endif %}
diff --git a/website/cmd/server/BUILD b/website/cmd/server/BUILD
new file mode 100644
index 000000000..6b5a08f0d
--- /dev/null
+++ b/website/cmd/server/BUILD
@@ -0,0 +1,10 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "server",
+ srcs = ["main.go"],
+ pure = True,
+ visibility = ["//website:__pkg__"],
+)
diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go
new file mode 100644
index 000000000..c401b6abd
--- /dev/null
+++ b/website/cmd/server/main.go
@@ -0,0 +1,215 @@
+// Copyright 2019 The gVisor Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Server is the main gvisor.dev binary.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "regexp"
+ "strings"
+)
+
+var redirects = map[string]string{
+ // GitHub redirects.
+ "/change": "https://github.com/google/gvisor",
+ "/issue": "https://github.com/google/gvisor/issues",
+ "/issue/new": "https://github.com/google/gvisor/issues/new",
+ "/pr": "https://github.com/google/gvisor/pulls",
+
+ // For links.
+ "/faq": "/docs/user_guide/faq/",
+
+ // From 2020-05-12 to 2020-06-30, the FAQ URL was uppercase. Redirect that
+ // back to maintain any links.
+ "/docs/user_guide/FAQ/": "/docs/user_guide/faq/",
+
+ // Redirects to compatibility docs.
+ "/c": "/docs/user_guide/compatibility/",
+ "/c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/",
+
+ // Redirect for old URLs.
+ "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/",
+ "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/",
+ "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/",
+ "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/",
+ "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/",
+
+ // Deprecated, but links continue to work.
+ "/cl": "https://gvisor-review.googlesource.com",
+}
+
+var prefixHelpers = map[string]string{
+ "change": "https://github.com/google/gvisor/commit/%s",
+ "issue": "https://github.com/google/gvisor/issues/%s",
+ "pr": "https://github.com/google/gvisor/pull/%s",
+
+ // Redirects to compatibility docs.
+ "c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/#%s",
+
+ // Deprecated, but links continue to work.
+ "cl": "https://gvisor-review.googlesource.com/c/gvisor/+/%s",
+}
+
+var (
+ validID = regexp.MustCompile(`^[A-Za-z0-9-]*/?$`)
+ goGetHTML5 = `<!doctype html><html><head><meta charset=utf-8>
+<meta name="go-import" content="gvisor.dev/gvisor git https://github.com/google/gvisor">
+<meta name="go-import" content="gvisor.dev/website git https://github.com/google/gvisor-website">
+<title>Go-get</title></head><body></html>`
+)
+
+// cronHandler wraps an http.Handler to check that the request is from the App
+// Engine Cron service.
+// See: https://cloud.google.com/appengine/docs/standard/go112/scheduling-jobs-with-cron-yaml#validating_cron_requests
+func cronHandler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("X-Appengine-Cron") != "true" {
+ http.NotFound(w, r)
+ return
+ }
+ // Fallthrough.
+ h.ServeHTTP(w, r)
+ })
+}
+
+// wrappedHandler wraps an http.Handler.
+//
+// If the query parameters include go-get=1, then we redirect to a single
+// static page that allows us to serve arbitrary Go packages.
+func wrappedHandler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ gg, ok := r.URL.Query()["go-get"]
+ if ok && len(gg) == 1 && gg[0] == "1" {
+ // Serve a trivial html page.
+ w.Write([]byte(goGetHTML5))
+ return
+ }
+ // Fallthrough.
+ h.ServeHTTP(w, r)
+ })
+}
+
+// redirectWithQuery redirects to the given target url preserving query parameters.
+func redirectWithQuery(w http.ResponseWriter, r *http.Request, target string) {
+ url := target
+ if qs := r.URL.RawQuery; qs != "" {
+ url += "?" + qs
+ }
+ http.Redirect(w, r, url, http.StatusFound)
+}
+
+// hostRedirectHandler redirects the www. domain to the naked domain.
+func hostRedirectHandler(h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.HasPrefix(r.Host, "www.") {
+ // Redirect to the naked domain.
+ r.URL.Scheme = "https" // Assume https.
+ r.URL.Host = r.Host[4:] // Remove the 'www.'
+ http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently)
+ return
+ }
+
+ if *projectID != "" && r.Host == *projectID+".appspot.com" && *customHost != "" {
+ // Redirect to the custom domain.
+ r.URL.Scheme = "https" // Assume https.
+ r.URL.Host = *customHost
+ http.Redirect(w, r, r.URL.String(), http.StatusMovedPermanently)
+ return
+ }
+ h.ServeHTTP(w, r)
+ })
+}
+
+// prefixRedirectHandler returns a handler that redirects to the given formated url.
+func prefixRedirectHandler(prefix, baseURL string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if p := r.URL.Path; p == prefix {
+ // Redirect /prefix/ to /prefix.
+ http.Redirect(w, r, p[:len(p)-1], http.StatusFound)
+ return
+ }
+ id := r.URL.Path[len(prefix):]
+ if !validID.MatchString(id) {
+ http.Error(w, "Not found", http.StatusNotFound)
+ return
+ }
+ target := fmt.Sprintf(baseURL, id)
+ redirectWithQuery(w, r, target)
+ })
+}
+
+// redirectHandler returns a handler that redirects to the given url.
+func redirectHandler(target string) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ redirectWithQuery(w, r, target)
+ })
+}
+
+// redirectRedirects registers redirect http handlers.
+func registerRedirects(mux *http.ServeMux) {
+ if mux == nil {
+ mux = http.DefaultServeMux
+ }
+
+ for prefix, baseURL := range prefixHelpers {
+ p := "/" + prefix + "/"
+ mux.Handle(p, hostRedirectHandler(wrappedHandler(prefixRedirectHandler(p, baseURL))))
+ }
+
+ for path, redirect := range redirects {
+ mux.Handle(path, hostRedirectHandler(wrappedHandler(redirectHandler(redirect))))
+ }
+}
+
+// registerStatic registers static file handlers
+func registerStatic(mux *http.ServeMux, staticDir string) {
+ if mux == nil {
+ mux = http.DefaultServeMux
+ }
+ mux.Handle("/", hostRedirectHandler(wrappedHandler(http.FileServer(http.Dir(staticDir)))))
+}
+
+func envFlagString(name, def string) string {
+ if val := os.Getenv(name); val != "" {
+ return val
+ }
+ return def
+}
+
+var (
+ addr = flag.String("http", envFlagString("HTTP", ":"+envFlagString("PORT", "8080")), "HTTP service address")
+ staticDir = flag.String("static-dir", envFlagString("STATIC_DIR", "_site"), "static files directory")
+
+ // Uses the standard GOOGLE_CLOUD_PROJECT environment variable set by App Engine.
+ projectID = flag.String("project-id", envFlagString("GOOGLE_CLOUD_PROJECT", ""), "The App Engine project ID.")
+ customHost = flag.String("custom-domain", envFlagString("CUSTOM_DOMAIN", "gvisor.dev"), "The application's custom domain.")
+)
+
+func main() {
+ flag.Parse()
+
+ registerRedirects(nil)
+ registerStatic(nil, *staticDir)
+
+ log.Printf("Listening on %s...", *addr)
+ log.Fatal(http.ListenAndServe(*addr, nil))
+}
diff --git a/website/cmd/syscalldocs/BUILD b/website/cmd/syscalldocs/BUILD
new file mode 100644
index 000000000..c5a0ed7fe
--- /dev/null
+++ b/website/cmd/syscalldocs/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "syscalldocs",
+ srcs = ["main.go"],
+ visibility = ["//website:__pkg__"],
+)
diff --git a/website/cmd/syscalldocs/main.go b/website/cmd/syscalldocs/main.go
new file mode 100644
index 000000000..327537214
--- /dev/null
+++ b/website/cmd/syscalldocs/main.go
@@ -0,0 +1,211 @@
+// Copyright 2019 The gVisor Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary syscalldocs generates system call markdown.
+package main
+
+import (
+ "bufio"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "text/template"
+)
+
+// CompatibilityInfo is the collection of all information.
+type CompatibilityInfo map[string]map[string]ArchInfo
+
+// ArchInfo is compatbility doc for an architecture.
+type ArchInfo struct {
+ // Syscalls maps syscall number for the architecture to the doc.
+ Syscalls map[uintptr]SyscallDoc `json:"syscalls"`
+}
+
+// SyscallDoc represents a single item of syscall documentation.
+type SyscallDoc struct {
+ Name string `json:"name"`
+ Support string `json:"support"`
+ Note string `json:"note,omitempty"`
+ URLs []string `json:"urls,omitempty"`
+}
+
+var mdTemplate = template.Must(template.New("out").Parse(`---
+title: {{.Title}}
+description: Syscall Compatibility Reference Documentation for {{.OS}}/{{.Arch}}
+layout: docs
+category: Compatibility
+weight: 50
+permalink: /docs/user_guide/compatibility/{{.OS}}/{{.Arch}}/
+---
+
+This table is a reference of {{.OS}} syscalls for the {{.Arch}} architecture and
+their compatibility status in gVisor. gVisor does not support all syscalls and
+some syscalls may have a partial implementation.
+
+This page is automatically generated from the source code.
+
+Of {{.Total}} syscalls, {{.Supported}} syscalls have a full or partial
+implementation. There are currently {{.Unsupported}} unsupported
+syscalls. {{if .Undocumented}}{{.Undocumented}} syscalls are not yet documented.{{end}}
+
+<table>
+ <thead>
+ <tr>
+ <th>#</th>
+ <th>Name</th>
+ <th>Support</th>
+ <th>Notes</th>
+ </tr>
+ </thead>
+ <tbody>
+ {{range $i, $syscall := .Syscalls}}
+ <tr>
+ <td><a class="doc-table-anchor" id="{{.Name}}"></a>{{.Number}}</td>
+ <td><a href="http://man7.org/linux/man-pages/man2/{{.Name}}.2.html" target="_blank" rel="noopener">{{.Name}}</a></td>
+ <td>{{.Support}}</td>
+ <td>{{.Note}} {{range $i, $url := .URLs}}<br/>See: <a href="{{.}}">{{.}}</a>{{end}}</td>
+ </tr>
+ {{end}}
+ </tbody>
+</table>
+`))
+
+// Fatalf writes a message to stderr and exits with error code 1
+func Fatalf(format string, a ...interface{}) {
+ fmt.Fprintf(os.Stderr, format, a...)
+ os.Exit(1)
+}
+
+func main() {
+ inputFlag := flag.String("in", "-", "File to input ('-' for stdin)")
+ outputDir := flag.String("out", ".", "Directory to output files.")
+
+ flag.Parse()
+
+ var input io.Reader
+ if *inputFlag == "-" {
+ input = os.Stdin
+ } else {
+ i, err := os.Open(*inputFlag)
+ if err != nil {
+ Fatalf("Error opening %q: %v", *inputFlag, err)
+ }
+ input = i
+ }
+ input = bufio.NewReader(input)
+
+ var info CompatibilityInfo
+ d := json.NewDecoder(input)
+ if err := d.Decode(&info); err != nil {
+ Fatalf("Error reading json: %v", err)
+ }
+
+ weight := 0
+ for osName, osInfo := range info {
+ for archName, archInfo := range osInfo {
+ outDir := filepath.Join(*outputDir, osName)
+ outFile := filepath.Join(outDir, archName+".md")
+
+ if err := os.MkdirAll(outDir, 0755); err != nil {
+ Fatalf("Error creating directory %q: %v", *outputDir, err)
+ }
+
+ f, err := os.OpenFile(outFile, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
+ if err != nil {
+ Fatalf("Error opening file %q: %v", outFile, err)
+ }
+ defer f.Close()
+
+ weight += 10
+ data := struct {
+ Title string
+ OS string
+ Arch string
+ Weight int
+ Total int
+ Supported int
+ Unsupported int
+ Undocumented int
+ Syscalls []struct {
+ Name string
+ Number uintptr
+ Support string
+ Note string
+ URLs []string
+ }
+ }{
+ Title: strings.Title(osName) + "/" + archName,
+ OS: osName,
+ Arch: archName,
+ Weight: weight,
+ Total: 0,
+ Supported: 0,
+ Unsupported: 0,
+ Undocumented: 0,
+ Syscalls: []struct {
+ Name string
+ Number uintptr
+ Support string
+ Note string
+ URLs []string
+ }{},
+ }
+
+ for num, s := range archInfo.Syscalls {
+ switch s.Support {
+ case "Full Support", "Partial Support":
+ data.Supported++
+ case "Unimplemented":
+ data.Unsupported++
+ default:
+ data.Undocumented++
+ }
+ data.Total++
+
+ for i := range s.URLs {
+ if !strings.HasPrefix(s.URLs[i], "http://") && !strings.HasPrefix(s.URLs[i], "https://") {
+ s.URLs[i] = "https://" + s.URLs[i]
+ }
+ }
+
+ data.Syscalls = append(data.Syscalls, struct {
+ Name string
+ Number uintptr
+ Support string
+ Note string
+ URLs []string
+ }{
+ Name: s.Name,
+ Number: num,
+ Support: s.Support,
+ Note: s.Note, // TODO urls
+ URLs: s.URLs,
+ })
+ }
+
+ sort.Slice(data.Syscalls, func(i, j int) bool {
+ return data.Syscalls[i].Number < data.Syscalls[j].Number
+ })
+
+ if err := mdTemplate.Execute(f, data); err != nil {
+ Fatalf("Error writing file %q: %v", outFile, err)
+ }
+ }
+ }
+}
diff --git a/website/css/main.scss b/website/css/main.scss
new file mode 100644
index 000000000..06106833f
--- /dev/null
+++ b/website/css/main.scss
@@ -0,0 +1,5 @@
+@import 'style.scss';
+@import 'front.scss';
+@import 'navbar.scss';
+@import 'sidebar.scss';
+@import 'footer.scss';
diff --git a/website/defs.bzl b/website/defs.bzl
new file mode 100644
index 000000000..f52946c15
--- /dev/null
+++ b/website/defs.bzl
@@ -0,0 +1,178 @@
+"""Wrappers for website documentation."""
+
+load("//tools:defs.bzl", "short_path")
+
+# DocInfo is a provider which simple adds sufficient metadata to the source
+# files (and additional data files) so that a jeyll header can be constructed
+# dynamically. This is done the via BUILD system so that the plain
+# documentation files can be viewable without non-compliant markdown headers.
+DocInfo = provider(
+ fields = [
+ "layout",
+ "description",
+ "permalink",
+ "category",
+ "subcategory",
+ "weight",
+ "editpath",
+ "authors",
+ ],
+)
+
+def _doc_impl(ctx):
+ return [
+ DefaultInfo(
+ files = depset(ctx.files.src + ctx.files.data),
+ ),
+ DocInfo(
+ layout = ctx.attr.layout,
+ description = ctx.attr.description,
+ permalink = ctx.attr.permalink,
+ category = ctx.attr.category,
+ subcategory = ctx.attr.subcategory,
+ weight = ctx.attr.weight,
+ editpath = short_path(ctx.files.src[0].short_path),
+ authors = ctx.attr.authors,
+ ),
+ ]
+
+doc = rule(
+ implementation = _doc_impl,
+ doc = "Annotate a document for jekyll headers.",
+ attrs = {
+ "src": attr.label(
+ doc = "The markdown source file.",
+ mandatory = True,
+ allow_single_file = True,
+ ),
+ "data": attr.label_list(
+ doc = "Additional data files (e.g. images).",
+ allow_files = True,
+ ),
+ "layout": attr.string(
+ doc = "The document layout.",
+ default = "docs",
+ ),
+ "description": attr.string(
+ doc = "The document description.",
+ default = "",
+ ),
+ "permalink": attr.string(
+ doc = "The document permalink.",
+ mandatory = True,
+ ),
+ "category": attr.string(
+ doc = "The document category.",
+ default = "",
+ ),
+ "subcategory": attr.string(
+ doc = "The document subcategory.",
+ default = "",
+ ),
+ "weight": attr.string(
+ doc = "The document weight.",
+ default = "50",
+ ),
+ "authors": attr.string_list(),
+ },
+)
+
+def _docs_impl(ctx):
+ # Tarball is the actual output.
+ tarball = ctx.actions.declare_file(ctx.label.name + ".tgz")
+
+ # But we need an intermediate builder to translate the files.
+ builder = ctx.actions.declare_file("%s-builder" % ctx.label.name)
+ builder_content = [
+ "#!/bin/bash",
+ "set -euo pipefail",
+ "declare -r T=$(mktemp -d)",
+ "function cleanup {",
+ " rm -rf $T",
+ "}",
+ "trap cleanup EXIT",
+ ]
+ for dep in ctx.attr.deps:
+ doc = dep[DocInfo]
+
+ # Sanity check the permalink.
+ if not doc.permalink.endswith("/"):
+ fail("permalink %s for target %s should end with /" % (
+ doc.permalink,
+ ctx.label.name,
+ ))
+
+ # Construct the header.
+ header = """\
+description: {description}
+permalink: {permalink}
+category: {category}
+subcategory: {subcategory}
+weight: {weight}
+editpath: {editpath}
+authors: {authors}
+layout: {layout}"""
+
+ for f in dep.files.to_list():
+ # Is this a markdown file? If not, then we ensure that it ends up
+ # in the same path as the permalink for relative addressing.
+ if not f.basename.endswith(".md"):
+ builder_content.append("mkdir -p $T/%s" % doc.permalink)
+ builder_content.append("cp %s $T/%s" % (f.path, doc.permalink))
+ continue
+
+ # Is this a post? If yes, then we must put this in the _posts
+ # directory. This directory is treated specially with respect to
+ # pagination and page generation.
+ dest = f.short_path
+ if doc.layout == "post":
+ dest = "_posts/" + f.basename
+ builder_content.append("echo Processing %s... >&2" % f.short_path)
+ builder_content.append("mkdir -p $T/$(dirname %s)" % dest)
+
+ # Construct the header dynamically. We include the title field from
+ # the markdown itself, as this is the g3doc format required. The
+ # title will be injected by the web layout however, so we don't
+ # want this to appear in the document.
+ args = dict([(k, getattr(doc, k)) for k in dir(doc)])
+ builder_content.append("title=\"$(grep -E '^# ' %s | head -n 1 | cut -d'#' -f2- || true)\"" % f.path)
+ builder_content.append("cat >$T/%s <<EOF" % dest)
+ builder_content.append("---")
+ builder_content.append("title: $title")
+ builder_content.append(header.format(**args))
+ builder_content.append("---")
+ builder_content.append("EOF")
+
+ # To generate the final page, we need to strip out the title (which
+ # was pulled above to generate the annotation in the frontmatter,
+ # and substitute the [TOC] tag with the {% toc %} plugin tag. Note
+ # that the pipeline here is almost important, as the grep will
+ # return non-zero if the file is empty, but we ignore that within
+ # the pipeline.
+ builder_content.append("grep -v -E '^# ' %s | sed -e 's|^\\[TOC\\]$|- TOC\\n{:toc}|' >>$T/%s" %
+ (f.path, dest))
+
+ builder_content.append("declare -r filename=$(readlink -m %s)" % tarball.path)
+ builder_content.append("(cd $T && tar -zcf \"${filename}\" .)\n")
+ ctx.actions.write(builder, "\n".join(builder_content), is_executable = True)
+
+ # Generate the tarball.
+ ctx.actions.run(
+ inputs = depset(ctx.files.deps),
+ outputs = [tarball],
+ progress_message = "Generating %s" % ctx.label,
+ executable = builder,
+ )
+ return [DefaultInfo(
+ files = depset([tarball]),
+ )]
+
+docs = rule(
+ implementation = _docs_impl,
+ doc = "Construct a site tarball from doc dependencies.",
+ attrs = {
+ "deps": attr.label_list(
+ doc = "All document dependencies.",
+ ),
+ },
+)
diff --git a/test/runtimes/runner.sh b/website/import.sh
index a8d9a3460..e1350e83d 100755
--- a/test/runtimes/runner.sh
+++ b/website/import.sh
@@ -14,22 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -euf -x -o pipefail
+set -xeuo pipefail
-echo -- "$@"
-
-# Create outputs dir if it does not exist.
-if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then
- mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}"
- chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}"
+if [[ -d $0.runfiles ]]; then
+ cd $0.runfiles
fi
-# Update the timestamp on the shard status file. Bazel looks for this.
-touch "${TEST_SHARD_STATUS_FILE}"
-
-# Get location of runner binary.
-readonly runner=$(find "${TEST_SRCDIR}" -name runner)
-
-# Pass the arguments of this script directly to the runner.
-exec "${runner}" "$@"
-
+exec docker import \
+ -c "EXPOSE 8080/tcp" \
+ -c "ENTRYPOINT [\"/server\"]" \
+ $(find . -name files.tgz) \
+ gvisor.dev/images/website
diff --git a/website/index.md b/website/index.md
new file mode 100644
index 000000000..84f877d49
--- /dev/null
+++ b/website/index.md
@@ -0,0 +1,50 @@
+<div class="jumbotron jumbotron-fluid">
+ <div class="container">
+ <div class="row">
+ <div class="col-md-3"></div>
+ <div class="col-md-6">
+ <p>gVisor is an <b>application kernel</b> for <b>containers</b> that provides efficient defense-in-depth anywhere.</p>
+ <p style="margin-top: 20px;">
+ <a class="btn" href="/docs/user_guide/quick_start/docker/">Quick start&nbsp;<i class="fas fa-arrow-alt-circle-right ml-2"></i></a>
+ <a class="btn" href="/docs/">Learn More&nbsp;<i class="fas fa-arrow-alt-circle-right ml-2"></i></a>
+ </p>
+ </div>
+ <div class="col-md-3"></div>
+ </div>
+ </div>
+</div>
+
+<div class="container"> <!-- Full page container. -->
+
+<div class="row">
+ <div class="col-md-4">
+ <h4 id="seamless-security">Container-native Security <i class="fas fa-lock"></i></h4>
+ <p>By providing each container with its own application kernel, gVisor
+ limits the attack surface of the host. This protection does not limit
+ functionality: gVisor runs unmodified binaries and integrates with container
+ orchestration systems, such as Docker and Kubernetes, and supports features
+ such as volumes and sidecars.</p>
+ <a class="button" href="/docs/architecture_guide/security/">Read More &raquo;</a>
+ </div>
+
+ <div class="col-md-4">
+ <h4 id="resource-efficiency">Resource Efficiency <i class="fas fa-feather-alt"></i></h4>
+ <p>Containers are efficient because workloads of different shapes and sizes
+ can be packed together by sharing host resources. gVisor uses host-native
+ abstractions, such as threads and memory mappings, to co-operate with the
+ host and enable the same resource model as native containers.</p>
+ <a class="button" href="/docs/architecture_guide/resources/">Read More &raquo;</a>
+ </div>
+
+ <div class="col-md-4">
+ <h4 id="platform-portability">Platform Portability <sup>&#9729;</sup>&#9729;</h4>
+ <p>Modern infrastructure spans multiple cloud services and data centers,
+ often with a mix of managed services and virtualized or traditional servers.
+ The pluggable platform architecture of gVisor allows it to run anywhere,
+ enabling consistent security policies across multiple environments without
+ having to rearchitect your infrastructure.</p>
+ <a class="button" href="/docs/architecture_guide/platforms/">Read More &raquo;</a>
+ </div>
+</div>
+
+</div> <!-- container -->
diff --git a/website/performance/README.md b/website/performance/README.md
new file mode 100644
index 000000000..1758fc608
--- /dev/null
+++ b/website/performance/README.md
@@ -0,0 +1,10 @@
+# Performance data
+
+This directory holds the CSVs generated by the now removed benchmark-tools
+repository. The new functionally equivalent
+[benchmark-tools is available.][benchmark-tools]
+
+In the future, these will be automatically posted to a cloud storage bucket and
+loaded dynamically. At that point, this directory will be removed.
+
+[benchmark-tools]: https://github.com/google/gvisor/tree/master/test/benchmarks
diff --git a/website/performance/applications.csv b/website/performance/applications.csv
new file mode 100644
index 000000000..7b4661c60
--- /dev/null
+++ b/website/performance/applications.csv
@@ -0,0 +1,13 @@
+runtime,method,metric,result
+runc,http.node,transfer_rate,3814.85
+runc,http.node,latency,11.0
+runc,http.node,requests_per_second,885.81
+runc,http.ruby,transfer_rate,2874.38
+runc,http.ruby,latency,18.0
+runc,http.ruby,requests_per_second,539.97
+runsc,http.node,transfer_rate,1615.54
+runsc,http.node,latency,27.0
+runsc,http.node,requests_per_second,375.13
+runsc,http.ruby,transfer_rate,1382.71
+runsc,http.ruby,latency,38.0
+runsc,http.ruby,requests_per_second,259.75
diff --git a/website/performance/density.csv b/website/performance/density.csv
new file mode 100644
index 000000000..729b44941
--- /dev/null
+++ b/website/performance/density.csv
@@ -0,0 +1,9 @@
+runtime,method,metric,result
+runc,density.empty,memory_usage,4092149.76
+runc,density.node,memory_usage,76709888.0
+runc,density.ruby,memory_usage,45737000.96
+runsc,density.empty,memory_usage,23695032.32
+runsc,density.node,memory_usage,124076605.44
+runsc,density.ruby,memory_usage,106141777.92
+runc,density.redis,memory_usage,1055323750.4
+runsc,density.redis,memory_usage,1076686028.8
diff --git a/website/performance/ffmpeg.csv b/website/performance/ffmpeg.csv
new file mode 100644
index 000000000..08661c749
--- /dev/null
+++ b/website/performance/ffmpeg.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,run_time,82.000625
+runsc,run_time,88.24018
diff --git a/website/performance/fio-tmpfs.csv b/website/performance/fio-tmpfs.csv
new file mode 100644
index 000000000..99777d2e4
--- /dev/null
+++ b/website/performance/fio-tmpfs.csv
@@ -0,0 +1,9 @@
+runtime,method,metric,result
+runc,fio.read,bandwidth,4240686080
+runc,fio.write,bandwidth,3029744640
+runsc,fio.read,bandwidth,2533604352
+runsc,fio.write,bandwidth,1207536640
+runc,fio.randread,bandwidth,1221472256
+runc,fio.randwrite,bandwidth,1046094848
+runsc,fio.randread,bandwidth,68940800
+runsc,fio.randwrite,bandwidth,67286016
diff --git a/website/performance/fio.csv b/website/performance/fio.csv
new file mode 100644
index 000000000..80d6ae289
--- /dev/null
+++ b/website/performance/fio.csv
@@ -0,0 +1,9 @@
+runtime,method,metric,result
+runc,fio.read,bandwidth,252253184
+runc,fio.write,bandwidth,457767936
+runsc,fio.read,bandwidth,252323840
+runsc,fio.write,bandwidth,431845376
+runc,fio.randread,bandwidth,5284864
+runc,fio.randwrite,bandwidth,107758592
+runsc,fio.randread,bandwidth,4403200
+runsc,fio.randwrite,bandwidth,69161984
diff --git a/website/performance/httpd100k.csv b/website/performance/httpd100k.csv
new file mode 100644
index 000000000..e92c7e9e0
--- /dev/null
+++ b/website/performance/httpd100k.csv
@@ -0,0 +1,17 @@
+connections,runtime,metric,result
+1,runc,transfer_rate,565.35
+1,runc,latency,1.0
+1,runsc,transfer_rate,282.84
+1,runsc,latency,2.0
+5,runc,transfer_rate,3260.57
+5,runc,latency,1.0
+5,runsc,transfer_rate,832.69
+5,runsc,latency,3.0
+10,runc,transfer_rate,4672.01
+10,runc,latency,1.0
+10,runsc,transfer_rate,1095.47
+10,runsc,latency,4.0
+25,runc,transfer_rate,4964.14
+25,runc,latency,2.0
+25,runsc,transfer_rate,961.03
+25,runsc,latency,12.0
diff --git a/website/performance/httpd10240k.csv b/website/performance/httpd10240k.csv
new file mode 100644
index 000000000..60dbe7b40
--- /dev/null
+++ b/website/performance/httpd10240k.csv
@@ -0,0 +1,17 @@
+connections,runtime,metric,result
+1,runc,transfer_rate,674.05
+1,runc,latency,1.0
+1,runsc,transfer_rate,243.35
+1,runsc,latency,2.0
+5,runc,transfer_rate,3089.83
+5,runc,latency,1.0
+5,runsc,transfer_rate,981.91
+5,runsc,latency,2.0
+10,runc,transfer_rate,4701.2
+10,runc,latency,1.0
+10,runsc,transfer_rate,1135.08
+10,runsc,latency,4.0
+25,runc,transfer_rate,5021.36
+25,runc,latency,2.0
+25,runsc,transfer_rate,963.26
+25,runsc,latency,12.0
diff --git a/website/performance/iperf.csv b/website/performance/iperf.csv
new file mode 100644
index 000000000..1f3b41aec
--- /dev/null
+++ b/website/performance/iperf.csv
@@ -0,0 +1,5 @@
+runtime,method,metric,result
+runc,network.download,bandwidth,746386000.0
+runc,network.upload,bandwidth,709808000.0
+runsc,network.download,bandwidth,640303500.0
+runsc,network.upload,bandwidth,482254000.0
diff --git a/website/performance/redis.csv b/website/performance/redis.csv
new file mode 100644
index 000000000..369b16712
--- /dev/null
+++ b/website/performance/redis.csv
@@ -0,0 +1,35 @@
+runtime,method,metric,result
+runc,PING_INLINE,requests_per_second,30525.03
+runc,PING_BULK,requests_per_second,30293.85
+runc,SET,requests_per_second,30257.19
+runc,GET,requests_per_second,30312.21
+runc,INCR,requests_per_second,30525.03
+runc,LPUSH,requests_per_second,30712.53
+runc,RPUSH,requests_per_second,30459.95
+runc,LPOP,requests_per_second,30367.45
+runc,RPOP,requests_per_second,30665.44
+runc,SADD,requests_per_second,30030.03
+runc,HSET,requests_per_second,30656.04
+runc,SPOP,requests_per_second,29940.12
+runc,LRANGE_100,requests_per_second,24224.81
+runc,LRANGE_300,requests_per_second,14302.06
+runc,LRANGE_500,requests_per_second,11728.83
+runc,LRANGE_600,requests_per_second,9900.99
+runc,MSET,requests_per_second,30120.48
+runsc,PING_INLINE,requests_per_second,14528.55
+runsc,PING_BULK,requests_per_second,15627.44
+runsc,SET,requests_per_second,15403.57
+runsc,GET,requests_per_second,15325.67
+runsc,INCR,requests_per_second,15269.51
+runsc,LPUSH,requests_per_second,15172.2
+runsc,RPUSH,requests_per_second,15117.16
+runsc,LPOP,requests_per_second,15257.86
+runsc,RPOP,requests_per_second,15188.33
+runsc,SADD,requests_per_second,15432.1
+runsc,HSET,requests_per_second,15163.0
+runsc,SPOP,requests_per_second,15561.78
+runsc,LRANGE_100,requests_per_second,13365.41
+runsc,LRANGE_300,requests_per_second,9520.18
+runsc,LRANGE_500,requests_per_second,8248.78
+runsc,LRANGE_600,requests_per_second,6544.07
+runsc,MSET,requests_per_second,14367.82
diff --git a/website/performance/startup.csv b/website/performance/startup.csv
new file mode 100644
index 000000000..6bad00df6
--- /dev/null
+++ b/website/performance/startup.csv
@@ -0,0 +1,7 @@
+runtime,method,metric,result
+runc,startup.empty,startup_time_ms,1193.10768
+runc,startup.node,startup_time_ms,2557.95336
+runc,startup.ruby,startup_time_ms,2530.12624
+runsc,startup.empty,startup_time_ms,1144.1775
+runsc,startup.node,startup_time_ms,2441.90284
+runsc,startup.ruby,startup_time_ms,2455.69882
diff --git a/website/performance/sysbench-cpu.csv b/website/performance/sysbench-cpu.csv
new file mode 100644
index 000000000..f4e6b69a6
--- /dev/null
+++ b/website/performance/sysbench-cpu.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,cpu_events_per_second,103.62
+runsc,cpu_events_per_second,103.21
diff --git a/website/performance/sysbench-memory.csv b/website/performance/sysbench-memory.csv
new file mode 100644
index 000000000..626ff4994
--- /dev/null
+++ b/website/performance/sysbench-memory.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,memory_ops_per_second,13098.73
+runsc,memory_ops_per_second,13107.44
diff --git a/website/performance/syscall.csv b/website/performance/syscall.csv
new file mode 100644
index 000000000..40bdce49e
--- /dev/null
+++ b/website/performance/syscall.csv
@@ -0,0 +1,4 @@
+runtime,metric,result
+runc,syscall_time_ns,1939.0
+runsc,syscall_time_ns,38219.0
+runsc-kvm,syscall_time_ns,763.0
diff --git a/website/performance/tensorflow.csv b/website/performance/tensorflow.csv
new file mode 100644
index 000000000..03498bef0
--- /dev/null
+++ b/website/performance/tensorflow.csv
@@ -0,0 +1,3 @@
+runtime,metric,result
+runc,run_time,207.1118165
+runsc,run_time,244.473401